blob: 5b8e7e9ba20c51c839cf1ea37d87f4dc6f86dcf5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" This file tests advanced emit_te features with help of TVMScript assertion"""
# The tests here depend on tvmscript
import tvm
from tvm import te, tir
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
def test_emit_te_with_symbolic_arg():
bb = rx.BlockBuilder()
m = tir.Var("m", "int64")
x = rx.Var("x", R.Tensor([10], "float32"))
y = rx.Var("y", R.Shape([m]))
def te_func(A, offset):
return te.compute(A.shape, lambda i: A[i + offset], name="B")
with bb.function("main", [x, y]):
out = bb.emit_te(te_func, x, m)
bb.emit_func_output(out)
after = bb.get()
@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_func(
A: T.Buffer((T.int64(10),), "float32"),
B: T.Buffer((T.int64(10),), "float32"),
m: T.int64,
):
T.func_attr({"tir.noalias": True})
for i in range(T.int64(10)):
with T.block("B"):
v_i = T.axis.spatial(T.int64(10), i)
T.writes(B[v_i])
B[v_i] = A[v_i + m]
@R.function
def main(
x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])
) -> R.Tensor((10,), dtype="float32"):
m = T.int64()
cls = Expected
gv = R.call_tir(
cls.te_func,
(x,),
out_sinfo=R.Tensor((10,), dtype="float32"),
tir_vars=R.shape([m]),
)
return gv
assert_structural_equal(after, Expected)
def test_symbolic_shape_in_prim_value():
"""Symbolic vars may be provided to TE in R.Prim"""
def te_slice(tensor, i):
return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice")
def from_builder():
bb = rx.BlockBuilder()
A = rx.Var("A", R.Tensor([16, 16], "float32"))
tir_i = tvm.tir.Var("tir_i", "int64")
relax_i = rx.Var("relax_i", R.Prim(value=tir_i))
with bb.function("main", params=[A, relax_i]):
A_sliced = bb.emit_te(te_slice, A, relax_i)
bb.emit_func_output(A_sliced)
return bb.get()
@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_slice(
A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
Output: T.Buffer(T.int64(16), "float32"),
row_index: T.int64,
):
T.func_attr({"tir.noalias": True})
for i in range(A.shape[1]):
with T.block("slice"):
vi = T.axis.remap("S", [i])
Output[vi] = A[row_index, vi]
@R.function
def main(
A: R.Tensor([16, 16], "float32"),
arg_row_index: R.Prim(value="row_index"),
):
cls = Expected
row_index = T.int64()
gv = R.call_tir(
cls.te_slice,
A,
tir_vars=[row_index],
out_sinfo=R.Tensor([16], "float32"),
)
return gv
tvm.ir.assert_structural_equal(from_builder(), Expected)