blob: 2edf74ebfb3d70aa4dfb584c272edc4436a29dbc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm.testing
from tvm.script import ir as I, tir as T
class BaseTestCase:
def test_well_formed(self):
After = tvm.tir.transform.InlinePrivateFunctions()(self.Before)
tvm.tir.analysis.verify_well_formed(After)
def test_produces_expected(self):
After = tvm.tir.transform.InlinePrivateFunctions()(self.Before)
tvm.ir.assert_structural_equal(self.Expected, After)
class TestSimple(BaseTestCase):
"""Simple case directly acting on PrimFunc"""
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")):
for i in range(64):
Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0]))
@T.prim_func(private=True)
def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")):
A = T.decl_buffer([16, 16], "float32", data=A_data)
B = T.decl_buffer([16], "float32", data=B_data)
for i in range(16):
B[i] = 0.0
for j in range(16):
B[i] = B[i] + A[i, j]
@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")):
for i in range(64):
A_view_data: T.handle("float32") = T.address_of(A[i, 0])
Aview = T.decl_buffer([16, 16], "float32", data=A_view_data)
B_view_data: T.handle("float32") = T.address_of(B[i, 0])
Bview = T.decl_buffer([16], "float32", data=B_view_data)
for j in range(16):
Bview[j] = 0.0
for k in range(16):
Bview[j] = Bview[j] + Aview[j, k]
class TestRetainCrossFunctionSubroutines(BaseTestCase):
"""Do not inline functions that cross device boundaries
When lowering TIR, calls for which the callsite and callee have
different targets are used at some stages, before being further
lowered to explicit device kernel launches. Since inlining the
function would remove this cross-device information,
InlinePrivateSubroutines should not inline these cases.
"""
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")):
T.func_attr({"target": T.target("llvm")})
for i in range(64):
Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0]))
@T.prim_func(private=True)
def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
A = T.decl_buffer([16, 16], "float32", data=A_data)
B = T.decl_buffer([16], "float32", data=B_data)
for i in range(16):
B[i] = 0.0
for j in range(16):
B[i] = B[i] + A[i, j]
Expected = Before
class TestRetainRecursiveSubroutines(BaseTestCase):
"""Do not inline recursive functions
To avoid potentially infinite loops at compile-time, disable
inlining of recursive functions. If inlining of these functions
would be useful, this restriction may be relaxed with improved
analysis of the subroutine.
"""
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
Before.subroutine(T.address_of(A[0]), 16)
@T.prim_func(private=True)
def subroutine(A_data: T.handle("float32"), A_size: T.int32):
A = T.decl_buffer(A_size, "float32", data=A_data)
A[1] = A[0] + A[1]
if A_size > 1:
Before.subroutine(T.address_of(A[1]), A_size - 1)
Expected = Before
class TestDeduplicateBlockName(BaseTestCase):
"""Block names must be de-duplicated after inlining"""
@pytest.mark.xfail(reason="Inlining of schedulable TIR not yet supported")
def test_produces_expected(self):
super().test_produces_expected(self)
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer([2, 16], "float32"), B: T.Buffer([2, 16], "float32")):
Before.subroutine(T.address_of(A[0, 0]), T.address_of(B[0, 0]))
Before.subroutine(T.address_of(A[1, 0]), T.address_of(B[1, 0]))
@T.prim_func(private=True)
def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
with T.block("scalar_mul"):
B[i] = A[i] * 2.0
@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")):
with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1:
A_1 = T.decl_buffer(16, "float32", data=A_data_1)
B_data_1: T.handle("float32") = T.address_of(B[0, 0])
B_1 = T.decl_buffer(16, "float32", data=B_data_1)
for i in range(16):
with T.block("scalar_mul_1"):
B_1[i] = A_1[i] * 2.0
with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2:
A_2 = T.decl_buffer(16, "float32", data=A_data_2)
B_data_2: T.handle("float32") = T.address_of(B[1, 0])
B_2 = T.decl_buffer(16, "float32", data=B_data_2)
for i in range(16):
with T.block("scalar_mul_2"):
B_2[i] = A_2[i] * 2.0
class TestInlineCallOccurringInExpression(BaseTestCase):
"""Inline a Call node that is used in a function
The current implementation only replaces `tir.Call` instances that
occur in a `tir.Evaluate` context. This is the primary use case,
used in destination-passing style.
This unit test is marked as xfail. If/when the implementation
supports inlining of function calls occurring as part of an
expression, the annotation can be removed.
"""
@pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tir.Evaluate is not yet supported")
def test_produces_expected(self):
super().test_produces_expected(self)
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
for i in range(16):
A[i] = Before.subroutine(i)
@T.prim_func(private=True)
def subroutine(i: T.int32) -> T.float32:
cos = T.cos(T.cast(i, "float32"))
sin = T.sin(T.cast(i, "float32"))
retval = cos * cos + sin * sin
T.ret(retval)
@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
for i in range(16):
cos = T.cos(T.cast(i, "float32"))
sin = T.sin(T.cast(i, "float32"))
retval = cos * cos + sin * sin
A[i] = retval
class TestInlineFunctionWithBufferArguments(BaseTestCase):
"""Inline a function that accepts buffer arguments
The current implementation does not support this usage. This unit
test is provided to display a possible user interaction, and is
marked with `@pytest.mark.xfail`. If/when the implementation
supports inlining of function calls with buffer arguments, the
annotation can be removed.
"""
@pytest.mark.xfail(reason="Inlining of PrimFuncs with buffer arguments")
def test_produces_expected(self):
super().test_produces_expected(self)
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
Before.subroutine(
T.tvm_stack_make_array(
A.data,
T.tvm_stack_make_shape(*A.shape, dtype="handle"),
0,
len(A.shape),
0.0,
A.elem_offset,
dtype="handle",
)
)
@T.prim_func(private=True)
def subroutine(A: T.Buffer(16, "float32")):
for i in range(16):
A[i] = A[i] * 2.0
@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
for i in range(16):
A[i] = A[i] * 2.0
if __name__ == "__main__":
tvm.testing.main()