blob: c9fdbd0c424e3a0748cb1db66aac7b1b96bc1ddb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression
class BaseBeforeAfter:
hoisted_conditionals = tvm.testing.parameter(HoistedConditionals.All)
hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.All)
def test_hoist(self, hoisted_conditionals, hoisted_let_bindings):
before = self.before
before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
config = {
"tir.HoistExpression": {
"hoisted_conditionals": hoisted_conditionals.value,
"hoisted_let_bindings": hoisted_let_bindings.value,
}
}
with tvm.transform.PassContext(config=config):
after_mod = tvm.tir.transform.HoistExpression()(before_mod)
after = after_mod["main"]
expected = self.expected.with_attr("global_symbol", "main")
try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script()
raise ValueError(
f"Function after simplification did not match expected:\n{script}"
) from err
class TestHoistToTop(BaseBeforeAfter):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.IfElseStmt,
HoistedConditionals.All,
)
@T.prim_func
def before(A: T.Buffer((16,), "float32"), n: T.int32):
for i in T.serial(16):
if n != 0:
A[i] = 0.0
@T.prim_func
def expected(A: T.Buffer((16,), "float32"), n: T.int32):
if n != 0:
for i in T.serial(16):
A[i] = 0.0
class TestSuppressHoistIfElse(BaseBeforeAfter):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.Never,
HoistedConditionals.IfElseExpr,
)
@T.prim_func
def before(A: T.Buffer((16,), "float32"), n: T.int32):
for i in T.serial(16):
if n != 0:
A[i] = 0.0
expected = before
class TestHoistBlockVar(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((128, 16), "float32"), n: T.int32):
i = T.env_thread("threadIdx.x")
T.launch_thread(i, 128)
for j in T.serial(16):
if i < 32:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((128, 16), "float32"), n: T.int32):
i = T.env_thread("threadIdx.x")
T.launch_thread(i, 128)
if i < 32:
for j in T.serial(16):
A[i, j] = 0.0
class TestSuppressHoistBlockVar(BaseBeforeAfter):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.All & ~HoistedConditionals.UsingBlockVar
)
@T.prim_func
def before(A: T.Buffer((128, 16), "float32"), n: T.int32):
thread_x = T.env_thread("threadIdx.x")
T.launch_thread(thread_x, 128)
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
if i < 32:
for j in T.serial(16):
A[i, j] = 0.0
expected = before
class TestHoistAcrossBlockVar(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((128, 16), "float32"), n: T.int32):
thread_x = T.env_thread("threadIdx.x")
T.launch_thread(thread_x, 128)
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
if n == 0:
for j in T.serial(16):
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((128, 16), "float32"), n: T.int32):
thread_x = T.env_thread("threadIdx.x")
if n == 0:
T.launch_thread(thread_x, 128)
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
for j in T.serial(16):
A[i, j] = 0.0
class TestSuppressHoistAcrossBlockVar(BaseBeforeAfter):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.All & ~HoistedConditionals.UsingBlockVar
)
@T.prim_func
def before(A: T.Buffer((128, 16), "float32"), n: T.int32):
thread_x = T.env_thread("threadIdx.x")
T.launch_thread(thread_x, 128)
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
for j in T.serial(16):
if n == 0:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((128, 16), "float32"), n: T.int32):
thread_x = T.env_thread("threadIdx.x")
T.launch_thread(thread_x, 128)
if n == 0:
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
for j in T.serial(16):
A[i, j] = 0.0
class TestHoistToMiddle(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
if i < 3:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 3:
for j in T.serial(4):
A[i, j] = 0.0
class TestHoistWithLet(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
condition = i < 3
if condition:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
condition = i < 3
if condition:
for j in T.serial(4):
A[i, j] = 0.0
class TestHoistDisableLet(BaseBeforeAfter):
"""As TestHoistWithLet, but forbid hoisting of LetStmt
Because the condition depends on the let binding, it should no
longer be hoisted.
"""
hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.Never)
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
condition = i < 3
if condition:
A[i, j] = 0.0
expected = before
class TestHoistIfElse(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
if i < 3:
A[i, j] = 0.0
else:
A[i, j] = 1.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 3:
for j in T.serial(4):
A[i, j] = 0.0
else:
for j in T.serial(4):
A[i, j] = 1.0
class TestHoistSequentialAssign(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
if i < 3:
A[i, j] = 0.0
B[i, j] = 0.0
else:
A[i, j] = 1.0
B[i, j] = 1.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 3:
for j in T.serial(4):
A[i, j] = 0.0
B[i, j] = 0.0
else:
for j in T.serial(4):
A[i, j] = 1.0
B[i, j] = 1.0
class TestHoistMultiIf(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
for k in T.serial(4):
if j < 3:
if i < 2:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 2:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 0.0
class TestHoistComplexConditional(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j, k in T.grid(4, 4, 4):
if j < 3 and i < 2:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 2:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 0.0
class TestSuppressSplittingConditional(BaseBeforeAfter):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.All & ~HoistedConditionals.BooleanExpression
)
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j, k in T.grid(4, 4, 4):
if j < 3 and i < 2:
A[i, j] = 0.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
if j < 3 and i < 2:
for k in T.serial(4):
A[i, j] = 0.0
class TestHoistMultiIfElse(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
for k in T.serial(4):
if j < 3:
if i < 2:
A[i, j] = 0.0
else:
A[i, j] = 1.0
else:
if i < 2:
A[i, j] = 2.0
else:
A[i, j] = 3.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 2:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 0.0
else:
for k in T.serial(4):
A[i, j] = 2.0
else:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 1.0
else:
for k in T.serial(4):
A[i, j] = 3.0
class TestHoistMultiIfElseDifferentBranches(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
for k in T.serial(4):
if j < 3:
if i < 2:
A[i, j] = 0.0
else:
A[i, j] = 1.0
else:
if i < 1:
A[i, j] = 2.0
else:
A[i, j] = 3.0
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 2:
if i < 1:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 0.0
else:
for k in T.serial(4):
A[i, j] = 2.0
else:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 0.0
else:
for k in T.serial(4):
A[i, j] = 3.0
else:
for j in T.serial(4):
if j < 3:
for k in T.serial(4):
A[i, j] = 1.0
else:
for k in T.serial(4):
A[i, j] = 3.0
class TestHoistIfElseExpr(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
A[i, j] = T.if_then_else(i < 2, 1.0, 2.0, dtype="float32")
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
if i < 2:
for j in T.serial(4):
A[i, j] = 1.0
else:
for j in T.serial(4):
A[i, j] = 2.0
class TestSuppressHoistIfElseExpr(TestHoistIfElseExpr):
hoisted_conditionals = tvm.testing.parameter(
HoistedConditionals.All & ~HoistedConditionals.IfElseExpr
)
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
A[i, j] = T.if_then_else(i < 2, 1.0, 2.0, dtype="float32")
expected = before
class TestHoistLetExpr(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
x = T.float32()
A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")})
@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
x = T.cast(i + 1, "float32")
for j in T.serial(4):
A[i, j] = 5.0 * x + T.cast(j, "float32")
class TestSuppressHoistLetExpr(BaseBeforeAfter):
hoisted_let_bindings = tvm.testing.parameter(
HoistedLetBindings.All & ~HoistedLetBindings.LetExpr
)
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
x = T.float32()
A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")})
expected = before
if __name__ == "__main__":
tvm.testing.main()