blob: 4ed3c6178fce0c7f5394c5ba143a630a17866880 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import (
verify_trace_roundtrip,
assert_structural_equal_ignore_global_symbol,
)
# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@T.prim_func
def rowsum_blockized(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [32, 4])
A = T.match_buffer(a, [32, 4, 128])
for i0, i2_0 in T.grid(32, 16):
with T.block("blockized_B"):
io, ko = T.axis.remap("SR", [i0, i2_0])
with T.init():
for i1 in T.serial(0, 4):
with T.block("B_init"):
ii_init = T.axis.S(4, i1)
B[io, ii_init] = 0.0
for i1_1, i2_1 in T.grid(4, 8):
with T.block("B"):
ii = T.axis.S(4, i1_1)
k = T.axis.R(128, ko * 8 + i2_1)
B[io, ii] = B[io, ii] + A[io, ii, k]
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j in T.grid(128, 128):
with T.block("init"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = 0.0
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_decompose1(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [32, 4, 128], elem_offset=0, align=64, offset_factor=1)
B = T.match_buffer(b, [32, 4], elem_offset=0, align=64, offset_factor=1)
for i0 in T.serial(0, 32):
with T.block("blockized_B_init"):
io = T.axis.S(32, i0)
for i1 in T.serial(0, 4):
with T.block("B_init"):
ii = T.axis.S(4, i1)
B[io, ii] = T.float32(0)
for i0, i2_o in T.grid(32, 16):
with T.block("blockized_B_update"):
io, ko = T.axis.remap("SR", [i0, i2_o])
for i1, i2_i in T.grid(4, 8):
with T.block("B"):
ii = T.axis.S(4, i1)
k = T.axis.R(128, ko * 8 + i2_i)
B[io, ii] = B[io, ii] + A[io, ii, k]
@T.prim_func
def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1)
B = T.match_buffer(b, [128, 128], elem_offset=0, align=64, offset_factor=1)
A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1)
for i0, i1 in T.grid(128, 128):
with T.block("update_init"):
vi_init, vj_init = T.axis.remap("SS", [i0, i1])
C[vi_init, vj_init] = T.float32(0)
for i2 in T.serial(0, 128):
with T.block("update_update"):
vi, vj, vk = T.axis.remap("SSR", [i0, i1, i2])
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
@T.prim_func
def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, k, j in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1)
B = T.match_buffer(b, [128, 128], elem_offset=0, align=64, offset_factor=1)
A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1)
# body
with T.block("root"):
T.reads([])
T.writes([])
for i0_0 in T.serial(0, 16):
for i0_1_init, i1_init in T.grid(8, 128):
with T.block("update_init"):
vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init)
vj_init = T.axis.S(128, i1_init)
C[vi_init, vj_init] = T.float32(0)
for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
with T.block("update_update"):
T.where((((i2_0 * 7) + i2_1) < 128))
vi = T.axis.S(128, i0_0 * 8 + i0_1)
vj = T.axis.S(128, i1)
vk = T.axis.R(128, i2_0 * 7 + i2_1)
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
@T.prim_func
def matmul_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
T.block_attr({"test_annotation": 1})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_decompose_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j in T.grid(128, 128):
with T.block("init"):
T.block_attr({"test_annotation": 1})
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = 0.0
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
T.block_attr({"test_annotation": 1})
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def colsum_with_vectorization(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 32], dtype="float32")
B = T.match_buffer(b, [32], dtype="float32")
for k in T.serial(0, 128):
for i in T.vectorized(0, 32):
with T.block("B"):
vk, vi = T.axis.remap("RS", [k, i])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vk, vi]
@T.prim_func
def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 32], dtype="float32")
B = T.match_buffer(b, [32], dtype="float32")
for i in T.vectorized(0, 32):
with T.block("B_init"):
vi = T.axis.S(32, i)
B[vi] = T.float32(0)
for k in T.serial(0, 128):
for i in T.vectorized(0, 32):
with T.block("B"):
vk, vi = T.axis.remap("RS", [k, i])
B[vi] = B[vi] + A[vk, vi]
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
def test_reduction_decompose0(use_block_name):
s = tir.Schedule(matmul, debug_mask="all")
C = "update" if use_block_name else s.get_block("update")
i, j, k = s.get_loops(C)
s.decompose_reduction(C, i)
assert_structural_equal_ignore_global_symbol(matmul_decompose0, s.mod["main"])
verify_trace_roundtrip(s, mod=matmul)
def test_reduction_decompose1(use_block_name):
s = tir.Schedule(rowsum_blockized, debug_mask="all")
blockized_B = "blockized_B" if use_block_name else s.get_block("blockized_B")
io, ko = s.get_loops(blockized_B)
s.decompose_reduction(blockized_B, io)
assert_structural_equal_ignore_global_symbol(matmul_decompose1, s.mod["main"])
verify_trace_roundtrip(s, mod=rowsum_blockized)
def test_reduction_decompose2():
s = tir.Schedule(matmul, debug_mask="all")
C = s.get_block("update")
i, j, k = s.get_loops(C)
s.decompose_reduction(C, k)
assert_structural_equal_ignore_global_symbol(matmul_decompose2, s.mod["main"])
verify_trace_roundtrip(s, mod=matmul)
def test_reduction_decompose3():
s = tir.Schedule(matmul_decompose_fail3, debug_mask="all")
C = s.get_block("update")
i, j, k = s.get_loops(C)
with pytest.raises(tvm.tir.ScheduleError):
s.decompose_reduction(C, k)
def test_reduction_decompose4():
s = tir.Schedule(matmul, debug_mask="all")
C = s.get_block("update")
i, j, k = s.get_loops(C)
io, ii = s.split(i, factors=[16, 8])
ko, ki = s.split(k, factors=[19, 7])
s.decompose_reduction(C, ii)
assert_structural_equal_ignore_global_symbol(matmul_decompose4, s.mod["main"])
verify_trace_roundtrip(s, mod=matmul)
def test_reduction_decompose_with_annotation():
s = tir.Schedule(matmul_with_annotation, debug_mask="all")
C = s.get_block("update")
i, j, k = s.get_loops(C)
s.decompose_reduction(C, i)
assert_structural_equal_ignore_global_symbol(matmul_decompose_with_annotation, s.mod["main"])
verify_trace_roundtrip(s, mod=matmul_with_annotation)
def test_reduction_decompose_with_different_for_kind():
s = tir.Schedule(colsum_with_vectorization, debug_mask="all")
B = s.get_block("B")
k, _ = s.get_loops(B)
B_init = s.decompose_reduction(B, k)
assert_structural_equal_ignore_global_symbol(s.mod["main"], colsum_decompose_with_vectorization)
assert s.get(B).same_as(s.get(s.get_block("B_update")))
assert s.get(B_init).same_as(s.get(s.get_block("B_init")))
verify_trace_roundtrip(s, mod=colsum_with_vectorization)
def test_decompose_reduction_ref_hash_check():
mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main"))
mod_bak = mod
hash_before = tvm.ir.structural_hash(mod_bak)
s = tir.Schedule(mod["main"], debug_mask="all")
C = s.get_block("update")
i, j, k = s.get_loops(C)
s.decompose_reduction(C, k)
hash_after = tvm.ir.structural_hash(mod_bak)
assert hash_before == hash_after
def test_decompose_reduction_nested_block():
@T.prim_func
def nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")):
for i, ko in T.grid(1, 2):
with T.block("outer"):
vi, vko = T.axis.remap("SR", [i, ko])
C = T.alloc_buffer((32,), dtype="float32")
with T.init():
B[vi] = T.float32(0)
for ki in T.serial(32):
with T.block("inner_1"):
vki = T.axis.remap("S", [ki])
C[vki] = A[vi, vko * 32 + vki]
for ki in T.serial(32):
with T.block("inner_2"):
vki = T.axis.remap("R", [ki])
B[vi] += C[vki]
@T.prim_func
def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")):
for i in range(1):
with T.block("outer_init"):
vi = T.axis.spatial(1, i)
T.reads()
T.writes(B[vi])
B[vi] = T.float32(0)
for ko in range(2):
with T.block("outer_update"):
vi, vko = T.axis.remap("SR", [i, ko])
T.reads(B[vi], A[vi, vko * 32 : vko * 32 + 32])
T.writes(B[vi])
C = T.alloc_buffer((32,))
for ki in range(32):
with T.block("inner_1"):
vki = T.axis.spatial(32, ki)
T.reads(A[vi, vko * 32 + vki])
T.writes(C[vki])
C[vki] = A[vi, vko * 32 + vki]
for ki in range(32):
with T.block("inner_2"):
vki = T.axis.reduce(32, ki)
T.reads(B[vi], C[vki])
T.writes(B[vi])
B[vi] = B[vi] + C[vki]
sch = tir.Schedule(nested_block, debug_mask="all")
outer = sch.get_block("outer")
i, ko = sch.get_loops(outer)
sch.decompose_reduction(outer, ko)
assert_structural_equal_ignore_global_symbol(decomposed_nested_block, sch.mod["main"])
verify_trace_roundtrip(sch, mod=nested_block)
class TestDecomposeReductionWithThreadBinding(tvm.testing.CompareBeforeAfter):
def transform(self):
def func(mod):
sch = tir.Schedule(mod)
t, _ = sch.get_loops("B")
sch.decompose_reduction("B", t)
return sch.mod
return func
@T.prim_func
def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")):
for t in T.thread_binding(0, 32, thread="threadIdx.x"):
for r in T.serial(16):
with T.block("B"):
vi, vr = T.axis.remap("SR", [t, r])
with T.init():
B[vi] = T.float32(0)
B[vi] += A[vi, vr]
@T.prim_func
def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")):
for t_init in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.block("B_init"):
vi = T.axis.remap("S", [t_init])
B[vi] = T.float32(0)
for t in T.thread_binding(0, 32, thread="threadIdx.x"):
for r in T.serial(16):
with T.block("B"):
vi, vr = T.axis.remap("SR", [t, r])
B[vi] += A[vi, vr]
if __name__ == "__main__":
tvm.testing.main()