| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import random |
| import sys |
| import pytest |
| import tvm |
| from tvm import te, arith, ir, tir, testing |
| from tvm.script import tir as T |
| |
| |
| def test_solution_consistency(): |
| seed = random.randrange(sys.maxsize) |
| print( |
| "\nThis test is intentionally non-deterministic, " |
| "if it fails please report it in github issue together with this seed {}\n".format(seed) |
| ) |
| random.seed(seed) |
| |
| def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): |
| variables = [te.var("x" + str(i)) for i in range(num_vars)] |
| |
| relations = [] |
| for i in range(num_formulas): |
| s1 = sum([v * random.randint(coef[0], coef[1]) for v in variables]) |
| s1 += random.randint(coef[0], coef[1]) |
| s2 = sum([v * random.randint(coef[0], coef[1]) for v in variables]) |
| s2 += random.randint(coef[0], coef[1]) |
| if random.random() < 0.7: |
| op = tvm.tir.EQ |
| else: |
| # we also make sure it can correctly handle inequalities |
| op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT]) |
| relations.append(op(s1, s2)) |
| |
| vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} |
| solution = arith.solve_linear_equations(relations, variables, vranges) |
| |
| testing.check_int_constraints_trans_consistency(solution) |
| |
| # leaving some variables as parameters should also be ok |
| for k in [1, 2]: |
| if len(variables) > k: |
| solution = arith.solve_linear_equations(relations, variables[:-k], vranges) |
| param_ranges = {v: vranges[v] for v in variables[-k:]} |
| testing.check_int_constraints_trans_consistency(solution, param_ranges) |
| |
| for i in range(2): |
| _check(num_vars=1, num_formulas=1) |
| for i in range(2): |
| _check(num_vars=1, num_formulas=2) |
| |
| for i in range(2): |
| _check(num_vars=2, num_formulas=1) |
| for i in range(2): |
| _check(num_vars=2, num_formulas=2) |
| for i in range(2): |
| _check(num_vars=2, num_formulas=3) |
| |
| for i in range(3): |
| _check(num_vars=3, num_formulas=3, coef=(-2, 2)) |
| for i in range(3): |
| _check(num_vars=3, num_formulas=4, coef=(-2, 2)) |
| |
| for i in range(3): |
| _check(num_vars=4, num_formulas=3, coef=(-1, 1)) |
| |
| for i in range(3): |
| _check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4)) |
| for i in range(3): |
| _check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4)) |
| |
| |
| def test_empty_var_to_solve(): |
| x, y = te.var("x"), te.var("y") |
| equations = [ |
| tvm.tir.EQ(x + y, 20), |
| tvm.tir.EQ(x - y, 10), |
| ] |
| solution = arith.solve_linear_equations(equations) |
| assert len(solution.src_to_dst) == 0 |
| assert len(solution.dst_to_src) == 0 |
| assert len(solution.src.variables) == 0 |
| assert len(solution.src.ranges) == 0 |
| assert ir.structural_equal(solution.src.relations, equations) |
| assert ir.structural_equal(solution.src, solution.dst) |
| |
| |
| def test_unique_solution(): |
| x, y = te.var("x"), te.var("y") |
| |
| solution = arith.solve_linear_equations( |
| [ |
| tvm.tir.EQ(x + y, 20), |
| tvm.tir.EQ(x - y, 10), |
| ], |
| [x, y], |
| ) |
| assert list(solution.dst.variables) == [] |
| assert ir.structural_equal(solution.src_to_dst[x], T.int32(15)) |
| assert ir.structural_equal(solution.src_to_dst[y], T.int32(5)) |
| |
| |
| def test_low_rank(): |
| x, y, z = te.var("x"), te.var("y"), te.var("z") |
| ranges = {} |
| |
| solution = arith.solve_linear_equations( |
| [ |
| tvm.tir.EQ(x + y + z, 15), |
| tvm.tir.EQ(x + y, 10), |
| ], |
| [x, y, z], |
| ranges, |
| ) |
| [n0] = solution.dst.variables |
| assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) |
| assert ir.structural_equal(solution.src_to_dst[y], -n0) |
| assert ir.structural_equal(solution.src_to_dst[z], T.int32(5)) |
| |
| |
| def test_infer_range(): |
| x, y = te.var("x"), te.var("y") |
| ranges = { |
| x: tvm.ir.Range.from_min_extent(-5, 10), |
| y: tvm.ir.Range.from_min_extent(0, 10), |
| } |
| |
| solution = arith.solve_linear_equations( |
| [ |
| tvm.tir.EQ(x + y, 0), |
| ], |
| [x, y], |
| ranges, |
| ) |
| [n0] = solution.dst.variables |
| assert ir.structural_equal(solution.src_to_dst[x], n0) |
| assert ir.structural_equal(solution.src_to_dst[y], -n0) |
| # inferred from y's range |
| assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9)) |
| assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) |
| # additional inequality is added into the system for x |
| [ineq] = solution.dst.relations |
| assert isinstance(ineq, tvm.tir.LE) |
| assert ir.structural_equal(ineq.a, T.int32(-5)) |
| assert ir.structural_equal(ineq.b, n0) |
| |
| |
| def test_ill_formed(): |
| x, y = te.var("x"), te.var("y") |
| |
| solution = arith.solve_linear_equations( |
| [ |
| tvm.tir.EQ(x + y, 0), |
| tvm.tir.EQ(x - y, 0), |
| tvm.tir.EQ(x, 5), |
| ], |
| [x, y], |
| {}, |
| ) |
| assert list(solution.dst.variables) == [] |
| [rel] = solution.dst.relations |
| ir.assert_structural_equal(rel, tir.const(False)) |
| assert len(solution.src_to_dst) == 0 |
| assert len(solution.dst_to_src) == 0 |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |