blob: 7fe8d7c679fa1684bf4effe8309190b032ddcfef [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm.script import tir as T
from tvm.tir.buffer import Buffer
from tvm.tir.function import PrimFunc
from tvm.tir.stmt import Block
def _check_func_signature_remap(lhs: PrimFunc, rhs: PrimFunc):
assert lhs != rhs
for x, y in zip(lhs.params, rhs.params):
assert x != y
assert lhs.buffer_map[x] != rhs.buffer_map[y]
def _check_buffer_decl(lhs: Buffer, rhs: Buffer):
assert lhs != rhs
assert lhs.data != rhs.data
def _check_block_signature_remap(lhs: Block, rhs: Block):
assert lhs != rhs
for x, y in zip(lhs.iter_vars, rhs.iter_vars):
assert x != y
assert x.var != y.var
for x, y in zip(lhs.alloc_buffers, rhs.alloc_buffers):
_check_buffer_decl(x, y)
for x, y in zip(lhs.match_buffers, rhs.match_buffers):
assert x != y
_check_buffer_decl(x.buffer, y.buffer)
def test_simple():
@T.prim_func
# Buffer A should be remapped
def elementwise(A: T.Buffer((128, 128), "float32")):
# Buffer B should be remapped
B = T.alloc_buffer((128, 128), "float32")
# i, j should be remapped
for i, j in T.grid(128, 128):
with T.block("B"):
# vi, vj should be remapped
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * 2.0
f1 = elementwise
f2 = tvm.tir.stmt_functor.renew_defs(f1)
tvm.ir.assert_structural_equal(f1, f2)
_check_func_signature_remap(f1, f2)
# check root block
_check_block_signature_remap(f1.body.block, f2.body.block)
# check remap of i
assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
# check remap of j
assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var
# check inner block
def _get_block(f):
return f.body.block.body.body.body.block
_check_block_signature_remap(_get_block(f1), _get_block(f2))
def test_match_buffer():
# well-formed checker complains about multiple definitions for variable A0_s1,
# likely stemming from strides=[s, s]
@T.prim_func(check_well_formed=False)
# A and B should be remapped
def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")):
with T.block("root"):
s = T.int32()
e = T.int32()
# A0 should be remapped
A0 = T.match_buffer(
A[0:128, 0:128],
shape=(128, 128),
dtype="float32",
# s and e should be remapped
strides=[s, s],
elem_offset=e,
)
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A0[vi, vj] * 2.0
f1 = func_match_buffer
f2 = tvm.tir.stmt_functor.renew_defs(f1)
tvm.ir.assert_structural_equal(f1, f2)
_check_func_signature_remap(f1, f2)
_check_block_signature_remap(f1.body.block, f2.body.block)
assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
def _get_block(f):
return f.body.block
block1 = _get_block(f1)
block2 = _get_block(f2)
_check_block_signature_remap(block1, block2)
matched_buffer1 = block1.match_buffers[0].buffer
matched_buffer2 = block2.match_buffers[0].buffer
# Stride var s should be remapped
assert matched_buffer1.strides[0] != matched_buffer2.strides[0]
assert matched_buffer1.strides[1] != matched_buffer2.strides[1]
# s should be only remapped once
assert matched_buffer1.strides[0] == matched_buffer1.strides[1]
assert matched_buffer2.strides[0] == matched_buffer2.strides[1]
# Element-offset var e should be remapped
assert matched_buffer1.elem_offset != matched_buffer2.elem_offset
def test_undefined_buffer():
@T.prim_func
def access_alloc():
# Buffer A should be remapped
A_data = T.allocate([128], "float16", "global")
A = T.Buffer(shape=[128], dtype="float16", data=A_data)
# check if buffer var also get remapped
T.evaluate(A.data)
for i in range(128):
A[i] = A[i] + T.float16(1.0)
f1 = access_alloc
f2 = tvm.tir.stmt_functor.renew_defs(f1)
tvm.ir.assert_structural_equal(f1, f2)
assert f1.body.buffer_var != f2.body.buffer_var
def _get_buffer_store_buffer(f):
return f.body.body[1].body.buffer
_check_buffer_decl(_get_buffer_store_buffer(f1), _get_buffer_store_buffer(f2))
def test_symbolic_func():
@T.prim_func
def symbolic_func(a: T.handle, b: T.handle, n: T.int32):
m = T.int32()
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m * 2))
for i, j in T.grid(n, m):
B[i, j * 2] = A[i, j]
B[i, j * 2 + 1] = A[i, j]
f1 = symbolic_func
f2 = tvm.tir.stmt_functor.renew_defs(f1)
tvm.ir.assert_structural_equal(f1, f2)
def test_buffer_map():
@T.prim_func
def main(a: T.handle, b: T.handle):
m = T.int64()
A = T.match_buffer(a, (m * 2,))
B = T.match_buffer(b, (m, 2))
for i, j in T.grid(m, 2):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi * 2 + vj]
f1 = main
f2 = tvm.tir.stmt_functor.renew_defs(main)
tvm.ir.assert_structural_equal(f1, f2)
assert f1.buffer_map[f1.params[1]].shape[0] != f2.buffer_map[f2.params[1]].shape[0]
def test_gather():
@T.prim_func(private=True)
def take(
A: T.Buffer((4096, 4096), "float16"),
B: T.Buffer((1,), "int32"),
T_take: T.Buffer((1, 4096), "float16"),
):
for ax0, ax1 in T.grid(1, 4096):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
T.writes(T_take[v_ax0, v_ax1])
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
f1 = take
f2 = tvm.tir.stmt_functor.renew_defs(take)
tvm.ir.assert_structural_equal(f1, f2)
if __name__ == "__main__":
tvm.testing.main()