blob: 0ad05ea8328892aa87390c1d4459a11ab0a1dfe0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.ir import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.testing import (
assert_structural_equal_ignore_global_symbol,
verify_trace_roundtrip,
)
# pylint: disable=no-member,invalid-name,unused-variable
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j in T.grid(128, 128):
with T.block("init"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = 0.0
for k in range(0, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024))
B = T.match_buffer(b, (1024, 1024))
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(1024, 1024):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.max(C[vi, vj], 0.0)
@T.prim_func
def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024))
B = T.match_buffer(b, (1024, 1024))
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
for i in T.serial(0, 1024, annotations={"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}):
for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}):
for k in T.serial(0, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(1024, 1024):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.max(C[vi, vj], 0.0)
@T.prim_func
def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024))
B = T.match_buffer(b, (1024, 1024))
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
T.block_attr({"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}})
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(1024, 1024):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"test2": 0.22, "test3": ["aa", 1]})
D[vi, vj] = T.max(C[vi, vj], 0.0)
@tvm.script.ir_module
class ModuleWithMultipleFuncs:
@T.prim_func
def vector_add(
A: T.Buffer(128, "float32"),
B: T.Buffer(128, "float32"),
) -> None:
for i in range(128):
with T.block("init"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]
@T.prim_func
def vector_add_2(
A: T.Buffer(128, "float32"),
B: T.Buffer(128, "float32"),
) -> None:
for i in range(128):
with T.block("init"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]
@T.prim_func
def tuple_reduction(data: T.Buffer((4, 32), "float32"), T_add: T.Buffer((4,), "float32")) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
data_red_temp_v0 = T.alloc_buffer([4], dtype="float32")
data_red_temp_v1 = T.alloc_buffer([4], dtype="float32")
for i0, i1 in T.grid(4, 32):
with T.block("data_red_temp"):
ax0, k1 = T.axis.remap("SR", [i0, i1])
T.reads(data[ax0, k1])
T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0])
with T.init():
data_red_temp_v0[ax0] = T.float32(0)
data_red_temp_v1[ax0] = T.float32(0)
v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + data[ax0, k1]
v_data_red_temp_v1: T.float32 = (
data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1]
)
data_red_temp_v0[ax0] = v_data_red_temp_v0
data_red_temp_v1[ax0] = v_data_red_temp_v1
for i0 in range(4):
with T.block("T_add"):
ax0 = T.axis.remap("S", [i0])
T.reads(data_red_temp_v0[ax0], data_red_temp_v1[ax0])
T.writes(T_add[ax0])
T_add[ax0] = data_red_temp_v0[ax0] + data_red_temp_v1[ax0]
# pylint: enable=no-member,invalid-name,unused-variable
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
def test_tir_schedule_creation():
# Tests:
# - Schedule.__init__ for PrimFunc and IRModule
# - Schedule.mod
# - Schedule.state
sch_1 = tir.Schedule(matmul, debug_mask="all")
sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mask="all")
assert sch_1.mod["main"].same_as(sch_2.mod["main"])
assert sch_1.state.mod["main"].same_as(sch_2.state.mod["main"])
def test_tir_schedule_get_block():
# Tests:
# - Schedule.get_block
# - Schedule.get_sref
# - Schedule.get
sch = tir.Schedule(matmul, debug_mask="all")
block_rv = sch.get_block(name="update")
block_sref = sch.get_sref(block_rv)
block = sch.get(block_rv)
assert block.name_hint == "update"
assert block_sref.stmt.same_as(block)
assert sch.state.get_sref(block).same_as(block_sref)
assert block.same_as(matmul.body.block.body.body.body[1].body.block)
def test_tir_schedule_work_on():
sch = tir.Schedule(ModuleWithMultipleFuncs, debug_mask="all")
with pytest.raises(ValueError, match="does not know which function to be working on"):
sch.get_block(name="init")
sch.work_on(func_name="vector_add")
sch.get_block(name="init")
assert sch.func_working_on == sch.mod.get_global_var("vector_add")
def test_tir_schedule_get_loops(use_block_name):
# Tests:
# - Schedule.get_loops
# - Schedule.get
sch = tir.Schedule(matmul, debug_mask="all")
block = "update" if use_block_name else sch.get_block(name="update")
i, j, k = sch.get_loops(block)
assert sch.get(i).loop_var.name == "i"
assert sch.get(j).loop_var.name == "j"
assert sch.get(k).loop_var.name == "k"
def test_tir_schedule_copy_1(use_block_name):
# Tests:
# - Schedule.copy
sch_1 = tir.Schedule(matmul, debug_mask="all")
block_rv = sch_1.get_block(name="update")
i, j, k = sch_1.get_loops(block="update" if use_block_name else block_rv)
assert sch_1.get(i).loop_var.name == "i"
assert sch_1.get(j).loop_var.name == "j"
assert sch_1.get(k).loop_var.name == "k"
sch_2 = sch_1.copy()
assert sch_2.get(block_rv).name_hint == "update"
assert sch_2.get(i).loop_var.name == "i"
assert sch_2.get(j).loop_var.name == "j"
assert sch_2.get(k).loop_var.name == "k"
def test_tir_schedule_copy_2():
sch = tir.Schedule(mod=matmul, debug_mask="all")
i, j, k = sch.get_loops(sch.get_block("update"))
sch_copy = sch.copy()
assert not sch.get_sref(i).same_as(sch_copy.get_sref(i))
assert not sch.get_sref(j).same_as(sch_copy.get_sref(j))
assert not sch.get_sref(k).same_as(sch_copy.get_sref(k))
assert sch.get_sref(i).stmt.same_as(sch_copy.get_sref(i).stmt)
assert sch.get_sref(j).stmt.same_as(sch_copy.get_sref(j).stmt)
assert sch.get_sref(k).stmt.same_as(sch_copy.get_sref(k).stmt)
i_0, i_1 = sch.split(i, factors=[None, 64])
j_0, j_1 = sch_copy.split(j, factors=[None, 32])
assert sch.get_sref(i_0).stmt.extent == 2
assert sch.get_sref(i_1).stmt.extent == 64
with pytest.raises(IndexError):
sch_copy.get_sref(i_0)
with pytest.raises(IndexError):
sch_copy.get_sref(i_1)
with pytest.raises(IndexError):
sch.get_sref(j_0)
with pytest.raises(IndexError):
sch.get_sref(j_1)
assert sch_copy.get_sref(j_0).stmt.extent == 4
assert sch_copy.get_sref(j_1).stmt.extent == 32
verify_trace_roundtrip(sch, mod=matmul)
verify_trace_roundtrip(sch_copy, mod=matmul)
def test_tir_schedule_remove_rv():
# Tests:
# - Schedule.remove_rv
sch = tir.Schedule(matmul, debug_mask="all")
block_rv = sch.get_block(name="update")
assert sch.get(block_rv).name_hint == "update"
sch.remove_rv(block_rv)
with pytest.raises(IndexError):
sch.get(block_rv)
def test_get_child_blocks():
s = tir.Schedule(matmul, debug_mask="all")
init = s.get_block("init")
update = s.get_block("update")
# loop
blocks = s.get_child_blocks(s.get_loops(init)[0])
assert len(blocks) == 2
assert s.get(init) == s.get(blocks[0])
assert s.get(update) == s.get(blocks[1])
# block
root = s.get_block("root")
blocks = s.get_child_blocks(root)
assert len(blocks) == 2
assert s.get(init) == s.get(blocks[0])
assert s.get(update) == s.get(blocks[1])
def test_get_producers(use_block_name):
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
block = "relu" if use_block_name else sch.get_block("relu")
(producer,) = sch.get_producers(block)
tvm.ir.assert_structural_equal(
sch.get_sref(producer).stmt,
sch.get_sref(sch.get_block("matmul")).stmt,
)
verify_trace_roundtrip(sch, mod=matmul_relu)
def test_get_producers_multiple_buffer_depdencies(use_block_name):
sch = tir.Schedule(mod=tuple_reduction, debug_mask="all")
block = "T_add" if use_block_name else sch.get_block("T_add")
(producer,) = sch.get_producers(block)
tvm.ir.assert_structural_equal(
sch.get_sref(producer).stmt,
sch.get_sref(sch.get_block("data_red_temp")).stmt,
)
def test_get_consumers(use_block_name):
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
block = "matmul" if use_block_name else sch.get_block("matmul")
(consumer,) = sch.get_consumers(block)
tvm.ir.assert_structural_equal(
sch.get_sref(consumer).stmt,
sch.get_sref(sch.get_block("relu")).stmt,
)
verify_trace_roundtrip(sch, mod=matmul_relu)
def test_get_consumers_multiple_buffer_depdencies(use_block_name):
sch = tir.Schedule(mod=tuple_reduction, debug_mask="all")
block = "data_red_temp" if use_block_name else sch.get_block("data_red_temp")
(consumer,) = sch.get_consumers(block)
tvm.ir.assert_structural_equal(
sch.get_sref(consumer).stmt,
sch.get_sref(sch.get_block("T_add")).stmt,
)
def test_annotate_unannotate_loop():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
matmul = sch.get_block("matmul")
relu = sch.get_block("relu")
sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa")
sch.annotate(sch.get_loops(matmul)[1], "test2", 612)
sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1])
sch.annotate(sch.get_loops(matmul)[0], "test4", {"arr": [0, 0], "key": 3})
assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_relu_ann1)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(sch.get_loops(matmul)[0], "test1")
sch.unannotate(sch.get_loops(matmul)[1], "test2")
sch.unannotate(sch.get_loops(matmul)[1], "test3")
sch.unannotate(sch.get_loops(matmul)[0], "test4")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
def test_annotate_unannotate_block():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
matmul = sch.get_block("matmul")
relu = sch.get_block("relu")
sch.annotate(matmul, "test1", "aaa")
sch.annotate(relu, "test2", 0.22)
sch.annotate(relu, "test3", ["aa", 1])
sch.annotate(matmul, "test4", {"arr": [0, 0], "key": 3})
assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_relu_ann2)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(matmul, "test1")
sch.unannotate(relu, "test2")
sch.unannotate(relu, "test3")
sch.unannotate(matmul, "test4")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
def test_get_output_blocks_single_output():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected"
block = sch.get(output_blocks[0])
assert block.name_hint == "relu"
relu_block = sch.get_block("relu")
assert sch.get(relu_block).same_as(block)
def test_get_output_blocks_multiple_outputs():
sch = tir.Schedule(mod=matmul, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected"
block_1 = sch.get(output_blocks[0])
assert block_1.name_hint == "init"
block_2 = sch.get(output_blocks[1])
assert block_2.name_hint == "update"
init_block = sch.get_block("init")
assert sch.get(init_block).same_as(block_1)
update_block = sch.get_block("update")
assert sch.get(update_block).same_as(block_2)
def test_get_output_blocks_nested():
@T.prim_func
def blockized(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
) -> None:
with T.block("blockized_B"):
vio = T.axis.spatial(1, 0)
vjo = T.axis.spatial(1, 0)
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
sch = tir.Schedule(mod=blockized, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected"
block_1 = sch.get(output_blocks[0])
assert block_1.name_hint == "blockized_B"
block_2 = sch.get(output_blocks[1])
assert block_2.name_hint == "B"
blockized_block = sch.get_block("blockized_B")
assert sch.get(blockized_block).same_as(block_1)
b_block = sch.get_block("B")
assert sch.get(b_block).same_as(block_2)
sch = tir.Schedule(mod=blockized, debug_mask="all")
output_blocks = sch.get_output_blocks("blockized_B")
assert len(output_blocks) == 1, "Unexpected number of blocks when 1 were expected"
block = sch.get(output_blocks[0])
assert block.name_hint == "B"
b_block = sch.get_block("B")
assert sch.get(b_block).same_as(block)
if __name__ == "__main__":
tvm.testing.main()