blob: 6374d20173b2a83e2852e6d549f52c37d8bfd1f2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import relay
from tvm.relay import Function, transform
from tvm.relay.testing import inception_v3
import numpy as np
import pytest
cpu_scope = tvm.target.VirtualDevice(tvm.cpu(), tvm.target.Target("llvm"))
metatable = {"VirtualDevice": [cpu_scope]}
core = tvm.IRModule()
core.import_from_std("core.rly")
def optimize_and_check(before_program, after_program, passes):
if isinstance(before_program, str):
before_program = tvm.relay.parse(before_program)
if isinstance(after_program, str):
after_program = tvm.relay.parse(after_program)
if not isinstance(passes, list):
passes = [passes]
optimize = tvm.transform.Sequential(passes)
optimized_program = optimize(before_program)
print("Actual:")
print(optimized_program)
print("Expected:")
print(after_program)
tvm.ir.assert_structural_equal(optimized_program, after_program, map_free_vars=True)
def test_dead_let():
before_program = """
#[version = "0.0.5"]
def @main(%z: int) {
let %x = 1;
%z
}
"""
after_program = """
#[version = "0.0.5"]
def @main(%z: int) {
%z
}
"""
optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_one_live_let():
before_program = """
#[version = "0.0.5"]
def @main(%z: int) {
let %x = 1;
let %y = 2;
%x + %x
}
"""
after_program = """
#[version = "0.0.5"]
def @main(%z: int) {
let %x = 1;
%x + %x
}
"""
optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_nested_let():
before_program = """
#[version = "0.0.5"]
def @main(%d: int, %b: int) {
let %a = %b;
let %c = %d;
%c
}
"""
after_program = """
#[version = "0.0.5"]
def @main(%d: int, %b: int) {
let %c = %d;
%c
}
"""
optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_live_recursion():
before_program = """
#[version = "0.0.5"]
def @main() {
let %f = fn (%n: int, %data: int) -> int {
if (%n == 0) {
%data
} else {
%f(%n - 1, log(%data))
}
};
%f(2, 10000)
}
"""
after_program = """
#[version = "0.0.5"]
def @main() {
let %f = fn (%n: int, %data: int) -> int {
if (%n == 0) {
%data
} else {
%f(%n - 1, log(%data))
}
};
%f(2, 10000)
}
"""
optimize_and_check(
before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
)
def test_dead_recursion():
before_program = """
#[version = "0.0.5"]
def @main() {
let %f = fn (%n: int, %data: int) -> int {
if (%n == 0) {
%data
} else {
%f(%n - 1, log(%data))
}
};
()
}
"""
after_program = """
#[version = "0.0.5"]
def @main() {
()
}
"""
optimize_and_check(
before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
)
def test_add_with_let():
before_program = """
#[version = "0.0.5"]
def @main() {
(let %a = 1; 3) + 2
}
"""
after_program = """
#[version = "0.0.5"]
def @main() {
3 + 2
}
"""
optimize_and_check(
before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
)
def test_tuple_get_item():
before_program = """
#[version = "0.0.5"]
def @main() {
let %a = 100;
(1, 2, 3, 4).0
}
"""
after_program = """
#[version = "0.0.5"]
def @main() {
(1, 2, 3, 4).0
}
"""
optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_inline_into_function():
"""Don't inline across function boundaries."""
before_program = """
#[version = "0.0.5"]
def @main() {
let %x = 1 + 1;
let %f = fn (%y: int) -> int {
let %z = %y + %y;
%x + %z
};
(%f(2), %f(3))
}
"""
after_program = """
#[version = "0.0.5"]
def @main() {
let %x = 1 + 1;
let %f = fn (%y: int) -> int {
%x + (%y + %y)
};
(%f(2), %f(3))
}
"""
optimize_and_check(
before_program, after_program, transform.DeadCodeElimination(inline_once=True)
)
def test_impure_op():
shape = np.array([64, 2])
metatable = {
"VirtualDevice": [cpu_scope],
"relay.Constant": [relay.const(shape, dtype="int64")],
}
"""Don't elide calls to side-effecting operators."""
before_program = tvm.relay.parse(
"""
#[version = "0.0.5"]
def @main() {
let %size: int64 = cast(1024, dtype="int64");
let %alignment: int64 = cast(64, dtype="int64");
let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
let %_ = memory.kill(%x);
0
}
""",
"from_string",
core,
metatable,
)
after_program = tvm.relay.parse(
"""
#[version = "0.0.5"]
def @main() {
%0 = memory.alloc_storage(cast(1024, dtype="int64"),
meta[relay.Constant][0],
cast(64, dtype="int64"),
virtual_device=meta[VirtualDevice][0]);
let %_ = memory.kill(%0);
0
}
""",
"from_string",
core,
metatable,
)
optimize_and_check(
before_program, after_program, transform.DeadCodeElimination(inline_once=True)
)
def test_impure_func():
shape = np.array([64, 2])
metatable = {
"VirtualDevice": [cpu_scope],
"relay.Constant": [relay.const(shape, dtype="int64")],
}
"""Don't elide calls to side-effecting functions."""
before_program = tvm.relay.parse(
"""
#[version = "0.0.5"]
def @f() -> int {
let %size: int64 = cast(1024, dtype="int64");
let %alignment: int64 = cast(64, dtype="int64");
let %x = memory.alloc_storage(%size, meta[relay.Constant][0], %alignment, virtual_device=meta[VirtualDevice][0]);
let %_ = memory.kill(%x);
0
}
def @main() -> int {
let %y = @f();
0
}
""",
"from_string",
core,
metatable,
)
after_program = tvm.relay.parse(
"""
#[version = "0.0.5"]
def @f() -> int {
%0 = memory.alloc_storage(cast(1024, dtype="int64"),
meta[relay.Constant][0],
cast(64, dtype="int64"),
virtual_device=meta[VirtualDevice][0]);
let %_ = memory.kill(%0);
0
}
def @main() -> int {
let %y = @f();
0
}
""",
"from_string",
core,
metatable,
)
optimize_and_check(
before_program, after_program, transform.DeadCodeElimination(inline_once=True)
)
def test_refs():
"""Don't elide expressions with reference create/read/write side effects"""
before_program = """
#[version = "0.0.5"]
def @f(%r) -> int {
let %v = ref_read(%r);
let %u = ref_write(%r, %v + 1);
%v
}
def @main() -> int {
let %r = ref(0);
let %y = @f(%r);
let %z = @f(%r);
%z
}
"""
after_program = before_program
optimize_and_check(
before_program,
after_program,
[transform.InferType(), transform.DeadCodeElimination(inline_once=True)],
)
def test_complexity():
mod = transform.InferType()(
tvm.IRModule.from_expr(inception_v3.get_net(1, 1000, (3, 299, 299), "float32"))
)
optimize_and_check(mod, mod, transform.DeadCodeElimination(inline_once=True))
if __name__ == "__main__":
tvm.testing.main()