blob: b4c7adc844567c6652c26555be59c85977e962ea [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List
import tvm
import tvm.testing
from tvm import relax as rx
from tvm.script import relax as R, tir as T
from tvm.relax.analysis import detect_recursion
def assert_groups(groups: List[List[rx.GlobalVar]], expected: List[List[str]]) -> None:
assert len(groups) == len(expected)
# disregard order, search only by name for convenience
expected_sets = [set(expected_group) for expected_group in expected]
actual_sets = [set(map(lambda gv: gv.name_hint, actual_group)) for actual_group in groups]
for expected_set in expected_sets:
assert expected_set in actual_sets
def test_no_recursion():
@tvm.script.ir_module
class NoRecursion:
@R.function
def a(x: R.Object) -> R.Object:
return x
@R.function
def b(x: R.Object) -> R.Object:
return x
groups = detect_recursion(NoRecursion)
assert len(groups) == 0
def test_simple_recursion():
@tvm.script.ir_module
class SimpleRecursion:
@R.function
def c(x: R.Object) -> R.Object:
return SimpleRecursion.c(x)
groups = detect_recursion(SimpleRecursion)
assert_groups(groups, ["c"])
def test_tree():
# no cycle!
@tvm.script.ir_module
class Tree:
@R.function
def a(x: R.Object) -> R.Object:
return Tree.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return Tree.c(x)
@R.function
def c(x: R.Object) -> R.Object:
z: R.Object = Tree.d(x)
return Tree.e(z)
@R.function
def d(x: R.Object) -> R.Object:
return Tree.e(x)
@R.function
def e(x: R.Object) -> R.Object:
return x
groups = detect_recursion(Tree)
assert len(groups) == 0
def test_two_function_case():
@tvm.script.ir_module
class TwoFunctionCase:
@R.function
def a(x: R.Object) -> R.Object:
return TwoFunctionCase.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return TwoFunctionCase.a(x)
# not part of the group, shouldn't be reported
@R.function
def c(x: R.Object) -> R.Object:
return x
groups = detect_recursion(TwoFunctionCase)
assert_groups(groups, [["a", "b"]])
def test_two_groups_of_two():
@tvm.script.ir_module
class TwoGroupsOfTwo:
@R.function
def a(x: R.Object) -> R.Object:
return TwoGroupsOfTwo.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return TwoGroupsOfTwo.a(x)
@R.function
def c(x: R.Object) -> R.Object:
return TwoGroupsOfTwo.d(x)
@R.function
def d(x: R.Object) -> R.Object:
return TwoGroupsOfTwo.c(x)
# not part of either group, shouldn't be reported
@R.function
def e(x: R.Object) -> R.Object:
return x
groups = detect_recursion(TwoGroupsOfTwo)
assert_groups(groups, [["a", "b"], ["c", "d"]])
def test_mutual_recursion_and_simple_recursion():
@tvm.script.ir_module
class MutualAndSimple:
@R.function
def a(x: R.Object) -> R.Object:
return MutualAndSimple.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return MutualAndSimple.a(x)
# forms its own group
@R.function
def c(x: R.Object) -> R.Object:
return MutualAndSimple.c(x)
groups = detect_recursion(MutualAndSimple)
assert_groups(groups, [["a", "b"], ["c"]])
def test_simultaneous_mutual_and_simple_recursion():
# even though both call themselves and each other,
# it should still form only one group
@tvm.script.ir_module
class SimultaneousMutualAndSimple:
@R.function
def a(x: R.Object) -> R.Object:
cls = SimultaneousMutualAndSimple
return cls.b(cls.a(x))
@R.function
def b(x: R.Object) -> R.Object:
cls = SimultaneousMutualAndSimple
return cls.a(cls.b(x))
groups = detect_recursion(SimultaneousMutualAndSimple)
assert_groups(groups, [["a", "b"]])
def test_three_function_case():
@tvm.script.ir_module
class ThreeFunctionCase:
@R.function
def a(x: R.Object) -> R.Object:
return ThreeFunctionCase.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return ThreeFunctionCase.c(x)
@R.function
def c(x: R.Object) -> R.Object:
return ThreeFunctionCase.a(x)
groups = detect_recursion(ThreeFunctionCase)
assert_groups(groups, [["a", "b", "c"]])
def test_call_from_outside_of_group():
@tvm.script.ir_module
class CallFromOutOfGroup:
# A calls into a group of mutually recursive functions,
# but is not part of the cycle
@R.function
def a(x: R.Object) -> R.Object:
return CallFromOutOfGroup.d(x)
@R.function
def b(x: R.Object) -> R.Object:
return CallFromOutOfGroup.c(x)
@R.function
def c(x: R.Object) -> R.Object:
return CallFromOutOfGroup.d(x)
@R.function
def d(x: R.Object) -> R.Object:
return CallFromOutOfGroup.b(x)
# E also calls into the cycle but isn't part of it
@R.function
def e(x: R.Object) -> R.Object:
return CallFromOutOfGroup.b(x)
groups = detect_recursion(CallFromOutOfGroup)
assert_groups(groups, [["b", "c", "d"]])
def test_call_from_group_to_outside():
@tvm.script.ir_module
class CallFromGroupToOutside:
# A calls into a group of mutually recursive functions,
# but is not part of the cycle
@R.function
def a(x: R.Object) -> R.Object:
return CallFromGroupToOutside.b(x)
@R.function
def b(x: R.Object) -> R.Object:
# d is called from a member of the group but it is not part of the cycle
z: R.Object = CallFromGroupToOutside.d(x)
return CallFromGroupToOutside.c(z)
@R.function
def c(x: R.Object) -> R.Object:
return CallFromGroupToOutside.a(x)
@R.function
def d(x: R.Object) -> R.Object:
return x
groups = detect_recursion(CallFromGroupToOutside)
assert_groups(groups, [["a", "b", "c"]])
def test_group_with_two_cycles():
"""
a -> b <- f
^ | ^
| v |
d <- c -> e
There are two smaller cycles in this group,
but you can have one big cycle
B -> C -> D -> A -> B -> C -> E -> F -> B
"""
@tvm.script.ir_module
class GroupWithTwoCycles:
@R.function
def a(x: R.Object) -> R.Object:
return GroupWithTwoCycles.b(x)
@R.function
def b(x: R.Object) -> R.Object:
return GroupWithTwoCycles.c(x)
@R.function
def c(x: R.Object) -> R.Object:
y = GroupWithTwoCycles.d(x)
return GroupWithTwoCycles.e(y)
@R.function
def d(x: R.Object) -> R.Object:
return GroupWithTwoCycles.a(x)
@R.function
def e(x: R.Object) -> R.Object:
return GroupWithTwoCycles.f(x)
@R.function
def f(x: R.Object) -> R.Object:
return GroupWithTwoCycles.b(x)
groups = detect_recursion(GroupWithTwoCycles)
assert_groups(groups, [["a", "b", "c", "d", "e", "f"]])
def test_multicycle_example():
"""
Example from the documentation
A <-> B <-> C
^ | ^
| v |
| D |
| | |
v v v
E <-> F <-> G
"""
@tvm.script.ir_module
class MulticycleExample:
@R.function
def a(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.b(x)
return cls.e(y)
@R.function
def b(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.a(x)
z = cls.c(y)
return cls.d(z)
@R.function
def c(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.g(x)
return cls.b(y)
@R.function
def d(x: R.Object) -> R.Object:
cls = MulticycleExample
return cls.f(x)
@R.function
def e(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.f(x)
return cls.a(y)
@R.function
def f(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.g(x)
return cls.e(y)
@R.function
def g(x: R.Object) -> R.Object:
cls = MulticycleExample
y = cls.f(x)
return cls.c(y)
groups = detect_recursion(MulticycleExample)
assert_groups(groups, [["a", "b", "c", "d", "e", "f", "g"]])
def test_control_flow():
@tvm.script.ir_module
class ControlFlowExample:
@R.function
def a(x: R.Object) -> R.Object:
cls = ControlFlowExample
y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool")
if y:
ret = cls.b(x)
else:
ret = cls.c(x)
return ret
@R.function
def b(x: R.Object) -> R.Object:
cls = ControlFlowExample
return cls.a(x)
@R.function
def c(x: R.Object) -> R.Object:
cls = ControlFlowExample
return cls.a(x)
groups = detect_recursion(ControlFlowExample)
assert_groups(groups, [["a", "b", "c"]])
def test_returning_self():
@tvm.script.ir_module
class ReturnsSelf:
@R.function
def a() -> R.Object:
# this is also a form of recursion
return ReturnsSelf.a
groups = detect_recursion(ReturnsSelf)
assert_groups(groups, [["a"]])
def test_mutual_recursion_via_references():
@tvm.script.ir_module
class GatherReferences:
@R.function
def a(x: R.Object) -> R.Object:
cls = GatherReferences
return cls.b(x)
@R.function
def b(x: R.Object) -> R.Object:
cls = GatherReferences
return (cls.a, cls.b, cls.c)
@R.function
def c(x: R.Object) -> R.Object:
cls = GatherReferences
return cls.a(x)
groups = detect_recursion(GatherReferences)
assert_groups(groups, [["a", "b", "c"]])
def test_disregard_primfuncs():
@tvm.script.ir_module
class CallPrimFunc:
# copied from test_analysis.py
@T.prim_func
def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")):
C = T.alloc_buffer((128, 128), "float32")
for i0, i1 in T.grid(4, 4):
with T.block("identity"):
vi0, vi1 = T.axis.remap("SS", [i0, i1])
C[vi0, vi1] = A[vi0, vi1]
for i0, i1 in T.grid(4, 4):
with T.block("identity"):
vi0, vi1 = T.axis.remap("SS", [i0, i1])
B[vi0, vi1] = C[vi0, vi1]
@R.function
def a(x: R.Tensor((4, 4), "float32")) -> R.Object:
cls = CallPrimFunc
y = R.call_tir(cls.identity_identity, x, R.Tensor((4, 4), "float32"))
return cls.b(y)
@R.function
def b(x: R.Tensor((4, 4), "float32")) -> R.Object:
cls = CallPrimFunc
y = R.call_tir(cls.identity_identity, x, R.Tensor((4, 4), "float32"))
return cls.a(y)
groups = detect_recursion(CallPrimFunc)
# the prim func should not be listed here
assert_groups(groups, [["a", "b"]])
if __name__ == "__main__":
tvm.testing.main()