[SVE] Add get_active_lane_mask builtin (#16965)
Adds a `get_active_lane_mask` builtin and lowering to
`llvm.get.active.lane.mask` intrinsic. This will be used in subsequent
patches for expressing predicated buffer loads/stores in TIR. Further
information can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).
Co-authored-by: Elen Kalda <elen.kalda@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>
Change-Id: Id9d65f9f11503ad35dd0b3db4bfc81249a76f701
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 10e5b46..5836eb8 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -915,6 +915,14 @@
*/
TVM_DLL const Op& vscale();
+/*!
+ * \brief Calculate a predicate mask given an upper bound (limit) and a current value (base).
+ *
+ * It will be lowered to the llvm.get.active.lane.mask intrinsic.
+ * (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
+ */
+TVM_DLL const Op& get_active_lane_mask();
+
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index c04ac78..5a0a564 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1903,6 +1903,7 @@
vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)
broadcast = Broadcast
@@ -2219,4 +2220,5 @@
"CommReducer",
"Range",
"vscale",
+ "get_active_lane_mask",
]
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 1723804..24ba4cc 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -88,7 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
-from .op import vscale
+from .op import vscale, get_active_lane_mask
from .generic import add, subtract, multiply
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 6b72e63..db52bec 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3349,6 +3349,27 @@
return call_intrin("int32", "tir.vscale")
+def get_active_lane_mask(dtype, base, limit):
+ """
+ Calculate a predicate mask given an upper bound (limit) and a current value (base).
+
+ It will be lowered to the llvm.get.active.lane.mask intrinsic.
+ (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the result.
+
+ base : PrimExpr
+ An expression reprsenting the base.
+
+ limit : PrimExpr
+ An expression representing the limit.
+ """
+ return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
+
+
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 95512a0..6566bb4 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1478,6 +1478,11 @@
llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
return builder_->CreateCall(f);
+ } else if (op->op.same_as(builtin::get_active_lane_mask())) {
+ llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask;
+ llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype),
+ {builder_->getInt32Ty(), builder_->getInt32Ty()});
+ return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])});
#endif
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index fbe31c8..cf82eb0 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -397,6 +397,13 @@
TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
+
+TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
+ .set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+ Integer(ScriptDtypePrintLocation::kFirst));
+
} // namespace builtin
} // namespace tir
} // namespace tvm
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py
index 8f22ba5..452638b 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -680,5 +680,25 @@
check_correct_assembly(dtype=dtype)
+@pytest.mark.skipif(
+ llvm_version_major() < 11,
+ reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM",
+)
+def test_get_active_lane_mask():
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+ @T.prim_func
+ def before(a: T.handle):
+ A = T.match_buffer(a, (30,), "int1")
+ for i in range(T.ceildiv(30, T.vscale() * 4)):
+ A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30)
+
+ with tvm.target.Target(target):
+ out = tvm.build(before)
+
+ ll = out.get_source("ll")
+ assert "get.active.lane.mask" in ll
+
+
if __name__ == "__main__":
tvm.testing.main()