[DataType] Update to use explicit Bool Type Aligning with DLPack (#18453)

This PR updates the project to use explicit bool type which helps us to
align with dlpack. It will also streamline explicit use of bool types.
diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index f703a0c..ae346ec 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3
+Subproject commit ae346ec92a3c386f1376064ae086aae72947c329
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 0af3022..0c69833 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -60,6 +60,7 @@
     kFloat = kDLFloat,
     kHandle = kDLOpaqueHandle,
     kBFloat = kDLBfloat,
+    kBool = kDLBool,
     kFloat8_e3m4 = kDLFloat8_e3m4,
     kFloat8_e4m3 = kDLFloat8_e4m3,
     kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
@@ -137,8 +138,10 @@
   }
   /*! \return whether type is a scalar type. */
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
-  /*! \return whether type is a scalar type. */
-  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
+  /*! \return whether type is a bool type. */
+  bool is_bool() const { return code() == DataType::kBool; }
+  /*! \return whether type can be used in a predicate expression. */
+  bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a bfloat type. */
@@ -204,7 +207,7 @@
   /*! \return whether type is a vector type. */
   bool is_vector() const { return lanes() > 1; }
   /*! \return whether type is a bool vector type. */
-  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
+  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); }
   /*! \return whether type is a Void type. */
   bool is_void() const {
     return code() == DataType::kHandle && bits() == 0 && static_cast<int16_t>(data_.lanes) == 0;
@@ -381,7 +384,7 @@
    * \return The constructed data type.
    */
   static DataType Bool(int lanes = 1, bool is_scalable = false) {
-    return DataType::UInt(1, lanes, is_scalable);
+    return DataType(kDLBool, 8, lanes, is_scalable);
   }
   /*!
    * \brief Construct a handle type.
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 6a0f427..57f8681 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -816,7 +816,7 @@
  * \return The result expression.
  */
 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 1);
+  return make_const(DataType::Bool(lanes), 1);
 }
 /*!
  * \brief Make a constant false expression.
@@ -825,7 +825,7 @@
  * \return The result expression.
  */
 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 0);
