blob: 18246d224b65c6cec7ff1328c7f913f0e75536c5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm.script import relax as R, tir as T
replace_by_tir_var = tvm.testing.parameter(
by_dict={"replace-by-string": False, "replace-by-tir-var": True}
)
def test_bind_static_value(replace_by_tir_var):
"""Symbolic vars may be replaced
The replaced variables may be given either as strings, or as TIR variables
"""
@R.function(private=True)
def before(A: R.Tensor(("M", "K")), B: R.Tensor(("K", "N"))) -> R.Tensor(("M", "N")):
return R.matmul(A, B)
@R.function(private=True)
def expected(A: R.Tensor((128, 64)), B: R.Tensor((64, 32))) -> R.Tensor((128, 32)):
return R.matmul(A, B)
if replace_by_tir_var:
M, K = before.params[0].struct_info.shape
_, N = before.params[1].struct_info.shape
symbolic_var_map = {M: 128, K: 64, N: 32}
else:
symbolic_var_map = {"M": 128, "K": 64, "N": 32}
after = before.bind_symbolic_vars(symbolic_var_map)
tvm.ir.assert_structural_equal(expected, after)
def test_error_with_duplicate_var_names():
"""Duplicate variable names may not be replaced by string
Two TIR variables may have the same name. If two symbolic
variables share the same name, the replacement map may not refer
to that variable by string.
"""
N1 = tvm.tir.Var("N", "int64")
N2 = tvm.tir.Var("N", "int64")
@R.function(private=True)
def func(A: R.Tensor((N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor((N1, N2)):
out: R.Tensor((N1, N2)) = R.matmul(A, B)
return out
with pytest.raises(tvm.TVMError):
func.bind_symbolic_vars({"N": 64})
def test_string_var_when_other_var_has_duplicate_var_names():
"""Like test_error_with_duplicate_var_names, but replacing a different variable
If two TIR variables share the same name, the restriction against
replacing variables by name only applies to those duplicate names.
Other variables may still be replaced by name.
"""
N1 = tvm.tir.Var("N", "int64")
N2 = tvm.tir.Var("N", "int64")
BatchSize = tvm.tir.Var("BatchSize", "int64")
@R.function(private=True)
def before(
A: R.Tensor((BatchSize, N1, N1)), B: R.Tensor((N1, N2))
) -> R.Tensor((BatchSize, N1, N2)):
out: R.Tensor((BatchSize, N1, N2)) = R.matmul(A, B)
return out
@R.function(private=True)
def expected(A: R.Tensor((16, N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor((16, N1, N2)):
out: R.Tensor((16, N1, N2)) = R.matmul(A, B)
return out
after = before.bind_symbolic_vars({"BatchSize": 16})
tvm.ir.assert_structural_equal(expected, after)
def test_error_with_nonexisting_var_name():
"""A string name of a symbolic var must be used by the function"""
@R.function(private=True)
def func(A: R.Tensor(("M", "N"))):
return A
with pytest.raises(tvm.TVMError):
func.bind_symbolic_vars({"non_existing_symbolic_var": 64})
def test_error_with_nonexisting_tir_var():
"""A TIR symbolic var must be a symbolic var of the function"""
@R.function(private=True)
def func(A: R.Tensor(["M", "N"])):
return A
with pytest.raises(tvm.TVMError):
func.bind_symbolic_vars({tvm.tir.Var("M", "int64"): 64})
def test_error_with_multiple_definitions():
"""The string/TIR var syntaxes may not define the same variable"""
@R.function(private=True)
def func(A: R.Tensor(["M", "N"])):
return A
tir_var = func.params[0].struct_info.shape[0]
symbolic_var_map = {tir_var: 0, "M": 0}
with pytest.raises(tvm.TVMError):
func.bind_symbolic_vars(symbolic_var_map)
def test_error_if_output_has_undefined():
"""The replacements may not introduce undefined symbolic vars"""
@R.function(private=True)
def func(A: R.Tensor(["M", "N"])):
return A
outside_var = tvm.tir.Var("outside_var", "int64")
with pytest.raises(tvm.TVMError):
func.bind_symbolic_vars({"M": outside_var * 2})
def test_replacements_may_produce_new_symbolic_vars():
"""The output may introduce symbolic vars, but they must be bound"""
@R.function(private=True)
def before(A: R.Tensor(["M", "N"])):
return A
@R.function(private=True)
def expected(A: R.Tensor(["outside_var * 2", "outside_var"])):
return A
outside_var = tvm.tir.Var("outside_var", "int64")
after = before.bind_symbolic_vars({"M": outside_var * 2, "N": outside_var})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_symbolic_vars_in_tensor_shape():
"""The bound variable should be replaced when appearing in struct info"""
@R.function(private=True)
def before(A: R.Tensor(["M", "N"])):
M = T.int64()
N = T.int64()
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
return B
@R.function(private=True)
def expected(A: R.Tensor(["M", 16])):
M = T.int64()
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
return B
after = before.bind_symbolic_vars({"N": 16})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_symbolic_vars_in_shape_expr():
"""The bound variable should be replaced when appearing in R.Shape"""
@R.function(private=True)
def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])):
M = T.int64()
N = T.int64()
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
return B
@R.function(private=True)
def expected(A: R.Tensor(["M * 16"]), x: R.Shape(["M", 16])):
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
return B
after = before.bind_symbolic_vars({"N": 16})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_defining_of_symbolic_vars_in_prim_value():
"""R.Prim may define symbolic variables
This case is a bit odd, because it always results in a
fully-constrained parameter at the relax level. After binding in
this test case, we have a function that accepts three parameters,
and the third parameter must always be the number 16.
However, this provides the most consistent behavior with other
uses of `relax.Function.bind_symbolic_vars`, which restricts the
allowed values for each parameter, but does not alter the number
of parameters. This is in contrast to the `BindParams` pass,
which provides a known value for relax parameters, removing them
from the function signature.
This convention also prevents surprise changes to the function
signature, such as shown in
`test_bind_symbolic_vars_with_expr_in_prim_value`.
"""
@R.function(private=True)
def before(A: R.Tensor(["M * N"]), x: R.Prim(value="M"), y: R.Prim(value="N")):
M = T.int64()
N = T.int64()
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
return B
@R.function(private=True)
def expected(A: R.Tensor(["M * 16"]), x: R.Prim(value="M"), y: R.Prim(value=16)):
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
return B
after = before.bind_symbolic_vars({"N": 16})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_usage_of_symbolic_vars_in_prim_value():
"""R.Prim may use symbolic variables defined by other parameters
Like test_bind_defining_of_symbolic_vars_in_prim_value, but with
R.Prim using a symbolic variable rather than defining it.
This also demonstrates why we should not remove fully-constrained
R.Prim function parameters. In this case, we have a function that
accepts two parameters, and we have specialized the shape of the
first parameter. It would be unexpected for specialization of the
first parameter to result in removal of a different parameter
altogether.
"""
@R.function(private=True)
def before(A: R.Tensor(["M", "N"]), x: R.Prim(value="M*N")):
M = T.int64()
N = T.int64()
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
return B
@R.function(private=True)
def expected(A: R.Tensor([16, 16]), x: R.Prim(value=256)):
B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([512]))
return B
after = before.bind_symbolic_vars({"M": 16, "N": 16})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_strided_slice():
"""relax.op.strided_slice stores PrimExpr attributes"""
@R.function(private=True)
def before(A: R.Tensor(["M", "N"])):
N = T.int64()
B = R.strided_slice(A, [1], [0], [N // 4])
return B
@R.function(private=True)
def expected(A: R.Tensor(["M", 32])):
B = R.strided_slice(A, [1], [0], [8])
return B
after = before.bind_symbolic_vars({"N": 32})
tvm.ir.assert_structural_equal(expected, after)
def test_bind_inside_match_cast():
"""Symbolic variables may occur within R.match_cast"""
@R.function(private=True)
def before(A: R.Tensor(["M", "N"]), B: R.Tensor(ndim=2)):
M = T.int64()
N = T.int64()
C = R.match_cast(B, R.Tensor([M, N]))
D = R.add(A, C)
return D
@R.function(private=True)
def expected(A: R.Tensor(["M", 32]), B: R.Tensor(ndim=2)):
M = T.int64()
C = R.match_cast(B, R.Tensor([M, 32]))
D = R.add(A, C)
return D
after = before.bind_symbolic_vars({"N": 32})
tvm.ir.assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()