blob: ff03ab4152c975dcf2c48d264f55cd17684845ba [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm.relax.transform import LegalizeOps
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
##################### Neural network #####################
def test_conv1d():
# fmt: off
@tvm.script.ir_module
class Conv1d:
@R.function
def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"):
gv: R.Tensor((2, 64, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"):
gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2])
pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(16), T.int64(3)):
with T.block("group_conv1d_ncw"):
v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry])
T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy])
with T.init():
group_conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
group_conv1d_ncw[v_nn, v_ff, v_yy] = group_conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * B[v_ff, v_rc, v_ry]
# fmt: on
mod = LegalizeOps()(Conv1d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv1d_with_out_dtype():
# fmt: off
@tvm.script.ir_module
class Conv1d:
@R.function
def main(x: R.Tensor((2, 3, 28), "float32"), w: R.Tensor((4, 3, 3), "float32")) -> R.Tensor((2, 4, 26), "float16"):
gv: R.Tensor((2, 4, 26), "float16") = R.nn.conv1d(x, w, out_dtype="float16")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"):
gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16"))
return gv
@T.prim_func(private=True)
def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[v_i0, v_i1, v_i2])
T.writes(pad_temp[v_i0, v_i1, v_i2])
pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(3), T.int64(3)):
with T.block("conv1d_ncw"):
v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry])
T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
with T.init():
conv1d_ncw[v_nn, v_ff, v_yy] = T.float16(0)
conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + T.Cast("float16", pad_temp[v_nn, v_rc, v_yy + v_ry]) * T.Cast("float16", rxplaceholder_1[v_ff, v_rc, v_ry])
# fmt: on
mod = LegalizeOps()(Conv1d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv1d_nwc():
# fmt: off
@tvm.script.ir_module
class Conv1d:
@R.function
def main(x: R.Tensor((2, 28, 128), "float32"), w: R.Tensor((64, 128, 3), "float32")) -> R.Tensor((2, 26, 64), "float32"):
gv: R.Tensor((2, 26, 64), "float32") = R.nn.conv1d(x, w, data_layout="NWC")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"):
gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(2), T.int64(28), T.int64(128)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[v_i0, v_i1, v_i2])
T.writes(pad_temp[v_i0, v_i1, v_i2])
pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
for nn, yy, ff, ry, rc in T.grid(T.int64(2), T.int64(26), T.int64(64), T.int64(3), T.int64(128)):
with T.block("conv1d_nwc"):
v_nn, v_yy, v_ff, v_ry, v_rc = T.axis.remap("SSSRR", [nn, yy, ff, ry, rc])
T.reads(pad_temp[v_nn, v_yy + v_ry, v_rc], rxplaceholder_1[v_ff, v_rc, v_ry])
T.writes(conv1d_nwc[v_nn, v_yy, v_ff])
with T.init():
conv1d_nwc[v_nn, v_yy, v_ff] = T.float32(0)
conv1d_nwc[v_nn, v_yy, v_ff] = conv1d_nwc[v_nn, v_yy, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_rc] * rxplaceholder_1[v_ff, v_rc, v_ry]
# fmt: on
mod = LegalizeOps()(Conv1d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv1d_symbolic():
# fmt: off
@tvm.script.ir_module
class Conv1d:
@R.function
def main(x: R.Tensor(("n", "c", "w"), "float32"), kernel: R.Tensor(("f", "c", "kw"), "float32")) -> R.Tensor(("n", "f", "w - kw + 1"), "float32"):
n = T.int64()
w = T.int64()
f = T.int64()
kw = T.int64()
gv: R.Tensor((n, f, w - kw + 1), "float32") = R.nn.conv1d(x, kernel)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kw"), dtype="float32")) -> R.Tensor(("n", "f", "w - kw + 1"), dtype="float32"):
n = T.int64()
f = T.int64()
w = T.int64()
kw = T.int64()
c = T.int64()
gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w + 1 - kw), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle):
T.func_attr({"tir.noalias": True})
n, c, w = T.int64(), T.int64(), T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w))
f, kw = T.int64(), T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw))
conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1) - kw))
# with T.block("root"):
pad_temp = T.alloc_buffer((n, c, w))
for i0, i1, i2 in T.grid(n, c, w):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[v_i0, v_i1, v_i2])
T.writes(pad_temp[v_i0, v_i1, v_i2])
pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
for nn, ff, yy, rc, ry in T.grid(n, f, w + T.int64(1) - kw, c, kw):
with T.block("conv1d_ncw"):
v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry])
T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
with T.init():
conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy + v_ry] * rxplaceholder_1[v_ff, v_rc, v_ry]
# fmt: on
mod = LegalizeOps()(Conv1d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv1d_transpose():
# fmt: off
@I.ir_module
class Conv1dTranspose:
@R.function
def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16, 3), "float32")):
gv = R.nn.conv1d_transpose(x, w, strides=2, padding=1, dilation=1, output_padding=1, groups=8)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)), "float32")):
T.func_attr({"tir.noalias": True})
data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55)))
data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58)))
kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)):
with T.block("data_dilate"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)):
with T.block("data_pad"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0.0))
for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)):
with T.block("kernel"):
v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1])
kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w]
for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(16), T.int64(3)):
with T.block("compute"):
v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c, w_1, dc, dw])
with T.init():
compute[v_b, v_c, v_w] = T.float32(0.0)
compute[v_b, v_c, v_w] = compute[v_b, v_c, v_w] + data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_w + v_dw] * kernel[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dw]
@R.function
def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128, 16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.conv1d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(Conv1dTranspose)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d():
# fmt: off
@tvm.script.ir_module
class Conv2d:
@R.function
def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"):
gv: R.Tensor((2, 64, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"):
gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 64, 13, 13), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3), T.int64(3)), "float32"), group_conv2d_nchw: T.Buffer((T.int64(2), T.int64(64), T.int64(13), T.int64(13)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer([T.int64(2), T.int64(128), T.int64(30), T.int64(30)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(30), T.int64(30)):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(29) and T.int64(1) <= i3_1 and i3_1 < T.int64(29), rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)], T.float32(0), dtype="float32")
for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(13), T.int64(16), T.int64(3), T.int64(3)):
with T.block("group_conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)], rxplaceholder_1[ff, rc, ry, rx])
T.writes(group_conv2d_nchw[nn, ff, yy, xx])
with T.init():
group_conv2d_nchw[nn, ff, yy, xx] = T.float32(0)
group_conv2d_nchw[nn, ff, yy, xx] = group_conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)] * rxplaceholder_1[ff, rc, ry, rx]
# fmt: on
mod = LegalizeOps()(Conv2d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_with_out_dtype():
# fmt: off
@tvm.script.ir_module
class Conv2d:
@R.function
def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"):
gv: R.Tensor((2, 4, 26, 26), "float16") = R.nn.conv2d(x, w, out_dtype="float16")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"):
gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 4, 26, 26), dtype="float16"))
return gv
@T.prim_func(private=True)
def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float16")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1]
for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3)):
with T.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx])
T.writes(conv2d_nchw[nn, ff, yy, xx])
with T.init():
conv2d_nchw[nn, ff, yy, xx] = T.float16(0)
conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + T.Cast("float16", pad_temp[nn, rc, yy + ry, xx + rx]) * T.Cast("float16", rxplaceholder_1[ff, rc, ry, rx])
# fmt: on
mod = LegalizeOps()(Conv2d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_nhwc():
# fmt: off
@tvm.script.ir_module
class Conv2d:
@R.function
def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"):
gv: R.Tensor((2, 26, 26, 64), "float32") = R.nn.conv2d(x, w, data_layout="NHWC")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"):
gv = R.call_tir(Expected.conv2d, (x, w), R.Tensor((2, 26, 26, 64), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nhwc: T.Buffer((T.int64(2), T.int64(26), T.int64(26), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer([T.int64(2), T.int64(28), T.int64(28), T.int64(128)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(28), T.int64(28), T.int64(128)):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1]
for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(26), T.int64(26), T.int64(64), T.int64(3), T.int64(3), T.int64(128)):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, yy + ry, xx + rx, rc], rxplaceholder_1[ff, rc, ry, rx])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = T.float32(0)
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + pad_temp[nn, yy + ry, xx + rx, rc] * rxplaceholder_1[ff, rc, ry, rx]
# fmt: on
mod = LegalizeOps()(Conv2d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_symbolic():
# fmt: off
@tvm.script.ir_module
class Conv2d:
@R.function
def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"):
n = T.int64()
h = T.int64()
w = T.int64()
f = T.int64()
kh = T.int64()
kw = T.int64()
gv: R.Tensor((n, f, h - kh + 1, w - kw + 1), "float32") = R.nn.conv2d(x, kernel)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"):
n = T.int64()
f = T.int64()
h = T.int64()
kh = T.int64()
w = T.int64()
kw = T.int64()
gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, h + 1 - kh, w + 1 - kw), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2d_nchw: T.handle):
T.func_attr({"tir.noalias": True})
c = T.int64()
f = T.int64()
h = T.int64()
kh = T.int64()
kw = T.int64()
n = T.int64()
w = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32")
conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h + T.int64(1) - kh, w + T.int64(1) - kw], dtype="float32")
pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32")
for i0, i1, i2, i3 in T.grid(n, c, h, w):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1]
for i0, i1, i2, i3, i4, i5, i6 in T.grid(n, f, h + T.int64(1) - kh, w + T.int64(1) - kw, c, kh, kw):
with T.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx])
T.writes(conv2d_nchw[nn, ff, yy, xx])
with T.init():
conv2d_nchw[nn, ff, yy, xx] = T.float32(0)
conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * rxplaceholder_1[ff, rc, ry, rx]
# fmt: on
mod = LegalizeOps()(Conv2d)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_transpose():
# fmt: off
@I.ir_module
class Conv2dTranspose:
@R.function
def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((128, 16, 3, 3), "float32")):
gv = R.nn.conv2d_transpose(x, w, strides=(2, 3), padding=(1, 1), dilation=(1, 1), output_padding=(1, 2), groups=8)
return gv
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: R.Tensor((128, 16, 3, 3), dtype="float32")) -> R.Tensor((2, 128, 56, 84), dtype="float32"):
gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56, 84), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55), T.int64(82)))
data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58), T.int64(86)))
kernel_transform = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3), T.int64(3)))
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(55), T.int64(82)):
with T.block("data_dilate"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)])
T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)], T.float32(0))
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(58), T.int64(86)):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56) and T.int64(1) <= v_i3 and v_i3 < T.int64(83), data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0))
for i, o, h, w in T.grid(T.int64(16), T.int64(128), T.int64(3), T.int64(3)):
with T.block("kernel_transform"):
v_i, v_o, v_h, v_w = T.axis.remap("SSSS", [i, o, h, w])
T.reads(rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w])
T.writes(kernel_transform[v_i, v_o, v_h, v_w])
kernel_transform[v_i, v_o, v_h, v_w] = rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w]
for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(84), T.int64(16), T.int64(3), T.int64(3)):
with T.block("compute"):
v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
T.reads(data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw])
T.writes(compute[v_b, v_c, v_h, v_w])
with T.init():
compute[v_b, v_c, v_h, v_w] = T.float32(0)
compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw]
# fmt: on
mod = LegalizeOps()(Conv2dTranspose)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_transpose_with_out_dtype():
# fmt: off
@tvm.script.ir_module
class Conv2dTranspose:
@R.function
def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 4, 3, 3), "float32")):
gv = R.nn.conv2d_transpose(x, w, out_dtype="float16")
return gv
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 30, 30), dtype="float16"):
gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 4, 30, 30), dtype="float16"))
return gv
@T.prim_func(private=True)
def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4), T.int64(30), T.int64(30)), "float16")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
data_dilate = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
data_pad = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(32), T.int64(32)))
kernel_transform = T.alloc_buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3)))
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)):
with T.block("data_dilate"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3])
T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
data_dilate[v_i0, v_i1, v_i2, v_i3] = rxplaceholder[v_i0, v_i1, v_i2, v_i3]
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(32)):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(30) and T.int64(2) <= v_i3 and v_i3 < T.int64(30), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0))
for o, i, h, w in T.grid(T.int64(4), T.int64(3), T.int64(3), T.int64(3)):
with T.block("kernel_transform"):
v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w])
T.reads(rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w])
T.writes(kernel_transform[v_o, v_i, v_h, v_w])
kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]
for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(4), T.int64(30), T.int64(30), T.int64(3), T.int64(3), T.int64(3)):
with T.block("compute"):
v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw])
T.writes(compute[v_b, v_c, v_h, v_w])
with T.init():
compute[v_b, v_c, v_h, v_w] = T.float16(0)
compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + T.Cast("float16", data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw]) * T.Cast("float16", kernel_transform[v_c, v_dc, v_dh, v_dw])
# fmt: on
mod = LegalizeOps()(Conv2dTranspose)
tvm.ir.assert_structural_equal(mod, Expected)
def test_conv2d_transpose_symbolic():
# fmt: off
@tvm.script.ir_module
class Conv2dTranspose:
@R.function
def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")):
gv = R.nn.conv2d_transpose(x, kernel, strides=(3, 3))
return gv
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), dtype="float32")) -> R.Tensor(("n", "c", "h * 3 + kh - 3", "w * 3 + kw - 3"), dtype="float32"):
n = T.int64()
c = T.int64()
h = T.int64()
kh = T.int64()
w = T.int64()
kw = T.int64()
f = T.int64()
gv = R.call_tir(Expected.conv2d_transpose, (x, kernel), out_sinfo=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": True})
n = T.int64()
c = T.int64()
h = T.int64()
w = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, h, w))
f = T.int64()
kh = T.int64()
kw = T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kh, kw))
compute = T.match_buffer(var_compute, (n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3)))
# with T.block("root"):
data_dilate = T.alloc_buffer((n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2)))
data_pad = T.alloc_buffer((n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)))
kernel_transform = T.alloc_buffer((c, c, kh, kw))
for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2)):
with T.block("data_dilate"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)])
T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(3) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)], T.float32(0))
for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh <= v_i2 + T.int64(1) and v_i2 + T.int64(3)< h * T.int64(3) + kh and kw <= v_i3 + T.int64(1) and v_i3 + T.int64(3) < w * T.int64(3) + kw , data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], T.float32(0))
for o, i, h_1, w_1 in T.grid(c, c, kh, kw):
with T.block("kernel_transform"):
v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1])
T.reads(rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)])
T.writes(kernel_transform[v_o, v_i, v_h, v_w])
kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)]
for b, c_1, h_1, w_1, dc, dh, dw in T.grid(n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3), c, kh, kw):
with T.block("compute"):
v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c_1, h_1, w_1, dc, dh, dw])
T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw])
T.writes(compute[v_b, v_c, v_h, v_w])
with T.init():
compute[v_b, v_c, v_h, v_w] = T.float32(0)
compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, v_dh, v_dw]
# fmt: on
mod = LegalizeOps()(Conv2dTranspose)
tvm.ir.assert_structural_equal(mod, Expected)
def test_max_pool2d():
# fmt: off
@tvm.script.ir_module
class MaxPool2D:
@R.function
def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"):
gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], layout="NHWC")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"):
gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 56, 56, 6), dtype="float32"))
return gv
@T.prim_func(private=True)
def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer([T.int64(4), T.int64(114), T.int64(114), T.int64(6)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3])
T.writes(pad_temp[ax0, ax1, ax2, ax3])
pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax1 and ax1 < T.int64(113) and T.int64(1) <= ax2 and ax2 < T.int64(113), rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3], T.float32(-3.4028234663852886e+38), dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)):
with T.block("pool_max"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
T.reads(pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3])
T.writes(pool_max[ax0, ax1, ax2, ax3])
T.block_attr({"schedule_rule":"meta_schedule.pool_max"})
with T.init():
pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38)
pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3])
# fmt: on
mod = LegalizeOps()(MaxPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_max_pool2d_NCHW16c():
# fmt: off
@tvm.script.ir_module
class MaxPool2D:
@R.function
def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"):
gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], layout="NCHW16c")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"):
gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 4, 110, 110, 16), dtype="float32"))
return gv
@T.prim_func(private=True)
def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)):
with T.block("pool_max"):
ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4])
T.writes(pool_max[ax0, ax1, ax2, ax3, ax4])
T.block_attr({"schedule_rule":"meta_schedule.pool_max"})
with T.init():
pool_max[ax0, ax1, ax2, ax3, ax4] = T.float32(-3.4028234663852886e+38)
pool_max[ax0, ax1, ax2, ax3, ax4] = T.max(pool_max[ax0, ax1, ax2, ax3, ax4], rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4])
# fmt: on
mod = LegalizeOps()(MaxPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_max_pool2d_ceil_mode():
# fmt: off
@tvm.script.ir_module
class MaxPool2D:
@R.function
def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), "float32"):
gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=True)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"):
gv = R.call_tir(Expected.max_pool2d, (x,), R.Tensor((4, 6, 38, 38), dtype="float32"))
return gv
@T.prim_func(private=True)
def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer([T.int64(4), T.int64(6), T.int64(116), T.int64(116)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)])
T.writes(pad_temp[ax0, ax1, ax2, ax3])
pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax2 and ax2 < T.int64(113) and T.int64(1) <= ax3 and ax3 < T.int64(113), rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38), dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)):
with T.block("pool_max"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
T.reads(pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1])
T.writes(pool_max[ax0, ax1, ax2, ax3])
T.block_attr({"schedule_rule":"meta_schedule.pool_max"})
with T.init():
pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38)
pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1])
# fmt: on
mod = LegalizeOps()(MaxPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
@pytest.mark.skip("TOPI pooling casts every shape value to i32.")
def test_max_pool2d_symbolic():
# fmt: off
@tvm.script.ir_module
class MaxPool2D:
@R.function
def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"):
n = T.int64()
c = T.int64()
h = T.int64()
w = T.int64()
kh = T.int64()
kw = T.int64()
gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.max_pool2d(x, pool_size=[kh, kw])
return gv
# fmt: on
mod = LegalizeOps()(MaxPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_avg_pool2d():
# fmt: off
@tvm.script.ir_module
class AvgPool2D:
@R.function
def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"):
gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], layout="NHWC")
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(4), T.int64(114), T.int64(114), T.int64(6)))
pool_sum = T.alloc_buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)):
with T.block("pad_temp"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3])
T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax1 and v_ax1 < T.int64(113) and T.int64(1) <= v_ax2 and v_ax2 < T.int64(113), rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3], T.float32(0))
for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)):
with T.block("pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3])
T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6)):
with T.block("pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax1 * T.int64(2) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax1 * T.int64(2), T.int64(0)) - v_ax1 * T.int64(2)) * (T.min(v_ax2 * T.int64(2) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(2), T.int64(0)) - v_ax2 * T.int64(2)), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) -> R.Tensor((4, 56, 56, 6), dtype="float32"):
gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 56, 56, 6), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(AvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_avg_pool2d_NCHW16c():
# fmt: off
@tvm.script.ir_module
class AvgPool2D:
@R.function
def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"):
gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], layout="NCHW16c")
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
pool_sum = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)))
for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)):
with T.block("pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap("SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1])
T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4])
T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
with T.init():
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(0)
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)):
with T.block("pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", T.max((T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1) - T.max(T.int64(0) - v_ax2, T.int64(0))) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1) - T.max(T.int64(0) - v_ax3, T.int64(0))), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> R.Tensor((4, 4, 110, 110, 16), dtype="float32"):
gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 4, 110, 110, 16), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(AvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_avg_pool2d_ceil_mode():
# fmt: off
@tvm.script.ir_module
class AvgPool2D:
@R.function
def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), "float32"):
gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.avg_pool2d(x, pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=True)
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(116), T.int64(116)))
pool_sum = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)):
with T.block("pad_temp"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)])
T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax2 and v_ax2 < T.int64(113) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(113), rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)], T.float32(0))
for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)):
with T.block("pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1])
T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38)):
with T.block("pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax2 * T.int64(3) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(3), T.int64(0)) - v_ax2 * T.int64(3)) * (T.min(v_ax3 * T.int64(3) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax3 * T.int64(3), T.int64(0)) - v_ax3 * T.int64(3)), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"):
gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 6, 38, 38), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(AvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
@pytest.mark.skip("TOPI pooling casts every shape value to i32.")
def test_avg_pool2d_symbolic():
# fmt: off
@tvm.script.ir_module
class AvgPool2D:
@R.function
def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"):
n = T.int64()
c = T.int64()
h = T.int64()
w = T.int64()
kh = T.int64()
kw = T.int64()
gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.avg_pool2d(x, pool_size=[kh, kw])
return gv
# fmt: on
mod = LegalizeOps()(AvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_adaptive_avg_pool2d():
# fmt: off
@tvm.script.ir_module
class AdaptiveAvgPool2D:
@R.function
def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"):
gv: R.Tensor((2, 4, 1, 1, 16), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=[1, 1], layout="NCHW16c")
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"):
gv = R.call_tir(Expected.adaptive_avg_pool2d, (x,), R.Tensor((2, 4, 1, 1, 16), dtype="float32"))
return gv
@T.prim_func(private=True)
def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(7), T.int64(7), T.int64(16)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": True})
adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)], dtype="float32")
for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16), T.int64(7), T.int64(7)):
with T.block("adaptive_pool_sum"):
ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4])
T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4])
with T.init():
adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = T.float32(0)
adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] + rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4]
for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)):
with T.block("adaptive_pool_avg"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4])
T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4])
T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] * T.float32(0.020408163265306121)
# fmt: on
mod = LegalizeOps()(AdaptiveAvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_adaptive_avg_pool2d_without_output_size():
# fmt: off
@tvm.script.ir_module
class AdaptiveAvgPool2D:
@R.function
def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"):
gv: R.Tensor((2, 16, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"):
gv = R.call_tir(Expected.adaptive_avg_pool2d, (x,), R.Tensor((2, 16, 7, 7), dtype="float32"))
return gv
@T.prim_func(private=True)
def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32")):
T.func_attr({"tir.noalias": True})
adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(16), T.int64(7), T.int64(7)], dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7), T.int64(1), T.int64(1)):
with T.block("adaptive_pool_sum"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1])
T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3])
with T.init():
adaptive_pool_sum[ax0, ax1, ax2, ax3] = T.float32(0)
adaptive_pool_sum[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1]
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7)):
with T.block("adaptive_pool_avg"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3])
T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3])
T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3]
# fmt: on
mod = LegalizeOps()(AdaptiveAvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
@pytest.mark.skip("TOPI pooling casts every shape value to i32.")
def test_adaptive_avg_pool2d_symbolic():
# fmt: off
@tvm.script.ir_module
class AdaptiveAvgPool2D:
@R.function
def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "oh", "ow"), "float32"):
n = T.int64()
c = T.int64()
oh = T.int64()
ow = T.int64()
gv: R.Tensor((n, c, oh, ow), "float32") = R.nn.adaptive_avg_pool2d(x, (oh, ow))
return gv
# fmt: on
mod = LegalizeOps()(AdaptiveAvgPool2D)
tvm.ir.assert_structural_equal(mod, Expected)
def test_relu():
# fmt: off
@tvm.script.ir_module
class Relu:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.relu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.relu, (x,), R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))
# fmt: on
mod = LegalizeOps()(Relu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_relu_symbolic():
# fmt: off
@tvm.script.ir_module
class Relu:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.nn.relu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.relu, (x,), R.Tensor((m, n), dtype="float32"))
return gv
@T.prim_func(private=True)
def relu(var_rxplaceholder: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32")
compute = T.match_buffer(var_compute, [m, n], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))
# fmt: on
mod = LegalizeOps()(Relu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_leakyrelu():
# fmt: off
@tvm.script.ir_module
class LeakyRelu:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.leakyrelu(x, 0.02)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.leaky_relu, (x,), R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def leaky_relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.Select(T.float32(0) < rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \
rxplaceholder[i0_1, i1_1] * T.float32(0.02))
# fmt: on
mod = LegalizeOps()(LeakyRelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_leakyrelu_symbolic():
# fmt: off
@tvm.script.ir_module
class LeakyRelu:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.nn.leakyrelu(x, 0.03)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.leaky_relu, (x, ), R.Tensor((m, n), dtype="float32"))
return gv
@T.prim_func(private=True)
def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32")
compute = T.match_buffer(var_compute, [m, n], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.Select(T.float32(0) < rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \
rxplaceholder[i0_1, i1_1] * T.float32(0.03))
# fmt: on
mod = LegalizeOps()(LeakyRelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_prelu():
# fmt: off
@tvm.script.ir_module
class PRelu:
@R.function
def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
slope_broadcasted = T.alloc_buffer((T.int64(3),))
for c in range(T.int64(3)):
with T.block("slope_broadcasted"):
v_c = T.axis.spatial(T.int64(3), c)
T.reads(y[T.int64(0)])
T.writes(slope_broadcasted[v_c])
slope_broadcasted[v_c] = y[T.int64(0)]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
# fmt: on
mod = LegalizeOps()(PRelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_prelu_symbolic():
# fmt: off
@tvm.script.ir_module
class PRelu:
@R.function
def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor(("m", 7), "float32"):
m = T.int64()
gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"):
m = T.int64()
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), dtype="float32"))
return gv
@T.prim_func(private=True)
def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
x = T.match_buffer(var_x, (m, T.int64(7)))
compute = T.match_buffer(var_compute, (m, T.int64(7)))
# with T.block("root"):
slope_broadcasted = T.alloc_buffer((T.int64(7),))
for c in range(T.int64(7)):
with T.block("slope_broadcasted"):
v_c = T.axis.spatial(T.int64(7), c)
T.reads(y[T.int64(0)])
T.writes(slope_broadcasted[v_c])
slope_broadcasted[v_c] = y[T.int64(0)]
for i0, i1 in T.grid(m, T.int64(7)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
# fmt: on
mod = LegalizeOps()(PRelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_gelu():
# fmt: off
@tvm.script.ir_module
class Gelu:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.gelu, (x,), R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def gelu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
T_multiply_1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
T_divide = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1])
T.writes(T_multiply_1[ax0, ax1])
T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757)
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_1[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32")
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_1"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(compute[ax0, ax1])
T.writes(T_multiply_2[ax0, ax1])
T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5)
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_divide"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_2[ax0, ax1])
T.writes(T_divide[ax0, ax1])
T_divide[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_2"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1], T_divide[ax0, ax1])
T.writes(T_multiply[ax0, ax1])
T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_divide[ax0, ax1]
# fmt: on
mod = LegalizeOps()(Gelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_gelu_symbolic():
# fmt: off
@tvm.script.ir_module
class Gelu:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.nn.gelu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.gelu, (x,), R.Tensor((m, n), dtype="float32"))
return gv
@T.prim_func(private=True)
def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32")
T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32")
T_multiply_1 = T.alloc_buffer([m, n], dtype="float32")
compute = T.alloc_buffer([m, n], dtype="float32")
T_multiply_2 = T.alloc_buffer([m, n], dtype="float32")
T_add = T.alloc_buffer([m, n], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("T_multiply"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1])
T.writes(T_multiply_1[ax0, ax1])
T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757)
for i0, i1 in T.grid(m, n):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_1[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("T_multiply_1"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(compute[ax0, ax1])
T.writes(T_multiply_2[ax0, ax1])
T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5)
for i0, i1 in T.grid(m, n):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_2[ax0, ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1]
for i0, i1 in T.grid(m, n):
with T.block("T_multiply_2"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1], T_add[ax0, ax1])
T.writes(T_multiply[ax0, ax1])
T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_add[ax0, ax1]
# fmt: on
mod = LegalizeOps()(Gelu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_gelu_tanh():
# fmt: off
@tvm.script.ir_module
class GeluTanh:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.gelu_tanh(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3)))
compute = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] * T_add_1[v_ax0, v_ax1]
mod = LegalizeOps()(GeluTanh)
tvm.ir.assert_structural_equal(mod, Expected)
def test_gelu_tanh_symbolic():
# fmt: off
@tvm.script.ir_module
class GeluTanh:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.nn.gelu_tanh(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((m, n), dtype="float32"))
return gv
@T.prim_func(private=True)
def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
T.func_attr({"tir.noalias": True})
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n))
T_multiply = T.match_buffer(var_T_multiply, (m, n))
# with T.block("root"):
T_multiply_1 = T.alloc_buffer((m, n))
T_multiply_2 = T.alloc_buffer((m, n))
T_multiply_3 = T.alloc_buffer((m, n))
T_multiply_4 = T.alloc_buffer((m, n))
T_add = T.alloc_buffer((m, n))
T_multiply_5 = T.alloc_buffer((m, n))
compute = T.alloc_buffer((m, n))
T_add_1 = T.alloc_buffer((m, n))
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(m, n):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(m, n):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] * T_add_1[v_ax0, v_ax1]
mod = LegalizeOps()(GeluTanh)
tvm.ir.assert_structural_equal(mod, Expected)
def test_silu():
# fmt: off
@tvm.script.ir_module
class Silu:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.silu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.silu, (x,), R.Tensor((2, 3), dtype="float32"))
return gv
@T.prim_func(private=True)
def silu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1])
T.writes(T_multiply[ax0, ax1])
T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1]
# fmt: on
mod = LegalizeOps()(Silu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_silu_symbolic():
# fmt: off
@tvm.script.ir_module
class Silu:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.nn.silu(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.silu, (x,), R.Tensor((m, n), dtype="float32"))
return gv
@T.prim_func(private=True)
def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32")
T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32")
compute = T.alloc_buffer([m, n], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
for i0, i1 in T.grid(m, n):
with T.block("T_multiply"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1])
T.writes(T_multiply[ax0, ax1])
T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1]
# fmt: on
mod = LegalizeOps()(Silu)
tvm.ir.assert_structural_equal(mod, Expected)
def test_softmax():
# fmt: off
@tvm.script.ir_module
class Softmax:
@R.function
def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"):
gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.softmax(x, axis=-2)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"):
gv = R.call_tir(Expected.softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32"))
return gv
@T.prim_func(private=True)
def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), T_softmax_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32")):
T.func_attr({"tir.noalias": True})
T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32")
T_softmax_exp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32")
T_softmax_expsum = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)):
with T.block("T_softmax_maxelem"):
i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, k, i2_1])
T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1])
with T.init():
T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1])
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)):
with T.block("T_softmax_exp"):
i0_2, i1_2, i2_2, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_2, i1_2, i2_2, i3_1], T_softmax_maxelem[i0_2, i1_2, i3_1])
T.writes(T_softmax_exp[i0_2, i1_2, i2_2, i3_1])
T_softmax_exp[i0_2, i1_2, i2_2, i3_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_2, i3_1] - T_softmax_maxelem[i0_2, i1_2, i3_1], dtype="float32")
for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)):
with T.block("T_softmax_expsum"):
i0_4, i1_4, i2_4, k = T.axis.remap("SSSR", [i0_3, i1_3, i2_3, i3])
T.reads(T_softmax_exp[i0_4, i1_4, k, i2_4])
T.writes(T_softmax_expsum[i0_4, i1_4, i2_4])
with T.init():
T_softmax_expsum[i0_4, i1_4, i2_4] = T.float32(0)
T_softmax_expsum[i0_4, i1_4, i2_4] = T_softmax_expsum[i0_4, i1_4, i2_4] + T_softmax_exp[i0_4, i1_4, k, i2_4]
for i0_5, i1_5, i2_5, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)):
with T.block("T_softmax_norm"):
i0_6, i1_6, i2_6, i3_2 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3])
T.reads(T_softmax_exp[i0_6, i1_6, i2_6, i3_2], T_softmax_expsum[i0_6, i1_6, i3_2])
T.writes(T_softmax_norm[i0_6, i1_6, i2_6, i3_2])
T.block_attr({"axis":2})
T_softmax_norm[i0_6, i1_6, i2_6, i3_2] = T_softmax_exp[i0_6, i1_6, i2_6, i3_2] / T_softmax_expsum[i0_6, i1_6, i3_2]
# fmt: on
mod = LegalizeOps()(Softmax)
tvm.ir.assert_structural_equal(mod, Expected)
def test_softmax_symbolic():
# fmt: off
@tvm.script.ir_module
class Softmax:
@R.function
def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"):
a = T.int64()
b = T.int64()
c = T.int64()
gv: R.Tensor((a, b, c), "float32") = R.nn.softmax(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"):
a = T.int64()
b = T.int64()
c = T.int64()
gv = R.call_tir(Expected.softmax, (x,), R.Tensor((a, b, c), dtype="float32"))
return gv
@T.prim_func(private=True)
def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle):
T.func_attr({"tir.noalias": True})
a = T.int64()
b = T.int64()
c = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32")
T_softmax_norm = T.match_buffer(var_T_softmax_norm, [a, b, c], dtype="float32")
T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32")
T_softmax_exp = T.alloc_buffer([a, b, c], dtype="float32")
T_softmax_expsum = T.alloc_buffer([a, b], dtype="float32")
for i0, i1, i2 in T.grid(a, b, c):
with T.block("T_softmax_maxelem"):
i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(rxplaceholder[i0_1, i1_1, k])
T.writes(T_softmax_maxelem[i0_1, i1_1])
with T.init():
T_softmax_maxelem[i0_1, i1_1] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem[i0_1, i1_1] = T.max(T_softmax_maxelem[i0_1, i1_1], rxplaceholder[i0_1, i1_1, k])
for i0, i1, i2 in T.grid(a, b, c):
with T.block("T_softmax_exp"):
i0_2, i1_2, i2_1 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[i0_2, i1_2, i2_1], T_softmax_maxelem[i0_2, i1_2])
T.writes(T_softmax_exp[i0_2, i1_2, i2_1])
T_softmax_exp[i0_2, i1_2, i2_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_1] - T_softmax_maxelem[i0_2, i1_2], dtype="float32")
for i0_3, i1_3, i2 in T.grid(a, b, c):
with T.block("T_softmax_expsum"):
i0_4, i1_4, k = T.axis.remap("SSR", [i0_3, i1_3, i2])
T.reads(T_softmax_exp[i0_4, i1_4, k])
T.writes(T_softmax_expsum[i0_4, i1_4])
with T.init():
T_softmax_expsum[i0_4, i1_4] = T.float32(0)
T_softmax_expsum[i0_4, i1_4] = T_softmax_expsum[i0_4, i1_4] + T_softmax_exp[i0_4, i1_4, k]
for i0_5, i1_5, i2 in T.grid(a, b, c):
with T.block("T_softmax_norm"):
i0_6, i1_6, i2_2 = T.axis.remap("SSS", [i0_5, i1_5, i2])
T.reads(T_softmax_exp[i0_6, i1_6, i2_2], T_softmax_expsum[i0_6, i1_6])
T.writes(T_softmax_norm[i0_6, i1_6, i2_2])
T.block_attr({"axis":2})
T_softmax_norm[i0_6, i1_6, i2_2] = T_softmax_exp[i0_6, i1_6, i2_2] / T_softmax_expsum[i0_6, i1_6]
# fmt: on
mod = LegalizeOps()(Softmax)
tvm.ir.assert_structural_equal(mod, Expected)
def test_log_softmax():
# fmt: off
@tvm.script.ir_module
class LogSoftmax:
@R.function
def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor(None, "float32", ndim=4):
gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.log_softmax(x, axis=-2)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 16, 32), dtype="float32")) -> R.Tensor((2, 3, 16, 32), dtype="float32"):
gv = R.call_tir(Expected.log_softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32"))
return gv
@T.prim_func(private=True)
def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"),):
T.func_attr({"tir.noalias": True})
T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32")
compute_1 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)):
with T.block("T_softmax_maxelem"):
i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, i1_1, k, i2_1])
T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1])
with T.init():
T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1])
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)):
with T.block("compute"):
i0_2, i1_2, i2_2, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_2, i1_2, k, i2_2], T_softmax_maxelem[i0_2, i1_2, i2_2])
T.writes(compute_1[i0_2, i1_2, i2_2])
with T.init():
compute_1[i0_2, i1_2, i2_2] = T.float32(0)
compute_1[i0_2, i1_2, i2_2] = compute_1[i0_2, i1_2, i2_2] + T.exp(rxplaceholder[i0_2, i1_2, k, i2_2] - T_softmax_maxelem[i0_2, i1_2, i2_2], dtype="float32")
for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)):
with T.block("compute_1"):
i0_4, i1_4, i2_4, i3_1 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3])
T.reads(rxplaceholder[i0_4, i1_4, i2_4, i3_1], T_softmax_maxelem[i0_4, i1_4, i3_1], compute_1[i0_4, i1_4, i3_1])
T.writes(compute[i0_4, i1_4, i2_4, i3_1])
T.block_attr({"axis": 2})
compute[i0_4, i1_4, i2_4, i3_1] = (rxplaceholder[i0_4, i1_4, i2_4, i3_1] - T_softmax_maxelem[i0_4, i1_4, i3_1] - T.log(compute_1[i0_4, i1_4, i3_1], dtype="float32"))
# fmt: on
mod = LegalizeOps()(LogSoftmax)
tvm.ir.assert_structural_equal(mod, Expected)
def test_log_softmax_symbolic():
# fmt: off
@tvm.script.ir_module
class LogSoftmax:
@R.function
def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"):
a = T.int64()
b = T.int64()
c = T.int64()
gv: R.Tensor((a, b, c), "float32") = R.nn.log_softmax(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="float32"):
a = T.int64()
b = T.int64()
c = T.int64()
# block 0
gv = R.call_tir(Expected.log_softmax, (x,), R.Tensor((a, b, c), dtype="float32"))
return gv
@T.prim_func(private=True)
def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": True})
a = T.int64()
b = T.int64()
c = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32")
compute = T.match_buffer(var_compute, [a, b, c], dtype="float32")
T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32")
compute_1 = T.alloc_buffer([a, b], dtype="float32")
for i0, i1, k in T.grid(a, b, c):
with T.block("T_softmax_maxelem"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(rxplaceholder[v_i0, v_i1, v_k])
T.writes(T_softmax_maxelem[v_i0, v_i1])
with T.init():
T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], rxplaceholder[v_i0, v_i1, v_k])
for i0, i1, k in T.grid(a, b, c):
with T.block("compute"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(rxplaceholder[v_i0, v_i1, v_k], T_softmax_maxelem[v_i0, v_i1])
T.writes(compute_1[v_i0, v_i1])
with T.init():
compute_1[v_i0, v_i1] = T.float32(0)
compute_1[v_i0, v_i1] = compute_1[v_i0, v_i1] + T.exp(rxplaceholder[v_i0, v_i1, v_k] - T_softmax_maxelem[v_i0, v_i1], dtype="float32")
for i0, i1, i2 in T.grid(a, b, c):
with T.block("compute_1"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1], compute_1[v_i0, v_i1],)
T.writes(compute[v_i0, v_i1, v_i2])
T.block_attr({"axis": 2})
compute[v_i0, v_i1, v_i2] = (rxplaceholder[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1] - T.log(compute_1[v_i0, v_i1], dtype="float32"))
# fmt: on
mod = LegalizeOps()(LogSoftmax)
tvm.ir.assert_structural_equal(mod, Expected)
def test_cross_entropy_with_logits():
# fmt: off
@tvm.script.ir_module
class CrossEntropyWithLogits:
@R.function
def main(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")) -> R.Tensor(None, "float32", ndim=2):
gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def cross_entropy_with_logits(rxplaceholder: T.Buffer(T.int64(3), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), T_multiply: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32")
T_multiply_red = T.alloc_buffer([], dtype="float32")
for i0 in T.serial(T.int64(3)):
with T.block("T_multiply"):
ax0 = T.axis.spatial(T.int64(3), i0)
T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0])
T.writes(T_multiply_1[ax0])
T_multiply_1[ax0] = rxplaceholder[ax0] * rxplaceholder_1[ax0]
for i0 in T.serial(T.int64(3)):
with T.block("T_multiply_red"):
k0 = T.axis.reduce(T.int64(3), i0)
T.reads(T_multiply_1[k0])
T.writes(T_multiply_red[()])
with T.init():
T_multiply_red[()] = T.float32(0)
T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[k0]
with T.block("T_multiply_1"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_red[()])
T.writes(T_multiply[()])
T_multiply[()] = T_multiply_red[()] * T.float32(-1)
# fmt: on
mod = LegalizeOps()(CrossEntropyWithLogits)
tvm.ir.assert_structural_equal(mod, Expected)
def test_cross_entropy_with_logits_batch():
# fmt: off
@tvm.script.ir_module
class CrossEntropyWithLogits:
@R.function
def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2):
gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def cross_entropy_with_logits(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
T_multiply_red = T.alloc_buffer([], dtype="float32")
T_multiply_1 = T.alloc_buffer([], dtype="float32")
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1])
T.writes(T_multiply[ax0, ax1])
T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * rxplaceholder_1[ax0, ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_red"):
k0, k1 = T.axis.remap("RR", [i0, i1])
T.reads(T_multiply[k0, k1])
T.writes(T_multiply_red[()])
with T.init():
T_multiply_red[()] = T.float32(0)
T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1]
with T.block("T_multiply_1"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_red[()])
T.writes(T_multiply_1[()])
T_multiply_1[()] = T_multiply_red[()] * T.float32(-1)
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_1[()])
T.writes(T_divide[()])
T_divide[()] = T_multiply_1[()] * T.float32(0.5)
# fmt: on
mod = LegalizeOps()(CrossEntropyWithLogits)
tvm.ir.assert_structural_equal(mod, Expected)
def test_cross_entropy_with_logits_batch_symbolic():
# fmt: off
@tvm.script.ir_module
class CrossEntropyWithLogits:
@R.function
def main(x: R.Tensor(("n", "m"), "float32"), y: R.Tensor(("n", "m"), "float32")) -> R.Tensor(None, "float32", ndim=2):
n = T.int64()
m = T.int64()
gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32")
T_multiply = T.alloc_buffer([n, m], dtype="float32")
T_multiply_red = T.alloc_buffer([], dtype="float32")
T_multiply_1 = T.alloc_buffer([], dtype="float32")
for ax0, ax1 in T.grid(n, m):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * rxplaceholder_1[v_ax0, v_ax1]
for k0, k1 in T.grid(n, m):
with T.block("T_multiply_red"):
v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
T.reads(T_multiply[v_k0, v_k1])
T.writes(T_multiply_red[()])
with T.init():
T_multiply_red[()] = T.float32(0)
T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1]
with T.block("T_multiply_1"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_red[()])
T.writes(T_multiply_1[()])
T_multiply_1[()] = T_multiply_red[()] * T.float32(-1)
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_1[()])
T.writes(T_divide[()])
T_divide[()] = T_multiply_1[()] / T.Cast("float32", n)
# fmt: on
mod = LegalizeOps()(CrossEntropyWithLogits)
tvm.ir.assert_structural_equal(mod, Expected)
def test_batch_norm():
# fmt: off
@tvm.script.ir_module
class BatchNorm:
@R.function
def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")):
gv: R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
T.func_attr({"tir.noalias": True})
x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
gamma = T.match_buffer(var_gamma, (T.int64(3),))
beta = T.match_buffer(var_beta, (T.int64(3),))
moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),))
moving_var = T.match_buffer(var_moving_var, (T.int64(3),))
T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),))
T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),))
with T.block("root"):
T.reads()
T.writes()
T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_divide = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_multiply_1 = T.alloc_buffer((T.int64(3),))
x_red = T.alloc_buffer((T.int64(3),))
T_divide_1 = T.alloc_buffer((T.int64(3),))
T_multiply_2 = T.alloc_buffer((T.int64(3),))
T_multiply_3 = T.alloc_buffer((T.int64(3),))
T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1)))
T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)))
T_multiply_red = T.alloc_buffer((T.int64(3),))
T_divide_2 = T.alloc_buffer((T.int64(3),))
T_multiply_5 = T.alloc_buffer((T.int64(3),))
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_subtract"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_1"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
for i0 in range(T.int64(1)):
for i1 in range(T.int64(3)):
for i2 in range(T.int64(1)):
for i3 in range(T.int64(1)):
with T.block("compute"):
v_i0 = T.axis.spatial(T.int64(1), i0)
v_i1 = T.axis.spatial(T.int64(3), i1)
v_i2 = T.axis.spatial(T.int64(1), i2)
v_i3 = T.axis.spatial(T.int64(1), i3)
T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_divide"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_2"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_3"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_add_1"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(3)):
with T.block("T_multiply_1"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(moving_mean[v_ax0])
T.writes(T_multiply_1[v_ax0])
T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0]
for ax0 in range(T.int64(3)):
for k0 in range(T.int64(2)):
for k2 in range(T.int64(28)):
for k3 in range(T.int64(28)):
with T.block("x_red"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
v_k0 = T.axis.reduce(T.int64(2), k0)
v_k2 = T.axis.reduce(T.int64(28), k2)
v_k3 = T.axis.reduce(T.int64(28), k3)
T.reads(x[v_k0, v_ax0, v_k2, v_k3])
T.writes(x_red[v_ax0])
with T.init():
x_red[v_ax0] = T.float32(0.0)
x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3]
for ax0 in range(T.int64(3)):
with T.block("T_divide_1"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(x_red[v_ax0])
T.writes(T_divide_1[v_ax0])
T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628)
for ax0 in range(T.int64(3)):
with T.block("T_multiply_2"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(T_divide_1[v_ax0])
T.writes(T_multiply_2[v_ax0])
T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_add_2"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
T.writes(T_add_1[v_ax0])
T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_multiply_3"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(moving_var[v_ax0])
T.writes(T_multiply_3[v_ax0])
T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_4"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_subtract_1"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_subtract_2"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
with T.block("T_multiply_4"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0 in range(T.int64(3)):
for k0 in range(T.int64(2)):
for k2 in range(T.int64(28)):
for k3 in range(T.int64(28)):
with T.block("T_multiply_red"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
v_k0 = T.axis.reduce(T.int64(2), k0)
v_k2 = T.axis.reduce(T.int64(28), k2)
v_k3 = T.axis.reduce(T.int64(28), k3)
T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0])
with T.init():
T_multiply_red[v_ax0] = T.float32(0.0)
T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
for ax0 in range(T.int64(3)):
with T.block("T_divide_2"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(T_multiply_red[v_ax0])
T.writes(T_divide_2[v_ax0])
T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628)
for ax0 in range(T.int64(3)):
with T.block("T_multiply_5"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(T_divide_2[v_ax0])
T.writes(T_multiply_5[v_ax0])
T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_add_3"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
T.writes(T_add_2[v_ax0])
T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0]
@R.function
def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")):
cls = Expected
gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")])
return gv
# fmt: on
mod = LegalizeOps()(BatchNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_batch_norm_symbolic():
# fmt: off
@tvm.script.ir_module
class BatchNorm:
@R.function
def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")):
n = T.int64()
h = T.int64()
w = T.int64()
c = T.int64()
gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
T.func_attr({"tir.noalias": True})
n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64()
x = T.match_buffer(var_x, (n, h, w, c))
gamma = T.match_buffer(var_gamma, (c,))
beta = T.match_buffer(var_beta, (c,))
moving_mean = T.match_buffer(var_moving_mean, (c,))
moving_var = T.match_buffer(var_moving_var, (c,))
T_add = T.match_buffer(var_T_add, (n, h, w, c))
T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),))
T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),))
with T.block("root"):
T.reads()
T.writes()
T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_subtract = T.alloc_buffer((n, h, w, c))
T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_divide = T.alloc_buffer((n, h, w, c))
T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_multiply = T.alloc_buffer((n, h, w, c))
T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_multiply_1 = T.alloc_buffer((c,))
x_red = T.alloc_buffer((h,))
T_divide_1 = T.alloc_buffer((h,))
T_multiply_2 = T.alloc_buffer((h,))
T_multiply_3 = T.alloc_buffer((c,))
T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
T_subtract_1 = T.alloc_buffer((n, h, w, c))
T_subtract_2 = T.alloc_buffer((n, h, w, c))
T_multiply_4 = T.alloc_buffer((n, h, w, c))
T_multiply_red = T.alloc_buffer((h,))
T_divide_2 = T.alloc_buffer((h,))
T_multiply_5 = T.alloc_buffer((h,))
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_subtract"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_1"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
for i0 in range(T.int64(1)):
for i1 in range(h):
for i2 in range(T.int64(1)):
for i3 in range(T.int64(1)):
with T.block("compute"):
v_i0 = T.axis.spatial(T.int64(1), i0)
v_i1 = T.axis.spatial(h, i1)
v_i2 = T.axis.spatial(T.int64(1), i2)
v_i3 = T.axis.spatial(T.int64(1), i3)
T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_divide"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_2"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c])
T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_3"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c])
T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_add_1"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(c):
with T.block("T_multiply_1"):
v_ax0 = T.axis.spatial(c, ax0)
T.reads(moving_mean[v_ax0])
T.writes(T_multiply_1[v_ax0])
T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0]
for ax0 in range(h):
for k0 in range(n):
for k2 in range(w):
for k3 in range(c):
with T.block("x_red"):
v_ax0 = T.axis.spatial(h, ax0)
v_k0 = T.axis.reduce(n, k0)
v_k2 = T.axis.reduce(w, k2)
v_k3 = T.axis.reduce(c, k3)
T.reads(x[v_k0, v_ax0, v_k2, v_k3])
T.writes(x_red[v_ax0])
with T.init():
x_red[v_ax0] = T.float32(0.0)
x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3]
for ax0 in range(h):
with T.block("T_divide_1"):
v_ax0 = T.axis.spatial(h, ax0)
T.reads(x_red[v_ax0])
T.writes(T_divide_1[v_ax0])
T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * w * c)
for ax0 in range(h):
with T.block("T_multiply_2"):
v_ax0 = T.axis.spatial(h, ax0)
T.reads(T_divide_1[v_ax0])
T.writes(T_multiply_2[v_ax0])
T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0]
for ax0 in range(T.max(c, h)):
with T.block("T_add_2"):
v_ax0 = T.axis.spatial(T.max(c, h), ax0)
T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
T.writes(T_add_1[v_ax0])
T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0]
for ax0 in range(c):
with T.block("T_multiply_3"):
v_ax0 = T.axis.spatial(c, ax0)
T.reads(moving_var[v_ax0])
T.writes(T_multiply_3[v_ax0])
T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
for ax3 in range(T.int64(1)):
with T.block("T_reshape_4"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h])
T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_subtract_1"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_subtract_2"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
with T.block("T_multiply_4"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0 in range(h):
for k0 in range(n):
for k2 in range(w):
for k3 in range(c):
with T.block("T_multiply_red"):
v_ax0 = T.axis.spatial(h, ax0)
v_k0 = T.axis.reduce(n, k0)
v_k2 = T.axis.reduce(w, k2)
v_k3 = T.axis.reduce(c, k3)
T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0])
with T.init():
T_multiply_red[v_ax0] = T.float32(0.0)
T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
for ax0 in range(h):
with T.block("T_divide_2"):
v_ax0 = T.axis.spatial(h, ax0)
T.reads(T_multiply_red[v_ax0])
T.writes(T_divide_2[v_ax0])
T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c)
for ax0 in range(h):
with T.block("T_multiply_5"):
v_ax0 = T.axis.spatial(h, ax0)
T.reads(T_divide_2[v_ax0])
T.writes(T_multiply_5[v_ax0])
T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0]
for ax0 in range(T.max(c, h)):
with T.block("T_add_3"):
v_ax0 = T.axis.spatial(T.max(c, h), ax0)
T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
T.writes(T_add_2[v_ax0])
T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0]
@R.function
def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")):
n = T.int64()
h = T.int64()
w = T.int64()
c = T.int64()
cls = Expected
gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")])
return gv
mod = LegalizeOps()(BatchNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_layer_norm():
# fmt: off
@tvm.script.ir_module
class LayerNorm:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1])
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), R.Tensor((2, 3, 4, 5), dtype="float32"))
return gv
@T.prim_func(private=True)
def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": True})
rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("rxplaceholder_red_temp"):
ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3])
T.reads(rxplaceholder[ax0, ax1, k2, k3])
T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1])
with T.init():
rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0)
rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3]
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] * rxplaceholder[ax0, ax1, k2, k3]
rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_layer_norm"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3])
T.writes(T_layer_norm[ax0, ax1, ax2, ax3])
T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3]
# fmt: on
mod = LegalizeOps()(LayerNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_layer_norm_1d():
# fmt: off
@I.ir_module
class LayerNorm_1D:
@R.function
def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
layer_norm: R.Tensor((3,), dtype="float32") = R.nn.layer_norm(x, layer_norm_weight, layer_norm_bias, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True)
gv: R.Tensor((3,), dtype="float32") = layer_norm
R.output(gv)
return gv
@I.ir_module
class LayerNorm_1D_Expected:
@T.prim_func(private=True)
def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
x_red_temp_v0 = T.alloc_buffer(())
x_red_temp_v1 = T.alloc_buffer(())
for k0 in range(T.int64(3)):
with T.block("x_red_temp"):
v_k0 = T.axis.reduce(T.int64(3), k0)
T.reads(x[v_k0])
T.writes(x_red_temp_v0[()], x_red_temp_v1[()])
with T.init():
x_red_temp_v0[()] = T.float32(0.0)
x_red_temp_v1[()] = T.float32(0.0)
v_x_red_temp_v0: T.float32 = x_red_temp_v0[()] + x[v_k0]
v_x_red_temp_v1: T.float32 = x_red_temp_v1[()] + x[v_k0] * x[v_k0]
x_red_temp_v0[()] = v_x_red_temp_v0
x_red_temp_v1[()] = v_x_red_temp_v1
for ax0 in range(T.int64(3)):
with T.block("T_layer_norm"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
T.writes(T_layer_norm[v_ax0])
T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * T.float32(0.33333333333333331) - x_red_temp_v0[()] * T.float32(0.33333333333333331) * (x_red_temp_v0[()] * T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
@R.function
def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"):
R.func_attr({"num_input": 1})
cls = LayerNorm_1D_Expected
with R.dataflow():
layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32"))
gv: R.Tensor((3,), dtype="float32") = layer_norm
R.output(gv)
return gv
# fmt: on
mod = LegalizeOps()(LayerNorm_1D)
tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected)
def test_layer_norm_fp16():
# fmt: off
@tvm.script.ir_module
class LayerNorm:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float16"), gamma: R.Tensor((4, 5), "float16"), beta: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"):
gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1])
return gv
@I.ir_module
class Expected:
@T.prim_func(private=True)
def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
T.func_attr({"tir.noalias": True})
rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), T.int64(5)), "float16")
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), T.int64(5)), "float16")
T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")
with T.block("root"):
T.reads()
T.writes()
rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(3)))
rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(3)))
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for k2 in range(T.int64(4)):
for k3 in range(T.int64(5)):
with T.block("rxplaceholder_red_temp"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_k2 = T.axis.reduce(T.int64(4), k2)
v_k3 = T.axis.reduce(T.int64(5), k3)
T.reads(rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
with T.init():
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(4)):
for ax3 in range(T.int64(5)):
with T.block("T_layer_norm"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(4), ax2)
v_ax3 = T.axis.spatial(T.int64(5), ax3)
T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3], rxplaceholder_2[v_ax2, v_ax3])
T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3]) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) + T.float32(1.0000000000000001e-05))) * rxplaceholder_1[v_ax2, v_ax3] + rxplaceholder_2[v_ax2, v_ax3]
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma: R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"))
return gv
# fmt: on
mod = LegalizeOps()(LayerNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_layer_norm_symbolic():
# fmt: off
@tvm.script.ir_module
class LayerNorm:
@R.function
def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"):
n = T.int64()
s = T.int64()
f = T.int64()
gv: R.Tensor((n, s, f), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[1, 2])
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"):
n = T.int64()
s = T.int64()
f = T.int64()
gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), R.Tensor((n, s, f), dtype="float32"))
return gv
@T.prim_func(private=True)
def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
T.func_attr({"tir.noalias": True})
f = T.int64()
n = T.int64()
s = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], dtype="float32")
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], dtype="float32")
T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f], dtype="float32")
rxplaceholder_red_temp_v0 = T.alloc_buffer([n], dtype="float32")
rxplaceholder_red_temp_v1 = T.alloc_buffer([n], dtype="float32")
for i0, i1, i2 in T.grid(n, s, f):
with T.block("rxplaceholder_red_temp"):
ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2])
T.reads(rxplaceholder[ax0, k1, k2])
T.writes(rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0])
with T.init():
rxplaceholder_red_temp_v0[ax0] = T.float32(0)
rxplaceholder_red_temp_v1[ax0] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2]
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] * rxplaceholder[ax0, k1, k2]
rxplaceholder_red_temp_v0[ax0] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[ax0] = v_rxplaceholder_red_temp_v1
for i0, i1, i2 in T.grid(n, s, f):
with T.block("T_layer_norm"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2])
T.writes(T_layer_norm[ax0, ax1, ax2])
T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1, ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) * (rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] + rxplaceholder_2[ax1, ax2]
# fmt: on
mod = LegalizeOps()(LayerNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_group_norm():
# fmt: off
@tvm.script.ir_module
class GroupNorm:
@R.function
def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), beta: R.Tensor((4,), "float32")) -> R.Tensor((2, 4, 4, 5), "float32"):
gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": True})
T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)))
rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2)))
rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2)))
T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2)))
T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2)))
T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)))
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]
for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("rxplaceholder_red_temp"):
v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4])
T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4])
T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
with T.init():
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1
for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
with T.block("T_reshape_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)])
T.writes(T_reshape_2[v_ax0, v_ax1])
T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]
for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
with T.block("T_reshape_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)])
T.writes(T_reshape_3[v_ax0, v_ax1])
T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("T_group_norm"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)):
with T.block("T_reshape_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]
@R.function
def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"):
gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(GroupNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_group_norm_fp16():
# fmt: off
@tvm.script.ir_module
class GroupNorm:
@R.function
def main(x: R.Tensor((2, 4, 4, 5), "float16"), gamma: R.Tensor((4,), "float16"), beta: R.Tensor((4,), "float16")) -> R.Tensor((2, 4, 4, 5), "float16"):
gv: R.Tensor((2, 4, 4, 5), "float16") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3])
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma: R.Tensor((4,), dtype="float16"), beta: R.Tensor((4,), dtype="float16")) -> R.Tensor((2, 4, 4, 5), dtype="float16"):
gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float16"))
return gv
@T.prim_func(private=True)
def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(4),), "float16"), rxplaceholder_2: T.Buffer((T.int64(4),), "float16"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16")
T_cast = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)))
rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2)))
rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2)))
T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16")
T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16")
T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16")
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float32", T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("rxplaceholder_red_temp"):
v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4])
T.reads(T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4])
T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
with T.init():
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1
for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
with T.block("T_reshape_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)])
T.writes(T_reshape_2[v_ax0, v_ax1])
T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]
for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
with T.block("T_reshape_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)])
T.writes(T_reshape_3[v_ax0, v_ax1])
T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)):
with T.block("T_group_norm"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)):
with T.block("T_reshape_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]
# fmt: on
mod = LegalizeOps()(GroupNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_group_norm_symbolic():
# fmt: off
@tvm.script.ir_module
class GroupNorm:
@R.function
def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), gamma: R.Tensor(("4 * c",), "float32"), beta: R.Tensor(("4 * c",), "float32")) -> R.Tensor(("n", "4 * c", "h", "w"), "float32"):
n = T.int64()
c = T.int64()
h = T.int64()
w = T.int64()
gv: R.Tensor((n, 4 * c, h, w), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=4, channel_axis=1, axes=[2, 3])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64):
T.func_attr({"tir.noalias": True})
n = T.int64()
h = T.int64()
w = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, (n, T.int64(4) * c, h, w))
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,))
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,))
T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w))
# with T.block("root"):
T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w))
rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4)))
rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4)))
T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4)))
T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4)))
T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w))
for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]
for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w):
with T.block("rxplaceholder_red_temp"):
v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4])
T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4])
T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
with T.init():
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]
v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]
rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0
rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1
for ax0, ax1 in T.grid(T.int64(4), c):
with T.block("T_reshape_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))])
T.writes(T_reshape_2[v_ax0, v_ax1])
T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]
for ax0, ax1 in T.grid(T.int64(4), c):
with T.block("T_reshape_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))])
T.writes(T_reshape_3[v_ax0, v_ax1])
T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]
for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w):
with T.block("T_group_norm"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2]
for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w):
with T.block("T_reshape_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_group_norm[(((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) % w])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) % w]
@R.function
def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"):
n = T.int64()
c = T.int64()
h = T.int64()
w = T.int64()
gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c]))
return gv
# fmt: on
mod = LegalizeOps()(GroupNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_rms_norm():
# fmt: off
@tvm.script.ir_module
class RMSNorm:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight, axes=[-2, -1])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("rsqrt"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_red[v_ax0, v_ax1])
T.writes(rsqrt[v_ax0, v_ax1])
rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
with T.block("T_cast_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v_ax0, v_ax1])
T.writes(T_cast_2[v_ax0, v_ax1])
T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_rms_norm_fp16():
# fmt: off
@tvm.script.ir_module
class RMSNorm:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"):
gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight, axes=[-2, -1])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), B: T.Buffer((T.int64(4), T.int64(5)), "float16"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3])
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("rsqrt"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_red[v_ax0, v_ax1])
T.writes(rsqrt[v_ax0, v_ax1])
rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
with T.block("T_cast_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v_ax0, v_ax1])
T.writes(T_cast_2[v_ax0, v_ax1])
T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), weight: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
cls = Expected
gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_rms_norm_symbolic():
# fmt: off
@tvm.script.ir_module
class RMSNorm:
@R.function
def main(x: R.Tensor(("n", "s", "f"), "float32"), weight: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"):
n = T.int64()
s = T.int64()
f = T.int64()
gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight, axes=[1, 2])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
T.func_attr({"tir.noalias": True})
n, s, f = T.int64(), T.int64(), T.int64()
A = T.match_buffer(var_A, (n, s, f))
B = T.match_buffer(var_B, (s, f))
T_cast = T.match_buffer(var_T_cast, (n, s, f))
# with T.block("root"):
T_cast_1 = T.alloc_buffer((n, s, f))
T_multiply = T.alloc_buffer((n, s, f))
T_multiply_red = T.alloc_buffer((n,))
rsqrt = T.alloc_buffer((n,))
T_cast_2 = T.alloc_buffer((s, f))
T_rms_norm = T.alloc_buffer((n, s, f))
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_cast"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax1, v_ax2])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
for ax0, k1, k2 in T.grid(n, s, f):
with T.block("T_multiply_red"):
v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
T.reads(T_multiply[v_ax0, v_k1, v_k2])
T.writes(T_multiply_red[v_ax0])
with T.init():
T_multiply_red[v_ax0] = T.float32(0)
T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
for ax0 in range(n):
with T.block("rsqrt"):
v_ax0 = T.axis.spatial(n, ax0)
T.reads(T_multiply_red[v_ax0])
T.writes(rsqrt[v_ax0])
rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05))
for ax0, ax1 in T.grid(s, f):
with T.block("T_cast_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v_ax0, v_ax1])
T.writes(T_cast_2[v_ax0, v_ax1])
T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
T.writes(T_cast[v_ax0, v_ax1, v_ax2])
T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1, v_ax2]
@R.function
def main(x: R.Tensor(("n", "s", "f"), dtype="float32"), weight: R.Tensor(("s", "f"), dtype="float32")) -> R.Tensor(("n", "s", "f"), dtype="float32"):
n = T.int64()
s = T.int64()
f = T.int64()
cls = Expected
gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, s, f), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_rms_norm_no_bias():
# fmt: off
@tvm.script.ir_module
class RMSNorm:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight, axes=[-2, -1])
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("rsqrt"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_red[v_ax0, v_ax1])
T.writes(rsqrt[v_ax0, v_ax1])
rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
with T.block("T_cast_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v_ax0, v_ax1])
T.writes(T_cast_2[v_ax0, v_ax1])
T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
tvm.ir.assert_structural_equal(mod, Expected)
def test_attention():
# fmt: off
@tvm.script.ir_module
class Attention:
@R.function
def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4, 32, 16, 8), "float32")):
scale = T.FloatImm("float32", 0.1)
gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="TopLeft")
return gv
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), B: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), C: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), D: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)))
T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(8)))
T_reshape_1 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(8)))
T_batch_matmul_NT = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
T_multiply = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)))
T_add = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)))
T_reshape_3 = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
trilu = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
trilu_red = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(1)))
T_subtract = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
compute = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
trilu_1 = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
trilu_red_1 = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(1)))
T_divide = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
T_transpose_3 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(16)))
T_reshape_4 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(16)))
T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(16)))
T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(16)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(8)):
with T.block("T_transpose_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(B[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = B[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)):
with T.block("T_reshape_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)):
with T.block("T_batch_matmul_NT"):
v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k])
T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
T.block_attr({"layout_free_placeholders": [T_reshape_1]})
with T.init():
T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0)
T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] * T.float32(0.10000000000000001)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)):
with T.block("T_reshape_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)])
T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], D[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + D[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("T_reshape_3"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)])
T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2])
T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("trilu"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_reshape_3[v_i0, v_i1, v_i2])
T.writes(trilu[v_i0, v_i1, v_i2])
trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, T_reshape_3[v_i0, v_i1, v_i2], T.float32(0))
for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)):
with T.block("trilu_red"):
v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2])
T.reads(trilu[v_ax0, v_ax1, v_k2])
T.writes(trilu_red[v_ax0, v_ax1, v_ax2])
with T.init():
trilu_red[v_ax0, v_ax1, v_ax2] = T.float32(-3.4028234663852886e+38)
trilu_red[v_ax0, v_ax1, v_ax2] = T.max(trilu_red[v_ax0, v_ax1, v_ax2], trilu[v_ax0, v_ax1, v_k2])
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("T_subtract"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(trilu[v_ax0, v_ax1, v_ax2], trilu_red[v_ax0, v_ax1, T.int64(0)])
T.writes(T_subtract[v_ax0, v_ax1, v_ax2])
T_subtract[v_ax0, v_ax1, v_ax2] = trilu[v_ax0, v_ax1, v_ax2] - trilu_red[v_ax0, v_ax1, T.int64(0)]
for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("compute"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_subtract[v_i0, v_i1, v_i2])
T.writes(compute[v_i0, v_i1, v_i2])
compute[v_i0, v_i1, v_i2] = T.exp(T_subtract[v_i0, v_i1, v_i2])
for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("trilu_1"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(compute[v_i0, v_i1, v_i2])
T.writes(trilu_1[v_i0, v_i1, v_i2])
trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, compute[v_i0, v_i1, v_i2], T.float32(0))
for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)):
with T.block("trilu_red_1"):
v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2])
T.reads(trilu_1[v_ax0, v_ax1, v_k2])
T.writes(trilu_red_1[v_ax0, v_ax1, v_ax2])
with T.init():
trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0)
trilu_red_1[v_ax0, v_ax1, v_ax2] = trilu_red_1[v_ax0, v_ax1, v_ax2] + trilu_1[v_ax0, v_ax1, v_k2]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
with T.block("T_divide"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(trilu_1[v_ax0, v_ax1, v_ax2], trilu_red_1[v_ax0, v_ax1, T.int64(0)])
T.writes(T_divide[v_ax0, v_ax1, v_ax2])
T_divide[v_ax0, v_ax1, v_ax2] = trilu_1[v_ax0, v_ax1, v_ax2] / trilu_red_1[v_ax0, v_ax1, T.int64(0)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(16)):
with T.block("T_transpose_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(C[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = C[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)):
with T.block("T_reshape_4"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)])
T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2])
T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]
for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)):
with T.block("T_batch_matmul_NN"):
v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
T.reads(T_divide[v_b, v_i, v_k], T_reshape_4[v_b, v_k, v_j])
T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
T.block_attr({"layout_free_placeholders": [T_reshape_4]})
with T.init():
T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0)
T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_divide[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(16)):
with T.block("T_reshape_5"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)])
T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16), T.int64(32), T.int64(16)):
with T.block("T_transpose_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]
@R.function
def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"), bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.attention_bias, (q, k, v, bias), out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(Attention)
tvm.ir.assert_structural_equal(mod, Expected)
def test_dynamic_attention():
"""The sequence lengths may be dynamic
In previous implementations, the `seq_len` and `seq_len_kv` were
assumed to be static integers, and produced an exception during
legalization.
"""
@tvm.script.ir_module
class Attention:
@R.function
def main(
q: R.Tensor((4, "seq_len", 32, 8), "float32"),
k: R.Tensor((4, "seq_len_kv", 32, 8), "float32"),
v: R.Tensor((4, "seq_len_kv", 32, 16), "float32"),
bias: R.Tensor((4, 32, "seq_len", "seq_len_kv"), "float32"),
):
scale = T.FloatImm("float32", 0.1)
gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight")
return gv
LegalizeOps()(Attention)
def test_nll_loss():
# fmt: off
@tvm.script.ir_module
class NLLLoss:
@R.function
def main(
predictions: R.Tensor((2, 3, 4, 5), "float32"),
targets: R.Tensor((2, 4, 5), "int64"),
weights: R.Tensor((3,), "float32"),
) -> R.Tensor((), "float32"):
gv = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(
predictions: R.Tensor((2, 3, 4, 5), dtype="float32"),
targets: R.Tensor((2, 4, 5), dtype="int64"),
weights: R.Tensor((3,), dtype="float32"),
) -> R.Tensor((), dtype="float32"):
gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def nll_loss(
predictions: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"),
targets: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"),
weights: T.Buffer(T.int64(3), "float32"),
output: T.Buffer((), "float32"),
):
# function attr dict
T.func_attr({"tir.noalias": True})
# body
# with T.block("root")
nll_loss = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32")
nll_loss_red = T.alloc_buffer([], dtype="float32")
nll_loss_1 = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32")
nll_loss_red_1 = T.alloc_buffer([], dtype="float32")
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]])
T.writes(nll_loss[v_ax0, v_ax1, v_ax2])
nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss[v_k0, v_k1, v_k2])
T.writes(nll_loss_red[()])
with T.init():
nll_loss_red[()] = T.float32(0)
nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2]
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(targets[v_ax0, v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]])
T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2])
nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red_1"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss_1[v_k0, v_k1, v_k2])
T.writes(nll_loss_red_1[()])
with T.init():
nll_loss_red_1[()] = T.float32(0)
nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2]
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(nll_loss_red[()], nll_loss_red_1[()])
T.writes(output[()])
output[()] = nll_loss_red[()] / nll_loss_red_1[()]
# fmt: on
mod = LegalizeOps()(NLLLoss)
tvm.ir.assert_structural_equal(mod, Expected)
def test_nll_no_weight():
# fmt: off
@tvm.script.ir_module
class NLLLoss:
@R.function
def main(predictions: R.Tensor((2, 3, 4, 5), "float32"), targets: R.Tensor((2, 4, 5), "int64")) -> R.Tensor((), "float32"):
gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets, reduction="mean", ignore_index=-1)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5), dtype="int64"),) -> R.Tensor((), dtype="float32"):
# block 0
gv = R.call_tir(Expected.nll_loss_without_weight, (predictions, targets), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def nll_loss_without_weight(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), T_divide: T.Buffer((), "float32"),):
# function attr dict
T.func_attr({"tir.noalias": True})
# body
# with T.block("root")
T_full = T.alloc_buffer([T.int64(3)], dtype="float32")
nll_loss = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32")
nll_loss_red = T.alloc_buffer([], dtype="float32")
nll_loss_1 = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32")
nll_loss_red_1 = T.alloc_buffer([], dtype="float32")
for ax0 in T.serial(T.int64(3)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(1)
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
T.writes(nll_loss[v_ax0, v_ax1, v_ax2])
nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss[v_k0, v_k1, v_k2])
T.writes(nll_loss_red[()])
with T.init():
nll_loss_red[()] = T.float32(0)
nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2]
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2])
nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red_1"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss_1[v_k0, v_k1, v_k2])
T.writes(nll_loss_red_1[()])
with T.init():
nll_loss_red_1[()] = T.float32(0)
nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2]
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(nll_loss_red[()], nll_loss_red_1[()])
T.writes(T_divide[()])
T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()]
# fmt: on
mod = LegalizeOps()(NLLLoss)
tvm.ir.assert_structural_equal(mod, Expected)
def test_nll_no_batch():
# fmt: off
@tvm.script.ir_module
class NLLLoss:
@R.function
def main(predictions: R.Tensor(("C",), "float32"), targets: R.Tensor((), "int64"), weights: R.Tensor(("C",), "float32")) -> R.Tensor((), "float32"):
gv = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=1)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(predictions: R.Tensor(("C",), dtype="float32"), targets: R.Tensor((), dtype="int64"), weights: R.Tensor(("C",), dtype="float32")) -> R.Tensor((), dtype="float32"):
C = T.int64()
gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), out_sinfo=R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def nll_loss(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((), "int64"), var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
C = T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (C,))
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_1, (C,))
# with T.block("root"):
nll_loss = T.alloc_buffer(())
nll_loss_1 = T.alloc_buffer(())
with T.block("nll_loss"):
vi = T.axis.spatial(T.int64(1), T.int64(0))
T.reads(rxplaceholder[()], rxplaceholder_1[rxplaceholder[()]], rxplaceholder_2[rxplaceholder[()]])
T.writes(nll_loss[()])
nll_loss[()] = T.Select(rxplaceholder[()] != T.int64(1), (T.float32(0) - rxplaceholder_1[rxplaceholder[()]]) * rxplaceholder_2[rxplaceholder[()]], T.float32(0))
with T.block("nll_loss_1"):
vi = T.axis.spatial(T.int64(1), T.int64(0))
T.reads(rxplaceholder[()], rxplaceholder_2[rxplaceholder[()]])
T.writes(nll_loss_1[()])
nll_loss_1[()] = T.Select(rxplaceholder[()] != T.int64(1), rxplaceholder_2[rxplaceholder[()]], T.float32(0))
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(nll_loss[()], nll_loss_1[()])
T.writes(T_divide[()])
T_divide[()] = nll_loss[()] / nll_loss_1[()]
# fmt: on
mod = LegalizeOps()(NLLLoss)
tvm.ir.assert_structural_equal(mod, Expected)
def test_nll_loss_symbolic():
# fmt: off
@tvm.script.ir_module
class NLLLoss:
@R.function
def main(predictions: R.Tensor(("N", "C", "d1", "d2"), "float32"), targets: R.Tensor(("N", "d1", "d2"), "int64"), weights: R.Tensor(("C",), "float32")) -> R.Tensor((), "float32"):
gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(predictions: R.Tensor(("N", "C", "d1", "d2"), dtype="float32"), targets: R.Tensor(("N", "d1", "d2"), dtype="int64"), weights: R.Tensor(("C",), dtype="float32")) -> R.Tensor((), dtype="float32"):
# block 0
gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
def nll_loss(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, T_divide: T.Buffer((), "float32"),):
# function attr dict
T.func_attr({"tir.noalias": True})
C = T.int64()
N = T.int64()
d1 = T.int64()
d2 = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [N, C, d1, d2], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N, d1, d2], dtype="int64")
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [C], dtype="float32")
# body
# with T.block("root")
nll_loss = T.alloc_buffer([N, d1, d2], dtype="float32")
nll_loss_red = T.alloc_buffer([], dtype="float32")
nll_loss_1 = T.alloc_buffer([N, d1, d2], dtype="float32")
nll_loss_red_1 = T.alloc_buffer([], dtype="float32")
for ax0, ax1, ax2 in T.grid(N, d1, d2):
with T.block("nll_loss"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2],rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]],)
T.writes(nll_loss[v_ax0, v_ax1, v_ax2])
nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0),)
for k0, k1, k2 in T.grid(N, d1, d2):
with T.block("nll_loss_red"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss[v_k0, v_k1, v_k2])
T.writes(nll_loss_red[()])
with T.init():
nll_loss_red[()] = T.float32(0)
nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2]
for ax0, ax1, ax2 in T.grid(N, d1, d2):
with T.block("nll_loss_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]],)
T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2])
nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0),)
for k0, k1, k2 in T.grid(N, d1, d2):
with T.block("nll_loss_red_1"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
T.reads(nll_loss_1[v_k0, v_k1, v_k2])
T.writes(nll_loss_red_1[()])
with T.init():
nll_loss_red_1[()] = T.float32(0)
nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2]
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(nll_loss_red[()], nll_loss_red_1[()])
T.writes(T_divide[()])
T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()]
# fmt: on
mod = LegalizeOps()(NLLLoss)
tvm.ir.assert_structural_equal(mod, Expected)
def test_pad():
@tvm.script.ir_module
class Pad:
@R.function
def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float32"):
gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, (0, 0, 1, 1, 1, 1))
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 128, 28), dtype="float32"),
) -> R.Tensor((2, 130, 30), dtype="float32"):
gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 30), dtype="float32"))
return gv
@T.prim_func(private=True)
def pad(
A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"),
PadInput: T.Buffer((T.int64(2), T.int64(130), T.int64(30)), "float32"),
):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i0, i1, i2 in T.grid(T.int64(2), T.int64(130), T.int64(30)):
with T.block("PadInput"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)])
T.writes(PadInput[v_i0, v_i1, v_i2])
PadInput[v_i0, v_i1, v_i2] = T.if_then_else(
T.int64(1) <= v_i1
and v_i1 < T.int64(129)
and T.int64(1) <= v_i2
and v_i2 < T.int64(29),
A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)],
T.float32(0),
)
mod = LegalizeOps()(Pad)
tvm.ir.assert_structural_equal(mod, Expected)
if __name__ == "__main__":
tvm.testing.main()