| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import tvm.testing |
| from tvm.ir import Range |
| from tvm.script.parser import ir as I |
| from tvm.script.parser import relax as R |
| from tvm.script.parser import tir as T |
| from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement |
| from tvm.relax import TensorStructInfo |
| |
| |
| def _assert_print(obj, expected): |
| if not isinstance(obj, str): |
| obj = obj.script(verbose_expr=True) |
| obj = obj.strip() |
| assert obj == expected.strip(), "\n" + obj |
| |
| |
| def test_constant(): |
| constant = R.dist.const( |
| 1, |
| struct_info=R.DTensor( |
| (), "float32", device_mesh=DeviceMesh((2, 2), Range(0, 4)), placement="R, R" |
| ), |
| ) |
| assert ( |
| constant.__str__() |
| == """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" |
| ) |
| |
| |
| def test_dtensor_struct_info(): |
| tensor_sinfo1 = TensorStructInfo((32, 32), "float32") |
| tensor_sinfo2 = TensorStructInfo((32, 32), "void") |
| obj0 = DTensorStructInfo( |
| tensor_sinfo1, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R") |
| ) |
| assert ( |
| obj0.__str__() |
| == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[1], R")""" |
| ) |
| |
| obj1 = DTensorStructInfo( |
| tensor_sinfo2, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R") |
| ) |
| assert ( |
| obj1.__str__() |
| == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), R.Range(0, 4)), placement="S[1], R")""" |
| ) |
| |
| obj2 = DTensorStructInfo( |
| tensor_sinfo2, DeviceMesh((2, 2), [0, 1, 2, 3]), Placement.from_text("S[1], R") |
| ) |
| assert ( |
| obj2.__str__() |
| == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), [0, 1, 2, 3]), placement="S[1], R")""" |
| ) |
| |
| |
| @I.ir_module |
| class TestModule: |
| I.module_attrs({"device_num": 10}) |
| I.module_global_infos( |
| { |
| "mesh": [ |
| R.device_mesh((2, 2), I.Range(0, 4)), # mesh[0] |
| R.device_mesh((1,), I.Range(4, 5)), # mesh[1] |
| ] |
| } |
| ) |
| |
| @T.prim_func |
| def tir_func( |
| x: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| y: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| y[vi, vj] = x[vi, vj] + 1.0 |
| |
| @R.function |
| def foo( |
| x: R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"), |
| ) -> R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"): |
| gv0 = R.dist.call_tir( |
| TestModule.tir_func, |
| x, |
| R.DTensor( |
| shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R" |
| ), |
| ) |
| return gv0 |
| |
| |
| def test_func(): |
| _assert_print( |
| TestModule["foo"], |
| """ |
| # from tvm.script import relax as R |
| |
| @R.function |
| def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"): |
| gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) |
| return gv0 |
| """, |
| ) |
| |
| |
| def test_module(): |
| _assert_print( |
| TestModule, |
| """ |
| # from tvm.script import ir as I |
| # from tvm.script import tir as T |
| # from tvm.script import relax as R |
| |
| @I.ir_module |
| class Module: |
| I.module_attrs({"device_num": 10}) |
| I.module_global_infos({"mesh": [R.device_mesh((2, 2), I.Range(0, 4)), R.device_mesh((1,), I.Range(4, 5))]}) |
| @T.prim_func |
| def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32")): |
| T.func_attr({"tir.noalias": True}) |
| # with T.block("root"): |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(""): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(x[vi, vj]) |
| T.writes(y[vi, vj]) |
| y[vi, vj] = x[vi, vj] + T.float32(1.0) |
| |
| @R.function |
| def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): |
| cls = Module |
| gv0 = R.dist.call_tir(cls.tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) |
| return gv0 |
| """, |
| ) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |