blob: 90e2948a320c04453f5ea2223e1b7b8500d35324 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import math
import pytest
import tvm.testing
from tvm import relax as rx
from tvm import tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
from tvm.script import relax as R
from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
k = T.int32()
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
C = T.match_buffer(z, (32, 32))
for i0, j0, k0 in T.grid(32, 32, 32):
with T.block():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]
@T.prim_func
def tir_relu(x: T.handle, y: T.handle):
T.func_attr({"global_symbol": "tir_relu"})
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
for i, j in T.grid(32, 32):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)
@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
T.func_attr({"global_symbol": "tir_zeros"})
A = T.match_buffer(x, [n])
for i in range(n):
with T.block():
vi = T.axis.remap("S", [i])
A[vi] = 1.0
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_tir(
cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
return gv
main_fn = Module["main"]
bindings = main_fn.body.blocks[0].bindings
## Node-wise Matching
def test_expr_pattern():
ep = is_expr(rx.Var("x"))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, rx.Var)
def test_var_pattern():
v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
assert v.match(rx.Var("x"))
assert is_var().match(rx.Var("x"))
assert is_var().match(rx.DataflowVar("x")) # DataflowVar is also a Var
assert not v.match(rx.GlobalVar("x"))
def test_dataflow_var_pattern():
v = is_dfv("x")
assert isinstance(v, DataflowVarPattern)
assert v.name == "x"
assert v.match(rx.DataflowVar("x"))
assert not v.match(rx.GlobalVar("x"))
assert is_dfv().match(bindings[0].var)
def test_global_var_pattern():
assert is_gv("x").match(rx.GlobalVar("x"))
# TODO: disabled as regex is not supported due to
# symbol conflict with PyTorch
# assert is_gv("x.*").match(rx.GlobalVar("x_2"))
assert is_gv().match(rx.GlobalVar("x"))
assert not is_gv("x").match(rx.GlobalVar("y"))
assert not is_gv("x").match(rx.Var("x"))
def test_constant_pattern():
c = is_const()
assert isinstance(c, ConstantPattern)
assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]))
def test_wildcard_pattern():
wc = wildcard()
assert isinstance(wc, WildcardPattern)
assert wc.match(rx.Var("x"))
def test_call_pattern():
wc1 = wildcard()
wc2 = wildcard()
c = is_op("relax.add")(wc1, wc2)
assert isinstance(c, CallPattern)
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)
assert c.match(rx.op.add(rx.Var("x"), rx.Var("y")))
def test_function_pattern():
wc1 = wildcard()
wc2 = wildcard()
f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2))
assert isinstance(f, FunctionPattern)
assert isinstance(f.params[0], WildcardPattern)
assert isinstance(f.params[1], WildcardPattern)
assert isinstance(f.body, CallPattern)
assert isinstance(f.body.args[0], WildcardPattern)
assert isinstance(f.body.args[1], WildcardPattern)
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")))
assert not f.match(
rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32"))
)
def test_tuple_pattern():
wc1 = wildcard()
wc2 = is_dfv()
t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")]))
assert not t.match(rx.Tuple([]))
assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# Negative index is also allowed
assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# None means any index.
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
with pytest.raises(IndexError):
t[2] # index cannot be greater than or equal to the tuple size.
def test_unordered_tuple_pattern():
t = is_tuple([is_const(), is_dfv()], unordered=True)
assert isinstance(t, UnorderedTuplePattern)
assert isinstance(t.fields[0], ConstantPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")]))
assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([]))
def test_tuple_get_item_pattern():
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
def test_or_pattern():
dfv_or_gv = is_dfv("x") | is_gv("x")
assert isinstance(dfv_or_gv, OrPattern)
assert dfv_or_gv.match(rx.DataflowVar("x"))
assert dfv_or_gv.match(rx.GlobalVar("x"))
assert not dfv_or_gv.match(rx.Var("x"))
assert not dfv_or_gv.match(rx.DataflowVar("y"))
assert not dfv_or_gv.match(rx.GlobalVar("y"))
def test_and_pattern():
# float[2, 3, 3]
f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32")
assert isinstance(f32_233, AndPattern)
assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3)))
def test_not_pattern():
no_shape233 = ~wildcard().has_shape((2, 3, 3))
assert isinstance(no_shape233, NotPattern)
assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
def test_dtype_pattern():
dtype = "float16"
pattern = has_dtype(dtype)
assert isinstance(pattern, DataTypePattern)
assert pattern.dtype == dtype
assert has_dtype("float32").match(bindings[0].var)
def test_shape_pattern():
shape = [32, 32]
pattern = wildcard().has_shape(shape)
assert isinstance(pattern, ShapePattern)
tvm.ir.structural_equal(pattern.shape, shape)
assert pattern.match(bindings[0].var)
assert wildcard().has_shape([32, 32]).match(bindings[0].var)
n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32"))
assert wildcard().has_shape([n, m, n + m]).match(symsh_var)
assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative.
assert not wildcard().has_shape([1, 2, 3]).match(symsh_var)
assert not wildcard().has_shape([m, n, n + m]).match(symsh_var)
def test_prim_arr_pattern():
"""
The difference between is_shape and has_shape is that:
1) is_shape directly matches a shape (e.g., as an argument);
2) has_shape matches a tensor and puts assumptions on the tensor's shape.
"""
pattern = is_shape([32, 32])
assert pattern[0] == 32
assert pattern[1] == 32
assert isinstance(pattern, PrimArrPattern)
assert pattern.match(rx.get_shape_of(bindings[0].var))
n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
symbolic_shape = rx.ShapeExpr([n, m, n + m])
assert is_shape([n, m, n + m]).match(symbolic_shape)
assert not is_shape([n, m, n * m]).match(symbolic_shape)
def test_extern_fn_pattern():
pattern = ExternFuncPattern("test.blockbuilder.nop")
assert pattern.match(rx.ExternFunc("test.blockbuilder.nop"))
def test_op_attr():
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3))
xp = is_var("x")
yp = is_var("y")
# TODO(@yuchen): reenable the assert after figuring out why it fails
# assert is_op("nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d)
assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4, 3]}).match(conv2d)
assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d)
def test_match_call_attr():
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))
annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"})
xp = is_var("x")
yp = is_var("y")
root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp))
assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match(
annotated_fn
)
assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn)
assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn)
assert root_pattern.has_attr({}).match(annotated_fn)
def test_is_call_tir():
lv1_val = bindings[1].value
lv2_val = bindings[2].value
var2val = get_var2val(Module["main"])
assert is_call_tir("tir_relu").match(lv1_val)
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)
@R.function(pure=False)
def simple_call_packed(
x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
) -> R.Tensor:
gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
return gv0
def test_varg_default_wildcard():
expr = simple_call_packed.body.blocks[0].bindings[0].value
yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard())
yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True)
no_pattern = ExternFuncPattern("test.vm.mul")(wildcard())
assert yes_pattern_explicit.match(expr)
assert yes_pattern_implicit.match(expr)
assert not no_pattern.match(expr)
def test_simple_call_packed():
expr = simple_call_packed.body.blocks[0].bindings[0].value
assert is_call_packed("test.vm.mul").match(expr)
assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr)
## Graph-wise Matching
def test_simple_used_by():
with PatternContext() as ctx:
n0 = is_var("x") # x is a free var (fn arg)
n1 = wildcard()
n0 ^ n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == main_fn.params[0]
assert matched[n1] == dfb.bindings[0].var
def test_simple_call_tir_edge():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0.used_by(n1)
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_simple_oub():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0 >> n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_counter_syntax_match():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 >> n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 ^ n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class Diamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul
# / \
# relu sigmoid
# \ /
# add
lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32"))
R.output(lv3)
return lv3
def test_diamond():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 ^ n1
n0 ^ n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
# simplify it with fork_to
with PatternContext() as ctx:
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(n1, n2)
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_diamond_counter_oub():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 >> n1
n0 >> n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class SmallDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu
# / \
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
@tvm.script.ir_module
class SmallParallel:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu relu
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
def test_distinguish_diamond_and_parallel():
# pattern lang cannot distinguish the two cases above.
diamond = SmallDiamond["main"].body.blocks[0]
parallel = SmallParallel["main"].body.blocks[0]
with PatternContext() as ctx:
# describe a diamond pattern
fork = is_call_dps_packed("my_relu")
join = is_call_dps_packed("my_add")
fork.only_used_by(join, index=0)
fork.only_used_by(join, index=1)
assert ctx.match_dfb(diamond)
assert not ctx.match_dfb(parallel)
with PatternContext() as ctx:
# describe a parallel pattern
join = is_call_dps_packed("my_add")
# Due to one-one matching:
# is_call_dps_packed("my_relu") creates the 1st relu
is_call_dps_packed("my_relu") >> join
# is_call_dps_packed("my_relu")
# creates the another different relu (obj address is different)
is_call_dps_packed("my_relu") >> join
assert ctx.match_dfb(parallel)
assert not ctx.match_dfb(diamond)
@tvm.script.ir_module
class CBRx2:
@R.function
def main(
x: R.Tensor((32, 32), "float32"),
w0: R.Tensor((1, 1), "float32"),
bias0: R.Tensor((32, 32), "float32"),
w1: R.Tensor((1, 1), "float32"),
bias1: R.Tensor((32, 32), "float32"),
) -> R.Tensor:
# R.TensorRT's CBR Optimization Pattern
# input
# / \
# cbr0 cbr1
# \ /
# concat
with R.dataflow():
lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32"))
lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32"))
lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32"))
lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32"))
R.output(lv6)
return lv6
def test_nested_context():
dfb = CBRx2["main"].body.blocks[0]
with PatternContext() as ctx0:
(
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
with PatternContext() as ctx1:
is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss
with PatternContext() as ctx2:
is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
assert ctx2.match_dfb(dfb)
assert PatternContext.current() == ctx2
assert not ctx1.match_dfb(dfb)
assert PatternContext.current() == ctx1
assert ctx0.match_dfb(dfb)
assert PatternContext.current() == ctx0
def test_two_cbr():
with PatternContext() as ctx:
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
assert cbr0.patterns[0] != cbr1.patterns[0]
assert cbr0.patterns[1] != cbr1.patterns[1]
assert cbr0.patterns[2] != cbr1.patterns[2]
is_var("x").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
# Deny the pattern
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
# input has no fork at y.
is_var("y").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
def test_two_matmul():
# Same as Figure 2(a) in TASO paper.
@tvm.script.ir_module
class MatMul2:
@R.function
def main(
a: R.Tensor((32, 16), "float32"),
b: R.Tensor((16, 48), "float32"),
c: R.Tensor((48, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32"))
lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape(
[32, 32]
)
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
# Three MatMul cannot match
assert not ctx.match_dfb(dfb)
def test_concat_mm_split():
# Same as Figure 2(b) in TASO paper.
@tvm.script.ir_module
class CMS:
@R.function
def main(
a: R.Tensor((32, 32), "float32"),
b: R.Tensor((16, 32), "float32"),
c: R.Tensor((16, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed(
"my_split",
(lv1,),
[R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")],
)
lv3 = R.TupleGetItem(lv2, 0)
lv4 = R.TupleGetItem(lv2, 1)
lv5 = R.add(lv3, lv4)
R.output(lv5)
return lv5
with PatternContext() as ctx:
(
is_call_dps_packed("my_concat")
>> is_call_dps_packed("my_matmul")
>> is_call_dps_packed("my_split")
)
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
split = is_call_dps_packed("my_split")
lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
split.fork_to(lv3, lv4)
add = is_op("relax.add")(lv3, lv4)
# TODO(@ganler): simplify this through implicit graph pattern.
lv3 >> add
lv4 >> add
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_self_attention():
# The example comes from.
# https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/
@tvm.script.ir_module
class SelfAttention:
@R.function
def main(
x: R.Tensor(("b", "s", "n", "h"), "float32"),
wq: R.Tensor(("h", "h"), "float32"),
wk: R.Tensor(("h", "h"), "float32"),
wv: R.Tensor(("h", "h"), "float32"),
) -> R.Tensor:
b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
with R.dataflow():
fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32"))
tpq = R.call_dps_packed(
"my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")
)
fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32"))
tpk = R.call_dps_packed(
"my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")
)
mul = R.multiply(tpq, tpk)
scale = R.multiply(mul, R.const(1.1, "float32"))
softmax = R.call_dps_packed(
"softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")
)
fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32"))
tpv = R.call_dps_packed(
"my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")
)
out = R.multiply(softmax, tpv)
R.output(out)
return out
with PatternContext() as ctx:
fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose")
fc_trans_k = fc_trans_q.dup()
fc_trans_v = fc_trans_q.dup()
is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v)
dfb = SelfAttention["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_nested_diamond():
@tvm.script.ir_module
class DiamondInDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul0 matmul1
# / \ / \
# sigmoid2 add4 sigmoid3
# \ / \ /
# add5 add6
# \ /
# add7
lv0 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv1 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv2 = R.call_dps_packed(
"extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")
)
lv3 = R.call_dps_packed(
"extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")
)
lv4 = R.call_dps_packed(
"extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")
)
lv5 = R.call_dps_packed(
"extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")
)
lv6 = R.call_dps_packed(
"extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")
)
lv7 = R.call_dps_packed(
"extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")
)
R.output(lv7)
return lv7
# match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 ^ add5
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# counter case: mis-match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 >> add5 # not only-used-by relation
assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match matmul1 diamond
with PatternContext() as ctx:
sigmoid3 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
add6 = is_call_dps_packed("extern_add")
sigmoid3 >> add6
add4 ^ add6
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match add-4-5-6-7
with PatternContext() as ctx:
add5, add6, add7 = (
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
)
is_call_dps_packed("extern_add").fork_to(add5, add6) # add4
add5 >> add7
add6 >> add7
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
def test_incremental_solving():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu -> sigmoid -> neg
lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid
with PatternContext(incremental=True) as ctx1:
# because we are doing incremental solving
# relu >> sigmoid is still a constraint in this context.
# that said the total constraint is:
# relu >> sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
# match relue -> sigmoid
assert ctx0.match_dfb(simple_chain.body.blocks[0])
def test_incremental_solving_counter():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# sigmoid -> neg
lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid # cannot match
with PatternContext(incremental=False) as ctx1:
# total constraint: sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
with PatternContext(incremental=True) as ctx1:
# total constraint: relu >> sigmoid >> neg
sigmoid >> neg
assert not ctx1.match_dfb(simple_chain.body.blocks[0])
def test_rewrite_simple():
@R.function
def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
with R.dataflow():
x2 = R.add(x, x)
x4 = R.add(x2, x2)
R.output(x4)
return x4
@R.function
def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, "float32"))
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, R.const(2, "float32"))
R.output(x4)
return x4
@R.function
def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, "float32"))
R.output(x4)
return x4
x = wildcard()
pattern = is_op("relax.add")(x, x)
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected1.with_attr("global_symbol", "main"))
add1 = is_op("relax.add")(x, x)
pattern = is_op("relax.add")(add1, add1)
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(4, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected2.with_attr("global_symbol", "main"))
# No rewriting, return the original call node as is
def rewriter(orig, _):
return orig
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, main)
def test_rewrite_attention():
@R.function
def main(
Q: R.Tensor((2, 4096, 8, 40), "float32"),
K: R.Tensor((2, 4096, 8, 40), "float32"),
V: R.Tensor((2, 4096, 8, 40), "float32"),
) -> R.Tensor((2, 4096, 8, 40), "float32"):
with R.dataflow():
lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
lv3_1 = R.matmul(lv59, lv62_transposed)
lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
lv69 = R.nn.softmax(lv68, axis=-1)
lv_3 = R.matmul(lv69, lv65)
lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
R.output(lv72)
return lv72
@R.function
def expected(
Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
with R.dataflow():
lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K)
R.output(lv72)
return lv72
def BSNH_to_BSH(tensor):
return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard())
def BSH_to_BSNH(tensor):
return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard()))
Q = wildcard()
K = wildcard()
V = wildcard()
Q_3D = BSNH_to_BSH(Q)
V_3D = BSNH_to_BSH(V)
K_3D = BSNH_to_BSH(K)
matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
multiply = is_op("relax.multiply")(matmul1, is_const())
softmax = is_op("relax.nn.softmax")(multiply)
matmul2 = is_op("relax.matmul")(softmax, K_3D)
pattern = BSH_to_BSNH(matmul2)
def rewriter(_, matchings):
return R.nn.attention(matchings[Q], matchings[K], matchings[V])
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main"))
def test_attention_qkv():
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
dfb = QKV_proj["main"].body.blocks[0]
out = ctx.match_dfb(dfb)
assert out[Q_weight_pat].name_hint == "w0"
assert out[K_weight_pat].name_hint == "w1"
assert out[V_weight_pat].name_hint == "w2"
def test_attention_fake_qkv():
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x1, w0)
lv1 = R.matmul(x2, w1)
lv2 = R.matmul(x2, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
dfb = QKV_proj["main"].body.blocks[0]
assert ctx.match_dfb(dfb) is None
def get_qkv_proj_rewriter():
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
def qkv_proj_rewriter(matchings, _):
inp = matchings[inp_pat]
Q_weight = matchings[Q_weight_pat]
K_weight = matchings[K_weight_pat]
V_weight = matchings[V_weight_pat]
width = Q_weight.struct_info.shape[1]
concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
matmul = R.matmul(inp, concat)
Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3])
return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}
return ctx, qkv_proj_rewriter
def test_combine_matmul_twice():
@R.function(private=True)
def qkv_x2(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
w4: R.Tensor((640, 640), "float32"),
w5: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
lv0 = R.matmul(x1, w0)
lv1 = R.matmul(x1, w1)
lv2 = R.matmul(x1, w2)
lv3 = R.matmul(x2, w3)
lv4 = R.matmul(x2, w4)
lv5 = R.matmul(x2, w5)
out = (lv0, lv1, lv2, lv3, lv4, lv5)
R.output(out)
return out
@R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
w4: R.Tensor((640, 640), "float32"),
w5: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv)
lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
lv2_1 = R.concat((w3, w4, w5), axis=1)
lv3 = R.matmul(x2, lv2_1, out_dtype="void")
lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640])
lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280])
lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920])
out = lv0, lv1_1, lv2, lv3_1, lv4, lv5
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
tvm.ir.assert_structural_equal(rewritten, expected)
def test_dataflow_may_start_with_match_cast():
"""Inputs to rewrite_bindings may contain R.match_cast
This is a regression test. In previous implementations, applying
`rewrite_bindings` when `R.match_cast` is the first binding of a
`R.dataflow` block would cause a segfault.
"""
@R.function(private=True)
def before(
x_untyped: R.Tensor,
w0_untyped: R.Tensor,
w1_untyped: R.Tensor,
w2_untyped: R.Tensor,
):
with R.dataflow():
x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
out_0 = R.matmul(x, w0)
out_1 = R.matmul(x, w1)
out_2 = R.matmul(x, w2)
out = (out_0, out_1, out_2)
R.output(out)
return out
@R.function(private=True)
def expected(
x_untyped: R.Tensor,
w0_untyped: R.Tensor,
w1_untyped: R.Tensor,
w2_untyped: R.Tensor,
):
with R.dataflow():
x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
w_concat = R.concat((w0, w1, w2), axis=1)
out_concat = R.matmul(x, w_concat)
out_0 = R.strided_slice(out_concat, axes=[2], begin=[0], end=[640])
out_1 = R.strided_slice(out_concat, axes=[2], begin=[640], end=[1280])
out_2 = R.strided_slice(out_concat, axes=[2], begin=[1280], end=[1920])
out = (out_0, out_1, out_2)
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, before)
tvm.ir.assert_structural_equal(rewritten, expected)
def test_combine_matmul_emit_order():
@R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
w1_t_t = R.permute_dims(w1_t, axes=None)
lv1 = R.matmul(x1, w1_t_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x1, w2_t)
out = (lv0, lv1, lv2)
R.output(out)
return out
@R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, 640), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((640, 640), dtype="float32"),
):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
w1_t = R.permute_dims(w1, axes=None)
w1_t_t = R.permute_dims(w1_t, axes=None)
w2_t = R.permute_dims(w2, axes=None)
lv = R.concat((w0_t, w1_t_t, w2_t), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="void")
lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
out = lv0, lv1_1, lv2
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected)
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
tvm.compile(mod, target="llvm")
def test_combine_transposed_matmul_twice():
@R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
lv1 = R.matmul(x1, w1_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x2, w2_t)
w3_t = R.permute_dims(w3, axes=None)
lv3 = R.matmul(x2, w3_t)
out = (lv0, lv1, lv2, lv3)
R.output(out)
return out
@R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
x2: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, 640), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((640, 640), dtype="float32"),
w3: R.Tensor((640, 640), dtype="float32"),
):
with R.dataflow():
lv: R.Tensor((1280, 640), dtype="float32") = R.concat((w0, w1), axis=0)
lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
lv3: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2, indices_or_sections=[640], axis=-1)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
lv_1: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
lv1_2: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv_1, axes=None)
lv2_1: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(
x2, lv1_2, out_dtype="void"
)
lv3_1: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2_1, indices_or_sections=[640], axis=-1)
lv2_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[0]
lv3_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[1]
out: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = (lv0, lv1_1, lv2_1_1, lv3_1_1)
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
w1_pat = wildcard()
w2_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w1_pat))
matmul2 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w2_pat))
def rewriter(matchings, _):
inp = matchings[inp_pat]
w1 = matchings[w1_pat]
w2 = matchings[w2_pat]
concat = R.concat([w1, w2], axis=0)
matmul = R.matmul(inp, R.permute_dims(concat))
sections = [w1.struct_info.shape[0]]
chunks = R.split(matmul, sections, -1)
return {
matchings[matmul1]: chunks[0],
matchings[matmul2]: chunks[1],
}
rewritten = rewrite_bindings(ctx, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected)
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
print(mod)
tvm.compile(mod, target="llvm")
def test_commutative_pattern_match():
@R.function(private=True)
def before(
x: R.Tensor((1024,)),
):
with R.dataflow():
y = R.add(x, x)
out = R.add(R.const(1.0), y)
R.output(out)
return out
@R.function(private=True)
def expected(
x: R.Tensor((1024,)),
):
with R.dataflow():
y = R.add(x, x)
out = R.add(y, R.const(2.0))
R.output(out)
return out
pattern_add = is_op("relax.add")
pattern_mul = is_op("relax.multiply")
pattern_op = pattern_add | pattern_mul
pattern_arg = wildcard()
pattern_const = is_const()
pattern = pattern_op(pattern_arg, pattern_const)
def rewriter(expr, matches):
op = matches[pattern_op]
arg = matches[pattern_arg]
const = matches[pattern_const].data.numpy()
if const.shape == tuple() and const[()] == 1.0:
return rx.Call(op, [arg, rx.const(2.0)])
else:
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
def test_repeated_pattern_match():
"""rewrite_call should iterate until convergence"""
@R.function(private=True)
def before(
x: R.Tensor((1024,)),
y: R.Tensor((1024,)),
z: R.Tensor((1024,)),
):
with R.dataflow():
a = R.add(x, y)
b = R.add(a, z)
out = R.multiply(b, R.const(5.0))
R.output(out)
return out
@R.function(private=True)
def expected(
x: R.Tensor((1024,)),
y: R.Tensor((1024,)),
z: R.Tensor((1024,)),
):
with R.dataflow():
x = R.multiply(x, R.const(5.0))
y = R.multiply(y, R.const(5.0))
a = R.add(x, y)
z = R.multiply(z, R.const(5.0))
b = R.add(a, z)
R.output(b)
return b
pattern_add_lhs = wildcard()
pattern_add_rhs = wildcard()
pattern_add = is_op("relax.add")(pattern_add_lhs, pattern_add_rhs)
mul_const = is_const()
pattern_mul = is_op("relax.multiply")(pattern_add, mul_const)
pattern = pattern_mul
def rewriter(_expr, matches):
const = matches[mul_const]
return (matches[pattern_add_lhs] * const) + (matches[pattern_add_rhs] * const)
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
bind_to_dataflow_var = tvm.testing.parameter(
by_dict={"var-to-var": False, "var-to-dataflow-var": True}
)
def test_rewrite_without_trivial_binding(bind_to_dataflow_var):
"""rewrite_call should avoid producing trivial "y = x" bindings
This may not be possible in all cases, and follows the same
rules as CanonicalizeBindings. For example, a `relax.Var` is
bound to a `relax.DataflowVar` may not be removed, to ensure
that the `relax.DataflowVar` is only used within a
`DataflowBlock`.
"""
if bind_to_dataflow_var:
@R.function(private=True)
def before(x: R.Tensor((1024,))):
with R.dataflow():
a = R.add(x, x)
b = R.reshape(a, (1024,))
R.output(b)
return b
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
b = R.add(x, x)
R.output(b)
return b
else:
@R.function(private=True)
def before(x: R.Tensor((1024,))):
a = R.add(x, x)
b = R.reshape(a, (1024,))
return b
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
a = R.add(x, x)
return a
pattern_arg = wildcard()
pattern_shape_expr = wildcard()
pattern = is_op("relax.reshape")(pattern_arg, pattern_shape_expr)
def rewriter(expr, matches):
arg = matches[pattern_arg]
shape_expr = matches[pattern_shape_expr]
if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):
return arg
else:
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
same_shape_func_type = tvm.testing.parameter(
"same_static_shape",
"same_dynamic_shape",
"different_static_shape",
"different_dynamic_shape",
)
def test_same_shape_pattern(same_shape_func_type):
if same_shape_func_type == "same_static_shape":
@R.function(private=True)
def func(
a: R.Tensor((1024, 128), "float32"),
b: R.Tensor((1024, 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "same_dynamic_shape":
@R.function(private=True)
def func(
a: R.Tensor(("n", 128), "float32"),
b: R.Tensor(("n", 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "different_static_shape":
@R.function(private=True)
def func(
a: R.Tensor((1024, 128), "float32"),
b: R.Tensor((1, 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "different_dynamic_shape":
@R.function(private=True)
def func(
a: R.Tensor(("n", 128), "float32"),
b: R.Tensor(("m", 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
else:
raise ValueError(f"Unknown value of same_shape_func_type={same_shape_func_type}")
with PatternContext() as ctx:
pat_lhs = wildcard()
pat_rhs = wildcard()
pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
pat_lhs.same_shape_as(pat_rhs)
block = func.body.blocks[0]
match = ctx.match_dfb(block)
if "same" in same_shape_func_type:
assert match
else:
assert match is None
def test_iterative_rewrite_without_trivial_binding():
"""Avoid introducing common sub-expressions
Pattern replacement may produce the same intermediate, which
should appear only once in the final result.
"""
@R.function(private=True)
def before(x: R.Tensor((1024,))):
with R.dataflow():
a = R.strided_slice(x, [0], [0], [512], [1])
b = R.strided_slice(x, [0], [512], [1024], [1])
c = R.add(a, b)
R.output(c)
return c
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
x_split = R.split(x, 2)
a = x_split[0]
b = x_split[1]
c = R.add(a, b)
R.output(c)
return c
pattern_arg = wildcard()
pattern_axes = wildcard()
pattern_begin = wildcard()
pattern_end = wildcard()
pattern_strides = wildcard()
pattern = is_op("relax.strided_slice")(
pattern_arg, pattern_axes, pattern_begin, pattern_end, pattern_strides
)
def rewriter(expr, matches):
arg = matches[pattern_arg]
axes = matches[pattern_axes]
begin = matches[pattern_begin]
end = matches[pattern_end]
strides = matches[pattern_strides]
strided_slice = matches[pattern]
if arg.struct_info.shape is None:
return expr
if len(axes) != 1:
return expr
axis = axes[0].value
begin = begin[0].value
end = end[0].value
stride = strides[0].value
if stride != 1:
return expr
size = arg.struct_info.shape[0]
if (
isinstance(size, tir.IntImm)
and isinstance(begin, tir.IntImm)
and isinstance(end, tir.IntImm)
):
size = size.value
begin = begin.value
end = end.value
else:
return expr
gcd = functools.reduce(math.gcd, [begin, end, size])
if (end - begin) // gcd == 1:
return rx.op.split(arg, size // gcd)[begin // gcd]
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
def test_iterative_rewrite_with_removed_intermediates():
"""Pattern replacement may require canonicalization
A pattern may replace a tuple returned by a function with a tuple
whose contents are known by Relax. In that case, canonicalization
is required to unwrap the TupleGetItem instances into the known
contents.
This test case shows the intermediate results produced in the
process of pattern-matching.
"""
@R.function(private=True)
def before(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
with R.dataflow():
c = R.concat([a, b])
d = R.split(c, 2)
e = d[0]
f = d[1]
g = R.add(a, e)
h = R.add(f, g)
R.output(h)
return h
# First pattern rewrite. The concat/rewrite can be unwrapped, so
# `d` is rewritten from `R.split(c, 2)` into `(a, b)`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# c = R.concat([a, b])
# d = (a,b)
# e = d[0]
# f = d[1]
# g = R.add(a, e)
# h = R.add(f, g)
# R.output(h)
# Canonicalization step. Because `d` is known to be `(a,b)`,
# canonicalization can rewrite `d[0]` into `a` and `d[1]` into
# `b`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# c = R.concat([a, b])
# d = (a,b)
# e = a
# f = b
# g = R.add(a, a)
# h = R.add(b, g)
# R.output(h)
# Dead-code-elimination step. This technically isn't required
# until the pattern matching has converged, but performing it now
# prevents testing for matches on dead code.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# g = R.add(a, a)
# h = R.add(b, g)
# R.output(h)
# Second pattern-matching step. Now, the `R.add(a,a)` can match
# the other option in our pattern, and be rewritten as
# `R.multiply(a,R.const(2))`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# g = R.multiply(a, R.const(2))
# h = R.add(b, g)
# R.output(h)
# Canonicalization and dead-code-elimination are applied again,
# but have no effect this time.
@R.function(private=True)
def expected(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
with R.dataflow():
g = R.multiply(a, R.const(2))
h = R.add(b, g)
R.output(h)
return h
pat_args = wildcard()
op_concat = is_op("relax.concat")
pat_concat = op_concat(pat_args).has_attr({"axis": 0})
op_split = is_op("relax.split")
pat_split = op_split(pat_concat).has_attr({"axis": 0, "indices_or_sections": T.int64(2)})
pat_unwrap_concat_split = pat_split
pat_arg = wildcard()
op_add = is_op("relax.add")
pat_add_self = op_add(pat_arg, pat_arg)
pattern = pat_unwrap_concat_split | pat_add_self
def rewriter(expr, matches):
if pat_unwrap_concat_split in matches:
args = matches[pat_args]
if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info, args[1].struct_info):
return args
elif pat_add_self in matches:
arg = matches[pat_arg]
return arg * rx.const(2)
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_with_struct_info_updates_when_matching():
"""A DFPattern may be restricted to a specific StructInfo"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.add(A, B)
R.output(C)
return C
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.multiply(A, B)
R.output(C)
return C
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_with_struct_info_is_no_op_when_not_matching():
"""StructInfoPattern requires the StructInfo provided
Here, the pattern would match, expect that the function has
`R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
# This R.add has the same shape as the pattern, and will
# be updated.
A = R.zeros([16, 32], "int32")
B = R.ones([16, 32], "int32")
C = R.add(A, B)
R.output(C)
return C
expected = before
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_struct_info_for_unknown_dtype():
"""TensorStructInfo with unknown dtype allows any dtype"""
pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.add(A, B)
D = R.zeros([2, 3], "float32")
E = R.ones([2, 3], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([2, 3], "int32")
B = R.ones([2, 3], "int32")
C = R.multiply(A, B)
D = R.zeros([2, 3], "float32")
E = R.ones([2, 3], "float32")
F = R.multiply(D, E)
output = (C, F)
R.output(output)
return output
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_struct_info_with_symbolic_vars():
"""StructInfoPattern may define symbolic vars
This test finds an elementwise `R.add`, while ignoring a
broadcasted `R.add`.
"""
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
pat = is_op("relax.add")(pat_lhs, pat_rhs)
def rewriter(expr, matches):
lhs = matches[pat_lhs]
rhs = matches[pat_rhs]
return rx.op.multiply(lhs, rhs)
@R.function(private=True)
def before():
with R.dataflow():
A = R.zeros([64, 128], "int32")
B = R.ones([64, 128], "int32")
C = R.add(A, B)
D = R.zeros([64, 128], "float32")
E = R.ones([1, 128], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
@R.function(private=True)
def expected():
with R.dataflow():
A = R.zeros([64, 128], "int32")
B = R.ones([64, 128], "int32")
C = R.multiply(A, B)
D = R.zeros([64, 128], "float32")
E = R.ones([1, 128], "float32")
F = R.add(D, E)
output = (C, F)
R.output(output)
return output
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_backtrack_if_rewriter_returns_no_op():
"""Rewriter participates in the pattern matching
Sometimes, the pattern-matching syntax is insufficient to check if
a replacement may be performed. In this case, the `rewriter`
function may perform additional validation. If this validation
fails, the `rewriter` function can return the original expression,
and no replacement is performed.
In addition, when the `rewriter` returns the original expression,
the pattern match should backtrack to determine if another branch
of the match may have produced a replacement.
This functionality allows pattern replacements to be composed.
"""
pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())
pat_arg = wildcard()
pat_zeros = is_op("relax.zeros")(wildcard())
pat_add = is_op("relax.add")(pat_arg, pat_zeros)
# OR conditions are checked in the order that they occur. Because
# `pat_match_no_rewrite` is a superset of `pat_add`, it will
# always match first.
pat = pat_match_no_rewrite | pat_add
def rewriter(expr, matches):
if pat_match_no_rewrite in matches:
# This branch simulates a rewrite whose precondition has
# failed. If the pattern-matching treats this as a
# successful match with no replacemen required, then no
# rewrite would be performed. On the other hand, if the
# pattern-matching treats this as an unsuccessful match,
# then it can backtrack and attempt `pat_add` instead.
return expr
elif pat_add in matches:
return matches[pat_arg]
else:
raise RuntimeError("Pattern matched, but neither branch matched")
@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)
R.output(C)
return C
@R.function(private=True)
def expected():
with R.dataflow():
C = R.ones([64, 128], "int32")
R.output(C)
return C
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
"""The matches should always contain the bound value
This is a regression test. In versions from
https://github.com/apache/tvm/pull/16732 to
https://github.com/apache/tvm/pull/16828, the `rewrite_call`
function could erroneously call the rewriter with `expr` and
`matches[pat]` set to a variable (`C`) instead of the value to
which it is bound (`R.add(A,B)`).
"""
pat_a = is_op("relax.add")(wildcard(), wildcard())
pat_b = is_op("relax.add")(wildcard(), wildcard())
pat = pat_a | pat_b
def rewriter(expr, matches):
assert isinstance(matches[pat], rx.Call)
return expr
@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)
R.output(C)
return C
expected = before
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()