blob: 482c2246654d2ff170dafc7f11f8fa43a2cee436 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-docstring, too-many-statements
import tvm
from tvm import relay
def get_recursive_count_loop():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar("sum_up")
i = relay.var("i", shape=[], dtype="int32")
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype="int32"))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var("i", shape=[], dtype="int32")
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod, sum_up
def test_call_chain_inline_leaf():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1 g2
/
g11(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + x1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1(inline) g2
/
g11(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1 + p0
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels_extern_compiler():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1(inline) g2
/
g11(inline, external compiler)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.with_attr("Compiler", "a")
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1 + fn11(p0)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_call_with_global():
def get_mod():
mod = tvm.IRModule({})
x = relay.var("x", shape=[], dtype="int32")
fn0 = relay.Function([x], x)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
gx = relay.GlobalVar("gx")
mod[gx] = fn0
sum_up = relay.GlobalVar("sum_up")
i = relay.var("i", shape=[], dtype="int32")
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype="int32"))
global_call = gx(i)
rec_call = relay.Call(sum_up, [one_less]) + global_call
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var("i", shape=[], dtype="int32")
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod
def expected():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar("sum_up")
i = relay.var("i", shape=[], dtype="int32")
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype="int32"))
rec_call = relay.Call(sum_up, [one_less]) + i
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var("i", shape=[], dtype="int32")
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_called():
mod, sum_up = get_recursive_count_loop()
iarg = relay.var("i", shape=[], dtype="int32")
mod["main"] = relay.Function([iarg], sum_up(iarg))
ref_mod = mod
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called():
def get_mod():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
return mod
def expected():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
mod["main"] = relay.Function([x, y], x + y + x)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called_extern_compiler():
def get_mod():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
return mod
def expected():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
mod["main"] = relay.Function([x, y], x + y + fn1(x))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True)
def test_globalvar_as_call_arg():
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1
call_fn2 = p2 - p3
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_as_call_arg_extern_compiler():
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", "b")
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", "b")
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = relay.Call(fn1, [p0, p1])
call_fn2 = relay.Call(fn2, [p2, p3])
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args():
def get_mod():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
g2 = relay.GlobalVar("g2")
mod[g1] = fn1
mod = relay.transform.InferType()(mod)
mod[g2] = fn2
p = relay.var("p", "bool")
mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return relay.transform.InferType()(mod)
def expected():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
p = relay.var("p", "bool")
mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), []))
return relay.transform.InferType()(mod)
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args_extern_compiler():
def get_mod():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", "b")
g1 = relay.GlobalVar("g1")
g2 = relay.GlobalVar("g2")
mod[g1] = fn1
mod[g2] = fn2
p = relay.var("p", "bool")
mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
def expected():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", "b")
p = relay.var("p", "bool")
mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), []))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_called_by_multiple_functions():
"""Test when only leaf call is inlined.
The call graph is like the following:
main g0
/ \ /
g1 g2(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
z0 = relay.var("z0", shape=(3, 5))
fn0 = relay.Function([x0, y0, z0], g2(x0, y0) + z0)
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn2 = p2 - p3
mod["main"] = relay.Function([p0, p1, p2, p3], g1(p0, p1) * call_fn2)
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
z0 = relay.var("z0", shape=(3, 5))
fn0 = relay.Function([x0, y0, z0], x0 - y0 + z0)
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_entry_with_inline():
"""Test entry function with inline
The call graph is like the following:
g1(inline) g2(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + y1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - y2)
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + y1)
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline_leaf_inline():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
|
g0(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + g0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + x1 * y1)
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
def test_callee_not_inline_leaf_inline_extern_compiler():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
|
g0(inline, external compiler)
"""
def get_mod():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.with_attr("Compiler", "aa")
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + g0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
def expected():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.with_attr("Compiler", "aa")
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + fn0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True)
if __name__ == "__main__":
tvm.testing.main()