+  return make_const(DataType::Bool(lanes), 0);
 }
 /*!
  * \brief Get x as constant int expression.
@@ -957,7 +957,7 @@
 
 template <typename ValueType>
 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
-  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
+  if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
   if (t.is_uint()) {
     // Use IntImm if it is a small integer
     uint64_t uval = static_cast<uint64_t>(value);
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
index 22f996a..b22b0a7 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -61,6 +61,7 @@
                 if (
                     DataType(b.dtype).type_code == DataTypeCode.INT
                     or DataType(b.dtype).type_code == DataTypeCode.UINT
+                    or DataType(b.dtype).type_code == DataTypeCode.BOOL
                 ):
                     a = IntImm(_get_type_str(b.dtype), a)
                 elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
@@ -80,6 +81,7 @@
             if (
                 DataType(a.dtype).type_code == DataTypeCode.INT
                 or DataType(a.dtype).type_code == DataTypeCode.UINT
+                or DataType(a.dtype).type_code == DataTypeCode.BOOL
             ):
                 b = IntImm(_get_type_str(a.dtype), b)
             elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0..a6313ae 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@
         )
 
         buffer_var = buffer.data
-        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="bool"), x))
         return BufferVar(self, buffer, dtype)
 
     def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index dda7f67..5118204 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -349,8 +349,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
   });
   return std::nullopt;
 }
@@ -358,8 +358,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
   });
   return std::nullopt;
 }
@@ -367,8 +367,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
   });
   return std::nullopt;
 }
@@ -376,8 +376,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value);
   });
   return std::nullopt;
 }
@@ -385,8 +385,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value);
   });
   return std::nullopt;
 }
@@ -394,8 +394,8 @@
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value);
   });
   return std::nullopt;
 }
@@ -426,7 +426,7 @@
 inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
   const IntImmNode* pa = a.as<IntImmNode>();
   if (pa) {
-    return IntImm(DataType::UInt(1), !(pa->value));
+    return IntImm(DataType::Bool(), !(pa->value));
   }
   return std::nullopt;
 }
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 7e1d8fb..d8296ba 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -735,9 +735,12 @@
    * \return Bound that represent everything dtype can represent.
    */
   static Entry Everything(DataType dtype) {
-    if (!dtype.is_int() && !dtype.is_uint()) {
+    if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) {
       return MakeBound(kNegInf, kPosInf);
     }
+    if (dtype.is_bool()) {
+      return MakeBound(0, 1);
+    }
     Entry ret;
     int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
     if (dtype.is_uint()) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 6c0065c..b856854 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -53,8 +53,9 @@
 IntImm::IntImm(DataType dtype, int64_t value, Span span) {
   ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype
                             << " was supplied.";
-  ICHECK(dtype.is_int() || dtype.is_uint())
-      << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
+  ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool())
+      << "ValueError: IntImm supports only int or uint or bool type, but " << dtype
+      << " was supplied.";
   if (dtype.is_uint()) {
     ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
                          << " is negative for unsigned integer type " << dtype;
@@ -62,7 +63,7 @@
       ICHECK_LT(value, 1LL << dtype.bits())
           << "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
     }
-  } else if (dtype.bits() == 1) {
+  } else if (dtype.bits() == 1 || dtype.is_bool()) {
     // int(1)
     ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
   } else if (dtype.bits() < 64) {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index ff8596c..5bcb5f2 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -328,7 +328,7 @@
     *static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value);
   } else if (dtype == DataType::Int(64)) {
     *static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value);
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     *static_cast<bool*>(arr->data) = static_cast<bool>(value);
   } else if (dtype == DataType::UInt(8)) {
     *static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value);
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a1..1bd3084 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@
   if (arr->device.device_type != kDLCPU) {
     arr = arr.CopyTo(DLDevice{kDLCPU, 0});
   }
