[DataType] Update to use explicit Bool Type Aligning with DLPack (#18453)
This PR updates the project to use explicit bool type which helps us to
align with dlpack. It will also streamline explicit use of bool types.
diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index f703a0c..ae346ec 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3
+Subproject commit ae346ec92a3c386f1376064ae086aae72947c329
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 0af3022..0c69833 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -60,6 +60,7 @@
kFloat = kDLFloat,
kHandle = kDLOpaqueHandle,
kBFloat = kDLBfloat,
+ kBool = kDLBool,
kFloat8_e3m4 = kDLFloat8_e3m4,
kFloat8_e4m3 = kDLFloat8_e4m3,
kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
@@ -137,8 +138,10 @@
}
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
- /*! \return whether type is a scalar type. */
- bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
+ /*! \return whether type is a bool type. */
+ bool is_bool() const { return code() == DataType::kBool; }
+ /*! \return whether type can be used in a predicate expression. */
+ bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a bfloat type. */
@@ -204,7 +207,7 @@
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
- bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
+ bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); }
/*! \return whether type is a Void type. */
bool is_void() const {
return code() == DataType::kHandle && bits() == 0 && static_cast<int16_t>(data_.lanes) == 0;
@@ -381,7 +384,7 @@
* \return The constructed data type.
*/
static DataType Bool(int lanes = 1, bool is_scalable = false) {
- return DataType::UInt(1, lanes, is_scalable);
+ return DataType(kDLBool, 8, lanes, is_scalable);
}
/*!
* \brief Construct a handle type.
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 6a0f427..57f8681 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -816,7 +816,7 @@
* \return The result expression.
*/
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
- return make_const(DataType::UInt(1, lanes), 1);
+ return make_const(DataType::Bool(lanes), 1);
}
/*!
* \brief Make a constant false expression.
@@ -825,7 +825,7 @@
* \return The result expression.
*/
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
- return make_const(DataType::UInt(1, lanes), 0);
+ return make_const(DataType::Bool(lanes), 0);
}
/*!
* \brief Get x as constant int expression.
@@ -957,7 +957,7 @@
template <typename ValueType>
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
- if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
+ if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
index 22f996a..b22b0a7 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -61,6 +61,7 @@
if (
DataType(b.dtype).type_code == DataTypeCode.INT
or DataType(b.dtype).type_code == DataTypeCode.UINT
+ or DataType(b.dtype).type_code == DataTypeCode.BOOL
):
a = IntImm(_get_type_str(b.dtype), a)
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
@@ -80,6 +81,7 @@
if (
DataType(a.dtype).type_code == DataTypeCode.INT
or DataType(a.dtype).type_code == DataTypeCode.UINT
+ or DataType(a.dtype).type_code == DataTypeCode.BOOL
):
b = IntImm(_get_type_str(a.dtype), b)
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0..a6313ae 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@
)
buffer_var = buffer.data
- self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
+ self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="bool"), x))
return BufferVar(self, buffer, dtype)
def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index dda7f67..5118204 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -349,8 +349,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
});
return std::nullopt;
}
@@ -358,8 +358,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
});
return std::nullopt;
}
@@ -367,8 +367,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
});
return std::nullopt;
}
@@ -376,8 +376,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value);
});
return std::nullopt;
}
@@ -385,8 +385,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value);
});
return std::nullopt;
}
@@ -394,8 +394,8 @@
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value);
});
return std::nullopt;
}
@@ -426,7 +426,7 @@
inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
- return IntImm(DataType::UInt(1), !(pa->value));
+ return IntImm(DataType::Bool(), !(pa->value));
}
return std::nullopt;
}
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 7e1d8fb..d8296ba 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -735,9 +735,12 @@
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything(DataType dtype) {
- if (!dtype.is_int() && !dtype.is_uint()) {
+ if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) {
return MakeBound(kNegInf, kPosInf);
}
+ if (dtype.is_bool()) {
+ return MakeBound(0, 1);
+ }
Entry ret;
int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
if (dtype.is_uint()) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 6c0065c..b856854 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -53,8 +53,9 @@
IntImm::IntImm(DataType dtype, int64_t value, Span span) {
ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype
<< " was supplied.";
- ICHECK(dtype.is_int() || dtype.is_uint())
- << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
+ ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool())
+ << "ValueError: IntImm supports only int or uint or bool type, but " << dtype
+ << " was supplied.";
if (dtype.is_uint()) {
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
<< " is negative for unsigned integer type " << dtype;
@@ -62,7 +63,7 @@
ICHECK_LT(value, 1LL << dtype.bits())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
- } else if (dtype.bits() == 1) {
+ } else if (dtype.bits() == 1 || dtype.is_bool()) {
// int(1)
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
} else if (dtype.bits() < 64) {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index ff8596c..5bcb5f2 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -328,7 +328,7 @@
*static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value);
} else if (dtype == DataType::Int(64)) {
*static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value);
- } else if (dtype == DataType::UInt(1)) {
+ } else if (dtype == DataType::Bool()) {
*static_cast<bool*>(arr->data) = static_cast<bool>(value);
} else if (dtype == DataType::UInt(8)) {
*static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value);
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a1..1bd3084 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@
if (arr->device.device_type != kDLCPU) {
arr = arr.CopyTo(DLDevice{kDLCPU, 0});
}
- ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+ ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || arr->dtype.code == kDLBool);
int64_t result;
switch (arr->dtype.bits) {
case 1: {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdb0c6b..5f8b599 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -148,6 +148,7 @@
// types
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace());
+ t_int1_ = llvm::Type::getInt1Ty(*ctx);
t_int_ = llvm::Type::getInt32Ty(*ctx);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
@@ -576,6 +577,8 @@
llvm::LLVMContext* ctx = llvm_target_->GetContext();
if (dtype.is_int() || dtype.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx, dtype.bits());
+ } else if (dtype.is_bool()) {
+ etype = t_int1_;
} else if (dtype.is_float()) {
switch (dtype.bits()) {
case 16:
@@ -922,7 +925,7 @@
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
- } else if (to.is_uint() && to.bits() == 1) {
+ } else if (to.is_bool()) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
return builder_->CreateFCmpONE(value, zero);
@@ -943,7 +946,7 @@
}
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
- } else if (from.is_uint() && to.is_float()) {
+ } else if ((from.is_uint() || from.is_bool()) && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
ICHECK(from.is_float() && to.is_float());
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 5cf053c..efec7ad 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -536,6 +536,7 @@
llvm::Type* t_void_{nullptr};
llvm::PointerType* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
+ llvm::Type* t_int1_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
llvm::Type* t_int16_{nullptr};
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index 769401c..8ea55b8 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -230,6 +230,12 @@
os << lanes;
return;
}
+ } else if (t.is_bool()) {
+ os << "uint";
+ if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) {
+ os << lanes;
+ return;
+ }
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc
index 60fa786..917036b 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -109,6 +109,11 @@
os << "void";
return;
}
+ // default c may be have bool type, can be handled in subclass
+ if (type.is_bool()) {
+ os << "int";
+ return;
+ }
if (type.is_float()) {
if (type.bits() == 32) {
os << "float";
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index ddbc22d..c062926 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -430,7 +430,7 @@
spirv::Value dst_ptr =
builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index));
spirv::Value src_ptr = VisitExpr(op->args[5]);
- spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+ spirv::SType type_bool = builder_->GetSType(DataType::Bool());
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
spirv::Value loaded =
@@ -492,7 +492,7 @@
builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
uint32_t mask = spv::MemoryAccessMaskNone;
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask);
- spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+ spirv::SType type_bool = builder_->GetSType(DataType::Bool());
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val,
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 545e677..bac66a3 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -76,7 +76,7 @@
ext_glsl450_ = ExtInstImport("GLSL.std.450");
t_int32_ = DeclareType(DataType::Int(32));
t_uint32_ = DeclareType(DataType::UInt(32));
- t_bool_ = DeclareType(DataType::UInt(1));
+ t_bool_ = DeclareType(DataType::Bool());
t_fp32_ = DeclareType(DataType::Float(32));
const_i32_zero_ = IntImm(t_int32_, 0);
@@ -115,7 +115,7 @@
SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) {
if (dtype == DataType::Int(32)) {
return t_int32_;
- } else if (dtype == DataType::UInt(1)) {
+ } else if (dtype == DataType::Bool()) {
return t_bool_;
} else if (dtype == DataType::Float(32)) {
return t_fp32_;
@@ -467,7 +467,7 @@
}
ICHECK_LE(dtype.type.bits(), 64);
Value ret = NewValue(dtype, kConstant);
- if (dtype.type == DataType::UInt(1)) {
+ if (dtype.type == DataType::Bool()) {
// bool types.
if (*pvalue) {
ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
@@ -501,8 +501,7 @@
SType t;
t.id = id_counter_++;
t.type = dtype;
- if (dtype.bits() == 1) {
- ICHECK(dtype.is_uint());
+ if (dtype.is_bool()) {
ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
} else if (dtype.is_int()) {
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
@@ -584,7 +583,7 @@
// future. Requiring StorageBuffer8BitAccess in order to declare an
// Int8 prevents use of an 8-bit loop iterator on a device that
// supports Int8 but doesn't support 8-bit buffer access.
- if (dtype.bits() == 8) {
+ if (dtype.bits() == 8 && !dtype.is_bool()) {
ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
<< "Vulkan target does not support StorageBuffer8BitAccess. "
<< "If your device supports 8-bit buffer access, "
@@ -822,19 +821,19 @@
}
}
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- ICHECK_EQ(a.stype.id, b.stype.id); \
- ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
- const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
- if (a.stype.type.is_int()) { \
- return MakeValue(spv::OpS##_Op, bool_type, a, b); \
- } else if (a.stype.type.is_uint()) { \
- return MakeValue(spv::OpU##_Op, bool_type, a, b); \
- } else { \
- ICHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
- } \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ ICHECK_EQ(a.stype.id, b.stype.id); \
+ ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+ if (a.stype.type.is_int()) { \
+ return MakeValue(spv::OpS##_Op, bool_type, a, b); \
+ } else if (a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpU##_Op, bool_type, a, b); \
+ } else { \
+ ICHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
@@ -842,17 +841,17 @@
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
- Value IRBuilder::_OpName(Value a, Value b) { \
- ICHECK_EQ(a.stype.id, b.stype.id); \
- ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
- const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
- if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
- return MakeValue(spv::OpI##_Op, bool_type, a, b); \
- } else { \
- ICHECK(a.stype.type.is_float()); \
- return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
- } \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
+ Value IRBuilder::_OpName(Value a, Value b) { \
+ ICHECK_EQ(a.stype.id, b.stype.id); \
+ ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
+ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
+ return MakeValue(spv::OpI##_Op, bool_type, a, b); \
+ } else { \
+ ICHECK(a.stype.type.is_float()); \
+ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
+ } \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
@@ -860,7 +859,7 @@
Value IRBuilder::Select(Value cond, Value a, Value b) {
ICHECK_EQ(a.stype.id, b.stype.id);
- ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1));
+ ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool());
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b869..5eee4ff 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@
<< " lanes. The number of lanes must match.";
DataType predicate_element_dtype = predicate_dtype.element_of();
- ICHECK(predicate_element_dtype.is_bool())
+ ICHECK(predicate_element_dtype.is_predicate_dtype())
<< "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
<< ".";
}
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a013..4762275 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@
<< " lanes. The number of lanes must match.";
DataType predicate_element_dtype = predicate_dtype.element_of();
- ICHECK(predicate_element_dtype.is_bool())
+ ICHECK(predicate_element_dtype.is_predicate_dtype())
<< "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
<< ".";
}
@@ -687,7 +687,8 @@
Span span) {
CHECK_EQ(block->iter_vars.size(), values.size())
<< "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
- CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
+ CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1))
+ << "TypeError: Expect Block.predicate to be a bool expression";
ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
node->iter_values = std::move(values);
node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f992..51c0b64 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@
} else if (ltype.is_float4() && !rtype.is_float4()) {
// Cast int->float4 for rhs when lhs is a float4
rhs = cast(ltype, rhs);
+ } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+ // Cast bool to int for lhs when rhs is a int or uint
+ lhs = cast(rtype, lhs);
+ } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+ // Cast bool to int for rhs when lhs is a int or uint
+ rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
@@ -621,7 +627,7 @@
// if_then_else
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
- ICHECK(cond.dtype() == DataType::Bool(1))
+ ICHECK(cond.dtype() == DataType::Bool())
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value, span);
if (const IntImmNode* op = cond.as<IntImmNode>()) {
@@ -698,10 +704,10 @@
<< rhs << " of type " << rhs.dtype();
}
-void type_check_integer_args(const PrimExpr& arg, const char* op) {
- ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
- << "Expected integer argument for " << op << ", but received " << arg << " of type "
- << arg.dtype();
+void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) {
+ ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || arg.dtype().is_bool())
+ << "Expected integer or boolean argument for " << op << ", but received " << arg
+ << " of type " << arg.dtype();
}
void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
@@ -712,6 +718,15 @@
<< "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
<< rhs.dtype();
}
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
+ ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || lhs.dtype().is_bool())
+ << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type "
+ << lhs.dtype();
+ ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || rhs.dtype().is_bool())
+ << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
+ << rhs.dtype();
+}
} // namespace
PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -781,7 +796,7 @@
// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
- type_check_integer_args(a, b, "& operator (bitwise AND)");
+ type_check_int_or_bool_args(a, b, "& operator (bitwise AND)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -793,7 +808,7 @@
// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
- type_check_integer_args(a, b, "| operator (bitwise OR)");
+ type_check_int_or_bool_args(a, b, "| operator (bitwise OR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -805,7 +820,7 @@
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
- type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+ type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -818,7 +833,7 @@
PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
PrimExpr bitwise_neg(PrimExpr a, Span span) {
- type_check_integer_args(a, "~ operator (bitwise NOT)");
+ type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
}
@@ -935,7 +950,7 @@
PrimExpr result = tir::Add(x, y, span);
PrimExpr identity_element = make_zero(source.dtype(), span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
PrimExpr all(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -944,7 +959,7 @@
PrimExpr result = tir::And(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), true, span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
PrimExpr any(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -953,7 +968,7 @@
PrimExpr result = tir::Or(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), false, span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
PrimExpr max(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -961,7 +976,7 @@
PrimExpr result = tir::Max(x, y, span);
PrimExpr identity_element = min_value(source.dtype(), span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -969,7 +984,7 @@
PrimExpr result = tir::Min(x, y, span);
PrimExpr identity_element = max_value(source.dtype(), span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -977,7 +992,7 @@
PrimExpr result = tir::Mul(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), 1, span);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
}
// fmod
@@ -992,7 +1007,7 @@
// floor
PrimExpr floor(PrimExpr x, Span span) {
- if (x.dtype().is_int() || x.dtype().is_uint()) {
+ if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
return x;
}
using tir::FloatImmNode;
@@ -1006,7 +1021,7 @@
// ceil
PrimExpr ceil(PrimExpr x, Span span) {
- if (x.dtype().is_int() || x.dtype().is_uint()) {
+ if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
return x;
}
using tir::FloatImmNode;
@@ -1020,7 +1035,7 @@
// round
PrimExpr round(PrimExpr x, Span span) {
- if (x.dtype().is_int() || x.dtype().is_uint()) {
+ if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
return x;
}
using tir::FloatImmNode;
@@ -1034,7 +1049,7 @@
// nearbyint
PrimExpr nearbyint(PrimExpr x, Span span) {
- if (x.dtype().is_int() || x.dtype().is_uint()) {
+ if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
return x;
}
using tir::FloatImmNode;
@@ -1048,7 +1063,7 @@
// trunc
PrimExpr trunc(PrimExpr x, Span span) {
- if (x.dtype().is_int() || x.dtype().is_uint()) {
+ if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
return x;
}
using tir::FloatImmNode;
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 8a5d39e..1b85d7d 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -218,7 +218,7 @@
init_nest_.emplace_back(LetStmt(
buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
- PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
+ PrimExpr v_strides_is_null = Call(DataType::Bool(), builtin::isnullptr(), {buf_strides->data});
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc
index 1b4bd7b..8cdef1b 100644
--- a/src/tir/transforms/inject_ptx_ldg32.cc
+++ b/src/tir/transforms/inject_ptx_ldg32.cc
@@ -41,7 +41,7 @@
// addr[0] -> global_addr / addr[1] -> local_addr
addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local");
predicate_buffer =
- decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(1), "predicate", "local");
+ decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local");
}
Stmt result = StmtMutator::VisitStmt_(allocate);
if (!has_buffer_2) {
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index f6df6c8..66e1379 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -256,7 +256,7 @@
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
Stmt alloc_nullptr_check = IfThenElse(
- Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
+ Call(DataType::Bool(), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_.value()),
cast(DataType::Int(32), device_id_.value()), op->buffer_var});
@@ -617,7 +617,7 @@
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
Stmt body = SeqStmt(
- {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
+ {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error),
let->body, free_stmt});
DataType dtype =
diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc
index 6c42972..6ae6deb 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -167,8 +167,8 @@
TEST(ScalableDataType, TestScalableBool) {
tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
- ASSERT_EQ(scalable_type.code(), kDLUInt);
- ASSERT_EQ(scalable_type.bits(), 1);
+ ASSERT_EQ(scalable_type.code(), kDLBool);
+ ASSERT_EQ(scalable_type.bits(), 8);
ASSERT_EQ(scalable_type.vscale_factor(), 4);
ASSERT_TRUE(scalable_type.is_scalable_vector());
}
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py
index 6954cf4..5eaaac6 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -93,7 +93,7 @@
x, y, z = te.var("x"), te.var("y"), te.var("z")
x64 = te.var("x", dtype="int64")
vx = te.var("vx", dtype="int32x2")
- vc = te.var("vc", dtype="uint1")
+ vc = te.var("vc", dtype="bool")
test_case = tvm.testing.parameter(
# Add rules
TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)),
@@ -285,22 +285,22 @@
tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
),
## Logical rules
- TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")),
+ TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("boolx2")),
TestCase(
tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
- (tvm.tir.NE(y, x)).astype("uint1x2"),
+ (tvm.tir.NE(y, x)).astype("boolx2"),
),
- TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")),
- TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")),
- TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")),
- TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")),
+ TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("boolx2")),
+ TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("boolx2")),
+ TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("boolx2")),
+ TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("boolx2")),
TestCase(
- tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
- (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
+ tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")),
+ (tvm.tir.And(y <= x, vc)).astype("boolx2"),
),
TestCase(
- tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
- (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
+ tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")),
+ (tvm.tir.Or(y <= x, vc)).astype("boolx2"),
),
)
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index a0ff507..b076827 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1721,7 +1721,6 @@
w = relax.Var("w", R.Tensor((5,), "float32"))
targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32"))
targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64"))
- targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool"))
targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32"))
targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32"))
@@ -1733,7 +1732,6 @@
bb.normalize(relax.op.nn.nll_loss(x, targets1, w))
# correct cases
- bb.normalize(relax.op.nn.nll_loss(x, targets2, w)) # bool is uint1
bb.normalize(relax.op.nn.nll_loss(x, targets3, w))
bb.normalize(relax.op.nn.nll_loss(x, targets4, w))
bb.normalize(relax.op.nn.nll_loss(x, targets5, w))
diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py
index 42c2998..4076070 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@
assert isinstance(x, tvm.tir.AttrStmt)
assert x.value.value == 1
- x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), tvm.runtime.convert("hellow"), nop)
+ x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), tvm.runtime.convert("hellow"), nop)
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
@@ -150,8 +150,8 @@
assert x.extent.value == 10
assert x.body == nop
- buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1")))
- buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var)
+ buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("bool")))
+ buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var)
x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10])
assert isinstance(x, tvm.tir.BufferStore)
assert x.buffer == buffer
@@ -160,7 +160,7 @@
assert x.value.value == 1
buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32")))
- x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
+ x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@
storage_scope = "global.texture"
buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
- x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
+ x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@
assert x.attr_key == "xyz"
assert x.body == nop
- x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop)
+ x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
assert isinstance(x, tvm.tir.IfThenElse)
assert x.then_case.value.value == 11
assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py
index 5e1d25e..bc7cfea 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -302,7 +302,7 @@
z = te.var("z", "int32")
assert str(tvm.tir.isnan(z)) == "T.bool(False)"
k = te.var("k", "int8x2")
- assert str(tvm.tir.isnan(k).dtype) == "uint1x2"
+ assert str(tvm.tir.isnan(k).dtype) == "boolx2"
def test_equality():
diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tir-base/test_tir_ops.py
index dfa5cba..cb7d8c5 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@
x = te.var("x")
for val in [0, 1]:
for func in [tvm.tir.all, tvm.tir.any]:
- check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
- check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+ check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+ check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
# Test const folding when both arguments are const
for tvm_func, py_func in [
@@ -80,13 +80,13 @@
for v1 in [0, 1]:
for v2 in [0, 1]:
tvm.ir.assert_structural_equal(
- tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, "uint1")),
- tvm.tir.const(py_func(v1, v2), "uint1"),
+ tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, "bool")),
+ tvm.tir.const(py_func(v1, v2), "bool"),
)
- x = te.var("x", "uint1")
- true = tvm.tir.const(1, "uint1")
- false = tvm.tir.const(0, "uint1")
+ x = te.var("x", "bool")
+ true = tvm.tir.const(1, "bool")
+ false = tvm.tir.const(0, "bool")
assert tvm.tir.all(x, true).same_as(x)
assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba..8352b11 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@
# the expected allocate
buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local"))
ir_expected = tir.Allocate(
- buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+ buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
)
# Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deac..e4af158 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@
buffer_load = tir.BufferLoad(
buffer=buffer_map[b],
indices=[0, tir.Ramp(0, 4, 4)],
- predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+ predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
)
body = tir.BufferStore(
buffer=buffer_map[a],
value=buffer_load,
indices=[0, tir.Ramp(0, 2, 4)],
- predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+ predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
)
func = tir.PrimFunc(
params=[a, b],