blob: 1c41d9367661a9ffd1455f0d5e04e173e59d7d16 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# type: ignore
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
import tvm
from tvm import relax
from tvm.ir import assert_structural_equal
import tvm.testing
def test_mlp():
@I.ir_module
class MLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@R.function
def foo(
x: R.Tensor((128, 128), "float32"),
weight1: R.Tensor((128, 128), "float32"),
weight2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv2 = R.dist.annotate_sharding(lv1, device_mesh="mesh[0]", placement="S[1]")
lv3 = R.matmul(lv2, weight2)
return lv3
@I.ir_module
class ShardedMLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.matmul(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(
lv1, weight2, out_dtype="void"
)
return lv3
after = relax.distributed.transform.PropagateSharding()(MLP)
assert_structural_equal(after, ShardedMLP)
def test_mlp_with_tuple():
@I.ir_module
class MLPWithTuple:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func(private=True)
def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (128, 128), "float32")
T_split = T.match_buffer(var_T_split, (64, 128), "float32")
T_split_1 = T.match_buffer(var_T_split_1, (64, 128), "float32")
# with T.block("root"):
for ax1, ax2 in T.grid(64, 128):
with T.block("T_split"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2]
for ax1, ax2 in T.grid(64, 128):
with T.block("T_split_1"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T_split_1[v_ax1, v_ax2] = A[v_ax1 + 64, v_ax2]
@R.function
def foo(
x: R.Tensor((128, 128), "float32"),
weight_packed: R.Tuple(
R.Tensor((128, 128), "float32"), R.Tensor((128, 128), "float32")
),
) -> R.Tensor((64, 128), "float32"):
cls = MLPWithTuple
weight1 = weight_packed[0]
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv_tuple = R.call_tir(
cls.split1,
[lv1],
out_sinfo=[R.Tensor((64, 128), "float32"), R.Tensor((64, 128), "float32")],
)
lv2 = lv_tuple[0]
lv3 = R.dist.annotate_sharding(lv2, device_mesh="mesh[0]", placement="S[1]")
weight2 = weight_packed[1]
lv4 = R.matmul(lv3, weight2)
return lv4
@I.ir_module
class ShardedMLPWithTuple:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func(private=True)
def split1(
A: T.Buffer((128, 128), "float32"),
T_split: T.Buffer((64, 128), "float32"),
T_split_1: T.Buffer((64, 128), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax1, ax2 in T.grid(64, 128):
with T.block("T_split"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[v_ax1, v_ax2])
T.writes(T_split[v_ax1, v_ax2])
T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2]
for ax1, ax2 in T.grid(64, 128):
with T.block("T_split_1"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[v_ax1 + 64, v_ax2])
T.writes(T_split_1[v_ax1, v_ax2])
T_split_1[v_ax1, v_ax2] = A[v_ax1 + 64, v_ax2]
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight_packed: R.Tuple(
R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
),
) -> R.DTensor((64, 128), "float32", "mesh[0]", "R"):
cls = ShardedMLPWithTuple
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = weight_packed[0]
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.matmul(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
gv = R.dist.call_tir(
cls.split1,
(lv1,),
out_sinfo=[
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
],
)
lv2: R.DTensor((64, 128), "float32", "mesh[0]", "S[1]") = gv[0]
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]") = weight_packed[1]
lv4: R.DTensor((64, 128), "float32", "mesh[0]", "R") = R.matmul(
lv2, weight2, out_dtype="void"
)
return lv4
after = relax.distributed.transform.PropagateSharding()(MLPWithTuple)
assert_structural_equal(after, ShardedMLPWithTuple)
def test_mlp_const():
@I.ir_module
class MLPWithConst:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@R.function
def foo(
x: R.Tensor((128, 128), "float32"),
weight1: R.Tensor((128, 128), "float32"),
weight2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv2 = R.add(lv1, R.const(2, "float32"))
lv3 = R.dist.annotate_sharding(lv2, device_mesh="mesh[0]", placement="S[1]")
lv4 = R.matmul(lv3, weight2)
return lv4
@I.ir_module
class ShardedMLPWithConst:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.matmul(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.add(
lv1, R.dist.const(2, R.DTensor((), "float32", "mesh[0]", "R"))
)
lv4: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(
lv2, weight2, out_dtype="void"
)
return lv4
after = relax.distributed.transform.PropagateSharding()(MLPWithConst)
assert_structural_equal(after, ShardedMLPWithConst)
def test_mlp_dynamic_shape():
@I.ir_module
class MLPDynamicShape:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@R.function
def foo(
x: R.Tensor(("m", "k0"), "float32"),
weight1: R.Tensor(("k0", "k1"), "float32"),
weight2: R.Tensor(("k1", "n"), "float32"),
) -> R.Tensor(("m", "n"), "float32"):
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv2 = R.dist.annotate_sharding(lv1, device_mesh="mesh[0]", placement="S[1]")
lv3 = R.matmul(lv2, weight2)
return lv3
@I.ir_module
class ShardedMLPDynamicShape:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function
def foo(
x: R.DTensor(("m", "k0"), "float32", "mesh[0]", "R"),
weight1: R.DTensor(("k0", "k1"), "float32", "mesh[0]", "S[1]"),
weight2: R.DTensor(("k1", "n"), "float32", "mesh[0]", "S[0]"),
) -> R.DTensor(("m", "n"), "float32", "mesh[0]", "R"):
m = T.int64()
n = T.int64()
k0 = T.int64()
k1 = T.int64()
lv0: R.DTensor((m, k1), "float32", "mesh[0]", "S[1]") = R.matmul(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((m, k1), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv3: R.DTensor((m, n), "float32", "mesh[0]", "R") = R.matmul(
lv1, weight2, out_dtype="void"
)
return lv3
after = relax.distributed.transform.PropagateSharding()(MLPDynamicShape)
assert_structural_equal(after, ShardedMLPDynamicShape)
def test_mlp_pipeline_parallelism():
@I.ir_module
class PipelineMLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((2,), I.Range(4, 6)), # mesh[1]
]
}
)
@R.function
def foo(
x: R.Tensor((128, 128), "float32"),
weight1: R.Tensor((128, 128), "float32"),
weight2: R.Tensor((128, 128), "float32"),
weight3: R.Tensor((128, 128), "float32"),
weight4: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
lv0 = R.matmul(x, weight1)
lv1 = R.nn.gelu(lv0)
lv2 = R.dist.annotate_sharding(lv1, device_mesh="mesh[0]", placement="S[1]")
lv3 = R.matmul(lv2, weight2)
lv4 = R.dist.annotate_sharding(lv3, device_mesh="mesh[1]", placement="R")
lv5 = R.matmul(lv4, weight3)
lv6 = R.nn.gelu(lv5)
lv7 = R.dist.annotate_sharding(lv6, device_mesh="mesh[1]", placement="S[1]")
lv8 = R.matmul(lv7, weight4)
return lv8
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class ShardedPipelineMLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((2,), I.Range(4, 6))]}
)
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
weight3: R.DTensor((128, 128), "float32", "mesh[1]", "S[1]"),
weight4: R.DTensor((128, 128), "float32", "mesh[1]", "S[0]"),
) -> R.DTensor((128, 128), "float32", "mesh[1]", "R"):
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.matmul(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(
lv1, weight2, out_dtype="void"
)
lv4: R.DTensor((128, 128), "float32", "mesh[1]", "R") = R.dist.redistribute(
lv3, device_mesh="mesh[1]", placement="R"
)
lv5: R.DTensor((128, 128), "float32", "mesh[1]", "S[1]") = R.matmul(
lv4, weight3, out_dtype="void"
)
lv6: R.DTensor((128, 128), "float32", "mesh[1]", "S[1]") = R.nn.gelu(lv5)
lv8: R.DTensor((128, 128), "float32", "mesh[1]", "R") = R.matmul(
lv6, weight4, out_dtype="void"
)
return lv8
after = relax.distributed.transform.PropagateSharding()(PipelineMLP)
assert_structural_equal(after, ShardedPipelineMLP)
def test_decoder_layer():
@I.ir_module
class LlamaAttentionLayer:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func
def rms_norm(
var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle
):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (T.int64(1), 256, T.int64(4096)), "float16")
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), 256, T.int64(4096)), "float16")
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), 256))
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
T.reads(A[v_bsz, v_i, v_k])
T.writes(Ared_temp[v_bsz, v_i])
with T.init():
Ared_temp[v_bsz, v_i] = T.float32(0)
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
"float32", A[v_bsz, v_i, v_k]
) * T.Cast("float32", A[v_bsz, v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("rms_norm"):
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
T.writes(rms_norm_1[v_bsz, v_i, v_k])
rms_norm_1[v_bsz, v_i, v_k] = T.Cast(
"float16",
T.Cast("float32", B[v_k])
* (
T.Cast("float32", A[v_bsz, v_i, v_k])
/ T.sqrt(
Ared_temp[v_bsz, v_i] * T.float32(0.000244140625)
+ T.float32(9.9999999999999995e-07)
)
),
)
@T.prim_func
def rotary_embedding(
var_A: T.handle,
B: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
C: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
var_rotary: T.handle,
):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (T.int64(1), 256, T.int64(32), T.int64(128)), "float16")
rotary = T.match_buffer(
var_rotary, (T.int64(1), 256, T.int64(32), T.int64(128)), "float16"
)
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(
B[256 + v_i1 - 256, v_i3],
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64) : v_i3 - T.int64(64) + T.int64(129)],
C[256 + v_i1 - 256, v_i3],
)
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
rotary[v_i0, v_i1, v_i2, v_i3] = B[256 + v_i1 - 256, v_i3] * A[
v_i0, v_i1, v_i2, v_i3
] + C[256 + v_i1 - 256, v_i3] * T.Select(
T.int64(64) <= v_i3,
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)],
A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1),
)
@R.function(pure=False)
def foo(
input_tokens: R.Tensor((1, 256, 4096), dtype="float16"),
mask: R.Tensor((1, 1, 256, 256), dtype="float16"),
div_const: R.Tensor((1, 32, 256, 256), dtype="float16"),
maximum_const: R.Tensor((1, 32, 256, 256), dtype="float16"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.Tensor((4096, 4096), dtype="float16"),
linear_weight1: R.Tensor((4096, 4096), dtype="float16"),
linear_weight2: R.Tensor((4096, 4096), dtype="float16"),
linear_weight3: R.Tensor((4096, 4096), dtype="float16"),
rms_norm_weight: R.Tensor((4096,), dtype="float16"),
cos_cached: R.Tensor((2048, 128), dtype="float16"),
sin_cached: R.Tensor((2048, 128), dtype="float16"),
):
cls = LlamaAttentionLayer
lv6 = R.call_tir(
cls.rms_norm,
(input_tokens, rms_norm_weight),
out_sinfo=R.Tensor((1, 256, 4096), dtype="float16"),
)
lv7: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(linear_weight, axes=None)
lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv7, "mesh[0]", "S[1]"
)
lv8: R.Tensor((1, 256, 4096), dtype="float16") = R.matmul(
lv6, lv7_copy, out_dtype="void"
)
lv9: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape(
lv8, R.shape([1, 256, 32, 128])
)
lv10: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight1, axes=None
)
lv10_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv10, "mesh[0]", "S[1]"
)
lv11: R.Tensor((1, 256, 4096), dtype="float16") = R.matmul(
lv6, lv10_copy, out_dtype="void"
)
lv12: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape(
lv11, R.shape([1, 256, 32, 128])
)
lv13: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight2, axes=None
)
lv13_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv13, "mesh[0]", "S[1]"
)
lv14: R.Tensor((1, 256, 4096), dtype="float16") = R.matmul(
lv6, lv13_copy, out_dtype="void"
)
lv15: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape(
lv14, R.shape([1, 256, 32, 128])
)
lv16 = R.call_tir(
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
)
lv17 = R.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
)
lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape(
lv17, R.shape([256, 32, 128])
)
lv19: R.Tensor((256, 32, 128), dtype="float16") = R.reshape(
lv15, R.shape([256, 32, 128])
)
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,)
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,)
)
lv24: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv21,
R.shape([256, 32, 128]),
sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),),
)
lv25: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv23,
R.shape([256, 32, 128]),
sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),),
)
lv26: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape(
lv24, R.shape([1, 256, 32, 128])
)
lv27: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape(
lv25, R.shape([1, 256, 32, 128])
)
lv28: R.Tensor((1, 32, 256, 128), dtype="float16") = R.permute_dims(
lv16, axes=[0, 2, 1, 3]
)
lv29: R.Tensor((1, 32, 256, 128), dtype="float16") = R.permute_dims(
lv26, axes=[0, 2, 1, 3]
)
lv30: R.Tensor((1, 32, 256, 128), dtype="float16") = R.permute_dims(
lv27, axes=[0, 2, 1, 3]
)
lv31: R.Tensor((1, 32, 128, 256), dtype="float16") = R.permute_dims(
lv29, axes=[0, 1, 3, 2]
)
lv32: R.Tensor((1, 32, 256, 256), dtype="float16") = R.matmul(
lv28, lv31, out_dtype="void"
)
lv33: R.Tensor((1, 32, 256, 256), dtype="float16") = R.divide(lv32, div_const)
lv34: R.Tensor((1, 32, 256, 256), dtype="float16") = R.maximum(lv33, maximum_const)
lv35: R.Tensor((1, 32, 256, 256), dtype="float16") = R.minimum(lv34, mask)
# lv36: R.Tensor((1, 32, 256, 256), dtype="float32") = R.astype(lv35, dtype="float32")
lv37: R.Tensor((1, 32, 256, 256), dtype="float16") = R.nn.softmax(lv35, axis=-1)
# lv38: R.Tensor((1, 32, 256, 256), dtype="float16") = R.astype(lv37, dtype="float16")
lv39: R.Tensor((1, 32, 256, 128), dtype="float16") = R.matmul(
lv37, lv30, out_dtype="void"
)
lv40: R.Tensor((1, 256, 32, 128), dtype="float16") = R.permute_dims(
lv39, axes=[0, 2, 1, 3]
)
lv41: R.Tensor((1, 256, 4096), dtype="float16") = R.reshape(
lv40, R.shape([1, 256, 4096])
)
lv42: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight3, axes=None
)
lv43: R.Tensor((1, 256, 4096), dtype="float16") = R.matmul(lv41, lv42, out_dtype="void")
lv44: R.Tensor((1, 256, 4096), dtype="float16") = R.add(input_tokens, lv43)
gv = lv44
return gv
@I.ir_module
class ShardedLlamaAttentionLayer:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func
def rms_norm(
A: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"),
B: T.Buffer((T.int64(4096),), "float16"),
rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), 256))
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
T.reads(A[v_bsz, v_i, v_k])
T.writes(Ared_temp[v_bsz, v_i])
with T.init():
Ared_temp[v_bsz, v_i] = T.float32(0)
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
"float32", A[v_bsz, v_i, v_k]
) * T.Cast("float32", A[v_bsz, v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("rms_norm"):
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
T.writes(rms_norm_1[v_bsz, v_i, v_k])
rms_norm_1[v_bsz, v_i, v_k] = T.Cast(
"float16",
T.Cast("float32", B[v_k])
* (
T.Cast("float32", A[v_bsz, v_i, v_k])
/ T.sqrt(
Ared_temp[v_bsz, v_i] * T.float32(0.000244140625)
+ T.float32(9.9999999999999995e-07)
)
),
)
@T.prim_func
def rotary_embedding(
A: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"),
B: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
C: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(
B[256 + v_i1 - 256, v_i3],
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64) : v_i3 - T.int64(64) + T.int64(129)],
C[256 + v_i1 - 256, v_i3],
)
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
rotary[v_i0, v_i1, v_i2, v_i3] = B[256 + v_i1 - 256, v_i3] * A[
v_i0, v_i1, v_i2, v_i3
] + C[256 + v_i1 - 256, v_i3] * T.Select(
T.int64(64) <= v_i3,
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)],
A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1),
)
@R.function(pure=False)
def foo(
input_tokens: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
mask: R.DTensor((1, 1, 256, 256), "float16", "mesh[0]", "R"),
div_const: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
maximum_const: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight1: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight2: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight3: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
rms_norm_weight: R.DTensor((4096,), "float16", "mesh[0]", "R"),
cos_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
sin_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
) -> R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"):
cls = ShardedLlamaAttentionLayer
lv6 = R.dist.call_tir(
cls.rms_norm,
(input_tokens, rms_norm_weight),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
)
lv7: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight, axes=None
)
lv8: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv7, out_dtype="void"
)
lv9: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv8, R.shape([1, 256, 32, 128])
)
lv10: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight1, axes=None
)
lv11: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv10, out_dtype="void"
)
lv12: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv11, R.shape([1, 256, 32, 128])
)
lv13: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight2, axes=None
)
lv14: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv13, out_dtype="void"
)
lv15: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv14, R.shape([1, 256, 32, 128])
)
lv16 = R.dist.call_tir(
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv17 = R.dist.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape(
lv17, R.shape([256, 32, 128])
)
lv19: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape(
lv15, R.shape([256, 32, 128])
)
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv20,
lv18,
sinfo_args=(R.Object,),
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv22,
lv19,
sinfo_args=(R.Object,),
)
lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv21,
R.shape([256, 32, 128]),
sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv23,
R.shape([256, 32, 128]),
sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv26: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv24, R.shape([1, 256, 32, 128])
)
lv27: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv25, R.shape([1, 256, 32, 128])
)
lv28: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv16, axes=[0, 2, 1, 3]
)
lv29: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv26, axes=[0, 2, 1, 3]
)
lv30: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv27, axes=[0, 2, 1, 3]
)
lv31: R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv29, axes=[0, 1, 3, 2]
)
lv32: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = R.matmul(
lv28, lv31, out_dtype="void"
)
lv33: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = R.divide(
lv32, div_const
)
lv34: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = R.maximum(
lv33, maximum_const
)
lv35: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = R.minimum(lv34, mask)
lv37: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = R.nn.softmax(
lv35, axis=-1
)
lv39: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = R.matmul(
lv37, lv30, out_dtype="void"
)
lv40: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.permute_dims(
lv39, axes=[0, 2, 1, 3]
)
lv41: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = R.reshape(
lv40, R.shape([1, 256, 4096])
)
lv42: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]") = R.permute_dims(
linear_weight3, axes=None
)
lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.matmul(
lv41, lv42, out_dtype="void"
)
lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.add(input_tokens, lv43)
gv: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = lv44
return gv
after = relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
assert_structural_equal(after, ShardedLlamaAttentionLayer)
# PropagateSharding should analyze TIR funtions
# and successfully propagate sharding annotations through them
def test_decoder_layer_tir():
@I.ir_module
class LlamaAttentionLayerTIR:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[T.int64(0), v_ax1, v_ax2], B[T.int64(0), v_ax1, v_ax2])
T.writes(T_add[T.int64(0), v_ax1, v_ax2])
T_add[T.int64(0), v_ax1, v_ax2] = (
A[T.int64(0), v_ax1, v_ax2] + B[T.int64(0), v_ax1, v_ax2]
)
@T.prim_func(private=True)
def divide(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_divide"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), v_ax1, v_ax2, v_ax3])
T.writes(T_divide[T.int64(0), v_ax1, v_ax2, v_ax3])
T_divide[T.int64(0), v_ax1, v_ax2, v_ax3] = (
A[T.int64(0), v_ax1, v_ax2, v_ax3] / B[T.int64(0), v_ax1, v_ax2, v_ax3]
)
@T.prim_func(private=True)
def matmul(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(4096)):
with T.block("matmul"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k])
T.reads(A[T.int64(0), v_i1, v_k], B[v_k, v_i2])
T.writes(matmul[T.int64(0), v_i1, v_i2])
with T.init():
matmul[T.int64(0), v_i1, v_i2] = T.float16(0)
matmul[T.int64(0), v_i1, v_i2] = (
matmul[T.int64(0), v_i1, v_i2] + A[T.int64(0), v_i1, v_k] * B[v_k, v_i2]
)
@T.prim_func(private=True)
def matmul1(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
B: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"),
matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(
T.int64(1), T.int64(32), T.int64(256), T.int64(256), T.int64(128)
):
with T.block("matmul"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSR", [i1, i2, i3, k])
T.reads(A[T.int64(0), v_i1, v_i2, v_k], B[T.int64(0), v_i1, v_k, v_i3])
T.writes(matmul[T.int64(0), v_i1, v_i2, v_i3])
with T.init():
matmul[T.int64(0), v_i1, v_i2, v_i3] = T.float16(0)
matmul[T.int64(0), v_i1, v_i2, v_i3] = (
matmul[T.int64(0), v_i1, v_i2, v_i3]
+ A[T.int64(0), v_i1, v_i2, v_k] * B[T.int64(0), v_i1, v_k, v_i3]
)
@T.prim_func(private=True)
def matmul2(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(
T.int64(1), T.int64(32), T.int64(256), T.int64(128), T.int64(256)
):
with T.block("matmul"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSR", [i1, i2, i3, k])
T.reads(A[T.int64(0), v_i1, v_i2, v_k], B[T.int64(0), v_i1, v_k, v_i3])
T.writes(matmul[T.int64(0), v_i1, v_i2, v_i3])
with T.init():
matmul[T.int64(0), v_i1, v_i2, v_i3] = T.float16(0)
matmul[T.int64(0), v_i1, v_i2, v_i3] = (
matmul[T.int64(0), v_i1, v_i2, v_i3]
+ A[T.int64(0), v_i1, v_i2, v_k] * B[T.int64(0), v_i1, v_k, v_i3]
)
@T.prim_func(private=True)
def maximum(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
T_maximum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_maximum"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), v_ax1, v_ax2, v_ax3])
T.writes(T_maximum[T.int64(0), v_ax1, v_ax2, v_ax3])
T_maximum[T.int64(0), v_ax1, v_ax2, v_ax3] = T.max(
A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), v_ax1, v_ax2, v_ax3]
)
@T.prim_func(private=True)
def minimum(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
B: T.Buffer((T.int64(1), T.int64(1), T.int64(256), T.int64(256)), "float16"),
T_minimum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_minimum"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(
A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), v_ax2, v_ax3]
)
T.writes(T_minimum[T.int64(0), v_ax1, v_ax2, v_ax3])
T_minimum[T.int64(0), v_ax1, v_ax2, v_ax3] = T.min(
A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), v_ax2, v_ax3]
)
@T.prim_func(private=True)
def reshape(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax1, v_ax2 * T.int64(128) + v_ax3])
T.writes(T_reshape[T.int64(0), v_ax1, v_ax2, v_ax3])
T_reshape[T.int64(0), v_ax1, v_ax2, v_ax3] = A[
T.int64(0), v_ax1, v_ax2 * T.int64(128) + v_ax3
]
@T.prim_func(private=True)
def reshape1(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
T_reshape: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(32), T.int64(128)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
@T.prim_func(private=True)
def reshape2(
A: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"),
T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_reshape[T.int64(0), v_ax1, v_ax2, v_ax3])
T_reshape[T.int64(0), v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func(private=True)
def reshape3(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[T.int64(0), v_ax1, v_ax2 // T.int64(128), v_ax2 % T.int64(128)])
T.writes(T_reshape[T.int64(0), v_ax1, v_ax2])
T_reshape[T.int64(0), v_ax1, v_ax2] = A[
T.int64(0), v_ax1, v_ax2 // T.int64(128), v_ax2 % T.int64(128)
]
@T.prim_func
def rms_norm(
A: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"),
B: T.Buffer((T.int64(4096),), "float16"),
rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), 256))
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz = T.axis.spatial(T.int64(1), T.int64(0))
v_i, v_k = T.axis.remap("SR", [i, k])
T.reads(A[T.int64(0), v_i, v_k])
T.writes(Ared_temp[T.int64(0), v_i])
with T.init():
Ared_temp[T.int64(0), v_i] = T.float32(0)
Ared_temp[T.int64(0), v_i] = Ared_temp[T.int64(0), v_i] + T.Cast(
"float32", A[T.int64(0), v_i, v_k]
) * T.Cast("float32", A[T.int64(0), v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)):
with T.block("rms_norm"):
v_bsz = T.axis.spatial(T.int64(1), T.int64(0))
v_i, v_k = T.axis.remap("SS", [i, k])
T.reads(B[v_k], A[T.int64(0), v_i, v_k], Ared_temp[T.int64(0), v_i])
T.writes(rms_norm_1[T.int64(0), v_i, v_k])
rms_norm_1[T.int64(0), v_i, v_k] = T.Cast(
"float16",
T.Cast("float32", B[v_k])
* (
T.Cast("float32", A[T.int64(0), v_i, v_k])
/ T.sqrt(
Ared_temp[T.int64(0), v_i] * T.float32(0.000244140625)
+ T.float32(9.9999999999999995e-07)
)
),
)
@T.prim_func
def rotary_embedding(
A: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"),
B: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
C: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3])
T.reads(
B[v_i1, v_i3],
A[
T.int64(0),
v_i1,
v_i2,
v_i3 - T.int64(64) : v_i3 - T.int64(64) + T.int64(129),
],
C[v_i1, v_i3],
)
T.writes(rotary[T.int64(0), v_i1, v_i2, v_i3])
rotary[T.int64(0), v_i1, v_i2, v_i3] = B[v_i1, v_i3] * A[
T.int64(0), v_i1, v_i2, v_i3
] + C[v_i1, v_i3] * T.Select(
T.int64(64) <= v_i3,
A[T.int64(0), v_i1, v_i2, v_i3 - T.int64(64)],
A[T.int64(0), v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1),
)
@T.prim_func(private=True)
def softmax(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"),
T_softmax_norm: T.Buffer(
(T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"
),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16")
T_softmax_exp = T.alloc_buffer(
(T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"
)
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_softmax_maxelem"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k])
T.reads(A[T.int64(0), v_i1, v_i2, v_k])
T.writes(T_softmax_maxelem[T.int64(0), v_i1, v_i2])
with T.init():
T_softmax_maxelem[T.int64(0), v_i1, v_i2] = T.float16(-65504)
T_softmax_maxelem[T.int64(0), v_i1, v_i2] = T.max(
T_softmax_maxelem[T.int64(0), v_i1, v_i2], A[T.int64(0), v_i1, v_i2, v_k]
)
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_softmax_exp"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3])
T.reads(
A[T.int64(0), v_i1, v_i2, v_i3], T_softmax_maxelem[T.int64(0), v_i1, v_i2]
)
T.writes(T_softmax_exp[T.int64(0), v_i1, v_i2, v_i3])
T_softmax_exp[T.int64(0), v_i1, v_i2, v_i3] = T.exp(
A[T.int64(0), v_i1, v_i2, v_i3] - T_softmax_maxelem[T.int64(0), v_i1, v_i2]
)
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_softmax_expsum"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k])
T.reads(T_softmax_exp[T.int64(0), v_i1, v_i2, v_k])
T.writes(T_softmax_expsum[T.int64(0), v_i1, v_i2])
with T.init():
T_softmax_expsum[T.int64(0), v_i1, v_i2] = T.float16(0)
T_softmax_expsum[T.int64(0), v_i1, v_i2] = (
T_softmax_expsum[T.int64(0), v_i1, v_i2]
+ T_softmax_exp[T.int64(0), v_i1, v_i2, v_k]
)
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)):
with T.block("T_softmax_norm"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3])
T.reads(
T_softmax_exp[T.int64(0), v_i1, v_i2, v_i3],
T_softmax_expsum[T.int64(0), v_i1, v_i2],
)
T.writes(T_softmax_norm[T.int64(0), v_i1, v_i2, v_i3])
T.block_attr({"axis": 3})
T_softmax_norm[T.int64(0), v_i1, v_i2, v_i3] = (
T_softmax_exp[T.int64(0), v_i1, v_i2, v_i3]
/ T_softmax_expsum[T.int64(0), v_i1, v_i2]
)
@T.prim_func(private=True)
def transpose(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = A[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(128)):
with T.block("T_transpose"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3])
T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3] = A[
T.int64(0), v_ax2, v_ax1, v_ax3
]
@T.prim_func(private=True)
def transpose2(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(128), T.int64(256)):
with T.block("T_transpose"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax1, v_ax3, v_ax2])
T.writes(T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3])
T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3] = A[
T.int64(0), v_ax1, v_ax3, v_ax2
]
@T.prim_func(private=True)
def transpose3(
A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"),
T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)):
with T.block("T_transpose"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(A[T.int64(0), v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3])
T_transpose[T.int64(0), v_ax1, v_ax2, v_ax3] = A[
T.int64(0), v_ax2, v_ax1, v_ax3
]
@R.function(pure=False)
def foo(
input_tokens: R.Tensor((1, 256, 4096), dtype="float16"),
mask: R.Tensor((1, 1, 256, 256), dtype="float16"),
div_const: R.Tensor((1, 32, 256, 256), dtype="float16"),
maximum_const: R.Tensor((1, 32, 256, 256), dtype="float16"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.Tensor((4096, 4096), dtype="float16"),
linear_weight1: R.Tensor((4096, 4096), dtype="float16"),
linear_weight2: R.Tensor((4096, 4096), dtype="float16"),
linear_weight3: R.Tensor((4096, 4096), dtype="float16"),
rms_norm_weight: R.Tensor((4096,), dtype="float16"),
cos_cached: R.Tensor((2048, 128), dtype="float16"),
sin_cached: R.Tensor((2048, 128), dtype="float16"),
) -> R.Tensor((1, 256, 4096), dtype="float16"):
cls = LlamaAttentionLayerTIR
lv6 = R.call_tir(
cls.rms_norm,
(input_tokens, rms_norm_weight),
out_sinfo=R.Tensor((1, 256, 4096), dtype="float16"),
)
lv7 = R.call_tir(
cls.transpose, (linear_weight,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv7, device_mesh="mesh[0]", placement="S[1]"
)
lv8 = R.call_tir(
cls.matmul, (lv6, lv7_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
lv9 = R.call_tir(
cls.reshape, (lv8,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv10 = R.call_tir(
cls.transpose, (linear_weight1,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
lv10_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv10, device_mesh="mesh[0]", placement="S[1]"
)
lv11 = R.call_tir(
cls.matmul, (lv6, lv10_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
lv12 = R.call_tir(
cls.reshape, (lv11,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv13 = R.call_tir(
cls.transpose, (linear_weight2,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
lv13_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv13, device_mesh="mesh[0]", placement="S[1]"
)
lv14 = R.call_tir(
cls.matmul, (lv6, lv13_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
lv15 = R.call_tir(
cls.reshape, (lv14,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv16 = R.call_tir(
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
)
lv17 = R.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
)
lv18 = R.call_tir(
cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16")
)
lv19 = R.call_tir(
cls.reshape1, (lv15,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16")
)
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,)
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,)
)
lv24: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv21,
R.shape([256, 32, 128]),
sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),),
)
lv25: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv23,
R.shape([256, 32, 128]),
sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),),
)
lv26 = R.call_tir(
cls.reshape2, (lv24,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv27 = R.call_tir(
cls.reshape2, (lv25,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv28 = R.call_tir(
cls.transpose1, (lv16,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16")
)
lv29 = R.call_tir(
cls.transpose1, (lv26,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16")
)
lv30 = R.call_tir(
cls.transpose1, (lv27,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16")
)
lv31 = R.call_tir(
cls.transpose2, (lv29,), out_sinfo=R.Tensor((1, 32, 128, 256), dtype="float16")
)
lv32 = R.call_tir(
cls.matmul1, (lv28, lv31), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16")
)
lv33 = R.call_tir(
cls.divide,
(lv32, div_const),
out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16"),
)
lv34 = R.call_tir(
cls.maximum,
(lv33, maximum_const),
out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16"),
)
lv35 = R.call_tir(
cls.minimum, (lv34, mask), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16")
)
lv37 = R.call_tir(
cls.softmax, (lv35,), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16")
)
lv39 = R.call_tir(
cls.matmul2, (lv37, lv30), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16")
)
lv40 = R.call_tir(
cls.transpose3, (lv39,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16")
)
lv41 = R.call_tir(
cls.reshape3, (lv40,), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
lv42 = R.call_tir(
cls.transpose, (linear_weight3,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
lv43 = R.call_tir(
cls.matmul, (lv41, lv42), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
lv44 = R.call_tir(
cls.add, (input_tokens, lv43), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16")
)
gv: R.Tensor((1, 256, 4096), dtype="float16") = lv44
return gv
# the below uses global vars that are not yet defined but the definitions
# will be added later
@I.ir_module(check_well_formed=False)
class ShardedLlamaAttentionLayerTIR:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function(pure=False)
def foo(
input_tokens: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
mask: R.DTensor((1, 1, 256, 256), "float16", "mesh[0]", "R"),
div_const: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
maximum_const: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight1: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight2: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight3: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
rms_norm_weight: R.DTensor((4096,), "float16", "mesh[0]", "R"),
cos_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
sin_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
) -> R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"):
cls = ShardedLlamaAttentionLayerTIR
lv6 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("rms_norm"),
(input_tokens, rms_norm_weight),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
)
lv7 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose"),
(linear_weight,),
out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
)
lv8 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul"),
(lv6, lv7),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"),
)
lv9 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape"),
(lv8,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv10 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose"),
(linear_weight1,),
out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
)
lv11 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul"),
(lv6, lv10),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"),
)
lv12 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape"),
(lv11,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv13 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose"),
(linear_weight2,),
out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
)
lv14 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul"),
(lv6, lv13),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"),
)
lv15 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape"),
(lv14,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv16 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
(lv9, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv17 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
(lv12, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv18 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape1"),
(lv17,),
out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),
)
lv19 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape1"),
(lv15,),
out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),
)
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv20,
lv18,
sinfo_args=(R.Object,),
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv22,
lv19,
sinfo_args=(R.Object,),
)
lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv21,
R.shape([256, 32, 128]),
sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv23,
R.shape([256, 32, 128]),
sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv26 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape2"),
(lv24,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv27 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape2"),
(lv25,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv28 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose1"),
(lv16,),
out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"),
)
lv29 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose1"),
(lv26,),
out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"),
)
lv30 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose1"),
(lv27,),
out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"),
)
lv31 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose2"),
(lv29,),
out_sinfo=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"),
)
lv32 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul1"),
(lv28, lv31),
out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
)
lv33 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("divide"),
(lv32, div_const),
out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
)
lv34 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("maximum"),
(lv33, maximum_const),
out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
)
lv35 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("minimum"),
(lv34, mask),
out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
)
lv37 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("softmax"),
(lv35,),
out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"),
)
lv39 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul2"),
(lv37, lv30),
out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"),
)
lv40 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose3"),
(lv39,),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"),
)
lv41 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape3"),
(lv40,),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"),
)
lv42 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("transpose"),
(linear_weight3,),
out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
)
lv43 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("matmul"),
(lv41, lv42),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
)
lv44 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("add"),
(input_tokens, lv43),
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
)
gv: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = lv44
return gv
for gv, func in LlamaAttentionLayerTIR.functions_items():
if gv.name_hint != "foo":
ShardedLlamaAttentionLayerTIR[gv] = func
after = relax.distributed.transform.PropagateSharding()(LlamaAttentionLayerTIR)
assert_structural_equal(after, ShardedLlamaAttentionLayerTIR)
def test_decoder_layer_dynamic_shape():
@I.ir_module
class LlamaAttentionLayerDynamicShape:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2,), I.Range(0, 2)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func
def rms_norm(
var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle
):
T.func_attr({"tir.noalias": True})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), n))
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
T.reads(A[v_bsz, v_i, v_k])
T.writes(Ared_temp[v_bsz, v_i])
with T.init():
Ared_temp[v_bsz, v_i] = T.float32(0)
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
"float32", A[v_bsz, v_i, v_k]
) * T.Cast("float32", A[v_bsz, v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("rms_norm"):
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
T.writes(rms_norm_1[v_bsz, v_i, v_k])
rms_norm_1[v_bsz, v_i, v_k] = T.Cast(
"float16",
T.Cast("float32", B[v_k])
* (
T.Cast("float32", A[v_bsz, v_i, v_k])
/ T.sqrt(
Ared_temp[v_bsz, v_i] * T.float32(0.000244140625)
+ T.float32(9.9999999999999995e-07)
)
),
)
@T.prim_func
def rotary_embedding(
var_A: T.handle,
B: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
C: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
var_rotary: T.handle,
m: T.int64,
):
T.func_attr({"tir.noalias": True})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
rotary = T.match_buffer(
var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16"
)
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(
B[m + v_i1 - n, v_i3],
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64) : v_i3 - T.int64(64) + T.int64(129)],
C[m + v_i1 - n, v_i3],
)
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[
v_i0, v_i1, v_i2, v_i3
] + C[m + v_i1 - n, v_i3] * T.Select(
T.int64(64) <= v_i3,
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)],
A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1),
)
@R.function(pure=False)
def foo(
input_tokens: R.Tensor((1, "n", 4096), dtype="float16"),
mask: R.Tensor((1, 1, "n", "m"), dtype="float16"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.Tensor((4096, 4096), dtype="float16"),
linear_weight1: R.Tensor((4096, 4096), dtype="float16"),
linear_weight2: R.Tensor((4096, 4096), dtype="float16"),
linear_weight3: R.Tensor((4096, 4096), dtype="float16"),
rms_norm_weight: R.Tensor((4096,), dtype="float16"),
cos_cached: R.Tensor((2048, 128), dtype="float16"),
sin_cached: R.Tensor((2048, 128), dtype="float16"),
):
n = T.int64()
m = T.int64()
cls = LlamaAttentionLayerDynamicShape
lv6 = R.call_tir(
cls.rms_norm,
(input_tokens, rms_norm_weight),
out_sinfo=R.Tensor((1, n, 4096), dtype="float16"),
)
lv7: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(linear_weight, axes=None)
lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv7, "mesh[0]", "S[1]"
)
lv8: R.Tensor((1, n, 4096), dtype="float16") = R.matmul(lv6, lv7_copy, out_dtype="void")
lv9: R.Tensor((1, n, 32, 128), dtype="float16") = R.reshape(
lv8, R.shape([1, n, 32, 128])
)
lv10: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight1, axes=None
)
lv10_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv10, "mesh[0]", "S[1]"
)
lv11: R.Tensor((1, n, 4096), dtype="float16") = R.matmul(
lv6, lv10_copy, out_dtype="void"
)
lv12: R.Tensor((1, n, 32, 128), dtype="float16") = R.reshape(
lv11, R.shape([1, n, 32, 128])
)
lv13: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight2, axes=None
)
lv13_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding(
lv13, "mesh[0]", "S[1]"
)
lv14: R.Tensor((1, n, 4096), dtype="float16") = R.matmul(
lv6, lv13_copy, out_dtype="void"
)
lv15: R.Tensor((1, n, 32, 128), dtype="float16") = R.reshape(
lv14, R.shape([1, n, 32, 128])
)
lv16 = R.call_tir(
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"),
tir_vars=R.shape([m]),
)
lv17 = R.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"),
tir_vars=R.shape([m]),
)
lv18: R.Tensor((n, 32, 128), dtype="float16") = R.reshape(lv17, R.shape([n, 32, 128]))
lv19: R.Tensor((n, 32, 128), dtype="float16") = R.reshape(lv15, R.shape([n, 32, 128]))
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,)
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,)
)
lv24: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv21,
R.shape([m, 32, 128]),
sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),),
)
lv25: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed(
"vm.builtin.attention_kv_cache_view",
lv23,
R.shape([m, 32, 128]),
sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),),
)
lv26: R.Tensor((1, m, 32, 128), dtype="float16") = R.reshape(
lv24, R.shape([1, m, 32, 128])
)
lv27: R.Tensor((1, m, 32, 128), dtype="float16") = R.reshape(
lv25, R.shape([1, m, 32, 128])
)
lv28: R.Tensor((1, 32, n, 128), dtype="float16") = R.permute_dims(
lv16, axes=[0, 2, 1, 3]
)
lv29: R.Tensor((1, 32, m, 128), dtype="float16") = R.permute_dims(
lv26, axes=[0, 2, 1, 3]
)
lv30: R.Tensor((1, 32, m, 128), dtype="float16") = R.permute_dims(
lv27, axes=[0, 2, 1, 3]
)
lv31: R.Tensor((1, 32, 128, m), dtype="float16") = R.permute_dims(
lv29, axes=[0, 1, 3, 2]
)
lv32: R.Tensor((1, 32, n, m), dtype="float16") = R.matmul(lv28, lv31, out_dtype="void")
lv33: R.Tensor((1, 32, n, m), dtype="float16") = R.divide(
lv32, R.const(8, dtype="float16")
) # just choose some random value
lv34: R.Tensor((1, 32, n, m), dtype="float16") = R.maximum(
lv33, R.const(1, dtype="float16")
) # just choose some random value
lv35: R.Tensor((1, 32, n, m), dtype="float16") = R.minimum(lv34, mask)
# lv36: R.Tensor((1, 32, n, m), dtype="float32") = R.astype(lv35, dtype="float32")
lv37: R.Tensor((1, 32, n, m), dtype="float16") = R.nn.softmax(lv35, axis=-1)
# lv38: R.Tensor((1, 32, n, m), dtype="float16") = R.astype(lv37, dtype="float16")
lv39: R.Tensor((1, 32, n, 128), dtype="float16") = R.matmul(
lv37, lv30, out_dtype="void"
)
lv40: R.Tensor((1, n, 32, 128), dtype="float16") = R.permute_dims(
lv39, axes=[0, 2, 1, 3]
)
lv41: R.Tensor((1, n, 4096), dtype="float16") = R.reshape(lv40, R.shape([1, n, 4096]))
lv42: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(
linear_weight3, axes=None
)
lv43: R.Tensor((1, n, 4096), dtype="float16") = R.matmul(lv41, lv42, out_dtype="void")
lv44: R.Tensor((1, n, 4096), dtype="float16") = R.add(input_tokens, lv43)
gv = lv44
return gv
@I.ir_module
class ShardedLlamaAttentionLayerDynamicShape:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func
def rms_norm(
var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle
):
T.func_attr({"tir.noalias": True})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), n))
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
T.reads(A[v_bsz, v_i, v_k])
T.writes(Ared_temp[v_bsz, v_i])
with T.init():
Ared_temp[v_bsz, v_i] = T.float32(0)
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
"float32", A[v_bsz, v_i, v_k]
) * T.Cast("float32", A[v_bsz, v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("rms_norm"):
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
T.writes(rms_norm_1[v_bsz, v_i, v_k])
rms_norm_1[v_bsz, v_i, v_k] = T.Cast(
"float16",
T.Cast("float32", B[v_k])
* (
T.Cast("float32", A[v_bsz, v_i, v_k])
/ T.sqrt(
Ared_temp[v_bsz, v_i] * T.float32(0.000244140625)
+ T.float32(9.9999999999999995e-07)
)
),
)
@T.prim_func
def rotary_embedding(
var_A: T.handle,
B: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
C: T.Buffer((T.int64(2048), T.int64(128)), "float16"),
var_rotary: T.handle,
m: T.int64,
):
T.func_attr({"tir.noalias": True})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
rotary = T.match_buffer(
var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16"
)
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(
B[m + v_i1 - n, v_i3],
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64) : v_i3 - T.int64(64) + T.int64(129)],
C[m + v_i1 - n, v_i3],
)
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[
v_i0, v_i1, v_i2, v_i3
] + C[m + v_i1 - n, v_i3] * T.Select(
T.int64(64) <= v_i3,
A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)],
A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1),
)
@R.function(pure=False)
def foo(
input_tokens: R.DTensor((1, "n", 4096), "float16", "mesh[0]", "R"),
mask: R.DTensor((1, 1, "n", "m"), "float16", "mesh[0]", "R"),
kv_cache: R.Tuple(R.Object, R.Object),
linear_weight: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight1: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight2: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"),
linear_weight3: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"),
rms_norm_weight: R.DTensor((4096,), "float16", "mesh[0]", "R"),
cos_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
sin_cached: R.DTensor((2048, 128), "float16", "mesh[0]", "R"),
) -> R.DTensor((1, "n", 4096), "float16", "mesh[0]", "R"):
n = T.int64()
m = T.int64()
cls = ShardedLlamaAttentionLayerDynamicShape
lv6 = R.dist.call_tir(
cls.rms_norm,
(input_tokens, rms_norm_weight),
out_sinfo=R.DTensor((1, n, 4096), "float16", "mesh[0]", "R"),
)
lv7: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight, axes=None
)
lv8: R.DTensor((1, n, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv7, out_dtype="void"
)
lv9: R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv8, R.shape([1, n, 32, 128])
)
lv10: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight1, axes=None
)
lv11: R.DTensor((1, n, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv10, out_dtype="void"
)
lv12: R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv11, R.shape([1, n, 32, 128])
)
lv13: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims(
linear_weight2, axes=None
)
lv14: R.DTensor((1, n, 4096), "float16", "mesh[0]", "S[2]") = R.matmul(
lv6, lv13, out_dtype="void"
)
lv15: R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv14, R.shape([1, n, 32, 128])
)
lv16 = R.dist.call_tir(
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"),
tir_vars=R.shape([m]),
)
lv17 = R.dist.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"),
tir_vars=R.shape([m]),
)
lv18: R.DTensor((n, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape(
lv17, R.shape([n, 32, 128])
)
lv19: R.DTensor((n, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape(
lv15, R.shape([n, 32, 128])
)
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv20,
lv18,
sinfo_args=(R.Object,),
)
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_append",
lv22,
lv19,
sinfo_args=(R.Object,),
)
lv24: R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv21,
R.shape([m, 32, 128]),
sinfo_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv25: R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed(
"vm.builtin.distributed.attention_kv_cache_view",
lv23,
R.shape([m, 32, 128]),
sinfo_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),),
)
lv26: R.DTensor((1, m, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv24, R.shape([1, m, 32, 128])
)
lv27: R.DTensor((1, m, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape(
lv25, R.shape([1, m, 32, 128])
)
lv28: R.DTensor((1, 32, n, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv16, axes=[0, 2, 1, 3]
)
lv29: R.DTensor((1, 32, m, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv26, axes=[0, 2, 1, 3]
)
lv30: R.DTensor((1, 32, m, 128), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv27, axes=[0, 2, 1, 3]
)
lv31: R.DTensor((1, 32, 128, m), "float16", "mesh[0]", "S[1]") = R.permute_dims(
lv29, axes=[0, 1, 3, 2]
)
lv32: R.DTensor((1, 32, n, m), "float16", "mesh[0]", "S[1]") = R.matmul(
lv28, lv31, out_dtype="void"
)
lv33: R.DTensor((1, 32, n, m), "float16", "mesh[0]", "S[1]") = R.divide(
lv32, R.dist.const(8, R.DTensor((), "float16", "mesh[0]", "R"))
)
lv34: R.DTensor((1, 32, n, m), "float16", "mesh[0]", "S[1]") = R.maximum(
lv33, R.dist.const(1, R.DTensor((), "float16", "mesh[0]", "R"))
)
lv35: R.DTensor((1, 32, n, m), "float16", "mesh[0]", "S[1]") = R.minimum(lv34, mask)
lv37: R.DTensor((1, 32, n, m), "float16", "mesh[0]", "S[1]") = R.nn.softmax(
lv35, axis=-1
)
lv39: R.DTensor((1, 32, n, 128), "float16", "mesh[0]", "S[1]") = R.matmul(
lv37, lv30, out_dtype="void"
)
lv40: R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]") = R.permute_dims(
lv39, axes=[0, 2, 1, 3]
)
lv41: R.DTensor((1, n, 4096), "float16", "mesh[0]", "S[2]") = R.reshape(
lv40, R.shape([1, n, 4096])
)
lv42: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]") = R.permute_dims(
linear_weight3, axes=None
)
lv43: R.DTensor((1, n, 4096), "float16", "mesh[0]", "R") = R.matmul(
lv41, lv42, out_dtype="void"
)
lv44: R.DTensor((1, n, 4096), "float16", "mesh[0]", "R") = R.add(input_tokens, lv43)
gv: R.DTensor((1, n, 4096), "float16", "mesh[0]", "R") = lv44
return gv
after = relax.distributed.transform.PropagateSharding()(LlamaAttentionLayerDynamicShape)
assert_structural_equal(after, ShardedLlamaAttentionLayerDynamicShape)
if __name__ == "__main__":
tvm.testing.main()