blob: 6f3cdfa9a0de465db404c334e2eb200ce2c70fd6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration test for MetaSchedule"""
import numpy as np
import pytest
import tempfile
import tvm
import tvm.testing
from tvm import IRModule
from tvm import meta_schedule as ms
from tvm import relax, tir
from tvm.ir import transform
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
# fmt: off
@I.ir_module
class Module0:
@R.function
def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 8, 4), dtype="int32"):
cls = Module0
with R.dataflow():
c = R.const([[[[-171701247],[-1719837685],[1801664104],[-634316588]],[[920159370],[-132073802],[2142531563],[1465185701]],[[-1505608067],[1737948828],[1581089391],[-1986167320]]],[[[-1449581822],[35714587],[496324563],[-1430879015]],[[-1615680873],[1198514997],[1494683955],[1567376558]],[[1319924884],[-380548171],[296785437],[-1546305981]]],[[[-398644701],[-2004794585],[-1850413687],[2072643657]],[[847950121],[-544212073],[-199532669],[-343273682]],[[953721562],[-1930209358],[1573600108],[-577689853]]]], "int32")
lv: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(data, c, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
b = R.const([[[[1, 1, 1, 1]]]], "int32")
lv1: R.Tensor((1, 8, 8, 4), dtype="int32") = R.add(lv, b)
c1 = R.const([[[[2042349344],[-2076067063],[1528163722],[-1156452837]],[[-2097172051],[1137787079],[-601389657],[1907495997]],[[987801941],[1073738593],[-1410339796],[-689755358]]],[[[90351522],[-44886952],[-1914103775],[-691553659]],[[-1288505112],[-1376578817],[-2067933148],[-1413101824]],[[1261422027],[-156976862],[-1185734459],[1608778622]]],[[[-664209483],[1907479806],[1838595152],[464942526]],[[877953160],[415131837],[-2010736511],[1218242769]],[[-1440127632],[112931],[521745784],[-1931145893]]]], "int32")
lv2: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv1, c1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
c2 = R.const([[[[687940110],[-910571705],[-901609800],[-500525928]],[[506872399],[1070176297],[-305936110],[1625439784]],[[-1565626954],[-1705688881],[-866370805],[-1750740826]]],[[[300497007],[-626864803],[390295545],[222549121]],[[319224543],[-2003064970],[657992492],[2014175448]],[[653278589],[-768810984],[-294555581],[-1197167662]]],[[[1703154671],[-1540759805],[-568817430],[-1729755444]],[[-275458074],[2078945571],[1683298006],[-1029327874]],[[1315093181],[159010501],[875694807],[-223655381]]]], "int32")
lv3: R.Tensor((1, 8, 8, 4), dtype="int32") = R.nn.conv2d(lv2, c2, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=4, data_layout="NHWC", kernel_layout="HWOI", out_layout="NHWC", out_dtype="int32")
gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
R.output(gv)
return gv
# fmt: on
# fmt: off
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32")
fused_constant = T.allocate_const([-171701247, -1719837685, 1801664104, -634316588, 920159370, -132073802, 2142531563, 1465185701, -1505608067, 1737948828, 1581089391, -1986167320, -1449581822, 35714587, 496324563, -1430879015, -1615680873, 1198514997, 1494683955, 1567376558, 1319924884, -380548171, 296785437, -1546305981, -398644701, -2004794585, -1850413687, 2072643657, 847950121, -544212073, -199532669, -343273682, 953721562, -1930209358, 1573600108, -577689853], "int32", [3, 3, 4, 1])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)):
with T.block("PaddedInput"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0)
for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)):
with T.block("DepthwiseConv2d"):
v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj])
fused_constant_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant)
T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_1[v_di, v_dj, v_c, T.int64(0)])
T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
with T.init():
DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_1[v_di, v_dj, v_c, T.int64(0)]
@T.prim_func(private=True)
def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
PaddedInput0 = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32")
fused_constant0 = T.allocate_const([2042349344, -2076067063, 1528163722, -1156452837, -2097172051, 1137787079, -601389657, 1907495997, 987801941, 1073738593, -1410339796, -689755358, 90351522, -44886952, -1914103775, -691553659, -1288505112, -1376578817, -2067933148, -1413101824, 1261422027, -156976862, -1185734459, 1608778622, -664209483, 1907479806, 1838595152, 464942526, 877953160, 415131837, -2010736511, 1218242769, -1440127632, 112931, 521745784, -1931145893], "int32", [3, 3, 4, 1])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)):
with T.block("PaddedInput"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
T.writes(PaddedInput0[v_i0, v_i1, v_i2, v_i3])
PaddedInput0[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0)
for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)):
with T.block("DepthwiseConv2d"):
v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj])
fused_constant0_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant0)
T.reads(PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant0_1[v_di, v_dj, v_c, T.int64(0)])
T.writes(DepthwiseConv2d0[v_b, v_i, v_j, v_c])
with T.init():
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b, v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]
@T.prim_func(private=True)
def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32")
DepthwiseConv2d = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")
fused_nn_conv2d_constant = T.allocate_const([1, 1, 1, 1], "int32", [1, 1, 1, 4])
fused_constant_2 = T.allocate_const([687940110, -910571705, -901609800, -500525928, 506872399, 1070176297, -305936110, 1625439784, -1565626954, -1705688881, -866370805, -1750740826, 300497007, -626864803, 390295545, 222549121, 319224543, -2003064970, 657992492, 2014175448, 653278589, -768810984, -294555581, -1197167662, 1703154671, -1540759805, -568817430, -1729755444, -275458074, 2078945571, 1683298006, -1029327874, 1315093181, 159010501, 875694807, -223655381], "int32", [3, 3, 4, 1])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)):
with T.block("PaddedInput"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3])
PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0)
for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)):
with T.block("DepthwiseConv2d"):
v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj])
fused_constant_2_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant_2)
T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)])
T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c])
with T.init():
DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
fused_nn_conv2d_constant_1 = T.Buffer((1, 1, 1, 4), "int32", data=fused_nn_conv2d_constant)
T.reads(DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3], fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3] + fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3]
@R.function
def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 8, 4), dtype="int32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.fused_conv2d_add, data, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32"))
lv2 = R.call_tir(cls.conv2d, lv, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32"))
lv3 = R.call_tir(cls.conv2d0, lv2, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32"))
gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3
R.output(gv)
return gv
# fmt: on
def test_extracting_tasks():
target = "llvm -mcpu=core-avx2 -num-cores=1"
relax_mod = Module0
relax_mod = relax.transform.LegalizeOps()(relax_mod)
relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
relax_mod = relax.transform.FuseOps()(relax_mod)
relax_mod = relax.transform.FoldConstant()(relax_mod)
relax_mod = relax.transform.FuseTIR()(relax_mod)
relax_expectation = {
"structural": 2, # The relax constants do not reach the tir at the lowering.
"ignore-tensor": 2,
"anchor-block": 1,
}
for module_equality, count in relax_expectation.items():
extracted_tasks = ms.relax_integration.extract_tasks(
relax_mod,
target,
{},
module_equality=module_equality,
)
assert len(extracted_tasks) == count
tir_relax_mod = Module
tir_relax_expectation = {"structural": 3, "ignore-tensor": 2, "anchor-block": 1}
for module_equality, count in tir_relax_expectation.items():
extracted_tasks = ms.relax_integration.extract_tasks(
tir_relax_mod,
target,
{},
module_equality=module_equality,
)
assert len(extracted_tasks) == count
@pytest.mark.parametrize("module_equality", ["structural", "ignore-tensor", "anchor-block"])
def test_using_anchor_trace(module_equality):
relax_mod = Module
target = "llvm -mcpu=core-avx2 -num-cores=1"
with tempfile.TemporaryDirectory() as work_dir:
database = ms.relax_integration.tune_relax(
mod=relax_mod,
params={},
target=target,
work_dir=work_dir,
# for faster tuning
max_trials_global=100,
max_trials_per_task=4,
num_trials_per_iter=4,
strategy="replay-trace",
module_equality=module_equality,
seed=0,
)
ms.relax_integration.compile_relax(
database,
mod=relax_mod,
target=target,
params={},
)
if __name__ == "__main__":
tvm.testing.main()