blob: a8232fbc8f7ddde15e4f59c4450af4e5f2b6ffae [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relax as rx
from tvm import tir
from tvm.script import relax as R
import pytest
def _check_equal(x, y, map_free_vars=False):
tvm.ir.assert_structural_equal(x, y, map_free_vars)
tvm.ir.assert_structural_equal(y, x, map_free_vars)
xhash = tvm.ir.structural_hash(x, map_free_vars)
yhash = tvm.ir.structural_hash(y, map_free_vars)
assert xhash == yhash
def _check_json_roundtrip(x):
xret = tvm.ir.load_json(tvm.ir.save_json(x))
_check_equal(x, xret, map_free_vars=True)
return xret
def test_var() -> None:
v0 = rx.Var("v0")
assert v0.name_hint == "v0"
assert v0.struct_info_ is None
shape = [54, 96]
v1 = rx.Var("v1", R.Tensor(shape, "float32"))
assert v1.name_hint == "v1"
for s0, s1 in zip(v1.struct_info.shape, shape):
assert s0 == s1
tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32"))
def test_dataflow_var() -> None:
v0 = rx.DataflowVar("v0")
assert v0.name_hint == "v0"
assert v0.struct_info_ is None
shape = [54, 96]
v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16"))
assert v1.name_hint == "v1"
assert isinstance(v1, rx.DataflowVar)
tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16"))
def test_tuple() -> None:
v0 = rx.Var("v0")
v1 = rx.Var("v1")
t = rx.Tuple((v0, v1))
assert t.fields[0] == v0
assert t.fields[1] == v1
assert t[0] == v0
assert t[1] == v1
assert t[-1] == v1
assert t[-2] == v0
with pytest.raises(IndexError, match="Tuple index out of range"):
t[2]
with pytest.raises(IndexError, match="Tuple index out of range"):
t[-3]
def test_tuple_sinfo_inferred_on_construction():
v0 = rx.Var("v0", rx.ObjectStructInfo())
v1 = rx.Var("v1", rx.ObjectStructInfo())
tup = rx.Tuple((v0, v1))
assert tup.struct_info_ is not None
tvm.ir.assert_structural_equal(
tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), rx.ObjectStructInfo()])
)
def test_tuple_sinfo_requires_fields_with_known_sinfo():
v0 = rx.Var("v0", rx.ObjectStructInfo())
v1 = rx.Var("v1")
tup = rx.Tuple((v0, v1))
assert tup.struct_info_ is None
def test_match_cast() -> None:
# match_cast([16, 8], [m, n])
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
shape = rx.const([16, 8], "int32")
var = rx.Var("v0", R.Shape())
b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32"))
assert b0.value == shape
assert b0.pattern[0] == m
assert b0.pattern[1] == n
assert b0.var is not None
# var1: R.Tensor((m, n), "float32") =
# match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32"))
value = rx.Var("value", R.Tensor("float32", ndim=-1))
var = rx.Var("v1", R.Tensor([m, n], "float32"))
b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32"))
assert b1.value == value
assert b1.pattern[0] == m
assert b1.pattern[1] == n
assert b1.var is not None
def test_match_cast() -> None:
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
ivalue = rx.Var("input_value")
sinfo = rx.TensorStructInfo([n, m], "float32")
b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo)
assert b0.value.same_as(ivalue)
assert b0.struct_info == sinfo
_check_json_roundtrip(b0)
def test_var_binding() -> None:
v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b0 = rx.VarBinding(v0, val)
assert b0.var.name_hint == "v0"
assert b0.value == val
def test_binding_block() -> None:
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b1 = rx.VarBinding(v0, val)
block0 = rx.BindingBlock([b0, b1])
assert block0.bindings[0] == b0
assert block0.bindings[1] == b1
def test_dataflow_block() -> None:
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b1 = rx.VarBinding(v0, val)
block0 = rx.DataflowBlock([b0, b1])
assert block0.bindings[0] == b0
assert block0.bindings[1] == b1
assert isinstance(block0, rx.DataflowBlock)
def test_seq_expr() -> None:
x = rx.Var("foo")
bindings = [rx.VarBinding(x, rx.const(1))]
blocks = [rx.BindingBlock(bindings)]
seqe = rx.SeqExpr(blocks, x)
assert seqe.blocks[0] == blocks[0]
assert seqe.body == x
def test_func():
x = rx.Var("foo", R.Tensor(dtype="float32", ndim=2))
bindings = [rx.VarBinding(x, rx.const(1))]
blocks = [rx.BindingBlock(bindings)]
seqe = rx.SeqExpr(blocks, x)
ret_struct_info = R.Tensor(dtype="float32", ndim=-1)
func = rx.Function([x], seqe, ret_struct_info)
func = func.with_attr("global_symbol", "func")
assert func.params[0] == x
assert func.body == seqe
assert func.ret_struct_info == ret_struct_info
assert func.attrs["global_symbol"] == "func"
def test_shape_of():
shape = [96, 54]
v1 = rx.Var("v1", R.Tensor(shape))
s1 = rx.get_shape_of(v1)
for x, y in zip(shape, s1):
assert x == y
def test_shape_expr():
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
s = rx.ShapeExpr([m, n])
assert s.values[0] == m
assert s.values[1] == n
assert s[0] == m
assert s[1] == n
assert s[-1] == n
assert s[-2] == m
assert isinstance(s.struct_info, rx.ShapeStructInfo)
with pytest.raises(IndexError, match="ShapeExpr index out of range"):
s[2]
with pytest.raises(IndexError, match="ShapeExpr index out of range"):
s[-3]
shape_expr = rx.ShapeExpr([10, 20])
assert shape_expr.values[0] == 10
assert shape_expr.values[1] == 20
tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20)))
x = rx.Var("v0", R.Tensor((10, 20), "float32"))
assert x.struct_info.shape[0] == 10
assert x.struct_info.shape[1] == 20
tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20)))
m = tir.Var("m", "int32")
with pytest.raises(
tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64"
):
rx.ShapeExpr([m, 3])
def test_prim_value():
pv = rx.PrimValue(tir.IntImm("int64", 1))
assert pv.value.value == 1
_check_equal(pv, rx.PrimValue(tir.IntImm("int64", 1)))
_check_json_roundtrip(pv)
def test_prim_value_with_var():
n = tir.Var("n", "int64")
pv = rx.PrimValue(n)
assert pv.value.same_as(n)
tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n))
_check_equal(pv, rx.PrimValue(n))
_check_json_roundtrip(pv)
def test_prim_value_with_expr():
n = tir.Var("n", "int64")
pv = rx.PrimValue(n + 1)
tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1))
_check_equal(pv, rx.PrimValue(n + 1))
_check_json_roundtrip(pv)
def test_string_imm():
s0 = rx.StringImm("hello")
s1 = rx.StringImm("hello")
assert s0.value == "hello"
_check_equal(s0, s1)
_check_json_roundtrip(s0)
def test_datatype_imm():
d0 = rx.DataTypeImm("int32")
d1 = rx.DataTypeImm("int32")
assert d0.value == "int32"
_check_equal(d0, d1)
_check_json_roundtrip(d0)
def test_call():
dtype = rx.PrimStructInfo("int32")
func = rx.Var("func", rx.FuncStructInfo([dtype], dtype))
arg = rx.Var("arg", dtype)
call = rx.Call(func, [arg])
assert call.op.same_as(func)
assert len(call.args) == 1
assert call.args[0].same_as(arg)
def test_call_raises_error_for_invalid_function():
"""relax::Call requires the function to have FuncStructInfo"""
dtype = rx.PrimStructInfo("int32")
func = rx.Var("func", dtype)
arg = rx.Var("arg", dtype)
with pytest.raises(ValueError):
rx.Call(func, [arg])
if __name__ == "__main__":
tvm.testing.main()