| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import sys |
| from typing import Optional, Union |
| |
| import pytest |
| import tvm |
| import tvm.script |
| import tvm.testing |
| from tvm import IRModule, relax, tir, topi |
| from tvm.ir import VDevice, DummyGlobalInfo |
| from tvm.script.parser import ir as I |
| from tvm.script.parser import relax as R |
| from tvm.script.parser import tir as T |
| |
| |
| def _check( |
| parsed: Union[relax.Function, IRModule], |
| expect: Optional[Union[relax.Function, IRModule]] = None, |
| ): |
| test = parsed.script(show_meta=True) |
| roundtrip_mod = tvm.script.from_source(test) |
| tvm.ir.assert_structural_equal(parsed, roundtrip_mod) |
| if isinstance(parsed, IRModule) and isinstance(roundtrip_mod, IRModule): |
| assert relax.analysis.well_formed(parsed) |
| assert relax.analysis.well_formed(roundtrip_mod) |
| if expect: |
| tvm.ir.assert_structural_equal(parsed, expect) |
| |
| |
| def test_simple_func(): |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| R.func_attr({"Primitive": True}) |
| gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), dtype="float32")) |
| return gv1 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,), attrs={"Primitive": True}): |
| y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))) |
| out = bb.emit( |
| relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), dtype="float32")) |
| ) |
| bb.emit_func_output(out) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_error_report(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| # error: a = b = c is not allowed. |
| gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| return gv0 |
| |
| |
| def test_mismatch_cast_dims_and_ndim(): |
| with pytest.raises(Exception): |
| |
| @R.function |
| def f( |
| x: R.Tensor((2, 3), "float32", ndim=3), |
| ): # error: ndim and the shape dims are mismatch |
| return x |
| |
| |
| def test_unexpected_num_kw_args(): |
| with pytest.raises(Exception): |
| |
| @R.function |
| def f(x: R.Tensor(dtype="float32", ndim=1, foo=2)): # error: unexpected kw args foo |
| return x |
| |
| |
| def test_unexpected_ndim(): |
| with pytest.raises(Exception): |
| |
| @R.function |
| # error: dim is expected to be non-negative int or -1 for unknown |
| def f(x: R.Tensor(dtype="float32", ndim=-2)): |
| return x |
| |
| |
| def test_unexpected_ndim_type(): |
| with pytest.raises(Exception): |
| |
| @R.function |
| def f(x: R.Tensor(dtype="float32", ndim="1")): # error: dim is expected to be int |
| return x |
| |
| |
| def test_unexpected_tir_cast_args(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x: R.Tensor(("m",), "float32")): |
| m = T.int64() |
| # tir.cast expects 2 arguments, but got 3 |
| return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) |
| |
| |
| def test_unexpected_tir_args(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @tvm.script.ir_module |
| class TestWellCallTIR: |
| @T.prim_func |
| def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: |
| T.func_attr(({"global_symbol": "tir_addone"})) |
| for i, j in T.grid(16, 16): |
| with T.block("tir_addone"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] + T.int32(1) |
| |
| @R.function |
| def foo(x: R.Tensor(("m", "m"), "float32")): |
| m = T.int64() |
| # tir.max expects 2 arguments, but got 1 |
| gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32")) |
| return gv |
| |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x: R.Tensor(("m", "n"), "float32")): |
| m = T.int64() |
| # call_tir expected a tir prim_func |
| return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32")) |
| |
| |
| def test_func_type_annotation_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x, y): # error: the parameter type annotation is missing |
| z = R.add(x, y) |
| y = z |
| return y |
| |
| |
| def test_if_mismatch_var_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): |
| if cond: |
| w = R.add(x, x) |
| y = R.multiply(w, w) |
| else: |
| w = R.multiply(x, x) |
| z = R.add(w, w) # error: The binding var is expected to `y` |
| return z |
| |
| |
| def test_unassigned_call_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x: R.Tensor): |
| R.add(x, x) |
| return x |
| |
| |
| def test_incorrect_tensor_shape(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x: R.Tensor([16])): |
| y: R.Tensor(16) = R.add(x, x) |
| return y |
| |
| |
| def test_simple_module(): |
| @I.ir_module |
| class TestModule: |
| @T.prim_func(private=True) |
| def tir_func( |
| x: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| y: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| y[vi, vj] = x[vi, vj] + 1.0 |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| cls = TestModule |
| gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32")) |
| return gv0 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,), {"global_symbol": "foo"}): |
| out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") |
| bb.emit_func_output(out) |
| |
| _check(TestModule, bb.get()) |
| |
| |
| def test_emit_te_primfunc_attrs(): |
| @I.ir_module |
| class TestModule: |
| @T.prim_func(private=True) |
| def plus_one( |
| x: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| y: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"some_attr": "foo", "another_attr": True, "tir.noalias": True}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| y[vi, vj] = x[vi, vj] + 1.0 |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| cls = TestModule |
| gv0 = R.call_tir(cls.plus_one, x, R.Tensor((128, 128), dtype="float32")) |
| return gv0 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,), {"global_symbol": "foo"}): |
| out = bb.emit_te( |
| lambda x: x + 1, |
| x, |
| primfunc_name_hint="plus_one", |
| primfunc_attrs={"some_attr": "foo", "another_attr": True}, |
| ) |
| bb.emit_func_output(out) |
| _check(TestModule, bb.get()) |
| |
| |
| def test_emit_te(): |
| @I.ir_module |
| class EmitTE: |
| @R.function |
| def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32"): |
| lv1 = R.emit_te(topi.add, x, x) |
| out = R.emit_te(topi.multiply, lv1, lv1) |
| return out |
| |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) |
| with bb.function("main", [x], {"global_symbol": "main"}): |
| lv1 = bb.emit_te(topi.add, x, x) |
| out = bb.emit_te(topi.multiply, lv1, lv1) |
| bb.emit_func_output(out) |
| |
| _check(EmitTE, bb.get()) |
| |
| |
| def test_module_with_attr_and_global_info(): |
| @I.ir_module |
| class TestModule: |
| I.module_attrs({"attr": 10}) |
| I.module_global_infos( |
| { |
| "dummy": [ |
| I.dummy_global_info(), # dummy[0] |
| I.dummy_global_info(), # dummy[1] |
| ] |
| } |
| ) |
| |
| @T.prim_func(private=True) |
| def tir_func( |
| x: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| y: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| y[vi, vj] = x[vi, vj] + 1.0 |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| cls = TestModule |
| gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32")) |
| return gv0 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,), {"global_symbol": "foo"}): |
| out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") |
| bb.emit_func_output(out) |
| mod = bb.get() |
| mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()]) |
| mod = mod.with_attr("attr", 10) |
| _check(TestModule, mod) |
| |
| |
| def test_global_info_vdevice(): |
| vdevices = [ |
| VDevice("llvm"), |
| VDevice("cuda", 0), |
| VDevice("cuda -arch=sm_80", 0), |
| VDevice("metal", 0, "global"), |
| ] |
| |
| @I.ir_module |
| class TestModule: |
| I.module_attrs({"attr": 10}) |
| I.module_global_infos( |
| { |
| "vdevice": [ |
| I.vdevice("llvm"), |
| I.vdevice("cuda", 0), |
| I.vdevice("cuda -arch=sm_80", 0), |
| I.vdevice("metal", 0, "global"), |
| ] |
| } |
| ) |
| |
| @T.prim_func(private=True) |
| def tir_func( |
| x: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| y: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| y[vi, vj] = x[vi, vj] + 1.0 |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| cls = TestModule |
| gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32")) |
| return gv0 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") |
| bb.emit_func_output(out) |
| mod = bb.get() |
| mod.update_global_info("vdevice", vdevices) |
| mod = mod.with_attr("attr", 10) |
| _check(TestModule, mod) |
| |
| |
| def test_relax_tensor_op(): |
| @R.function |
| def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): |
| y = R.add(x, x) |
| z = R.multiply(x, y) |
| return z |
| |
| x = relax.Var("x", R.Tensor((4, 4), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| y = bb.emit(relax.op.add(x, x)) |
| z = bb.emit(relax.op.multiply(x, y)) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_relax_base_op(): |
| @R.function |
| def foo(x: R.Tensor((4, 4), "float32")): |
| alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32") |
| shape = R.shape_of(alloc) |
| return shape |
| |
| x = relax.Var("x", R.Tensor((4, 4), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) |
| shape = bb.emit(relax.op.shape_of(alloc)) |
| bb.emit_func_output(shape) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_relax_shape_to_tensor(): |
| @R.function |
| def foo(x: R.Shape((4, 4))): |
| tensor = R.shape_to_tensor(x) |
| return tensor |
| |
| x = relax.Var("x", R.Shape((4, 4))) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| tensor = bb.emit(relax.op.shape_to_tensor(x)) |
| bb.emit_func_output(tensor) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_symbolic_shape(): |
| @R.function |
| def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): |
| m = T.int64() |
| n = T.int64() |
| gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) |
| return gv0 |
| |
| @R.function |
| def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): |
| m = T.int64() |
| n = T.int64() |
| gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) |
| return gv0 |
| |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| m = T.int64() |
| n = T.int32() # The shape dtype should be int64 |
| gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) |
| return gv0 |
| |
| def _expected(name: str): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = relax.Var("x", R.Tensor([m, n], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function(name, (x,)): |
| out = bb.emit( |
| relax.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) |
| ) |
| bb.emit_func_output(out) |
| return bb.get()[name] |
| |
| _check(foo, _expected("foo")) |
| _check(bar, _expected("bar")) |
| |
| |
| def test_shadowing(): |
| @R.function |
| def foo(x: R.Tensor((4, 4), "float32")): |
| y = R.add(x, x) |
| z = R.multiply(x, y) |
| y = R.add(x, y) |
| y = z |
| y = R.multiply(y, x) |
| z = y |
| return z |
| |
| x = relax.Var("x", R.Tensor((4, 4), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| y = bb.emit(relax.op.add(x, x)) |
| z = bb.emit(relax.op.multiply(x, y)) |
| y = bb.emit(relax.op.add(x, y)) |
| y = bb.emit(z) |
| y = bb.emit(relax.op.multiply(y, x)) |
| z = bb.emit(y) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_match_cast(): |
| @R.function |
| def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): |
| m = T.int64() |
| n = T.int64() |
| x0 = R.match_cast(x, R.Tensor([m], "float32")) |
| with R.dataflow(): |
| y0 = R.match_cast(y, R.Tensor([n], "float32")) |
| gv = y0 |
| R.output(gv) |
| return (x0, R.shape([m, n * 2])) |
| |
| x = relax.Var("x", R.Tensor("float32")) |
| y = relax.Var("y", R.Tensor("float32")) |
| m = tir.Var("m", dtype="int64") |
| n = tir.Var("n", dtype="int64") |
| y2 = relax.Var("y", R.Tensor([n], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x, y)): |
| x0 = bb.match_cast(x, R.Tensor([m], "float32")) |
| with bb.dataflow(): |
| y0 = bb.match_cast(y, R.Tensor([n], "float32")) |
| bb.emit_output(y0) |
| bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])])) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_tuple_return(): |
| @R.function |
| def foo(x: R.Tensor((4, 4), "float32")): |
| gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) |
| gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) |
| return (gv0, gv1) |
| |
| x = relax.Var("x", R.Tensor((4, 4), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) |
| gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) |
| bb.emit_func_output(relax.Tuple((gv0, gv1))) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_tuple_return_2(): |
| @R.function |
| def foo(x: R.Tensor("float32", ndim=2)): |
| n, m = T.int64(), T.int64() |
| x0 = R.match_cast(x, R.Tensor((n, m), "float32")) |
| return (x0, R.shape([n + 1, m, 1])) |
| |
| x = relax.Var("x", R.Tensor("float32", ndim=2)) |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) |
| bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])])) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_tuple_binding(): |
| @R.function |
| def foo(x: R.Tensor("float32", ndim=2)): |
| n, m = T.int64(), T.int64() |
| x0 = R.match_cast(x, R.Tensor((n, m), "float32")) |
| t0 = (x, x0) |
| t1 = (x, R.shape([n, m]), t0) |
| return t1 |
| |
| x = relax.Var("x", R.Tensor("float32", ndim=2)) |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) |
| t0 = bb.emit(relax.Tuple([x, x0])) |
| t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0])) |
| bb.emit_func_output(t1) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_tuple_get_item(): |
| @R.function |
| def foo(x: R.Tensor, y: R.Tensor): |
| t1 = R.tuple(x, y) |
| t2 = (x, y) |
| a = t1[0] |
| b = R.TupleGetItem(t2, 1) |
| c = R.add(a, b) |
| return c |
| |
| x = relax.Var("x", R.Tensor()) |
| y = relax.Var("y", R.Tensor()) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x, y)): |
| t1 = bb.emit(relax.Tuple([x, y])) |
| t2 = bb.emit(relax.Tuple([x, y])) |
| a = bb.emit(relax.TupleGetItem(t1, 0)) |
| b = bb.emit(relax.TupleGetItem(t2, 1)) |
| c = bb.emit(relax.op.add(a, b)) |
| bb.emit_func_output(c) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_dataflow_block(): |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| with R.dataflow(): |
| lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) |
| gv = lv1 |
| R.output(gv) |
| return gv |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| with bb.dataflow(): |
| lv0 = bb.emit( |
| relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| ) |
| lv1 = bb.emit( |
| relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) |
| ) |
| gv = bb.emit_output(lv1) |
| bb.emit_func_output(gv) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_dataflow_block_advanced(): |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) |
| with R.dataflow(): |
| m = T.int64() |
| n = T.int64() |
| lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) |
| lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) |
| gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) |
| gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) |
| gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) |
| gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) |
| gv4 = gv3 |
| gv5 = gv2 |
| R.output(gv5, gv4) |
| gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) |
| gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) |
| return gv7 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| m = tir.Var("m", dtype="int64") |
| n = tir.Var("n", dtype="int64") |
| with bb.function("foo", (x,)): |
| gv0 = bb.emit( |
| relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| ) |
| gv1 = bb.emit( |
| relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) |
| ) |
| with bb.dataflow(): |
| lv0 = bb.emit( |
| relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) |
| ) |
| lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) |
| gv2 = bb.emit( |
| relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) |
| ) |
| gv21 = bb.emit( |
| relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) |
| ) |
| gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) |
| gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) |
| gv32 = bb.emit_output(gv31) |
| gv22 = bb.emit_output(gv21) |
| gv4 = bb.emit( |
| relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), dtype="float32")) |
| ) |
| gv5 = bb.emit( |
| relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), dtype="float32")) |
| ) |
| bb.emit_func_output(gv5) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_dataflow_binding_after_output(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| with R.dataflow(): |
| gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| R.output(gv) |
| lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32")) |
| return gv |
| |
| |
| def test_dataflow_output_global_var(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| with R.dataflow(): |
| gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) |
| R.output(gv0, gv1) |
| return gv1 |
| |
| |
| def test_dataflow_multiple_output(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| with R.dataflow(): |
| gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| R.output(gv) |
| R.output(gv) |
| return gv |
| |
| |
| def test_dataflow_output_outside_dataflow_block(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): |
| gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| R.output(gv) |
| return gv |
| |
| |
| def test_dataflow_scope_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(x: R.Tensor(ndim=2)): |
| with R.dataflow(): |
| y = R.add(x, x) |
| z = R.multiply(y, x) |
| w = R.add(z, x) |
| R.output(y, w) |
| t = R.multiply(y, z) # z is not in the outer scope |
| return t |
| |
| |
| def test_return_without_binding(): |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")): |
| return x |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| bb.emit_func_output(x) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_multiple_return(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")): |
| return x |
| return x |
| |
| |
| def test_function_without_return(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")): |
| gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) |
| |
| |
| def test_tensor_type_without_args(): |
| @R.function |
| def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32")) |
| return v |
| |
| x = relax.Var("x", R.Tensor((32, 32), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x)): |
| v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32"))) |
| bb.emit_func_output(v) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_tensor_with_vdevice(): |
| vdevices = [ |
| VDevice("llvm"), |
| VDevice("cuda", 0), |
| VDevice("metal", 0, "global"), |
| VDevice("cuda -arch=sm_80", 0), |
| ] |
| |
| @I.ir_module |
| class TestModule: |
| I.module_attrs({"attr": 10}) |
| I.module_global_infos( |
| { |
| "vdevice": [ |
| I.vdevice("llvm"), |
| I.vdevice("cuda", 0), |
| I.vdevice("metal", 0, "global"), |
| I.vdevice("cuda -arch=sm_80", 0), |
| ] |
| } |
| ) |
| |
| @R.function |
| def foo( |
| a: R.Tensor((128, 128), "float32", "cuda:1"), # noqa: F722 |
| b: R.Tensor((128, 128), "float32", "llvm"), |
| c: R.Tensor((128, 128), "float32", "vdevice:3"), # noqa: F722 |
| ) -> R.Tensor((128, 128), "float32", "cuda:1"): # noqa: F722 |
| s = R.add(a, c) |
| return s |
| |
| a = relax.Var("a", R.Tensor((128, 128), "float32", vdevices[3])) |
| b = relax.Var("b", R.Tensor((128, 128), "float32", vdevices[0])) |
| c = relax.Var("c", R.Tensor((128, 128), "float32", vdevices[3])) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (a, b, c)): |
| out = bb.emit(relax.op.add(a, c)) |
| bb.emit_func_output(out) |
| mod = bb.get() |
| mod = mod.with_attr("attr", 10) |
| mod.update_global_info("vdevice", vdevices) |
| |
| _check(TestModule, mod) |
| |
| |
| def test_direct_return(): |
| @R.function |
| def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): |
| return x |
| |
| x = relax.Var("x", R.Tensor((32, 32), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x)): |
| bb.emit_func_output(x) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_call_packed(): |
| @R.function(pure=False) |
| def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) |
| return z |
| |
| x = relax.Var("x", R.Tensor((32, 32), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x), pure=False): |
| z = bb.emit( |
| relax.Call( |
| relax.ExternFunc("vm.builtin.copy"), |
| (x,), |
| None, |
| sinfo_args=[R.Tensor((32, 32), "float32")], |
| ) |
| ) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_call_packed_without_sinfo_args(): |
| @R.function(pure=False) |
| def foo(x: R.Object) -> R.Object: |
| z = R.call_packed("test", x) |
| return z |
| |
| x = relax.Var("x", R.Object()) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x), pure=False): |
| z = bb.emit( |
| relax.Call( |
| relax.ExternFunc("test"), |
| (x,), |
| None, |
| sinfo_args=[], |
| ) |
| ) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_annotation(): |
| @R.function(pure=False) |
| def foo( |
| x: R.Tensor((32, "m"), "float32"), |
| y: R.Tensor(("m",), "float32"), |
| r: R.Tensor(dtype="int64"), |
| ) -> R.Object: |
| m = T.int64() |
| z: R.Tensor((32, m), "float32") = R.multiply(x, y) |
| w: R.Tensor(ndim=2) = R.multiply(z, z) |
| q: R.Tensor = R.add(w, w) |
| t = R.add(w, z) |
| sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) |
| lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) |
| o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) |
| return o |
| |
| def _check_struct_info(binding, expected_sinfo): |
| tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) |
| tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo) |
| |
| # Cannot use block builder here because we need to check the annotated type, |
| # which may be inconsistent with deduced type. |
| assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) |
| m = relax.get_shape_of(foo.params[0])[1] |
| bindings = foo.body.blocks[0].bindings |
| sh = bindings[4].var |
| |
| _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) |
| _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2)) |
| _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1)) |
| _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2)) |
| _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) |
| _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) |
| _check_struct_info(bindings[6], relax.ObjectStructInfo()) |
| |
| |
| def test_annotate_override(): |
| @R.function |
| def foo(x: R.Tensor): |
| y = x |
| # z will be treated as object type even though it's a tensor |
| z: R.Object = R.add(x, y) |
| return z |
| |
| assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) |
| y_bind, z_bind = foo.body.blocks[0].bindings |
| assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo) |
| assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo) |
| |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def test(x: R.Tensor): |
| # Error: x is of Tensor StructInfo, which can not annotate to R.Shape. |
| z: R.Shape = x |
| return z |
| |
| @R.function |
| def bar(x: R.Tensor): |
| # x is of Tensor StructInfo, the annotation of `z` is ignored. |
| z: R.Object = x |
| return z |
| |
| assert isinstance(bar.ret_struct_info, relax.TensorStructInfo) |
| (z_bind,) = bar.body.blocks[0].bindings |
| assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) |
| |
| |
| def test_call_dps_packed_empty_shape(): |
| @R.function |
| def foo(x: R.Tensor((), "float32")): |
| z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32")) |
| return z |
| |
| (z_bind,) = foo.body.blocks[0].bindings |
| shape_expr = z_bind.value.sinfo_args[0].shape |
| |
| assert isinstance(shape_expr, relax.ShapeExpr) |
| assert len(shape_expr.values) == 0 |
| |
| |
| def test_call_tir_empty_tuple_arg(): |
| bb = relax.BlockBuilder() |
| dummy_param = relax.Var("dummy_param", R.Tensor(())) |
| with bb.function("foo", [dummy_param], {"global_symbol": "foo"}): |
| output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) |
| bb.emit_func_output(output) |
| |
| _check(bb.get()) |
| |
| |
| def test_call_tir_with_tir_var(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def main( |
| dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), "float32") |
| ) -> R.Tensor(("n * 2",), "float32"): |
| n = T.int64() |
| cls = Module |
| y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), tir_vars=(n,)) |
| return y |
| |
| @T.prim_func |
| def copy(var_x: T.handle, var_y: T.handle, n: T.int64): |
| X = T.match_buffer(var_x, (n * 2,), dtype="float32") |
| Y = T.match_buffer(var_y, (n * 2,), dtype="float32") |
| for i in T.grid(n * 2): |
| with T.block("block"): |
| vi = T.axis.remap("S", [i]) |
| Y[vi] = X[vi] |
| |
| _check(Module) |
| |
| |
| def test_call_tir_with_grad(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def identity_tir(a: T.handle, b: T.handle) -> None: |
| A = T.match_buffer(a, [54, 96]) |
| B = T.match_buffer(b, [54, 96]) |
| |
| for i, j in T.grid(54, 96): |
| with T.block("compute"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] |
| |
| @R.function |
| def main(v0: R.Tensor([54, 96], "float32")): |
| cls = Module |
| out = R.call_tir_with_grad( |
| cls.identity_tir, |
| (v0,), |
| R.Tensor((54, 96), "float32"), |
| te_grad_name="identity_k_grad", |
| te_grad_kwargs={"k": 1.0}, |
| ) |
| return out |
| |
| _check(Module) |
| |
| |
| def test_call_tir_inplace(): |
| @tvm.script.ir_module |
| class Module: |
| @T.prim_func |
| def copy( |
| A: T.Buffer((2, 3), "int32"), |
| B: T.Buffer((2, 3), "int32"), |
| out1: T.Buffer((2, 3), "int32"), |
| ): |
| # copies the contents of B into A and out1 |
| T.func_attr({"tir.noalias": True}) |
| for i0, i1 in T.grid(T.int64(2), T.int64(3)): |
| with T.block("T_zeros"): |
| ax0, ax1 = T.axis.remap("SS", [i0, i1]) |
| T.reads(B[ax0, ax1]) |
| T.writes(A[ax0, ax1], out1[ax0, ax1]) |
| A[ax0, ax1] = B[ax0, ax1] |
| out1[ax0, ax1] = B[ax0, ax1] |
| |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") |
| ) -> R.Tuple( |
| R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") |
| ): |
| res = R.call_tir_inplace( |
| Module.copy, |
| (x, y), |
| [0, -1], |
| [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], |
| ) |
| return res |
| |
| _check(Module) |
| |
| |
| def test_call_tir_inplace_with_tuple_var_raises_error(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @tvm.script.ir_module |
| class Module: |
| @R.function |
| def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")): |
| cls = Module |
| args = (x, y) |
| res = R.call_tir_inplace( |
| cls.copy, |
| # The `args` tuple must be an in-line tuple, not a |
| # reference to a tuple. This error should be |
| # caught and raised during parsing. |
| args, |
| inplace_indices=[0, -1], |
| out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], |
| ) |
| return res |
| |
| @T.prim_func |
| def copy( |
| A: T.Buffer((2, 3), "int32"), |
| B: T.Buffer((2, 3), "int32"), |
| out1: T.Buffer((2, 3), "int32"), |
| ): |
| # copies the contents of B into A and out1 |
| T.func_attr({"tir.noalias": True}) |
| for iters in T.grid(T.int64(2), T.int64(3)): |
| with T.block("T_zeros"): |
| i, j = T.axis.remap("SS", iters) |
| A[i, j] = B[i, j] |
| out1[i, j] = B[i, j] |
| |
| |
| def test_local_function(): |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") |
| ) -> R.Tensor((2, 3), "float32"): |
| @R.function |
| def outer_func( |
| c1: R.Tensor((2, 3), "float32") |
| ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): |
| @R.function |
| def inner_func(x1: R.Tensor((2, 3), "float32")): |
| s: R.Tensor((2, 3), "float32") = R.add(x1, c1) |
| return s |
| |
| return inner_func |
| |
| in_call = outer_func(x) |
| res = in_call(y) |
| return res |
| |
| main_bindings = main.body.blocks[0].bindings |
| assert len(main_bindings) == 3 |
| outer_func = main_bindings[0].value |
| assert isinstance(outer_func, relax.Function) |
| |
| outer_func_bindings = outer_func.body.blocks[0].bindings |
| assert len(outer_func_bindings) == 1 |
| inner_func = outer_func_bindings[0].value |
| assert isinstance(inner_func, relax.Function) |
| |
| |
| def test_inline_prim_func(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @I.ir_module |
| class TestModule: |
| @R.function |
| def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): |
| @T.prim_func |
| def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.match_buffer(b, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| |
| for i, j, k in T.grid(128, 128, 128): |
| with T.block(): |
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| with T.init(): |
| C[vi, vj] = 0.0 |
| C[vi, vj] += A[vi, vk] * B[vj, vk] |
| |
| z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) |
| return z |
| |
| |
| def test_cross_function_call(): |
| @I.ir_module |
| class Mod0: |
| @R.function |
| def foo(x: R.Tensor((10, 5), "float32")): |
| s = R.add(x, x) |
| return s |
| |
| @R.function |
| def main(x: R.Tensor((10, 5), "float32")): |
| cls = Mod0 |
| inner = cls.foo |
| gv1 = inner(x) |
| gv2 = Mod0.foo(x) |
| return (inner, gv1, gv2) |
| |
| @I.ir_module |
| class Mod1: |
| @R.function |
| def main(x: R.Tensor((10, 5), "float32")): |
| cls = Mod1 |
| inner = cls.foo |
| gv1 = inner(x) |
| gv2 = Mod1.foo(x) |
| return (inner, gv1, gv2) |
| |
| @R.function |
| def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): |
| s = R.add(x, x) |
| return s |
| |
| |
| def test_if_branch(): |
| @R.function |
| def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): |
| if cond: |
| w = R.add(x, x) |
| y = R.multiply(w, w) |
| else: |
| w = R.multiply(x, x) |
| y = R.add(w, w) |
| return y |
| |
| cond, x = foo.params |
| y_bind = foo.body.blocks[0].bindings[0] |
| y, ite = y_bind.var, y_bind.value |
| |
| assert isinstance(y, relax.Var) |
| assert y.name_hint == "y" |
| |
| assert isinstance(ite, relax.If) |
| assert isinstance(ite.true_branch, relax.SeqExpr) |
| assert isinstance(ite.false_branch, relax.SeqExpr) |
| |
| def check_call(call, op, args): |
| assert isinstance(call, relax.Call) |
| if isinstance(op, str): |
| assert call.op.name == op |
| else: |
| assert call.op == op |
| tvm.ir.assert_structural_equal(call.args, args) |
| |
| w_bind = ite.true_branch.blocks[0].bindings[0] |
| # the seq exprts in the branches are normalized to bind any call |
| # in the seq expr "body" to a var |
| y_bind = ite.true_branch.blocks[-1].bindings[-1] |
| assert w_bind.var.name_hint == "w" |
| check_call(w_bind.value, "relax.add", [x, x]) |
| check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) |
| |
| w_bind = ite.false_branch.blocks[0].bindings[0] |
| y_bind = ite.false_branch.blocks[-1].bindings[-1] |
| assert w_bind.var.name_hint == "w" |
| check_call(w_bind.value, "relax.multiply", [x, x]) |
| check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) |
| |
| |
| def test_if_branch_with_match_cast(): |
| """The last branch of a relax::If node may be a MatchCast |
| |
| This is a regression test. In previous implementations, using |
| R.match_cast as the last binding would cause a segfault while |
| parsing. |
| """ |
| |
| @R.function |
| def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")): |
| if is_bfloat16: |
| A = R.match_cast(A, R.Tensor([16, 16], "bfloat16")) |
| B = A.astype("float16") |
| else: |
| B = R.match_cast(A, R.Tensor([16, 16], "float16")) |
| return B |
| |
| A, is_bfloat16 = func.params |
| (block,) = func.body.blocks |
| (B_binding,) = block.bindings |
| |
| B_var = B_binding.var |
| assert isinstance(B_var, relax.Var) |
| assert B_var.name_hint == "B" |
| |
| if_then_else = B_binding.value |
| assert isinstance(if_then_else, relax.If) |
| assert isinstance(if_then_else.true_branch, relax.SeqExpr) |
| assert isinstance(if_then_else.false_branch, relax.SeqExpr) |
| |
| else_branch = if_then_else.false_branch |
| (else_block,) = else_branch.blocks |
| |
| assert isinstance(else_block.bindings[-1], relax.MatchCast) |
| |
| # If the `R.match_cast` were removed, the function would infer the |
| # return value as `R.Tensor([16,16])`, with an unknown dtype. |
| # With the `R.match_cast` retained, the output dtype is known. |
| tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16")) |
| |
| |
| def test_if_inside_dataflow(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): |
| with R.dataflow(): |
| if cond: |
| w = R.add(x, x) |
| y = R.multiply(w, w) |
| else: |
| w = R.multiply(x, x) |
| y = R.add(w, w) |
| R.output(y) |
| return y |
| |
| |
| def test_var_if_scoping_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): |
| if cond: |
| w = R.add(x, x) |
| y = R.multiply(w, w) |
| else: |
| w = R.multiply(x, x) |
| y = R.add(w, w) |
| return w # error: The w is not defined in the outer scope |
| |
| |
| def test_if_branch_var_scope(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): |
| if cond: |
| w = R.add(x, x) |
| y = R.multiply(w, w) |
| else: |
| w = R.multiply(x, x) |
| y = R.add(w, w) |
| return w |
| |
| |
| def test_scalar_tensor_as_branch_condition(): |
| """Branch condition can be 0-d tensor""" |
| |
| @R.function |
| def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): |
| if cond: |
| out = R.add(x, x) |
| else: |
| out = R.multiply(x, x) |
| return out |
| |
| if_else = func.body.blocks[0].bindings[0].value |
| assert isinstance(if_else.cond, relax.Var) |
| tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([], "bool")) |
| |
| |
| def test_prim_value_as_branch_condition(): |
| """In addition to scalar tensor, can use R.Prim condition""" |
| |
| @R.function |
| def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): |
| if cond: |
| out = R.add(x, x) |
| else: |
| out = R.multiply(x, x) |
| return out |
| |
| if_else = func.body.blocks[0].bindings[0].value |
| assert isinstance(if_else.cond, relax.Var) |
| tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool")) |
| |
| |
| def test_computed_prim_value_as_branch_condition(): |
| """The R.Prim condition may be computed within the function""" |
| |
| @R.function |
| def func(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| if R.prim_value(N % 16 == 0): |
| out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| else: |
| out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| N = func.params[0].struct_info.shape[0] |
| if_else = func.body.blocks[0].bindings[0].value |
| assert isinstance(if_else.cond, relax.PrimValue) |
| tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value) |
| tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N % 16 == 0)) |
| |
| |
| def test_tir_expr_as_branch_condition(): |
| """Syntactic sugar, wrap PrimExpr as PrimValue""" |
| |
| @R.function(private=True) |
| def sugared(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| if N % 16 == 0: |
| out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| else: |
| out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| @R.function(private=True) |
| def unsugared(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| if R.prim_value(N % 16 == 0): |
| out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| else: |
| out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| tvm.ir.assert_structural_equal(unsugared, sugared) |
| |
| |
| def test_scalar_tensor_as_assert_condition(): |
| """Branch condition can be 0-d tensor""" |
| |
| @R.function(pure=False) |
| def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): |
| _ = R.assert_op(cond) |
| out = R.add(x, x) |
| return out |
| |
| assert_op = func.body.blocks[0].bindings[0].value |
| condition = assert_op.args[0] |
| assert isinstance(condition, relax.Var) |
| tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool")) |
| |
| |
| def test_prim_value_as_assert_condition(): |
| """In addition to scalar tensor, can use R.Prim condition""" |
| |
| @R.function(pure=False) |
| def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): |
| _ = R.assert_op(cond) |
| out = R.add(x, x) |
| return out |
| |
| assert_op = func.body.blocks[0].bindings[0].value |
| condition = assert_op.args[0] |
| assert isinstance(condition, relax.Var) |
| tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool")) |
| |
| |
| def test_computed_prim_value_as_assert_condition(): |
| """The R.Prim condition may be computed within the function""" |
| |
| @R.function(pure=False) |
| def func(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| _ = R.assert_op(R.prim_value(N % 16 == 0)) |
| out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| N = func.params[0].struct_info.shape[0] |
| assert_op = func.body.blocks[0].bindings[0].value |
| condition = assert_op.args[0] |
| assert isinstance(condition, relax.PrimValue) |
| tvm.ir.assert_structural_equal(N % 16 == 0, condition.value) |
| tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16 == 0)) |
| |
| |
| def test_tir_expr_as_assert_condition(): |
| """Syntactic sugar, wrap PrimExpr as PrimValue""" |
| |
| @R.function(pure=False, private=True) |
| def sugared(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| _ = R.assert_op(N % 16 == 0) |
| out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| @R.function(pure=False, private=True) |
| def unsugared(x: R.Tensor(["N"], "float32")): |
| N = T.int64() |
| _ = R.assert_op(R.prim_value(N % 16 == 0)) |
| out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) |
| return out |
| |
| tvm.ir.assert_structural_equal(unsugared, sugared) |
| |
| |
| def test_erase_to_well_defined_removes_internal_vars(): |
| @R.function |
| def foo(x: R.Tensor): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) |
| assert foo.ret_struct_info.shape is None |
| _check(foo) |
| |
| |
| def test_erase_to_well_defined_keeps_variables_exposed_by_tensor_shape(): |
| @R.function |
| def foo(x: R.Tensor(["m", "n"])): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| assert foo.ret_struct_info.shape is not None |
| _check(foo) |
| |
| |
| def test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr(): |
| @R.function |
| def foo(x: R.Tensor, _: R.Shape(["m", "n"])): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| assert foo.ret_struct_info.shape is not None |
| _check(foo) |
| |
| |
| def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value(): |
| @R.function |
| def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| assert foo.ret_struct_info.shape is not None |
| _check(foo) |
| |
| |
| def test_erase_to_well_defined_infers_from_shape_expr(): |
| @I.ir_module |
| class Module: |
| # The subroutine's symbolic variables are only in-scope for the subroutine. |
| @R.function |
| def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| # However, struct inference can make the symbolic variables in |
| # the main function to the symbolic variables in the |
| # subroutine. Therefore, the shape of the tensor returned |
| # from main can have a well-defined shape. |
| @R.function |
| def main(x: R.Tensor, shape: R.Shape(["m", "n"])): |
| output = Module.subroutine(x, shape) |
| return output |
| |
| assert Module["main"].ret_struct_info.shape is not None |
| _check(Module) |
| |
| |
| def test_erase_to_well_defined_infers_from_prim_value(): |
| @I.ir_module |
| class Module: |
| # The subroutine's symbolic variables are only in-scope for the subroutine. |
| @R.function |
| def subroutine( |
| x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n") |
| ) -> R.Tensor(["m", "n"]): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| # However, struct inference can make the symbolic variables in |
| # the main function to the symbolic variables in the |
| # subroutine. Therefore, the shape of the tensor returned |
| # from main can have a well-defined shape. |
| @R.function |
| def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: R.Prim(value="n")): |
| output = Module.subroutine(x, relax_m, relax_n) |
| return output |
| |
| assert Module["main"].ret_struct_info.shape is not None |
| _check(Module) |
| |
| |
| def test_empty_tuple(): |
| @R.function |
| def foo(x: R.Tuple()): |
| y: R.Tuple() = R.tuple() |
| return y |
| |
| x = relax.Var("x", relax.TupleStructInfo([])) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x,)): |
| y = bb.emit(relax.Tuple([])) |
| bb.emit_func_output(y) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_symbolic_vars_in_tensor_shape_with_usage_first(): |
| """First param may use symbolic variable defined in second param""" |
| |
| @R.function |
| def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): |
| z = R.add(x, y) |
| return z |
| |
| m = tir.Var("m", "int64") |
| x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) |
| y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x, y)): |
| z = bb.emit(relax.op.add(x, y)) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_symbolic_vars_in_tensor_shape_with_definition_first(): |
| """Second param may use symbolic variable defined in first param""" |
| |
| @R.function |
| def bar( |
| x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") |
| ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): |
| m = T.int64() |
| z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) |
| return z |
| |
| m = tir.Var("m", "int64") |
| x = relax.Var("x", relax.TensorStructInfo([m], "float32")) |
| y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("bar", (x, y)): |
| z = bb.emit( |
| relax.call_dps_packed( |
| "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32") |
| ) |
| ) |
| bb.emit_func_output(z) |
| |
| _check(bar, bb.get()["bar"]) |
| |
| |
| def test_symbolic_vars_in_shape(): |
| """Symbolic variable may be defined in R.Shape""" |
| |
| @R.function |
| def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): |
| m = T.int64() |
| z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) |
| return z |
| |
| m = tir.Var("m", "int64") |
| x = relax.Var("x", relax.ShapeStructInfo([m])) |
| y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("baz", (x, y)): |
| z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) |
| bb.emit_func_output(z) |
| |
| _check(baz, bb.get()["baz"]) |
| |
| |
| def test_symbolic_vars_in_prim_value(): |
| """Symbolic variable may be defined in R.Prim""" |
| |
| @R.function |
| def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")): |
| m = T.int64() |
| z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) |
| return z |
| |
| m = tir.Var("m", "int64") |
| x = relax.Var("x", relax.PrimStructInfo(value=m)) |
| y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("baz", (x, y)): |
| z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) |
| bb.emit_func_output(z) |
| |
| _check(baz, bb.get()["baz"]) |
| |
| |
| def test_undefined_symbolic_var_raises_error(): |
| """An undefined symbolic variable in an error |
| |
| A symbolic variables is defined at the first site where it appears |
| as a shape parameter without any modification. TVMScript does not |
| support solving for a symbolic variable in terms of the argument |
| shape. That is, this test case raises an error, and will not |
| attempt to define `m` as either `x.shape[0]-1` or `x.shape[1]//2`. |
| """ |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function |
| def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined |
| z = R.add(x, x) |
| return z |
| |
| |
| def test_arith_operators(): |
| @R.function |
| def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): |
| a0 = -x |
| a1 = x + y |
| a2 = x - y |
| a3 = x * y |
| a4 = x / y |
| a5 = x // y |
| a6 = x**y |
| |
| c0 = x > y |
| c1 = x < y |
| c2 = x >= y |
| c3 = x <= y |
| |
| tuple_expr = ((x, x), y) |
| t0 = tuple_expr[0] |
| t1 = tuple_expr[1] |
| t2 = tuple_expr[0][0] # <= Will normalize to two bindings |
| return (a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2) |
| |
| m = tir.Var("m", "int64") |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) |
| y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x, y)): |
| a0 = bb.emit(relax.op.negative(x)) |
| a1 = bb.emit(relax.op.add(x, y)) |
| a2 = bb.emit(relax.op.subtract(x, y)) |
| a3 = bb.emit(relax.op.multiply(x, y)) |
| a4 = bb.emit(relax.op.divide(x, y)) |
| a5 = bb.emit(relax.op.floor_divide(x, y)) |
| a6 = bb.emit(relax.op.power(x, y)) |
| |
| c0 = bb.emit(relax.op.greater(x, y)) |
| c1 = bb.emit(relax.op.less(x, y)) |
| c2 = bb.emit(relax.op.greater_equal(x, y)) |
| c3 = bb.emit(relax.op.less_equal(x, y)) |
| |
| tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y))) |
| t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0)) |
| t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1)) |
| tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0)) |
| t2 = bb.emit(relax.TupleGetItem(tmp, 0)) |
| bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2))) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_memory_ops(): |
| @R.function |
| def foo(x: R.Tensor(("m", "n"), dtype="float32")): |
| m = T.int64() |
| n = T.int64() |
| storage = R.memory.alloc_storage( |
| R.shape([4 * m * n]), virtual_device_index=0, storage_scope="global", dtype="float32" |
| ) |
| alloc = R.memory.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32") |
| tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) |
| gv = tensor |
| return alloc, gv |
| |
| _check(foo) |
| |
| |
| def test_vm_ops(): |
| @R.function(pure=False) |
| def foo(x: R.Tensor(("m", "n"), dtype="float32")): |
| m = T.int64() |
| n = T.int64() |
| storage = R.vm.alloc_storage(R.shape([4 * m * n]), runtime_device_index=0, dtype="uint8") |
| alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32") |
| tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) |
| tir_dym = R.vm.call_tir_dyn("te_func", (x, tensor, R.ShapeExpr((m, n)))) |
| return alloc, tir_dym |
| |
| _check(foo) |
| |
| |
| def test_builtin_ops(): |
| @R.function |
| def foo(x: R.Tensor(("m", "n"), dtype="float32")): |
| tensor = R.builtin.stop_lift_params(x) |
| gv = tensor |
| return gv |
| |
| _check(foo) |
| |
| |
| def test_prim_value(): |
| @R.function(pure=False) |
| def foo(): |
| gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) |
| return gv |
| |
| _check(foo) |
| |
| |
| def test_string_imm(): |
| @R.function(pure=False) |
| def foo(): |
| gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) |
| return gv |
| |
| _check(foo) |
| |
| |
| def test_datatype_imm(): |
| @R.function(pure=False) |
| def foo(): |
| gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) |
| return gv |
| |
| _check(foo) |
| |
| |
| def test_function_void_return_type(): |
| @tvm.script.ir_module |
| class Foo: |
| @R.function |
| def main(x: R.Tensor((3, 3), dtype="float32")): |
| res = Foo.mul(x) |
| return res |
| |
| @R.function |
| def mul(x: R.Tensor((3, 3), dtype="float32")): |
| res = R.multiply(x, x) |
| return res |
| |
| _check(Foo) |
| # Since the return type of function `mul` is not annotated, |
| # the function `main` regards it as a generic return type. |
| assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo) |
| assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo) |
| |
| @tvm.script.ir_module |
| class Bar: |
| @R.function |
| def main(x1: R.Tensor((3, 3), dtype="float32")): |
| res1 = Bar.mul(x1) |
| return res1 |
| |
| @R.function |
| def mul(x: R.Tensor((3, 3), dtype="float32")) -> None: |
| res = R.multiply(x, x) |
| return res |
| |
| # Since the return type of function `mul` is not annotated, |
| # the function `main` regards it as a generic return type. |
| _check(Bar) |
| tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([])) |
| tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([])) |
| |
| |
| def test_class_normalize(): |
| @tvm.script.ir_module |
| class InputModule: |
| @R.function |
| def mul_add(x: R.Tensor) -> R.Tensor: |
| return R.multiply(R.add(x, x), R.add(x, x)) |
| |
| # The parser automatically normalizes the input AST to the following ANF form |
| @tvm.script.ir_module |
| class OutputModule: |
| @R.function |
| def mul_add(x: R.Tensor) -> R.Tensor: |
| gv = R.add(x, x) |
| gv1 = R.add(x, x) |
| return R.multiply(gv, gv1) |
| |
| _check(InputModule, OutputModule) |
| |
| |
| def test_context_aware_parsing(monkeypatch): |
| @tvm.script.ir_module |
| class Module: |
| @T.prim_func |
| def add( |
| X: T.Buffer([T.int64(2), T.int64(4)], "float32"), |
| Y: T.Buffer((), "float32"), |
| Z: T.Buffer([T.int64(2), T.int64(4)], "float32"), |
| ): |
| T.evaluate(0) |
| |
| @R.function |
| def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): |
| R.func_attr({"relax.force_pure": True}) |
| cls = Module |
| alloc = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) |
| _: R.Tuple() = cls.add(x, R.const(1, "float32"), alloc) |
| return alloc |
| |
| _check(Module) |
| |
| # Break the env settings, but context-aware parsing can still handle it |
| def _break_env(self, *args): |
| raise RuntimeError("Fail to pass context-aware parsing") |
| |
| monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env) |
| |
| _check(Module) |
| |
| |
| def test_unit_tuple_on_rhs_of_assign(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def main(input: R.Tensor((5, 5))) -> R.Tuple(R.Tensor((5, 5))): |
| gv = (input,) |
| return gv |
| |
| _check(Module) |
| |
| |
| def test_empty_tuple_on_rhs_of_assign(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def main(input: R.Tensor((5, 5))) -> R.Tuple(): |
| gv = () |
| return gv |
| |
| _check(Module) |
| |
| |
| def test_global_var_sinfo(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def foo(x: R.Tensor((128, 128), "float32")): |
| gv0 = R.emit_te(topi.add, x, x) |
| return gv0 |
| |
| target_sinfo = R.Callable( |
| (R.Tensor((128, 128), dtype="float32"),), R.Tensor((128, 128), dtype="float32") |
| ) |
| gv = Module.get_global_var("foo") |
| tvm.ir.assert_structural_equal(gv.struct_info, target_sinfo) |
| tvm.ir.assert_structural_equal(Module["foo"].struct_info, target_sinfo) |
| _check(Module) |
| |
| |
| def test_assert_op(): |
| @I.ir_module |
| class AssertOp: |
| @R.function(pure=False) |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") |
| return x |
| |
| _check(AssertOp) |
| |
| |
| def test_assert_outside_of_class(): |
| @R.function(pure=False) |
| def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") |
| return x |
| |
| # this just makes sure that the machinery regarding the pure attribute parses |
| # in the case where the function is outside of a class too |
| _check(func) |
| |
| |
| def test_impure_inner_function(): |
| @R.function |
| def f(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| # we will not actually call it |
| @R.function(pure=False) |
| def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}") |
| return y |
| |
| return x |
| |
| assert f.is_pure |
| # definition of g |
| assert not f.body.blocks[0].bindings[0].value.is_pure |
| |
| # make sure we are not incorrectly passing state for inner functions |
| _check(f) |
| |
| |
| def test_impure_inner_function_in_class(): |
| @I.ir_module |
| class ImpureInner: |
| @R.function |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| # we will not actually call it |
| @R.function(pure=False) |
| def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}") |
| return y |
| |
| return x |
| |
| assert ImpureInner["main"].is_pure |
| # definition of g |
| assert not ImpureInner["main"].body.blocks[0].bindings[0].value.is_pure |
| |
| # make sure we are not incorrectly passing state for inner functions |
| _check(ImpureInner) |
| |
| |
| def test_print(): |
| @I.ir_module |
| class Print: |
| @R.function(pure=False) |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.print(x, format="x: {}") |
| return x |
| |
| _check(Print) |
| |
| |
| def test_parse_multiple_pure_and_impure_funcs(): |
| @I.ir_module |
| class Mixture: |
| @R.function(pure=False) |
| def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.print(x, format="x: {}") |
| return x |
| |
| @R.function(pure=False) |
| def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") |
| return x |
| |
| @R.function |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| return x |
| |
| assert not Mixture["print"].is_pure |
| assert not Mixture["assert_func"].is_pure |
| assert Mixture["main"].is_pure |
| _check(Mixture) |
| |
| |
| def test_function_with_void_return_type_may_be_used_as_statements(): |
| """Void return of calls do not need to be assigned""" |
| |
| @I.ir_module |
| class Unsugared: |
| @R.function(pure=False) |
| def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.print(x, format="x: {}") |
| return x |
| |
| @R.function(pure=False) |
| def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") |
| return x |
| |
| @I.ir_module |
| class Sugared: |
| @R.function(pure=False) |
| def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| R.print(x, format="x: {}") |
| return x |
| |
| @R.function(pure=False) |
| def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") |
| return x |
| |
| tvm.ir.assert_structural_equal(Unsugared, Sugared) |
| |
| |
| def test_function_with_non_void_return_type_must_be_assigned(): |
| """Non-void results must be assigned to a variable""" |
| |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function(pure=False) |
| def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| R.add(x, x) |
| return x |
| |
| |
| def test_function_with_void_return_type_in_if_else(): |
| """Last statement in if/else may be a void return""" |
| |
| @I.ir_module |
| class Unsugared: |
| @R.function(pure=False) |
| def conditional( |
| x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") |
| ) -> R.Tensor((), "int32"): |
| if condition: |
| y = R.print(x, format="True condition: {}") |
| else: |
| y = R.print(x, format="False condition: {}") |
| return x |
| |
| @I.ir_module |
| class Sugared: |
| @R.function(pure=False) |
| def conditional( |
| x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") |
| ) -> R.Tensor((), "int32"): |
| if condition: |
| R.print(x, format="True condition: {}") |
| else: |
| R.print(x, format="False condition: {}") |
| return x |
| |
| _check(Sugared, Unsugared) |
| |
| |
| def test_call_pure_packed(): |
| @R.function |
| def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| z = R.call_pure_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) |
| return z |
| |
| x = relax.Var("x", R.Tensor((32, 32), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("foo", (x)): |
| z = bb.emit( |
| R.call_pure_packed("vm.builtin.copy", x, sinfo_args=[R.Tensor((32, 32), "float32")]) |
| ) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_call_pure_packed_returning_object(): |
| @R.function |
| def foo() -> R.Object: |
| z = R.call_pure_packed("dummy_func", sinfo_args=R.Object) |
| return z |
| |
| bb = relax.BlockBuilder() |
| with bb.function("foo", params=[]): |
| z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()])) |
| bb.emit_func_output(z) |
| |
| _check(foo, bb.get()["foo"]) |
| |
| |
| def test_private_function(): |
| @I.ir_module |
| class Addition: |
| @R.function(private=True) |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.add(x, x) |
| return y |
| |
| x = relax.Var("x", R.Tensor((), "int32")) |
| bb = relax.BlockBuilder() |
| with bb.function("main", (x), private=True): |
| y = bb.emit(R.add(x, x)) |
| bb.emit_func_output(y) |
| |
| _check(Addition, bb.get()) |
| |
| |
| def test_private_function_with_global_symbol_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @I.ir_module |
| class Addition: |
| @R.function(private=True) |
| def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| # it is an error to simultaneously mark a function private |
| # and give it a global symbol manually |
| R.func_attr({"global_symbol": "main"}) |
| y = R.add(x, x) |
| return y |
| |
| # should not execute |
| _check(Addition) |
| |
| |
| def test_private_function_with_global_symbol_no_module_fail(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.function(private=True) |
| def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| R.func_attr({"global_symbol": "main"}) |
| y = R.add(x, x) |
| return y |
| |
| # should not execute |
| _check(func) |
| |
| |
| def test_macro_hygienic(): |
| x = R.prim_value(2) |
| |
| @R.macro(hygienic=True) |
| def alloc_and_shape(dtype: str): |
| alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=x, dtype=dtype) |
| shape = R.shape_of(alloc) |
| return shape |
| |
| x = R.prim_value(1) |
| |
| @R.function(private=True) |
| def func(z: R.Tensor((4, 4), "float32")): |
| shape = alloc_and_shape(dtype="float32") |
| return shape |
| |
| @R.function(private=True) |
| def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): |
| alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( |
| R.shape([4, 4]), |
| R.dtype("float32"), |
| R.prim_value(2), # Make sure prim_value is 2 |
| ) |
| shape: R.Shape([4, 4]) = R.shape_of(alloc) |
| shape_1: R.Shape([4, 4]) = shape |
| return shape_1 |
| |
| _check(func, expect) |
| |
| |
| def test_macro_non_hygienic(): |
| global global_x_var # Lookup doesn't find this variable if it's not global |
| |
| global_x_var = R.prim_value(2) |
| |
| @R.macro(hygienic=False) |
| def alloc_and_shape(dtype: str): |
| alloc = R.builtin.alloc_tensor( |
| R.shape([4, 4]), runtime_device_index=global_x_var, dtype=dtype |
| ) |
| shape = R.shape_of(alloc) |
| return shape |
| |
| global_x_var = R.prim_value(1) |
| |
| @R.function(private=True) |
| def func(z: R.Tensor((4, 4), "float32")): |
| shape = alloc_and_shape(dtype="float32") |
| return shape |
| |
| @R.function(private=True) |
| def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): |
| alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( |
| R.shape([4, 4]), |
| R.dtype("float32"), |
| R.prim_value(1), # Make sure prim_value is 1 |
| ) |
| shape: R.Shape([4, 4]) = R.shape_of(alloc) |
| shape_1: R.Shape([4, 4]) = shape |
| return shape_1 |
| |
| _check(func, expect) |
| |
| |
| def test_macro_no_variable_leak(): |
| with pytest.raises(tvm.error.DiagnosticError): |
| |
| @R.macro(hygienic=True) |
| def add_two(value): |
| x = value + R.const(1) # `x` defined in macro |
| y = x + R.const(1) |
| return y |
| |
| @R.function(private=True) |
| def func(t: R.Tensor((), "int32")): |
| u = add_two(t) |
| return x # Should be undefined here |
| |
| |
| def test_reused_extern_func(): |
| """ExternFunc lookups can become bindings in EliminateCommonSubexpr""" |
| |
| @R.function(private=True) |
| def parsed(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): |
| func = R.ExternFunc("extern_func") |
| gv0 = R.call_dps_packed(func, x, R.Tensor((128, 128), dtype="float32")) |
| gv1 = R.call_dps_packed(func, gv0, R.Tensor((128, 128), dtype="float32")) |
| return gv1 |
| |
| x = relax.Var("x", R.Tensor((128, 128), "float32")) |
| bb = relax.BlockBuilder() |
| with bb.function("main", [x], private=True): |
| func = bb.emit(relax.ExternFunc("extern_func")) |
| y = bb.emit(relax.call_dps_packed(func, x, out_sinfo=R.Tensor((128, 128), "float32"))) |
| z = bb.emit(relax.call_dps_packed(func, y, out_sinfo=R.Tensor((128, 128), "float32"))) |
| bb.emit_func_output(z) |
| |
| expected = bb.get()["main"] |
| |
| _check(parsed, expected) |
| |
| |
| def test_extern_func_in_module(): |
| """Module-level parsing may produce function bindings""" |
| |
| @I.ir_module |
| class parsed_module: |
| my_ext = R.ExternFunc("my_ext") |
| |
| @R.function |
| def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): |
| return a |
| |
| @R.function |
| def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): |
| return a |
| |
| expected = tvm.IRModule({"my_ext": relax.ExternFunc("my_ext"), "func": func}) |
| |
| _check(parsed_module, expected) |
| |
| |
| def test_define_relax_function_using_global_var(): |
| """A @R.function may call a GlobalVar |
| |
| When parsing a @R.function, the function's body may reference |
| GlobalVar instances available in the calling python scope. The |
| resulting function should pass TVMScript's well-formed check, as |
| the GlobalVar may be available in the IRModule for which the |
| function is being defined. |
| """ |
| |
| @I.ir_module |
| class DefinedAllAtOnce: |
| @R.function |
| def main(A: R.Tensor, B: R.Tensor): |
| return DefinedAllAtOnce.subroutine(A, B) |
| |
| @R.function(private=True) |
| def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor: |
| return R.matmul(A, B) |
| |
| @I.ir_module |
| class MainDefinedLater: |
| @R.function(private=True) |
| def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor: |
| return R.matmul(A, B) |
| |
| subroutine_gvar = MainDefinedLater.get_global_var("subroutine") |
| |
| @R.function |
| def main(A: R.Tensor, B: R.Tensor): |
| return subroutine_gvar(A, B) |
| |
| MainDefinedLater["main"] = main |
| |
| tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater) |
| |
| |
| def test_function_attributes_are_defined(): |
| """func.attrs defaults to an empty DictAttrs""" |
| |
| @I.ir_module |
| class Module: |
| @R.function |
| def main(x: R.Tensor, shape: R.Shape(["m", "n"])): |
| output = Module.subroutine(x, shape) |
| return output |
| |
| @R.function |
| def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]): |
| q = x |
| m, n = T.int64(), T.int64() |
| z = R.match_cast(q, R.Tensor((m, n))) |
| w = z |
| return w |
| |
| for gvar, func in Module.functions.items(): |
| assert func.attrs is not None |
| |
| |
| @pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing") |
| def test_function_symbolic_variables_are_annotated(): |
| """Symbolic variables must be exposed for struct inference |
| |
| Because Relax struct inference is performed while the function is |
| being built, all constraints on symbolic variables that are used |
| for simplifications must be provided to the analyzer. |
| """ |
| |
| @R.function(private=True) |
| def inferred_sinfo(A: R.Tensor(["extent"])): |
| extent = T.int64() |
| output = R.strided_slice(A, [0], [0], [extent - 1]) |
| return output |
| |
| @R.function(private=True) |
| def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): |
| extent = T.int64() |
| output: R.Tensor([extent - 1]) = R.strided_slice(A, [0], [0], [extent - 1]) |
| return output |
| |
| tvm.ir.assert_structural_equal(inferred_sinfo, expected) |
| |
| |
| def test_conditional_may_use_symbolic_variables_from_function_scope(): |
| """Symbolic variables from function scope may be used in branch |
| |
| This is a regression test. In earlier implementations, the |
| branches of `relax::If` were normalized with |
| `EraseToWellDefinedInScope`, using a fresh variable scope. While |
| this had the intended behavior of preventing variables defined in |
| a single branch from being usable outside of the conditional, it |
| also caused the conditional's branches to treat function-scope |
| symbolic variables as if they were undefined. |
| |
| """ |
| |
| @R.function(private=True) |
| def explicit_sinfo( |
| A: R.Tensor(["N"], "float32"), |
| B: R.Tensor(["N"], "float32"), |
| cond: R.Prim("bool"), |
| ) -> R.Tensor(["N"], "float32"): |
| N = T.int64() |
| |
| if cond: |
| out: R.Tensor([N], "float32") = A + B |
| else: |
| out: R.Tensor([N], "float32") = A * B |
| |
| return out |
| |
| @R.function(private=True) |
| def inferred_sinfo( |
| A: R.Tensor(["N"], "float32"), |
| B: R.Tensor(["N"], "float32"), |
| cond: R.Prim("bool"), |
| ): |
| N = T.int64() |
| if cond: |
| out = A + B |
| else: |
| out = A * B |
| |
| return out |
| |
| tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) |
| |
| |
| def test_return_from_dataflow_block(): |
| """Return statements imply |
| |
| The `R.output` statement in a `R.dataflow()` block marks a |
| variable that should be a `relax.Var` instead of a |
| `relax.DataflowVar`, allowing it to be used outside of the |
| `DataflowBlock` that defined it. A relax function's output is not |
| part of any binding, and must not contain any `DataflowVar`, so |
| these are exposed implicitly. |
| |
| """ |
| |
| @R.function(private=True) |
| def output_then_return(A: R.Tensor([16], "float16")): |
| with R.dataflow(): |
| B = R.add(A, A) |
| C = R.multiply(B, B) |
| R.output(C) |
| |
| return C |
| |
| @R.function(private=True) |
| def return_inside_dataflow(A: R.Tensor([16], "float16")): |
| with R.dataflow(): |
| B = R.add(A, A) |
| C = R.multiply(B, B) |
| return C |
| |
| tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |