blob: 8728df7e3f3abd2f0fd296ddb6528c56e0cb42df [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import pytest
import tvm
import tvm.testing
from tvm import te
from tvm.arith import ConstIntBound
NEG_INF = ConstIntBound.NEG_INF
POS_INF = ConstIntBound.POS_INF
class TestCase:
def __init__(self, expr, expected_bounds, known_bounds=None, constraint=None):
self.expr = expr
self.expected_bounds = expected_bounds
if known_bounds is None:
self.known_bounds = {}
else:
self.known_bounds = known_bounds
self.constraint = constraint
@property
def __name__(self):
return str(self.expr)
class BaseCompare:
def test_const_bounds(self, test_case):
analyzer = tvm.arith.Analyzer()
for var, bounds in test_case.known_bounds.items():
analyzer.update(var, ConstIntBound(*bounds))
assert analyzer.const_int_bound_is_bound(var)
with contextlib.ExitStack() as stack:
if test_case.constraint is not None:
stack.enter_context(analyzer.constraint_scope(test_case.constraint))
bounds = analyzer.const_int_bound(test_case.expr)
if test_case.expected_bounds[0] is None:
assert bounds.max_value == test_case.expected_bounds[1]
elif test_case.expected_bounds[1] is None:
assert bounds.min_value == test_case.expected_bounds[0]
else:
assert (bounds.min_value, bounds.max_value) == test_case.expected_bounds
class TestDataType(BaseCompare):
test_case = tvm.testing.parameter(
TestCase(te.var("x", dtype="int64"), (NEG_INF, POS_INF)),
TestCase(te.var("x", dtype="int8"), (-128, 127)),
TestCase(te.var("x", dtype="uint8"), (0, 255)),
TestCase(te.size_var("x", dtype="int32"), (0, POS_INF)),
)
class TestCastBound(BaseCompare):
x = te.var("x", dtype="int8")
tmod = tvm.tir.truncmod
test_case = tvm.testing.parameter(
TestCase(tmod(x, 3).astype("uint32"), (0, 2)),
TestCase(tmod(x, 3).astype("float32").astype("int32"), (-2, 2)),
)
class TestAddSubBound(BaseCompare):
x = te.var("x", "int64")
y = te.var("y", "int64")
test_case = tvm.testing.parameter(
TestCase(x + y, (NEG_INF, POS_INF)),
TestCase(x + y, (1, 14), known_bounds={x: (0, 4), y: (1, 10)}),
TestCase(x - y, (-10, 3), known_bounds={x: (0, 4), y: (1, 10)}),
TestCase(x - y, (-10, POS_INF), known_bounds={x: (0, POS_INF), y: (1, 10)}),
TestCase(1 - x, (NEG_INF, 1), known_bounds={x: (0, POS_INF), y: (1, 10)}),
)
@pytest.mark.xfail(reason="Not currently supported")
class TestBoundsUsingReciprocals(BaseCompare):
"""Special handling for differences of reciprocals
These terms can appear when comparing the number of operations for
different orderings of matrix multiplications, with A, B, and C
known to be positive values.
In these cases, comparing `(A+B)*C < A*B` is equivalent to
`1/A + 1/B < 1/C`. Working in terms of the reciprocals
allows the ConstIntBound analyzer to provide a tighter
bound for these differences than would otherwise be
available.
For `(A+B)*C - A*B`, the normal bottom-up integer bounds are unable to
provide the bounds required to provide these inequalities, because they
treat the terms as uncorrelated. That is, they assume that `(A+B)*C` may
achieve its minimum while `A*B` simultaneously achieves its maximum.
"""
A, B, C = [te.var(letter, "int64") for letter in "ABC"]
symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)}
asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)}
test_case = tvm.testing.parameter(
TestCase((A + B) * C - A * B, (2048, None), known_bounds=symmetric_bounds),
TestCase((A + B) * C - B * A, (2048, None), known_bounds=symmetric_bounds),
TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds),
TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=symmetric_bounds),
TestCase((A + B) * C - A * B, (2048, None), known_bounds=asymmetric_bounds),
TestCase((A + B) * C - B * A, (2048, None), known_bounds=asymmetric_bounds),
TestCase(A * B - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds),
TestCase(B * A - (A + B) * C, (None, -2048), known_bounds=asymmetric_bounds),
)
class TestMulBound(BaseCompare):
x, y = te.var("x"), te.var("y")
test_case = tvm.testing.parameter(
TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}),
TestCase(x * y, (-32, 24), {x: (-3, 4), y: (-8, 2)}),
TestCase(x * y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-8, 2)}),
)
class TestTruncDivBound(BaseCompare):
x, y = te.var("x"), te.var("y")
expr = tvm.tir.truncdiv(x, y)
test_case = tvm.testing.parameter(
TestCase(expr, (-2, None), {x: (-9, 4), y: (4, 10)}),
TestCase(expr, (-4, 9), {x: (-9, 4), y: (-2, 0)}),
TestCase(expr, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}),
TestCase(expr, (-9, 9), {x: (-9, 4), y: (-4, 12)}),
)
class TestTruncModBound(BaseCompare):
x, y = te.var("x"), te.var("y")
expr = tvm.tir.truncmod(x, y)
test_case = tvm.testing.parameter(
TestCase(expr, (-9, 4), {x: (-9, 4), y: (4, 10)}),
TestCase(expr, (-9, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}),
TestCase(expr, (0, 9), {x: (1, POS_INF), y: (4, 10)}),
)
class TestFloorDivBound(BaseCompare):
x, y = te.var("x"), te.var("y")
ux = te.var("x", dtype="uint32")
uy = te.var("y", dtype="uint32")
test_case = tvm.testing.parameter(
TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}),
TestCase(x // y, (-4, 9), {x: (-9, 4), y: (-2, 0)}),
TestCase(x // y, (NEG_INF, POS_INF), {x: (NEG_INF, 4), y: (-2, 1)}),
TestCase(x // y, (-9, 9), {x: (-9, 4), y: (-4, 12)}),
TestCase(ux // uy, (0, 4), {ux: (1, 4), uy: (0, 12)}),
)
class TestFloorModBound(BaseCompare):
x, y = te.var("x"), te.var("y")
test_case = tvm.testing.parameter(
TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}),
TestCase(x % y, (0, 9), {x: (NEG_INF, POS_INF), y: (4, 10)}),
TestCase(x % y, (0, 9), {x: (1, POS_INF), y: (4, 10)}),
)
class TestMinMaxBound(BaseCompare):
x, y = te.var("x"), te.var("y")
test_case = tvm.testing.parameter(
TestCase(tvm.te.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}),
TestCase(tvm.te.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: (4, 10)}),
TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: (4, 10)}),
TestCase(tvm.te.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 10)}),
)
class TestSelectBound(BaseCompare):
x, y = te.var("x"), te.var("y")
test_case = tvm.testing.parameter(
TestCase(
tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1),
(0, 11),
{x: (-9, 11), y: (4, 10)},
),
)
class TestShiftAndBound(BaseCompare):
x, y = te.var("x"), te.var("y")
test_case = tvm.testing.parameter(
TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}),
TestCase(x & y, (0, 10), {x: (-9, 11), y: (2, 10)}),
TestCase(x & y, (0, 10), {x: (10, 11), y: (2, 10)}),
)
class TestMixIndexBound(BaseCompare):
x, y = te.var("x"), te.var("y")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
test_case = tvm.testing.parameter(
TestCase(tmod(x, 8) + tdiv(x, 8) * 8, (0, 24 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}),
TestCase(y + x * 3, (0, 24 * 3 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}),
TestCase(
tmod(x, 7) + tdiv(x, 7) * 7, (0, (23 // 7) * 7 + 6), {x: (0, 24 - 1), y: (0, 3 - 1)}
),
)
class TestLetBound(BaseCompare):
x = te.var("x")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)),
)
class TestFloorModNegativeDivisor(BaseCompare):
flm, fld = tvm.te.floormod, tvm.te.floordiv
a, b = te.var("a"), te.var("b")
test_case = tvm.testing.parameter(
TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}),
)
class TestDivModAssumeNoZeroDivisor(BaseCompare):
"""Divmod non negative expression makes assumption that divide by
zero won't occur this assumption is important to get best result
from symbolic shape programs
"""
a, b = te.var("a"), te.var("b")
test_case = tvm.testing.parameter(
TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}),
TestCase(a % b, (0, 6), {a: (0, 6), b: (0, POS_INF)}),
)
class TestMultipleCondition(BaseCompare):
a = te.var("a")
test_case = tvm.testing.parameter(
TestCase(
a % 58 - 1,
(0, None),
known_bounds={a: (0, 128)},
constraint=tvm.tir.all(1 <= a % 58, a % 58 < 57),
),
)
class TestBroadcastBound(BaseCompare):
a = te.var("a")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}),
)
class TestRampBound(BaseCompare):
a = te.var("a")
test_case = tvm.testing.parameter(
TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 128)}),
)
class TestModularSetBound(BaseCompare):
analyzer = tvm.arith.Analyzer()
tx = tvm.te.var("tx", dtype="int32")
bx = tvm.te.var("bx", dtype="int32")
expr = (bx * 2048 + tx * 16) % 7168
test_case = tvm.testing.parameter(
TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}),
)
if __name__ == "__main__":
tvm.testing.main()