blob: 97a6b889c0115bbb77d675ee36e5aa9861d95167 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import re
import pytest
import tvm.testing
from tvm import ir, tir
from tvm.ir import Range
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T
def _assert_print(obj, expected):
assert obj.script(verbose_expr=True).strip() == expected.strip()
def test_prim_func():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(0),
).with_attr("global_symbol", "main")
_assert_print(
func,
expected="""
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)
def test_prim_func_no_sugar_inlined_buffer():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(a),
).with_attr("global_symbol", "main")
_assert_print(
func,
expected="""
# from tvm.script import tir as T
@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
A = T.match_buffer(a, (128, 128))
T.evaluate(a)
""",
)
def test_prim_func_no_sugar_shared_buffer_data():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data),
},
body=tir.Evaluate(0),
).with_attr("global_symbol", "main")
_assert_print(
func,
expected="""
# from tvm.script import tir as T
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256), data=A.data)
T.evaluate(0)
""",
)
def test_block_realize():
i = tir.Var("i", "int32")
j = tir.Var("j", "int32")
k = tir.Var("k", "int32")
with IRBuilder() as ib:
with T.block(name="block", no_realize=False):
vi = ib.name("vi", T.axis.spatial(128, i))
vj = ib.name("vj", T.axis.spatial(64, j))
vk = ib.name("vk", T.axis.reduce(32, k))
T.reads()
T.writes()
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
i = T.int32()
j = T.int32()
k = T.int32()
with T.block("block"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(64, j)
vk = T.axis.reduce(32, k)
T.reads()
T.writes()
T.evaluate(0)""",
)
def test_block():
i = tir.Var("i", "int32")
j = tir.Var("j", "int32")
k = tir.Var("k", "int32")
with IRBuilder() as ib:
with T.block(name="block", no_realize=False):
vi = ib.name("vi", T.axis.spatial(128, i))
vj = ib.name("vj", T.axis.spatial(64, j))
vk = ib.name("vk", T.axis.reduce(32, k))
T.reads()
T.writes()
T.evaluate(0)
obj = ib.get().block
_assert_print(
obj,
"""
with T.block("block", no_realize=True):
vi = T.axis.spatial(128)
vj = T.axis.spatial(64)
vk = T.axis.reduce(32)
T.reads()
T.writes()
T.evaluate(0)""",
)
def test_match_buffer_region():
src = tir.decl_buffer((128, 128), "float32", name="src")
tgt = tir.decl_buffer((64, 64), "float32", name="tgt")
obj = tir.MatchBufferRegion(
tgt,
tir.BufferRegion(
src,
[
Range(64, 128),
Range(64, 128),
],
),
)
_assert_print(
obj,
"""
src = T.Buffer((128, 128))
tgt = T.match_buffer(src[64:128, 64:128], (64, 64))
""",
)
def test_buffer():
a = tir.decl_buffer((128, 128), "float16", name="A")
_assert_print(
a,
"""A = T.Buffer((128, 128), "float16")
A""",
)
def test_buffer_region():
src = tir.decl_buffer((128, 128), "float32", name="src")
obj = tir.BufferRegion(
src,
[
Range(64, 128),
Range(64, 128),
],
)
_assert_print(
obj,
"""
src = T.Buffer((128, 128))
src[64:128, 64:128]
""",
)
def test_buffer_load():
a = tir.decl_buffer((128, 128), "float16", name="A")
obj = tir.BufferLoad(a, [128, 128])
_assert_print(
obj,
"""
A = T.Buffer((128, 128), "float16")
A[128, 128]
""",
)
def test_buffer_store():
a = tir.decl_buffer((128, 128), "float16", name="A")
with IRBuilder() as ib:
T.buffer_store(a, a[128, 128] + 1, [128, 128])
obj = ib.get()
_assert_print(
obj,
"""
A = T.Buffer((128, 128), "float16")
A[128, 128] = A[128, 128] + T.float16(1)
""",
)
def test_for():
with IRBuilder() as ib:
with T.grid(128, 128, 128) as (i, j, k):
ib.name_many(["i", "j", "k"], [i, j, k])
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
for i, j, k in T.grid(128, 128, 128):
T.evaluate(0)
""",
)
def test_let_stmt():
with IRBuilder() as ib:
with T.LetStmt(T.float32(10)) as v:
ib.name("v", v)
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.LetStmt(T.float32(10)) as v:
T.evaluate(0)
""",
)
def test_attr_stmt():
with IRBuilder() as ib:
with T.attr("pragma", "unroll", 1):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.attr("pragma", "unroll", 1):
T.evaluate(0)
""",
)
def test_assert_stmt():
with IRBuilder() as ib:
with T.Assert(True, "assertion"):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.Assert(T.bool(True), "assertion"):
T.evaluate(0)
""",
)
def test_while():
with IRBuilder() as ib:
x = T.int32()
with T.While(x < 10):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
v = T.int32()
while v < 10:
T.evaluate(0)
""",
)
def test_allocate():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32"):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
T.evaluate(0)
""",
)
def test_allocate_with_decl_buffer_sugar():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.decl_buffer((128, 128)) as buffer:
T.evaluate(0)
""",
)
def test_allocate_with_decl_buffer_sugar_multi_usage():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.decl_buffer((128, 128)) as buffer:
T.evaluate(buffer.data)
""",
)
def test_allocate_with_decl_buffer_no_sugar_mismatch():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([256, 256], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
buffer = T.decl_buffer((256, 256), data=v)
T.evaluate(v)
""",
)
def test_decl_buffer():
with IRBuilder() as ib:
with T.decl_buffer((10, 10), data=T.ptr("float32")):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
v = T.handle("float32", "global")
with T.decl_buffer((10, 10), data=v) as buffer:
T.evaluate(0)
""",
)
def test_prefetch():
a = tir.decl_buffer((128, 128), "float16", name="A")
with IRBuilder() as ib:
T.prefetch(a, [Range(0, 64), Range(0, 64)])
obj = ib.get()
_assert_print(
obj,
"""
A = T.Buffer((128, 128), "float16")
T.prefetch(A, [T.Range(0, 64), T.Range(0, 64)])
""",
)
def test_seq_stmt():
with IRBuilder() as ib:
with T.serial(10):
T.evaluate(1)
T.evaluate(2)
obj = ib.get().body
_assert_print(
obj,
"""
T.evaluate(1)
T.evaluate(2)
""",
)
def test_if_then_else():
with IRBuilder() as ib:
with T.If(T.int32() == 1):
with T.Then():
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
v = T.int32()
if v == 1:
T.evaluate(0)
""",
)
def test_evaluate():
with IRBuilder() as ib:
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
T.evaluate(0)
""",
)
def test_buffer_realize():
with IRBuilder() as ib:
a = tir.decl_buffer((128, 128), "float32", name="A")
with T.realize(a[0:128, 0:128], "test_storage_scope", True):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
A = T.Buffer((128, 128))
with T.realize(A[0:128, 0:128], "test_storage_scope"):
T.evaluate(0)
""",
)
def test_var():
a = tir.Var("a", "float32")
_assert_print(
a,
"""
a = T.float32()
a""",
)
def test_size_var():
a = tir.SizeVar("a", "float32")
_assert_print(
a,
"""
a = T.float32(is_size_var=True)
a""",
)
def test_iter_var():
a = tir.IterVar((0, 8), "a", iter_type=tir.IterVar.DataPar)
_assert_print(
a,
"""
a = T.int32()
T.iter_var(a, T.Range(0, 8), "DataPar", "")
""",
)
def test_string_imm():
s = tir.StringImm("str")
_assert_print(s, '"str"')
def test_cast():
obj = tir.Cast("float64", tir.Var("a", "float32"))
_assert_print(
obj,
"""
a = T.float32()
T.Cast("float64", a)
""",
)
def test_llvm_intrin_imm():
a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))')
a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))')
def test_binary_arith():
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
for op, sign in [
(tir.Add, "+"),
(tir.Sub, "-"),
(tir.Mul, "*"),
(tir.Mod, "truncmod"),
(tir.FloorDiv, "//"),
(tir.FloorMod, "%"),
(tir.LT, "<"),
(tir.LE, "<="),
(tir.EQ, "=="),
(tir.NE, "!="),
(tir.GT, ">"),
(tir.GE, ">="),
]:
obj = op(a, b)
if sign.isalpha():
expected = """
a = T.int32()
b = T.int32()
T.{}(a, b)""".format(
sign
)
else:
expected = """
a = T.int32()
b = T.int32()
a {} b""".format(
sign
)
_assert_print(obj, expected)
def test_binary_arith_const():
a = tir.IntImm("int64", 3)
b = tir.IntImm("int64", 4)
for op, name in [
(tir.Add, "Add"),
(tir.Sub, "Sub"),
(tir.Mul, "Mul"),
(tir.Div, "Div"),
(tir.Mod, "truncmod"),
(tir.FloorDiv, "FloorDiv"),
(tir.FloorMod, "FloorMod"),
(tir.LT, "LT"),
(tir.LE, "LE"),
(tir.EQ, "EQ"),
(tir.NE, "NE"),
(tir.GT, "GT"),
(tir.GE, "GE"),
]:
obj = op(a, b)
expected = """
T.{}({}, {})""".format(
name, str(a), str(b)
)
_assert_print(obj, expected)
def test_int_div():
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
_assert_print(
tir.Div(a, b),
"""
a = T.int32()
b = T.int32()
T.Div(a, b)
""",
)
def test_logical():
a = tir.Var("a", "bool")
b = tir.Var("b", "bool")
_assert_print(
tir.And(a, b),
"""
a = T.bool()
b = T.bool()
a and b
""",
)
_assert_print(
tir.Or(a, b),
"""
a = T.bool()
b = T.bool()
a or b
""",
)
_assert_print(
tir.Not(a),
"""
a = T.bool()
not a
""",
)
def test_select():
obj = tir.Select(True, 0, 2)
_assert_print(
obj,
"""T.Select(T.bool(True), 0, 2)
""",
)
@pytest.mark.parametrize(
"lanes, scripted_lanes", [(32, "32"), (tvm.tir.vscale() * 8, "T.vscale() * 8")]
)
def test_ramp(lanes, scripted_lanes):
a = tir.Var("a", "int32")
obj = tir.Ramp(a, 1, lanes)
_assert_print(
obj,
"""
a = T.int32()
T.Ramp(a, 1, {})
""".format(
scripted_lanes
),
)
@pytest.mark.parametrize(
"lanes, scripted_lanes", [(4, "4"), (tvm.tir.vscale() * 4, "T.vscale() * 4")]
)
def test_broadcast(lanes, scripted_lanes):
obj = tir.Broadcast(0, lanes)
_assert_print(
obj,
"""
T.Broadcast(0, {})
""".format(
scripted_lanes
),
)
def test_let_expr():
x = tir.Var("x", "int32")
obj = tir.Let(x, 1, x + 1)
_assert_print(
obj,
"""
x = T.int32()
T.Let(x + 1, where={x: 1})
""",
)
def test_call():
obj = tir.atan(T.float32(1.0))
_assert_print(
obj,
"""
T.atan(T.float32(1))
""",
)
def test_comm_reducer():
obj = T.comm_reducer(lambda x, y: x + y, identity=[T.float32(0)])
_assert_print(
obj,
"""
T.comm_reducer(lambda x, y: x + y, [T.float32(0)])
""",
)
def test_any():
obj = tir.Any()
_assert_print(
obj,
"""
T.Any()
""",
)
def test_int_imm():
obj = T.int16(1)
_assert_print(
obj,
"""
T.int16(1)
""",
)
def test_float_imm():
obj = T.float16(1)
_assert_print(
obj,
"""
T.float16(1)
""",
)
def test_range():
obj = Range(0, 10)
_assert_print(
obj,
"""
I.Range(0, 10)
""",
)
def test_prim_type():
obj = ir.PrimType("float32")
_assert_print(obj, "T.float32")
def test_pointer_type():
obj = ir.PointerType(ir.PrimType("int32"), "global")
_assert_print(obj, 'T.handle("int32", "global")')
def test_tuple_type():
obj = ir.TupleType([ir.PrimType("float32"), ir.PrimType("int32")])
_assert_print(obj, "T.Tuple(T.float32, T.int32)")
def test_remap():
from tvm.script import tir as T
@T.prim_func
def block_with_remap_implicitly():
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
v1 = T.axis.spatial(128, i1)
v2 = T.axis.reduce(128, i2)
v3 = T.axis.spatial(128, i3 - 1)
v4 = T.axis.reduce(128, i4)
v5 = T.axis.spatial(128, i5)
@T.prim_func
def block_with_remap_explicitly():
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
v1, v2 = T.axis.remap("SR", [i1, i2])
v3 = T.axis.spatial(128, i3 - 1)
v4, v5 = T.axis.remap("RS", [i4, i5])
expected_output = """
# from tvm.script import tir as T
@T.prim_func
def main():
# with T.block("root"):
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
v1, v2 = T.axis.remap("SR", [i1, i2])
v3 = T.axis.spatial(128, i3 - 1)
v4, v5 = T.axis.remap("RS", [i4, i5])
T.reads()
T.writes()
T.evaluate(0)"""
_assert_print(block_with_remap_explicitly.with_attr("global_symbol", "main"), expected_output)
_assert_print(block_with_remap_implicitly.with_attr("global_symbol", "main"), expected_output)
def test_root_block():
from tvm.script import tir as T
@T.prim_func
def root_block_implicitly():
a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block():
T.evaluate(0)
@T.prim_func
def root_block_explicitly():
with T.block("root"):
a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block():
T.evaluate(0)
expected_output = """
# from tvm.script import tir as T
@T.prim_func
def main():
# with T.block("root"):
a = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block(""):
T.reads()
T.writes()
T.evaluate(0)
"""
_assert_print(root_block_implicitly.with_attr("global_symbol", "main"), expected_output)
_assert_print(root_block_explicitly.with_attr("global_symbol", "main"), expected_output)
def test_private_primfunc():
from tvm.script import tir as T
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(0),
)
_assert_print(
func,
expected="""
# from tvm.script import tir as T
@T.prim_func(private=True)
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)
def test_prim_func_different_symbol():
from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.func_attr({"global_symbol": "func"})
T.evaluate(0)
expected_output = """
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)
"""
_assert_print(main, expected_output)
def test_variable_with_cpp_address():
"""The show_object_address option displays the C++ addressess
Because the C++ address may vary with each execution, the output
produced with this option cannot be compared to a fixed string.
Instead, this test uses the normal script output to generate a
regular expression against with the test output must match. The
regular expression validates that all names have been appended
with "_0x" followed by a hexadecimal number, and that the address
is the same for each variable.
"""
from tvm.script import tir as T
# The test function has all named objects suffixed with "_name",
# to avoid spurious replacement when generating the expected
# regex.
@T.prim_func
def func(a_name: T.handle):
N_name = T.int64()
A_name = T.match_buffer(a_name, N_name, "float32")
for i_name in range(N_name):
A_name[i_name] = A_name[i_name] + 1.0
without_address = func.script(show_object_address=False)
script = func.script(show_object_address=True)
expected_regex = re.escape(without_address)
for name in ["a_name", "A_name", "N_name", "i_name"]:
# Replace all occurrences with a backref to an earlier match
expected_regex = expected_regex.replace(name, rf"(?P={name})")
# Then replace the first such backref with a capturing group.
expected_regex = expected_regex.replace(
rf"(?P={name})", rf"(?P<{name}>{name}_0x[A-Fa-f0-9]+)", 1
)
assert re.match(expected_regex, script)
def test_return_statement():
from tvm.script import tir as T
@T.prim_func
def func():
T.evaluate(T.ret(5))
expected_output = """
# from tvm.script import tir as T
@T.prim_func
def func():
return 5
"""
_assert_print(func, expected_output)
if __name__ == "__main__":
tvm.testing.main()