blob: 8e583b3dd4cc8c11b04254f30e13bee88239c37f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax, topi
from tvm.script import ir as I, relax as R, tir as T
def _check(mod_before, mod_expected):
mod_after = relax.transform.FuseTIR()(mod_before)
tvm.ir.assert_structural_equal(mod_expected, mod_after)
def test_simple():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
bb.emit_func_output(gv)
return bb.get().with_attrs({"foo": "bar"})
def expected():
def fused_add_exp_squeeze(x, p0):
add = topi.add(x, p0)
exp = topi.exp(add)
squeeze = topi.squeeze(exp)
return squeeze
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0))
bb.emit_func_output(gv)
return bb.get().with_attrs({"foo": "bar"})
_check(before(), expected())
def test_conv2d_fuse():
def before(dtype):
bb = relax.BlockBuilder()
# Grouped function 1
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
p0 = relax.Var("p0", R.Tensor((), dtype))
with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
x,
w,
strides=1,
padding=1,
dilation=1,
primfunc_name_hint="conv2d",
)
lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2"))
bb.emit_func_output(gv)
# Grouped function 2
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
x,
w,
strides=1,
padding=0,
dilation=1,
primfunc_name_hint="conv2d1",
)
gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2"))
bb.emit_func_output(gv)
# Get the global variables of the grouped functions
mod = bb.get()
fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
# Main function
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
with bb.function("main", [x, w1, w2, w3]):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)]))
lv2 = bb.emit_te(
topi.nn.conv2d,
lv1,
w3,
strides=1,
padding=1,
dilation=1,
)
gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2]))
bb.emit_func_output(gv)
return bb.get()
def expected(dtype):
def fused_conv2d_add1_add2(x, w, p):
conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1)
add = topi.add(p, conv)
return topi.add(conv, add)
def fused_conv2d1_add2(x, w, p):
conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1)
return topi.add(conv, p)
bb = relax.BlockBuilder()
# Main function
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
with bb.function("main", [x, w1, w2, w3]):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, relax.const(1, dtype))
lv2 = bb.emit_te(
topi.nn.conv2d,
lv1,
w3,
strides=1,
padding=1,
dilation=1,
)
gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, lv2))
bb.emit_func_output(gv)
return bb.get()
_check(before("float32"), expected("float32"))
def test_two_subfunction():
def before():
bb = relax.BlockBuilder()
x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
with bb.dataflow():
lv1 = bb.emit_te(topi.exp, x1)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_exp_squeeze")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit(relax.Call(func_gv, [x]))
lv2 = bb.emit(relax.Call(func_gv, [lv]))
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_exp_squeeze(x):
exp = topi.exp(x)
squeeze = topi.squeeze(exp)
return squeeze
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.call_te(fused_exp_squeeze, lv)
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_same_primfunc():
def before():
bb = relax.BlockBuilder()
x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": True}):
with bb.dataflow():
lv1 = bb.emit_te(topi.exp, x1)
lv2 = bb.emit_te(topi.exp, lv1)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv2))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_exp_exp_squeeze")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit(relax.Call(func_gv, [x]))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_exp_exp_squeeze(x):
exp = topi.exp(x)
exp = topi.exp(exp)
squeeze = topi.squeeze(exp)
return squeeze
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.call_te(fused_exp_exp_squeeze, x)
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_with_tuple_as_param():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv1 = bb.emit(relax.TupleGetItem(x, 1))
lv2 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_exp_add")
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
with bb.function("main", [x]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(func_gv, [x]))
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_exp_add(x1, x2):
exp = topi.exp(x1)
return topi.add(exp, x2)
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
with bb.function("main", [x]):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv1 = bb.emit(relax.TupleGetItem(x, 1))
gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1))
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_with_nested_tuple_as_param():
tuple_struct_info = R.Tuple(
[R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])]
)
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", tuple_struct_info)
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv0_exp = bb.emit_te(topi.exp, lv0)
lv1 = bb.emit(relax.TupleGetItem(x, 1))
lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0))
lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1))
lv2 = bb.emit_te(topi.add, lv1_0, lv1_1)
gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_exp_add_add")
x = relax.Var("x", tuple_struct_info)
with bb.function("main", [x]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(func_gv, [x]))
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_exp_add_add(x1, x2, x3):
exp = topi.exp(x1)
add = topi.add(x2, x3)
return topi.add(exp, add)
bb = relax.BlockBuilder()
x = relax.Var("x", tuple_struct_info)
with bb.function("main", [x]):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv1 = bb.emit(relax.TupleGetItem(x, 1))
lv2 = bb.emit(relax.TupleGetItem(lv1, 0))
lv3 = bb.emit(relax.TupleGetItem(lv1, 1))
gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, lv3))
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_with_call_tir_in_main():
def before():
bb = relax.BlockBuilder()
x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
with bb.dataflow():
lv = bb.emit_te(topi.exp, x1)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_exp_squeeze")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv0 = bb.emit(relax.Call(func_gv, [x]))
lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
gv = bb.emit_output(lv1)
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_exp_squeeze(x):
exp = topi.exp(x)
squeeze = topi.squeeze(exp)
return squeeze
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32"))
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_with_const_in_argument():
def before():
bb = relax.BlockBuilder()
x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
x2 = relax.Var("x2", R.Tensor([], "float32"))
with bb.function("fused_add_exp_squeeze", [x1, x2], attrs={"Primitive": True}):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x1, x2)
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_add_exp_squeeze")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, "float32")]))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_add_exp_squeeze(x, y):
add = topi.add(x, y)
exp = topi.exp(add)
squeeze = topi.squeeze(exp)
return squeeze
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_tuple_output():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}):
with bb.dataflow():
gv0 = bb.emit_output(bb.call_te(topi.add, x, p0))
gv1 = bb.emit_output(bb.call_te(topi.exp, gv0))
bb.emit_func_output(relax.Tuple([gv0, gv1]))
fused_add_exp = bb.get().get_global_var("fused_add_exp")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0]))
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_add_exp(x, p0):
add = topi.add(x, p0)
exp = topi.exp(add)
return add, exp
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0))
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_with_immediate_tuple():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
y = relax.Var("y", R.Tensor([10, 20], "float32"))
with bb.function("fused_add", [x, y], attrs={"Primitive": True}):
with bb.dataflow():
lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])]))
lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0))
lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1))
lv_y = bb.emit(relax.TupleGetItem(lv0, 1))
gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y))
bb.emit_func_output(gv)
fused_add = bb.get().get_global_var("fused_add")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
y = relax.Var("y", R.Tensor([10, 20], "float32"))
with bb.function("main", [x, y]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add, [x, y]))
bb.emit_func_output(gv)
return bb.get()
def expected():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
y = relax.Var("y", R.Tensor([10, 20], "float32"))
with bb.function("main", [x, y]):
with bb.dataflow():
gv = bb.emit_output(bb.call_te(topi.add, x, y, primfunc_name_hint="fused_add"))
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_fuse_return_partial_result():
def te_argmax_idx_val(val):
from tvm import te
def f_combine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)
argmax = te.comm_reducer(f_combine, f_identity, name="argmax")
m, n = val.shape
k = te.reduce_axis((0, n), "k")
max_idx, max_val = te.compute(
(m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax"
)
return max_idx, max_val
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
offset = relax.Var("offset", R.Tensor([10], "int32"))
with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": True}):
with bb.dataflow():
lv = bb.emit_te(te_argmax_idx_val, x)
idx = bb.emit(relax.TupleGetItem(lv, 0))
gv = bb.emit_output(bb.call_te(topi.add, idx, offset))
bb.emit_func_output(gv)
mod = bb.get()
func_gv = mod.get_global_var("fused_argmax_add")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
offset = relax.Var("x", R.Tensor([10], "int32"))
with bb.function("main", [x, offset]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(func_gv, [x, offset]))
bb.emit_func_output(gv)
return bb.get()
def expected():
def fused_argmax_add(x, offset):
idx, value = te_argmax_idx_val(x)
idx = topi.add(idx, offset)
return idx
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
offset = relax.Var("offset", R.Tensor([10], "int32"))
with bb.function("main", [x, offset]):
with bb.dataflow():
gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset))
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
def test_multiple_relax_functions():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")
x = relax.Var("x", R.Tensor([20, 10], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
with bb.function(
"fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": True}, private=True
):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
fused_add1_exp1_squeeze1 = bb.get().get_global_var("fused_add1_exp1_squeeze1")
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("func1", [x]):
with bb.dataflow():
gv = bb.emit_output(
relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")])
)
bb.emit_func_output(gv)
x = relax.Var("x", R.Tensor([20, 10], "float32"))
with bb.function("func2", [x]):
with bb.dataflow():
gv = bb.emit_output(
relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, "float32")])
)
bb.emit_func_output(gv)
return bb.get()
@I.ir_module
class Expected:
@R.function
def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"):
with R.dataflow():
gv2 = R.call_tir(
Expected.fused_add_exp_squeeze,
(x, R.const(1, "float32")),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(gv2)
return gv2
@R.function
def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="float32"):
with R.dataflow():
gv3 = R.call_tir(
Expected.fused_add1_exp1_squeeze1,
(x, R.const(1, "float32")),
out_sinfo=R.Tensor((20, 10), dtype="float32"),
)
R.output(gv3)
return gv3
@T.prim_func(private=True)
def fused_add1_exp1_squeeze1(
x: T.Buffer((T.int64(20), T.int64(10)), "float32"),
p0: T.Buffer((), "float32"),
T_squeeze: T.Buffer((T.int64(20), T.int64(10)), "float32"),
):
T.func_attr({"tir.noalias": True})
T_add = T.alloc_buffer((T.int64(20), T.int64(10)))
compute = T.alloc_buffer((T.int64(20), T.int64(10)))
for ax0, ax1 in T.grid(T.int64(20), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(x[v_ax0, v_ax1], p0[()])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for i0, i1 in T.grid(T.int64(20), T.int64(10)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(20), T.int64(10)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_squeeze[v_ax0, v_ax1])
T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1]
@T.prim_func(private=True)
def fused_add_exp_squeeze(
x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
p0: T.Buffer((), "float32"),
T_squeeze: T.Buffer((T.int64(10), T.int64(20)), "float32"),
):
T.func_attr({"tir.noalias": True})
T_add = T.alloc_buffer((T.int64(10), T.int64(20)))
compute = T.alloc_buffer((T.int64(10), T.int64(20)))
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(x[v_ax0, v_ax1], p0[()])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for i0, i1 in T.grid(T.int64(10), T.int64(20)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_squeeze[v_ax0, v_ax1])
T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1]
_check(before(), Expected)
def test_skip_call_dps_packed():
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, 3), "float32")):
with R.dataflow():
y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32"))
R.output(y)
return y
# FuseTIR should do no change to it.
_check(Module, Module)
def test_symbolic_shape_aware_fuse():
@I.ir_module
class Before:
@R.function
def fused_add_exp_squeeze(
x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
) -> R.Tensor(["n", "m"], dtype="float32"):
R.func_attr({"Primitive": True})
with R.dataflow():
lv0 = R.emit_te(topi.add, x, p0)
lv1 = R.emit_te(topi.exp, lv0)
gv = R.emit_te(topi.squeeze, lv1)
R.output(gv)
return gv
@R.function
def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
R.output(gv)
return gv
def fused_add_exp_squeeze(x, p0):
return topi.squeeze(topi.exp(topi.add(x, p0)))
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
with R.dataflow():
gv = R.emit_te(fused_add_exp_squeeze, x, R.const(1, "float32"))
R.output(gv)
return gv
_check(Before, Expected)
def test_fuse_of_dynamic_kernel_with_var_params_and_static_args():
@I.ir_module
class Before:
@T.prim_func(private=True)
def dynamic_tir_kernel(a: T.handle, b: T.handle):
m = T.int64()
n = T.int64()
A = T.match_buffer(a, [m, n], "float32")
B = T.match_buffer(b, [m, n], "float32")
for iters in T.grid(m, n):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
B[i, j] = A[i, j] * i + j
@R.function(private=True)
def fused_function(x: R.Tensor([16, 32], "float32")) -> R.Tensor([16, 32], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
y = R.call_tir(cls.dynamic_tir_kernel, [x], out_sinfo=R.Tensor([16, 32], "float32"))
z = R.call_tir(cls.dynamic_tir_kernel, [y], out_sinfo=R.Tensor([16, 32], "float32"))
R.output(z)
return z
@R.function
def main(x: R.Tensor([16, 32], "float32")) -> R.Tensor([16, 32], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_function(
X: T.Buffer([T.int64(16), T.int64(32)], "float32"),
Z: T.Buffer([T.int64(16), T.int64(32)], "float32"),
):
T.func_attr({"tir.noalias": True})
Y = T.alloc_buffer(X.shape, "float32")
for iters in T.grid(*X.shape):
with T.block("compute_Y"):
i, j = T.axis.remap("SS", iters)
Y[i, j] = X[i, j] * i + j
for iters in T.grid(*X.shape):
with T.block("compute_Z"):
i, j = T.axis.remap("SS", iters)
Z[i, j] = Y[i, j] * i + j
@R.function
def main(x: R.Tensor([16, 32], "float32")) -> R.Tensor([16, 32], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(cls.fused_function, [x], out_sinfo=R.Tensor([16, 32], "float32"))
R.output(gv)
return gv
_check(Before, Expected)
def test_fuse_of_dynamic_kernel_with_expression_params_and_static_args():
"""Parameters and arguments do not need to match structurally
Here, the kernel requires arguments (m*n), and is provided
"""
@I.ir_module
class Before:
@T.prim_func(private=True)
def dynamic_tir_kernel(a: T.handle, b: T.handle, c: T.handle, d: T.handle):
m = T.int64()
n = T.int64()
A = T.match_buffer(a, [m * n], "float32")
B = T.match_buffer(b, [m], "float32")
C = T.match_buffer(c, [n], "float32")
D = T.match_buffer(d, [m * n], "float32")
for i, j in T.grid(m, n):
with T.block("compute"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi * 32 + vj] = A[vi * 32 + vj] * B[vi] + C[vj]
@R.function(private=True)
def fused_function(
x: R.Tensor([16 * 32], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([32], "float32"),
):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
y = R.call_tir(
cls.dynamic_tir_kernel, [x, B, C], out_sinfo=R.Tensor([16 * 32], "float32")
)
z = R.call_tir(
cls.dynamic_tir_kernel, [y, B, C], out_sinfo=R.Tensor([16 * 32], "float32")
)
R.output(z)
return z
@R.function
def main(
x: R.Tensor([16 * 32], "float32"),
B: R.Tensor([16], "float32"),
C: R.Tensor([32], "float32"),
):
cls = Before
with R.dataflow():
gv = cls.fused_function(x, B, C)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_function(
X: T.Buffer(T.int64(512), "float32"),
B: T.Buffer(T.int64(16), "float32"),
C: T.Buffer(T.int64(32), "float32"),
Z: T.Buffer(T.int64(512), "float32"),
):
T.func_attr({"tir.noalias": True})
Y = T.alloc_buffer((T.int64(512),))
for i, j in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
vi, vj = T.axis.remap("SS", [i, j])
Y[vi * 32 + vj] = X[vi * 32 + vj] * B[vi] + C[vj]
for i, j in T.grid(T.int64(16), T.int64(32)):
with T.block("compute_1"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi * 32 + vj] = Y[vi * 32 + vj] * B[vi] + C[vj]
@R.function
def main(
x: R.Tensor((512,), dtype="float32"),
B: R.Tensor((16,), dtype="float32"),
C: R.Tensor((32,), dtype="float32"),
) -> R.Tensor((512,), dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused_function, (x, B, C), out_sinfo=R.Tensor((512,), dtype="float32")
)
R.output(gv)
return gv
_check(Before, Expected)
def test_symbolic_shape_aware_fuse_with_allocation():
def te_mean(x, axis):
return topi.divide(topi.sum(x, axis, keepdims=True), 4096)
@I.ir_module
class Before:
@R.function
def fused_mean_add_tir_sqrt_divide_multiply(
x: R.Tensor((1, "n", 4096), dtype="float32"),
y: R.Tensor((1, "n", 4096), dtype="float32"),
rms_norm_weight: R.Tensor((4096,), dtype="float32"),
) -> R.Tensor((1, "n", 4096), dtype="float32"):
R.func_attr({"Primitive": True})
with R.dataflow():
lv0 = R.emit_te(te_mean, x, axis=2)
lv1 = R.emit_te(topi.add, lv0, lv0)
lv2 = R.emit_te(topi.sqrt, lv1)
lv3 = R.emit_te(topi.divide, y, lv2)
gv = R.emit_te(topi.multiply, rms_norm_weight, lv3)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((1, "n", 4096), dtype="float32"),
y: R.Tensor((1, "n", 4096), dtype="float32"),
rms_norm_weight: R.Tensor((4096,), dtype="float32"),
) -> R.Tensor((1, "n", 4096), dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight)
R.output(gv)
return gv
def fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight):
lv0 = te_mean(x, axis=2)
lv1 = topi.add(lv0, lv0)
lv2 = topi.sqrt(lv1)
lv3 = topi.divide(y, lv2)
return topi.multiply(rms_norm_weight, lv3)
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, "n", 4096), dtype="float32"),
y: R.Tensor((1, "n", 4096), dtype="float32"),
rms_norm_weight: R.Tensor((4096,), dtype="float32"),
) -> R.Tensor((1, "n", 4096), dtype="float32"):
with R.dataflow():
gv = R.emit_te(fused_mean_add_tir_sqrt_divide_multiply, x, y, rms_norm_weight)
R.output(gv)
return gv
_check(Before, Expected)
def test_symbolic_var_in_call_tir_args():
@I.ir_module
class Before:
@T.prim_func(private=True)
def foo(
X: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
m: T.int64,
):
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
with T.block("rotary"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2, v3]
@R.function
def fused(
x: R.Tensor((1, 1, 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, 1, 32, 128), dtype="float32"):
R.func_attr({"Primitive": True})
m = T.int64()
cls = Before
with R.dataflow():
lv1 = R.emit_te(topi.add, x, x)
gv = R.call_tir(
cls.foo,
[lv1, y],
out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float32"),
tir_vars=R.shape([m]),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((1, 1, 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, 1, 32, 128), dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused(x, y, len)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused(
X: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
m: T.int64,
):
T.func_attr({"tir.noalias": True})
T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3]
)
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
with T.block("rotary"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3]
@R.function
def main(
x: R.Tensor((1, 1, 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, 1, 32, 128), dtype="float32"):
m = T.int64()
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused,
(x, y),
out_sinfo=R.Tensor([1, 1, 32, 128], "float32"),
tir_vars=R.shape([m]),
)
R.output(gv)
return gv
_check(Before, Expected)
def test_same_buffer_multiple_read():
@I.ir_module
class Module:
@T.prim_func(private=True)
def concatenate(
rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"),
rxplaceholder_1: T.Buffer(
(T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"
),
T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
):
T.func_attr({"op_pattern": 2, "tir.noalias": True})
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
)
T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
T.int64(1) <= v_ax0,
rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
)
@T.prim_func(private=True)
def transpose2(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"),
):
T.func_attr({"op_pattern": 2, "tir.noalias": True})
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[
v_ax0, v_ax3, v_ax1, v_ax2
]
@R.function
def fused_concatenate_transpose2(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.concatenate,
(inp_0, inp_0),
out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"),
)
gv = R.call_tir(
cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32")
)
R.output(gv)
return gv
@R.function
def main(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"num_input": 3})
cls = Module
with R.dataflow():
lv = cls.fused_concatenate_transpose2(inp_0)
R.output(lv)
return lv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_concatenate_transpose2(
inp_0: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"),
T_transpose_handle_intermediate: T.Buffer(
(T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": True})
T_concat_handle_intermediate = T.alloc_buffer(
(T.int64(2), T.int64(4), T.int64(64), T.int64(64))
)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3])
T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
T.int64(1) <= v_ax0,
inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
inp_0[v_ax0, v_ax1, v_ax2, v_ax3],
)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2])
T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_handle_intermediate[
v_ax0, v_ax1, v_ax2, v_ax3
] = T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2]
@R.function
def main(
inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"num_input": 3})
cls = Expected
with R.dataflow():
lv = R.call_tir(
cls.fused_concatenate_transpose2,
(inp_0,),
out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"),
)
R.output(lv)
return lv
_check(Module, Expected)
def test_tir_expression_in_shape():
@I.ir_module
class Module:
@R.function
def fused_transpose_matmul(
x: R.Tensor((3, 4), dtype="float32"),
y: R.Tensor(("n - 1", 4), dtype="float32"),
tir_vars: R.Shape(["n"]),
) -> R.Tensor(("n - 1", 3), dtype="float32"):
R.func_attr({"Primitive": True})
with R.dataflow():
lv = R.emit_te(topi.transpose, x)
gv = R.emit_te(topi.matmul, y, lv)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((3, 4), dtype="float32"),
y: R.Tensor(("n - 1", 4), dtype="float32"),
tir_vars: R.Shape(["n"]),
) -> R.Tensor(("n - 1", 3), dtype="float32"):
cls = Module
with R.dataflow():
lv = cls.fused_transpose_matmul(x, y, tir_vars)
R.output(lv)
return lv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_transpose_matmul(
x: T.Buffer((T.int64(3), T.int64(4)), "float32"),
p_y: T.handle,
p_output0: T.handle,
n: T.int64,
):
T.func_attr({"tir.noalias": True})
y = T.match_buffer(p_y, (n - T.int64(1), T.int64(4)))
var_T_matmul_intermediate = T.match_buffer(p_output0, (n - T.int64(1), T.int64(3)))
var_T_transpose_intermediate = T.alloc_buffer((T.int64(4), T.int64(3)))
for ax0, ax1 in T.grid(T.int64(4), T.int64(3)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
var_T_transpose_intermediate[v_ax0, v_ax1] = x[v_ax1, v_ax0]
for ax0, ax1, k in T.grid(n - T.int64(1), T.int64(3), T.int64(4)):
with T.block("T_matmul"):
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
with T.init():
var_T_matmul_intermediate[v_ax0, v_ax1] = T.float32(0)
var_T_matmul_intermediate[v_ax0, v_ax1] = (
var_T_matmul_intermediate[v_ax0, v_ax1]
+ y[v_ax0, v_k] * var_T_transpose_intermediate[v_k, v_ax1]
)
@R.function
def main(
x: R.Tensor((3, 4), dtype="float32"),
y: R.Tensor(("n - 1", 4), dtype="float32"),
tir_vars: R.Shape(["n"]),
) -> R.Tensor(("n - 1", 3), dtype="float32"):
n = T.int64()
cls = Expected
with R.dataflow():
lv = R.call_tir(
cls.fused_transpose_matmul,
(x, y),
out_sinfo=R.Tensor((n - 1, 3), dtype="float32"),
tir_vars=R.shape([n]),
)
R.output(lv)
return lv
_check(Module, Expected)
def test_tuple_input_unused_field():
@I.ir_module
class Module:
@T.prim_func(private=True)
def reshape(
A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"),
):
T.func_attr({"op_pattern": 2, "tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[
(
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1)
// T.int64(8)
+ v_ax0
)
% T.int64(4),
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8),
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]
)
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[
(
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) // T.int64(8)
+ v_ax0
)
% T.int64(4),
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8),
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]
@R.function(private=True)
def fused_reshape(
lv: R.Tuple(
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")
)
) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
lv1: R.Tensor((4, 8, 2048), dtype="float32") = lv[0]
gv = R.call_tir(
cls.reshape, (lv1,), out_sinfo=R.Tensor((4, 8, 32, 64), dtype="float32")
)
R.output(gv)
return gv
@R.function
def main(
tup: R.Tuple(
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")
)
) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
cls = Module
with R.dataflow():
lv_1: R.Tensor((4, 8, 32, 64), dtype="float32") = cls.fused_reshape(tup)
R.output(lv_1)
return lv_1
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_reshape(
lv_0: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
T_reshape_handle_intermediate: T.Buffer(
(T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"
),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
lv_0[
(
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1)
// T.int64(8)
+ v_ax0
)
% T.int64(4),
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8),
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]
)
T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv_0[
(
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) // T.int64(8)
+ v_ax0
)
% T.int64(4),
((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8),
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]
@R.function
def main(
tup: R.Tuple(
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")
)
) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
cls = Expected
with R.dataflow():
lv: R.Tensor((4, 8, 2048), dtype="float32") = tup[0]
lv_1 = R.call_tir(
cls.fused_reshape, (lv,), out_sinfo=R.Tensor((4, 8, 32, 64), dtype="float32")
)
R.output(lv_1)
return lv_1
_check(Module, Expected)
def test_unique_duplicated_buffer_allocation():
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(1.0)
@T.prim_func(private=True)
def add1(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(2.0)
@R.function
def main(
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((4096, 4096), dtype="float16"):
cls = Module
with R.dataflow():
gv: R.Tensor((4096, 4096), dtype="float16") = cls.fused_func(input_embeds)
R.output(gv)
return gv
@R.function(private=True)
def fused_func(
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((4096, 4096), dtype="float16"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
gv = R.call_tir(cls.add1, (lv,), out_sinfo=R.Tensor((4096, 4096), dtype="float16"))
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_func(
input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out_intermediate_1: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
Out_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(input_embeds[vi, vj])
T.writes(Out_intermediate[vi, vj])
Out_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1)
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add_1"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Out_intermediate[vi, vj])
T.writes(Out_intermediate_1[vi, vj])
Out_intermediate_1[vi, vj] = Out_intermediate[vi, vj] + T.float16(2)
@R.function
def main(
input_embeds: R.Tensor((4096, 4096), dtype="float16")
) -> R.Tensor((4096, 4096), dtype="float16"):
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused_func,
(input_embeds,),
out_sinfo=R.Tensor((4096, 4096), dtype="float16"),
)
R.output(gv)
return gv
_check(Module, Expected)
def test_extern_func():
bb = relax.BlockBuilder()
bb.add_func(relax.extern("extern_func"), "extern_func")
mod = bb.get()
# FuseTIR should keep the ExternFunc in the IRModule.
_check(mod, mod)
def test_symbolic_var_in_buffer_shape():
"""A PrimFunc may have dynamic buffer shapes
Symbolic variables in a PrimFunc may be present in the buffer
shape without a corresponding parameter. These symbolic variables
are inferred from the buffer's shape. (Or, at runtime, they are
typically determined from the DLTensor's known shape.)
"""
@I.ir_module
class Before:
@T.prim_func(private=True)
def foo(
X_handle: T.handle,
Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
rotary_handle: T.handle,
m: T.int64,
):
sequence_length = T.int64()
X = T.match_buffer(
X_handle, [T.int64(1), sequence_length, T.int64(32), T.int64(128)], "float32"
)
rotary = T.match_buffer(
rotary_handle, [T.int64(1), sequence_length, T.int64(32), T.int64(128)], "float32"
)
for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, T.int64(32), T.int64(128)):
with T.block("rotary"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2, v3]
@R.function
def fused(
x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
R.func_attr({"Primitive": True})
sequence_length = T.int64()
m = T.int64()
cls = Before
with R.dataflow():
lv1 = R.emit_te(topi.add, x, x)
gv = R.call_tir(
cls.foo,
[lv1, y],
out_sinfo=R.Tensor((1, sequence_length, 32, 128), dtype="float32"),
tir_vars=R.shape([m]),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused(x, y, len)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused(
X_handle: T.handle,
Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
rotary_handle: T.handle,
m: T.int64,
):
T.func_attr({"tir.noalias": True})
sequence_length = T.int64()
X = T.match_buffer(
X_handle, [T.int64(1), sequence_length, T.int64(32), T.int64(128)], "float32"
)
rotary = T.match_buffer(
rotary_handle, [T.int64(1), sequence_length, T.int64(32), T.int64(128)], "float32"
)
T_add = T.alloc_buffer((T.int64(1), sequence_length, T.int64(32), T.int64(128)))
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), sequence_length, T.int64(32), T.int64(128)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3]
)
for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, T.int64(32), T.int64(128)):
with T.block("rotary"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3]
@R.function
def main(
x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
y: R.Tensor((2048, 128), dtype="float32"),
len: R.Shape(["m"]),
) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
sequence_length = T.int64()
m = T.int64()
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused,
(x, y),
out_sinfo=R.Tensor([1, sequence_length, 32, 128], "float32"),
tir_vars=R.shape([m]),
)
R.output(gv)
return gv
_check(Before, Expected)
def test_symbolic_var_called_with_static_shape():
"""A dynamic PrimFunc may be called with a static shape"""
@I.ir_module
class Before:
@T.prim_func(private=True)
def sum_1d(
X_handle: T.handle,
Y: T.Buffer([T.int64(1)], "float32"),
):
num_elements = T.int64()
X = T.match_buffer(X_handle, [num_elements], "float32")
for i in range(num_elements):
with T.block("sum"):
vi = T.axis.remap("R", [i])
with T.init():
Y[0] = 0.0
Y[0] = Y[0] + X[vi]
@R.function(private=True)
def fused(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
gv = R.call_tir(
cls.sum_1d,
[x],
out_sinfo=R.Tensor([1], dtype="float32"),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused(x)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused(
X: T.Buffer([T.int64(64)], "float32"),
Y: T.Buffer([T.int64(1)], "float32"),
):
T.func_attr({"tir.noalias": True})
for i in range(T.int64(64)):
with T.block("sum"):
vi = T.axis.remap("R", [i])
with T.init():
Y[0] = 0.0
Y[0] = Y[0] + X[vi]
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), dtype="float32"))
R.output(gv)
return gv
_check(Before, Expected)
def test_symbolic_var_called_with_multiple_static_shapes():
"""A dynamic PrimFunc may be called with different shapes each time"""
@I.ir_module
class Before:
@T.prim_func(private=True)
def sum_1d(
X_handle: T.handle,
Sum: T.Buffer([T.int64(1)], "float32"),
):
num_elements = T.int64()
X = T.match_buffer(X_handle, [num_elements], "float32")
for i in range(num_elements):
with T.block("sum"):
vi = T.axis.remap("R", [i])
with T.init():
Sum[0] = 0.0
Sum[0] = Sum[0] + X[vi]
@T.prim_func(private=True)
def sum_scalar(
X: T.Buffer([T.int64(1)], "float32"),
Y: T.Buffer([T.int64(1)], "float32"),
Sum: T.Buffer([T.int64(1)], "float32"),
):
for i in range(T.int64(1)):
with T.block("Out"):
vi = T.axis.remap("S", [i])
Sum[vi] = X[vi] + Y[vi]
@R.function(private=True)
def fused(
x: R.Tensor([64], dtype="float32"),
y: R.Tensor([16], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
x_sum = R.call_tir(
cls.sum_1d,
[x],
out_sinfo=R.Tensor([1], dtype="float32"),
)
y_sum = R.call_tir(
cls.sum_1d,
[y],
out_sinfo=R.Tensor([1], dtype="float32"),
)
gv = R.call_tir(
cls.sum_scalar,
[x_sum, y_sum],
out_sinfo=R.Tensor([1], dtype="float32"),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
y: R.Tensor([16], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused(x, y)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused(
X: T.Buffer([T.int64(64)], "float32"),
Y: T.Buffer([T.int64(16)], "float32"),
Out: T.Buffer([T.int64(1)], "float32"),
):
T.func_attr({"tir.noalias": True})
XSum = T.alloc_buffer([T.int64(1)], "float32")
YSum = T.alloc_buffer([T.int64(1)], "float32")
for i in range(T.int64(64)):
with T.block("XSum"):
vi = T.axis.remap("R", [i])
with T.init():
XSum[0] = 0.0
XSum[0] = XSum[0] + X[vi]
for i in range(T.int64(16)):
with T.block("YSum"):
vi = T.axis.remap("R", [i])
with T.init():
YSum[0] = 0.0
YSum[0] = YSum[0] + Y[vi]
for i in range(T.int64(1)):
with T.block("Out"):
vi = T.axis.remap("S", [i])
Out[vi] = XSum[vi] + YSum[vi]
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
y: R.Tensor([16], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(cls.fused, (x, y), out_sinfo=R.Tensor((1,), dtype="float32"))
R.output(gv)
return gv
_check(Before, Expected)
def test_symbolic_var_called_with_static_argument():
"""A dynamic PrimFunc may accept a static argument
The `tir_vars` parameter in `R.call_tir` contains definitions for
all TIR variables explicitly listed in the function signature, and
contains the TIR expression to be passed as the argument for for
each parameter.
This test is identical to the earlier test named
"test_symbolic_var_called_with_static_shape", except for the
explicit parameter in `sum_1d`.
"""
@I.ir_module
class Before:
@T.prim_func(private=True)
def sum_1d(
X_handle: T.handle,
Y: T.Buffer([T.int64(1)], "float32"),
num_elements: T.int64,
):
X = T.match_buffer(X_handle, [num_elements], "float32")
for i in range(num_elements):
with T.block("sum"):
vi = T.axis.remap("R", [i])
with T.init():
Y[0] = 0.0
Y[0] = Y[0] + X[vi]
@R.function(private=True)
def fused(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
gv = R.call_tir(
cls.sum_1d,
[x],
out_sinfo=R.Tensor([1], dtype="float32"),
tir_vars=R.shape([64]),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused(x)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused(
X: T.Buffer([T.int64(64)], "float32"),
Y: T.Buffer([T.int64(1)], "float32"),
):
T.func_attr({"tir.noalias": True})
for i in range(T.int64(64)):
with T.block("sum"):
vi = T.axis.remap("R", [i])
with T.init():
Y[0] = 0.0
Y[0] = Y[0] + X[vi]
@R.function
def main(
x: R.Tensor([64], dtype="float32"),
) -> R.Tensor([1], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), dtype="float32"))
R.output(gv)
return gv
_check(Before, Expected)
def test_gather():
@I.ir_module
class Before:
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(1.0)
@T.prim_func(private=True)
def take(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(1),), "int32"),
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
@R.function
def main(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
gv: R.Tensor((1, 4096), dtype="float16") = cls.fused_func(input_ids, input_embeds)
R.output(gv)
return gv
@R.function(private=True)
def fused_func(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
lv = R.call_tir(
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
gv = R.call_tir(
cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
)
R.output(gv)
return gv
@I.ir_module
class After:
@T.prim_func(private=True)
def fused_func(
input_ids: T.Buffer((T.int64(1),), "int32"),
input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
):
T.func_attr({"tir.noalias": True})
Out_handle_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out_handle_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1)
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_take[v_ax0, v_ax1] = Out_handle_intermediate[input_ids[v_ax0], v_ax1]
@R.function
def main(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = After
with R.dataflow():
gv = R.call_tir(
cls.fused_func,
(input_ids, input_embeds),
out_sinfo=R.Tensor((1, 4096), dtype="float16"),
)
R.output(gv)
return gv
_check(Before, After)
def test_inplace_simple():
@I.ir_module
class Module:
I.module_attrs({"foo": "bar"})
@T.prim_func(private=True)
def add_inplace(
A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: T.Buffer((), "float32")
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
# T.reads(A[v_ax0, v_ax1], B[()])
# T.writes(A[v_ax0, v_ax1])
A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
@T.prim_func(private=True)
def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(10), T.int64(20)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
# T.reads(A[v_i0, v_i1])
# T.writes(A[v_i0, v_i1])
A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
@T.prim_func(private=True)
def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
# T.reads(A[v_ax0, v_ax1])
# T.writes(A[v_ax0, v_ax1])
A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
@R.function(private=True)
def fused_add_exp_squeeze(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
# This overwrites x and is actually evil because the function is marked as pure
# but we are doing it just to test the pass. The automatic DataflowUseInplaceCalls
# transformation will not produce code like this, but it may make sense to do it
# if ownership of x is fully and truly transferred.
# Users should apply with caution!
lv = R.call_tir_inplace(
cls.add_inplace,
(x, p0),
inplace_indices=[0],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
lv1 = R.call_tir_inplace(
cls.exp_inplace,
(lv,),
inplace_indices=[0],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
gv = R.call_tir_inplace(
cls.squeeze_inplace,
(lv1,),
inplace_indices=[0],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Module
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0)
R.output(gv1)
return gv1
@I.ir_module
class Expected:
I.module_attrs({"foo": "bar"})
@T.prim_func(private=True)
def fused_add_exp_squeeze(
x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: T.Buffer((), "float32")
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for i0, i1 in T.grid(T.int64(10), T.int64(20)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
x[v_i0, v_i1] = T.exp(x[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
x[v_ax0, v_ax1] = x[v_ax0, v_ax1]
# note that this will clobber x! Use with caution
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Expected
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace(
cls.fused_add_exp_squeeze,
(x, p0),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
inplace_indices=[0],
)
R.output(gv1)
return gv1
_check(Module, Expected)
def test_fuse_inplace_and_non_inplace():
@I.ir_module
class Module:
I.module_attrs({"foo": "bar"})
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
B: T.Buffer((), "float32"),
Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
@T.prim_func(private=True)
def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(10), T.int64(20)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
@T.prim_func(private=True)
def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
@R.function(private=True)
def fused_add_exp_squeeze(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.add,
(x, p0),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
lv1 = R.call_tir_inplace(
cls.exp_inplace,
(lv,),
inplace_indices=[0],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
gv = R.call_tir_inplace(
cls.squeeze_inplace,
(lv1,),
inplace_indices=[0],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Module
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0)
R.output(gv1)
return gv1
@I.ir_module
class Expected:
I.module_attrs({"foo": "bar"})
@T.prim_func(private=True)
def fused_add_exp_squeeze(
x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
p0: T.Buffer((), "float32"),
p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for i0, i1 in T.grid(T.int64(10), T.int64(20)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1]
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Expected
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
cls.fused_add_exp_squeeze,
(x, p0),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(gv1)
return gv1
_check(Module, Expected)
def test_use_as_inplace_and_dps():
@I.ir_module
class Module:
# we will use it both in-place and normally (DPS)
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
B: T.Buffer((), "float32"),
Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
@R.function(private=True)
def fused_sums(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.add,
(x, p0),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
lv1 = R.call_tir_inplace(
cls.add,
(x, p0, lv),
inplace_indices=[2],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
lv2 = R.call_tir_inplace(
cls.add,
(x, p0, lv1),
inplace_indices=[2],
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(lv2)
return lv2
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Module
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x, p0)
R.output(gv1)
return gv1
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_sums(
x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
p0: T.Buffer((), "float32"),
p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
@R.function
def main(
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
) -> R.Tensor((10, 20), dtype="float32"):
cls = Expected
with R.dataflow():
gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
cls.fused_sums,
(x, p0),
out_sinfo=R.Tensor((10, 20), dtype="float32"),
)
R.output(gv1)
return gv1
_check(Module, Expected)
def test_private_nonprimitive_func():
"""Input IRModule may contain calls to non-primitive functions
This is a regression test. Prior implementations did not preserve
relax-to-relax function calls.
"""
@I.ir_module
class Before:
@R.function
def main(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
gv = cls.fused_func(input_ids, input_embeds)
R.output(gv)
return gv
@R.function(private=True)
def fused_func(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
lv = R.call_tir(
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
gv = R.call_tir(
cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
)
R.output(gv)
return gv
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(1.0)
@T.prim_func(private=True)
def take(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(1),), "int32"),
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
_check(Before, Before)
def test_fuse_with_axis_separators():
@I.ir_module
class Before:
@T.prim_func(private=True)
def add(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] + B[i, j]
@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
w = R.call_tir(
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
out = R.call_tir(
cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out
@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x, y, z)
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
T.func_attr({"tir.noalias": True})
X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
for iters in T.grid(*X.shape):
with T.block("compute_Y"):
i, j = T.axis.remap("SS", iters)
Temp[i, j] = X[i, j] + Y[i, j]
for iters in T.grid(*X.shape):
with T.block("compute_Z"):
i, j = T.axis.remap("SS", iters)
C[i, j] = Temp[i, j] + Z[i, j]
@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused_function,
[x, y, z],
out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
)
R.output(gv)
return gv
_check(Before, Expected)
def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
@I.ir_module
class Before:
@T.prim_func(private=True)
def mul(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] * B[i, j]
@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
out = R.call_tir(
cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out
@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x)
R.output(gv)
return gv
with pytest.raises(
tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same relax var:.*"
):
relax.transform.FuseTIR()(Before)
if __name__ == "__main__":
tvm.testing.main()