blob: e3274aea886a49dad67cab6a1cce6ee44eebbd64 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import relax
import tvm.script
from tvm.script import ir as I, tir as T, relax as R
def test_to_non_dataflow():
@tvm.script.ir_module
class TestToNonDataflow:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
with R.dataflow():
lv0 = R.call_dps_packed(
"test.op.identity",
(x,),
R.Tensor(
(m, n),
dtype="float32",
),
)
gv0 = R.call_dps_packed(
"test.op.identity",
(lv0,),
R.Tensor(
(m, n),
dtype="float32",
),
)
R.output(gv0)
return gv0
mod = TestToNonDataflow
old_vars = []
def fvisit(e):
if isinstance(e, relax.Var):
nonlocal old_vars
old_vars.append(e)
relax.analysis.post_order_visit(mod["foo"], fvisit)
x, lv0, gv0 = old_vars
new_mod = relax.transform.ToNonDataflow()(mod)
new_vars = []
def fvisit(e):
if isinstance(e, relax.Var):
nonlocal new_vars
new_vars.append(e)
relax.analysis.post_order_visit(new_mod["foo"], fvisit)
assert x == new_vars[0]
assert lv0 != new_vars[1]
assert isinstance(lv0, relax.DataflowVar)
assert not isinstance(new_vars[1], relax.DataflowVar)
assert isinstance(gv0, relax.Var)
assert isinstance(new_vars[2], relax.Var)
assert gv0 == new_vars[2]
def test_call_tir_rewrite():
@tvm.script.ir_module
class TestCallTIRRewrite:
@T.prim_func
def exp(A_handle: T.handle, B_handle: T.handle):
m = T.int64()
n = T.int64()
A = T.match_buffer(A_handle, (m, n), "float32")
B = T.match_buffer(B_handle, (m, n), "float32")
T.evaluate(0)
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
# we expect RemovePurityChecking to have been used before this point
R.func_attr({"relax.force_pure": True})
m, n = T.int64(), T.int64()
gv0 = R.call_tir(TestCallTIRRewrite.exp, (x,), R.Tensor((m, n), dtype="float32"))
return gv0
mod = TestCallTIRRewrite
# before rewrite
v0 = mod["foo"].body.blocks[0].bindings[0].var
s0 = mod["foo"].body.blocks[0].bindings[0].value
assert isinstance(s0, relax.Call)
assert s0.op.name == "relax.call_tir"
# after rewrite
new_mod = relax.transform.CallTIRRewrite()(mod)
func = new_mod["foo"]
block = func.body.blocks[0]
assert not isinstance(block, relax.DataflowBlock)
s1 = block.bindings[0].value
assert isinstance(s1, relax.Call)
assert s1.op.name == "relax.builtin.alloc_tensor"
assert isinstance(s1.args[0], relax.ShapeExpr)
tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape)
s2 = block.bindings[1].value
tvm.ir.expr.GlobalVar
assert s2.op.name_hint == "exp"
def test_transform_remove_purity_checking():
@tvm.script.ir_module
class Before:
@R.function
def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
y = R.add(x, x)
z = R.add(x, y)
return z
@R.function
def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
y = R.add(x, x)
z = R.call_pure_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32")))
return z
@R.function
def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
closure = R.make_closure(Before.base, ())
res = R.invoke_pure_closure(closure, (x,), sinfo_args=R.Tensor((), "int32"))
return res
@R.function(pure=False)
def impure_func() -> R.Object:
y = R.print(format="I am impure!")
return y
@R.function
def nested_pure_func() -> R.Tensor((), "int32"):
@R.function
def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
y = R.add(x, x)
q = R.call_pure_packed(
"vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))
)
return q
z = R.const(1, dtype="int32")
w = nested(z)
return w
@R.function(pure=False)
def nested_impure_func() -> R.Tensor((), "int32"):
@R.function(pure=False)
def nested() -> R.Object:
x = R.print(format="Oops!")
return x
y = R.const(1, dtype="int32")
z = nested()
return y
@tvm.script.ir_module
class Expected:
@R.function
def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.func_attr({"relax.force_pure": True})
y = R.add(x, x)
z = R.add(x, y)
return z
@R.function
def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.func_attr({"relax.force_pure": True})
y = R.add(x, x)
z = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32")))
return z
@R.function
def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.func_attr({"relax.force_pure": True})
closure = R.make_closure(Expected.base, ())
res = R.invoke_closure(closure, (x,), sinfo_args=R.Tensor((), "int32"))
return res
@R.function(pure=False)
def impure_func() -> R.Object:
y = R.print(format="I am impure!")
return y
@R.function
def nested_pure_func() -> R.Tensor((), "int32"):
R.func_attr({"relax.force_pure": True})
@R.function
def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.func_attr({"relax.force_pure": True})
y = R.add(x, x)
q = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32")))
return q
z = R.const(1, dtype="int32")
w = nested(z)
return w
@R.function(pure=False)
def nested_impure_func() -> R.Tensor((), "int32"):
@R.function(pure=False)
def nested() -> R.Object:
x = R.print(format="Oops!")
return x
y = R.const(1, dtype="int32")
z = nested()
return y
new_mod = relax.transform.RemovePurityChecking()(Before)
tvm.ir.assert_structural_equal(new_mod, Expected)
def test_call_dps_packed_rewrite():
@tvm.script.ir_module
class TestCallDPSPackedRewrite:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
# we expect RemovePurityChecking to have been used before this point
R.func_attr({"relax.force_pure": True})
m, n = T.int64(), T.int64()
gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
return gv0
mod = TestCallDPSPackedRewrite
# before rewrite
v0 = mod["foo"].body.blocks[0].bindings[0].var
s0 = mod["foo"].body.blocks[0].bindings[0].value
assert isinstance(s0, relax.Call)
assert s0.op.name == "relax.call_dps_packed"
# CallTIRRewrite also works for call_dps_packed
new_mod = relax.transform.CallTIRRewrite()(mod)
func = new_mod["foo"]
block = func.body.blocks[0]
assert not isinstance(block, relax.DataflowBlock)
s1 = block.bindings[0].value
assert isinstance(s1, relax.Call)
assert s1.op.name == "relax.builtin.alloc_tensor"
assert isinstance(s1.args[0], relax.ShapeExpr)
tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape)
s2 = block.bindings[1].value
assert s2.op.global_symbol == "test.op.identity"
def test_call_tir_inplace_simple():
# simple case: one inplace argument
@tvm.script.ir_module
class Input:
@T.prim_func
def zeros(A: T.Buffer((2, 3), "int32")):
# just overwrites A with 0s
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.writes(A[ax0, ax1])
A[ax0, ax1] = T.int32(0)
@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
# we expect RemovePurityChecking to have been used before this point
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(Input.zeros, x, 0, R.Tensor((2, 3), dtype="int32"))
return gv0
@tvm.script.ir_module
class Expected:
@T.prim_func
def zeros(A: T.Buffer((2, 3), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.writes(A[ax0, ax1])
A[ax0, ax1] = T.int32(0)
@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
_ = Expected.zeros(x)
gv0 = x
return gv0
new_mod = relax.transform.CallTIRRewrite()(Input)
tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True)
def test_call_tir_inplace_multiple_args():
@tvm.script.ir_module
class Input:
@T.prim_func
def copy(
A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32")
):
# copies the contents of C into A and B
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(C[ax0, ax1])
T.writes(A[ax0, ax1], B[ax0, ax1])
A[ax0, ax1] = C[ax0, ax1]
B[ax0, ax1] = C[ax0, ax1]
@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.copy,
(x, y, z),
[0, 1],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0
@tvm.script.ir_module
class Expected:
@T.prim_func
def copy(
A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32")
):
# copies the contents of C into A and B
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(C[ax0, ax1])
T.writes(A[ax0, ax1], B[ax0, ax1])
A[ax0, ax1] = C[ax0, ax1]
B[ax0, ax1] = C[ax0, ax1]
@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
_ = Expected.copy(x, y, z)
gv0 = (x, y)
return gv0
new_mod = relax.transform.CallTIRRewrite()(Input)
tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True)
def test_call_tir_inplace_some_new():
@tvm.script.ir_module
class Input:
@T.prim_func
def copy(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
C: T.Buffer((2, 3), "int32"),
out1: T.Buffer((2, 3), "int32"),
out2: T.Buffer((2, 3), "int32"),
):
# copies the contents of C into A, out1, and out2
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(C[ax0, ax1])
T.writes(A[ax0, ax1], out1[ax0, ax1], out2[ax0, ax1])
A[ax0, ax1] = C[ax0, ax1]
out1[ax0, ax1] = C[ax0, ax1]
out2[ax0, ax1] = C[ax0, ax1]
@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(
R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), dtype="int32")
):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.copy,
(x, y, z),
[0, -1, -1],
[
R.Tensor((2, 3), dtype="int32"),
R.Tensor((2, 3), dtype="int32"),
R.Tensor((2, 3), dtype="int32"),
],
)
return gv0
@tvm.script.ir_module
class Expected:
@T.prim_func
def copy(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
C: T.Buffer((2, 3), "int32"),
out1: T.Buffer((2, 3), "int32"),
out2: T.Buffer((2, 3), "int32"),
):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(C[ax0, ax1])
T.writes(A[ax0, ax1], out1[ax0, ax1], out2[ax0, ax1])
A[ax0, ax1] = C[ax0, ax1]
out1[ax0, ax1] = C[ax0, ax1]
out2[ax0, ax1] = C[ax0, ax1]
@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(
R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), dtype="int32")
):
R.func_attr({"relax.force_pure": True})
gv0 = R.builtin.alloc_tensor(R.shape([2, 3]), "int32", R.prim_value(0))
gv1 = R.builtin.alloc_tensor(R.shape([2, 3]), "int32", R.prim_value(0))
_ = Expected.copy(x, y, z, gv0, gv1)
gv2 = (x, gv0, gv1)
return gv2
new_mod = relax.transform.CallTIRRewrite()(Input)
tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True)
def test_call_tir_inplace_repeated_input():
with pytest.raises(tvm.error.DiagnosticError):
@tvm.script.ir_module
class Input:
@T.prim_func
def func(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
C: T.Buffer((2, 3), "int32"),
):
T.evaluate(0)
@R.function
def foo(
x: R.Tensor((2, 3), "int32"),
y: R.Tensor((2, 3), "int32"),
z: R.Tensor((2, 3), "int32"),
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.func,
(x, y, z),
# repeated 0 -> that's an error
[0, 0],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0
def test_call_tir_inplace_all_new():
with pytest.raises(tvm.error.DiagnosticError):
@tvm.script.ir_module
class Input:
@T.prim_func
def func(A: T.Buffer((2, 3), "int32")):
T.evaluate(0)
@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
# cannot make the only output a fresh one
gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32"))
return gv0
def test_inplace_mutation_with_tuple_argument_raises_error():
"""TIR PrimFuncs do not support Tuple arguments
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, `[[A]]` specifies a tuple of arguments, where the
first argument is itself a tuple. Since PrimFuncs do not support
Tuple arguments, this is invalid.
This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.
"""
with pytest.raises(tvm.error.DiagnosticError):
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
cls = Module
gv1 = R.call_tir_inplace(
cls.multiply_by_two,
[[A]],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)
def test_inplace_mutation_with_non_tensor_argument_raises_error():
"""In-place argument must be a tensor
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, the argument `A` is not a tensor.
This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.
"""
with pytest.raises(tvm.error.DiagnosticError):
@I.ir_module
class Module:
@R.function
def main(A: R.Object):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)
def test_inplace_mutation_with_incompatible_tensor_shape_raises_error():
"""In-place argument must have compatible shape
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the shape of argument `A` is
different than the output's shape (`[32]` as opposed to `[16]`).
"""
with pytest.raises(tvm.error.DiagnosticError):
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([32], dtype="float32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)
def test_inplace_mutation_with_incompatible_tensor_dtype_raises_error():
"""In-place argument must have compatible dtype
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the dtype of argument `A` is
different than the output's dtype (`int32` as opposed to `float32`).
"""
with pytest.raises(tvm.error.DiagnosticError):
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], dtype="int32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)
if __name__ == "__main__":
tvm.testing.main()