blob: 1fda0f4321088be83c61b8cfd2c7d3d9603aa274 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import (
verify_trace_roundtrip,
assert_structural_equal_ignore_global_symbol,
)
# pylint: disable=no-member,invalid-name,unused-variable
########## Function before schedule ##########
@T.prim_func
def elementwise(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.alloc_buffer((128, 128))
C = T.match_buffer(c, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def elementwise_shape_int64(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (T.int64(128), T.int64(128)))
B = T.alloc_buffer((T.int64(128), T.int64(128)))
C = T.match_buffer(c, (T.int64(128), T.int64(128)))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def elementwise_reindex_cache_read(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
):
B = T.alloc_buffer((128, 128))
B_shared = T.alloc_buffer((128, 64, 2), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(B_shared[vj, vi // 2, vi % 2])
B_shared[vj, vi // 2, vi % 2] = B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vj, vi // 2, vi % 2])
T.writes(C[vi, vj])
C[vi, vj] = B_shared[vj, vi // 2, vi % 2] + T.float32(1)
@T.prim_func
def elementwise_reindex_cache_write(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
):
B = T.alloc_buffer((128, 128))
B_shared = T.alloc_buffer((128, 128), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B_shared[vj, vi])
B_shared[vj, vi] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = B_shared[vj, vi]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + T.float32(1)
@T.prim_func
def reduce(A: T.Buffer((128, 128, 128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
B = T.alloc_buffer((128, 128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
for l in range(128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = T.float32(0)
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + B[vi, vj, vk]
@T.prim_func
def reduce_reindex_cache_write_0(
A: T.Buffer((128, 128, 128, 128), "float32"), C: T.Buffer((128, 128), "float32")
):
B = T.alloc_buffer((128, 128, 128))
B_shared = T.alloc_buffer((128, 128, 128), scope="shared")
for i, j, k in T.grid(128, 128, 128):
for l in range(128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B_shared[vj, vi, vk])
with T.init():
B_shared[vj, vi, vk] = T.float32(0)
B_shared[vj, vi, vk] = B_shared[vj, vi, vk] + A[vi, vj, vk, vl]
with T.block("B_shared"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
T.reads(B_shared[vj, vi, vk])
T.writes(B[vi, vj, vk])
B[vi, vj, vk] = B_shared[vj, vi, vk]
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(B[vi, vj, vk])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + B[vi, vj, vk]
@T.prim_func
def reduce_reindex_cache_write_1(
A: T.Buffer((128, 128, 128, 128), "float32"), C: T.Buffer((128, 128), "float32")
):
B = T.alloc_buffer((128, 128, 128))
B_shared = T.alloc_buffer((128, 128, 128), scope="shared")
C_shared = T.alloc_buffer((128, 128), scope="shared")
for i, j, k in T.grid(128, 128, 128):
for l in range(128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B_shared[vj, vi, vk])
with T.init():
B_shared[vj, vi, vk] = T.float32(0)
B_shared[vj, vi, vk] = B_shared[vj, vi, vk] + A[vi, vj, vk, vl]
with T.block("B_shared"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
T.reads(B_shared[vj, vi, vk])
T.writes(B[vi, vj, vk])
B[vi, vj, vk] = B_shared[vj, vi, vk]
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(B[vi, vj, vk])
T.writes(C_shared[vj, vi])
with T.init():
C_shared[vj, vi] = T.float32(0)
C_shared[vj, vi] = C_shared[vj, vi] + B[vi, vj, vk]
for i, j in T.grid(128, 128):
with T.block("C_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(C_shared[vj, vi])
T.writes(C[vi, vj])
C[vi, vj] = C_shared[vj, vi]
@T.prim_func
def func_nested_seq(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = 2.0
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = 1.0
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = A[vi, vj] + B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
@T.prim_func
def access_under_scope(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
for i0, j0 in T.grid(8, 8):
with T.block("scope"):
i, j = T.axis.remap("SS", [i0, j0])
for x, y in T.grid(16, 16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
A[vi, vj] = 1.0
for x, y in T.grid(16, 16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = A[vi, vj] + 1.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
@T.prim_func
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128), dtype="float16")
B = T.match_buffer(b, (128, 128), dtype="float16")
C = T.match_buffer(c, (128, 128), dtype="float16")
D = T.match_buffer(d, (128, 128), dtype="float16")
for i, j in T.grid(128, 128):
with T.block("load_store"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = A[vi, vj]
for i, j in T.grid(8, 8):
with T.block("opaque"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.evaluate(
T.tvm_load_matrix_sync(
B.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A.data,
vi * 2048 + vj * 16,
128,
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
for i, j in T.grid(8, 8):
with T.block("match_buffer"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A0 = T.match_buffer(
A[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
C0 = T.match_buffer(
C[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
T.evaluate(
T.tvm_load_matrix_sync(
C0.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A0.data,
A0.elem_offset,
A0.strides[0],
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
@T.prim_func
def func_multi_consumer() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0
for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A[vi]
@T.prim_func
def reindex_cache_read_multi_consumer() -> None:
A = T.alloc_buffer((128,))
B = T.alloc_buffer((128,))
C = T.alloc_buffer((128,))
A_shared = T.alloc_buffer((4, 32), scope="shared")
for i in range(8):
for j in range(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
T.reads()
T.writes(A[vi])
A[vi] = T.float32(1)
for j in range(16):
with T.block("A_shared"):
vi = T.axis.spatial(128, i * 16 + j)
T.reads(A[vi])
T.writes(A_shared[vi // 32, vi % 32])
A_shared[vi // 32, vi % 32] = A[vi]
for j in range(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
T.reads(A_shared[vi // 32, vi % 32])
T.writes(B[vi])
B[vi] = A_shared[vi // 32, vi % 32] + T.float32(1)
for i in range(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
T.reads(A[vi])
T.writes(C[vi])
C[vi] = A[vi]
@T.prim_func
def func_multi_producer() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
for i in range(128):
with T.block("A0"):
vi = T.axis.S(128, i)
A[vi] = 1.0
for i in range(128):
with T.block("A1"):
vi = T.axis.S(128, i)
A[vi] = 2.0
for i in range(128):
with T.block("B"):
vi = T.axis.S(128, i)
B[vi] = A[vi]
@T.prim_func
def func_with_block_predicate() -> None:
A = T.alloc_buffer((120))
B = T.alloc_buffer((120))
for i, j in T.grid(16, 8):
with T.block("producer"):
T.where(i * 8 + j < 120)
ax = T.axis.S(120, i * 8 + j)
A[ax] = 0.0
for i, j in T.grid(16, 8):
with T.block("consumer"):
T.where(i * 8 + j < 120)
ax = T.axis.S(120, i * 8 + j)
B[ax] = A[ax] + 1.0
@T.prim_func
def inplace_func(data_io: T.Buffer((64), "int32")):
data_1d = T.alloc_buffer([64], dtype="int32")
for i0 in T.serial(64):
with T.block("copy_in"):
v0 = T.axis.remap("S", [i0])
data_1d[v0] = data_io[v0]
for i0 in T.serial(1):
with T.block("ext_call"):
T.reads(data_1d[:64])
T.writes(data_1d[:64])
T.evaluate(T.call_extern("call_impl", data_1d.data, dtype=""))
for i0 in T.serial(64):
with T.block("copy_out"):
v0 = T.axis.remap("S", [i0])
data_io[v0] = data_1d[v0]
@T.prim_func
def inplace_call(data_io: T.Buffer((64), "int32")):
for i0 in T.serial(1):
with T.block("ext_call"):
T.reads(data_io[:64])
T.writes(data_io[:64])
T.evaluate(T.call_extern("call_impl", data_io.data, dtype=""))
@T.prim_func
def cache_read_nested_seq_target(
B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
A = T.alloc_buffer([128, 128], dtype="float32")
A_global = T.alloc_buffer([128, 128], dtype="float32")
for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads()
T.writes(A[vi, vj])
A[vi, vj] = T.float32(2)
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads()
T.writes(B[vi, vj])
B[vi, vj] = T.float32(1)
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads(A[vi, vj], B[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + B[vi, vj]
for ax0, ax1 in T.grid(128, 128):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_global[v0, v1])
A_global[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = A_global[vi, vj] * T.float32(2)
@T.prim_func
def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
with T.block("C"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
T.writes(C[v_ax0, v_ax1])
C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
########## Expected function after cache_read ##########
@T.prim_func
def cache_read_elementwise(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
B = T.alloc_buffer((128, 128))
A_global = T.alloc_buffer((128, 128))
B_local = T.alloc_buffer((128, 128), scope="local")
for i, j in T.grid(128, 128):
with T.block("A_global"):
vi, vj = T.axis.remap("SS", [i, j])
A_global[vi, vj] = A[vi, vj]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A_global[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("B_local"):
vi, vj = T.axis.remap("SS", [i, j])
B_local[vi, vj] = B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B_local[vi, vj] + 1.0
@T.prim_func
def cache_read_under_scope(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
A_global = T.alloc_buffer((128, 128))
for i0, j0 in T.grid(8, 8):
with T.block("scope"):
i, j = T.axis.remap("SS", [i0, j0])
A_local = T.alloc_buffer((16, 16), scope="local")
for x, y in T.grid(16, 16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
A[vi, vj] = 1.0
for x, y in T.grid(16, 16):
with T.block("A_local"):
vi = T.axis.S(16, x)
vj = T.axis.S(16, y)
A_local[vi, vj] = A[i * 16 + vi, j * 16 + vj]
for x, y in T.grid(16, 16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = A_local[vi - i * 16, vj - j * 16] + 1.0
for i, j in T.grid(128, 128):
with T.block("A_global"):
vi, vj = T.axis.remap("SS", [i, j])
A_global[vi, vj] = A[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A_global[vi, vj] * 2.0
@T.prim_func
def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128), dtype="float16")
B = T.match_buffer(b, (128, 128), dtype="float16")
C = T.match_buffer(c, (128, 128), dtype="float16")
D = T.match_buffer(d, (128, 128), dtype="float16")
A_global = T.alloc_buffer((128, 128), dtype="float16")
for i, j in T.grid(128, 128):
with T.block("A_global"):
vi, vj = T.axis.remap("SS", [i, j])
A_global[vi, vj] = A[vi, vj]
for i, j in T.grid(128, 128):
with T.block("load_store"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = A_global[vi, vj]
for i, j in T.grid(8, 8):
with T.block("opaque"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.evaluate(
T.tvm_load_matrix_sync(
B.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A_global.data,
vi * 2048 + vj * 16,
128,
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
for i, j in T.grid(8, 8):
with T.block("match_buffer"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A0 = T.match_buffer(
A_global[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
C0 = T.match_buffer(
C[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
T.evaluate(
T.tvm_load_matrix_sync(
C0.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A0.data,
A0.elem_offset,
A0.strides[0],
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
@T.prim_func
def cache_read_multi_consumer() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
A_global = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A_global[vi] = A[vi]
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A_global[vi] + 1.0
for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A_global[vi]
@T.prim_func
def cache_read_multi_consumer_target() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
A_global = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0
for i in T.grid(128):
with T.block("A"):
vi = T.axis.S(128, i)
A_global[vi] = A[vi]
for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A_global[vi]
@T.prim_func
def continuous_cache_read(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
B = T.alloc_buffer((128, 128))
B_shared = T.alloc_buffer((128, 128), scope="shared")
B_local = T.alloc_buffer((128, 128), scope="local")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
B_shared[vi, vj] = B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("B_local"):
vi, vj = T.axis.remap("SS", [i, j])
B_local[vi, vj] = B_shared[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B_local[vi, vj] + 1.0
@T.prim_func
def block_predicate_cache_read() -> None:
A = T.alloc_buffer([120], dtype="float32")
B = T.alloc_buffer([120], dtype="float32")
A_shared = T.alloc_buffer([120], dtype="float32", scope="shared")
for i, j in T.grid(16, 8):
with T.block("producer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
A[ax] = T.float32(0)
for ax0 in T.serial(120):
with T.block("A_shared"):
v0 = T.axis.spatial(120, ax0)
A_shared[v0] = A[v0]
for i, j in T.grid(16, 8):
with T.block("consumer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
B[ax] = A_shared[ax] + T.float32(1)
@T.prim_func
def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None:
A = T.match_buffer(var_A, (T.int64(128), T.int64(128)), dtype="float32")
C = T.match_buffer(var_C, (T.int64(128), T.int64(128)), dtype="float32")
B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
for ax0, ax1 in T.grid(T.int64(128), T.int64(128)):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_global[v0, v1])
A_global[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A_global[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + T.float32(1)
@T.prim_func
def cache_read_inplace(data_io: T.Buffer(64, "int32")) -> None:
data_1d = T.alloc_buffer([64], dtype="int32")
data_io_local = T.alloc_buffer([64], dtype="int32", scope="local")
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io[v0])
T.writes(data_io_local[v0])
data_io_local[v0] = data_io[v0]
for i0 in T.serial(64):
with T.block("copy_in"):
v0 = T.axis.spatial(64, i0)
T.reads(data_io_local[v0])
T.writes(data_1d[v0])
data_1d[v0] = data_io_local[v0]
for i0 in T.serial(1):
with T.block("ext_call"):
T.reads(data_1d[0:64])
T.writes(data_1d[0:64])
T.evaluate(T.call_extern("call_impl", data_1d.data, dtype=""))
for i0 in T.serial(64):
with T.block("copy_out"):
v0 = T.axis.spatial(64, i0)
T.reads(data_1d[v0])
T.writes(data_io[v0])
data_io[v0] = data_1d[v0]
@T.prim_func
def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None:
data_io_local = T.alloc_buffer([64], dtype="int32", scope="local")
data_io_global = T.alloc_buffer([64], dtype="int32")
data_io_global_1 = T.alloc_buffer([64], dtype="int32")
for ax0 in T.serial(64):
with T.block("data_io_global"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io[v0])
T.writes(data_io_global[v0])
data_io_global[v0] = data_io[v0]
for i0 in T.serial(1):
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io_global[v0])
T.writes(data_io_local[v0])
data_io_local[v0] = data_io_global[v0]
with T.block("ext_call"):
T.reads(data_io_local[0:64])
T.writes(data_io_local[0:64])
T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype=""))
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io_local[v0])
T.writes(data_io_global_1[v0])
data_io_global_1[v0] = data_io_local[v0]
for ax0 in T.serial(64):
with T.block("data_io_global"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io_global_1[v0])
T.writes(data_io[v0])
data_io[v0] = data_io_global_1[v0]
@T.prim_func
def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
B_global = T.alloc_buffer((T.int64(1),), "int32")
for ax0 in range(T.int64(1)):
with T.block("B_global"):
v0 = T.axis.spatial(T.int64(1), ax0)
T.reads(B[v0])
T.writes(B_global[v0])
B_global[v0] = B[v0]
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
with T.block("C"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0])
T.writes(C[v_ax0, v_ax1])
C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1]
########## Expected function after cache_write ##########
@T.prim_func
def cache_write_elementwise(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
B = T.alloc_buffer((128, 128))
B_global = T.alloc_buffer((128, 128), scope="local")
C_local = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B_global"):
vi, vj = T.axis.remap("SS", [i, j])
B_global[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = B_global[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C_local"):
vi, vj = T.axis.remap("SS", [i, j])
C_local[vi, vj] = B[vi, vj] + 1.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = C_local[vi, vj]
@T.prim_func
def cache_write_under_scope(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
A_global = T.alloc_buffer((128, 128))
for i0, j0 in T.grid(8, 8):
with T.block("scope"):
i, j = T.axis.remap("SS", [i0, j0])
A_local = T.alloc_buffer((16, 16), scope="local")
B_global = T.alloc_buffer((16, 16))
for x, y in T.grid(16, 16):
with T.block("A_local"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
A_local[vi - i * 16, vj - j * 16] = 1.0
for x, y in T.grid(16, 16):
with T.block("A"):
vi = T.axis.S(16, x)
vj = T.axis.S(16, y)
A_global[i * 16 + vi, j * 16 + vj] = A_local[vi, vj]
for x, y in T.grid(16, 16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B_global[vi - i * 16, vj - j * 16] = A_global[vi, vj] + 1.0
for x, y in T.grid(16, 16):
with T.block("B_global"):
vi = T.axis.S(16, x)
vj = T.axis.S(16, y)
B[i * 16 + vi, j * 16 + vj] = B_global[vi, vj]
for i, j in T.grid(128, 128):
with T.block("A_global"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = A_global[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
@T.prim_func
def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128), dtype="float16")
B = T.match_buffer(b, (128, 128), dtype="float16")
C = T.match_buffer(c, (128, 128), dtype="float16")
D = T.match_buffer(d, (128, 128), dtype="float16")
D_global = T.alloc_buffer((128, 128), dtype="float16")
B_global = T.alloc_buffer((128, 128), dtype="float16")
C_global = T.alloc_buffer((128, 128), dtype="float16")
for i, j in T.grid(128, 128):
with T.block("load_store"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(D_global[vi, vj])
D_global[vi, vj] = A[vi, vj]
for i, j in T.grid(8, 8):
with T.block("opaque"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.evaluate(
T.tvm_load_matrix_sync(
B_global.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A.data,
vi * 2048 + vj * 16,
128,
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
for i, j in T.grid(8, 8):
with T.block("match_buffer"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A0 = T.match_buffer(
A[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
C0 = T.match_buffer(
C_global[
vi * 16 : vi * 16 + 16,
vj * 16 : vj * 16 + 16,
],
(16, 16),
"float16",
strides=[128, 1],
offset_factor=1,
)
T.evaluate(
T.tvm_load_matrix_sync(
C0.data,
16,
16,
16,
vi * 8 + vj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
A0.data,
A0.elem_offset,
A0.strides[0],
1,
dtype="handle",
),
128,
"row_major",
dtype="handle",
)
)
for i, j in T.grid(128, 128):
with T.block("D"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = D_global[vi, vj]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = B_global[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = C_global[vi, vj]
@T.prim_func
def cache_write_multi_consumer() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
A_global = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A_global"):
vi = T.axis.S(128, i * 16 + j)
A_global[vi] = 1.0
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = A_global[vi]
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0
for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A[vi]
@T.prim_func
def cache_write_multi_consumer_B_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = 1.0
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A_global[vi] + 1.0
for ax0 in T.serial(128):
with T.block("A_global"):
v0 = T.axis.spatial(128, ax0)
A[v0] = A_global[v0]
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A[vi]
@T.prim_func
def cache_write_multi_consumer_C_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = T.float32(1)
for ax0 in T.serial(16):
with T.block("A_global"):
v0 = T.axis.spatial(128, i * 16 + ax0)
A[v0] = A_global[v0]
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A[vi] + T.float32(1)
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A_global[vi]
@T.prim_func
def cache_write_multi_consumer_all_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = T.float32(1)
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A_global[vi] + T.float32(1)
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A_global[vi]
for ax0 in T.serial(128):
with T.block("A_global"):
v0 = T.axis.spatial(128, ax0)
A[v0] = A_global[v0]
@T.prim_func
def continuous_cache_write(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.alloc_buffer((128, 128))
C = T.match_buffer(c, (128, 128))
B_shared = T.alloc_buffer((128, 128), scope="shared")
B_local = T.alloc_buffer((128, 128), scope="local")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_local[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_shared[vi, vj] = B_local[vi, vj]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = B_shared[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def block_predicate_cache_write_intermediate_buf() -> None:
A = T.alloc_buffer([120], dtype="float32")
B = T.alloc_buffer([120], dtype="float32")
A_shared = T.alloc_buffer([120], dtype="float32", scope="shared")
for i, j in T.grid(16, 8):
with T.block("producer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
A_shared[ax] = T.float32(0)
for ax0 in T.serial(120):
with T.block("A_shared"):
v0 = T.axis.spatial(120, ax0)
A[v0] = A_shared[v0]
for i, j in T.grid(16, 8):
with T.block("consumer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
B[ax] = A[ax] + 1.0
@T.prim_func
def block_predicate_cache_write_output_buf() -> None:
A = T.alloc_buffer([120], dtype="float32")
B = T.alloc_buffer([120], dtype="float32")
B_shared = T.alloc_buffer([120], dtype="float32", scope="shared")
for i, j in T.grid(16, 8):
with T.block("producer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
A[ax] = T.float32(0)
for i, j in T.grid(16, 8):
with T.block("consumer"):
ax = T.axis.spatial(120, i * 8 + j)
T.where(i * 8 + j < 120)
B_shared[ax] = A[ax] + T.float32(1)
for ax0 in T.serial(120):
with T.block("B_shared"):
v0 = T.axis.spatial(120, ax0)
B[v0] = B_shared[v0]
@T.prim_func
def symbolic_matmul_blocked(var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32):
A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4))
B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32))
C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32))
for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32):
with T.block("matmul_o"):
v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0])
T.reads(
A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4],
B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32],
)
T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32])
for i0_1, i1_1, k in T.grid(32, 32, 4):
with T.block("matmul"):
v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k])
T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i])
T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i])
with T.init():
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0)
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = (
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]
+ A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i]
)
@T.prim_func
def symbolic_matmul_blocked_cache_read(
var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32
):
A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4))
B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32))
C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32))
for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32):
with T.block("matmul_o"):
v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0])
T.reads(
A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4],
B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32],
)
T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32])
A_shared = T.alloc_buffer((32, 4), scope="shared")
for ax0, ax1 in T.grid(32, 4):
with T.block("A_shared"):
v0 = T.axis.spatial(32, ax0)
v1 = T.axis.spatial(4, ax1)
T.reads(A[v_i0_o * 32 + v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v_i0_o * 32 + v0, v1]
for i0_1, i1_1, k in T.grid(32, 32, 4):
with T.block("matmul"):
v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k])
T.reads(A_shared[v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i])
T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i])
with T.init():
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0)
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = (
C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]
+ A_shared[v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i]
)
@T.prim_func
def symbolic_matmul_blocked_cache_write(
var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32
):
A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4))
B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32))
C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32))
for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32):
with T.block("matmul_o"):
v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0])
T.reads(
A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4],
B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32],
)
T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32])
C_pad_local = T.alloc_buffer((32, 32), scope="local")
for i0_1, i1_1, k in T.grid(32, 32, 4):
with T.block("matmul"):
v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k])
T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i])
T.writes(C_pad_local[v_i0_i, v_i1_i])
with T.init():
C_pad_local[v_i0_i, v_i1_i] = T.float32(0)
C_pad_local[v_i0_i, v_i1_i] = (
C_pad_local[v_i0_i, v_i1_i]
+ A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i]
)
for ax0, ax1 in T.grid(32, 32):
with T.block("C_pad_local"):
v0 = T.axis.spatial(32, ax0)
v1 = T.axis.spatial(32, ax1)
T.reads(C_pad_local[v0, v1])
T.writes(C[v_i0_o * 32 + v0, v_i1_o * 32 + v1])
C[v_i0_o * 32 + v0, v_i1_o * 32 + v1] = C_pad_local[v0, v1]
########## Testcases for cache_read ##########
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
def test_cache_read_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
if use_block_name:
cached_a = sch.cache_read("B", "A", "global")
cached_b = sch.cache_read("C", "B", "local")
else:
cached_a = sch.cache_read(block_b, 0, "global")
cached_b = sch.cache_read(block_c, 0, "local")
assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
assert sch.get(block_c) == sch.get(sch.get_block("C"))
assert_structural_equal_ignore_global_symbol(cache_read_elementwise, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_cache_read_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "local")
sch.cache_read(block_c, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_read_under_scope, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
def test_cache_read_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block = "load_store" if use_block_name else sch.get_block("load_store")
sch.cache_read(block, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_read_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)
def test_cache_read_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
# Test that specific consumer block targeting works.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c])
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
# Also test setting multiple consumers yields same result as unspecified.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c])
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
def test_continuous_cache_read(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "shared")
sch.cache_read(block_c, 0, "local")
assert_structural_equal_ignore_global_symbol(continuous_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_cache_read_with_block_predicate(use_block_name):
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_read(block, 0, "shared")
assert_structural_equal_ignore_global_symbol(block_predicate_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
def test_cache_read_non_int32_shape(use_block_name):
sch = tir.Schedule(elementwise_shape_int64, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_read_shape_int64, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
def test_cache_read_nested_buffer_access(use_block_name):
sch = tir.Schedule(nested_buffer_access, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 1, "global")
assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=nested_buffer_access)
def test_cache_read_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "global")
def test_cache_read_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 1, "global")
def test_cache_read_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "test_scope")
def test_cache_read_allocate_const():
@T.prim_func
def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
B_buf = T.decl_buffer((8), dtype="float32", data=B)
for i in range(8):
with T.block("C"):
vi = T.axis.spatial(8, i)
C[vi] = A[vi] + B_buf[vi]
@T.prim_func
def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
B_buf_global = T.alloc_buffer((8), dtype="float32")
A_global = T.alloc_buffer((8), dtype="float32")
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
B_buf = T.decl_buffer((8), data=B)
for ax0 in range(8):
with T.block("A_global"):
v0 = T.axis.spatial(8, ax0)
A_global[v0] = A[v0]
for ax0 in range(8):
with T.block("B_buf_global"):
v0 = T.axis.spatial(8, ax0)
B_buf_global[v0] = B_buf[v0]
for i in range(8):
with T.block("C"):
vi = T.axis.spatial(8, i)
C[vi] = A_global[vi] + B_buf_global[vi]
sch = tir.Schedule(before)
block_c = sch.get_block("C")
sch.cache_read(block_c, 1, "global")
sch.cache_read(block_c, 0, "global")
after = sch.mod["main"]
assert_structural_equal_ignore_global_symbol(expected, after)
verify_trace_roundtrip(sch=sch, mod=before)
def test_inplace_cache_read():
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
block = sch.get_block("copy_in")
sch.cache_read(block, 0, "local", [block])
assert_structural_equal_ignore_global_symbol(cache_read_inplace, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=inplace_func)
def test_cache_inplace():
# cache_inplace could introduce WAR, which is expected but stage pipeline property changes
debug_mask = tvm.tir.schedule.state.ScheduleDebugMask.VERIFY_SREF_TREE
sch = tvm.tir.Schedule(inplace_call, debug_mask=debug_mask)
block = sch.get_block("ext_call")
blocks = sch.cache_inplace(block, 0, "local")
block = sch.cache_read(blocks[0], 0, "global", [blocks[0]])
block = sch.cache_write(blocks[1], 0, "global")
assert_structural_equal_ignore_global_symbol(cache_inplace_buffer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask)
def test_cache_read_nested_seq(use_block_name):
sch = tir.Schedule(func_nested_seq, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c])
assert_structural_equal_ignore_global_symbol(cache_read_nested_seq_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_nested_seq)
########## Testcases for cache_write ##########
def test_cache_write_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local")
cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global")
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
assert sch.get(block_c) == sch.get(sch.get_block("C"))
assert_structural_equal_ignore_global_symbol(cache_write_elementwise, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_cache_write_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
block_scope = sch.get_block("scope")
sch.cache_write(block_a, 0, "local")
sch.cache_write(block_b, 0, "global")
sch.cache_write(block_scope, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_write_under_scope, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
def test_cache_write_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block_store = "load_store" if use_block_name else sch.get_block("load_store")
block_opaque = "opaque" if use_block_name else sch.get_block("opaque")
block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer")
sch.cache_write(block_store, 0, "global")
sch.cache_write(block_opaque, 0, "global")
sch.cache_write(block_match_buffer, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_write_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)
def test_cache_write_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
sch.cache_write(block_a, 0, "global")
assert_structural_equal_ignore_global_symbol(cache_write_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
# Test that specific consumer block targeting works.
# B read cache buffer and C read original output buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b])
assert_structural_equal_ignore_global_symbol(
cache_write_multi_consumer_B_consume_cache, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
# Test that specific consumer block targeting works.
# B read original output buffer and C read cache buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c])
assert_structural_equal_ignore_global_symbol(
cache_write_multi_consumer_C_consume_cache, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
# Test that specific consumer block targeting works.
# B and C read cache buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c])
assert_structural_equal_ignore_global_symbol(
cache_write_multi_consumer_all_consume_cache, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
def test_continuous_cache_write(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_write(block_b, 0, "shared")
sch.cache_write(block_b, 0, "local")
assert_structural_equal_ignore_global_symbol(continuous_cache_write, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_cache_write_with_block_predicate(use_block_name):
# cache write for intermediate buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = "producer" if use_block_name else sch.get_block("producer")
sch.cache_write(block, 0, "shared")
assert_structural_equal_ignore_global_symbol(
block_predicate_cache_write_intermediate_buf, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
# cache write for external buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_write(block, 0, "shared")
assert_structural_equal_ignore_global_symbol(
block_predicate_cache_write_output_buf, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
def test_cache_write_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_a0 = "A0" if use_block_name else sch.get_block("A0")
block_a1 = "A1" if use_block_name else sch.get_block("A1")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a0, 0, "global")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a1, 0, "global")
def test_cache_write_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 1, "global")
def test_cache_write_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 0, "test_scope")
@pytest.mark.parametrize("use_decl_buffer", [True, False])
def test_cache_write_allocate_const(use_decl_buffer):
def apply_decl_buffer(*args, **kwargs):
if use_decl_buffer:
return T.decl_buffer(*args, **kwargs)
else:
return T.Buffer(*args, **kwargs)
@T.prim_func
def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")):
B = T.alloc_buffer([128, 128], dtype="float32")
const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
for i, j in T.grid(128, 128):
for x in range(8):
with T.block("B"):
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")):
B = T.alloc_buffer([128, 128], dtype="float32")
A_global = T.alloc_buffer([128, 128], dtype="float32")
C_global = T.alloc_buffer([128, 128], dtype="float16")
const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
for ax0, ax1 in T.grid(128, 128):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_global[v0, v1])
A_global[v0, v1] = A[v0, v1]
for i, j, x in T.grid(128, 128, 8):
with T.block("B"):
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx])
T.writes(B[vi, vj])
B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(C_global[vi, vj])
C_global[vi, vj] = B[vi, vj] + T.float32(1)
for ax0, ax1 in T.grid(128, 128):
with T.block("C_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(C_global[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_global[v0, v1]
sch = tir.Schedule(before)
block_b = sch.get_block("B")
block_c = sch.get_block("C")
sch.cache_read(block_b, 0, "global")
sch.cache_write(block_c, 0, "global")
after = sch.mod["main"]
assert_structural_equal_ignore_global_symbol(expected, after)
verify_trace_roundtrip(sch=sch, mod=before)
def test_reindex_cache_read():
sch = tir.Schedule(elementwise, debug_mask="all")
sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2))
assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_reindex_cache_read_multi_consumer():
sch = tir.Schedule(func_multi_consumer)
sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32))
assert_structural_equal_ignore_global_symbol(reindex_cache_read_multi_consumer, sch.mod["main"])
# NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues.
def test_reindex_cache_read_fail_not_match():
sch = tir.Schedule(elementwise, debug_mask="all")
with pytest.raises(tvm.tir.ScheduleError):
sch.reindex_cache_read(
"C",
0,
"shared",
lambda i, j: j * 2,
)
def test_reindex_cache_read_failed_not_single_point():
sch = tir.Schedule(access_under_scope, debug_mask="all")
with pytest.raises(tvm.tir.ScheduleError):
sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j))
def test_reindex_cache_write():
sch = tir.Schedule(elementwise, debug_mask="all")
sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i))
assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_write, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_reindex_cache_write_reduce():
sch = tir.Schedule(reduce, debug_mask="all")
sch.reindex_cache_write("B", 0, "shared", lambda i, j, k, l: (j, i, k))
assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_0, sch.mod["main"])
sch.reindex_cache_write("C", 0, "shared", lambda i, j, k: [j, i])
assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_1, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=reduce)
def test_reindex_cache_write_fail_not_match():
sch = tir.Schedule(elementwise, debug_mask="all")
with pytest.raises(tvm.tir.ScheduleError):
sch.reindex_cache_write(
"B",
0,
"shared",
lambda i, j: i,
)
def test_reindex_cache_write_fail_not_single_point():
sch = tir.Schedule(access_under_scope, debug_mask="all")
with pytest.raises(tvm.tir.ScheduleError):
sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j))
def test_symbolic_matmul_blocked_cache_read(use_block_name):
sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all")
block = "matmul" if use_block_name else sch.get_block("matmul")
sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared")
assert_structural_equal_ignore_global_symbol(
sch.mod["main"], symbolic_matmul_blocked_cache_read
)
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
def test_symbolic_matmul_blocked_cache_write(use_block_name):
sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all")
block = "matmul" if use_block_name else sch.get_block("matmul")
sch.cache_write(block=block, write_buffer_index=0, storage_scope="local")
assert_structural_equal_ignore_global_symbol(
sch.mod["main"], symbolic_matmul_blocked_cache_write
)
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
if __name__ == "__main__":
tvm.testing.main()