blob: a1308098216c679917ba3407e0416d768e99ecff [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# type: ignore
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
import tvm
from tvm import relax
from tvm.ir import assert_structural_equal
import tvm.testing
def test_mlp():
@I.ir_module
class MLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func(private=True)
def gelu1(
A: T.Buffer((T.int64(128), T.int64(64)), "float32"),
T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(64)))
compute = T.alloc_buffer((T.int64(128), T.int64(64)))
T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(64)))
T_add = T.alloc_buffer((T.int64(128), T.int64(64)))
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757)
for i0, i1 in T.grid(T.int64(128), T.int64(64)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_1[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5)
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
@T.prim_func(private=True)
def matmul1(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(64)), "float32"),
matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func(private=True)
def matmul2(
A: T.Buffer((T.int64(128), T.int64(64)), "float32"),
B: T.Buffer((T.int64(64), T.int64(128)), "float32"),
matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
R.func_attr({"num_input": 1})
cls = MLP
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view(
cls.matmul1,
(x, weight1),
out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view(
cls.gelu1, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]")
)
lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
gv: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.call_tir_local_view(
cls.matmul2,
(lv2, weight2),
out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"),
)
lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce(
gv, op_type="sum"
)
return lv3
@I.ir_module(check_well_formed=False)
class LoweredMLP:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function
def foo(
x: R.Tensor((128, 128), dtype="float32"),
weight1: R.Tensor((128, 128), dtype="float32"),
weight2: R.Tensor((128, 128), dtype="float32"),
) -> R.Tensor((128, 128), dtype="float32"):
R.func_attr({"num_input": 1})
cls = LoweredMLP
gv: R.Tensor((128, 128), dtype="float32") = R.ccl.broadcast_from_worker0(x)
gv1: R.Tensor((128, 64), dtype="float32") = R.ccl.scatter_from_worker0(
weight1, num_workers=2, axis=1
)
gv2: R.Tensor((64, 128), dtype="float32") = R.ccl.scatter_from_worker0(
weight2, num_workers=2, axis=0
)
lv0 = R.call_tir(
MLP.get_global_var("matmul1"),
(gv, gv1),
out_sinfo=R.Tensor((128, 64), dtype="float32"),
)
lv1 = R.call_tir(
MLP.get_global_var("gelu1"), (lv0,), out_sinfo=R.Tensor((128, 64), dtype="float32")
)
lv2: R.Tensor((128, 64), dtype="float32") = lv1
gv_1 = R.call_tir(
MLP.get_global_var("matmul2"),
(lv2, gv2),
out_sinfo=R.Tensor((128, 128), dtype="float32"),
)
lv3: R.Tensor((128, 128), dtype="float32") = R.ccl.allreduce(gv_1, op_type="sum")
return lv3
for gv, func in MLP.functions_items():
if gv.name_hint != "foo":
LoweredMLP[gv] = func
mod = MLP
mod = relax.distributed.transform.LowerDistIR()(mod)
tvm.ir.assert_structural_equal(mod, LoweredMLP)
def test_mlp_with_tuple():
@I.ir_module
class MLPWithTuple:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@T.prim_func(private=True)
def gelu1(
A: T.Buffer((T.int64(128), T.int64(64)), "float32"),
T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(64)))
compute = T.alloc_buffer((T.int64(128), T.int64(64)))
T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(64)))
T_add = T.alloc_buffer((T.int64(128), T.int64(64)))
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757)
for i0, i1 in T.grid(T.int64(128), T.int64(64)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_1[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5)
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(128), T.int64(64)):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
@T.prim_func(private=True)
def matmul11(
A: T.Buffer((T.int64(64), T.int64(64)), "float32"),
B: T.Buffer((T.int64(64), T.int64(128)), "float32"),
matmul: T.Buffer((T.int64(64), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(64), T.int64(128), T.int64(64)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func(private=True)
def matmul2(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(64)), "float32"),
matmul: T.Buffer((T.int64(128), T.int64(64)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(A[v_i0, v_k], B[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
@T.prim_func(private=True)
def split11(
A: T.Buffer((128, 64), "float32"),
T_split: T.Buffer((64, 64), "float32"),
T_split_1: T.Buffer((64, 64), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax1, ax2 in T.grid(64, 64):
with T.block("T_split"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[v_ax1, v_ax2])
T.writes(T_split[v_ax1, v_ax2])
T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2]
for ax1, ax2 in T.grid(64, 64):
with T.block("T_split_1"):
v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2])
T.reads(A[v_ax1 + 64, v_ax2])
T.writes(T_split_1[v_ax1, v_ax2])
T_split_1[v_ax1, v_ax2] = A[v_ax1 + 64, v_ax2]
@R.function
def foo(
x: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
weight_packed: R.Tuple(
R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
),
) -> R.DTensor((64, 128), "float32", "mesh[0]", "R"):
cls = MLPWithTuple
weight1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = weight_packed[0]
lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view(
cls.matmul2,
(x, weight1),
out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view(
cls.gelu1, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]")
)
gv: R.Tuple(
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
) = R.dist.call_tir_local_view(
cls.split11,
(lv1,),
out_sinfo=[
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"),
],
)
lv2: R.DTensor((64, 128), "float32", "mesh[0]", "S[1]") = gv[0]
lv3: R.DTensor((64, 128), "float32", "mesh[0]", "S[1]") = lv2
weight2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]") = weight_packed[1]
gv_1: R.DTensor((64, 128), "float32", "mesh[0]", "R") = R.dist.call_tir_local_view(
cls.matmul11,
(lv3, weight2),
out_sinfo=R.DTensor((64, 128), "float32", "mesh[0]", "R"),
)
lv4: R.DTensor((64, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce(
gv_1, op_type="sum"
)
return lv4
@I.ir_module(check_well_formed=False)
class LoweredMLPWithTuple:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{"mesh": [R.device_mesh((2,), I.Range(0, 2)), R.device_mesh((1,), I.Range(4, 5))]}
)
@R.function
def foo(
x: R.Tensor((128, 128), dtype="float32"),
weight_packed: R.Tuple(
R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")
),
) -> R.Tensor((64, 128), dtype="float32"):
cls = LoweredMLPWithTuple
gv: R.Tensor((128, 128), dtype="float32") = R.ccl.broadcast_from_worker0(x)
gv1: R.Tensor((128, 128), dtype="float32") = weight_packed[0]
gv2: R.Tensor((128, 64), dtype="float32") = R.ccl.scatter_from_worker0(
gv1, num_workers=2, axis=1
)
gv3: R.Tensor((128, 128), dtype="float32") = weight_packed[1]
gv4: R.Tensor((64, 128), dtype="float32") = R.ccl.scatter_from_worker0(
gv3, num_workers=2, axis=0
)
lv0 = R.call_tir(
MLPWithTuple.get_global_var("matmul2"),
(gv, gv2),
out_sinfo=R.Tensor((128, 64), dtype="float32"),
)
lv1 = R.call_tir(
MLPWithTuple.get_global_var("gelu1"),
(lv0,),
out_sinfo=R.Tensor((128, 64), dtype="float32"),
)
gv_1 = R.call_tir(
MLPWithTuple.get_global_var("split11"),
(lv1,),
out_sinfo=[
R.Tensor((64, 64), dtype="float32"),
R.Tensor((64, 64), dtype="float32"),
],
)
lv2: R.Tensor((64, 64), dtype="float32") = gv_1[0]
lv3: R.Tensor((64, 64), dtype="float32") = lv2
gv_1_1 = R.call_tir(
MLPWithTuple.get_global_var("matmul11"),
(lv3, gv4),
out_sinfo=R.Tensor((64, 128), dtype="float32"),
)
lv4: R.Tensor((64, 128), dtype="float32") = R.ccl.allreduce(gv_1_1, op_type="sum")
return lv4
for gv, func in MLPWithTuple.functions_items():
if gv.name_hint != "foo":
LoweredMLPWithTuple[gv] = func
mod = MLPWithTuple
mod = relax.distributed.transform.LowerDistIR()(mod)
tvm.ir.assert_structural_equal(mod, LoweredMLPWithTuple)
if __name__ == "__main__":
tvm.testing.main()