blob: b3e72943bf6ebd3c17293442274d6d21486aa326 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import (
verify_trace_roundtrip,
assert_structural_equal_ignore_global_symbol,
)
# pylint: disable=no-member,invalid-name,unused-variable
@T.prim_func
def elementwise(a: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
D = T.match_buffer(d, (64, 64))
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
with T.block("C"):
vi = T.axis.spatial(128, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + T.float32(1)
for i_0, j_0, i_1, j_1 in T.grid(8, 8, 8, 8):
with T.block("D"):
vi = T.axis.spatial(64, i_0 * 8 + i_1)
vj = T.axis.spatial(64, j_0 * 8 + j_1)
T.reads(B[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = B[vi, vj] + T.float32(2)
@T.prim_func
def elementwise_merged(a: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
D = T.match_buffer(d, (64, 64))
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for i_0_m in range(8):
for j_0, i_1, j_1 in T.grid(8, 16, 16):
with T.block("C"):
vi = T.axis.spatial(128, i_0_m * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + T.float32(1)
for j_0, i_1, j_1 in T.grid(8, 8, 8):
with T.block("D"):
vi = T.axis.spatial(64, i_0_m * 8 + i_1)
vj = T.axis.spatial(64, j_0 * 8 + j_1)
T.reads(B[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = B[vi, vj] + T.float32(2)
@T.prim_func
def elementwise_merged2(a: T.handle, c: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
D = T.match_buffer(d, (64, 64))
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for i_0_m, j_0_m in T.grid(8, 8):
for i_1, j_1 in T.grid(16, 16):
with T.block("C"):
vi = T.axis.spatial(128, i_0_m * 16 + i_1)
vj = T.axis.spatial(128, j_0_m * 16 + j_1)
T.reads(B[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + T.float32(1)
for i_1, j_1 in T.grid(8, 8):
with T.block("D"):
vi = T.axis.spatial(64, i_0_m * 8 + i_1)
vj = T.axis.spatial(64, j_0_m * 8 + j_1)
T.reads(B[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = B[vi, vj] + T.float32(2)
def test_merge():
sch = tir.Schedule(elementwise, debug_mask="all")
block_c = sch.get_block("C")
block_d = sch.get_block("D")
i = sch.get_loops(block_c)[0]
j = sch.get_loops(block_d)[0]
sch.merge(i, j)
assert_structural_equal_ignore_global_symbol(elementwise_merged, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_merge2():
sch = tir.Schedule(elementwise, debug_mask="all")
block_c = sch.get_block("C")
block_d = sch.get_block("D")
i = sch.get_loops(block_c)[1]
j = sch.get_loops(block_d)[1]
sch.merge(i, j)
assert_structural_equal_ignore_global_symbol(elementwise_merged2, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_merge_fail_not_only_child():
@T.prim_func
def elementwise_with_seq(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.match_buffer(c, (128, 128, 128))
B = T.alloc_buffer((128, 128, 128))
D = T.alloc_buffer((128, 128, 128))
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("D"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
D[vi, vj, vk] = A[vi, vj, vk] * 2.0
for k in T.serial(0, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[vi, vj, vk] * 2.0
sch = tir.Schedule(elementwise_with_seq, debug_mask="all")
block_b = sch.get_block("B")
_, _, b = sch.get_loops(block_b)
block_c = sch.get_block("C")
_, _, c = sch.get_loops(block_c)
with pytest.raises(tvm.tir.ScheduleError):
sch.merge(b, c)
def test_merge_fail_not_start_with_zero():
@T.prim_func
def elementwise_loops_not_start_with_zero(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.match_buffer(c, (128, 128, 128))
B = T.alloc_buffer((128, 128, 128))
for i, j in T.grid(128, 128):
for k in T.serial(1, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] * 2.0
sch = tir.Schedule(elementwise_loops_not_start_with_zero, debug_mask="all")
block_b = sch.get_block("B")
_, _, b = sch.get_loops(block_b)
block_c = sch.get_block("C")
_, _, c = sch.get_loops(block_c)
with pytest.raises(tvm.tir.ScheduleError):
sch.merge(b, c)
def test_merge_fail_not_same_extent():
@T.prim_func
def elementwise_loops_not_same_extent(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.match_buffer(c, (128, 128, 128))
B = T.alloc_buffer((64, 128, 128))
for i, j in T.grid(64, 128):
for k in T.serial(0, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] * 2.0
sch = tir.Schedule(elementwise_loops_not_same_extent, debug_mask="all")
block_b = sch.get_block("B")
_, _, b = sch.get_loops(block_b)
block_c = sch.get_block("C")
_, _, c = sch.get_loops(block_c)
with pytest.raises(tvm.tir.ScheduleError):
sch.merge(b, c)
def test_merge_fail_not_same_level():
@T.prim_func
def elementwise_not_same_level(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.match_buffer(c, (128, 128, 128))
B = T.alloc_buffer((128, 128, 128))
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] * 2.0
sch = tir.Schedule(elementwise_not_same_level, debug_mask="all")
block_b = sch.get_block("B")
_, b, _ = sch.get_loops(block_b)
block_c = sch.get_block("C")
_, _, c = sch.get_loops(block_c)
with pytest.raises(tvm.tir.ScheduleError):
sch.merge(b, c)
def test_merge_fail_with_different_scope():
@T.prim_func
def elementwise_with_different_scope(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.match_buffer(c, (128, 128, 128))
B = T.alloc_buffer((128, 128, 128))
with T.block("A"):
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] * 2.0
sch = tir.Schedule(elementwise_with_different_scope, debug_mask="all")
block_b = sch.get_block("B")
_, _, b = sch.get_loops(block_b)
block_c = sch.get_block("C")
_, _, c = sch.get_loops(block_c)
with pytest.raises(tvm.tir.ScheduleError):
sch.merge(b, c)
if __name__ == "__main__":
tvm.testing.main()