blob: 0f7c4dd7d65accbe309f62d3d53722f7a2b4ecdd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
import numpy as np
from tvm import relay, ir, testing
from tvm.relay.frontend.tensorflow import from_tensorflow
def run_relay(graph, shape_dict=None, *vars):
with testing.disable_span_filling():
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict)
with testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_tensorflow(
graph.as_graph_def(add_shapes=True), shape=shape_dict
)
assert ir.structural_equal(mod["main"], mod_with_span["main"])
return relay.create_executor("debug", mod=mod).evaluate()(*vars)
def test_assert_true():
g = tf.Graph()
shape = (1, 2)
with g.as_default():
x = tf.placeholder(tf.float32, shape=shape, name="input")
assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"])
with tf.Session() as sess:
x_value = np.random.rand(*shape)
assert sess.run(assert_op, feed_dict={x: x_value}) is None
# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple.
#
# ToDo: It appears that the frontend converter gets confused here and
# entirely eliminates all operands from main(). Likely because x <= x
# is always true, so the placeholder can be eliminated. But TF doesn't
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np.testing.assert_allclose(0, run_relay(g, {"input": shape}).numpy())
def test_assert_true_var_capture():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, shape=())
# It turns out that tf.assert() creates a large and complex subgraph if
# you capture a variable as part of the error message. So we need to
# test that, too.
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x])
with tf.Session() as sess:
x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None
# TODO: The frontend converter notes the output of
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None.
np.testing.assert_allclose(True, run_relay(g, None, x_value).numpy())
def test_assert_false():
g = tf.Graph()
with g.as_default():
assert_op = tf.Assert(tf.constant(False), ["it failed"])
with tf.Session() as sess:
try:
print(sess.run(assert_op))
assert False # TF should have thrown an exception
except tf.errors.InvalidArgumentError as e:
assert "it failed" in e.message
# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple. For the same
# reason, there should not be an error here, even though the assertion
# argument is false.
np.testing.assert_allclose(0, run_relay(g).numpy())
if __name__ == "__main__":
test_assert_true()
test_assert_true_var_capture()
test_assert_false()