blob: a36fd214794b2183581400c264bec3cfc26aee6f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import te
from tvm.tir.buffer import decl_buffer
def test_deduce():
a = te.var("a")
b = te.var("b")
c = te.var("c")
d = te.var("d")
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(10, 15)
d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.tir.const(0, "int32")
fdiv = tvm.te.floordiv
e0 = (-b) * a + c - d
res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d - c, b * -1)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
e0 = d * a + c - d
res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d - c, d)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
e1 = a * 4 + b < c
res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = fdiv(c - 1 - b, 4)
tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
# expression containing variable a is on rhs
e1 = c > a * 4 + b
res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
e2 = tvm.te.max(5, a * 4) < 0
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
# expression containing variable a is on rhs
e2 = zero < tvm.te.max(5, a * 4)
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
e3 = (-b) + a * c - d
res3 = tvm.arith.deduce_bound(a, e3 >= 0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = fdiv(2, c) + 1
tvm.testing.assert_prim_expr_equal(res3.min_value, ans3)
res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
tvm.testing.assert_prim_expr_equal(res3.min_value, ans3)
# tests for `EQ` op
res4 = tvm.arith.deduce_bound(a, a == b, {}, {})
tvm.testing.assert_prim_expr_equal(res4.max_value, b)
tvm.testing.assert_prim_expr_equal(res4.min_value, b)
# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
assert str(res5.max_value) == "neg_inf"
assert str(res5.min_value) == "pos_inf"
# variable `a` on the RHS side
res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
tvm.testing.assert_prim_expr_equal(res6.max_value, 10)
tvm.testing.assert_prim_expr_equal(res6.min_value, 10)
# Add, Sub in `EQ`
e4 = (a - c) == (b + d)
ans4 = b + d + c
res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
tvm.testing.assert_prim_expr_equal(res7.max_value, ans4)
tvm.testing.assert_prim_expr_equal(res7.min_value, ans4)
# Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {})
tvm.testing.assert_prim_expr_equal(res8.max_value, -2)
tvm.testing.assert_prim_expr_equal(res8.min_value, -2)
# Unsatisfiable Mul in `EQ`
e5 = 4 * a == b
res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf"
res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})
# simplifier is now able to prove symbolic relation (b * a % b == 0)
tvm.testing.assert_prim_expr_equal(res10.max_value, 1)
tvm.testing.assert_prim_expr_equal(res10.min_value, 1)
def test_check():
a = te.var("a")
b = te.var("b")
c = te.var("c")
d = te.var("d")
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(5, 7)
d_s = tvm.arith.IntervalSet(-3, -1)
# no compare operator
res1 = tvm.arith.deduce_bound(a, a + b, {b: b_s}, {})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.deduce_bound(a, (a + b > 3).astype(c.dtype) > c, {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
res2 = tvm.arith.deduce_bound(a, a * 2 - a > b, {b: b_s}, {})
assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = te.var("a")
b = te.var("b")
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = b + a * coff + 3
res1 = tvm.arith.deduce_bound(a, e0 < 17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") >= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True)
res1 = tvm.arith.deduce_bound(a, e0 >= 17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True)
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = te.var("a")
b = te.var("b")
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = (b * 3 + a * coff) * 4
res1 = tvm.arith.deduce_bound(a, e0 < 63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) < 63, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") >= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) <= 63, True)
res1 = tvm.arith.deduce_bound(a, e0 > 63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) > 63, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) >= 63, True)
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
def test_deduce_non_support():
a = te.var("a")
def test_non_support(lhs):
res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
assert res.is_nothing()
test_non_support(tvm.tir.floormod(a, 16))
test_non_support(tvm.tir.Min(a, 16))
test_non_support(tvm.tir.Max(a, 16))
test_non_support(tvm.tir.LE(a, 16))
test_non_support(tvm.tir.LT(a, 16))
test_non_support(tvm.tir.GE(a, 16))
test_non_support(tvm.tir.GT(a, 16))
test_non_support(tvm.tir.EQ(a, 16))
test_non_support(tvm.tir.NE(a, 16))
test_non_support(tvm.tir.log(a))
test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))
def test_deduce_floordiv():
def do_test(gen_expr, dom_map, expect_min, expect_max):
a = te.var("a")
expr = gen_expr(a)
res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
if isinstance(expect_min, str):
assert str(res.min_value) == expect_min
else:
tvm.testing.assert_prim_expr_equal(res.min_value, expect_min)
if isinstance(expect_max, str):
assert str(res.max_value) == expect_max
else:
tvm.testing.assert_prim_expr_equal(res.max_value, expect_max)
# test basic cases
do_test(lambda a: a // 8 > 3, {}, 32, "pos_inf")
do_test(lambda a: a // 8 >= 3, {}, 24, "pos_inf")
do_test(lambda a: a // 8 < 3, {}, "neg_inf", 23)
do_test(lambda a: a // 8 <= 3, {}, "neg_inf", 31)
do_test(lambda a: a // 8 == 3, {}, "pos_inf", "neg_inf")
do_test(lambda a: a // 8 > -3, {}, -16, "pos_inf")
do_test(lambda a: a // 8 >= -3, {}, -24, "pos_inf")
do_test(lambda a: a // -8 > 3, {}, "neg_inf", -32)
do_test(lambda a: a // -8 >= 3, {}, "neg_inf", -24)
do_test(lambda a: a // -8 < 3, {}, -23, "pos_inf")
do_test(lambda a: a // -8 <= 3, {}, -31, "pos_inf")
do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")
# test nested cases
b = te.var("b")
bs = {b: tvm.arith.IntervalSet(2, 6)}
do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
do_test(lambda a: b * 3 + a // 8 > 63, bs, 464, "pos_inf")
do_test(lambda a: b * 3 + a // 8 >= 63, bs, 456, "pos_inf")
if __name__ == "__main__":
tvm.testing.main()