blob: 677ba41a209bdbfd6dc1b519dbbcb5a106225e75 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T, ir as I
def test_reshape_expand_dims():
@tvm.script.ir_module
class Module:
@T.prim_func
def reshape(
rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
):
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(
rxplaceholder[
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) % T.int64(3),
]
)
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) % T.int64(3),
]
@T.prim_func
def expand_dims(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
expand_dims: T.Buffer(
(T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32"
),
):
for i0, i1, i2, i3, i4 in T.grid(
T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
):
with T.block("expand_dims"):
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(rxplaceholder[i0_1, i2_1, i4_1])
T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1]
@R.function
def main(
x: R.Tensor((8, 3), dtype="float32")
) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
cls = Module
with R.dataflow():
y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32"))
z = R.call_tir(
cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), "float32")
)
R.output(z)
return z
@tvm.script.ir_module
class Expected:
@T.prim_func
def reshape(
rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
):
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(
rxplaceholder[
(v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3),
(v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) % T.int64(3),
]
)
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
(v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3),
(v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) % T.int64(3),
]
@T.prim_func
def expand_dims(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
expand_dims: T.Buffer(
(T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32"
),
):
for i0, i1, i2, i3, i4 in T.grid(
T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
):
with T.block("expand_dims"):
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(rxplaceholder[i0_1, i2_1, i4_1])
T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1]
@R.function
def main(
x: R.Tensor((8, 3), dtype="float32")
) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
with R.dataflow():
cls = Expected
y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3))
# Note: `z` is the output var of the dataflow block, and is thus
# not expected to be rewritten.
z = R.call_tir(
cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), dtype="float32")
)
R.output(z)
return z
assert relax.analysis.has_reshape_pattern(Module["expand_dims"])
mod = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(mod, Expected)
def test_reshape_pattern_detect():
# fmt: off
@tvm.script.ir_module
class Module:
@T.prim_func
def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32")):
for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720))
v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320))
v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64))
v_ax3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64))
T.reads(rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)]
@T.prim_func
def expand_dims(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32"),
expand_dims: T.Buffer(
(T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)),
"float32",
),
):
for i0, i1, i2, i3, i4, i5 in T.grid(
T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)
):
with T.block("expand_dims"):
i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1])
T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1, i5_1]
@R.function
def main(
x: R.Tensor((2, 4096, 320), dtype="float32")
) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"):
cls = Module
with R.dataflow():
y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4096, 5, 64), dtype="float32"))
z = R.call_tir(
cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), "float32")
)
R.output(z)
return z
@tvm.script.ir_module
class Expected:
@T.prim_func
def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32"), expand_dims_1: T.Buffer((T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)), "float32")):
# with T.block("root"):
for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)):
with T.block("expand_dims"):
i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1])
T.writes(expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1, i5_1]
@T.prim_func
def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32")):
# with T.block("root"):
for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720))
v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320))
v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64))
v_ax3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64))
T.reads(rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) % T.int64(320)]
@R.function
def main(x: R.Tensor((2, 4096, 320), dtype="float32")) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"):
cls = Expected
with R.dataflow():
y: R.Tensor((2, 4096, 5, 64), dtype="float32") = R.reshape(x, R.shape([2, 4096, 5, 64]))
z = R.call_tir(cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"))
R.output(z)
return z
# fmt: on
assert relax.analysis.has_reshape_pattern(Module["reshape"])
mod = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(mod, Expected)
def test_reshape_dynamic_shape():
@tvm.script.ir_module
class Module:
@T.prim_func(private=True)
def reshape(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n, 16, 128), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(
n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048
)
v1 = T.axis.spatial(
16, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2048 // 128
)
v2 = T.axis.spatial(
128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128
)
T.reads(
A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + v1) % 32, v2 % 128]
)
T.writes(T_reshape[0, v0, v1, v2])
T_reshape[0, v0, v1, v2] = A[
((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + v1) % 32, v2 % 128
]
@R.function
def main(
x: R.Tensor((8, 16, 128), dtype="float16")
) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
cls = Module
with R.dataflow():
y = R.call_tir(
cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), dtype="float16")
)
z = R.add(y, R.const(1, "float16"))
R.output(z)
return z
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def reshape(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n, 16, 128), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(
n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048
)
v1 = T.axis.spatial(
16, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2048 // 128
)
v2 = T.axis.spatial(
128, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 128
)
T.reads(
A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + v1) % 32, v2 % 128]
)
T.writes(T_reshape[0, v0, v1, v2])
T_reshape[0, v0, v1, v2] = A[
((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + v1) % 32, v2 % 128
]
@R.function
def main(
x: R.Tensor((8, 16, 128), dtype="float16")
) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
with R.dataflow():
y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape(
x, R.shape([1, 8, 16, 128])
)
z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y, R.const(1, "float16"))
R.output(z)
return z
assert relax.analysis.has_reshape_pattern(Module["reshape"])
mod = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(mod, Expected)
def test_reshape_non_dataflow():
@tvm.script.ir_module
class Module:
@T.prim_func
def reshape(
rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"),
T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
):
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(
rxplaceholder[
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) % T.int64(3),
]
)
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
(v_ax0 * 12 + v_ax1 * 3 + v_ax2) % T.int64(3),
]
@R.function
def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"):
cls = Module
y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32"))
return y
assert relax.analysis.has_reshape_pattern(Module["reshape"])
# The binding var of the call_tir is not a DataflowVar. So the pass does no change.
mod = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(mod, Module)
def test_tuple_get_reshape():
@tvm.script.ir_module
class Module:
@T.prim_func
def fused_reshape5(
lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
T_reshape_handle_intermediate: T.Buffer(
(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
lv2_0[
(
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1)
// T.int64(4096)
+ v_ax0
)
% T.int64(2),
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096),
(v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
]
)
T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2_0[
(
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096)
+ v_ax0
)
% T.int64(2),
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096),
(v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
]
@R.function
def main(
lv41_1: R.Tuple(
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
),
) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
cls = Module
with R.dataflow():
lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
lv645 = R.call_tir(
cls.fused_reshape5,
(lv, lv1, lv2),
out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float16"),
)
out: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.add(lv645, lv645)
R.output(out)
return out
@tvm.script.ir_module
class Expected:
@T.prim_func
def fused_reshape5(
lv2_0: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
lv2_1: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
lv2_2: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float16"),
T_reshape_handle_intermediate: T.Buffer(
(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16"
),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
lv2_0[
(
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1)
// T.int64(4096)
+ v_ax0
)
% T.int64(2),
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096),
(v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
]
)
T.writes(T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2_0[
(
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) // T.int64(4096)
+ v_ax0
)
% T.int64(2),
((v_ax2 * T.int64(40) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096),
(v_ax2 * T.int64(40) + v_ax3) % T.int64(320),
]
@R.function
def main(
lv41_1: R.Tuple(
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
),
) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
with R.dataflow():
lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
lv1: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[1]
lv2: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[2]
lv645: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.reshape(
lv, R.shape([2, 4096, 8, 40])
)
out: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.add(lv645, lv645)
R.output(out)
return out
rewritten = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(rewritten, Expected)
def test_invalid_reshape():
@tvm.script.ir_module
class Module:
# The strided_slice op has the reshape pattern, but it can take only a part of the input.
# It can't be replaced with the reshape op because reshape expects to preserve the "volume"
# of the input.
@T.prim_func
def strided_slice(
A: T.Buffer((T.int64(1), T.int64(1024)), "int32"),
T_strided_slice: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_strided_slice"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_strided_slice[v_ax0, v_ax1])
T_strided_slice[v_ax0, v_ax1] = A[v_ax0, v_ax1]
@T.prim_func
def add_one(
A: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
T_add_one: T.buffer((T.int64(1), T.int64(1000)), "int32"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_add_one"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_add_one[v_ax0, v_ax1])
T_add_one[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 1
@R.function
def main(A: R.Tensor((1, 1024), dtype="int32")) -> R.Tensor((1, 1000), dtype="int32"):
with R.dataflow():
cls = Module
S = R.call_tir(
cls.strided_slice, (A,), out_sinfo=R.Tensor((1, 1000), dtype="int32")
)
A = R.call_tir(cls.add_one, (S,), out_sinfo=R.Tensor((1, 1000), dtype="int32"))
R.output(A)
return A
assert relax.analysis.has_reshape_pattern(Module["strided_slice"])
rewritten = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(rewritten, Module)
def test_reshape_detect_nop():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((8, 8), dtype="float16")) -> R.Tensor((8, 8), dtype="float16"):
with R.dataflow():
gv = R.call_pure_packed(
"foo", x, x, sinfo_args=(R.Tensor((8, 8), dtype="float16"),)
)
out = R.call_pure_packed(
"foo", gv, gv, sinfo_args=(R.Tensor((8, 8), dtype="float16"),)
)
R.output(out)
return out
rewritten = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(rewritten, Module)
def test_reshape_scalar():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
with R.dataflow():
lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, [1])
lv2: R.Tensor((1,), dtype="float32") = R.add(lv1, lv1)
R.output(lv2)
return lv2
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(1),), "float32"),
B: T.Buffer((T.int64(1),), "float32"),
T_add: T.Buffer((T.int64(1),), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0 in range(T.int64(1)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func(private=True)
def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for ax0 in range(T.int64(1)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
T.reads(A[()])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = A[()]
@R.function
def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
cls = Expected
with R.dataflow():
lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, R.shape([1]))
lv2 = R.call_tir(cls.add, (lv1, lv1), out_sinfo=R.Tensor((1,), dtype="float32"))
R.output(lv2)
return lv2
mod = Module
mod = relax.transform.LegalizeOps()(mod)
rewritten = relax.transform.RewriteDataflowReshape()(mod)
tvm.ir.assert_structural_equal(rewritten, Expected)
def test_rewrite_static_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor([256], dtype="float32")):
with R.dataflow():
y = R.reshape(x, [64, 4])
z = R.add(y, y)
R.output(z)
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((256,), dtype="float32")):
cls = Expected
with R.dataflow():
y = R.reshape(x, R.shape([64, 4]))
z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32"))
R.output(z)
return z
@T.prim_func(private=True)
def add(
y1: T.Buffer((T.int64(64), T.int64(4)), "float32"),
y2: T.Buffer((T.int64(64), T.int64(4)), "float32"),
z: T.Buffer((T.int64(64), T.int64(4)), "float32"),
):
T.func_attr({"tir.noalias": True})
for iters in T.grid(T.int64(64), T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]
After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)
# def test_rewrite_dynamic_reshape():
# @I.ir_module
# class Before:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# with R.dataflow():
# y = R.reshape(x, [N // 4, 4])
# z = R.add(y, y)
# R.output(z)
# return z
# @I.ir_module
# class Expected:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# cls = Expected
# with R.dataflow():
# y = R.reshape(x, R.shape([N // 4, 4]))
# z = R.call_tir(
# cls.add,
# (y, y),
# tir_vars=[N],
# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"),
# )
# R.output(z)
# return z
# @T.prim_func(private=True)
# def add(
# y1_handle: T.handle,
# y2_handle: T.handle,
# z_handle: T.handle,
# N: T.int64,
# ):
# y1 = T.match_buffer(y1_handle, [N // 4, 4], "float32")
# y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32")
# z = T.match_buffer(z_handle, [N // 4, 4], "float32")
# T.func_attr({"tir.noalias": True})
# for iters in T.grid(T.int64(64), T.int64(4)):
# with T.block("T_add"):
# i, j = T.axis.remap("SS", iters)
# z[i, j] = y1[i, j] + y2[i, j]
# After = tvm.ir.transform.Sequential(
# [
# # Lower both R.reshape and R.add from Relax to TIR
# relax.transform.LegalizeOps(),
# # Identify reshapes, raise calls to cls.reshape from TIR
# # to Relax
# relax.transform.RewriteDataflowReshape(),
# # Clean up afterwards, removing the no-longer-required
# # PrimFunc "reshape"
# relax.transform.DeadCodeElimination(),
# ]
# )(Before)
# After.show()
# tvm.ir.assert_structural_equal(Expected, After)
def test_rewrite_dynamic_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
with R.dataflow():
y = R.reshape(x, [N * 4, T.int64(4)])
z = R.add(y, y)
R.output(z)
return z
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
cls = Expected
with R.dataflow():
y = R.reshape(x, R.shape([N * 4, T.int64(4)]))
z = R.call_tir(
cls.add,
(y, y),
tir_vars=[N],
out_sinfo=R.Tensor((N * 4, 4), dtype="float32"),
)
R.output(z)
return z
@T.prim_func(private=True)
def add(
y1_handle: T.handle,
y2_handle: T.handle,
z_handle: T.handle,
N: T.int64,
):
y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")
T.func_attr({"tir.noalias": True})
for iters in T.grid(N * 4, T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]
After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)
if __name__ == "__main__":
tvm.testing.main()