blob: 828aa92bda28afc6321812b75a94bea74a79f4ab [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm.testing
from tvm.script import ir as I, relax as R, tir as T
import pytest
def test_rewrite_defined_by_ir_module():
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.add(A, B)
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@R.function
def before(x: R.Tensor([32], "float32")):
R.func_attr({"global_symbol": "main"})
split = R.split(x, 2)
lhs = split[0]
rhs = split[1]
out = lhs + rhs
return out
@R.function
def expected(x: R.Tensor([32], "float32")):
R.func_attr({"global_symbol": "main"})
split = R.split(x, 2)
lhs = split[0]
rhs = split[1]
out = R.call_pure_packed(
"my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32")
)
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_missing_pattern_raises_error():
"""The rewriter must define a pattern to be matched"""
with pytest.raises(KeyError, match="pattern"):
@R.rewriter
class Rewriter:
@R.function
def replacement():
return R.tuple()
def test_incorrect_function_type_of_pattern_raises_error():
"""The rewriter's pattern must be a Relax function"""
with pytest.raises(TypeError, match="pattern"):
@R.rewriter
class Rewriter:
@T.prim_func
def pattern():
pass
@R.function
def replacement():
return R.tuple()
def test_missing_replacement_raises_error():
"""The rewriter must define a replacement"""
with pytest.raises(KeyError, match="replacement"):
@R.rewriter
class Rewriter:
@R.function
def pattern():
return R.tuple()
def test_incorrect_function_type_of_replacement_raises_error():
"""The rewriter's replacement must be a Relax function"""
with pytest.raises(TypeError, match="replacement"):
@R.rewriter
class Rewriter:
@R.function
def pattern():
return R.tuple()
@T.prim_func
def replacement():
pass
def test_mismatch_of_static_shapes_raises_error():
"""The pattern and replacement must accept the same shapes"""
with pytest.raises(ValueError, match="must have the same signature"):
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([32])):
return A
@R.function
def replacement(A: R.Tensor([16])):
return A
def test_rewriter_may_be_applied_to_ir_module():
"""A rewriter may mutate an IRModule
The `PatternMatchingRewriter.__call__` implementation may accept
either a single Relax function, or an entire IRModule. If it is
passed an IRModule, then all functions in the `IRModule` are
updated.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.add(A, B)
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@I.ir_module
class Before:
@R.function
def func_a(x: R.Tensor([32], "float32")):
split = R.split(x, 2)
lhs = split[0]
rhs = split[1]
out = lhs + rhs
return out
@R.function
def func_b(x: R.Tensor([16], "float32")):
out = x + x
return out
@I.ir_module
class Expected:
@R.function
def func_a(x: R.Tensor([32], "float32")):
split = R.split(x, 2)
lhs = split[0]
rhs = split[1]
out = R.call_pure_packed(
"my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32")
)
return out
@R.function
def func_b(x: R.Tensor([16], "float32")):
out = R.call_pure_packed(
"my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")
)
return out
After = Rewriter(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_rewriter_may_be_used_as_ir_transform():
"""A rewriter may be used as a tvm.ir.transform.Pass"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.add(A, B)
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor([16], "float32")):
y = x + x
return y
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor([16], "float32")):
out = R.call_pure_packed(
"my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")
)
return out
After = tvm.ir.transform.Sequential([Rewriter])(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_same_pattern_applied_multiple_times():
"""The pattern-match may apply multiple times"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.add(A, B)
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@R.function(private=True)
def before(x: R.Tensor([16], "float32")):
y = x + x
z = y + y
return z
@R.function(private=True)
def expected(x: R.Tensor([16], "float32")):
y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32"))
z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32"))
return z
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_composition_of_rewrite_rules():
"""Rewrite rules may be composed together"""
@R.rewriter
class RewriteAdd:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = A + B
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@R.rewriter
class RewriteMultiply:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = A * B
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
C = R.call_pure_packed(
"my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@R.function(private=True)
def before(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
D = A + B
E = C * D
return E
@R.function(private=True)
def expected(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32"))
E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32"))
return E
rewriter = RewriteAdd | RewriteMultiply
after = rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_recursive_rewrite_rules():
"""Rewrite rules are applied until convergence
In this test, both the `RewriteAdd` and `RewriteMultiply` patterns
must be applied in order to produce the expected output. However,
the `RewriteMultiply` pattern relies on the expression produced by
the `RewriteAdd` pass.
"""
@R.rewriter
class RewriteAdd:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A + A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return A * R.const(2.0, "float32")
@R.rewriter
class RewriteMultiply:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")):
C = A * B
return C
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")):
C = R.call_pure_packed(
"my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
return C
@R.function(private=True)
def before(A: R.Tensor([16], "float32")):
B = A + A
return B
@R.function(private=True)
def expected(A: R.Tensor([16], "float32")):
B = R.call_pure_packed(
"my_optimized_mul_impl",
A,
R.const(2.0, "float32"),
sinfo_args=R.Tensor([16], "float32"),
)
return B
rewriter = RewriteAdd | RewriteMultiply
after = rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_of_arbitrary_dtype():
"""A pattern-match may apply to a tensor with unknown dtype
In this test case, a pattern identifies `R.strided_slice` usage
which returns the last slice of an array, and replaces it with a
view into the input array.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]):
M = T.int64()
N = T.int64()
last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M])
last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0)
return last_slice_1d
@R.function
def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]):
M = T.int64()
N = T.int64()
# TODO(Lunderberg): Improve this syntax. A Relax
# PrimValue (e.g. `A.dtype.bits`) should be usable in any
# Relax context that accepts a `PrimExpr`. Currently,
# this requires `R.match_cast` to produce a TIR symbolic
# variable from the Relax PrimValue.
bits_per_element = T.uint8()
_ = R.match_cast(
A.dtype.bits,
R.Prim(value=bits_per_element),
)
lanes_per_element = T.uint16()
_ = R.match_cast(
A.dtype.lanes,
R.Prim(value=lanes_per_element),
)
last_slice = R.memory.view(
A,
[N],
relative_byte_offset=(M - 1)
* N
* T.ceildiv(
bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8
),
)
return last_slice
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([32, 16], "float16"),
B: R.Tensor(["P", "Q"], "int4x8"),
C: R.Tensor([16, 32]),
):
P = T.int64()
Q = T.int64()
A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32])
A_slice_1d = R.squeeze(A_slice_2d, axis=0)
B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P])
B_slice_1d = R.squeeze(B_slice_2d, axis=0)
C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16])
C_slice_1d = R.squeeze(C_slice_2d, axis=0)
return (A_slice_1d, B_slice_1d, C_slice_1d)
@I.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([32, 16], "float16"),
B: R.Tensor(["P", "Q"], "int4x8"),
C: R.Tensor([16, 32]),
):
P = T.int64()
Q = T.int64()
# The pattern matches any 2-d tensor, with any data type.
# When the match's shape and dtype are both known,
# normalization and canonicalization produces a statically
# known value for `relative_byte_offset`.
#
# Relative offset is `(31 rows) *
# (16 elements/row) *
# (2 bytes/element)`
A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992)
# The pattern can also match a 2-d tensor with dynamic
# shape. The `relative_byte_offset` uses the known
# datatype (4 bytes for each int4x8), but with dynamic
# shape variables substituted in where required.
#
# Relative offset is `((P-1) rows) *
# (Q elements/row) *
# (4 bytes/element)`
B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4)
# The pattern can also match a 2-d tensor with static
# shape, but unknown data type. The
# `relative_byte_offset` is determined based on the known
# number of elements, and the dynamic size of each
# element.
#
# Relative offset is `(15 rows) *
# (32 elements/row) *
# (ceildiv(bits*lanes,8) bytes/element)`
C_bits_per_element = T.uint8()
C_bits_prim_value = C.dtype.bits
_ = R.match_cast(
C_bits_prim_value,
R.Prim(value=C_bits_per_element),
)
C_lanes_per_element = T.uint16()
C_lanes_prim_value = C.dtype.lanes
_ = R.match_cast(
C_lanes_prim_value,
R.Prim(value=C_lanes_per_element),
)
C_slice_1d = R.memory.view(
C,
shape=[32],
relative_byte_offset=(
(C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7)
// 8
)
* 480,
)
return (A_slice_1d, B_slice_1d, C_slice_1d)
after = Rewriter(Before)
tvm.ir.assert_structural_equal(Expected, after)
def test_rewrite_may_introduce_private_relax_subroutines():
"""The replacement may contain subroutines"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A + A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return Rewriter.subroutine(A)
@R.function(private=True)
def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
B = A + A
C = B + B
return C
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
B = Expected.subroutine(A)
C = Expected.subroutine(B)
return C
@R.function(private=True)
def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
After = Rewriter(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_rewrite_only_introduces_private_subroutines_when_required():
"""Only subroutines that are used will be added to the module
Like `test_rewrite_may_introduce_private_relax_subroutines`, but
the rewritten function only requires some of the subroutines
provided by the rewriter.
"""
@R.rewriter
class RewriteAdd:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A + A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return RewriteAdd.subroutine_add(A)
@R.function(private=True)
def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
@R.rewriter
class RewriteMul:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A * A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32"))
@T.prim_func(private=True)
def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * A[i]
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
B = A + A
C = B + B
return C
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
B = Expected.subroutine_add(A)
C = Expected.subroutine_add(B)
return C
@R.function(private=True)
def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
rewriter = RewriteAdd | RewriteMul
After = rewriter(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_rewriter_may_not_introduce_public_subroutines():
"""The rewriter may only introduce private functions"""
with pytest.raises(ValueError, match="is publicly exposed"):
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A + A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return Rewriter.subroutine(A)
@R.function
def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
def test_rewrite_branches_may_reuse_subroutine_name():
"""Each rewriter is independent, and may reuse subroutine names"""
@R.rewriter
class RewriteAdd:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A + A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return RewriteAdd.subroutine(A)
@R.function(private=True)
def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
@R.rewriter
class RewriteMul:
@R.function
def pattern(A: R.Tensor([16], "float32")):
return A * A
@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32"))
@T.prim_func(private=True)
def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * A[i]
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
B = A + A
C = B * B
return C
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
B = Expected.subroutine(A)
C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32"))
return C
@R.function(private=True)
def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
return A * R.const(2.0, "float32")
@T.prim_func(private=True)
def subroutine_1(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * A[i]
rewriter = RewriteAdd | RewriteMul
After = rewriter(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_rewrite_of_explicit_relax_tuple():
"""The rewriter function may return a tuple
When it occurs explicitly within the Relax function, the tuple
pattern matches against the Relax tuple, and the Relax tuple is
replaced.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
proj_A = R.matmul(lhs_A, rhs)
proj_B = R.matmul(lhs_B, rhs)
proj_tuple = (proj_A, proj_B)
return proj_tuple
@R.function
def replacement(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
lhs = R.concat([lhs_A, lhs_B])
proj_concat = R.matmul(lhs, rhs)
proj_tuple = R.split(proj_concat, 2)
return proj_tuple
@R.function(private=True)
def before(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
proj_A = R.matmul(A, state)
proj_B = R.matmul(B, state)
proj_tuple = (proj_A, proj_B)
out = proj_tuple[0] + proj_tuple[1]
return out
@R.function(private=True)
def expected(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
concat_AB = R.concat([A, B])
proj_concat = R.matmul(concat_AB, state)
proj_tuple = R.split(proj_concat, 2)
out = proj_tuple[0] + proj_tuple[1]
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_of_output_relax_tuple():
"""The rewriter may update a tuple being returned
Unlike most relax expressions, tuples may appear as nested
expressions. Pattern-matching should be aware of this option.
Like `test_rewrite_of_explicit_relax_tuple`, but the tuple appears
as the return value in the function being modified.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
proj_A = R.matmul(lhs_A, rhs)
proj_B = R.matmul(lhs_B, rhs)
proj_tuple = (proj_A, proj_B)
return proj_tuple
@R.function
def replacement(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
lhs = R.concat([lhs_A, lhs_B])
proj_concat = R.matmul(lhs, rhs)
proj_tuple = R.split(proj_concat, 2)
return proj_tuple
@R.function(private=True)
def before(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
proj_A = R.matmul(A, state)
proj_B = R.matmul(B, state)
return (proj_A, proj_B)
@R.function(private=True)
def expected(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
concat_AB = R.concat([A, B])
proj_concat = R.matmul(concat_AB, state)
proj_tuple = R.split(proj_concat, 2)
return proj_tuple
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_of_implicit_tuple():
"""The rewriter function may return a tuple
The tuple being replaced does not need to explicitly exist within
the updated Relax function. So long as each element of the tuple
pattern matches a Relax expression, the pattern match can apply.
This rule ensures that pattern-matching is never broken when
`CanonicalizeBindings` is applied.
This test is identical to `test_rewrite_of_explicit_relax_tuple`,
except that the function does not contain the round trip of
packing `proj_A` and `proj_B` into a tuple, then immediately
unpacking them from the tuple.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
proj_A = R.matmul(lhs_A, rhs)
proj_B = R.matmul(lhs_B, rhs)
proj_tuple = (proj_A, proj_B)
return proj_tuple
@R.function
def replacement(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16], "float32"),
):
lhs = R.concat([lhs_A, lhs_B])
proj_concat = R.matmul(lhs, rhs)
proj_tuple = R.split(proj_concat, 2)
return proj_tuple
@R.function(private=True)
def before(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
proj_A = R.matmul(A, state)
proj_B = R.matmul(B, state)
out = proj_A + proj_B
return out
@R.function(private=True)
def expected(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
concat_AB = R.concat([A, B])
proj_concat = R.matmul(concat_AB, state)
proj_tuple = R.split(proj_concat, 2)
out = proj_tuple[0] + proj_tuple[1]
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_of_implicit_tuple_with_shared_wildcard():
"""Tuple elements may depend on the same input
Here, both elements of the tuple depend on `y`.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
x: R.Tensor([16], "float32"),
y: R.Tensor([16], "float32"),
z: R.Tensor([16], "float32"),
):
lhs = x + y
rhs = y + z
return (lhs, rhs)
@R.function
def replacement(
x: R.Tensor([16], "float32"),
y: R.Tensor([16], "float32"),
z: R.Tensor([16], "float32"),
):
return R.call_pure_packed(
"optimized_impl",
x,
y,
z,
sinfo_args=R.Tuple(
[
R.Tensor([16], "float32"),
R.Tensor([16], "float32"),
]
),
)
@R.function(private=True)
def before(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
lhs = A + B
rhs = B + C
out = R.multiply(lhs, rhs)
return out
@R.function(private=True)
def expected(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
lhs_rhs = R.call_pure_packed(
"optimized_impl",
A,
B,
C,
sinfo_args=R.Tuple(
[
R.Tensor([16], "float32"),
R.Tensor([16], "float32"),
]
),
)
out = R.multiply(lhs_rhs[0], lhs_rhs[1])
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_no_rewrite_of_implicit_tuple_when_shared_wildcard_is_mismatched():
"""Tuple elements must match simultaneously
Each element of the tuple matches individually, but the two
elements both depend on `B`. Because the first tuple element
would require `y = B`, while the second tuple element would
require `y = C`, the match fails.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
x: R.Tensor([16], "float32"),
y: R.Tensor([16], "float32"),
z: R.Tensor([16], "float32"),
):
lhs = x + y
rhs = y + z
return (lhs, rhs)
@R.function
def replacement(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
return R.call_pure_packed(
"optimized_impl",
A,
B,
C,
sinfo_args=R.Tuple(
[
R.Tensor([16], "float32"),
R.Tensor([16], "float32"),
]
),
)
@R.function(private=True)
def before(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
D: R.Tensor([16], "float32"),
):
lhs = A + B
rhs = C + D
out = R.multiply(lhs, rhs)
return out
expected = before
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_implicit_tuple_may_not_introduce_extra_compute():
"""Matching of implicit tuple may not cause extra compute
Here, the `(proj_A, proj_B)` tuple could be an implcit tuple
match, but that would repeat the computation of `proj_A`. It
would be computed once on its own, to be used for `proj_A_on_B`,
and once for computing `(proj_A, proj_B)`.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16, 16], "float32"),
):
proj_A = R.matmul(lhs_A, rhs)
proj_B = R.matmul(lhs_B, rhs)
proj_tuple = (proj_A, proj_B)
return proj_tuple
@R.function
def replacement(
lhs_A: R.Tensor([16, 16], "float32"),
lhs_B: R.Tensor([16, 16], "float32"),
rhs: R.Tensor([16, 16], "float32"),
):
lhs = R.concat([lhs_A, lhs_B])
proj_concat = R.matmul(lhs, rhs)
proj_tuple = R.split(proj_concat, 2)
return proj_tuple
@R.function(private=True)
def before(
state: R.Tensor([16, 16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
# This function has no location at which a tuple
# `(proj_A,proj_B)` could be constructed, then unpacked.
proj_A = R.matmul(A, state)
# A tuple `(proj_A, proj_B)` could not be constructed at this
# location, because `proj_B` has not yet been computed.
proj_A_on_B = R.matmul(proj_A, B)
proj_B = R.matmul(proj_A_on_B, state)
# A tuple `(proj_A, proj_B)` could be constructed here, but a
# use-site of `proj_A` has already occurred. Implicit
# matching of a tuple is only allowed if it would replace
# every use-site of a variable.
out = proj_A + proj_B
return out
expected = before
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_of_implicit_tuple_with_three_elements():
"""Implicit tuples may contain three elements"""
@R.rewriter
class Rewriter:
@R.function
def pattern(qkv: R.Tensor([12288], "float32")):
qkv_tuple = R.split(qkv, 3, axis=0)
q = qkv_tuple[0]
k = qkv_tuple[1]
v = qkv_tuple[2]
q_embed = R.call_pure_packed(
"rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32")
)
k_embed = R.call_pure_packed(
"rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32")
)
return (q_embed, k_embed, v)
@R.function
def replacement(qkv: R.Tensor([12288], "float32")):
return R.call_pure_packed(
"split_rotary_embedding",
[qkv],
sinfo_args=[
R.Tensor([4096], "float32"),
R.Tensor([4096], "float32"),
R.Tensor([4096], "float32"),
],
)
@R.function(private=True)
def before(
state: R.Tensor([4096], "float32"),
proj_qkv: R.Tensor([12288, 4096], "float32"),
kv_cache: R.Object,
):
qkv = R.matmul(proj_qkv, state)
qkv_tuple = R.split(qkv, 3, axis=0)
q = qkv_tuple[0]
k = qkv_tuple[1]
v = qkv_tuple[2]
q_embed = R.call_pure_packed(
"rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32")
)
k_embed = R.call_pure_packed(
"rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32")
)
attention = R.call_pure_packed(
"compute_self_attention",
[q_embed, k_embed, v, kv_cache],
sinfo_args=R.Tensor([4096]),
)
return attention
@R.function(private=True)
def expected(
state: R.Tensor([4096], "float32"),
proj_qkv: R.Tensor([12288, 4096], "float32"),
kv_cache: R.Object,
):
qkv = R.matmul(proj_qkv, state)
embedded_qkv_tuple = R.call_pure_packed(
"split_rotary_embedding",
[qkv],
sinfo_args=[
R.Tensor([4096], "float32"),
R.Tensor([4096], "float32"),
R.Tensor([4096], "float32"),
],
)
v = embedded_qkv_tuple[2]
q_embed = embedded_qkv_tuple[0]
k_embed = embedded_qkv_tuple[1]
attention = R.call_pure_packed(
"compute_self_attention",
[q_embed, k_embed, v, kv_cache],
sinfo_args=R.Tensor([4096]),
)
return attention
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_pattern_matching_may_not_reorder_across_impure_functions():
"""Matched pattern must be ordered with respect to impure functions
To ensure that debug printouts, memory management, performance
measurements, etc are not impacted by a pattern match, a pattern
must be entirely before, or entirely after an impure function. A
pattern match in which some parts of the matched expression are
performed before an impure function, while others are performed
afterwards, is not allowed.
In this test, the matmul and the add may not be fused, because the
impure print statement occurs between them.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
state = R.matmul(weights, state)
state = R.add(bias, state)
return state
@R.function
def replacement(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
return R.call_pure_packed(
"my_optimized_fma_impl",
state,
weights,
bias,
sinfo_args=R.Tensor([16], "float32"),
)
@R.function(private=True, pure=False)
def before(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
R.print(format="Start of function")
state = R.matmul(weights, state)
R.print(format="After matmul, before add")
state = R.add(bias, state)
R.print(format="End of function")
return state
expected = before
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_pattern_matching_may_occur_between_impure_functions():
"""Matched pattern may be adjacent to impure functions
To ensure that debug printouts, memory management, performance
measurements, etc are not impacted by a pattern match, a pattern
must be entirely before, or entirely after an impure function. A
pattern match in which some parts of the matched expression are
performed before an impure function, while others are performed
afterwards, is not allowed.
In this test, the matmul and the add may be fused, because the
pattern occurs without an impure print statement in-between.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
state = R.matmul(weights, state)
state = R.add(bias, state)
return state
@R.function
def replacement(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
return R.call_pure_packed(
"my_optimized_fma_impl",
state,
weights,
bias,
sinfo_args=R.Tensor([16], "float32"),
)
@R.function(private=True, pure=False)
def before(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
R.print(format="Start of function")
state = R.matmul(weights, state)
state = R.add(bias, state)
R.print(format="End of function")
return state
@R.function(private=True, pure=False)
def expected(
state: R.Tensor([16], "float32"),
weights: R.Tensor([16, 16], "float32"),
bias: R.Tensor([16], "float32"),
):
R.print(format="Start of function")
state = R.call_pure_packed(
"my_optimized_fma_impl",
state,
weights,
bias,
sinfo_args=R.Tensor([16], "float32"),
)
R.print(format="End of function")
return state
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_rewrite_may_apply_within_conditional():
"""Rewrites may apply within to inner dataflow regions
While dataflow regions may not contain conditionals, they may
occur within the body of conditionals.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
return A + B
@R.function
def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
return R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
@R.function(private=True)
def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = A + B
else:
C = A + B
out = C + B
return out
@R.function(private=True)
def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
else:
C = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
out = R.call_pure_packed(
"my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32")
)
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_match_dynamic_shape():
"""Pattern match/rewrites may be dynamic
The tuple being replaced does not need to explicitly exist within
the updated Relax function. So long as each element of the tuple
pattern matches a Relax expression, the pattern match can apply.
This rule ensures that pattern-matching is never broken when
`CanonicalizeBindings` is applied.
This test is identical to `test_rewrite_of_explicit_relax_tuple`,
except that the function does not contain the round trip of
packing `proj_A` and `proj_B` into a tuple, then immediately
unpacking them from the tuple.
"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
lhs_A: R.Tensor(["N1", "M"], "float32"),
lhs_B: R.Tensor(["N2", "M"], "float32"),
rhs: R.Tensor(["M"], "float32"),
):
proj_A = R.matmul(lhs_A, rhs)
proj_B = R.matmul(lhs_B, rhs)
return (proj_A, proj_B)
@R.function
def replacement(
lhs_A: R.Tensor(["N1", "M"], "float32"),
lhs_B: R.Tensor(["N2", "M"], "float32"),
rhs: R.Tensor(["M"], "float32"),
):
N1 = T.int64()
N2 = T.int64()
lhs = R.concat([lhs_A, lhs_B])
proj_concat = R.matmul(lhs, rhs)
proj_A: R.Tensor([N1], "float32") = R.strided_slice(
proj_concat, axes=[0], begin=[0], end=[N1]
)
proj_B: R.Tensor([N2], "float32") = R.strided_slice(
proj_concat, axes=[0], begin=[N1], end=[N2 + N1]
)
return (proj_A, proj_B)
@R.function(private=True)
def before(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
proj_A = R.matmul(A, state)
proj_B = R.matmul(B, state)
out = proj_A + proj_B
return out
@R.function(private=True)
def expected(
state: R.Tensor([16], "float32"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
concat_AB = R.concat([A, B])
proj_concat = R.matmul(concat_AB, state)
proj_A = R.strided_slice(proj_concat, axes=[0], begin=[0], end=[16])
proj_B = R.strided_slice(proj_concat, axes=[0], begin=[16], end=[32])
out = proj_A + proj_B
return out
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
def test_match_dynamic_pattern_against_dynamic_shape():
"""A dynamic pattern may match a static shape"""
@R.rewriter
class Rewriter:
@R.function
def pattern(
A: R.Tensor(["M", "N"], "float32"),
B: R.Tensor(["N", "N"], "float32"),
):
return R.matmul(A, B)
@R.function
def replacement(
A: R.Tensor(["M", "N"], "float32"),
B: R.Tensor(["N", "N"], "float32"),
):
M = T.int64()
N = T.int64()
return R.call_pure_packed(
"my_optimized_square_matmul",
A,
B,
sinfo_args=R.Tensor([M, N], "float32"),
)
@R.function(private=True)
def before(
A: R.Tensor(["N", "N*2"], "float32"),
B: R.Tensor(["N*2", "N*2"], "float32"),
C: R.Tensor(["N", "N"], "float32"),
):
N = T.int64()
D: R.Tensor([N, N * 2], "float32") = R.matmul(A, B)
E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D)
F: R.Tensor([N * 2, N], "float32") = R.matmul(E, C)
return F
@R.function(private=True)
def expected(
A: R.Tensor(["N", "N*2"], "float32"),
B: R.Tensor(["N*2", "N*2"], "float32"),
C: R.Tensor(["N", "N"], "float32"),
):
N = T.int64()
D: R.Tensor([N, N * 2], "float32") = R.call_pure_packed(
"my_optimized_square_matmul",
A,
B,
sinfo_args=R.Tensor([N, N * 2], "float32"),
)
E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D)
F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed(
"my_optimized_square_matmul",
E,
C,
sinfo_args=R.Tensor([N * 2, N], "float32"),
)
return F
after = Rewriter(before)
tvm.ir.assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()