blob: bed44c4a6ac21d3206005a6f7f8ea376087bf897 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.script
import tvm.testing
from tvm import relax, tir
from tvm.script import relax as R
import numpy as np
import pytest
param_specification = tvm.testing.parameter("by_string", "by_var")
param_shape = tvm.testing.parameter("static_shape", "dynamic_shape", "ndim", "arbitrary")
tensor_param_dtype = tvm.testing.parameter("float32", None)
def test_bind_tensor_param(param_specification, param_shape, tensor_param_dtype):
if param_shape == "static_shape":
shape = [16]
ndim = -1
elif param_shape == "dynamic_shape":
shape = [tir.Var("N", "int64")]
ndim = -1
elif param_shape == "ndim":
shape = None
ndim = 1
elif param_shape == "arbitrary":
shape = None
ndim = -1
else:
raise ValueError(f"Unknown param_shape: {param_shape}")
@R.function
def before(A: R.Tensor(shape, ndim=ndim, dtype=tensor_param_dtype)):
R.func_attr({"global_symbol": "main"})
B: R.Tensor(shape=shape, ndim=ndim, dtype=tensor_param_dtype) = A
out = R.add(B, B)
return out
np_data = np.arange(16).astype("float32")
inlined_relax_const = relax.const(np_data)
@R.function
def expected() -> R.Tensor([16], "float32"):
R.func_attr({"global_symbol": "main"})
B = inlined_relax_const
out = R.add(B, B)
return out
if param_specification == "by_string":
var = "A"
elif param_specification == "by_var":
var = before.params[0]
else:
raise ValueError("Unknown param_specification: {param_specification}")
after = before.bind_params({var: np.arange(16).astype("float32")})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_shape_param(param_shape):
if param_shape == "static_shape":
shape = [16]
ndim = -1
elif param_shape == "dynamic_shape":
shape = [tir.Var("N", "int64")]
ndim = -1
elif param_shape == "ndim":
shape = None
ndim = 1
elif param_shape == "arbitrary":
shape = None
ndim = -1
else:
raise ValueError(f"Unknown param_shape: {param_shape}")
@R.function
def before(A: R.Shape(shape, ndim=ndim)):
R.func_attr({"global_symbol": "main"})
B: R.Shape(shape, ndim=ndim) = A
return B
@R.function
def expected() -> R.Shape([16]):
R.func_attr({"global_symbol": "main"})
B = R.ShapeExpr([16])
return B
after = before.bind_params({"A": relax.ShapeExpr([16])})
tvm.ir.assert_structural_equal(expected, after)
prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32")
def test_bind_prim_value(prim_value_dtype):
if prim_value_dtype != "int64":
pytest.xfail(reason="Currently, only support int64 as known symbolic value")
N = tir.Var("N", prim_value_dtype)
value = tir.const(16, prim_value_dtype)
@R.function
def before(A: R.Prim(value=N)) -> R.Prim(value=N):
R.func_attr({"global_symbol": "main"})
B: R.Prim(value=N) = A
return B
@R.function
def expected() -> R.Prim(value=value):
R.func_attr({"global_symbol": "main"})
B = R.prim_value(value)
return B
after = before.bind_params({"A": relax.PrimValue(value)})
tvm.ir.assert_structural_equal(expected, after)
def test_error_on_unknown_var():
@R.function
def before(A: R.Tensor([16], dtype="float32")):
R.func_attr({"global_symbol": "main"})
return A
unknown_var = relax.Var("unknown_var")
with pytest.raises(tvm.TVMError):
before.bind_params({unknown_var: np.arange(16).astype("float32")})
def test_error_on_unknown_var_name():
@R.function
def before(A: R.Tensor([16], dtype="float32")):
R.func_attr({"global_symbol": "main"})
return A
with pytest.raises(tvm.TVMError):
before.bind_params({"unknown_var_name": np.arange(16).astype("float32")})
if __name__ == "__main__":
tvm.testing.main()