blob: f34dce5c86fd69a0762971fc886ee0393297e96c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm.tir import floordiv, floormod
from tvm.script import tir as T
def ifuse(inputs, pred_extent=None):
"""Fuse iterators"""
value, extent = 0, 1
for i, ext in inputs:
value = value * ext + i
extent = extent * ext
return value, extent if pred_extent is None else pred_extent
def isplit(axis, factor):
"""Split iterators"""
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
return [
(fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)),
(flm(axis[0], factor), factor),
]
def var_dom(iters):
"""Get domains of iterators"""
return {var: tvm.ir.Range(0, ext) for var, ext in iters}
def convert_iter_expr(expr):
return tvm.arith.normalize_iter_map_to_expr(expr)
def assert_iter_sum_pattern(
expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True
):
keys = list(expect_dict.keys())
res = tvm.arith.detect_iter_map(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
indices = res.indices
assert len(indices) == len(keys), res.errors
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
extent,
base,
) = spec[0:2]
scale = spec[2] if len(spec) > 2 else 1
expect_iter = spec[3] if len(spec) > 3 else None
sum_expr = indices[i]
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
if extent == 1:
assert len(sum_expr.args) == 0
else:
assert len(sum_expr.args) == 1
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent)
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale)
tvm.testing.assert_prim_expr_equal(sum_expr.base, base)
if expect_iter is not None:
if not isinstance(expect_iter, tvm.arith.IterMapExpr):
sum_expr = convert_iter_expr(sum_expr)
tvm.ir.assert_structural_equal(sum_expr, expect_iter)
def assert_iter_map_simplify(
expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True
):
keys = list(expect_dict.keys())
imap = tvm.arith.detect_iter_map(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
res = tvm.arith.iter_map_simplify(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
for i, input_expr in enumerate(keys):
expected_expr = expect_dict[input_expr]
tvm.ir.assert_structural_equal(res[i], expected_expr)
def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"):
res = tvm.arith.detect_iter_map(
list(iters), dom_map, predicate=predicate, check_level=check_level
).indices
assert len(res) == 0
def test_trivial():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
dom_map = var_dom([(x, 3), (y, 4), (z, 1)])
assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map)
assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map)
# not independent
assert_iter_sum_failure([x, x, 3], dom_map)
assert_iter_sum_pattern(
{x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True
)
assert_iter_sum_pattern(
{x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False
)
assert_iter_sum_failure([x, z], dom_map, check_level="bijective")
def test_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
c = tvm.tir.SizeVar("c", "int32")
c0 = tvm.tir.SizeVar("c0", "int32")
assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)]))
assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)]))
# fuse with symbolic factor
assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)]))
# duplication
assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)]))
assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)]))
# factor mismatch
assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)]))
# simple stride pattern
assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)]))
# simple stride pattern with symbolic
assert_iter_sum_pattern(
{x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)])
)
def test_split():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
c0 = tvm.tir.SizeVar("c0", "int32")
c1 = tvm.tir.SizeVar("c1", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)]))
assert_iter_sum_pattern(
{fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)])
)
# simple symbolic bound
# TODO(tvm-team) improve symbolic divisible check to enable
# more complicated symbolic bound
assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)]))
assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)]))
assert_iter_sum_pattern(
{
fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2),
},
var_dom([(x, 8)]),
)
assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
# domain of x is undefined
assert_iter_sum_pattern(
{fld(flm(x, 49) + y, 49): (1, fld(flm(x, 49) + y, 49))}, var_dom([(y, 1)])
)
def test_compound():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
xo, xi = isplit((x, 10), 5)
yo, yi = isplit((y, 9), 3)
z = ifuse([yo, xo, yi])
# reconstruct the pattern manually
mx = tvm.arith.IterMark(x, 10)
my = tvm.arith.IterMark(y, 9)
xoscale = 3
yoscale = 6
yiscale = 1
mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale)
myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale)
myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale)
mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18)
sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0)
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))
def test_compound_floormod_two_regression():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
# regression
# extent of 2 of negative scale cannot be normalized
assert_iter_sum_failure(
[fld(x, 2) * 2 - flm(x, 2) + 1],
dom_map=var_dom([(x, 8)]),
)
def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
# available contraints
# upper bound only
assert_iter_sum_pattern(
{x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128
)
assert_iter_sum_pattern(
{x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127
)
# lower bound only
assert_iter_sum_pattern(
{x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5
)
assert_iter_sum_pattern(
{x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6
)
# lower bound + upper bound
assert_iter_sum_pattern(
{x * 10 + y: (122, 6)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128),
)
assert_iter_sum_pattern(
{x * 10 + y: (122, 6)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127),
)
assert_iter_sum_pattern(
{x * 64 + y * 4 + z: (16, 16)},
var_dom([(x, 16), (y, 16), (z, 4)]),
predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y),
)
# constraint on one fused iter
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
k = tvm.tir.Var("k", "int32")
assert_iter_sum_pattern(
{i * 8 + j * 2 + k: (88, 1)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9),
)
# constraint on single var
assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10)
# iterations are subparts of constraint, invalid case 1
assert_iter_sum_failure(
[i, j, k],
var_dom([(i, 128), (j, 128), (k, 128)]),
predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100),
)
# iterations are subparts of constraint, invalid case 2
assert_iter_sum_failure(
[i * 128 + j, k],
var_dom([(i, 128), (j, 128), (k, 128)]),
predicate=i * 16384 + j * 128 + k < 100,
)
# irrelavant predicate
assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24)
# constraint on nested fused iters
assert_iter_sum_pattern(
{i * 8 + j * 2 + k: (22, 3)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(
1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25
),
)
# duplicate constraint on one fused iter
assert_iter_sum_pattern(
{i * 6 + j * 2 + k: (66, 2)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9),
)
# duplicate constraint on nested fused iters
assert_iter_sum_pattern(
{i * 6 + j * 2 + k: (15, 3)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(
1 <= j * 2 + k,
2 <= j * 2 + k,
j * 2 + k < 8,
j * 2 + k < 9,
3 <= i * 6 + j * 2 + k,
i * 6 + j * 2 + k < 25,
1 <= i * 6 + j * 2 + k,
i * 6 + j * 2 + k < 18,
),
)
# constraint on non-disjoint fused iters should fail
assert_iter_sum_failure(
[i * 8 + j * 2 + k],
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j),
)
# constraint with differnent lower bound
assert_iter_sum_pattern(
{
(i * 16 + j) // 23 * 8
+ (i * 16 + j) % 23
- 15: (
64,
0,
1,
(i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)),
)
},
var_dom([(i, 12), (j, 16)]),
predicate=tvm.tir.And(
tvm.tir.And(
i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23)
),
tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23),
),
)
# constraint on many disjoint fused iters, case 1
# i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
# i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)
# i1 * 60 in [60, 240), extent=180 (= scale of i0)
i0 = tvm.tir.Var("i0", "int32")
i1 = tvm.tir.Var("i1", "int32")
i2 = tvm.tir.Var("i2", "int32")
i3 = tvm.tir.Var("i3", "int32")
i4 = tvm.tir.Var("i4", "int32")
i5 = tvm.tir.Var("i5", "int32")
assert_iter_sum_pattern(
{i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)},
var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]),
predicate=tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5),
)
# constraint on many disjoint fused iters, case 2
assert_iter_sum_pattern(
{i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)},
var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]),
predicate=tvm.tir.all(
3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10
),
)
# constraint on split iters
assert_iter_sum_pattern(
{i % 16: (7, 3), i // 16: (8, 4)},
var_dom([(i, 1024)]),
predicate=tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12),
check_level="bijective",
)
# constraint on split iters, nested case 1
assert_iter_sum_pattern(
{(i * 32 + j) % 16: (7, 3)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10),
)
# constraint on split iters, nested case 2
assert_iter_sum_failure(
[
(i * 32 + j) % 16,
],
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32),
check_level="bijective",
)
assert_iter_sum_pattern(
{(i * 32 + j) % 16: (16, 0)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32),
)
assert_iter_sum_pattern(
{(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64),
)
# non-standard form of predicate
assert_iter_sum_pattern(
{x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y
)
# duplicate constraint
assert_iter_sum_pattern(
{x * 10 + y: (64, 0)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64),
)
# useless constraint
assert_iter_sum_pattern(
{x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140
)
i1 = tvm.tir.Var("i1", "int32")
i2 = tvm.tir.Var("i2", "int32")
i3 = tvm.tir.Var("i3", "int32")
i4 = tvm.tir.Var("i4", "int32")
assert_iter_sum_pattern(
{i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)},
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 10,
)
),
)
# wrong constraint
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 7,
)
),
)
# incompatible constraint
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 10,
i1 * 4 + i3 < 20,
)
),
)
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i1 * 4 + i3 < 20,
)
),
)
# zero iter
xo = tvm.tir.Var("xo", "int32")
xi = tvm.tir.Var("xi", "int32")
y = tvm.tir.Var("y", "int32")
assert_iter_sum_pattern(
{xo * 129 + xi: (128, 0), y: (128, 0)},
var_dom([(xo, 1), (xi, 129), (y, 128)]),
predicate=xo * 129 + xi < 128,
)
# strided iteration predicate
assert_iter_sum_pattern(
{xo * 16 + xi * 4: (10, 0, 4)},
var_dom([(xo, 3), (xi, 4)]),
predicate=xo * 4 + xi < 10,
)
def convert_division(divisions):
if divisions is None or len(divisions) == 0:
return []
res = []
for division in divisions[:-1]:
res.append(
[
tvm.arith.normalize_iter_map_to_expr(division[0].source),
tvm.arith.normalize_iter_map_to_expr(division[1].source),
]
)
res.append([divisions[-1][0].extent, divisions[-1][1].extent])
return res
def create_iter(name, extent):
return tvm.tir.Var(name, "int32"), extent
def test_subspace_division():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
c = tvm.tir.SizeVar("c", "int32")
# simple 1.1
res = tvm.arith.subspace_divide(
[z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x]
)
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
tvm.ir.assert_structural_equal(res[0][1], x + c)
# simple 1.2
res = tvm.arith.subspace_divide(
[z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18
)
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
tvm.ir.assert_structural_equal(res[0][1], x + c)
tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18)
tvm.ir.assert_structural_equal(res[1][1], T.bool(True))
# compound 1
i0 = create_iter("i0", 4)
j0 = create_iter("j0", 8)
i3 = create_iter("i3", 2)
i1, i2 = isplit(j0, 4)
k0 = ifuse([i0, i1])
k1 = ifuse([i2, i3])
# compound 1.1
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]])
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][1], i3[0])
assert_iter_sum_pattern
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices
assert len(res2) == 2
# compound 1.2
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]])
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], i0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices
assert len(res2) == 2
# compound 1.3
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]])
res = convert_division(res)
assert len(res) == 0
# compound 1.4
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][1], i3[0])
tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7)
tvm.ir.assert_structural_equal(res[2][1], T.bool(True))
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices
assert len(res2) == 2
# compound 1.5
res = tvm.arith.subspace_divide(
[k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7
)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], i0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])
tvm.ir.assert_structural_equal(res[2][0], T.bool(True))
tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7)
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices
assert len(res2) == 2
# compound 1.6
res = tvm.arith.subspace_divide(
[k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7)
)
res = convert_division(res)
assert len(res) == 0
# compound 2
j0 = create_iter("j0", 4)
l0 = create_iter("l0", 2)
l1 = create_iter("l1", 6)
j3 = create_iter("j3", 3)
k0 = ifuse([l0, l1])
i1, j2 = isplit(k0, 3)
j1, i1 = isplit(i1, 2)
i0 = ifuse([j0, j1])
i2 = ifuse([j2, j3])
# compound 2.1
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]]
)
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices
assert len(res2) == 3
# compound 2.2
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]]
)
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], j0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0])
res1 = tvm.arith.detect_iter_map(
[res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])
).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices
assert len(res2) == 3
# compound 2.3
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]]
)
res = convert_division(res)
assert len(res) == 0
# compound 2.4
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]],
var_dom([j0, l0, l1, j3]),
[l1[0], j3[0]],
tvm.tir.all(i0[0] < 7, i2[0] < 8),
)
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])
tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7)
tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8)
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices
assert len(res2) == 3
# compound 2.5
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8
)
res = convert_division(res)
assert len(res) == 0
def test_subspace_divide_trivial_iters():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
# trivial 1.1
res = tvm.arith.subspace_divide(
[x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False
)
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], x)
tvm.ir.assert_structural_equal(res[0][1], y)
# trivial 1.2
res = tvm.arith.subspace_divide(
[x, y],
var_dom([(x, 1), (y, 1)]),
[y],
simplify_trivial_iterators=False,
)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], x)
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], y)
def test_complex():
n0 = create_iter("n0", 2)
n1 = create_iter("n1", 4)
m0 = ifuse([n0, n1], 6)
m1 = create_iter("m1", 3)
l0 = create_iter("l0", 4)
l1 = create_iter("l1", 8)
l2 = ifuse([m0, m1], 16)
l3 = create_iter("l3", 32)
k0, k4 = isplit(l0, 2)
k1, k5 = isplit(l1, 2)
k2, k6 = isplit(l2, 4)
k3, k7 = isplit(l3, 4)
j0 = ifuse([k0, k1], 7)
j1 = ifuse([k2, k3])
j2 = ifuse([k4, k5])
j3 = ifuse([k6, k7], 15)
i0 = ifuse([j0, j1], 200)
i1 = ifuse([j2, j3], 50)
n0_mark = tvm.arith.IterMark(n0[0], n0[1])
n1_mark = tvm.arith.IterMark(n1[0], n1[1])
l0_mark = tvm.arith.IterMark(l0[0], l0[1])
l1_mark = tvm.arith.IterMark(l1[0], l1[1])
m1_mark = tvm.arith.IterMark(m1[0], m1[1])
l3_mark = tvm.arith.IterMark(l3[0], l3[1])
m0_expr = tvm.arith.IterSumExpr(
[
tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4),
tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1),
],
0,
)
m0_mark = tvm.arith.IterMark(m0_expr, 6)
l2_expr = tvm.arith.IterSumExpr(
[tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)],
0,
)
l2_mark = tvm.arith.IterMark(l2_expr, 16)
k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4)
k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1)
k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8)
k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1)
k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30)
k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15)
k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4)
k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1)
j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0)
j0_mark = tvm.arith.IterMark(j0_expr, 7)
i0_expr = tvm.arith.IterSumExpr(
[tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0
)
j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0)
j3_mark = tvm.arith.IterMark(j3_expr, 15)
i1_expr = tvm.arith.IterSumExpr(
[k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0
)
i0_mark = tvm.arith.IterMark(i0_expr, i0[1])
i1_mark = tvm.arith.IterMark(i1_expr, i1[1])
i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0)
i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0)
assert_iter_sum_pattern(
{i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)},
var_dom([l0, l1, n0, n1, m1, l3]),
predicate=tvm.tir.all(
i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15
),
)
# wrong constraint
assert_iter_sum_failure(
[i0[0], i1[0]],
var_dom([l0, l1, n0, n1, m1, l3]),
tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14),
)
# subspace_division
res = tvm.arith.subspace_divide(
[i0[0], i1[0]],
var_dom([l0, l1, n0, n1, m1, l3]),
[n0[0], n1[0], m1[0], l3[0]],
tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15),
)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2))
tvm.ir.assert_structural_equal(
res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4)
)
tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2)))
tvm.ir.assert_structural_equal(
res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4))
)
tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7)
tvm.ir.assert_structural_equal(
res[2][1],
tvm.tir.all(
n0[0] * 4 + n1[0] < 6,
(n0[0] * 4 + n1[0]) * 3 + m1[0] < 16,
floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15,
),
)
assert_iter_sum_pattern(
{res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1]
)
assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1]))
def test_normalize_iter_map_to_expr():
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
xo, xi = isplit((x, 10), 5)
yo, yi = isplit((y, 9), 3)
z = ifuse([yo, xo, yi])
res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)]))
tvm.ir.assert_structural_equal(
tvm.arith.normalize_iter_map_to_expr(res.indices[0]),
fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3),
)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5))
# iter mark wrap a complex expr
split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1)
def test_inverse_affine_iter_map():
analyzer = tvm.arith.Analyzer()
l0 = create_iter("l0", 64)
l1 = create_iter("l1", 64)
l2 = create_iter("l2", 64)
# simple case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l0_1_l1_1_fused = ifuse([l0_1, l1_1])
iter_map = tvm.arith.detect_iter_map(
[l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])
).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 2
l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16
l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
# compound case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l2_1, l2_2 = isplit(l2, 4)
l2_0, l2_1 = isplit(l2_1, 4)
l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])
iter_map = tvm.arith.detect_iter_map(
[l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])
).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 3
l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16
l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
l2_inverse = (
floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
)
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
assert analyzer.can_prove_equal(res[l2[0]], l2_inverse)
# diamond-shape DAG
l0_0, l0_1 = isplit(l0, 16)
l1 = ifuse([l0_1, l0_0])
l1_0, l1_1 = isplit(l1, 8)
l2 = ifuse([l1_1, l1_0])
iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 1
l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8)
l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4)
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
def test_inverse_affine_map_trivial_iter():
analyzer = tvm.arith.Analyzer()
l0 = create_iter("l0", 64)
l1 = create_iter("l1", 64)
iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, l1])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
# output_0 is expected to be constant and it is not included in the inverse map
assert len(res) == 2
assert analyzer.can_prove_equal(res[l0[0]], outputs[1])
assert analyzer.can_prove_equal(res[l1[0]], outputs[2])
def test_free_variables():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
# illegal iter if z is within dom
assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)]))
# iter is valid if z is free, even there are linear forms of z
assert_iter_sum_pattern(
{z * 19 + y * 3 + x: (9, z * 19)},
var_dom(
[
(x, 3),
(y, 3),
]
),
)
assert_iter_sum_pattern(
{z * z + y * 3 + x: (9, z * z)},
var_dom(
[
(x, 3),
(y, 3),
]
),
)
class TestPadding:
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
positive_test_case = tvm.testing.parameter(
# left padding only, offset divisible
({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"),
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32): (6, 2, 1)}),
({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}),
# right padding only, offset non-divisible
({x: 26}, {fld(x, 15): (2, 0, 1)}),
({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}),
# padding constants on both side
({x: 45}, {fld(x + 71, 32): (2, 2, 1)}),
({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}),
# padding for free iteration part
({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}),
({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
),
# different offsets on splits
(
{x: 240},
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
),
)
negative_test_case = tvm.testing.parameter(
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}),
({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
),
# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
({x: 3}, {flm(x, 16)}),
# (x % c1) // c2 is not proved as surjective if c1 % c2 != 0
({x: 255}, {fld(flm(x, 255), 16)}),
)
def test_padding(self, positive_test_case):
iter_extent, mapped_iterators, *args = positive_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level)
def test_padding_error(self, negative_test_case):
iter_extent, mapped_iterators, *args = negative_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level)
def test_overlapped_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
a = tvm.tir.Var("x", "int32")
b = tvm.tir.Var("y", "int32")
# non-bijective fuse of two
assert_iter_sum_pattern(
{
x * 7 + y: (22, 0, 1),
},
var_dom([(x, 3), (y, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")
# non-bijective fuse of three
assert_iter_sum_pattern(
{
x * 18 + y * 7 + z: (40, 0, 1),
},
var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")
# negative scale fusion is not allowed
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
# with predicate
assert_iter_sum_pattern(
{
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
},
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
check_level="surjective",
)
# stride=1 kernel
assert_iter_sum_pattern(
{x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
)
# do not allow both strided and overlapped
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")
def test_iter_map_simplify_symbolic_case():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = x * 32 + y
n = tvm.tir.SizeVar("n", "int64")
def simple_fuse0(x):
return (x // n) * n + x % n
assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)]))
assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)]))
def fsymbolic_fuse0(x):
return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n
assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)]))
def fsymbolic_fuse1(x):
return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n
assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)]))
def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n
assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)]))
def test_iter_map_simplify_symbolic_predicate():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
n = tvm.tir.SizeVar("n", "int64")
def simple_fuse0(x):
return (x // n) * n + x % n
z = x * 32 + y
assert_iter_map_simplify(
{simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16)
)
def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n
z = x * 64 + y
assert_iter_map_simplify(
{fsymbolic_fuse2(z): z},
var_dom([(x, (n * n + 1) // 2), (y, 64)]),
predicate=(z < n * n * 32),
)
def test_iter_map_simplify_symbolic_reshape():
n = tvm.tir.Var("n", "int64")
fused = tvm.tir.Var("fused", "int64")
ax0 = (fused // 4096) // n
ax1 = (fused // 4096) % n
ax2 = fused % 4096
rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096
assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)]))
def test_iter_map_simplify_unit_loop_order():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = tvm.tir.Var("z", "int64")
# trivial iterators can be found at any when comparing via scale
# ensure order unchange
assert_iter_map_simplify(
{x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False
)
# Even with simplifcation, it should follow the original order
assert_iter_map_simplify(
{x + y + (z // 4) * 4 + z % 4: z + x + y},
var_dom([(x, 1), (y, 1), (z, 32)]),
simplify_trivial_iterators=False,
)
assert_iter_map_simplify(
{y + 64 - x % 2 * 64: y + 64 - x % 2 * 64},
var_dom([(x, 6), (y, 64)]),
simplify_trivial_iterators=False,
)
# When we have iterators that have same scale but one of them come
# with unit extent, we should prioritize unit extent
assert_iter_map_simplify(
{x // 128 + y + z: y + z},
var_dom([(x, 128), (y, 128), (z, 1)]),
simplify_trivial_iterators=False,
)
def assert_normalize_to_iter_sum(index, input_iters, args, base):
"""Assert the result of arith.normalize_to_iter_sum is correct
Parameters
----------
index : tvm.tir.PrimExpr
The index to be normalized
input_iters : Mapping[Var, Range]
The input iterators
args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be
either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the
iterator normalized to PrimExpr and the second element is the scale.
base : tvm.tir.PrimExpr
The expected base
"""
res = tvm.arith.normalize_to_iter_sum(index, input_iters)
assert isinstance(res, tvm.arith.IterSumExpr)
assert len(res.args) == len(args)
for split, item in zip(res.args, args):
if isinstance(item, tvm.arith.IterSplitExpr):
tvm.ir.assert_structural_equal(split, item)
continue
tvm.testing.assert_prim_expr_equal(split.scale, item[1])
tvm.testing.assert_prim_expr_equal(
tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
)
tvm.testing.assert_prim_expr_equal(res.base, base)
def test_normalize_to_iter_sum():
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = tvm.tir.Var("z", "int64")
a = tvm.tir.Var("a", "int64")
n = tvm.tir.Var("n", "int64")
flm = tvm.tir.floormod
assert_normalize_to_iter_sum(
z + ((y + x * 4 + 2) * n) + 3,
var_dom([(x, 9), (y, 4), (z, 3)]),
[(x, n * 4), (y, n), (z, 1)],
2 * n + 3,
)
# max cannot detected so it goes into base
assert_normalize_to_iter_sum(
tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3,
var_dom([(x, 9), (y, 4), (z, 3)]),
[(x, n * 4), (y, n)],
tvm.tir.max(z, a) + 2 * n + 3,
)
# order by symbolc prod
assert_normalize_to_iter_sum(
z + ((y * 4 * a + x * 4 + 2) * n) + 3,
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, a * n * 4), (x, n * 4), (z, 1)],
2 * n + 3,
)
# order by cscale
assert_normalize_to_iter_sum(
z + 2 * y * 3 + 4 * x,
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (x, 4), (z, 1)],
0,
)
# split pattern
assert_normalize_to_iter_sum(
z + 2 * y * 3 + 4 * (x // 2),
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (x // 2, 4), (z, 1)],
0,
)
# non-divisible
assert_normalize_to_iter_sum(
x // 5,
var_dom([(x, 4096)]),
[
tvm.arith.IterSplitExpr(
tvm.arith.IterMark(x, 4096),
lower_factor=tvm.tir.const(5, "int64"),
extent=tvm.tir.const(820, "int64"),
scale=tvm.tir.const(1, "int64"),
)
],
0,
)
# iter simplify
assert_normalize_to_iter_sum(
z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (z, 2), (x, 1)],
0,
)
def test_detect_iter_map_with_bufferload_recursion():
n = tvm.tir.Var("n", "int32")
m = tvm.tir.Var("m", "int32")
divisor = tvm.tir.Var("divisor", "int32")
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen")
indices = [(buffer[i] + j) // divisor]
iter_vars = {
i: tvm.ir.Range(tvm.tir.const(0, "int32"), n),
j: tvm.ir.Range(tvm.tir.const(0, "int32"), m),
}
result = tvm.arith.detect_iter_map(indices, iter_vars)
assert len(result.indices) == 0
if __name__ == "__main__":
tvm.testing.main()