blob: 33f3933cb4e5774561aa1cc52d4ce5db1a7320c8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,,missing-function-docstring
import tvm
from tvm.tir.transform import DefaultGPUSchedule
from tvm.script import tir as T
import tvm.testing
def test_broadcast_to_symbolic():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def broadcast_to(
rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
var_T_broadcast_to: T.handle,
):
T.func_attr({"tir.noalias": True})
x_0 = T.int64()
x_1 = T.int64()
T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
# with T.block("root"):
for ax0, ax1 in T.grid(x_0, x_1):
with T.block("T_broadcast_to"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, T.int64(0)])
T.writes(T_broadcast_to[v_ax0, v_ax1])
T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]
@tvm.script.ir_module
class Expected:
@T.prim_func
def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), var_T_broadcast_to: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
x_0, x_1 = T.int64(), T.int64()
T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)):
with T.block("T_broadcast_to"):
v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1)
v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
After = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_matmul():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def matmul(
A: T.Buffer((32, 32), "float16"),
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
@T.prim_func
def matmul_gpu(
A: T.Buffer((32, 32), "float16"),
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main",
"target": T.target({"arch": "sm_86",
"keys": ["cuda", "gpu"],
"kind": "cuda",
"max_num_threads": 1024,
"tag": "",
"thread_warp_size": 32}),
"tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
@T.prim_func
def matmul_cpu(
A: T.Buffer((32, 32), "float16"),
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main",
"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
"tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
@tvm.script.ir_module
class Expected:
@T.prim_func
def matmul(
A: T.Buffer((32, 32), "float16"),
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
for k in range(32):
with T.block("C"):
v_i = T.axis.spatial(
32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32
)
v_j = T.axis.spatial(
32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32
)
v_k = T.axis.reduce(32, k)
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
@T.prim_func
def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")):
T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
@T.prim_func
def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")):
T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
for k in range(32):
with T.block("C"):
v_i = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32)
v_j = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32)
v_k = T.axis.reduce(32, k)
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float16(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
After = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_add():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
@tvm.script.ir_module
class Expected:
@T.prim_func
def add(
rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"),
rxplaceholder_1: T.Buffer(
(T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
),
T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"),
):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(
T.int64(72), thread="threadIdx.x"
):
with T.block("T_add"):
ax0 = T.axis.spatial(
T.int64(4),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
// T.int64(18),
)
ax1 = T.axis.spatial(
T.int64(3),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(18)
// T.int64(6),
)
ax2 = T.axis.spatial(
T.int64(2),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(6)
// T.int64(3),
)
ax3 = T.axis.spatial(
T.int64(3),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(3),
)
T.reads(
rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
)
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = (
rxplaceholder[T.int64(0), ax2, ax3]
+ rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
)
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
After = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_full():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
@tvm.script.ir_module
class Expected:
@T.prim_func
def full(
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"):
with T.block("T_full"):
ax0 = T.axis.spatial(
T.int64(2),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3),
)
ax1 = T.axis.spatial(
T.int64(3),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3),
)
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
After = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_scheduled():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Scheduled:
@T.prim_func
def full(
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"):
with T.block("T_full"):
ax0 = T.axis.spatial(
T.int64(2),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3),
)
ax1 = T.axis.spatial(
T.int64(3),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3),
)
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
# should do nothing
After = DefaultGPUSchedule()(Scheduled)
tvm.ir.assert_structural_equal(After, Scheduled)
def test_multiple():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
@T.prim_func
def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
@tvm.script.ir_module
class Expected:
@T.prim_func
def add(
rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"),
rxplaceholder_1: T.Buffer(
(T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
),
T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"),
):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(
T.int64(72), thread="threadIdx.x"
):
with T.block("T_add"):
ax0 = T.axis.spatial(
T.int64(4),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
// T.int64(18),
)
ax1 = T.axis.spatial(
T.int64(3),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(18)
// T.int64(6),
)
ax2 = T.axis.spatial(
T.int64(2),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(6)
// T.int64(3),
)
ax3 = T.axis.spatial(
T.int64(3),
(i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1)
% T.int64(3),
)
T.reads(
rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
)
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = (
rxplaceholder[T.int64(0), ax2, ax3]
+ rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
)
@T.prim_func
def full(
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"):
with T.block("T_full"):
ax0 = T.axis.spatial(
T.int64(2),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3),
)
ax1 = T.axis.spatial(
T.int64(3),
(i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % T.int64(3),
)
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=3):
After = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_add_on_metal():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
@tvm.script.ir_module
class Expected:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"):
with T.block("T_add"):
ax0 = T.axis.spatial(T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) // T.int64(18))
ax1 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(18) // T.int64(6))
ax2 = T.axis.spatial(T.int64(2), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(6) // T.int64(3))
ax3 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(3))
T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("apple/m1-gpu")
with target, tvm.transform.PassContext(opt_level=0):
mod = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(mod, Expected)
def test_scalar_add():
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")):
T.func_attr({"tir.noalias": True})
with T.block("T_add"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(rxplaceholder[()])
T.writes(T_add[()])
T_add[()] = rxplaceholder[()] + T.int64(1)
@tvm.script.ir_module
class Expected:
@T.prim_func
def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
with T.block("T_add"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(rxplaceholder[()])
T.writes(T_add[()])
T_add[()] = rxplaceholder[()] + T.int64(1)
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=0):
mod = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(mod, Expected)
def test_sum():
# sum has two reduction axes and no spatial axis
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
for k0, k1 in T.grid(T.int64(2), T.int64(2)):
with T.block("A_red"):
v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
with T.init():
A_red[()] = T.float64(0)
A_red[()] = A_red[()] + A[v_k0, v_k1]
@tvm.script.ir_module
class Expected:
@T.prim_func
def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
T.func_attr({"tir.is_scheduled": True})
# with T.block("root"):
for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
for k0, k1 in T.grid(T.int64(2), T.int64(2)):
with T.block("A_red"):
v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
T.reads(A[v_k0, v_k1])
T.writes(A_red[()])
with T.init():
A_red[()] = T.float64(0)
A_red[()] = A_red[()] + A[v_k0, v_k1]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=0):
mod = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(mod, Expected)
if __name__ == "__main__":
tvm.testing.main()