As part of the work to improve support for vector length agnostic (VLA) programming in TVM, we have been prototyping scalable vectors in TIR. The discussion related to codegen in this RFC would be specific to Scalable Vector Extension (SVE) in Arm(R) architecture, but the TIR enhancements will be architecture agnostic. To read more about SVE, see the SVE overview
This RFC corresponds to the second part of the meta-RFC (with some changes to the initial plan).
Due to the significant improvements in LLVM over the past few years the full scale of changes as described in the first RFC was no longer necessary. Since LLVM can lower to SVE from fixed bound loops, there wasn‘t a need to create explicit SVE loops at TVM’s codegen level.
Expressing scalability at the TIR level would be valuable though as it gives the schedule author more control over using scalable vectors and would make it possible to explicitly use features that are present in SVE, such as gather load and scatter store operations. Moreover, it would open up the path to support Scalable Matrix Extension (SME) in TVM.
Finally, some of the proposed changes could also benefit vector length specific vectorization, especially changes around dealing with loops that are not entirely vectorizable (this is discussed in Predication).
In SVE the size of scalable vectors is expressed through compile time unknown value called vscale
(the name has also been adopted by LLVM). It essentially represents how many chunks of 128 bits fits into a given SVE hardware implementation. So e.g. vscale = 1
implies 128 bit vectors and vscale = 2
implies 256 bit vectors. The same SVE code will run on both of these implementations.
We can introduce a TIR intrinsic for vscale
that we can use in Ramp
and Broadcast
nodes to represent the scalable vectors in TIR. In the backend codegen tir.vscale
can be directly mapped to llvm.vscale
.
Here's how we could create scalable TIR by using the scheduling primitives:
@T.prim_func def before_split(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in range(128): with T.block("B"): vi = T.axis.spatial(128, i) T.reads(A[vi]) T.writes(B[vi]) B[vi] = A[vi] * 2.0 sch = tir.Schedule(before_split) i, = sch.get_loops(sch.get_block("B")) vscale = tvm.tir.vscale() sch.split(i, factors=[tvm.tir.ceildiv(128, 4 * vscale), 4 * vscale])
would result in
@T.prim_func def after_split(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i0, i1 in T.grid(tvm.tir.ceildiv(128, 4 * vscale), 4 * vscale): with T.block("B"): vi = T.axis.spatial(128, i0 * 4 * vscale + i1) T.reads(A[vi]) T.writes(B[vi]) B[vi] = A[vi] * 2.0
The multiplier 4 in front of the vscale
comes form 128 / size_in_bits(float32)
. It encodes the minimum vector length in the given architecture, which is known at a compile time. After vectorizing the inner loop and lowering the TIR, we get:
@T.prim_func def lowered(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i_0 in range(tvm.tir.ceildiv(128, 4 * vscale)): B_1 = T.Buffer((128,), data=B.data) A_1 = T.Buffer((128,), data=A.data) B_1[i_0 * 4 * vscale:i_0 * 4 * vscale + 4 * vscale] = A_1[i_0 * 4 * vscale:i_0 * 4 * vscale + 4 * vscale] * T.Broadcast(T.float32(2), 4 * vscale)
For simplicity, this example corresponds to a case where we know that the axis dimension is divisible by the real hardware vector length. In general, we can't make this assumption - see Predication for more discussion.
Currently the LLVM backend maps ramps and broadcasts to LLVM vectors. Since LLVM represents the scalable vectors the same way as fixed length vectors, we can create the scalable vector operations in LLVM with minimal changes to the codegen. Here's a (shortened) example of LLVM with scalable vectors:
entry: tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 4 dereferenceable(4000) %agg.result, i8 0, i64 4000, i1 false) %0 = tail call i64 @llvm.vscale.i64() %.neg = mul nuw nsw i64 %0, 1016 %n.vec = and i64 %.neg, 1000 %1 = tail call i64 @llvm.vscale.i64() %2 = shl nuw nsw i64 %1, 2 %3 = tail call i64 @llvm.vscale.i64() %4 = shl nuw nsw i64 %3, 2 %5 = tail call i64 @llvm.vscale.i64() %6 = shl nuw nsw i64 %5, 2 %7 = tail call i64 @llvm.vscale.i64() %8 = shl nuw nsw i64 %7, 3 br label %vector.body vector.body: %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] %9 = getelementptr inbounds i32, ptr %arr0, i64 %index %wide.load = load <vscale x 4 x i32>, ptr %9, align 4 %10 = getelementptr inbounds i32, ptr %9, i64 %2 %wide.load9 = load <vscale x 4 x i32>, ptr %10, align 4 %11 = getelementptr inbounds i32, ptr %arr1, i64 %index %wide.load10 = load <vscale x 4 x i32>, ptr %11, align 4 %12 = getelementptr inbounds i32, ptr %11, i64 %4 %wide.load11 = load <vscale x 4 x i32>, ptr %12, align 4 %13 = add nsw <vscale x 4 x i32> %wide.load10, %wide.load %14 = add nsw <vscale x 4 x i32> %wide.load11, %wide.load9 %15 = getelementptr inbounds [1000 x i32], ptr %agg.result, i64 0, i64 %index store <vscale x 4 x i32> %13, ptr %15, align 4 %16 = getelementptr inbounds i32, ptr %15, i64 %6 store <vscale x 4 x i32> %14, ptr %16, align 4 %index.next = add nuw i64 %index, %8 %17 = icmp eq i64 %index.next, %n.vec br i1 %17, label %middle.block, label %vector.body
Making ramps and broadcasts scalable would mean changing the data types of ramp->lanes
and broadcast->lanes
from int
to PrimExpr
and correctly dealing with the accesses to the lanes throughout the codebase. E.g. for Ramp
the current node definition
class RampNode : public PrimExprNode { public: /*! \brief The base value. */ PrimExpr base; /*! \brief The stride of each step. */ PrimExpr stride; /*! \brief Total number of lanes. */ int lanes; ...
would change to
class RampNode : public PrimExprNode { public: /*! \brief The base value. */ PrimExpr base; /*! \brief The stride of each step. */ PrimExpr stride; /*! \brief Total number of lanes. */ PrimExpr lanes; ...
The are some advantages of representing the scalable vectors through modified Ramp
and Broadcast
nodes:
IRBuilder
has same APIs for scalable and fixed length vectors, so basic support would require minimal changes in the LLVM backendThe main areas that need changes are:
VectorizeLoop
pass - We need to handle creating scalable vectors from loops and correctly vectorizing the associated broadcasts.
tir.split
and tir.tile
(and TE equivalents) - These primitives would be the main interface through which we can create scalable vectors and, therefore, they should support compile time unknown factors
runtime::DataType
- We need to express scalable lanes in runtime::DataType
in a way that doesn't break the DLPack standard ABI. It was suggested by #tqchen to define
int kScalableVectorLaneMark = -1
to represent scalable lanes in DataType
. The lanes of scalable vectors are usually expressed as a multiple of vscale
. In order to make the lanes value (including the scalar multiplier) recoverable from DataType
, we will use a corresponding negative value, e.g. for T.ramp(0, 1, 4 * vscale())
would have its runtime::DataType
lanes set to -4.
Codegen - Most of the codegen changes would entail extending the current vector support to scalable vectors. Besides handling the tir.vscale -> llvm.vscale
conversion, an important enhancement would be support for strided ramps and broadcasts. Unlike Arm(R) Neon(TM) instruction set, SVE supports gather load and scatter store operations, which in TIR could be represented either by a BufferLoad
/Bufferstore
which is indexed by a Ramp
with stride != 1
, e.g.
A[T.ramp(0, 2, 4 * tir.vscale())]
or when the index is a non-linear expression involving Ramp
, e.g.
A[(10 * T.ramp(0, 1, 8 * tir.vscale())) % 4]
The current CPU codegen supports strided ramps by scalarising, while with SVE these could be turned into llvm.masked.gather
and llvm.masked.scatter
intrinsics.
Predication is used in vectorization when the axis dimension is not divisible by the vector length.
Let's start with a simple loop and split and vectorize it:
@T.prim_func def main(A: T.Buffer((60,), "float32"), B: T.Buffer((60,), "float32")): for i in range(60): with T.block("B"): vi = T.axis.spatial(60, i) T.reads(A[vi]) T.writes(B[vi]) B[vi] = A[vi] sch = tvm.tir.Schedule(main) l, = sch.get_loops(sch.get_block("B")) vscale = tvm.tir.vscale() sch.split(l, factors=[tvm.tir.ceildiv(60, 4 * vscale), 4 * vscale]) l0, l1 = sch.get_loops(sch.get_block("B")) sch.vectorize(l1)
This will give us
@T.prim_func def main(A: T.Buffer((60,), "float32"), B: T.Buffer((60,), "float32")): for i_0 in range(tvm.tir.ceildiv(60, 4 * vscale)): for i_1 in T.vectorized(4 * vscale): with T.block("B"): vi = T.axis.spatial(60, i_0 * 4 * vscale + i_1) T.where(i_0 * 4 * vscale + i_1 < 60) T.reads(A[vi]) T.writes(B[vi]) B[vi] = A[vi]
Actually vectorizing the TIR (i.e. creating the Ramp
nodes for the buffer operation indices) will be left for VectorizeLoop
pass in the TIR lowering pipeline. Currently this pass will keep the loops as scalar loops if the loop can‘t be exactly vectorized. In case of SVE we can use the predication instead. Since we don’t know the vector length at the compile time, VectorizeLoop
will always emit predicated loads and stores for SVE.
This would entail introducing another TIR intrinsic, tir.get_active_lane_mask()
which is analogous to llvm.get.active.lane.mask.*
, i.e. it would take a variable (loop induction var) and a bound and produce a bit mask corresponding to the active lanes. We'd also need to implement support for predication in BufferLoad
and BufferStore
operations. In TVMScript we can support predicates by adding Buffer::store
and Buffer::load
methods which can accept predicate and create BufferStore
and BufferLoad
objects. Expressing predicated operations in TVMScript would then look like:
@T.prim_func def main(A: T.Buffer((60,), "float32"), B: T.Buffer((60,), "float32")): for i0 in range(ceildiv(60, 4 * vscale)): let pred = T.get_active_lane_mask(i0 * 4 * vscale, 60) B.store( A.load([T.ramp(i0 * 4 * vscale, 1, 4 * vscale)], predicate=pred), [T.ramp(i0 * 4 * vscale, 1, 4 * vscale)], predicate=pred)
These loads and stores will then be lowered to llvm.masked.*
intrinsics.
Predication is not exclusive to VLA programming so the implementation could also benefit fixed length vectors.
When using the cleanup loop, the TIR example above would instead lower to:
@T.prim_func def main(A: T.Buffer((60,), "float32"), B: T.Buffer((60,), "float32")): for i0 in range(floordiv(60, 4 * vscale)): B[T.ramp(i0 * 4 * vscale, 1, 4 * vscale)] = A[T.ramp(i0 * 4 * vscale, 1, 4 * vscale)] for j0 in range(60 % 4 * vscale): B[4 * vscale * floordiv(60, 4 * vscale) + j0] = A[4 * vscale * floordiv(60, 4 * vscale) + j0]
These loads and stores can be lowered to “regular” LLVM vectorized load and store instructions with scalable vector type.
There is ongoing work to improve the VLA support in MLIR stack, see this RFC for a nice summary about where it is at and what's the plan for the future. The changes are conceptually similar to the ideas outlined in this PR, making changes to the linalg
and the transform
dialects.
Halide supports 128 bit SVE vectors - see the PR.
SME is an architecture extension that is targeting matrix multiply style operations by computing outer product between two scalable vectors. See the introduction to SME for more information.
Unlike for SVE, the LLVM will not support generating SME code from fixed bound loops, so the plan is to target the SME LLVM intrinsics from tensorize
and define the corresponding loop nests through scalable loop bounds.
In the future, we could include scalable vector support into meta schedule, making it possible for the tuner to choose between scalable and fixed length vectors or mix them in a same program.
New APIs similar to sample_perfect_tile
could be added that would sample vscale
dependent factors, e.g.
# sampling scalable tiles only sch.sample_scalable_tile(l0, n=3, max_innermost_vscale_multiplier=2) # sampling scalable and fixed length tiles sch.sample_mixed_tile(l0, n=6, max_innermost_fixed_size_factor=8, max_innermost_vscale_multiplier=2)
Splitting a loop and vectorizing the innermost dimension is a very common pattern in scheduling for scalable vectors. In the presence of sampling primitives that can support vscale
dependent splitting, we can create additional builtin composite schedule rules similar to the ScheduleRuleMultiLevelTiling
family. This could create a wide variety of programs with scalable vectors for the tuner to trial.