blob: 5d45db52417769e1b4b82dcf724fa0c971a7e4f5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
def test_single_buffer():
@I.ir_module
class Before:
@T.prim_func(private=True)
def tir_func(
X: T.Buffer((224, 224), "float32"),
W: T.Buffer((224, 224), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
T.func_attr({"layout_free_buffers": [1]})
W_rewrite = T.alloc_buffer((4, 4, 56, 56))
for i, j in T.grid(224, 224):
with T.block("W_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Before
with R.dataflow():
gv = R.call_tir(
cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32")
)
R.output(gv)
return gv
@I.ir_module
class After:
@T.prim_func(private=True)
def tir_func_prepacked(
X: T.Buffer((224, 224), "float32"),
W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
@T.prim_func(private=True)
def tir_func_weight_prepack(
W: T.Buffer((224, 224), "float32"),
W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
):
for i, j in T.grid(224, 224):
with T.block("W_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = After
with R.dataflow():
lv = R.call_tir(
cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32")
)
lv1 = R.call_tir(
cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32")
)
gv: R.Tensor((224, 224), dtype="float32") = lv1
R.output(gv)
return gv
mod = relax.transform.SplitLayoutRewritePreproc()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_multiple_buffers():
@I.ir_module
class Before:
@T.prim_func(private=True)
def tir_func(
X: T.Buffer((224, 224), "float32"),
W1: T.Buffer((224, 224), "float32"),
W2: T.Buffer((224, 224), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
W1_rewrite = T.alloc_buffer((4, 4, 56, 56))
W2_rewrite = T.alloc_buffer((4, 4, 56, 56))
for i, j in T.grid(224, 224):
with T.block("W1_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj]
for i, j in T.grid(224, 224):
with T.block("W2_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj]
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = (
X[vi, vj]
+ W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
)
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w1: R.Tensor((224, 224), dtype="float32"),
w2: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Before
with R.dataflow():
gv = R.call_tir(
cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), dtype="float32")
)
R.output(gv)
return gv
@I.ir_module
class After:
@T.prim_func(private=True)
def tir_func_prepacked(
X: T.Buffer((224, 224), "float32"),
W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = (
X[vi, vj]
+ W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
)
@T.prim_func(private=True)
def tir_func_weight_prepack(
W1: T.Buffer((224, 224), "float32"),
W2: T.Buffer((224, 224), "float32"),
W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
):
for i, j in T.grid(224, 224):
with T.block("W1_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj]
for i, j in T.grid(224, 224):
with T.block("W2_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj]
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w1: R.Tensor((224, 224), dtype="float32"),
w2: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = After
with R.dataflow():
lv0 = R.call_tir(
cls.tir_func_weight_prepack,
(w1, w2),
out_sinfo=[
R.Tensor((4, 4, 56, 56), "float32"),
R.Tensor((4, 4, 56, 56), "float32"),
],
)
lv1 = R.call_tir(
cls.tir_func_prepacked,
(x, lv0[0], lv0[1]),
out_sinfo=R.Tensor((224, 224), "float32"),
)
gv: R.Tensor((224, 224), dtype="float32") = lv1
R.output(gv)
return gv
mod = relax.transform.SplitLayoutRewritePreproc()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_attr_inheritance():
@I.ir_module
class Before:
@T.prim_func(private=True)
def tir_func(
X: T.Buffer((224, 224), "float32"),
W: T.Buffer((224, 224), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": True})
W_rewrite = T.alloc_buffer((4, 4, 56, 56))
for i, j in T.grid(224, 224):
with T.block("W_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Before
with R.dataflow():
gv = R.call_tir(
cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32")
)
R.output(gv)
return gv
@I.ir_module
class After:
@T.prim_func(private=True)
def tir_func_prepacked(
X: T.Buffer((224, 224), "float32"),
W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
Out: T.Buffer((224, 224), "float32"),
):
T.func_attr({"tir.noalias": True})
for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
with T.block("Out"):
vi = T.axis.spatial(224, i0 * 56 + i1)
vj = T.axis.spatial(224, j0 * 56 + j1)
Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
@T.prim_func(private=True)
def tir_func_weight_prepack(
W: T.Buffer((224, 224), "float32"),
W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
):
T.func_attr({"tir.noalias": True})
for i, j in T.grid(224, 224):
with T.block("W_rewrite"):
vi, vj = T.axis.remap("SS", [i, j])
W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
@R.function
def forward(
x: R.Tensor((224, 224), dtype="float32"),
w: R.Tensor((224, 224), dtype="float32"),
) -> R.Tensor((224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
cls = After
with R.dataflow():
lv = R.call_tir(
cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32")
)
lv1 = R.call_tir(
cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32")
)
gv: R.Tensor((224, 224), dtype="float32") = lv1
R.output(gv)
return gv
mod = relax.transform.SplitLayoutRewritePreproc()(Before)
tvm.ir.assert_structural_equal(mod, After)
if __name__ == "__main__":
tvm.testing.main()