blob: 3addfab2e88c1b78fa15e2d836bdd9d2071d4137 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests to validate relax optimize layout tranform pass."""
import numpy as np
import pytest
import tvm.testing
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.relax.transform import DeadCodeElimination, FuseTIR, OptimizeLayoutTransform
from tvm.script import ir as I, tir as T, relax as R
def _run_pass_compare_output(Before, Expected):
After = tvm.ir.transform.Sequential(
[
OptimizeLayoutTransform(),
DeadCodeElimination(),
FuseTIR(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_optimize_transform_layout_pass_one_arg():
@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output[v_ax0, v_ax1])
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
@R.function
def main(
x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2 = R.call_tir(
Before.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv5 = R.call_tir(
Before.relax_add_replacement,
(lv4, lv3),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output[v_ax0, v_ax1])
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
@R.function
def main(
x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv5 = R.call_tir(
Expected.relax_add_replacement,
(lv1, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
R.output(gv)
return gv
_run_pass_compare_output(Before, Expected)
def test_optimize_transform_layout_pass_two_args():
@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output[v_ax0, v_ax1])
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
@R.function
def main(
x: R.Tensor((16,), dtype="float32"),
y: R.Tensor((16,), dtype="float32"),
z: R.Tensor((16,), dtype="float32"),
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
z, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv3 = R.call_tir(
Before.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv4 = R.call_tir(
Before.relax_add_replacement,
(lv, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv9 = R.call_tir(
Before.relax_add_replacement,
(lv7, lv8),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv10
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output[v_ax0, v_ax1])
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
@R.function
def main(
x: R.Tensor((16,), dtype="float32"),
y: R.Tensor((16,), dtype="float32"),
z: R.Tensor((16,), dtype="float32"),
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
z, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv3 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv4 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv5 = R.call_tir(
Expected.relax_add_replacement,
(lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
R.output(gv)
return gv
_run_pass_compare_output(Before, Expected)
def test_tranform_layout_tir_remove_pad_transform_layout():
@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_relu_replacement(
arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")
):
T.func_attr({"operator_name": "relax.relu"})
# with T.block("root"):
for ax0 in range(16):
with T.block("T_add"):
v_ax0 = T.axis.spatial(16, ax0)
T.reads(arg0[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
@T.prim_func(private=True)
def remove_pad(var_input: T.handle, var_output: T.handle):
T.func_attr({"operator_name": "remove_pad", "tir.noalias": T.bool(True)})
p0 = T.int64()
input = T.match_buffer(var_input, (p0,))
i0 = T.int64()
output = T.match_buffer(var_output, (i0,))
# with T.block("root"):
for ax0 in range(i0):
with T.block("output"):
v_ax0 = T.axis.spatial(i0, ax0)
T.reads(input[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = input[v_ax0]
@R.function
def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv1 = R.call_tir(
Before.relax_relu_replacement,
(lv,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv2: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv1,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
lv_1 = R.call_tir(
Before.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32")
)
lv3: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv_1,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv4 = R.call_tir(
Before.relax_relu_replacement,
(lv3,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
Before.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_relu_replacement(
arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")
):
T.func_attr({"operator_name": "relax.relu"})
# with T.block("root"):
for ax0 in range(16):
with T.block("T_add"):
v_ax0 = T.axis.spatial(16, ax0)
T.reads(arg0[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
@T.prim_func(private=True)
def remove_pad(var_input: T.handle, var_output: T.handle):
T.func_attr({"operator_name": "remove_pad", "tir.noalias": T.bool(True)})
p0 = T.int64()
input = T.match_buffer(var_input, (p0,))
i0 = T.int64()
output = T.match_buffer(var_output, (i0,))
# with T.block("root"):
for ax0 in range(i0):
with T.block("output"):
v_ax0 = T.axis.spatial(i0, ax0)
T.reads(input[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = input[v_ax0]
@R.function
def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv1 = R.call_tir(
Expected.relax_relu_replacement,
(lv,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv4 = R.call_tir(
Expected.relax_relu_replacement,
(lv1,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
gv = R.call_tir(
Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
R.output(gv)
return gv
_run_pass_compare_output(Before, Expected)
if __name__ == "__main__":
tvm.testing.main()