blob: c60b307d68f727425e7caca9654eeabd3e5da913 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test the TIR codegen path of VM compiled mode.
Restrictions: all shape lowered, explicit allocation.
"""
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal
from tvm.script import relax as R
from tvm.script import tir as T
def get_tir_mod(mod):
builder = relax.ExecBuilder()
return relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled")
def test_add():
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def foo(x: R.Tensor):
R.func_attr({"global_symbol": "foo"})
z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
return z
@tvm.script.ir_module
class Expected:
@T.prim_func
def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.func_attr({"global_symbol": "__vmtir__foo"})
T.anylist_setitem_call_packed(
r,
T.int32(2),
"test.vm.add",
T.anylist_getitem(r, T.int32(0)),
T.anylist_getitem(r, T.int32(0)),
)
T.anylist_setitem_call_packed(
r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2))
)
before = Before
expected = Expected
after = get_tir_mod(before)
assert_structural_equal(expected, after)
def test_tir_call():
@tvm.script.ir_module
class Before:
@T.prim_func
def shape_func(H: T.Buffer(T.int64(4), "int64")):
T.func_attr({"global_symbol": "shape_func"})
# generated compute function
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
@R.function(pure=False)
def foo(x: R.Tensor([4], "int64")):
R.func_attr({"global_symbol": "foo"})
_ = Before.shape_func(x)
return x
@tvm.script.ir_module
class Expected:
@T.prim_func
def shape_func(H: T.Buffer(T.int64(4), "int64")):
T.func_attr({"global_symbol": "shape_func"})
# generated compute function
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
@T.prim_func
def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.func_attr({"global_symbol": "__vmtir__foo"})
T.call_cpacked("shape_func", T.anylist_getitem(r, T.int32(0)))
T.anylist_setitem_call_packed(
r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(0))
)
before = Before
expected = Expected
after = get_tir_mod(before)
assert_structural_equal(expected, after)
def test_if_cond():
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor:
R.func_attr({"global_symbol": "ife"})
if cond:
w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
else:
w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
return w
@tvm.script.ir_module
class Expected:
@T.prim_func
def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.func_attr({"global_symbol": "__vmtir__ife"})
if T.Call(
"bool",
tvm.ir.Op.get("tir.tvm_call_packed"),
["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))],
):
T.anylist_setitem_call_packed(
r,
T.int32(4),
"test.vm.add",
T.anylist_getitem(r, T.int32(1)),
T.anylist_getitem(r, T.int32(1)),
)
T.anylist_setitem_call_packed(
r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(4))
)
else:
T.anylist_setitem_call_packed(
r,
T.int32(5),
"test.vm.mul",
T.anylist_getitem(r, T.int32(1)),
T.anylist_getitem(r, T.int32(1)),
)
T.anylist_setitem_call_packed(
r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(5))
)
T.anylist_setitem_call_packed(
r, T.int32(2), "vm.builtin.copy", T.anylist_getitem(r, T.int32(3))
)
before = Before
expected = Expected
after = get_tir_mod(before)
assert_structural_equal(expected, after)
def test_const():
@tvm.script.ir_module
class Before:
@R.function
def main(x: R.Tensor):
R.func_attr({"global_symbol": "main"})
y = R.const([1, 2])
z = (y, R.const([3, 4]), x)
return z
@tvm.script.ir_module
class Expected:
@T.prim_func
def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
# function attr dict
T.func_attr({"global_symbol": "__vmtir__main"})
# body
T.anylist_setitem_call_packed(
r,
T.int32(2),
"vm.builtin.make_tuple",
T.anylist_getitem(c, T.int32(0)),
T.anylist_getitem(c, T.int32(1)),
T.anylist_getitem(r, T.int32(0)),
)
T.anylist_setitem_call_packed(
r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2))
)
before = Before
expected = Expected
after = get_tir_mod(before)
assert_structural_equal(expected, after)
def test_const_call():
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor):
R.func_attr({"global_symbol": "main"})
y = R.const([1, 2])
z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor))
return z
@tvm.script.ir_module
class Expected:
@T.prim_func
def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
# function attr dict
T.func_attr({"global_symbol": "__vmtir__main"})
# body
T.anylist_setitem_call_packed(
r,
2,
"test.vm.add",
T.anylist_getitem(r, 0),
T.anylist_getitem(c, 0),
)
T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))
before = Before
expected = Expected
after = get_tir_mod(before)
assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()