| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import ctypes |
| from typing import Tuple, Callable |
| |
| import numpy as np |
| import pytest |
| |
| import tvm |
| import tvm.script |
| import tvm.testing |
| from tvm import relax, rpc, te, tir, topi |
| from tvm.contrib import utils, cc, popen_pool |
| from tvm.relax.testing import nn |
| from tvm.script import relax as R, tir as T, ir as I |
| from tvm.relax.testing.vm import check_saved_func |
| from tvm.runtime import ShapeTuple |
| |
| EXEC_MODE = ["bytecode", "compiled"] |
| |
| |
| @pytest.fixture(params=EXEC_MODE) |
| def exec_mode(request): |
| return request.param |
| |
| |
| def test_vm_compile_simple(exec_mode): |
| @tvm.script.ir_module |
| class TestVMCompileStage0: |
| @R.function |
| def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): |
| z = R.call_pure_packed( |
| "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) |
| ) |
| return y |
| |
| mod = TestVMCompileStage0 |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| inp1 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) |
| inp2 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| vm["foo"](inp1, inp2) |
| tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_compile_without_target_arg(exec_mode): |
| """Like test_vm_compile_simple, but with a default target""" |
| |
| @tvm.script.ir_module |
| class mod: |
| @R.function |
| def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): |
| z = R.call_pure_packed( |
| "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) |
| ) |
| return y |
| |
| ex = relax.build(mod, exec_mode=exec_mode) |
| inp1 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) |
| inp2 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| vm["foo"](inp1, inp2) |
| tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_match_check(exec_mode): |
| @tvm.script.ir_module |
| class TestMatchCheck: |
| @R.function |
| def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], dtype=None): |
| return y |
| |
| mod = TestMatchCheck |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x0 = tvm.runtime.tensor(np.zeros((1, 2)).astype("int32")) |
| y0 = tvm.runtime.tensor(np.zeros((2, 1)).astype("float32")) |
| y1 = tvm.runtime.tensor(np.zeros((1, 2)).astype("float32")) |
| y2 = tvm.runtime.tensor(np.zeros((2, 1, 1)).astype("float32")) |
| |
| vm["foo"](x0, y0) |
| |
| with pytest.raises(RuntimeError, match=".*return.*"): |
| vm["foo"](x0, y1) |
| |
| with pytest.raises(ValueError, match=".*return.*"): |
| vm["foo"](x0, y2) |
| |
| |
| def test_vm_compile_stage2(exec_mode): |
| @tvm.script.ir_module |
| class TestVMCompileStage2: |
| @R.function |
| def foo(x: R.Tensor(dtype="float32")) -> R.Shape: |
| n, m = T.int64(), T.int64() |
| _ = R.match_cast(x, R.Tensor((n, m), "float32")) |
| return R.shape([n * 2, m * 3]) |
| |
| mod = TestVMCompileStage2 |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| shape = (32, 16) |
| arr = tvm.runtime.tensor(np.random.rand(*shape).astype("float32")) |
| res = vm["foo"](arr) |
| assert res[0] == shape[0] * 2 |
| assert res[1] == shape[1] * 3 |
| |
| # dtype mismatch |
| with pytest.raises(ValueError, match=".*dtype.*"): |
| vm["foo"](tvm.runtime.tensor(np.zeros((1, 2)).astype("int32"))) |
| |
| # ndim mismatch |
| with pytest.raises(ValueError, match=".*match_cast.*ndim.*"): |
| vm["foo"](tvm.runtime.tensor(np.zeros((1,)).astype("float32"))) |
| |
| # type mismach |
| with pytest.raises(TypeError): |
| vm["foo"]([]) |
| |
| |
| def test_vm_compile_stage3(exec_mode): |
| @tvm.script.ir_module |
| class TestVMCompileStage3: |
| @R.function |
| def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) |
| R.output(y) |
| return y |
| |
| mod = TestVMCompileStage3 |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| shape = (32, 16) |
| inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) |
| res = vm["foo"](inp) |
| tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_compile_e2e(exec_mode): |
| @tvm.script.ir_module |
| class TestVMCompileE2E: |
| @R.function |
| def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: |
| with R.dataflow(): |
| n, m = T.int64(), T.int64() |
| _ = R.match_cast(x, R.Tensor((n, m), "float32")) |
| y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) |
| R.output(y) |
| return y |
| |
| mod = TestVMCompileE2E |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| shape = (32, 16) |
| inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) |
| res = check_saved_func(vm, "foo", inp) |
| tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_compile_e2e_func_param_with_shape(exec_mode): |
| @tvm.script.ir_module |
| class TestVMCompileE2E2: |
| @T.prim_func |
| def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: |
| T.func_attr({"global_symbol": "tir_matmul"}) |
| m = T.int32() |
| n = T.int32() |
| k = T.int32() |
| A = T.match_buffer(x, (m, n)) |
| B = T.match_buffer(y, (n, k)) |
| C = T.match_buffer(z, (m, k)) |
| |
| for i, j, k in T.grid(m, k, n): |
| with T.block("matmul"): |
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| with T.init(): |
| C[vi, vj] = T.float32(0) |
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| |
| @R.function |
| def func( |
| x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") |
| ) -> R.Tensor: |
| m, k = T.int64(), T.int64() |
| cls = TestVMCompileE2E2 |
| gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) |
| return gv0 |
| |
| mod = TestVMCompileE2E2 |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| data = tvm.runtime.tensor(np.random.rand(32, 16).astype(np.float32)) |
| weight = tvm.runtime.tensor(np.random.rand(16, 32).astype(np.float32)) |
| res = check_saved_func(vm, "func", data, weight) |
| expected = np.dot(data.numpy(), weight.numpy()) |
| tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) |
| |
| |
| def test_call_tir_inplace_e2e_simple(exec_mode): |
| @tvm.script.ir_module |
| class TestCallTIRInplaceE2ESimple: |
| @T.prim_func |
| def copy( |
| A: T.Buffer((2, 3), "int32"), |
| B: T.Buffer((2, 3), "int32"), |
| C: T.Buffer((2, 3), "int32"), |
| out1: T.Buffer((2, 3), "int32"), |
| ): |
| # copies the contents of C into A, B, and out1 |
| T.func_attr({"tir.noalias": True}) |
| for i0, i1 in T.grid(T.int64(2), T.int64(3)): |
| with T.block("T_zeros"): |
| ax0, ax1 = T.axis.remap("SS", [i0, i1]) |
| T.reads(C[ax0, ax1]) |
| T.writes(A[ax0, ax1], B[ax0, ax1], out1[ax0, ax1]) |
| A[ax0, ax1] = C[ax0, ax1] |
| B[ax0, ax1] = C[ax0, ax1] |
| out1[ax0, ax1] = C[ax0, ax1] |
| |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32") |
| ) -> R.Tuple( |
| R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") |
| ): |
| res = R.call_tir_inplace( |
| TestCallTIRInplaceE2ESimple.copy, |
| (x, y, z), |
| [0, 1, -1], |
| [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], |
| ) |
| return res |
| |
| mod = TestCallTIRInplaceE2ESimple |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| x = tvm.runtime.tensor(np.zeros((2, 3)).astype(np.int32)) |
| y = tvm.runtime.tensor(np.zeros((2, 3)).astype(np.int32)) |
| z = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) |
| vm.set_input("main", x, y, z) |
| vm.invoke_stateful("main") |
| outs = vm.get_outputs("main") |
| # check the expected aliasing (the last result is newly allocated) |
| assert x == outs[0] |
| assert y == outs[1] |
| assert x != y |
| assert x != outs[2] |
| assert y != outs[2] |
| tvm.testing.assert_allclose(x.numpy(), z.numpy(), rtol=1e-7, atol=1e-7) |
| tvm.testing.assert_allclose(y.numpy(), z.numpy(), rtol=1e-7, atol=1e-7) |
| tvm.testing.assert_allclose(outs[2].numpy(), z.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_call_tir_inplace_e2e_rw(exec_mode): |
| # read and write from the same tensor |
| @tvm.script.ir_module |
| class TestCallTIRInplaceE2ERW: |
| @T.prim_func |
| def inplace_add(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): |
| # sums A and B, storing the result in A |
| T.func_attr({"tir.noalias": True}) |
| for i0, i1 in T.grid(T.int64(2), T.int64(3)): |
| with T.block("T_add"): |
| ax0, ax1 = T.axis.remap("SS", [i0, i1]) |
| T.reads(A[ax0, ax1], B[ax0, ax1]) |
| T.writes(A[ax0, ax1]) |
| A[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] |
| |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") |
| ) -> R.Tensor((2, 3), "int32"): |
| res = R.call_tir_inplace( |
| TestCallTIRInplaceE2ERW.inplace_add, (x, y), [0], R.Tensor((2, 3), "int32") |
| ) |
| return res |
| |
| mod = TestCallTIRInplaceE2ERW |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| x = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) |
| y = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) |
| vm.set_input("main", x, y) |
| vm.invoke_stateful("main") |
| out = vm.get_outputs("main") |
| expected = tvm.runtime.tensor(np.full((2, 3), 2).astype(np.int32)) |
| |
| assert x == out |
| tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_emit_te_extern(exec_mode): |
| if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): |
| print("skip because extern function is not available") |
| return |
| bb = relax.BlockBuilder() |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = relax.Var("x", R.Tensor([n, m], "float32")) |
| y = relax.Var("y", R.Tensor([m, n], "float32")) |
| |
| with bb.function("rx_cblas_matmul", [x, y]): |
| out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) |
| bb.emit_func_output(out) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| data = tvm.runtime.tensor(np.random.rand(16, 32).astype(np.float32)) |
| weight = tvm.runtime.tensor(np.random.rand(32, 16).astype(np.float32)) |
| res = check_saved_func(vm, "rx_cblas_matmul", data, weight) |
| expected = np.dot(data.numpy(), weight.numpy()) |
| tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) |
| |
| |
| def test_vm_emit_te_concat(exec_mode): |
| # concatenate of two vectors of size (n,) and (m,) |
| bb = relax.BlockBuilder() |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = relax.Var("x", R.Tensor([n], "float32")) |
| y = relax.Var("y", R.Tensor([m], "float32")) |
| |
| def te_func(A, B): |
| C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) |
| return C |
| |
| with bb.function("rx_func", [x, y]): |
| x1 = bb.emit_te(te_func, x, y) |
| bb.emit_func_output(x1) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| inp = tvm.runtime.tensor( |
| np.random.rand( |
| 1, |
| ).astype(np.float32) |
| ) |
| inp2 = tvm.runtime.tensor( |
| np.random.rand( |
| 2, |
| ).astype(np.float32) |
| ) |
| res = check_saved_func(vm, "rx_func", inp, inp2) |
| tvm.testing.assert_allclose( |
| res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 |
| ) |
| |
| |
| def test_vm_emit_te_dtype_change(exec_mode): |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", R.Tensor([n], "float32")) |
| |
| # convert a tensor with dtype of float32 to int16 |
| def te_func(A): |
| B = te.compute((n,), lambda i: A[i].astype("int16")) |
| return B |
| |
| with bb.function("rx_func", [x]): |
| y = bb.emit_te(te_func, x) |
| bb.emit_func_output(y) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| inp = tvm.runtime.tensor( |
| np.random.rand( |
| 1, |
| ).astype(np.float32) |
| ) |
| res = check_saved_func(vm, "rx_func", inp) |
| np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) |
| |
| |
| def test_vm_emit_te_floor_symbolic_shape(exec_mode): |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", R.Tensor([n], "float32")) |
| |
| def te_func(A): |
| C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) |
| return C |
| |
| with bb.function("rx_func", [x]): |
| x1 = bb.emit_te(te_func, x) |
| bb.emit_func_output(x1) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| shape = (9,) |
| inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) |
| res = check_saved_func(vm, "rx_func", inp) |
| |
| def expected_output(): |
| output_shape = (shape[0] // 2,) |
| return inp.numpy()[: output_shape[0]] + 1 |
| |
| tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_emit_te_constant_param_cpu(exec_mode): |
| x_np = np.random.rand(2, 2).astype("float32") |
| c_np = np.random.rand(2, 2).astype("float32") |
| |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((2, 2), "float32")) |
| c = relax.const(c_np, "float32") |
| with bb.function("main", [x]): |
| with bb.dataflow(): |
| lv0 = bb.emit_te(topi.add, x, c) |
| gv = bb.emit_output(lv0) |
| bb.emit_func_output(gv) |
| |
| mod = bb.get() |
| exec = relax.build(mod, "llvm", exec_mode=exec_mode) |
| dev = tvm.cpu() |
| vm = relax.VirtualMachine(exec, dev) |
| |
| add_res = check_saved_func(vm, "main", tvm.runtime.tensor(x_np, dev)) |
| tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) |
| |
| |
| @tvm.testing.requires_gpu |
| def test_vm_emit_te_constant_param_gpu(exec_mode): |
| x_np = np.random.rand(2, 2).astype("float32") |
| c_np = np.random.rand(2, 2).astype("float32") |
| |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((2, 2), "float32")) |
| c = relax.const(c_np, "float32") |
| with bb.function("main", [x]): |
| with bb.dataflow(): |
| lv0 = bb.emit_te(topi.add, x, c) |
| gv = bb.emit_output(lv0) |
| bb.emit_func_output(gv) |
| |
| mod = bb.get() |
| sch = tvm.tir.Schedule(mod, debug_mask="all") |
| loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) |
| sch.bind(loops[0], "threadIdx.x") |
| |
| exec = relax.build(sch.mod, "cuda", exec_mode=exec_mode) |
| dev = tvm.cuda() |
| vm = relax.VirtualMachine(exec, dev) |
| |
| add_res = check_saved_func(vm, "main", tvm.runtime.tensor(x_np, dev)) |
| tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_relax_symbolic_shape(exec_mode): |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", R.Tensor([n], "float32")) |
| y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32")) |
| |
| def te_func(A, B): |
| C = te.compute((n,), lambda i: A[i] + B[i // 2]) |
| return C |
| |
| with bb.function("rx_func", [x, y]): |
| x1 = bb.emit_te(te_func, x, y) |
| bb.emit_func_output(x1) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| shape1 = (5,) |
| shape2 = (3,) |
| inp = tvm.runtime.tensor(np.random.rand(*shape1).astype(np.float32)) |
| inp2 = tvm.runtime.tensor(np.random.rand(*shape2).astype(np.float32)) |
| res = check_saved_func(vm, "rx_func", inp, inp2) |
| |
| def expected_output(): |
| return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] |
| |
| tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_relax_symbolic_shape_tuple(exec_mode): |
| @I.ir_module |
| class mod: |
| @R.function |
| def main(shape: R.Shape(["m", "n"])): |
| m = T.int64() |
| n = T.int64() |
| return R.shape([2 * m, 3 * n]) |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| func = vm["main"] |
| |
| assert func(ShapeTuple([2, 3])) == (4, 9) |
| |
| with pytest.raises(ValueError): |
| func(ShapeTuple([2, 3, 4])) |
| |
| with pytest.raises(TypeError): |
| func(R.prim_value(2)) |
| |
| |
| def test_vm_relax_symbolic_prim_value(exec_mode): |
| @I.ir_module |
| class mod: |
| @R.function |
| def main(shape: R.Prim(value="n")): |
| n = T.int64() |
| return R.prim_value(n * n) |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| func = vm["main"] |
| |
| assert func(2) == 4 |
| |
| with pytest.raises(TypeError): |
| func(ShapeTuple([2])) |
| |
| |
| def test_vm_relax_multiple_symbolic_prim_value(exec_mode): |
| """Like test_vm_relax_symbolic_prim_value, but with multiple variables""" |
| |
| @I.ir_module |
| class mod: |
| @R.function |
| def main( |
| # Provides definition of "n" |
| _n: R.Prim(value="n"), |
| # Requires definitions of both "n" and "m", but cannot |
| # provide either. |
| _shape: R.Shape(["n*2", "m*2"]), |
| # Provides definition of "m" |
| _m: R.Prim(value="m"), |
| ): |
| n = T.int64() |
| m = T.int64() |
| return R.shape([n * n, m + 1]) |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| func = vm["main"] |
| |
| assert func(2, ShapeTuple([4, 12]), 6) == (4, 7) |
| |
| with pytest.raises(RuntimeError): |
| func(2, ShapeTuple([4, 12]), 1) |
| |
| with pytest.raises(tvm.TVMError): |
| func(ShapeTuple([2])) |
| |
| |
| @pytest.mark.xfail(reason="Current support for R.Prim with known value is primarily for int64") |
| @pytest.mark.parametrize("exec_mode", EXEC_MODE) |
| def test_vm_relax_prim_value_fp32(exec_mode): |
| """A PrimValue may be R.prim('float32') |
| |
| Unlike shape tuples, which must contain int64, a PrimValue may be |
| any type that can be represented as a single primitive value. |
| """ |
| |
| @I.ir_module |
| class mod: |
| @R.function |
| def main( |
| # First failure occurs during parsing. The syntactic |
| # sugar for symbolic variables assumes that all symbolic |
| # variables are int64, rather than using the type that is |
| # later declared. |
| _x: R.Prim(value="half_fill_value"), |
| ): |
| half_fill_value = T.float32() |
| # Second failure occurs when calling `relax.op.full`. The |
| # `fill_value` is expected to be a scalar constant |
| # (R.Tensor with 0-dim shape), not a primitive value, even |
| # though these are semantically the same. |
| return R.full(shape=[16, 16], fill_value=R.prim_value(2 * half_fill_value)) |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| # Third failure occurs here. The current codegen assumes that all |
| # symbolic variables are int64. |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| func = vm["main"] |
| |
| res = func(16.0).numpy() |
| assert np.all(res == 32.0) |
| |
| |
| def test_vm_relax_dyn_tir_shape(exec_mode): |
| # case where TIR variables are unbound in generated PrimFunc |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| |
| def te_func(A): |
| C = te.compute((n + 1), lambda i: A[i]) |
| return C |
| |
| with bb.function("rx_func"): |
| x = nn.Placeholder((n,), dtype="float32", name="x") |
| y = nn.Placeholder((n + 1,), dtype="float32", name="y") |
| |
| x1 = bb.emit_te(te_func, y) |
| bb.emit_func_output(x1, params=[x, y]) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| with utils.tempdir() as temp: |
| ex.export_library(temp.relpath("exec.so")) |
| vm = relax.VirtualMachine(tvm.runtime.load_module(temp.relpath("exec.so")), tvm.cpu()) |
| |
| inp = tvm.runtime.tensor(np.random.rand(2).astype(np.float32)) |
| inp2 = tvm.runtime.tensor(np.random.rand(3).astype(np.float32)) |
| |
| res = check_saved_func(vm, "rx_func", inp, inp2) |
| |
| tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_tuple(exec_mode): |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| |
| with bb.function("rx_func"): |
| x = nn.Placeholder((n,), dtype="float32", name="x") |
| y = nn.Placeholder((n,), dtype="float32", name="y") |
| tup = relax.Tuple([x, y]) |
| item = tup[0] |
| bb.emit_func_output([tup, item], params=[x, y]) |
| |
| mod = bb.get() |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| shape = (5,) |
| inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) |
| inp2 = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) |
| (res1, res2), res3 = vm["rx_func"](inp, inp2) |
| |
| tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) |
| tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) |
| tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_vm_tuplegetitem(exec_mode): |
| @tvm.script.ir_module |
| class TestVMTupleGetItem: |
| @R.function |
| def tuple_get_item( |
| x: R.Tensor(ndim=2, dtype="float32"), |
| y: R.Tensor(ndim=2, dtype="float32"), |
| ): |
| t = (x, y) |
| a = t[0] |
| b = t[1] |
| c = R.call_pure_packed( |
| "test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) |
| ) |
| return c |
| |
| mod = TestVMTupleGetItem |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) |
| y_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) |
| res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) |
| tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_lower_memory_alloc_storage_tensor(exec_mode): |
| @tvm.script.ir_module |
| class TestMemoryAllocStorageTensor: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")): |
| R.func_attr({"relax.force_pure": True}) |
| cls = TestMemoryAllocStorageTensor |
| storage = R.memory.alloc_storage( |
| R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" |
| ) |
| y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") |
| # this is an impure operation, but the overall function is pure so we force purity |
| _ = cls.copy(x, y) |
| return y |
| |
| @T.prim_func |
| def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): |
| for i0, i1 in T.grid(2, 3): |
| with T.block("block"): |
| vi0, vi1 = T.axis.remap("SS", [i0, i1]) |
| B[vi0, vi1] = A[vi0, vi1] |
| |
| mod = TestMemoryAllocStorageTensor |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) |
| y = vm["main"](x) |
| tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_sub_func_call(exec_mode): |
| @tvm.script.ir_module |
| class TestVMSubFunction: |
| @T.prim_func |
| def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: |
| T.func_attr({"global_symbol": "tir_matmul"}) |
| m = T.int32() |
| n = T.int32() |
| k = T.int32() |
| A = T.match_buffer(x, (m, n)) |
| B = T.match_buffer(y, (n, k)) |
| C = T.match_buffer(z, (m, k)) |
| |
| for i, j, k in T.grid(m, k, n): |
| with T.block("matmul"): |
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| with T.init(): |
| C[vi, vj] = T.float32(0) |
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| |
| @R.function |
| def relax_matmul_tir( |
| x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") |
| ) -> R.Tensor((32, 32), dtype="float32"): |
| cls = TestVMSubFunction |
| with R.dataflow(): |
| gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) |
| R.output(gv0) |
| return gv0 |
| |
| @R.function |
| def relax_matmul_packed( |
| x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") |
| ) -> R.Object: |
| gv0 = R.call_pure_packed( |
| "test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) |
| ) |
| return gv0 |
| |
| @R.function |
| def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Object: |
| cls = TestVMSubFunction |
| gv0 = cls.relax_matmul_tir(x, w) |
| gv1 = cls.relax_matmul_packed(gv0, gv0) |
| return gv1 |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(TestVMSubFunction, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_inp = tvm.runtime.tensor(np.random.rand(32, 32).astype(np.float32)) |
| y_inp = tvm.runtime.tensor(np.random.rand(32, 32).astype(np.float32)) |
| res = check_saved_func(vm, "main", x_inp, y_inp) |
| product = np.dot(x_inp.numpy(), y_inp.numpy()) |
| expected = product * product |
| tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) |
| |
| |
| def test_recursion(exec_mode): |
| @tvm.script.ir_module |
| class TestVMRecursion: |
| @R.function |
| def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: |
| cond = R.call_pure_packed( |
| "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) |
| ) |
| if cond: |
| res = R.const(1.0) |
| else: |
| gv0 = R.call_pure_packed( |
| "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) |
| ) |
| tmp = TestVMRecursion.recursion(gv0) |
| res = R.call_pure_packed( |
| "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) |
| ) |
| return res |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(TestVMRecursion, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| inp = np.empty(1).astype("float32") |
| recursion_runs = np.random.randint(1, 10) |
| inp.fill(recursion_runs) |
| inp = tvm.runtime.tensor(inp) |
| res = check_saved_func(vm, "recursion", inp) |
| tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) |
| |
| |
| @tvm.testing.requires_gpu |
| def test_vm_to_device(exec_mode): |
| @tvm.script.ir_module |
| class TestToVDevice: |
| @R.function |
| def foo1( |
| x: R.Tensor((2, 3), "float32"), |
| ) -> R.Tensor((2, 3), "float32"): |
| copied = R.to_vdevice(x, tvm.ir.VDevice("cuda", 0, "global")) |
| return copied |
| |
| @R.function |
| def foo2( |
| x: R.Tensor((2, 3), "float32"), |
| ) -> R.Tensor((2, 3), "float32"): |
| copied = R.to_vdevice(x, tvm.ir.VDevice("llvm", 0, "global")) |
| return copied |
| |
| mod = TestToVDevice |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) |
| res_1 = check_saved_func(vm, "foo1", x_inp) |
| res_2 = check_saved_func(vm, "foo2", x_inp) |
| |
| # check the copied tensor's device |
| assert res_1.device == tvm.cuda(0) |
| assert res_2.device == tvm.cpu(0) |
| |
| tvm.testing.assert_allclose(res_1.numpy(), x_inp.numpy()) |
| tvm.testing.assert_allclose(res_2.numpy(), x_inp.numpy()) |
| |
| |
| def test_vm_closure(exec_mode): |
| @tvm.script.ir_module |
| class TestClosure: |
| @R.function |
| def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): |
| return R.call_pure_packed("test.vm.add", x, env, sinfo_args=(R.Tensor())) |
| |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), "float32"), |
| y: R.Tensor((2, 3), "float32"), |
| ): |
| cls = TestClosure |
| clo = R.make_closure(cls.lifted_func_1, (x,)) |
| res = R.invoke_pure_closure(clo, (y,), sinfo_args=(R.Tensor())) |
| return res |
| |
| mod = TestClosure |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(mod, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) |
| y_inp = tvm.runtime.tensor(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) |
| res = check_saved_func(vm, "main", x_inp, y_inp) |
| tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) |
| |
| |
| def test_time_evaluator(exec_mode): |
| @tvm.script.ir_module |
| class TestTimeEvaluator: |
| @R.function |
| def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): |
| return R.call_pure_packed( |
| "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) |
| ) |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| ex = relax.build(TestTimeEvaluator, target, exec_mode=exec_mode) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x = tvm.runtime.tensor(np.random.rand(1).astype("float32")) |
| y = tvm.runtime.tensor(np.random.rand(1).astype("float32")) |
| |
| # ensure we can use time_evaluator with the stateful API |
| vm.set_input("main", x, y) |
| timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main") |
| # just checking that it has some results at all |
| assert timing_res.results |
| |
| # ensure we can use it with a closure |
| vm.save_function("main", "saved_main", x, y) |
| timing_res = vm.time_evaluator("saved_main", tvm.cpu())() |
| assert timing_res.results |
| |
| |
| @tvm.script.ir_module |
| class TestVMSetInput: |
| @T.prim_func |
| def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): |
| T.func_attr({"global_symbol": "test_vm_mul"}) |
| m = T.int32() |
| n = T.int32() |
| A = T.match_buffer(x, (m, n)) |
| B = T.match_buffer(y, (m, n)) |
| C = T.match_buffer(z, (m, n)) |
| |
| for i, j in T.grid(m, n): |
| with T.block("mul"): |
| vi = T.axis.spatial(m, i) |
| vj = T.axis.spatial(n, j) |
| with T.init(): |
| C[vi, vj] = T.float32(0) |
| C[vi, vj] = A[vi, vj] * B[vi, vj] |
| |
| # test returning a tuple |
| @R.function |
| def test_vm_tuple( |
| x: R.Tensor((), "int32"), |
| ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): |
| return (x, x) |
| |
| # nested tuple too |
| @R.function |
| def test_vm_nested_tuple( |
| x: R.Tensor((), "int32") |
| ) -> R.Tuple( |
| R.Tuple( |
| R.Tensor((), "int32"), |
| R.Tuple( |
| R.Tensor((), "int32"), |
| ), |
| ), |
| R.Tensor((), "int32"), |
| ): |
| return ((x, (x,)), x) |
| |
| @R.function |
| def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| cls = TestVMSetInput |
| gv0 = R.call_tir(cls.test_vm_mul, (x, w), R.Tensor((32, 32), dtype="float32")) |
| return gv0 |
| |
| |
| def test_multi_systemlib(exec_mode): |
| @tvm.script.ir_module |
| class ModA: |
| I.module_attrs({"system_lib_prefix": "libA_"}) |
| |
| @T.prim_func |
| def tir_init(x_handle: T.handle): |
| N = T.int64() |
| x = T.match_buffer(x_handle, [N], "float32") |
| for i in range(N): |
| x[i] = T.float32(0) |
| |
| @R.function |
| def main(s: R.Shape(["m"])) -> R.Tensor: |
| m = T.int64() |
| gv0 = R.call_tir(ModA.tir_init, (), R.Tensor((m + 1,), dtype="float32")) |
| return gv0 |
| |
| @tvm.script.ir_module |
| class ModB: |
| I.module_attrs({"system_lib_prefix": "libB_"}) |
| |
| @T.prim_func |
| def tir_init(x_handle: T.handle): |
| N = T.int64() |
| x = T.match_buffer(x_handle, [N], "float32") |
| for i in range(N): |
| x[i] = T.float32(1) |
| |
| @R.function |
| def main(s: R.Shape(["m"])) -> R.Tensor: |
| m = T.int64() |
| gv0 = R.call_tir(ModB.tir_init, (), R.Tensor((m,), dtype="float32")) |
| return gv0 |
| |
| target = tvm.target.Target("llvm", host="llvm") |
| libA = relax.build(ModA, target, exec_mode=exec_mode) |
| libB = relax.build(ModB, target, exec_mode=exec_mode) |
| |
| temp = utils.tempdir() |
| pathA = temp.relpath("libA.a") |
| pathB = temp.relpath("libB.a") |
| path_dso = temp.relpath("mylibAll.so") |
| libA.export_library(pathA, fcompile=cc.create_staticlib) |
| libB.export_library(pathB, fcompile=cc.create_staticlib) |
| |
| # package two static libs together |
| # check that they do not interfere with each other |
| # even though they have shared global var names |
| # intentionally craft same gvar function with different behaviors |
| cc.create_shared(path_dso, ["-Wl,--whole-archive", pathA, pathB, "-Wl,--no-whole-archive"]) |
| |
| def popen_check(): |
| # Load dll, will trigger system library registration |
| ctypes.CDLL(path_dso) |
| # Load the system wide library |
| vmA = relax.VirtualMachine(tvm.runtime.system_lib("libA_"), tvm.cpu()) |
| vmB = relax.VirtualMachine(tvm.runtime.system_lib("libB_"), tvm.cpu()) |
| |
| retA = vmA["main"](tvm.runtime.ShapeTuple([1])) |
| retB = vmB["main"](tvm.runtime.ShapeTuple([2])) |
| np.testing.assert_equal(retA.numpy(), np.array([0, 0]).astype("float32")) |
| np.testing.assert_equal(retB.numpy(), np.array([1, 1]).astype("float32")) |
| |
| # system lib should be loaded in different process |
| worker = popen_pool.PopenWorker() |
| worker.send(popen_check) |
| |
| |
| def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: |
| a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| vm.set_input("main", a, b) |
| vm.invoke_stateful("main") |
| res0 = vm.get_outputs("main") |
| |
| data_dict = {"x": a, "w": b} |
| vm.set_input("main", **data_dict) |
| vm.invoke_stateful("main") |
| res1 = vm.get_outputs("main") |
| tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) |
| tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) |
| |
| # bug! If you don't bind the Tensor to a var, the memory will get corrupted. |
| # Possibly due to object lifecycles and other FFI issues |
| a = tvm.runtime.tensor(np.array(2).astype("int32"), device) |
| vm.set_input("test_vm_tuple", a) |
| vm.invoke_stateful("test_vm_tuple") |
| res2 = vm.get_outputs("test_vm_tuple") |
| # the results are Tensors wrapped around scalars, |
| # so we have to get the scalar out of the Tensor |
| assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) |
| |
| b = tvm.runtime.tensor(np.array(1).astype("int32"), device) |
| vm.set_input("test_vm_nested_tuple", b) |
| vm.invoke_stateful("test_vm_nested_tuple") |
| res3 = vm.get_outputs("test_vm_nested_tuple") |
| assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1 |
| result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), int(res3[1].numpy())) |
| assert result_cast == ((1, (1,)), 1) |
| |
| |
| def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: |
| # this should fail: once you set inputs, you cannot run statelessly |
| a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| vm.set_input("main", a, b) |
| # must use invoke stateful! |
| vm["main"]() |
| |
| |
| def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: |
| # this should fail: if the function needs inputs, you can't invoke directly |
| vm.invoke_stateful("main") |
| |
| |
| def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: |
| # this should fail: you can't get outputs without invoking the function first |
| a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| vm.set_input("main", a, b) |
| _ = vm.get_outputs("main") |
| |
| |
| def make_vm(mod, exec_mode, temp) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: |
| """Returns a local VM for the given mod and the device""" |
| target = tvm.target.Target("llvm", host="llvm") |
| exec = relax.build(mod, target, exec_mode=exec_mode) |
| libname = temp.relpath("exec.so") |
| exec.export_library(libname) |
| exec_loaded = tvm.runtime.load_module(libname) |
| device = tvm.cpu() |
| return relax.VirtualMachine(exec_loaded, device), device |
| |
| |
| def run_on_rpc( |
| mod: tvm.IRModule, |
| trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None], |
| exec_mode: str, |
| ): |
| """ |
| Sets up a VM over localhost using the given mod and runs the given trial function. |
| The trial function should take a VM and a device |
| """ |
| target = tvm.target.Target("llvm", host="llvm") |
| exec = relax.build(mod, target, exec_mode=exec_mode) |
| temp = utils.tempdir() |
| path = temp.relpath("vm_library.so") |
| exec.export_library(path) |
| |
| # Use local rpc server for testing. |
| # Server must use popen so it doesn't inherit the current process state. It |
| # will crash otherwise. |
| def check_remote(server): |
| remote = rpc.connect(server.host, server.port, session_timeout=10) |
| |
| # Upload the serialized Executable. |
| remote.upload(path) |
| # Get a handle to remote Executable. |
| rexec = remote.load_module("vm_library.so") |
| |
| device = remote.cpu() |
| # Build a VM out of the executable and context. |
| vm = relax.VirtualMachine(rexec, device=device) |
| trial_func(vm, device) |
| |
| check_remote(rpc.Server("127.0.0.1")) |
| |
| |
| def test_set_input(exec_mode): |
| temp = utils.tempdir() |
| set_input_trial(*make_vm(TestVMSetInput, exec_mode, temp)) |
| |
| |
| def test_set_input_tuple(exec_mode): |
| @tvm.script.ir_module |
| class MyMod: |
| @R.function |
| def main(x: R.Tuple([R.Tensor((32,), "float32"), R.Tensor((32,), "float32")])) -> R.Tensor: |
| y = x[0] |
| return y |
| |
| temp = utils.tempdir() |
| vm, device = make_vm(MyMod, exec_mode, temp) |
| device = tvm.cpu(0) |
| a = tvm.runtime.empty((32,), "float32", device=device) |
| b = tvm.runtime.empty((32,), "float32", device=device) |
| vm.set_input("main", (a, b)) |
| vm.invoke_stateful("main") |
| |
| |
| def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: |
| # just checking that we can use kwargs for the args when saving a function |
| a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| vm.save_function("main", "saved_main", x=a, w=b) |
| res0 = vm["saved_main"]() |
| tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) |
| |
| |
| def test_save_function_kwargs(exec_mode): |
| temp = utils.tempdir() |
| save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode, temp)) |
| |
| |
| def test_save_function_kwargs_rpc(exec_mode): |
| run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode) |
| |
| |
| def save_function_time_evaluator_trial( |
| vm: relax.VirtualMachine, device: tvm.runtime.Device |
| ) -> None: |
| # just checking that the saved function can be called in the time evaluator |
| a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) |
| vm.save_function("main", "saved_main", a, b) |
| vm.time_evaluator("saved_main", device)() |
| |
| |
| def test_save_function_time_evaluator(exec_mode): |
| temp = utils.tempdir() |
| save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode, temp)) |
| |
| |
| def test_save_function_time_evaluator_rpc(exec_mode): |
| run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode) |
| |
| |
| # if you set an input, you should not be able to call statelessly |
| |
| |
| def test_set_input_stateless_failure(exec_mode): |
| temp = utils.tempdir() |
| args = make_vm(TestVMSetInput, exec_mode, temp) |
| with pytest.raises(RuntimeError): |
| set_input_attempt_stateless(*args) |
| |
| |
| def test_set_input_stateless_failure_rpc(exec_mode): |
| with pytest.raises(RuntimeError): |
| run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode) |
| |
| |
| def test_set_input_invoke_failure(exec_mode): |
| temp = utils.tempdir() |
| args = make_vm(TestVMSetInput, exec_mode, temp) |
| with pytest.raises(ValueError): |
| set_input_attempt_invoke(*args) |
| |
| |
| def test_set_input_invoke_failure_rpc(exec_mode): |
| with pytest.raises(RuntimeError): |
| run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode) |
| |
| |
| def test_set_input_get_failure(exec_mode): |
| temp = utils.tempdir() |
| args = make_vm(TestVMSetInput, exec_mode, temp) |
| with pytest.raises(ValueError): |
| set_input_attempt_get(*args) |
| |
| |
| def test_set_input_get_failure_rpc(exec_mode): |
| with pytest.raises(RuntimeError): |
| run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) |
| |
| |
| @tvm.testing.requires_gpu |
| def test_relax_module_with_multiple_targets(exec_mode): |
| """Relax functions may contain kernels for multiple targets |
| |
| In this example, the module contains one function to execute on |
| LLVM, and one function to execute on CUDA. |
| |
| """ |
| |
| @I.ir_module |
| class Module: |
| I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) |
| |
| @R.function |
| def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")): |
| C = R.add(A, B) |
| return C |
| |
| @R.function |
| def func_llvm( |
| A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm") |
| ): |
| C = R.add(A, B) |
| return C |
| |
| seq = tvm.ir.transform.Sequential( |
| [ |
| tvm.relax.transform.LegalizeOps(), |
| tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()), |
| ], |
| name="LegalizeAndSchedule", |
| ) |
| with tvm.target.Target("cuda"): |
| built = tvm.relax.build(seq(Module)) |
| |
| np_A = np.random.random([32, 32]).astype("float32") |
| np_B = np.random.random([32, 32]).astype("float32") |
| |
| dev_llvm = tvm.device("llvm") |
| vm_llvm = tvm.relax.VirtualMachine(built, device=dev_llvm) |
| llvm_output = vm_llvm["func_llvm"]( |
| tvm.runtime.tensor(np_A, dev_llvm), |
| tvm.runtime.tensor(np_B, dev_llvm), |
| ) |
| |
| dev_cuda = tvm.device("cuda") |
| vm_cuda = tvm.relax.VirtualMachine(built, device=dev_cuda) |
| |
| cuda_output = vm_cuda["func_cuda"]( |
| tvm.runtime.tensor(np_A, dev_cuda), |
| tvm.runtime.tensor(np_B, dev_cuda), |
| ) |
| |
| np_C = np_A + np_B |
| |
| tvm.testing.assert_allclose(llvm_output.numpy(), np_C) |
| tvm.testing.assert_allclose(cuda_output.numpy(), np_C) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |