blob: d11774d73797d8379bd8a21244c7886effbc4bff [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm.testing
from tvm.ir import Range
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement
from tvm.relax import TensorStructInfo
def _assert_print(obj, expected):
if not isinstance(obj, str):
obj = obj.script(verbose_expr=True)
obj = obj.strip()
assert obj == expected.strip(), "\n" + obj
def test_constant():
constant = R.dist.const(
1,
struct_info=R.DTensor(
(), "float32", device_mesh=DeviceMesh((2, 2), Range(0, 4)), placement="R, R"
),
)
assert (
constant.__str__()
== """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))"""
)
def test_dtensor_struct_info():
tensor_sinfo1 = TensorStructInfo((32, 32), "float32")
tensor_sinfo2 = TensorStructInfo((32, 32), "void")
obj0 = DTensorStructInfo(
tensor_sinfo1, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R")
)
assert (
obj0.__str__()
== """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[1], R")"""
)
obj1 = DTensorStructInfo(
tensor_sinfo2, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R")
)
assert (
obj1.__str__()
== """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), R.Range(0, 4)), placement="S[1], R")"""
)
obj2 = DTensorStructInfo(
tensor_sinfo2, DeviceMesh((2, 2), [0, 1, 2, 3]), Placement.from_text("S[1], R")
)
assert (
obj2.__str__()
== """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), [0, 1, 2, 3]), placement="S[1], R")"""
)
@I.ir_module
class TestModule:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2, 2), I.Range(0, 4)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func
def tir_func(
x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
y[vi, vj] = x[vi, vj] + 1.0
@R.function
def foo(
x: R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"),
) -> R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"):
gv0 = R.dist.call_tir(
TestModule.tir_func,
x,
R.DTensor(
shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R"
),
)
return gv0
def test_func():
_assert_print(
TestModule["foo"],
"""
# from tvm.script import relax as R
@R.function
def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"):
gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"))
return gv0
""",
)
def test_module():
_assert_print(
TestModule,
"""
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"device_num": 10})
I.module_global_infos({"mesh": [R.device_mesh((2, 2), I.Range(0, 4)), R.device_mesh((1,), I.Range(4, 5))]})
@T.prim_func
def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block(""):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(x[vi, vj])
T.writes(y[vi, vj])
y[vi, vj] = x[vi, vj] + T.float32(1.0)
@R.function
def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"):
cls = Module
gv0 = R.dist.call_tir(cls.tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"))
return gv0
""",
)
if __name__ == "__main__":
tvm.testing.main()