blob: 2b35ae71f2d6d6267a72df41e5de646d658f473b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
from tvm import relay, testing, transform
from tvm.relay.frontend.common import StrAttrsDict, set_span
from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span
def test_key_is_present():
attrs = StrAttrsDict({"a": 1})
assert attrs.has_attr("a")
def test_key_is_not_present():
attrs = StrAttrsDict({"a": 1})
assert not attrs.has_attr("b")
class TestSetSpan:
def test_pass_ctx_switch(self):
def _res(should_fill):
if should_fill:
with testing.enable_span_filling():
return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
else:
with testing.disable_span_filling():
return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
disable = relay.var("x", shape=(1, 64, 56, 56))
enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
_verify_structural_equal_with_span(_res(False), disable)
_verify_structural_equal_with_span(_res(True), enable)
# Should tag all exprs without span, and stop when expr is span-tagged
def test_builtin_tuple(self):
def _res():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
return set_span(tuple([a, b]), "tuple")
def _golden():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("tuple"))
return tuple([a, b])
res_tuple, golden_tuple = _res(), _golden()
assert len(res_tuple) == len(golden_tuple)
for i in range(len(res_tuple)):
_verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
def test_builtin_list(self):
def _res():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
t = relay.Tuple([a, b])
t_a = relay.TupleGetItem(t, 0)
t_b = relay.TupleGetItem(t, 1)
return set_span([t_a, t_b], "list")
def _golden():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("list"))
t = relay.Tuple([a, b], span=_create_span("list"))
t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
return [t_a, t_b]
res_list, golden_list = _res(), _golden()
assert len(res_list) == len(golden_list)
for i in range(len(res_list)):
_verify_structural_equal_with_span(res_list[i], golden_list[i])
def test_var(self):
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
_verify_structural_equal_with_span(x, x_expected)
def test_constant(self):
c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c")
c_expected = relay.const(
np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c")
)
_verify_structural_equal_with_span(c, c_expected)
def test_call(self):
def _res():
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
y = set_span(
relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
)
return relay.Function([x], y)
def _golden():
x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("conv2d"))
y = _set_span(
relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
)
return relay.Function([x], y)
_verify_structural_equal_with_span(_res(), _golden())
def test_tuple(self):
def _res():
a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
b = relay.const(np.ones([1, 1, 1]), dtype="int64")
t = set_span(relay.Tuple([a, b]), "t")
return relay.Function([], t)
def _golden():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("t"))
t = relay.Tuple([a, b], span=_create_span("t"))
return relay.Function([], t)
_verify_structural_equal_with_span(_res(), _golden())
def test_tuple_getitem(self):
def _res():
a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
b = relay.const(np.ones([1, 1, 1]), dtype="int64")
t = relay.Tuple([a, b])
i = set_span(relay.TupleGetItem(t, 0), "i")
return relay.Function([], i)
def _golden():
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("i"))
t = relay.Tuple([a, b], span=_create_span("i"))
i = relay.TupleGetItem(t, 0, span=_create_span("i"))
return relay.Function([], i)
_verify_structural_equal_with_span(_res(), _golden())
def test_let(self):
def _res():
x = set_span(relay.Var("x"), "x_var")
c_1 = relay.const(np.ones(10))
add = relay.add(x, x)
body = set_span(relay.Let(x, c_1, add), "let")
c_2 = set_span(relay.const(np.zeros(10)), "zeros")
y = set_span(relay.add(body, c_2), "add_2")
return relay.Function([x], y)
def _golden():
x = relay.Var("x", span=_create_span("x_var"))
c_1 = relay.const(np.ones(10), span=_create_span("let"))
add = _set_span(relay.add(x, x), "let")
body = relay.Let(x, c_1, add, span=_create_span("let"))
c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
y = _set_span(relay.add(body, c_2), "add_2")
return relay.Function([x], y)
_verify_structural_equal_with_span(_res(), _golden())
def test_if(self):
def _res():
x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
eq = relay.equal(x, y)
true_branch = set_span(relay.add(x, y), "true_branch")
false_branch = relay.subtract(x, y)
ife = set_span(relay.If(eq, true_branch, false_branch), "if")
return relay.Function([x, y], ife)
def _golden():
x = relay.var("x", shape=[], dtype="float32", span=_create_span("x_var"))
y = relay.var("y", shape=[], dtype="float32", span=_create_span("y_var"))
eq = _set_span(relay.equal(x, y), "if")
true_branch = _set_span(relay.add(x, y), "true_branch")
false_branch = _set_span(relay.subtract(x, y), "if")
ife = relay.If(eq, true_branch, false_branch, span=_create_span("if"))
return relay.Function([x, y], ife)
_verify_structural_equal_with_span(_res(), _golden())
def test_fn(self):
def _res():
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
f = set_span(relay.Function([x], y), "func")
return f
def _golden():
x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func"))
y = _set_span(
relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func"
)
f = relay.Function([x], y, span=_create_span("func"))
return f
_verify_structural_equal_with_span(_res(), _golden())
if __name__ == "__main__":
testing.main()