-  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || arr->dtype.code == kDLBool);
   int64_t result;
   switch (arr->dtype.bits) {
     case 1: {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdb0c6b..5f8b599 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -148,6 +148,7 @@
   // types
   t_void_ = llvm::Type::getVoidTy(*ctx);
   t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace());
+  t_int1_ = llvm::Type::getInt1Ty(*ctx);
   t_int_ = llvm::Type::getInt32Ty(*ctx);
   t_char_ = llvm::Type::getInt8Ty(*ctx);
   t_int8_ = llvm::Type::getInt8Ty(*ctx);
@@ -576,6 +577,8 @@
   llvm::LLVMContext* ctx = llvm_target_->GetContext();
   if (dtype.is_int() || dtype.is_uint()) {
     etype = llvm::Type::getIntNTy(*ctx, dtype.bits());
+  } else if (dtype.is_bool()) {
+    etype = t_int1_;
   } else if (dtype.is_float()) {
     switch (dtype.bits()) {
       case 16:
@@ -922,7 +925,7 @@
 
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
-  } else if (to.is_uint() && to.bits() == 1) {
+  } else if (to.is_bool()) {
     if (from.is_float()) {
       llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
       return builder_->CreateFCmpONE(value, zero);
@@ -943,7 +946,7 @@
     }
   } else if (from.is_int() && to.is_float()) {
     return builder_->CreateSIToFP(value, target);
-  } else if (from.is_uint() && to.is_float()) {
+  } else if ((from.is_uint() || from.is_bool()) && to.is_float()) {
     return builder_->CreateUIToFP(value, target);
   } else {
     ICHECK(from.is_float() && to.is_float());
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 5cf053c..efec7ad 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -536,6 +536,7 @@
   llvm::Type* t_void_{nullptr};
   llvm::PointerType* t_void_p_{nullptr};
   llvm::Type* t_int_{nullptr};
+  llvm::Type* t_int1_{nullptr};
   llvm::Type* t_char_{nullptr};
   llvm::Type* t_int8_{nullptr};
   llvm::Type* t_int16_{nullptr};
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index 769401c..8ea55b8 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -230,6 +230,12 @@
       os << lanes;
       return;
     }
+  } else if (t.is_bool()) {
+    os << "uint";
+    if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) {
+      os << lanes;
+      return;
+    }
   } else if (t.is_uint() || t.is_int()) {
     if (t.is_uint()) {
       os << 'u';
diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc
index 60fa786..917036b 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -109,6 +109,11 @@
     os << "void";
     return;
   }
+  // default c may be have bool type, can be handled in subclass
+  if (type.is_bool()) {
+    os << "int";
+    return;
+  }
   if (type.is_float()) {
     if (type.bits() == 32) {
       os << "float";
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index ddbc22d..c062926 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -430,7 +430,7 @@
     spirv::Value dst_ptr =
         builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index));
     spirv::Value src_ptr = VisitExpr(op->args[5]);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     spirv::Value loaded =
@@ -492,7 +492,7 @@
         builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
     uint32_t mask = spv::MemoryAccessMaskNone;
     spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val,
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 545e677..bac66a3 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -76,7 +76,7 @@
   ext_glsl450_ = ExtInstImport("GLSL.std.450");
   t_int32_ = DeclareType(DataType::Int(32));
   t_uint32_ = DeclareType(DataType::UInt(32));
-  t_bool_ = DeclareType(DataType::UInt(1));
+  t_bool_ = DeclareType(DataType::Bool());
   t_fp32_ = DeclareType(DataType::Float(32));
   const_i32_zero_ = IntImm(t_int32_, 0);
 
@@ -115,7 +115,7 @@
 SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) {
   if (dtype == DataType::Int(32)) {
     return t_int32_;
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     return t_bool_;
   } else if (dtype == DataType::Float(32)) {
     return t_fp32_;
@@ -467,7 +467,7 @@
   }
   ICHECK_LE(dtype.type.bits(), 64);
   Value ret = NewValue(dtype, kConstant);
-  if (dtype.type == DataType::UInt(1)) {
+  if (dtype.type == DataType::Bool()) {
     // bool types.
     if (*pvalue) {
       ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
@@ -501,8 +501,7 @@
     SType t;
     t.id = id_counter_++;
     t.type = dtype;
-    if (dtype.bits() == 1) {
-      ICHECK(dtype.is_uint());
+    if (dtype.is_bool()) {
       ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
     } else if (dtype.is_int()) {
       ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
@@ -584,7 +583,7 @@
   // future.  Requiring StorageBuffer8BitAccess in order to declare an
   // Int8 prevents use of an 8-bit loop iterator on a device that
   // supports Int8 but doesn't support 8-bit buffer access.
-  if (dtype.bits() == 8) {
+  if (dtype.bits() == 8 && !dtype.is_bool()) {
     ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
         << "Vulkan target does not support StorageBuffer8BitAccess.  "
         << "If your device supports 8-bit buffer access, "
@@ -822,19 +821,19 @@
   }
 }
 
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                                     \
-  Value IRBuilder::_OpName(Value a, Value b) {                                                  \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                                          \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                                      \
-    const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int()) {                                                                \
-      return MakeValue(spv::OpS##_Op, bool_type, a, b);                                         \
-    } else if (a.stype.type.is_uint()) {                                                        \
-      return MakeValue(spv::OpU##_Op, bool_type, a, b);                                         \
-    } else {                                                                                    \
-      ICHECK(a.stype.type.is_float());                                                          \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                                      \
-    }                                                                                           \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                                    \
+  Value IRBuilder::_OpName(Value a, Value b) {                                                 \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                                         \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                                     \
+    const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int()) {                                                               \
+      return MakeValue(spv::OpS##_Op, bool_type, a, b);                                        \
+    } else if (a.stype.type.is_uint()) {                                                       \
+      return MakeValue(spv::OpU##_Op, bool_type, a, b);                                        \
+    } else {                                                                                   \
+      ICHECK(a.stype.type.is_float());                                                         \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                                     \
+    }                                                                                          \
   }
 
 DEFINE_BUILDER_CMP_OP(LT, LessThan);
@@ -842,17 +841,17 @@
 DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
 DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
 
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                                    \
-  Value IRBuilder::_OpName(Value a, Value b) {                                                  \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                                          \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                                      \
-    const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                                      \
-      return MakeValue(spv::OpI##_Op, bool_type, a, b);                                         \
-    } else {                                                                                    \
-      ICHECK(a.stype.type.is_float());                                                          \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                                      \
-    }                                                                                           \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                                   \
+  Value IRBuilder::_OpName(Value a, Value b) {                                                 \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                                         \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                                     \
+    const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                                     \
+      return MakeValue(spv::OpI##_Op, bool_type, a, b);                                        \
+    } else {                                                                                   \
+      ICHECK(a.stype.type.is_float());                                                         \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                                     \
+    }                                                                                          \
   }
 
 DEFINE_BUILDER_CMP_UOP(EQ, Equal);
@@ -860,7 +859,7 @@
 
 Value IRBuilder::Select(Value cond, Value a, Value b) {
   ICHECK_EQ(a.stype.id, b.stype.id);
-  ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1));
+  ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool());
   return MakeValue(spv::OpSelect, a.stype, cond, a, b);
 }
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b869..5eee4ff 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
         << ".";
   }
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a013..4762275 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
         << ".";
   }
@@ -687,7 +687,8 @@
                            Span span) {
   CHECK_EQ(block->iter_vars.size(), values.size())
       << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
-  CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
+  CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1))
+      << "TypeError: Expect Block.predicate to be a bool expression";
   ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
   node->iter_values = std::move(values);
   node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f992..51c0b64 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@
   } else if (ltype.is_float4() && !rtype.is_float4()) {
     // Cast int->float4 for rhs when lhs is a float4
     rhs = cast(ltype, rhs);
+  } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+    // Cast bool to int for lhs when rhs is a int or uint
+    lhs = cast(rtype, lhs);
+  } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+    // Cast bool to int for rhs when lhs is a int or uint
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -621,7 +627,7 @@
 
 // if_then_else
 PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
-  ICHECK(cond.dtype() == DataType::Bool(1))
+  ICHECK(cond.dtype() == DataType::Bool())
       << "if_then_else only accept the condition to be boolean type.";
   BinaryOpMatchTypes(true_value, false_value, span);
   if (const IntImmNode* op = cond.as<IntImmNode>()) {
@@ -698,10 +704,10 @@
                                 << rhs << " of type " << rhs.dtype();
 }
 
-void type_check_integer_args(const PrimExpr& arg, const char* op) {
-  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
-      << "Expected integer argument for " << op << ", but received " << arg << " of type "
-      << arg.dtype();
+void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) {
+  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || arg.dtype().is_bool())
+      << "Expected integer or boolean argument for " << op << ", but received " << arg
+      << " of type " << arg.dtype();
 }
 
 void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
