blob: 914402fb62ce8c6ee9e2d2c35afe7cdae5526b49 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te
def test_cast():
analyzer = tvm.arith.Analyzer()
x = te.var("x", dtype="int8")
m = analyzer.modular_set((x * 3).astype("uint32"))
assert m.coeff == 3
assert m.base == 0
m = analyzer.modular_set((x * 3 + 1).astype("float32").astype("int32"))
assert m.coeff == 3
assert m.base == 1
def test_add_sub():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x", "int64"), te.var("y", "int64")
m = analyzer.modular_set(x * 6 + y * 4)
assert m.coeff == 2
assert m.base == 0
analyzer.bind(y, x * 4 + 1)
m = analyzer.modular_set(1 - y)
assert m.coeff == 4
assert m.base == 0
def test_mul():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1))
assert m.coeff == 4
assert m.base == 2
def test_floormod():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256))
assert m.coeff == 4
assert m.base == 0
def test_div_shift():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
# not sure if x is non-negative
tdiv = tvm.tir.truncdiv
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 1
assert m.base == 0
# right shift always round down so it is fine
m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2
assert m.base == 1
fld = tvm.te.floordiv
m = analyzer.modular_set(fld(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
def test_mod():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
tmod = tvm.tir.truncmod
fmod = tvm.tir.floormod
# not sure if x is non-negative
m = analyzer.modular_set(tmod(x * 4 + 1, 4))
assert m.coeff == 1
assert m.base == 0
# no need to be positive if base == 0
m = analyzer.modular_set(tmod(x * 4, 4))
assert m.coeff == 4
assert m.base == 0
# floor mod tests
m = analyzer.modular_set(fmod(x * 4 + 3, 2))
assert m.coeff == 2
assert m.base == 1
m = analyzer.modular_set(fmod(x * 4 + 3, 8))
assert m.coeff == 4
assert m.base == 3
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set(tmod(x * 4 + 3, 2))
assert m.coeff == 2
assert m.base == 1
def test_min_max_select():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
m = analyzer.modular_set(tvm.te.min(x * 3, y * 9))
assert m.coeff == 3
assert m.base == 0
m = analyzer.modular_set(tvm.te.max(x * 3 + 1, y * 9 + 4))
assert m.coeff == 3
assert m.base == 1
m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2))
assert m.coeff == 1
assert m.base == 0
def test_mix_index():
a = te.var("a")
b = te.var("b")
analyzer = tvm.arith.Analyzer()
tdiv = tvm.tir.truncdiv
m = analyzer.modular_set(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4))
assert m.coeff == 2
assert m.base == 0
m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = analyzer.modular_set(a * 12 + tvm.te.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
def test_constraint_scope():
a = te.var("a")
b = te.var("b")
analyzer = tvm.arith.Analyzer()
tmod = tvm.tir.truncmod
with analyzer.constraint_scope(tmod(b, 4) == 2):
m = analyzer.modular_set(b + 1)
assert m.coeff == 4
assert m.base == 3
with analyzer.constraint_scope(tmod(a, 2) == 1):
m = analyzer.modular_set(b + a * 2)
assert m.coeff == 4
assert m.base == 0
m = analyzer.modular_set(b + a * 2)
assert m.coeff == 2
assert m.base == 0
m = analyzer.modular_set(b + 1)
assert m.coeff == 1
assert m.base == 0
def test_intersect():
a = te.var("a")
analyzer = tvm.arith.Analyzer()
tmod = tvm.tir.truncmod
with analyzer.constraint_scope(tmod(a, 4) == 1):
with analyzer.constraint_scope(tmod(a, 3) == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1
with analyzer.constraint_scope(tmod(a, 3) == 2):
with analyzer.constraint_scope(tmod(a, 5) == 3):
with analyzer.constraint_scope(tmod(a, 7) == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23
def test_let():
analyzer = tvm.arith.Analyzer()
x = te.var("x")
y = te.var("y")
m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
assert m.coeff == 10
assert m.base == 1
def test_bitwise_and():
analyzer = tvm.arith.Analyzer()
x = te.var("x")
y = te.var("y")
# RHS of bitwise_and is 2^p - 1
m = analyzer.modular_set((x * 16 + y * 4) & 31)
assert m.coeff == 4
assert m.base == 0
# arbitrary RHS
m = analyzer.modular_set((x * 16 + y * 4) & 17)
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
tvm.testing.main()