blob: 90603a8bf293d704add41724a1d01a7e5038b667 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import pytest
import tvm.testing
from tvm import dlight as dl
from tvm.script import tir as T
from tvm.target import Target
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@pytest.fixture
def transform(self):
def transform(mod):
with Target("nvidia/geforce-gtx-1080-ti"):
# Use Matmul rule for Conv for now
return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
return transform
class TestConv3d(BaseBeforeAfter):
# fmt: off
@T.prim_func
def before(
A: T.Buffer((14308, 3, 2, 14, 14), "float16"),
W: T.Buffer((1280, 3, 2, 14, 14), "float16"),
C: T.Buffer((14308, 1280, 1, 1, 1), "float16"),
):
pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16")
for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14):
with T.block("pad_A"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4]
for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14):
with T.block("C"):
v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz])
with T.init():
C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0)
C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz]
@T.prim_func
def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")):
T.func_attr({"tir.is_scheduled": True})
# with T.block("root"):
C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local")
pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared")
W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(20, thread="blockIdx.y"):
for ax1_0 in T.thread_binding(448, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(4, 2):
for ax2_3_1_init in T.vectorized(2):
with T.block("C_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init)
C_reindex_pad_local[0, v1, v2] = T.float16(0.0)
for ax3_0 in range(74):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("pad_A_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("W_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2):
for ax2_3_1 in T.vectorized(2):
with T.block("C_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1)
v3 = T.axis.reduce(1184, ax3_0 * 16 + ax3_1)
C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("C_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < 14308)
C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2]
# fmt: on
if __name__ == "__main__":
tvm.testing.main()