@@ -712,6 +718,15 @@
       << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
       << rhs.dtype();
 }
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
+  ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || lhs.dtype().is_bool())
+      << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type "
+      << lhs.dtype();
+  ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || rhs.dtype().is_bool())
+      << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
+      << rhs.dtype();
+}
 }  // namespace
 
 PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -781,7 +796,7 @@
 // bitwise and
 PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
 PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "& operator (bitwise AND)");
+  type_check_int_or_bool_args(a, b, "& operator (bitwise AND)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -793,7 +808,7 @@
 // bitwise_or
 PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
 PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "| operator (bitwise OR)");
+  type_check_int_or_bool_args(a, b, "| operator (bitwise OR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -805,7 +820,7 @@
 // bitwise_xor
 PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
 PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+  type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -818,7 +833,7 @@
 PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
 
 PrimExpr bitwise_neg(PrimExpr a, Span span) {
-  type_check_integer_args(a, "~ operator (bitwise NOT)");
+  type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
   return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
 }
 
@@ -935,7 +950,7 @@
   PrimExpr result = tir::Add(x, y, span);
   PrimExpr identity_element = make_zero(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 PrimExpr all(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -944,7 +959,7 @@
   PrimExpr result = tir::And(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), true, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 PrimExpr any(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -953,7 +968,7 @@
   PrimExpr result = tir::Or(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), false, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 PrimExpr max(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -961,7 +976,7 @@
   PrimExpr result = tir::Max(x, y, span);
   PrimExpr identity_element = min_value(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -969,7 +984,7 @@
   PrimExpr result = tir::Min(x, y, span);
   PrimExpr identity_element = max_value(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> init, Span span) {
@@ -977,7 +992,7 @@
   PrimExpr result = tir::Mul(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), 1, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span);
 }
 
 // fmod
@@ -992,7 +1007,7 @@
 
 // floor
 PrimExpr floor(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1006,7 +1021,7 @@
 
 // ceil
 PrimExpr ceil(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1020,7 +1035,7 @@
 
 // round
 PrimExpr round(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1034,7 +1049,7 @@
 
 // nearbyint
 PrimExpr nearbyint(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1048,7 +1063,7 @@
 
 // trunc
 PrimExpr trunc(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 8a5d39e..1b85d7d 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -218,7 +218,7 @@
   init_nest_.emplace_back(LetStmt(
       buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
   init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
-  PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
+  PrimExpr v_strides_is_null = Call(DataType::Bool(), builtin::isnullptr(), {buf_strides->data});
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc
index 1b4bd7b..8cdef1b 100644
--- a/src/tir/transforms/inject_ptx_ldg32.cc
+++ b/src/tir/transforms/inject_ptx_ldg32.cc
@@ -41,7 +41,7 @@
       // addr[0] -> global_addr /  addr[1] -> local_addr
       addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local");
       predicate_buffer =
-          decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(1), "predicate", "local");
+          decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local");
     }
     Stmt result = StmtMutator::VisitStmt_(allocate);
     if (!has_buffer_2) {
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index f6df6c8..66e1379 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -256,7 +256,7 @@
     Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
 
     Stmt alloc_nullptr_check = IfThenElse(
-        Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
+        Call(DataType::Bool(), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
     PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
                             {cast(DataType::Int(32), device_type_.value()),
                              cast(DataType::Int(32), device_id_.value()), op->buffer_var});
@@ -617,7 +617,7 @@
     Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
 
     Stmt body = SeqStmt(
-        {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
+        {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error),
          let->body, free_stmt});
 
     DataType dtype =
diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc
index 6c42972..6ae6deb 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -167,8 +167,8 @@
 
 TEST(ScalableDataType, TestScalableBool) {
   tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
-  ASSERT_EQ(scalable_type.code(), kDLUInt);
-  ASSERT_EQ(scalable_type.bits(), 1);
+  ASSERT_EQ(scalable_type.code(), kDLBool);
+  ASSERT_EQ(scalable_type.bits(), 8);
   ASSERT_EQ(scalable_type.vscale_factor(), 4);
   ASSERT_TRUE(scalable_type.is_scalable_vector());
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py
index 6954cf4..5eaaac6 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -93,7 +93,7 @@
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     x64 = te.var("x", dtype="int64")
     vx = te.var("vx", dtype="int32x2")
-    vc = te.var("vc", dtype="uint1")
+    vc = te.var("vc", dtype="bool")
     test_case = tvm.testing.parameter(
         # Add rules
         TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)),
@@ -285,22 +285,22 @@
             tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
         ),
         ## Logical rules
-        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")),
+        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("boolx2")),
         TestCase(
             tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
-            (tvm.tir.NE(y, x)).astype("uint1x2"),
+            (tvm.tir.NE(y, x)).astype("boolx2"),
         ),
-        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")),
-        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")),
+        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("boolx2")),
+        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("boolx2")),
+        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("boolx2")),
+        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("boolx2")),
         TestCase(
-            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-            (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")),
+            (tvm.tir.And(y <= x, vc)).astype("boolx2"),
         ),
         TestCase(
-            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-            (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")),
+            (tvm.tir.Or(y <= x, vc)).astype("boolx2"),
         ),
     )
 
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index a0ff507..b076827 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1721,7 +1721,6 @@
     w = relax.Var("w", R.Tensor((5,), "float32"))
     targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32"))
     targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64"))
-    targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool"))
     targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32"))
     targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
     targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32"))
@@ -1733,7 +1732,6 @@
         bb.normalize(relax.op.nn.nll_loss(x, targets1, w))
 
     # correct cases
-    bb.normalize(relax.op.nn.nll_loss(x, targets2, w))  # bool is uint1
     bb.normalize(relax.op.nn.nll_loss(x, targets3, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets4, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets5, w))
diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py
index 42c2998..4076070 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), tvm.runtime.convert("hellow"), nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), tvm.runtime.convert("hellow"), nop)
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
@@ -150,8 +150,8 @@
     assert x.extent.value == 10
     assert x.body == nop
 
-    buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1")))
-    buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var)
+    buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("bool")))
+    buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var)
     x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10])
     assert isinstance(x, tvm.tir.BufferStore)
     assert x.buffer == buffer
