blob: f1fe5ece972ca587ff7db94fa4099bc8d0d7325e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import s_tir
from tvm.script import tir as T
@T.prim_func
def buffer_load_store_func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.match_buffer(b, (128, 128), "float32")
C = T.alloc_buffer((128, 128), "float32")
D = T.alloc_buffer((128, 128), "float32")
for ii, jj in T.grid(128, 128):
with T.sblock():
i, j = T.axis.remap("SS", [ii, jj])
A[i, j] = T.float32(0)
for i0, j0, k0 in T.grid(32, 32, 32):
with T.sblock():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
for ii, jj in T.grid(4, 4):
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
for ii, jj in T.grid(4, 4):
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += (
D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
)
@T.prim_func
def buffer_opaque_access(b: T.handle, c: T.handle) -> None:
B = T.match_buffer(b, [16, 16], "float32")
C = T.match_buffer(c, [16, 16], "float32")
with T.sblock():
T.reads([])
T.writes(B[0:16, 0:16])
A = T.decl_buffer([256], "float32")
for i, j in T.grid(16, 16):
A[i * 16 + j] = 1
for i in range(0, 16):
for j in range(0, 16):
T.evaluate(A[i * 16 + j])
for j in range(0, 16):
T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle"))
for i, j in T.grid(16, 16):
with T.sblock():
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj]
@T.prim_func
def lca_is_func_root(a: T.handle) -> None:
A = T.match_buffer(a, [0, 0], "float32")
A[0, 0] = 1.0
@T.prim_func
def match_buffer_func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.match_buffer(b, (128, 128), "float32")
for i, j in T.grid(8, 8):
with T.sblock("block"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16])
T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4))
B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8))
for ii, jj in T.grid(16, 16):
with T.sblock("AAA"):
vii, vjj = T.axis.remap("SS", [ii, jj])
AA = T.match_buffer(A[vii, vjj], ())
AA[()] = 1.0
T.evaluate(B0.data)
T.evaluate(B1.data)
@T.prim_func
def global_buffer_with_blockidx(
a: T.Buffer((1, 32), "int32"), b: T.Buffer((1, 32), "int32")
) -> None:
for i0 in T.thread_binding(0, 1, thread="blockIdx.x"):
for i1 in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.sblock("copy"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(a[i, j])
T.writes(b[i, j])
b[i, j] = a[i, j]
def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
C, D = func.body.block.alloc_buffers
lca = s_tir.analysis.detect_buffer_access_lca(func)
# LCA of Buffer A is root
root_block = func.body.block
assert lca[A] == func.body.block
# LCA of Buffer B is the loop dominate all reduction loop
reduce_dom_loop = root_block.body[1].body
reduce_block = reduce_dom_loop.body.body.block
assert lca[B] == reduce_dom_loop
# LCA of Buffer C is the second loop kk
loop_jj = reduce_block.body.body
assert lca[C] == loop_jj
# LCA of Buffer D is loop jj
loop_kk = loop_jj.body[1]
assert lca[D] == loop_kk
def test_opaque_access():
func = buffer_opaque_access
B, C = [func.buffer_map[x] for x in func.params]
lca = s_tir.analysis.detect_buffer_access_lca(func)
# Cannot detect buffer A since it is define by low-level Allocate
# LCA of Buffer B is root
root_block = func.body.block
assert lca[B] == func.body.block
# LCA of Buffer C is the correspond block
assert lca[C] == root_block.body[1].body.body.block
def test_lca_func_root():
func = lca_is_func_root
(A,) = [func.buffer_map[x] for x in func.params]
lca = s_tir.analysis.detect_buffer_access_lca(func)
assert lca[A] is None
def test_match_buffer():
func = match_buffer_func
A, B = [func.buffer_map[x] for x in func.params]
lca = s_tir.analysis.detect_buffer_access_lca(func)
root_block = func.body.block
block = root_block.body.body.body.block
block_inner = block.body[0].body.body.block
# LCA of Buffer C is the inner block
assert lca[A] == block_inner
# LCA of Buffer C is the main block
assert lca[B] == block
def test_global_buffer_with_blockidx():
func = global_buffer_with_blockidx
A, B = [func.buffer_map[x] for x in func.params]
lca = s_tir.analysis.detect_buffer_access_lca(func)
root_block = func.body.block
blockidx_loop = root_block.body
# LCA of both A and B should be the loop bound to `blockIdx`
assert lca[A] == blockidx_loop
assert lca[B] == blockidx_loop
if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()
test_lca_func_root()
test_match_buffer()
test_global_buffer_with_blockidx()