| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import tvm |
| import tvm.testing |
| from tvm import relax as rx |
| from tvm.relax.analysis import contains_impure_call |
| from tvm.script import relax as R |
| |
| |
| def test_simple_pure_case(): |
| @tvm.script.ir_module |
| class PureTest: |
| @R.function |
| def pure_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): |
| y = R.add(x, x) |
| z = R.multiply(x, y) |
| return R.add(z, R.const(1, "int32")) |
| |
| assert not contains_impure_call(PureTest["pure_func"]) |
| |
| |
| def test_simple_impure_case(): |
| @tvm.script.ir_module |
| class ImpureTest: |
| @R.function(pure=False) |
| def impure_func() -> R.Object: |
| y = R.print(format="I am a message") |
| return y |
| |
| assert contains_impure_call(ImpureTest["impure_func"]) |
| |
| |
| def test_nested_function(): |
| @tvm.script.ir_module |
| class NestedTest: |
| @R.function |
| def pure_with_impure_nested() -> R.Tensor((), "int32"): |
| # unused |
| @R.function(pure=False) |
| def impure_inner() -> R.Object: |
| y = R.print(format="Another, worse, message") |
| return y |
| |
| x = R.const(0, dtype="int32") |
| return R.add(x, x) |
| |
| assert not contains_impure_call(NestedTest["pure_with_impure_nested"]) |
| assert contains_impure_call( |
| NestedTest["pure_with_impure_nested"].body.blocks[0].bindings[0].value |
| ) |
| |
| |
| def test_ignoring_recursive_call(): |
| # Ignoring a recursive call. This can be useful if some transformation |
| # removes an impure operation and the compiler needs to check if the impure |
| # function has become pure |
| @tvm.script.ir_module |
| class RecursiveTest: |
| @R.function(pure=False) |
| def recursive_impure() -> R.Object: |
| x = R.const(1, "int32") |
| y = R.add(x, x) |
| z = R.print(x, y, format="{} {}") |
| w = RecursiveTest.recursive_impure() |
| return w |
| |
| assert contains_impure_call(RecursiveTest["recursive_impure"]) |
| # but if we remove the impure call... |
| body = RecursiveTest["recursive_impure"].body |
| own_name = body.blocks[0].bindings[-1].value.op |
| # skipping the call to print... |
| new_bindings = [ |
| body.blocks[0].bindings[0], |
| body.blocks[0].bindings[1], |
| body.blocks[0].bindings[-1], |
| ] |
| # Note: we construct the function in this way so that we keep the old vars |
| # with their current StructInfo. That would get fixed during normalization. |
| # However, this situation is meant to correspond to an intermediate state |
| # that might arise within a pass. |
| new_body = rx.SeqExpr([rx.BindingBlock(new_bindings)], body.body) |
| |
| # if we didn't ignore the recursive call, the fact the var's StructInfo |
| # calls it impure would throw it off |
| assert not contains_impure_call(new_body, own_name=own_name) |
| assert contains_impure_call(new_body) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |