blob: e1f5c54b1bca93d841a58a7ed5a936f3e52808c4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.relax.transform.legalize_ops.common import register_legalize
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
import pytest
def test_customize_legalize():
# fmt: off
@tvm.script.ir_module
class Add:
@R.function
def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
cls = Expected
gv = R.call_tir(cls.add, (y, x), R.Tensor((4, 3, 2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def add(rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + rxplaceholder[T.int64(0), ax2, ax3]
# fmt: on
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call):
from tvm import topi # pylint: disable=import-outside-toplevel
return bb.call_te(topi.add, call.args[1], call.args[0])
mod = LegalizeOps({"relax.add": customize_legalize_add})(Add)
tvm.ir.assert_structural_equal(mod, Expected)
def test_legalize_multiple_types_of_call():
# fmt: off
@tvm.script.ir_module
class Before:
@R.function
def mul2(x: R.Tensor((3, 3), "float32")):
gv = R.multiply(x, R.const(2.0, "float32"))
return gv
@T.prim_func(private=True)
def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_id[v_ax0, v_ax1])
T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1]
@R.function
def main(x: R.Tensor((3, 3), "float32")):
cls = Before
gv: R.Tensor((3, 3), "float32") = cls.mul2(x)
gv1 = R.call_tir(cls.identity, gv, R.Tensor((3, 3), dtype="float32"))
gv2 = R.multiply(gv1, R.const(2.0, "float32"))
return gv2
@tvm.script.ir_module
class Expected:
@R.function
def mul2(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.multiply, (x,), R.Tensor((3, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_id[v_ax0, v_ax1])
T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1]
@T.prim_func(private=True)
def multiply(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.float32(2)
@R.function
def main(x1: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
cls = Expected
gv1: R.Tensor((3, 3), dtype="float32") = cls.mul2(x1)
gv11 = R.call_tir(cls.identity, gv1, R.Tensor((3, 3), dtype="float32"))
gv2 = R.call_tir(cls.multiply, (gv11,), R.Tensor((3, 3), dtype="float32"))
return gv2
# fmt: on
After = LegalizeOps()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_can_not_legalize():
# case 1: does't have legalization
add_legalize = tvm.ir.Op.get("relax.add").get_attr("FLegalize")
# reset it for test
tvm.ir.Op.get("relax.add").reset_attr("FLegalize")
# fmt: off
@tvm.script.ir_module
class Before0:
@R.function
def main(x: R.Tensor((3, 3), "float32")):
gv: R.Tensor((3, 3), "float32") = R.add(x, x)
return gv
# fmt: on
After0 = LegalizeOps()(Before0)
tvm.ir.assert_structural_equal(After0, Before0)
register_legalize("relax.add", add_legalize)
# case 2: don't know all shape
s = relax.Var("s", relax.ShapeStructInfo((3, 3)))
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
y = relax.Var("y", relax.TensorStructInfo(s, "float32"))
bb = relax.BlockBuilder()
with bb.function("main", [x, y]):
with bb.dataflow():
gv = bb.emit_output(R.add(x, y))
bb.emit_func_output(gv)
Before1 = bb.get()
After1 = LegalizeOps()(Before1)
tvm.ir.assert_structural_equal(After1, Before1)
def test_legalize_scalar_data_type_preserve():
# fmt: off
@tvm.script.ir_module
class Before0:
@R.function
def main(x: R.Tensor((3, 3), "float16")):
gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514, "float16"))
return gv
@tvm.script.ir_module
class Before1:
@R.function
def main(x: R.Tensor((3, 3), "uint8")):
gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8"))
return gv
@tvm.script.ir_module
class Before2:
@R.function
def main(x: R.Tensor((3, 3), "bool")):
gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool"))
return gv
@tvm.script.ir_module
class Expected0:
@T.prim_func(private=True)
def multiply(
rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float16"),
T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.float16(
1.1455078125
)
@R.function
def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3), dtype="float16"):
cls = Expected0
gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="float16"))
return gv
@tvm.script.ir_module
class Expected1:
@T.prim_func(private=True)
def multiply(
rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.uint8(2)
@R.function
def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3), dtype="uint8"):
cls = Expected1
gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="uint8"))
return gv
@tvm.script.ir_module
class Expected2:
@T.prim_func(private=True)
def equal(
rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "bool"),
T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_equal"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_equal[v_ax0, v_ax1])
T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] == tvm.tir.const(True, "bool")
@R.function
def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), dtype="bool"):
cls = Expected2
gv = R.call_tir(cls.equal, (x,), out_sinfo=R.Tensor((3, 3), dtype="bool"))
return gv
# fmt: on
After0 = LegalizeOps()(Before0)
tvm.ir.assert_structural_equal(After0, Expected0)
After1 = LegalizeOps()(Before1)
tvm.ir.assert_structural_equal(After1, Expected1)
After2 = LegalizeOps()(Before2)
tvm.ir.assert_structural_equal(After2, Expected2)
def test_matmul_legalization_requires_known_dtype():
@I.ir_module
class ArbitraryDtype:
@R.function
def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]):
return R.matmul(A, B)
with pytest.raises(AssertionError) as err:
LegalizeOps()(ArbitraryDtype)
# This error should be caught while attempting to legalize the
# R.matmul, where we can present a user-friendly error.
# Otherwise, the error isn't caught until the implementation of
# `BlockBuilder.call_te`, when attempting to create a numeric
# constant of type kHandle, which produces a much less
# user-friendly error.
err_message = err.value.args[0]
assert err_message.startswith("To legalize R.matmul")
emit_legalization_through_builder = tvm.testing.parameter(
by_dict={
"return_relax_expr": False,
"return_relax_var": True,
}
)
@pytest.fixture
def custom_op(emit_legalization_through_builder):
op_name = "custom_op.matmul_bias_add"
def infer_struct_info(call: relax.Call, context):
activations, weight, bias = call.args
matmul_call = relax.op.matmul(activations, weight)
matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
matmul_call, context
)
matmul_var = relax.Var("dummy_var", matmul_sinfo)
add_call = matmul_var + bias
add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
return add_sinfo
def legalize(bb: relax.BlockBuilder, call: relax.Call):
activations, weight, bias = call.args
legalized = relax.op.matmul(activations, weight) + bias
if emit_legalization_through_builder:
legalized = bb.emit(legalized)
return legalized
op_attrs = {
"FInferStructInfo": infer_struct_info,
"FLegalize": legalize,
"FPurity": True,
}
for key, value in op_attrs.items():
tvm.ir.register_op_attr(op_name, key, value)
op = tvm.ir.Op.get(op_name)
yield op
for key in op_attrs:
op.reset_attr(key)
def test_recursive_legalization(custom_op):
"""Legalization of an operator may produce new operators requiring legalization"""
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16, 32, 64], "float32"),
Weight: R.Tensor([64, 128], "float32"),
Bias: R.Tensor([16, 32, 128], "float32"),
):
return relax.Call(custom_op, [A, Weight, Bias])
AfterFirstIter = LegalizeOps()(Before)
AfterSecondIter = LegalizeOps()(AfterFirstIter)
# After LegalizeOps, the custom operation should be replaced by
# `R.matmul` and `R.add`, which should in turn be replaced with
# TIR implementations. Therefore, the second application of
# LegalizeOps() should be a no-op.
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
def test_legalize_with_vdevice():
"""Legalization may generate kernels for multiple targets
This is a regression test. In previous implementations, Relax
expressions whose argument types differed only by their `vdevice`
would be legalized to use the same `PrimFunc`.
"""
@I.ir_module
class Before:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
@R.function
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C
@R.function
def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
):
C = R.add(A, B)
return C
@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
@R.function
def func_cuda(
A: R.Tensor((32, 32), dtype="float32"),
B: R.Tensor((32, 32), dtype="float32"),
):
cls = Expected
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
return C
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"tir.noalias": True})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]
@R.function
def func_llvm(
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
):
cls = Expected
C = R.call_tir(
cls.add_llvm,
(A, B),
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
)
return C
@T.prim_func(private=True)
def add_llvm(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"target": T.target("llvm"), "tir.noalias": True})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]
with tvm.target.Target("cuda"):
After = tvm.relax.transform.LegalizeOps()(Before)
tvm.ir.assert_structural_equal(Expected, After)
if __name__ == "__main__":
tvm.testing.main()