blob: fd7c70c6148cf2b02e5b6ed623f1420b42397c7b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax as rx
from tvm import tir
from tvm.script import ir as I, relax as R, tir as T
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", R.Tensor([m, n], "float32"))
cond = rx.Var("cond", R.Tensor([], "bool"))
def build_function(blocks, params=[]):
"""Returns relax.function with given blocks"""
seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
func = rx.Function([x, cond] + params, seq_expr, R.Tensor("float32")).with_attr(
"global_symbol", "foo"
)
return func
def test_var():
# Error: Var gv0 is not defined
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, gv0)
bindings = [rx.VarBinding(gv1, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# Error: Var gv0 is defined more than once
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, x)
call_node2 = rx.op.multiply(x, x)
bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_dataflow_var():
# Error: DataflowVar lv0 is not defined
lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, lv0)
bindings = [rx.VarBinding(gv0, call_node)]
blocks = [rx.DataflowBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# Error: DataflowVar gv0 is defined more than once
lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, x)
call_node2 = rx.op.multiply(x, x)
bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)]
blocks = [rx.DataflowBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# Error: DataflowVar lv0 is defined outside DataflowBlock
lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, x)
bindings = [rx.VarBinding(lv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# Error: DataflowVar lv0 is used outside DataflowBlock
lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(lv0, x)
bindings = [rx.VarBinding(lv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_param_var():
v0 = rx.Var("v0", R.Tensor([m, n], "float32"))
v1 = rx.Var("v1", R.Tensor([m, n], "float32"))
v2 = rx.Var("v2", R.Tensor([m, n], "float32"))
bb = rx.BlockBuilder()
with bb.function("func1", [v0, v1]):
gv0 = bb.emit(rx.op.add(v0, v1))
bb.emit_func_output(gv0)
with bb.function("func2", [v0, v2]):
gv0 = bb.emit(rx.op.add(v2, v1))
bb.emit_func_output(gv0)
mod = bb.get()
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_global_var():
# Error: GlobalVar GlobalVar0 is not defined
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
globalvar = rx.GlobalVar("GlobalVar0")
call_node = rx.Call(
op=tvm.ir.Op.get("relax.call_tir"),
args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])],
)
bindings = [rx.VarBinding(gv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_symbolic_var():
# Error: Symbolic Var new_s is not defined
new_s = tir.Var("new_s", "int64")
gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64"))
call_node = rx.op.add(x, x)
bindings = [rx.VarBinding(gv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_symbolic_var_across_functions():
# Error: Symbolic Var s presents across different functions
s = tir.Var("s", "int64")
v0 = rx.Var("v0", R.Tensor([5, s], "float32"))
v1 = rx.Var("v1", R.Tensor([s, 7], "float32"))
bb = rx.BlockBuilder()
with bb.function("func1", [v0]):
bb.emit_func_output(v0)
with bb.function("func2", [v1]):
bb.emit_func_output(v1)
mod = bb.get()
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_symbolic_var_invalid_type():
with pytest.raises(
tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64"
):
dim = tir.Var("dim", "float32")
y = rx.Var("y", R.Tensor([dim], "float32"))
gv0 = rx.Var("gv0", R.Tensor([dim], "float32"))
call_node = rx.op.add(y, y)
bindings = [rx.VarBinding(gv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks, [y])
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_seq_expr():
# Error: SeqExpr in VarBinding
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
# build a SeqExpr
gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, gv0)
_bindings = [rx.VarBinding(gv1, call_node)]
_blocks = [rx.BindingBlock(_bindings)]
_seq_expr = rx.SeqExpr(_blocks, gv1)
# build a Binding with the SeqExpr as value
bindings = [rx.VarBinding(gv0, _seq_expr)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_recursive():
scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
gv0 = rx.Var("gv0", scalar_struct_info)
f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], scalar_struct_info))
ipt = rx.Var("ipt", scalar_struct_info)
x0 = rx.Var("x0", scalar_struct_info)
x1 = rx.Var("x1", scalar_struct_info)
x2 = rx.Var("x2", scalar_struct_info)
y = rx.Var("y", scalar_struct_info)
inner_block = rx.BindingBlock(
[rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, [x0]))]
)
inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_struct_info)
outer_block = rx.BindingBlock(
[
rx.VarBinding(f, inner_func),
rx.VarBinding(x1, rx.const(1, "int32")),
rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))),
rx.VarBinding(gv0, x2),
]
)
func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
mod = tvm.IRModule.from_expr(func)
normalized = rx.transform.Normalize()(mod)
assert rx.analysis.well_formed(normalized)
def test_if():
# Error: Var defined in true/false branch is invisible in the outer scope
# except the return Var, i.e the var in the last stmt
# v_in_if is invisible in the outer scope
v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32"))
# gv0 is visible in the outer scope
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
# build true branch
true_bindings = [
rx.VarBinding(v_in_if, rx.op.add(x, x)),
rx.VarBinding(gv0, rx.op.multiply(x, x)),
]
true_blocks = [rx.BindingBlock(true_bindings)]
true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var)
# build false branch
false_bindings = [
rx.VarBinding(v_in_if, rx.op.multiply(x, x)),
rx.VarBinding(gv0, rx.op.add(x, x)),
]
false_blocks = [rx.BindingBlock(false_bindings)]
false_seq_expr = rx.SeqExpr(false_blocks, false_blocks[-1].bindings[-1].var)
# build If node
if_node = rx.If(cond=cond, true_branch=true_seq_expr, false_branch=false_seq_expr)
gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
# try to call v_in_if defined in the true/false branch
bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=True)
def test_if_non_seq_body():
# Error: If node has a body that is not a seq node
if_node = rx.If(cond=cond, true_branch=x, false_branch=x)
blocks = [
rx.BindingBlock(
[
rx.VarBinding(
rx.Var("gv1", R.Tensor([m, n], "float32")),
if_node,
)
]
)
]
func = build_function(blocks)
mod = tvm.IRModule.from_expr(func)
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# on the other hand, if they're wrapped in a seq node, it's fine
seq = rx.SeqExpr([], x)
new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq)
new_blocks = [
rx.BindingBlock(
[
rx.VarBinding(
rx.Var("gv1", R.Tensor([m, n], "float32")),
new_if_node,
)
]
)
]
new_func = build_function(new_blocks)
new_mod = tvm.IRModule.from_expr(new_func)
# apply normalization to fill in struct_info_
normalized = rx.transform.Normalize()(new_mod)
assert rx.analysis.well_formed(normalized, check_struct_info=True)
def test_if_complex_condition():
# Error: If condition must be a leaf expression
cond_tuple = rx.Tuple([cond])
cond_idx = rx.TupleGetItem(cond_tuple, 0)
if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x))
blocks = [
rx.BindingBlock(
[
rx.VarBinding(
rx.Var("gv1", R.Tensor([m, n], "float32")),
if_node,
)
]
)
]
func = build_function(blocks)
mod = tvm.IRModule.from_expr(func)
assert not rx.analysis.well_formed(mod, check_struct_info=False)
cond_var = rx.Var("q", R.Tensor([], "bool"))
new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x))
blocks = [
rx.BindingBlock(
[
rx.VarBinding(cond_var, cond_idx),
rx.VarBinding(
rx.Var("gv1", R.Tensor([m, n], "float32")),
new_if,
),
]
)
]
func = build_function(blocks)
mod = tvm.IRModule.from_expr(func)
# apply normalization to fill in struct_info_
normalized = rx.transform.Normalize()(mod)
assert rx.analysis.well_formed(normalized, check_struct_info=True)
def test_tuple_get_item_nested():
# Error: The tuple value in tuple get item must be a leaf expression
nested_tup = rx.Var(
"t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])])
)
double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0)
ret_var = rx.Var("r", R.Tensor([], "int32"))
f = rx.Function(
[nested_tup],
rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var),
ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
)
f = f.with_attr("global_symbol", "f")
mod = tvm.IRModule.from_expr(f)
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# okay with an intermediate binding
first_idx = rx.TupleGetItem(nested_tup, 0)
idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")]))
second_idx = rx.TupleGetItem(idx_var, 0)
new_f = rx.Function(
[nested_tup],
rx.SeqExpr(
[
rx.BindingBlock(
[rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)]
)
],
ret_var,
),
ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
)
new_f = new_f.with_attr("global_symbol", "new_f")
mod = tvm.IRModule.from_expr(new_f)
# normalize in order to fill in checked type
normalized = rx.transform.Normalize()(mod)
assert rx.analysis.well_formed(normalized, check_struct_info=True)
def test_complex_seq_body():
# Error: seq expr with a body that is not a leaf expression is not permitted
x = rx.Var("x", R.Tensor([], "int32"))
y = rx.Var("y", R.Tensor([], "int32"))
func = rx.Function(
[x, y],
rx.SeqExpr([], rx.op.add(x, y)),
R.Tensor(ndim=0, dtype="int32"),
).with_attr("global_symbol", "foo")
mod = tvm.IRModule.from_expr(func)
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# but if the result is bound, then it's okay
z = rx.Var("z", R.Tensor([], "int32"))
new_func = rx.Function(
[x, y],
rx.SeqExpr(
[
rx.BindingBlock(
[
rx.VarBinding(
var=z,
value=rx.op.add(x, y),
)
]
)
],
z,
),
R.Tensor(ndim=0, dtype="int32"),
).with_attr("global_symbol", "foo")
new_mod = tvm.IRModule.from_expr(new_func)
# normalize in order to fill in checked type
normalized = rx.transform.Normalize()(new_mod)
assert rx.analysis.well_formed(normalized, check_struct_info=True)
def test_inline_prim_func():
# Error: inline prim_func is disallowed in Relax IR
x = rx.Var("x", R.Tensor([], "int32"))
y = rx.Var("y", R.Tensor([], "int32"))
new_func = rx.Function(
[],
rx.SeqExpr(
[
rx.BindingBlock(
[
rx.VarBinding(
var=x,
value=tir.PrimFunc([], tir.Evaluate(0)),
),
rx.VarBinding(
var=y,
value=rx.Call(
op=tvm.ir.Op.get("relax.call_tir"),
args=[
rx.GlobalVar("GlobalVar0"),
rx.Tuple([x, tir.PrimFunc([], tir.Evaluate(0))]),
rx.ShapeExpr([]),
],
),
),
]
)
],
y,
),
R.Tensor(ndim=0, dtype="int32"),
).with_attr("global_symbol", "foo")
new_mod = tvm.IRModule.from_expr(new_func)
assert not rx.analysis.well_formed(new_mod, check_struct_info=False)
def test_ANF():
# Error: Nested Call
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
call_node = rx.op.add(x, rx.op.add(x, x))
bindings = [rx.VarBinding(gv0, call_node)]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
# Error: Call Node in Tuple
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))]
blocks = [rx.BindingBlock(bindings)]
func = build_function(blocks)
mod = tvm.IRModule({rx.GlobalVar("foo"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_global_var_vs_gsymbol():
# Error: gsymbol "main1" not equals to the name in global var "main"
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
bindings = [rx.VarBinding(gv0, x)]
blocks = [rx.DataflowBlock(bindings)]
func = rx.Function(
[x],
rx.SeqExpr(blocks, gv0),
R.Tensor(ndim=2, dtype="float32"),
).with_attr("global_symbol", "main1")
mod = tvm.IRModule({rx.GlobalVar("main"): func})
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_nested_dataflow():
scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
gv0 = rx.Var("gv0", scalar_struct_info)
f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info))
x0 = rx.DataflowVar("x0", scalar_struct_info)
x1 = rx.DataflowVar("x1", scalar_struct_info)
x2 = rx.DataflowVar("x2", scalar_struct_info)
y = rx.Var("y", scalar_struct_info)
inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, x0)])
inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_struct_info)
outer_block = rx.DataflowBlock(
[
rx.VarBinding(x1, rx.const(1, "int32")),
rx.VarBinding(f, inner_func),
rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))),
rx.VarBinding(gv0, x2),
]
)
func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
mod = tvm.IRModule.from_expr(func)
normalized = rx.transform.Normalize()(mod)
assert rx.analysis.well_formed(normalized)
def test_sinfo_args_tir_var_used_before_define_call_packed():
# Error: Symbolic Var m1, n1 are not defined
m1 = tir.Var("m1", "int64")
n1 = tir.Var("n1", "int64")
call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32"))
func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])])
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_sinfo_args_tir_var_used_before_define_call_tir():
# Error: Symbolic Var m1, n1 are not defined
m1 = tir.Var("m1", "int64")
n1 = tir.Var("n1", "int64")
call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32"))
func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])])
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod, check_struct_info=False)
def test_sinfo_erase_to_well_formed():
# Error: The return sinfo contains undefined symbolic vars
"""
@R.function
def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtype="float32"):
m = T.int64()
n = T.int64()
gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32"))
return gv
"""
m1 = tir.Var("m1", "int64")
n1 = tir.Var("n1", "int64")
call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32"))
blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]
seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr(
"global_symbol", "foo"
)
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)
def test_func_sinfo_well_formed():
@R.function
def foo():
@R.function
def local(x: R.Tensor(["m", "n"], "float32")):
return x
return local
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo))
assert rx.analysis.well_formed(mod)
def test_conditional_in_dataflow_block():
# error: not allowed to have a conditional inside a dataflow block
x = rx.Var("x", rx.TensorStructInfo([], dtype="int32"))
y = rx.Var("y", rx.TensorStructInfo([], dtype="int32"))
block = rx.DataflowBlock([rx.VarBinding(y, rx.If(rx.const(True, dtype="bool"), x, x))])
func = rx.Function([x], rx.SeqExpr([block], y), R.Tensor((), dtype="int32")).with_attr(
"global_symbol", "foo"
)
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)
def test_unlabeled_impure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but the function is not labeled as impure
func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attr(
"global_symbol", "foo"
)
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)
def test_labeled_impure():
# the function is labeled impure so the impure operation is permitted
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but the function is not labeled as impure
func = rx.Function(
[x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=False
).with_attrs({"global_symbol": "foo"})
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)
def test_force_pure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but force_pure overrides the judgment
func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs(
{"global_symbol": "foo", "relax.force_pure": True}
)
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)
def test_force_pure_improper():
# we require both the is_pure and force_pure flags to be set together
x = rx.Var("x", R.Tensor((), dtype="int32"))
# otherwise inoffensive, but the flags are wrong
func = rx.Function(
[x], rx.SeqExpr([], x), R.Tensor((), dtype="int32"), is_pure=False
).with_attrs({"global_symbol": "foo", "relax.force_pure": True})
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)
def test_impure_in_dataflow_block(capfd):
# even if force_pure is set, an impure operation cannot appear in a dataflow block
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.DataflowVar("y")
block = rx.DataflowBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs(
{"global_symbol": "foo", "relax.force_pure": True}
)
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)
_stdout, stderr = capfd.readouterr()
assert "R.print" in stderr
def test_well_formed_function():
"""Relax's well-formed check can be applied on a function"""
@R.function
def func(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return R.matmul(A, B)
assert rx.analysis.well_formed(func)
def test_well_formed_function_referencing_global_var():
"""GlobalVar may refer to other functions in the module
If validating that a IRModule is well-formed, the GlobalVar must
have a definition. If validating that a relax.Function is
well-formed, no GlobalVar definitions are available.
"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return Module.subroutine(A, B)
@R.function(private=True)
def subroutine(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return R.matmul(A, B)
assert rx.analysis.well_formed(Module)
assert rx.analysis.well_formed(Module["main"])
assert rx.analysis.well_formed(Module["subroutine"])
def test_pass_dltensor_arg_to_tir():
"""Relax may pass R.Tensor as DLTensor
In TIR, a `DLTensor*` argument with unknown shape and dtype is
represented as a `tir.Var` with
`tvm::PrimType(DataType::Handle())`, and with no entry in the
`PrimFuncNode::buffer_map`. In Relax, this is represented as
`R.Tensor`. Calls from Relax to TIR that pass a tensor of unknown
rank/shape are well-formed.
In the test case below, a TIR function accepts an arbitrary
`R.Tensor`, and returns a boolean value based on inspection of the
runtime datatype.
"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor) -> R.Prim("bool"):
return Module.is_bfloat16_dtype(A)
@T.prim_func(private=True)
def is_bfloat16_dtype(tensor: T.handle) -> T.bool:
T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True})
# From #include <tvm/tir/builtin.h>
kArrTypeCode = T.meta_var(5)
kArrTypeBits = T.meta_var(6)
kArrTypeLanes = T.meta_var(7)
# From #include <dlpack/dlpack.h>
kDLBfloat = T.meta_var(4)
type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode, dtype="uint8")
type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits, dtype="uint8")
type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes, dtype="uint16")
is_bfloat16: T.bool = (
(type_code == kDLBfloat) and (type_bits == 16) and (type_lanes == 1)
)
return is_bfloat16
assert rx.analysis.well_formed(Module)
def test_call_tir_with_matching_arguments():
"""R.call_tir is well-formed when called with matching arguments"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert rx.analysis.well_formed(Module)
def test_call_tir_input_ndim():
"""Arguments to R.call_tir must have the correct dimensionality
Here, the `add_one` function expects a 1-d input tensor, but is
called with a 2-d tensor.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([4, 4], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_output_ndim():
"""Output shape R.call_tir must have the correct dimensionality
Here, the `add_one` function requires a 1-d output tensor, but is
provided with a 2-d tensor.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_input_shape():
"""Arguments to R.call_tir must have the correct shape
Here, the `add_one` function expects an input tensor with 16
elements, but is called with an input tensor with 32 elements.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([32], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_output_shape():
"""Output shape R.call_tir must have the correct shape
Here, the `add_one` function requires an output tensor with 16
elements, but is provided an output tensor with 32 elements.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_input_dtype():
"""Arguments to R.call_tir must have the correct dtype
Here, the `add_one` function expects an input tensor containing
float16 value, but is called with an input tensor containing
float32 values.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float32")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_output_dtype():
"""Output shape R.call_tir must have the correct shape
Here, the `add_one` function requires an output tensor that may be
populated with float16 values, but is provided an output tensor
that may be populated with float32 elements.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32"))
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_with_correct_dynamic_output_shape():
"""Output shape R.call_tir may not be verifiable
Here, the input arguments to the `reshape` function are not
sufficient to infer the shape of the outputs. This is legal,
since the output shape is determined by the `out_sinfo` parameter.
Inability to verify the output shape does not mean that the output
shape is invalid.
"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16"))
return B
@T.prim_func
def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
M = T.int64()
N = T.int64()
B = T.match_buffer(B_handle, [M, N], dtype="float16")
for i, j in T.grid(M, N):
with T.block("compute"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * N + vj]
assert rx.analysis.well_formed(Module)
@pytest.mark.xfail(reason="Not supported")
def test_call_tir_with_incorrect_dynamic_output_shape():
"""Output shape R.call_tir may not be verifiable
Here, the input arguments to the `reshape` function are not
sufficient to infer the shape of the outputs. Even though the
IRModule will not provide well-defined output due to the
out-of-bounds read from buffer A, catching this error is beyond
the current scope of the Relax well-formed checker.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16"))
return B
@T.prim_func
def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
M = T.int64()
N = T.int64()
B = T.match_buffer(B_handle, [M, N], dtype="float16")
for i, j in T.grid(M, N):
with T.block("compute"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * N + vj]
assert not rx.analysis.well_formed(Module)
def test_call_tir_incorrect_dimensionality_of_output_shape():
"""Dimensionality may be verified
Here, the input arguments to the `reshape` function are not
sufficient to infer the shape of the outputs.
Even though the output shape may not be inferred from the input
arguments, the output dimensionality can still be inferred from
the PrimFunc signature. The IRModule below is ill-formed, because
the PrimFunc requires a 2-d output argument, but is provided with
a 3-d output argument.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16"))
return B
@T.prim_func
def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
M = T.int64()
N = T.int64()
B = T.match_buffer(B_handle, [M, N], dtype="float16")
for i, j in T.grid(M, N):
with T.block("compute"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * N + vj]
assert not rx.analysis.well_formed(Module)
@pytest.mark.xfail(reason="Not yet supported")
def test_call_tir_output_shape_with_mixed_static_and_dynamic():
"""Some dimensions of the R.call_tir output shape may be verifiable
Here, the input arguments to the `reshape` function are not
sufficient to infer the shape of the outputs. This is legal,
since the output shape is taken from the `out_sinfo` parameter.
Identifying this failure mode is not yet supported in the current
implementation. This is because the output is inferred as
`R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo`
is a 3-d tensor. The mismatch in the first dimension is not yet
counted, because the entire tensor shape is removed by
`EraseToWellDefined`.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([256], "float16")):
B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16"))
return B
@T.prim_func
def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle):
M = T.int64()
N = T.int64()
B = T.match_buffer(B_handle, [16, M, N], dtype="float16")
for i, j, k in T.grid(16, M, N):
with T.block("compute"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi * N * M + vj * N + vk]
assert not rx.analysis.well_formed(Module)
def test_call_tir_with_correct_inferred_dynamic_output_shape():
"""Some dynamic output shapes of R.call_tir may be inferred
Here, the `flatten` function is dynamic, and will flatten any 2-d
TIR buffer. Even though it is dynamic, the input shapes are
sufficient to infer that `M==8` and `N==4`. As a result, the
output shape of `[M*N]` can be inferred to be `[32]`, and the
shape specified in `out_sinfo` can be validated.
"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([8, 4], "float16")):
B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16"))
return B
@T.prim_func
def flatten(A_handle: T.handle, B_handle: T.handle):
M = T.int64()
N = T.int64()
A = T.match_buffer(A_handle, [M, N], dtype="float16")
B = T.match_buffer(B_handle, [M * N], dtype="float16")
for i in T.grid(M * N):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi // N, vi % N]
assert rx.analysis.well_formed(Module)
def test_call_tir_with_incorrect_inferred_dynamic_output_shape():
"""Some dynamic output shapes of R.call_tir may be inferred
Here, the `flatten` function is dynamic, and will flatten any 2-d
TIR buffer. Even though it is dynamic, the input shapes are
sufficient to infer that `M==8` and `N==4`. As a result, the
output shape of `[M*N]` can be inferred to be `[32]`, and the
shape specified in `out_sinfo` can be validated.
This unit test is identical to the above test
`test_call_tir_with_correct_inferred_dynamic_output_shape`, except
that the output shape is explicitly specified as `[64]`, which is
caught as a mismatch from the expected output shape.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([8, 4], "float16")):
B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16"))
return B
@T.prim_func
def flatten(A_handle: T.handle, B_handle: T.handle):
M = T.int64()
N = T.int64()
A = T.match_buffer(A_handle, [M, N], dtype="float16")
B = T.match_buffer(B_handle, [M * N], dtype="float16")
for i in T.grid(M * N):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi // N, vi % N]
assert not rx.analysis.well_formed(Module)
def test_call_tir_with_dtensor_arguments():
"""R.call_tir and R.dist.call_tir share the same operation
Both `R.call_tir` and `R.dist.call_tir` produce the same
"relax.call_tir" operation, differing only in the StructInfo of
their arguments. Normalization of "relax.call_tir" must handle
`R.DTensor` arguments.
"""
# from tvm.script.parser import relax as R
@I.ir_module
class Module:
I.module_attrs({"device_num": 4})
I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0, 4))]})
@R.function
def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")):
B = R.dist.call_tir(
Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]")
)
return B
@T.prim_func
def flatten(A_handle: T.handle, B_handle: T.handle):
M = T.int64()
N = T.int64()
A = T.match_buffer(A_handle, [M, N], dtype="float16")
B = T.match_buffer(B_handle, [M * N], dtype="float16")
for i in T.grid(M * N):
with T.block("compute"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi // N, vi % N]
assert rx.analysis.well_formed(Module)
def test_call_tir_inplace_with_correct_shapes():
"""R.call_tir_inplace is well-formed when called with matching arguments"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir_inplace(
Module.add_one,
A,
inplace_indices=[0],
out_sinfo=R.Tensor([16], "float16"),
)
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
A[vi] = A[vi] + T.float16(1.0)
assert rx.analysis.well_formed(Module)
def test_call_tir_inplace_with_incorrect_shapes():
"""R.call_tir_inplace is ill-formed when output shape does not match input"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.call_tir_inplace(
Module.add_one,
A,
inplace_indices=[0],
out_sinfo=R.Tensor([32], "float16"),
)
return B
@T.prim_func
def add_one(A: T.Buffer(16, "float16")):
for i in range(16):
with T.block("compute"):
vi = T.axis.remap("S", [i])
A[vi] = A[vi] + T.float16(1.0)
assert not rx.analysis.well_formed(Module)
def test_call_tir_inplace_with_some_allocated_outputs():
"""R.call_tir_inplace may contain some non-inplace outputs"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")):
out = R.call_tir_inplace(
Module.add_one,
(A, B),
inplace_indices=[-1, 1],
out_sinfo=[
R.Tensor([16], "float16"),
R.Tensor([32], "float16"),
],
)
return out
@T.prim_func
def add_one(
A: T.Buffer(16, "float16"),
B: T.Buffer(32, "float16"),
C: T.Buffer(16, "float16"),
):
for i in range(32):
with T.block("inplace_B"):
vi = T.axis.remap("S", [i])
B[vi] = B[vi] + T.float16(1.0)
for i in range(16):
with T.block("output_C"):
vi = T.axis.remap("S", [i])
C[vi] = A[vi] + T.float16(1.0)
assert rx.analysis.well_formed(Module)
def test_var_binding_must_have_compatible_struct_info():
"""Variables must accurately describe their contents
To be well-formed, the inferred struct info must not conflict with
the StructInfo annotations.
"""
# The function is equivalent to the TVMScript below. However,
# TVMScript applies additional checks that would catch this error
# while parsing. In order to validate the well-formed checker
# itself, this test directly constructs the function withoutusing
# TVMScript, skipping the TVMScript-specific checks.
#
# @R.function
# def main(
# A: R.Tensor(shape=[128, 32], dtype="float32"),
# ):
# B: R.Tensor(shape=[128, 32], dtype="int32") = A
# return B
param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32"))
binding = tvm.relax.VarBinding(var, param)
body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
tvm.relax.expr._update_struct_info(body, var.struct_info)
main = tvm.relax.Function([param], body)
assert not rx.analysis.well_formed(main)
def test_var_binding_may_have_less_constrained_struct_info():
"""StructInfo of variable may be less specific than expression
The StructInfo annotation of a variable is not required to be an
exact match to the expression's StructInfo, and may provide less
specific information than the inference would provide.
"""
@I.ir_module
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
):
B: R.Object = R.add(A, A)
return B
assert isinstance(
Module["main"].body.blocks[0].bindings[0].var.struct_info, tvm.relax.ObjectStructInfo
), "Validity of this test requires a variable with R.Object struct info"
assert rx.analysis.well_formed(Module)
def test_var_binding_with_incomplete_struct_info_must_be_consistent():
"""StructInfo of variable must be accurate
Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.
"""
# The function is equivalent to the TVMScript below. However,
# TVMScript applies additional checks that would catch this error
# while parsing. In order to validate the well-formed checker
# itself, this test directly constructs the function withoutusing
# TVMScript, skipping the TVMScript-specific checks.
#
# @R.function
# def main(
# A: R.Tensor(shape=[128, 32], dtype="float32"),
# ):
# B: R.Tensor(ndim=3) = A
# return B
param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32"))
binding = tvm.relax.VarBinding(var, param)
body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
tvm.relax.expr._update_struct_info(body, var.struct_info)
main = tvm.relax.Function([param], body)
assert not rx.analysis.well_formed(main)
def test_incomplete_struct_info_must_be_consistent():
"""StructInfo annotations must be accurate
Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(ndim=3) = R.add(A, B)
return C
assert not rx.analysis.well_formed(Module)
def test_struct_info_annotations_must_be_correct():
"""StructInfo annotations must be correct
To be well-formed, the inferred struct info must not conflict with
the StructInfo annotations.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
return C
assert not rx.analysis.well_formed(Module)
def test_struct_info_may_be_incomplete():
"""StructInfo annotations may be less specific
The StructInfo annotations are not required to be an exact match
to the inferred StructInfo, and may provide less specific
information than the inference would provide.
"""
@I.ir_module
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Object = R.add(A, B)
return C
assert rx.analysis.well_formed(Module)
def test_incomplete_struct_info_must_be_consistent():
"""StructInfo annotations must be accurate
Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.
"""
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(ndim=3) = R.add(A, B)
return C
assert not rx.analysis.well_formed(Module)
if __name__ == "__main__":
tvm.testing.main()