blob: a1079288af2948a1ab3155dfc41f71d2575dced9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Union
import pytest
import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
from tvm.ir import Range
from tvm.relax import SeqExpr, VarBinding, Call
from tvm.relax.distributed import DeviceMesh
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
def _check(
parsed: Union[relax.Function, IRModule],
expect: Optional[Union[relax.Function, IRModule]] = None,
):
test = parsed.script(show_meta=True)
roundtrip_mod = tvm.script.from_source(test)
tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
if expect:
tvm.ir.assert_structural_equal(parsed, expect)
def test_call_tir_dtensor():
@I.ir_module
class TestModule:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2, 2), I.Range(0, 4)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func
def tir_func(
x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
y[vi, vj] = x[vi, vj] + 1.0
@R.function
def foo(
x: R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"),
) -> R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"):
gv0 = R.dist.call_tir(
TestModule.tir_func,
x,
R.DTensor(
shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R"
),
)
return gv0
device_mesh_list = [DeviceMesh((2, 2), Range(0, 4)), DeviceMesh((1,), Range(4, 5))]
foo_func = TestModule["foo"]
params = foo_func.params
assert len(params) == 1
assert params[0].struct_info == R.DTensor(
(128, 128), "float32", device_mesh_list[0], placement="S[0], R"
)
assert foo_func.ret_struct_info == R.DTensor(
(128, 128), "float32", device_mesh_list[0], placement="S[0], R"
)
assert isinstance(foo_func.body, SeqExpr)
assert len(foo_func.body.blocks[0].bindings) == 1
assert isinstance(foo_func.body.blocks[0].bindings[0], VarBinding)
value = foo_func.body.blocks[0].bindings[0].value
assert isinstance(value, Call)
assert value.sinfo_args[0] == R.DTensor(
(128, 128), "float32", device_mesh_list[0], placement="S[0], R"
)
_check(TestModule)
def test_explicit_device_id():
@I.ir_module
class TestModule:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2, 2), [0, 1, 2, 3]), # mesh[0]
R.device_mesh(
(1,),
[
4,
],
), # mesh[1]
]
}
)
@T.prim_func
def tir_func(
x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
y[vi, vj] = x[vi, vj] + 1.0
@R.function
def foo(
x: R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"),
) -> R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"):
gv0 = R.dist.call_tir(
TestModule.tir_func,
x,
R.DTensor(
shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R"
),
)
return gv0
_check(TestModule)
def test_constant():
@I.ir_module
class TestModule:
I.module_attrs({"device_num": 10})
I.module_global_infos(
{
"mesh": [
R.device_mesh((2, 2), I.Range(0, 4)), # mesh[0]
R.device_mesh((1,), I.Range(4, 5)), # mesh[1]
]
}
)
@T.prim_func
def tir_func(
x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
T.func_attr({"tir.noalias": True})
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
y[vi, vj] = x[vi, vj] + 1.0
@R.function
def foo(
x: R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"),
) -> R.DTensor((128, 128), "float32", device_mesh="mesh[0]", placement="S[0], R"):
gv0 = R.dist.call_tir(
TestModule.tir_func,
x,
R.DTensor(
shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R"
),
)
gv1 = R.add(
gv0, R.dist.const(1.0, struct_info=R.DTensor((), "float32", "mesh[0]", "R, R"))
)
return gv1
_check(TestModule)
if __name__ == "__main__":
tvm.testing.main()