blob: 1d982b0972ed80dbdc4ef545c30906f6c26a8c04 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.relax.transform.transform import CanonicalizeBindings
import tvm.script
import tvm.testing
import pytest
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.script import ir as I, relax as R, tir as T
def verify(input, expected):
tvm.ir.assert_structural_equal(CanonicalizeBindings()(input), expected)
def test_simple_assignments():
@I.ir_module
class TestChainAssignments:
@R.function
def main(x: R.Tensor):
y = x
z = y
q = z
p = q
o = p
return o
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
return x
verify(TestChainAssignments, Expected)
def test_dataflow_block():
@I.ir_module
class TestDataflowAssignments:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.const(1)
z = y
o = z
p = o
m = p
n = m
R.output(n)
return n
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
n = R.const(1)
R.output(n)
return n
verify(TestDataflowAssignments, Expected)
def test_assign_to_output_in_dataflow_block():
@I.ir_module
class TestDataflowAssignments:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = x # is not a dataflow var
z = y
o = z
p = o
m = p
n = m
R.output(n)
return n
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
# we get a dataflow block where the
# only assignment is n = x, which we can eliminate,
# resulting in an empty block that is normalized away
return x
verify(TestDataflowAssignments, Expected)
def test_ops():
@I.ir_module
class TestOps:
@R.function
def main(x: R.Tensor, y: R.Tensor):
w = y
q = x
z = R.add(w, q)
return R.add(q, z)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
z = R.add(y, x)
return R.add(x, z)
verify(TestOps, Expected)
@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same struct info.")
def test_casting():
@I.ir_module
class TestCasting:
@R.function
def main(x: R.Tensor) -> R.Object:
y = x
# z will be treated as object type even though it's a tensor
z: R.Object = y
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor) -> R.Object:
# Cannot unify because the cast indicates user intent
z: R.Object = x
return z
verify(TestCasting, Expected)
def test_match_cast():
@I.ir_module
class TestMatchCast:
@R.function
def main(x: R.Tensor):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
# can't get rid of z because its struct_info is different from x's
m, n = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((m, n)))
return z
verify(TestMatchCast, Expected)
def test_same_shape():
@I.ir_module
class TestSameShape:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
y = x
# trivial check
z = R.match_cast(x, R.Tensor((m, n), "float32"))
w = z
q = R.add(w, y)
return R.add(q, w)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
# the trivial check is canonicalized into a var binding
# and then eliminated
q = R.add(x, x)
return R.add(q, x)
verify(TestSameShape, Expected)
def test_change_shape():
@I.ir_module
class TestChangeShape:
@R.function
def main(x: R.Tensor(ndim=2)):
y = x
# The MatchCast is non-trivial, as it introduces new shape
# vars. Because the input tensor has an unknown shape
# rather than a symbolic shape, these new shape vars
# cannot be expressed in terms of previous variables.
# Therefore, the match cast must be retained.
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((o, p)))
w = z
q = R.add(w, y)
return R.add(q, w)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(ndim=2)):
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((o, p)))
# the struct_info field on q will need to be updated
q = R.add(z, x)
return R.add(q, z)
verify(TestChangeShape, Expected)
def test_replace_symbolic_variable_and_remove_match_cast():
@I.ir_module
class TestChangeShape:
@R.function
def main(x: R.Tensor(("m", "n"))):
y = x
# The MatchCast is non-trivial, as it introduces new shape
# vars. However, the new shape vars are redundant, and
# are replaced by canonicalization. After replacing the
# new shape vars, the MatchCast is trivial and may be
# removed.
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((o, p)))
w = z
q = R.add(w, y)
return R.add(q, w)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"))):
m = T.int64()
n = T.int64()
q: R.Tensor([m, n]) = R.add(x, x)
return R.add(q, x)
verify(TestChangeShape, Expected)
def test_replace_symbolic_variable_and_remove_match_cast_of_tuple():
"""Symbolic variables may be defined in R.match_cast of tuple
This test is similar to
`test_replace_symbolic_variable_and_remove_match_cast`, except
that the MatchCast is performed on a Relax tuple.
This is a regression test. Earlier implementations only inferred
TIR variables from `R.match_cast` of tensors, shapes, and prim
values, but omitted tuples.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tuple(R.Tensor(("m", "n")))):
y = x
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tuple(R.Tensor((o, p))))
w = z
q = R.add(w[0], y[0])
return R.add(q, w[0])
@I.ir_module
class Expected:
@R.function
def main(x: R.Tuple(R.Tensor(("m", "n")))):
q = R.add(x[0], x[0])
return R.add(q, x[0])
verify(Before, Expected)
def test_unwrap_tuple():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
w = tuple_var[0]
q = tuple_var[1]
z = R.add(w, q)
return R.add(q, z)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
z = R.add(x, y)
return R.add(y, z)
verify(Before, Expected)
def test_basic_folding_example():
@I.ir_module
class Input:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
y = R.const(1)
n = y
R.output(n)
return n
@I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return n
verify(Input, Expected)
def test_fold_match_cast():
@I.ir_module
class Input:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
y = R.const(1)
n = R.match_cast(y, R.Tensor((), "int32"))
R.output(n)
return n
@I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
# the cast is trivial, so it is removed
n = R.const(1)
R.output(n)
return n
verify(Input, Expected)
def test_fold_variables_from_match_cast():
"""Symbolic variables in R.match_cast may be inferred
If the argument to `R.match_cast` has known shape parameters, they
may be used to infer symbolic shape parameters.
"""
@I.ir_module
class Before:
@R.function
def main(
state: R.Tensor([16], dtype="float32"),
A: R.Tensor([16, 16], dtype="float32"),
B: R.Tensor([16, 16], dtype="float32"),
):
N1 = T.int64()
M = T.int64()
N2 = T.int64()
# The symbolic variables `N1`, `N2` and `M` are defined by
# these `R.match_cast` statements. Since the inputs have
# a known shape, the values of these symbolic variables
# may be inferred.
lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32"))
lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32"))
rhs = R.match_cast(state, R.Tensor([M], dtype="float32"))
# The symbolic shapes propagate downstream.
lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void")
proj_A = R.strided_slice(
proj_concat,
(R.prim_value(0),),
(R.prim_value(0),),
(R.prim_value(N1),),
assume_inbound=False,
)
proj_B = R.strided_slice(
proj_concat,
[R.prim_value(0)],
[R.prim_value(N1)],
[R.prim_value(N1 + N2)],
assume_inbound=False,
)
return (proj_A, proj_B)
@I.ir_module
class Expected:
@R.function
def main(
state: R.Tensor([16], dtype="float32"),
A: R.Tensor([16, 16], dtype="float32"),
B: R.Tensor([16, 16], dtype="float32"),
):
# The function no longer depends on symbolic variables.
# Shape inference is now propagated using the
# statically-known shapes.
lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0)
proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void")
proj_A: R.Tensor([16], dtype="float32") = R.strided_slice(
proj_concat,
[R.prim_value(0)],
[R.prim_value(0)],
[R.prim_value(16)],
assume_inbound=False,
)
proj_B: R.Tensor([16], dtype="float32") = R.strided_slice(
proj_concat,
[R.prim_value(0)],
[R.prim_value(16)],
[R.prim_value(32)],
assume_inbound=False,
)
return (proj_A, proj_B)
verify(Before, Expected)
def test_inconsistent_match_cast_raises_error():
"""Symbolic variables from R.match_cast must be consistent
All match cast statements must provide consistent definitions for
symbolic variables. In this test, the value of `M` would be
inferred as 16 from either `state` or `A`, but would be inferred
as 32 from `B`.
"""
@I.ir_module
class Before:
@R.function
def main(
state: R.Tensor([16], dtype="float32"),
A: R.Tensor([16, 16], dtype="float32"),
B: R.Tensor([32, 32], dtype="float32"),
):
N1 = T.int64()
M = T.int64()
N2 = T.int64()
# These R.match_cast statements define inconsistent values
# for the symbolic shape parameters.
lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32"))
lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32"))
rhs = R.match_cast(state, R.Tensor([M], dtype="float32"))
lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void")
proj_A = R.strided_slice(
proj_concat,
(R.prim_value(0),),
(R.prim_value(0),),
(R.prim_value(N1),),
assume_inbound=False,
)
proj_B = R.strided_slice(
proj_concat,
[R.prim_value(0)],
[R.prim_value(N1)],
[R.prim_value(N1 + N2)],
assume_inbound=False,
)
return (proj_A, proj_B)
with pytest.raises(ValueError, match="MatchCast statements must be consistent"):
CanonicalizeBindings()(Before)
def test_match_cast_may_have_distinct_values_in_branches():
"""Conditional branches may have different values of symbolic variables
Here, the value of `N` can be inferred as 16 within the `if`
branch and as 32 within the `else` branch.
"""
@I.ir_module
class Before:
@R.function
def main(
state: R.Tensor(["N"], dtype="float32"),
A: R.Tensor(["M", 16], dtype="float32"),
B: R.Tensor(["M", 32], dtype="float32"),
scale: R.Prim("float32"),
):
N = T.int64()
M = T.int64()
if N == 16:
weights: R.Tensor([M, 16], "float32") = A * scale
weights: R.Tensor([M, N], "float32") = R.match_cast(
weights, R.Tensor([M, N], "float32")
)
weights: R.Tensor([M, N], "float32") = weights * scale
else:
weights: R.Tensor([M, 32], "float32") = B * scale
weights: R.Tensor([M, N], "float32") = R.match_cast(
weights, R.Tensor([M, N], "float32")
)
weights: R.Tensor([M, N], "float32") = weights * scale
weights: R.Tensor([M, N], "float32") = weights * scale
out: R.Tensor([M], "float32") = R.matmul(weights, state)
return out
@I.ir_module
class Expected:
@R.function
def main(
state: R.Tensor(["N"], dtype="float32"),
A: R.Tensor(["M", 16], dtype="float32"),
B: R.Tensor(["M", 32], dtype="float32"),
scale: R.Prim("float32"),
):
N = T.int64()
M = T.int64()
if N == 16:
# Prior to the R.match_cast, the
weights: R.Tensor([M, 16], "float32") = A * scale
# The scaled weights within the branch may perform
# shape inference knowing that N==16.
weights: R.Tensor([M, 16], "float32") = weights * scale
# The match cast on exiting the if branch restores the
weights = R.match_cast(weights, R.Tensor([M, N], "float32"))
else:
# Prior to the R.match_cast, the
weights: R.Tensor([M, 32], "float32") = B * scale
# Within the else-branch, the R.match_cast implies
# that N==32. While this conflicts with the earlier
# definition, the two occur in separate branches, so
# this is legal.
# The scaled weights within the branch may perform
# shape inference knowing that N==32.
weights: R.Tensor([M, 32], "float32") = weights * scale
weights = R.match_cast(weights, R.Tensor([M, N], "float32"))
# Outside of the conditional, we no longer have a known
# value for N, so this shape inference must be done using
# a dynamic shape for `N`.
weights: R.Tensor([M, N], "float32") = weights * scale
# After the conditional branch, we no longer have a known
# value of N, so this shape inference must use the dynamic
# shape.
out: R.Tensor([M], "float32") = R.matmul(weights, state)
return out
verify(Before, Expected)
def test_multiple_outputs():
@I.ir_module
class Input:
@R.function
def main():
with R.dataflow():
x = R.const(1)
y = R.const(1)
z = R.const(1)
l = x
m = y
n = z
R.output(l, m, n)
return (l, m, n)
@I.ir_module
class Expected:
@R.function
def main():
with R.dataflow():
l = R.const(1)
m = R.const(1)
n = R.const(1)
R.output(l, m, n)
return (l, m, n)
verify(Input, Expected)
def test_single_output_multiple_nondataflow():
"""Non-dataflow vars being updated may also be part trivial bindings
Like `test_multiple_outputs`, but only `n` is used in the return
statement.
"""
@I.ir_module
class Input:
@R.function
def main():
with R.dataflow():
x = R.const(1)
y = R.const(1)
z = R.const(1)
l = x
m = y
n = z
R.output(l, m, n)
return n
@I.ir_module
class Expected:
@R.function
def main():
with R.dataflow():
l = R.const(1)
m = R.const(1)
n = R.const(1)
R.output(n)
return n
verify(Input, Expected)
def test_fold_const_to_output():
@I.ir_module
class Before:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return n
@I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return R.const(1)
verify(Before, Expected)
def test_canonicalize_var_to_dataflow_var_if_legal():
"""Canonicalize Var to DataflowVar inside DataflowBlock
DataflowVar instances may only be used inside a DataflowBlock. If
a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and
`y` is a `Var`, replacing `y` with `x` may result in usage of a
`DataflowVar` outside of a `DataflowBlock`.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = R.add(y, R.const(1))
R.output(y, z)
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = R.add(y, R.const(1))
R.output(z)
return z
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_update_dataflow_computations_if_var_replacement_occurs():
"""Canonicalize Var to DataflowVar inside DataflowBlock
DataflowBlocks may produce additional outputs after the first
output Var, and these additional outputs may be in terms of the
first output. Computations that depend on a replaced var must be
updated to remain well-formed.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
lv1 = R.add(x, R.const(1))
gv1 = lv1
gv2 = R.add(lv1, R.const(1))
R.output(gv1, gv2)
return (gv1, gv2)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
# lv1 has been replaced with gv1
gv1 = R.add(x, R.const(1))
# So gv1 must be used in the computation of gv2
gv2 = R.add(gv1, R.const(1))
R.output(gv1, gv2)
return (gv1, gv2)
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
"""Canonicalize Var to DataflowVar inside DataflowBlock
Like test_update_dataflow_computations_if_var_replacement_occurs,
but the usage of a DataflowVar occurs before the trivial binding
that causes it to be replaced.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
lv1 = R.add(x, R.const(1))
gv2 = R.add(lv1, R.const(1))
gv1 = lv1
R.output(gv1, gv2)
return (gv1, gv2)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
# lv1 has been replaced with gv1
gv1 = R.add(x, R.const(1))
# So gv1 must be used in the computation of gv2
gv2 = R.add(gv1, R.const(1))
# Even though the trivial binding of "gv1 = lv1"
# occurred in this position.
R.output(gv1, gv2)
return (gv1, gv2)
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_replace_var_with_dataflow_if_all_usage_within_dataflow_block():
"""Canonicalize Var to DataflowVar inside DataflowBlock
Like `test_update_dataflow_computations_if_var_replacement_occurs`,
except that `gv1` is not part of the function's return value. When
deciding which variable to replace, the following logic is applied:
1. Normally, when encountering `x = y`, replace usage of `x` with `y`.
2. Unless the trivial binding is a `var_x = dataflow_y`, in which case
replace `dataflow_y` with `var_x` at the point of definition. This
prevents usage of `dataflow_y` from escaping the dataflow block.
3. Unless `var_x` has no usage outside the dataflow block, in which
case we replace usage of `var_x` with `dataflow_y`.
The third rule ensures that canonicalization can occur in a single
step. Otherwise, the output of this test case would contain a
non-dataflow var defined within a dataflow block, and only used within
that dataflow block. (Equivalent to the input for the test case
`test_canonicalize_var_to_dataflow_var_if_legal`.)
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
lv1 = R.add(x, R.const(1))
gv1 = lv1
gv2 = R.add(lv1, R.const(1))
R.output(gv1, gv2)
return gv2
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
gv1 = R.add(x, R.const(1))
gv2 = R.add(gv1, R.const(1))
R.output(gv2)
return gv2
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_var_to_dataflow_with_trivial_binding():
"""Canonicalize Var to DataflowVar inside DataflowBlock
Like
`test_replace_var_with_dataflow_if_all_usage_within_dataflow_block`,
except the non-DataflowVar is on the right-hand side of the trivial
binding.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
gv1 = R.add(x, R.const(1))
lv1 = gv1
gv2 = R.add(lv1, R.const(1))
R.output(gv1, gv2)
return gv2
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
gv1 = R.add(x, R.const(1))
gv2 = R.add(gv1, R.const(1))
R.output(gv2)
return gv2
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_with_updated_struct_info():
"""CanonicalizeBindings and Normalizer may both replace a Var
If the CanonicalizeBindings pass has no replacements to make for a
variable, it must still delegate to the ExprMutator. This is because
a variable replacement may have occurred as part of the IRNormalizer,
in order to provide better struct info.
"""
@I.ir_module
class Before:
@R.function(private=True)
def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16), dtype="int32"):
# CanonicalizeBindings recognizes this trivial binding, and
# replaces `B` with `A`.
B = A
# The value is updated from `R.add(B,B)` to `R.add(A,A)`.
# Changing the value triggers struct inference, allowing the
# shape to be updated to `[n,16]`. This requires a variable
# replacement, which is tracked by the `ExprMutator`.
C: R.Tensor(dtype="int32", ndim=2) = R.add(B, B)
# Replacement of `C` is not explicitly tracked by
# CanonicalizeBindings. However, if CanonicalizeBindings just
# returns `GetRef<Var>(var)`, `ExprMutator` cannot apply the
# replacement, and this will try to return the old
# version of `C` with `ndim=2`.
return C
@I.ir_module
class Expected:
@R.function(private=True)
def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16), dtype="int32"):
n = T.int64()
C: R.Tensor([n, 16], "int32") = R.add(A, A)
return C
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_trivial_binding_to_dataflow_var():
"""Canonicalize Var to DataflowVar inside DataflowBlock
DataflowVar instances may only be used inside a DataflowBlock. If
a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and
`y` is a `Var`, replacing `y` with `x` may result in usage of a
`DataflowVar` outside of a `DataflowBlock`.
If a binding exists solely to convert from DataflowVar into Var,
then canonicalization replaces the earlier DataflowVar with a Var.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = y
R.output(z)
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
R.output(y)
return y
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_multiple_trivial_binding_to_dataflow_var():
"""Canonicalize Var to DataflowVar inside DataflowBlock
Like test_canonicalize_trivial_binding_to_dataflow_var, but there
exist multiple trivial bindings to the DataflowVar.
"""
@I.ir_module
class Before:
@R.function
def main(w: R.Tensor):
with R.dataflow():
x = R.add(w, R.const(1))
y = x
z = x
R.output(y, z)
return (y, z)
@I.ir_module
class Expected:
@R.function
def main(w: R.Tensor):
with R.dataflow():
x = R.add(w, R.const(1))
R.output(x)
return (x, x)
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_trivial_var_binding_inside_dataflow_block():
"""Canonicalize Var to DataflowVar inside DataflowBlock
Canonicalization handles cases where a Var could be replaced by a
DataflowVar, and where a Var is a trivial binding. If these two
cases both occur, should produce reasonable results.
"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = y
R.output(y, z)
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
R.output(y)
return y
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_across_non_dataflow_tuple():
"""Canonicalize Var to DataflowVar inside DataflowBlock"""
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = (y,)
gv = R.add(z[0], R.const(1))
R.output(z, gv)
return gv
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.add(x, R.const(1))
z = (y,)
gv = R.add(y, R.const(1))
R.output(gv)
return gv
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_var_used_in_distinct_df_blocks():
"""If a var is used only in dataflow blocks,
but outside of the one where it was originally defined,
it should be exposed as an output."""
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
# v must remain exposed!
R.output(v)
_ = R.print(format="Hi mom!")
with R.dataflow():
a = R.multiply(v, v)
b = R.add(a, a)
c = R.subtract(b, a)
d = R.add(c, c)
R.output(d)
return d
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Before, after)
def test_inner_function():
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
@R.function(pure=False)
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
R.output(z, w, v)
_ = R.print(format="oops")
with R.dataflow():
a = R.multiply(v, v)
b = R.add(a, a)
c = R.multiply(a, b)
R.output(a, b, c)
return c
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
R.output(inner_func, z, v, w)
q = inner_func(w, v)
with R.dataflow():
a = R.multiply(q, q)
b = R.add(a, a)
c = R.multiply(b, a)
R.output(a, b, c)
return c
# expected: we do not need to expose all the outputs
@I.ir_module
class Expected:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
@R.function(pure=False)
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
R.output(v)
_ = R.print(format="oops")
with R.dataflow():
a = R.multiply(v, v)
b = R.add(a, a)
c = R.multiply(a, b)
R.output(c)
return c
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
R.output(inner_func, v, w)
q = inner_func(w, v)
with R.dataflow():
a = R.multiply(q, q)
b = R.add(a, a)
c = R.multiply(b, a)
R.output(c)
return c
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalize_inside_branches():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
R.output(z)
if R.const(True):
with R.dataflow():
w = R.add(z, z)
v = R.multiply(w, w)
# w does not need to be output
R.output(w, v)
q = v
else:
with R.dataflow():
w = R.multiply(z, z)
v = R.add(w, w)
R.output(w, v)
q = v
return q
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
R.output(z)
if R.const(True):
with R.dataflow():
w = R.add(z, z)
v = R.multiply(w, w)
R.output(v)
q = v
else:
with R.dataflow():
w = R.multiply(z, z)
v = R.add(w, w)
R.output(v)
q = v
return q
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_canonicalization_causes_struct_info_update():
"""Regression test for failure mode causing undefined variable
The ExprMutator is only allowed to update a variable's struct info
if the value bound to it has new struct info. When
CanonicalizeBindings replaces a trivial binding, this may provide
better struct info as a result. If this happens, the
In previous implementations, ExprMutator::ReEmitBinding defined a
remap for `binding->var->vid`, even if the derived class defined a
replacement by overriding `VisitVarDef`. If the derived class
defines a new variable binding by overriding `VisitVarDef`, and
also causes a variable replacement by overriding `VisitExpr` and
returning a type with different struct info, then `ExprMutator`
must check for both `binding->var->vid` *AND* `new_var->vid`. The
former may be present in the unmodified graph, and the latter may
be produced by the derived class before delegating to the base
class.
"""
@I.ir_module
class Before:
@R.function
def transform_params(
A: R.Tensor(("vocab_size", 4096), dtype="float16"),
B: R.Tensor((6144, 4096), dtype="float16"),
):
with R.dataflow():
# Trivial binding of `DataFlow = NonDataFlow`.
# Wherever `C` is used, Canonicalization will attempt
# to replace it with `B`.
C = B
# RHS contains `(A,C)`, which CanonicalizeBindings
# replaces with `(A,B)`. Because this changes the
# RHS, a new LHS (and new struct info!) will be
# generated.
D: R.Tuple(
R.Tensor(dtype="float16", ndim=2),
R.Tensor((6144, 4096), dtype="float16"),
) = (A, C)
# Trivial binding of `NonDataFlow = DataFlow`. The
# definition of `D` will be replaced with a definition
# of `E`. This definition of `E` will then be updated
# to have a known shape.
E = D
R.output(E)
# By the time `E` is encountered at a usage site, the
# `ExprMutator` must have a replacement for the old
# version of `E` with `ndim=2` to the new versions of `E`
# with `shape=[vocab_size,4096]`.
return E
@I.ir_module
class Expected:
@R.function
def transform_params(
A: R.Tensor(("vocab_size", 4096), dtype="float16"),
B: R.Tensor((6144, 4096), dtype="float16"),
):
vocab_size = T.int64()
with R.dataflow():
E: R.Tuple(
R.Tensor((vocab_size, 4096), dtype="float16"),
R.Tensor((6144, 4096), dtype="float16"),
) = (A, B)
R.output(E)
return E
after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)
def test_unwrap_tuple_of_constant():
@I.ir_module
class TestChainAssignments:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(x, y)
return z
@I.ir_module
class Expected:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(R.const(0, "int64"), R.const(1, "int64"))
return z
verify(TestChainAssignments, Expected)
def test_trivial_binding_of_replaced_non_dataflow_var():
@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = A
C = R.add(A, B)
R.output(A, B, C)
return C
@I.ir_module
class Expected:
@R.function
def main(param_tuple: R.Tuple([R.Tensor])):
with R.dataflow():
A = param_tuple[0]
C = R.add(A, A)
R.output(C)
return C
After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def _get_binding_names(mod):
return [binding.var.name_hint for binding in mod["main"].body.blocks[0].bindings]
expected_names = _get_binding_names(Expected)
after_names = _get_binding_names(After)
assert after_names == expected_names
def test_trace_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""
@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
output = (A, B, C)
R.output(output)
return output
@I.ir_module
class Expected:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
R.output()
return param_tuple
After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_trace_partial_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""
@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
output = (A, B)
R.output(output)
return output
Expected = Before
After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)
if __name__ == "__main__":
tvm.testing.main()