@@ -160,7 +160,7 @@
     assert x.value.value == 1
 
     buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32")))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@
 
     storage_scope = "global.texture"
     buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop)
+    x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
     assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py
index 5e1d25e..bc7cfea 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -302,7 +302,7 @@
     z = te.var("z", "int32")
     assert str(tvm.tir.isnan(z)) == "T.bool(False)"
     k = te.var("k", "int8x2")
-    assert str(tvm.tir.isnan(k).dtype) == "uint1x2"
+    assert str(tvm.tir.isnan(k).dtype) == "boolx2"
 
 
 def test_equality():
diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tir-base/test_tir_ops.py
index dfa5cba..cb7d8c5 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@
     x = te.var("x")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
-            check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
-            check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+            check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+            check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
 
     # Test const folding when both arguments are const
     for tvm_func, py_func in [
@@ -80,13 +80,13 @@
         for v1 in [0, 1]:
             for v2 in [0, 1]:
                 tvm.ir.assert_structural_equal(
-                    tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, "uint1")),
-                    tvm.tir.const(py_func(v1, v2), "uint1"),
+                    tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, "bool")),
+                    tvm.tir.const(py_func(v1, v2), "bool"),
                 )
 
-    x = te.var("x", "uint1")
-    true = tvm.tir.const(1, "uint1")
-    false = tvm.tir.const(0, "uint1")
+    x = te.var("x", "bool")
+    true = tvm.tir.const(1, "bool")
+    false = tvm.tir.const(0, "bool")
 
     assert tvm.tir.all(x, true).same_as(x)
     assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba..8352b11 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@
     # the expected allocate
     buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local"))
     ir_expected = tir.Allocate(
-        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+        buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
     )
 
     # Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deac..e4af158 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@
     buffer_load = tir.BufferLoad(
         buffer=buffer_map[b],
         indices=[0, tir.Ramp(0, 4, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     body = tir.BufferStore(
         buffer=buffer_map[a],
         value=buffer_load,
         indices=[0, tir.Ramp(0, 2, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     func = tir.PrimFunc(
         params=[a, b],