blob: cead775e97cd883767765c47e7eca916a2364818 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring, missing-module-docstring
import pytest
import tvm
from tvm.script import tir as T
from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None:
m = T.int32()
A = T.match_buffer(a, [m, n])
B = T.match_buffer(b, [m, n])
C = T.match_buffer(c, [m, m])
for i, j, k in T.grid(m, m, n):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None:
m = T.int32()
A = T.match_buffer(a, [m, 128])
B = T.match_buffer(b, [m, 128])
C = T.match_buffer(c, [m, m])
for i, j, k in T.grid(m, m, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
# x is considered undefined because it appears as part of x*8,
# but not on its own
@T.prim_func(check_well_formed=False)
def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None:
x = T.int32()
m = T.int32()
A = T.match_buffer(a, [m, x * 8])
B = T.match_buffer(b, [m, x * 8])
C = T.match_buffer(c, [m, m])
for i, j, k in T.grid(m, m, x * 8):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def element_wise(a: T.handle, c: T.handle) -> None:
m = T.int32()
n = T.int32()
A = T.match_buffer(a, (m, n), "float32")
C = T.match_buffer(c, (m, n), "float32")
B = T.alloc_buffer((m, n), "float32")
for i, j in T.grid(m, n):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(m, n):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def element_wise_128_64(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 64), "float32")
C = T.match_buffer(c, (128, 64), "float32")
B = T.alloc_buffer((128, 64), "float32")
for i, j in T.grid(128, 64):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 64):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def element_wise_128_n(a: T.handle, c: T.handle) -> None:
n = T.int32()
A = T.match_buffer(a, (128, n), "float32")
C = T.match_buffer(c, (128, n), "float32")
B = T.alloc_buffer((128, n), "float32")
for i, j in T.grid(128, n):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, n):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q)
B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q)
for i, j in T.grid(m, n):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
@T.prim_func
def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4)
B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4)
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
@T.prim_func
def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n)
B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n)
for i, j in T.grid(m, n):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
def test_specialize_nothing():
func = matmul.specialize({})
assert func.same_as(matmul) # Pointer the same
def test_specialize_matmul():
a, _, _, n = matmul.params
# fully specialized
func = matmul.specialize({a: tvm.tir.decl_buffer((128, 128))})
assert_structural_equal_ignore_global_symbol(func, matmul_128)
# partially specialized
func = matmul.specialize({n: 128})
assert_structural_equal_ignore_global_symbol(func, matmul_m_128)
# symbolic specialized
func = matmul.specialize({n: tvm.tir.Var("x", "int32") * 8})
assert_structural_equal_ignore_global_symbol(func, matmul_m_8x)
def test_specialize_elemwise():
a, c = element_wise.params
C = element_wise.buffer_map[c]
# fully specialized
func = element_wise.specialize({a: tvm.tir.decl_buffer((128, 64))})
assert_structural_equal_ignore_global_symbol(func, element_wise_128_64)
# partially specialized
func = element_wise.specialize({c: tvm.tir.decl_buffer((128, C.shape[1]))})
assert_structural_equal_ignore_global_symbol(func, element_wise_128_n)
def test_specialize_mem_copy():
a, _, m, n, p, q = mem_copy.params
# fully specialized
func = mem_copy.specialize({a: tvm.tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)})
assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4)
func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4})
assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4)
# partially specialized
func = mem_copy.specialize({q: n})
assert_structural_equal_ignore_global_symbol(func, mem_copy_m_n_p_n)
def test_specialize_recursive_load():
# TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]]
pass
def test_specialize_with_const_folding():
@T.prim_func
def before(a: T.handle, b: T.handle):
n = T.int32()
A = T.match_buffer(a, [n // 8, 8], "int32")
B = T.match_buffer(b, [n], "int32")
for i in range(n - 1):
with T.block():
vi = T.axis.S(n - 1, i)
B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42
@T.prim_func
def expected(a: T.handle, b: T.handle):
A = T.match_buffer(a, [2, 8], "int32")
B = T.match_buffer(b, [16], "int32")
for i in range(15):
with T.block():
vi = T.axis.S(15, i)
B[vi] = A[vi // 8, vi % 8] + 714
b = before.params[1]
after = before.specialize({b: tvm.tir.decl_buffer([16], dtype="int32")})
assert_structural_equal_ignore_global_symbol(expected, after)
def test_specialize_decl_buffer():
"""Buffers occurring in a DeclBuffer statement should be updated"""
@T.prim_func(private=True)
def before(A_data: T.handle("float32"), A_size: T.int32):
A_buf = T.decl_buffer(A_size, "float32", data=A_data)
for i in range(A_size):
A_buf[i] = A_buf[i] * 2.0
@T.prim_func(private=True)
def expected(A_data: T.handle("float32")):
A_buf = T.decl_buffer(16, "float32", data=A_data)
for i in range(16):
A_buf[i] = A_buf[i] * 2.0
param_map = {before.params[1]: T.int32(16)}
after = before.specialize(param_map)
tvm.ir.assert_structural_equal(expected, after)
def test_specialize_buffer_var_to_var():
"""A buffer var may be remapped by specialization
If a buffer variable is replaced by a specialization, then other
buffers using the same buffer var should also be updated.
"""
@T.prim_func(private=True)
def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")):
A_flat = T.decl_buffer([256], "float32", data=A.data)
B_flat = T.decl_buffer([256], "float32", data=B.data)
for i in range(256):
B_flat[i] = A_flat[i] * 2.0
# well-formed checker complains about multiple nested definitions of B_flat
# since it appears in the buffer map twice
@T.prim_func(private=True, check_well_formed=False)
def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle):
B = T.match_buffer(B_handle, [16, 16], "float32", data=A.data)
A_flat = T.decl_buffer([256], "float32", data=A.data)
B_flat = T.decl_buffer([256], "float32", data=A.data)
for i in range(256):
B_flat[i] = A_flat[i] * 2.0
A = before.buffer_map[before.params[0]]
B_handle = before.params[1]
param_map = {B_handle: A}
after = before.specialize(param_map)
tvm.ir.assert_structural_equal(expected, after)
def test_specialize_buffer_var_to_expr():
"""Handle specialization of buffer var
The `tir::Buffer::data` field must be an explicit `tir::Var`, and
cannot be replaced with a `tir::PrimExpr` of type
`DataType::Handle()`. However, these substitutions are useful
when lowering. If these occur, a binding of the `tir::Var` is
included in the specialized function.
"""
@T.prim_func(private=True)
def before(A_data: T.handle("float32"), B_data: T.handle("float32")):
A_buf = T.decl_buffer(32, "float32", data=A_data)
B_buf = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B_buf[i] = A_buf[i] * 2.0
@T.prim_func(private=True)
def expected(A_data: T.handle("float32")):
A_buf = T.decl_buffer(32, "float32", data=A_data)
B_data: T.Ptr[T.float32] = T.address_of(A_buf[16])
B_buf = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B_buf[i] = A_buf[i] * 2.0
B_data = before.params[1]
A_buf = before.body.buffer
param_map = {B_data: tvm.tir.address_of(A_buf[16])}
after = before.specialize(param_map)
tvm.ir.assert_structural_equal(expected, after)
def test_specialization_updates_struct_info():
"""Update struct info in specialization
A PrimFunc may have a `relax.StructInfo`. If that PrimFunc is
specialized, the struct info should be updated.
"""
@T.prim_func(private=True)
def before(n: T.int32) -> T.int32:
T.ret(n * 10)
@T.prim_func(private=True)
def expected() -> T.int32:
T.ret(50)
sinfo_before = tvm.relax.FuncStructInfo(
[tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
)
tvm.ir.assert_structural_equal(before.struct_info, sinfo_before)
sinfo_expected = tvm.relax.FuncStructInfo([], tvm.relax.PrimStructInfo("int32"))
tvm.ir.assert_structural_equal(expected.struct_info, sinfo_expected)
n = before.params[0]
param_map = {n: 5}
after = before.specialize(param_map)
tvm.ir.assert_structural_equal(after, expected)
tvm.ir.assert_structural_equal(after.struct_info, sinfo_expected)
if __name__ == "__main__":
tvm.testing.main()