[SVE] Add get_active_lane_mask builtin (#16965)

Adds a `get_active_lane_mask` builtin and lowering to
`llvm.get.active.lane.mask` intrinsic. This will be used in subsequent
patches for expressing predicated buffer loads/stores in TIR. Further
information can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).

Co-authored-by: Elen Kalda <elen.kalda@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>

Change-Id: Id9d65f9f11503ad35dd0b3db4bfc81249a76f701
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 10e5b46..5836eb8 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -915,6 +915,14 @@
  */
 TVM_DLL const Op& vscale();
 
+/*!
+ * \brief Calculate a predicate mask given an upper bound (limit) and a current value (base).
+ *
+ * It will be lowered to the llvm.get.active.lane.mask intrinsic.
+ * (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
+ */
+TVM_DLL const Op& get_active_lane_mask();
+
 /*! \brief The kind of structure field info used in intrinsic */
 enum TVMStructFieldKind : int {
   // array head address
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index c04ac78..5a0a564 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1903,6 +1903,7 @@
 vectorlow = _dtype_forward(_tir_op.vectorlow)
 vectorhigh = _dtype_forward(_tir_op.vectorhigh)
 vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)
 
 
 broadcast = Broadcast
@@ -2219,4 +2220,5 @@
     "CommReducer",
     "Range",
     "vscale",
+    "get_active_lane_mask",
 ]
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 1723804..24ba4cc 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -88,7 +88,7 @@
 from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
 from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
 from .op import start_profile_intrinsic, end_profile_intrinsic
-from .op import vscale
+from .op import vscale, get_active_lane_mask
 from .generic import add, subtract, multiply
 
 from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 6b72e63..db52bec 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3349,6 +3349,27 @@
     return call_intrin("int32", "tir.vscale")
 
 
+def get_active_lane_mask(dtype, base, limit):
+    """
+    Calculate a predicate mask given an upper bound (limit) and a current value (base).
+
+    It will be lowered to the llvm.get.active.lane.mask intrinsic.
+    (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    base : PrimExpr
+        An expression reprsenting the base.
+
+    limit : PrimExpr
+        An expression representing the limit.
+    """
+    return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min")  # type: ignore
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 95512a0..6566bb4 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1478,6 +1478,11 @@
     llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
     llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
     return builder_->CreateCall(f);
+  } else if (op->op.same_as(builtin::get_active_lane_mask())) {
+    llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask;
+    llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype),
+                                         {builder_->getInt32Ty(), builder_->getInt32Ty()});
+    return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])});
 #endif
   } else {
     LOG(FATAL) << "unknown intrinsic " << op->op;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index fbe31c8..cf82eb0 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -397,6 +397,13 @@
 
 TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr<TCallEffectKind>("TCallEffectKind",
                                                           Integer(CallEffectKind::kPure));
+
+TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
+    .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         Integer(ScriptDtypePrintLocation::kFirst));
+
 }  // namespace builtin
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py
index 8f22ba5..452638b 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -680,5 +680,25 @@
         check_correct_assembly(dtype=dtype)
 
 
+@pytest.mark.skipif(
+    llvm_version_major() < 11,
+    reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM",
+)
+def test_get_active_lane_mask():
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+    @T.prim_func
+    def before(a: T.handle):
+        A = T.match_buffer(a, (30,), "int1")
+        for i in range(T.ceildiv(30, T.vscale() * 4)):
+            A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30)
+
+    with tvm.target.Target(target):
+        out = tvm.build(before)
+
+    ll = out.get_source("ll")
+    assert "get.active.lane.mask" in ll
+
+
 if __name__ == "__main__":
     tvm.testing.main()