blob: dd10e15853f1b495e0b448204f2b78aef921998e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm import relay
import numpy as np
import pytest
from tvm.testing import enabled_targets
var_list = []
def verify_structure(stmt, expected_struct):
node_dict = {}
struct = {}
def _extract_vars(op):
global var_list
if isinstance(op, tvm.tir.Var):
var_list.append(op.name)
def _visit(op):
key = op
if isinstance(op, tvm.tir.IfThenElse):
global var_list
tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))]
var_list.clear()
elif isinstance(op, tvm.tir.For):
val = [(op.body,), ("tir.For", op.loop_var.name)]
elif isinstance(op, tvm.tir.AttrStmt):
val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
else:
return
node_dict[key] = val
tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
for key, val in node_dict.items():
struct[val[1]] = tuple(
node_dict[child][1] if child in node_dict else None for child in val[0]
)
assert struct == expected_struct, "Structure mismatch: expect %s but got %s" % (
expected_struct,
struct,
)
var_list.clear()
def _opaque_eval(var):
return tvm.tir.Evaluate(tvm.tir.call_extern("int32", "dummy", var))
def test_hoist_top_for():
ib = tvm.tir.ir_builder.create()
l = te.var("l")
m = te.var("m")
n = te.var("n")
data = ib.pointer("float32", name="data")
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)):
ib.emit(_opaque_eval(m))
with ib.else_scope():
ib.emit(_opaque_eval(n))
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.For", "j"): (("tir.For", "k"),),
("tir.IfThenElse", ("i",)): (("tir.For", "j"), ("tir.For", "j")),
("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
}
verify_structure(new_stmt, expected_struct)
def test_hoist_multi_var_if():
ib = tvm.tir.ir_builder.create()
l = te.var("l")
m = te.var("m")
n = te.var("n")
data = ib.pointer("float32", name="data")
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i + j < 2)):
ib.emit(_opaque_eval(m))
with ib.else_scope():
ib.emit(_opaque_eval(n))
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
new_stmt = new_mod["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),),
("tir.For", "i"): (("tir.For", "j"),),
}
verify_structure(new_stmt, expected_struct)
def test_hoist_no_match_for():
ib = tvm.tir.ir_builder.create()
l = te.var("l")
m = te.var("m")
n = te.var("n")
data = ib.pointer("float32", name="data")
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
data[i * 3 + j] = data[i * 3 + j] + 0.5
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)):
ib.emit(_opaque_eval(m))
with ib.else_scope():
ib.emit(_opaque_eval(n))
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")),
("tir.For", "j"): (None,),
("tir.For", "i"): (("tir.For", "j"),),
}
verify_structure(new_stmt, expected_struct)
def test_no_else():
ib = tvm.tir.ir_builder.create()
l = te.var("l")
m = te.var("m")
n = te.var("n")
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)):
ib.emit(_opaque_eval(m))
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.For", "j"): (("tir.For", "k"),),
("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
}
verify_structure(new_stmt, expected_struct)
def test_attr_stmt():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),),
("tir.For", "i"): (("tir.For", "j"),),
("tir.AttrStmt", "thread_extent", 64): (("tir.For", "i"),),
("tir.AttrStmt", "thread_extent", 32): (("tir.AttrStmt", "thread_extent", 64),),
}
verify_structure(new_stmt, expected_struct)
def test_nested_for():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")
with ib.for_range(0, 5, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.if_scope(i >= 3):
data[i * 3 + j] = data[i * 3 + j] + 0.5
with ib.for_range(0, 15, "k") as k:
with ib.for_range(0, 20, "l") as l:
with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
with ib.else_scope():
data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "l"): (None,),
("tir.For", "k"): (("tir.For", "l"),),
("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
("tir.For", "j"): (None,),
("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
}
verify_structure(new_stmt, expected_struct)
def test_if_block():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")
n = te.var("n")
with ib.for_range(0, 5, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.if_scope(i >= 3):
data[i * 3 + j] = data[i * 3 + j] + 0.5
with ib.for_range(0, 15, "k") as k:
with ib.for_range(0, 20, "l") as l:
with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
with ib.else_scope():
data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
with ib.if_scope(j < 5):
data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1
with ib.for_range(0, 5, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.for_range(0, 15, "k") as k:
with ib.if_scope(n >= 3):
data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.IfThenElse", ("i", "j")): (None, None),
("tir.IfThenElse", ("j",)): (None, None),
("tir.For", "l"): (None,),
("tir.For", "k"): (None,),
("tir.For", "j"): (("tir.For", "j"),),
("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
("tir.IfThenElse", ("n",)): (("tir.For", "j"), None),
}
verify_structure(new_stmt, expected_struct)
def test_multi_if():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")
with ib.for_range(0, 10, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.for_range(0, 10, "k") as k:
with ib.if_scope(3 <= i):
with ib.if_scope(3 <= j):
data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
new_stmt = new_mod["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("j",)): (("tir.For", "k"), None),
("tir.For", "j"): (("tir.IfThenElse", ("j",)),),
("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
}
verify_structure(new_stmt, expected_struct)
def test_no_hoisting_1():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")
n = te.var("n")
with ib.for_range(0, 10, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.for_range(0, 10, "k") as k:
with ib.if_scope(k <= 3):
data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_no_hoisting_2():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")
n = te.var("n")
x = te.var("x")
with ib.for_range(0, 10, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.for_range(0, 10, "k") as k:
with ib.if_scope(i <= 3):
data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3
data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True)
def test_no_hoisting_3():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
dshape_inner = (33, 63)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
ib.scope_attr(tx, "thread_extent", dshape_inner[0])
ib.scope_attr(bx, "thread_extent", dshape_inner[1])
with ib.if_scope(tx < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_no_hoisting_4():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
dshape_inner = (33, 63)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
ib.scope_attr(tx, "thread_extent", dshape_inner[0])
with ib.if_scope(tx < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True)
def test_no_hoisting_5():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
dshape_inner = (33, 63)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
ib.scope_attr(bx, "thread_extent", dshape_inner[1])
with ib.for_range(0, n, "k") as k:
ib.scope_attr(tx, "thread_extent", dshape_inner[0])
with ib.if_scope(tx < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_no_hoisting_6():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope((tx + k) < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_no_hoisting_7():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.if_scope((tx + j) < 9):
with ib.for_range(0, n, "k") as k:
with ib.if_scope((tx + k) < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_1():
n = te.size_var("n")
m = te.size_var("m")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), "k")
B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
s = te.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
mod = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
stmt = mod["main"].body
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_2():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
dshape_inner = (33, 63)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
# ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.if_scope(tx < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
stmt = mod["main"].body
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True)
def test_hoisting_block_scope_3():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
dshape_inner = (33, 63)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
ib.scope_attr(tx, "thread_extent", dshape_inner[0])
ib.scope_attr(bx, "thread_extent", dshape_inner[1])
with ib.for_range(0, n, "k") as k:
with ib.if_scope(tx < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_4():
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
AA = te.compute((n,), lambda *i: A(*i), name="A")
BB = te.compute((n,), lambda *i: B(*i), name="B")
T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name="T")
C = te.compute(A.shape, lambda *i: T(*i), name="C")
s = te.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
s[C].pragma(xo1, "parallel_launch_point")
s[C].pragma(xo2, "parallel_stride_pattern")
s[C].pragma(xo2, "parallel_barrier_when_finish")
s[C].vectorize(xi)
mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)
mod = tvm.tir.transform.Simplify()(mod)
stmt = mod["main"].body
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_5():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data", scope="global")
l = te.var("l")
m = te.var("m")
n = te.var("n")
g = te.var("g")
ib.scope_attr(data, "storage_scope", "global")
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(data[g] < 3):
data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3
with ib.else_scope():
data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
stmt = new_stmt
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_6():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope((tx + n) < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
def test_hoisting_block_scope_7():
ib = tvm.tir.ir_builder.create()
dshape = (32, 64)
data = ib.pointer("float32", name="data")
l = te.var("l")
m = te.var("m")
n = te.var("n")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", dshape[0])
ib.scope_attr(bx, "thread_extent", dshape[1])
with ib.for_range(0, l, "i") as i:
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope((tx + i) < 3):
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
with ib.else_scope():
data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
@pytest.mark.skip()
def test_hoisting_op_conv():
dtype = "float32"
dshape = (1, 80, 73, 73)
kshape = (192, 80, 3, 3)
padding = (1, 1)
groups = 1
dilation = (1, 1)
kernel_size = (3, 3)
channels = 192
scale = 1
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.conv2d(
x,
w,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.InferType()(mod)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
params = {"w": tvm.nd.array(kernel)}
for target, dev in enabled_targets():
with tvm.transform.PassContext(opt_level=3):
lib = relay.build_module.build(mod, target=target, params=params)
m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
x = np.random.uniform(size=dshape)
data_tvm = tvm.nd.array(data)
m.set_input("x", data_tvm)
m.run()
e = m.module.time_evaluator("run", dev, number=300, repeat=3)
t1 = e(data_tvm).results
t1 = np.array(t1) * 1000
print("{} ms".format(t1.mean()))
with tvm.transform.PassContext(
opt_level=3, config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
):
lib = relay.build_module.build(mod, target=target, params=params)
m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
x = np.random.uniform(size=dshape)
data_tvm = tvm.nd.array(data)
m.set_input("x", data_tvm)
m.set_input(**params)
m.run()
e = m.module.time_evaluator("run", dev, number=300, repeat=3)
t2 = e(data_tvm).results
t2 = np.array(t2) * 1000
print("{} ms".format(t2.mean()))
tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)
if __name__ == "__main__":
tvm.testing.main()