blob: 9633244c67fb39fb055a0cdeabe4e073745f55fb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test last-stage of codegen VM.
Restrictions: all shape lowered, explicit allocation.
"""
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
from tvm.relax.testing.vm import check_saved_func
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
EXEC_MODE = ["bytecode", "compiled"]
def codegen(mod, target, exec_mode="bytecode"):
builder = relax.ExecBuilder()
tir_mod = relax.vm_build._vmcodegen(builder, mod, exec_mode=exec_mode)
return relax.vm_build._vmlink(builder, target, tir_mod)
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_copy(exec_mode):
@tvm.script.ir_module
class TestVMMove:
@R.function(pure=False)
def foo(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "foo"})
z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32")))
return z
mod = TestVMMove
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
inp = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32))
vm = relax.VirtualMachine(ex, tvm.cpu())
res = check_saved_func(vm, "foo", inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_to_device(exec_mode):
@tvm.script.ir_module
class TestVMToDevice:
@R.function(pure=False)
def foo(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "foo"})
# Copy x to the first cpu: device_type=1 and device_id=0.
z = R.call_packed(
"vm.builtin.to_device", x, 1, 0, sinfo_args=(R.Tensor((3, 4), dtype="float32"))
)
return z
mod = TestVMToDevice
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
inp = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32))
vm = relax.VirtualMachine(ex, tvm.cpu())
res = check_saved_func(vm, "foo", inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
# check the resulting tensor is on cpu:0
assert res.device == tvm.cpu(0)
assert res.device.dlpack_device_type() == 1
assert res.device.index == 0
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_if_cond_const(exec_mode):
@tvm.script.ir_module
class TestVMIfCondConst:
@R.function
def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"):
R.func_attr({"global_symbol": "main"})
if relax.const(True, dtype="bool"):
ret = x
else:
ret = x
return ret
mod = TestVMIfCondConst
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.runtime.tensor(np.random.rand(3, 4))
res = vm["main"](inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy())
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_exec_serialize_export_library(exec_mode):
@tvm.script.ir_module
class TestVMMove:
@R.function(pure=False)
def foo(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "foo"})
z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32")))
return z
mod = TestVMMove
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target)
from tvm.contrib import utils
temp_dir = utils.tempdir()
path_exec = temp_dir.relpath("exec.so")
ex.export_library(path_exec)
loaded_exec = tvm.runtime.load_module(path_exec)
assert ex.as_text() == loaded_exec["as_text"]()
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_if_cond(exec_mode):
@tvm.script.ir_module
class TestVMCompileIf:
@R.function(pure=False)
def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor:
R.func_attr({"global_symbol": "ife"})
if cond:
w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
else:
w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
return w
mod = TestVMCompileIf
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.runtime.tensor(np.random.rand(3, 4))
res = vm["ife"](tvm.runtime.tensor(1), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](tvm.runtime.tensor(True), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](tvm.runtime.tensor(0), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7)
res = vm["ife"](tvm.runtime.tensor(False), inp)
tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_return_const_tuple(exec_mode):
@tvm.script.ir_module
class ReturnConstTuple:
@R.function
def main(x: R.Tensor(ndim=2, dtype="float32")):
R.func_attr({"global_symbol": "main"})
y = R.const([1, 2])
z = (y, R.const([3, 4]), x)
return z
mod = ReturnConstTuple
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.runtime.tensor(np.random.rand(2, 3))
res0, res1, res2 = vm["main"](inp)
tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2]))
tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4]))
tvm.testing.assert_allclose(res2.numpy(), inp.numpy())
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_const_as_call_arg(exec_mode):
@tvm.script.ir_module
class TestVMConstAsCallArg:
@R.function(pure=False)
def main(x: R.Tensor(ndim=2, dtype="float32")):
R.func_attr({"global_symbol": "main"})
a = R.call_packed(
"test.vm.add",
relax.const([1, 2]),
relax.const([3, 4]),
sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
)
b = R.call_packed(
"test.vm.add",
a,
x,
sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
)
return b
mod = TestVMConstAsCallArg
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
inp = tvm.runtime.tensor(np.random.rand(1, 2))
res = vm["main"](inp)
tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy())
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_shape_check_builtin(exec_mode):
MS = MatchShapeCode
MK = MakeShapeCode
# slot assignment:
# 0: n, 1: m
sindex = {"n": 0, "m": 1}
@tvm.script.ir_module
class TestVMShapeCheck:
@R.function(pure=False)
def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3):
R.func_attr({"global_symbol": "main"})
n = T.int64()
k = T.int64()
shape_heap = R.call_builtin_with_ctx(
"vm.builtin.alloc_shape_heap",
[R.prim_value(3)],
sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
)
_ = R.call_packed(
"vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()]
)
_ = R.call_packed(
"vm.builtin.match_shape",
x,
shape_heap,
2,
MS.STORE_TO_HEAP,
sindex["n"],
MS.STORE_TO_HEAP,
sindex["m"],
"",
sinfo_args=[R.Tuple()],
)
# construct shape value for return
s = R.call_packed(
"vm.builtin.make_shape",
shape_heap,
3,
MK.LOAD_SHAPE,
sindex["m"],
MK.LOAD_SHAPE,
sindex["n"],
MK.USE_IMM,
2,
sinfo_args=[R.Shape(ndim=3)],
)
return s
mod = TestVMShapeCheck
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
x = tvm.runtime.tensor(np.zeros((1, 2)).astype("float32"))
res = vm["main"](x)
assert res == tvm.runtime.container.ShapeTuple([2, 1, 2])
# wrong input type
with pytest.raises(TypeError):
vm["main"]([])
# wrong ndim
with pytest.raises(ValueError, match=r".*ndim.*"):
vm["main"](tvm.runtime.tensor(np.zeros(1).astype("float32")))
# wrong dtype
with pytest.raises(ValueError, match=r".*dtype.*"):
vm["main"](tvm.runtime.tensor(np.zeros((1, 2)).astype("int32")))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_prim_value(exec_mode):
@tvm.script.ir_module
class TestVMPrimValue:
@R.function
def main():
R.func_attr({"global_symbol": "main"})
ret = R.prim_value(T.int64(1))
return ret
mod = TestVMPrimValue
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
res = vm["main"]()
assert res == 1
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_string_imm(exec_mode):
@tvm.script.ir_module
class TestVMStringImm:
@R.function
def main():
R.func_attr({"global_symbol": "main"})
ret = R.str("hello")
return ret
mod = TestVMStringImm
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
res = vm["main"]()
assert res == "hello"
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_datatype_imm(exec_mode):
@tvm.script.ir_module
class TestDataTypeImm:
@R.function
def main():
R.func_attr({"global_symbol": "main"})
ret = R.dtype("float32")
return ret
mod = TestDataTypeImm
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
res = vm["main"]()
assert res == "float32"
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_builtin_reshape(exec_mode):
@tvm.script.ir_module
class TestVMBuiltinReshape:
@R.function(pure=False)
def main(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "main"})
y = R.call_packed(
"vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32")
)
return y
mod = TestVMBuiltinReshape
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
dev = tvm.cpu()
vm = relax.VirtualMachine(ex, dev)
input_np = np.random.rand(3, 4).astype("float32")
input = tvm.runtime.tensor(input_np, dev)
res = vm["main"](input)
expected = input_np.reshape(6, 2)
tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_kill_object(exec_mode):
@I.ir_module
class TestKillObject:
@T.prim_func
def full(T_full: T.Buffer((T.int64(4),), "float32")):
T.func_attr({"global_symbol": "full", "tir.noalias": True})
for ax0 in range(T.int64(4)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(4), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(0)
@T.prim_func
def full1(T_full: T.Buffer((T.int64(4),), "float32")):
T.func_attr({"global_symbol": "full1", "tir.noalias": True})
for ax0 in range(T.int64(4)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(4), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(1)
# PrimFuncs called directly are treated as impure
@R.function(pure=False)
def main() -> R.Tensor((4,), dtype="float32"):
R.func_attr({"global_symbol": "main"})
cls = TestKillObject
storage: R.Object = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8"))
alloc: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
storage, R.prim_value(0), R.shape([4]), R.dtype("float32")
)
_: R.Tuple = cls.full(alloc)
__1: R.Tuple = R.vm.kill_object(alloc)
x: R.Tensor((4,), dtype="float32") = alloc
alloc1: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
storage, R.prim_value(0), R.shape([4]), R.dtype("float32")
)
_1: R.Tuple = cls.full(alloc1)
_1_1: R.Tuple = R.vm.kill_object(alloc1)
y: R.Tensor((4,), dtype="float32") = alloc1
storage_1: R.Object = R.vm.alloc_storage(
R.shape([16]), R.prim_value(0), R.dtype("uint8")
)
alloc2: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
storage_1, R.prim_value(0), R.shape([4]), R.dtype("float32")
)
_2: R.Tuple = cls.full1(alloc2)
z: R.Tensor((4,), dtype="float32") = alloc2
_2_1: R.Tuple = R.vm.kill_object(storage)
return z
mod = TestKillObject
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
dev = tvm.cpu()
vm = relax.VirtualMachine(ex, dev)
res = vm["main"]()
tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32"))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_preserve_trivial_bindings(exec_mode):
@I.ir_module
class mod:
@R.function(pure=False)
def main():
callback = R.ExternFunc("test.vm.check_if_defined")
storage = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8"))
alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([4]), R.dtype("float32"))
storage_alias = storage
alloc_alias = alloc
storage_before = callback(storage)
alloc_before = callback(alloc)
storage_alias_before = callback(storage_alias)
alloc_alias_before = callback(alloc_alias)
_ = R.vm.kill_object(storage)
_ = R.vm.kill_object(alloc)
storage_after = callback(storage)
alloc_after = callback(alloc)
storage_alias_after = callback(storage_alias)
alloc_alias_after = callback(alloc_alias)
return (
storage_before,
alloc_before,
storage_alias_before,
alloc_alias_before,
storage_after,
alloc_after,
storage_alias_after,
alloc_alias_after,
)
target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
dev = tvm.cpu()
vm = relax.VirtualMachine(ex, dev)
result_list = vm["main"]()
# Making a dictionary of expected results is purely to improve
# readability of test failures. This is equivalent to asserting
# on each element of the result array, but lets pytest give us a
# diff of the dictionaries in case of failure.
expected_results = {
"storage_before": True,
"alloc_before": True,
"storage_alias_before": True,
"alloc_alias_before": True,
"storage_after": False,
"alloc_after": False,
"storage_alias_after": True,
"alloc_alias_after": True,
}
observed_results = {
name: bool(tir_bool) for name, tir_bool in zip(expected_results.keys(), result_list)
}
assert observed_results == expected_results
if __name__ == "__main__":
tvm.testing.main()