| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import tvm |
| import tvm.testing |
| from tvm import te, tir |
| from tvm.script import tir as T |
| |
| |
| class CanonicalChecker: |
| def __init__(self): |
| self.analyzer = tvm.arith.Analyzer() |
| |
| def _convert(self, expr): |
| # TODO(Lunderberg): Make utility functions `tir.convert` and |
| # `relax.convert` that convert to their respective IR types. |
| # Implementation should be in C++, and should only consist of |
| # conversions that are applied automatically through FFI. |
| if isinstance(expr, int): |
| return T.int32(expr) |
| else: |
| return expr |
| |
| def verify(self, data, expected): |
| res = self.analyzer.canonical_simplify(data) |
| expected = self._convert(expected) |
| assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( |
| data, res, expected |
| ) |
| |
| |
| def test_mul_sum_simplify(): |
| ck = CanonicalChecker() |
| x, y, z = te.var("x"), te.var("y"), te.var("z") |
| |
| ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6) |
| ck.verify(x * 3 - 4 * x + 1, 1 - x) |
| ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) |
| tdiv = tvm.tir.truncdiv |
| tmod = tvm.tir.truncmod |
| # trucdiv |
| ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x) |
| ck.verify(tmod(x + y + x + y * 3, 2), 0) |
| |
| # floordiv |
| fld = tvm.te.floordiv |
| flm = tvm.te.floormod |
| ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2)) |
| ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x) |
| ck.verify(flm(x + y + x + y * 3, 2), 0) |
| ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x) |
| |
| |
| def test_split_index_simplify(): |
| ck = CanonicalChecker() |
| x, y, z = te.var("x"), te.var("y"), te.var("z") |
| |
| # trucdiv |
| tdiv = tvm.tir.truncdiv |
| tmod = tvm.tir.truncmod |
| |
| # split div const |
| ck.verify(tdiv(x, 3) * 3 + tmod(x, 3), x) |
| ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x) |
| ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4)) |
| ck.verify(tdiv(tmod(x, 2), 8), 0) |
| ck.verify(tdiv(tmod(x, 2), 7), 0) |
| ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6)) |
| |
| # split mod const |
| ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8) |
| ck.verify(tmod(x * 8, 2), 0) |
| |
| # simplify then fold |
| ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) |
| ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000)) |
| ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y) |
| # complex fold |
| ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y) |
| |
| ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True) |
| ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True) |
| ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y) |
| |
| # floordiv |
| fld = tvm.te.floordiv |
| flm = tvm.te.floormod |
| ck.verify(fld(x * 5, 2), fld(x * 5, 2)) |
| ck.verify(fld(x, 3) * 3 + flm(x, 3), x) |
| ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x) |
| ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4)) |
| ck.verify(fld(flm(x, 2), 8), 0) |
| ck.verify(fld(flm(x, 2), 7), 0) |
| ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6)) |
| |
| # cannot simplify mixed case, unless we canonicalize into one mode. |
| ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2)) |
| |
| ck.verify(tmod(-x, 2), tmod(x, -2) * -1) |
| |
| |
| def test_div_simplify(): |
| ck = CanonicalChecker() |
| x = te.var("x") |
| tdiv = tvm.tir.truncdiv |
| |
| # truc div |
| ck.verify(tdiv(16 + 48 * x, 16), x * 3 + 1) |
| # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 |
| # (17+48*x)/16 != 1+3*x |
| ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16)) |
| # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified |
| ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10)) |
| ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1) |
| # Trying expressions that are not simplifiable for any values of the variables |
| ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16)) |
| |
| # floordiv |
| fld = tvm.te.floordiv |
| ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True) |
| ck.verify(fld(16 + 48 * x, 16), x * 3 + 1) |
| ck.verify(fld(17 + 48 * x, 16), x * 3 + 1) |
| ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16)) |
| |
| |
| def test_fp16_const_fold(): |
| ck = CanonicalChecker() |
| zero = tvm.tir.const(0, "float16") |
| one = tvm.tir.const(1, "float16") |
| half = tvm.tir.const(0.5, "float16") |
| |
| ck.verify(zero + half, half) |
| ck.verify(half - zero, half) |
| |
| ck.verify(zero * half, zero) |
| ck.verify(half * one, half) |
| |
| ck.verify(half / one, half) |
| ck.verify(zero / half, zero) |
| |
| |
| def test_floormod_simplify(): |
| ck = CanonicalChecker() |
| flm = tvm.te.floormod |
| x, y = te.var("x"), te.var("y") |
| ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16)) |
| ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4) |
| |
| ck.verify(flm(-x, 2), flm(x, -2) * -1) |
| |
| |
| def test_canonical_mixed(): |
| ck = CanonicalChecker() |
| x = te.var("x") |
| z = tvm.tir.const(3, "int32") |
| tdiv = tvm.tir.truncdiv |
| tmod = tvm.tir.truncmod |
| ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0) |
| ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0) |
| ck.verify(x - 2 < 3, x < 5) |
| ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0) |
| ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0) |
| ck.verify(x * x - x * x, 0) |
| ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2) |
| |
| fld = tvm.te.floordiv |
| ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0) |
| ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0) |
| |
| |
| def test_reduce_combiner_simplify(): |
| ck = CanonicalChecker() |
| dummy = te.var("dummy") |
| comm_reducer = te.comm_reducer |
| prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0)) |
| |
| sum_or_prod = comm_reducer( |
| lambda x, y: tvm.tir.Select(dummy < 0, x + y, x * y), |
| lambda t0: tvm.tir.Select(dummy < 0, tvm.tir.const(0, t0), tvm.tir.const(1, t0)), |
| ) |
| sum_and_prod = comm_reducer( |
| lambda x, y: (x[0] + y[0], x[1] * y[1]), |
| lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)), |
| ) |
| some_reducer1 = comm_reducer( |
| lambda x, y: ( |
| x[0] + y[0], |
| x[0] + y[0] + x[1] + y[1], |
| x[0] * y[2] + y[0] * x[2], |
| x[1] + y[2], |
| 4.0, |
| ), |
| lambda t0, t1, t2, t3, t4: ( |
| tvm.tir.const(0, t0), |
| tvm.tir.const(1, t1), |
| tvm.tir.const(2, t2), |
| tvm.tir.const(3, t3), |
| tvm.tir.const(4, t4), |
| ), |
| ) |
| |
| k = te.reduce_axis((0, 10), name="k") |
| A = te.placeholder((10,), name="A") |
| # Test that SimplifyCombiner makes use of vranges |
| ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4)) |
| ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k)) |
| ck.verify(sum_or_prod(A[k], k, init=1), te.sum(A[k], k, init=1)) |
| ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True) |
| ck.verify(sum_or_prod(A[k], k), prod(A[k], k)) |
| ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1)) |
| ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True) |
| ck.verify(sum_and_prod((A[k], A[10 - k]), k)[0], te.sum(A[k], k)) |
| ck.verify(sum_and_prod((A[k], A[10 - k]), k)[1], prod(A[10 - k], k)) |
| |
| reference_simplified_sources = [ |
| [A[0]], |
| [A[0], A[1]], |
| [A[0], A[2]], |
| [A[0], A[1], A[2], A[3]], |
| [A[4]], |
| ] |
| for j in range(5): |
| # Here we use the j-th component of the result, so only it and the components it |
| # depends on are left. |
| simplified = ck.analyzer.canonical_simplify( |
| some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j] |
| ) |
| |
| # Check that the remaining components are the expected ones. |
| for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): |
| tvm.ir.assert_structural_equal(lhs, rhs) |
| |
| # Test that components with side effects are not removed |
| dummy = tvm.ir.GlobalVar("dummy") |
| side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs) |
| ck.verify( |
| sum_and_prod((A[k], side_effect(A[10 - k])), k)[0], |
| sum_and_prod((A[k], side_effect(A[10 - k])), k)[0], |
| ) |
| ck.verify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0], te.sum(side_effect(A[k]), k)) |
| |
| |
| def test_reduce_simplify(): |
| ck = CanonicalChecker() |
| k = te.reduce_axis((0, 10), name="k") |
| j = te.reduce_axis((-5, 3), name="j") |
| A = te.placeholder((10,), name="A") |
| ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j])) |
| ck.verify(te.sum(A[3], []), A[3]) |
| ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype="float32")) |
| # The rule below is not typical, removed for now |
| ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k)) |
| |
| |
| def test_simplify_if_then_else(): |
| ck = CanonicalChecker() |
| x = te.var("x") |
| y = te.var("y") |
| tdiv = tvm.tir.truncdiv |
| tmod = tvm.tir.truncmod |
| # simplification that takes condition into account. |
| res = tvm.tir.if_then_else( |
| (x * 4 + y) >= 466036, |
| tvm.tir.if_then_else( |
| 24512 <= tmod(((x * 4) + y) - 466036, 24528), |
| tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16), |
| x, |
| ), |
| y, |
| ) |
| |
| res2 = tvm.tir.if_then_else( |
| (x * 4) >= 466036 - y, |
| tvm.tir.if_then_else( |
| 24512 <= tmod(((x * 4) + y) - 466036, 24528), |
| tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16), |
| x, |
| ), |
| y, |
| ) |
| expected = tvm.tir.if_then_else( |
| tvm.tir.LE(466036, (x * 4 + y)), |
| tvm.tir.if_then_else( |
| tvm.tir.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x |
| ), |
| y, |
| ) |
| ck.verify(res, expected) |
| ck.verify(res2, expected) |
| # can only simplify if condition |
| res = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) |
| expected = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) |
| ck.verify(res, ck.analyzer.canonical_simplify(expected)) |
| |
| res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0) |
| expected = tvm.tir.Select(x >= 10, x, 0) |
| ck.verify(res, ck.analyzer.canonical_simplify(expected)) |
| |
| res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0) |
| ck.verify(res, 0) |
| |
| |
| def test_complex_cases(): |
| ck = CanonicalChecker() |
| x = te.var("x") |
| y = te.var("y") |
| tdiv = tvm.tir.truncdiv |
| tmod = tvm.tir.truncmod |
| res2 = ( |
| tdiv(tdiv(tmod(x * 128 + y, 1296), 36) * 2 + 1, 2) * 36 |
| + tdiv(tmod((x * 128) + y, 36) * 2 + 1, 2) |
| - tmod((x * 128) + y, 1296) |
| + 1 |
| ) |
| ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5)) |
| ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127)) |
| ck.verify(res2, 1) |
| |
| ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True) |
| res3 = ( |
| tdiv(x * 1024 + y, 65536) |
| + tdiv(tmod(x * 1024 + y, 65536), 256) |
| + tdiv(tmod(x * 1024 + y, 256), 16) |
| + tmod(x * 1024 + y, 16) |
| - tdiv(y, 256) |
| - tdiv(tmod(y, 256), 16) |
| - tmod(y, 16) |
| - (x * 4) |
| ) |
| ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4)) |
| |
| |
| def test_simplify_cast(): |
| ck = CanonicalChecker() |
| tcast = tvm.tir.Cast |
| fld = tvm.te.floordiv |
| flm = tvm.te.floormod |
| # cast(i64, i + j + 1) - cast(i64, i) |
| i = te.var("i", dtype="int32") |
| j = te.var("j", dtype="int32") |
| res = tcast("int64", i + j + 1) - tcast("int64", i) |
| ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64")) |
| # cast(i32, i + j + 1) - cast(i32, i) |
| i = te.var("i", dtype="int64") |
| j = te.var("j", dtype="int64") |
| ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10)) |
| ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) |
| res = tcast("int32", i + j + 1) - tcast("int32", i) |
| ck.verify(res, tcast("int32", j) + 1) |
| # cast(i32, i + j - 100) |
| i = te.var("i", dtype="int64") |
| j = te.var("j", dtype="int64") |
| ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1)) |
| ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) |
| res = tcast("int32", i + j - 100) |
| ck.verify(res, res) |
| # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32 |
| # - cast(i32, flm(axis, 7i64) * 2i64) |
| axis = te.var("axis", dtype="int64") |
| ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42)) |
| res = ( |
| tcast( |
| "int32", |
| flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64") |
| + tvm.tir.const(1, "int64"), |
| ) |
| + tvm.tir.const(1, "int32") |
| - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")) |
| ) |
| ck.verify(res, 2) |
| |
| |
| def test_simplify_normalize_min_value_expr(): |
| ck = CanonicalChecker() |
| x = te.var("x", "int32") |
| |
| ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) |
| ck.verify(te.min_value("int32") + x == 0, tir.const(False)) |
| ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) |
| ck.verify(0 == te.min_value("int32") + x, tir.const(False)) |
| ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) |
| ck.verify(x + te.min_value("int32") == 0, tir.const(False)) |
| ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) |
| ck.verify(0 == x + te.min_value("int32"), tir.const(False)) |
| |
| |
| def test_proddiv_simplify(): |
| ck = CanonicalChecker() |
| flm = tvm.te.floormod |
| fld = tvm.te.floordiv |
| tdiv = tvm.te.truncdiv |
| tmod = tvm.te.truncmod |
| |
| x, y, z = te.var("x"), te.var("y"), te.var("y") |
| |
| ck.verify(flm(x * 32 * x, x), 0) |
| ck.verify(flm(z * x * 32 * x * y, x * z), 0) |
| ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0) |
| ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2)) |
| ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x) |
| |
| ck.verify(tmod(x * 32 * x, x), 0) |
| ck.verify(tmod(z * x * 32 * x * y, x * z), 0) |
| ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2)) |
| ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x) |
| |
| ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2)) |
| ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2) |
| ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z)) |
| |
| ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2)) |
| ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2) |
| ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z)) |
| |
| |
| def test_floormod_two(): |
| ck = CanonicalChecker() |
| flm = tvm.te.floormod |
| x, y = te.var("x"), te.var("y") |
| ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) |
| |
| |
| def test_simplify_le(): |
| ck = CanonicalChecker() |
| # Case 1. Ignore the extra expr if it's small than the division number |
| x, y, z = te.var("x"), te.var("y"), te.var("z") |
| ck.analyzer.bind(y, tvm.ir.Range(0, 8)) |
| ck.analyzer.bind(z, tvm.ir.Range(0, 2)) |
| ck.verify(x * 8 + y < 16, x < 2) |
| ck.verify(x * 8 + z * 4 < 16, x < 2) |
| ck.verify(x * 8 + z * 4 < 16, x < 2) |
| |
| # TODO: Not sure why `-2 < x` will be convert to `x > -2`, use a explicit simplify here. |
| ck.verify(x * -8 + y < 16, ck.analyzer.rewrite_simplify(-2 < x)) |
| ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x)) |
| |
| ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16) |
| |
| n = te.size_var("n") |
| ck.verify(x * 8 + y < n, x * 8 + y < n) |
| |
| # Case 2. Simplify the extra expr |
| x1, x2, ty, tx, vec = ( |
| tvm.te.var("x1"), |
| tvm.te.var("x2"), |
| tvm.te.var("ty"), |
| tvm.te.var("tx"), |
| tvm.te.var("vec"), |
| ) |
| ck.analyzer.bind(x1, tvm.ir.Range(0, 2)) |
| ck.analyzer.bind(x2, tvm.ir.Range(0, 3)) |
| ck.analyzer.bind(ty, tvm.ir.Range(0, 8)) |
| ck.analyzer.bind(tx, tvm.ir.Range(0, 32)) |
| ck.analyzer.bind(vec, tvm.ir.Range(0, 8)) |
| ck.verify( |
| x1 * 5632 + (((x2 * 8 + ty) * 32 + tx) * 8 + vec) % 5632 < 11008, |
| x1 * 22 + (x2 * 8 + ty) % 22 < 43, |
| ) |
| ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8) |
| |
| # Case 3. No failure |
| x, y, z = te.var("x"), te.var("y"), te.var("z") |
| ck.analyzer.bind(y, tvm.ir.Range(0, 1024)) |
| ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |