| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Unit tests for converting TensorFlow control flow op to Relay.""" |
| import pytest |
| |
| try: |
| import tensorflow.compat.v1 as tf |
| |
| tf.disable_v2_behavior() |
| except ImportError: |
| import tensorflow as tf |
| from tensorflow.python.ops import control_flow_ops |
| import numpy as np |
| from tvm import nd, relay, ir, testing |
| from tvm.relay.frontend.tensorflow import from_tensorflow |
| |
| |
| def check_equal(graph, tf_out, input_map=None): |
| with testing.disable_span_filling(): |
| mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) |
| with testing.enable_span_filling(): |
| mod_with_span, _ = from_tensorflow(graph.as_graph_def(add_shapes=True)) |
| assert ir.structural_equal(mod["main"], mod_with_span["main"]) |
| |
| if input_map is not None: |
| params.update(input_map) |
| relay_out = relay.create_executor("vm", mod=mod).evaluate()(**params) |
| if isinstance(relay_out, nd.NDArray): |
| np.testing.assert_allclose(tf_out, relay_out.numpy()) |
| else: |
| if not isinstance(tf_out, (list, tuple)): |
| tf_out = [tf_out] |
| for x, y in zip(tf_out, [r.numpy() for r in relay_out]): |
| np.testing.assert_allclose(x, y) |
| |
| |
| def test_vanilla_loop(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i = tf.constant(0, name="while/constant") |
| |
| def c(i): |
| return tf.less(i, 10) |
| |
| def b(i): |
| return tf.add(i, 1) |
| |
| r = tf.while_loop(c, b, [i]) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_callnode_loop_vars(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i = tf.add(tf.constant(0), 1) |
| |
| def c(i): |
| return tf.less(i, 10) |
| |
| def b(i): |
| return tf.add(i, 1) |
| |
| r = tf.while_loop(c, b, [i]) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_loop_2_vars(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i0 = tf.constant(0) |
| j0 = tf.ones([2, 2]) |
| |
| def c(i, j): |
| return i < 10 |
| |
| def b(i, j): |
| return [tf.add(i, 1), j] |
| |
| i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) |
| i1 += tf.constant(1337) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(i1) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_loop_3_vars(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i0 = tf.constant(1) |
| j0 = tf.constant(2) |
| k0 = tf.constant(4) |
| |
| def c(i, j, k): |
| return i < 10 |
| |
| def b(i, j, k): |
| return [i + 1, j * k, k + i] |
| |
| r = tf.while_loop(c, b, loop_vars=[i0, j0, k0]) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_loop_conditions(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i = tf.constant(1) |
| j = tf.constant(1) |
| k = tf.constant(5) |
| |
| def c(i, j, k): |
| return tf.equal( |
| tf.not_equal(tf.less(i + j, 10), tf.less(j * k, 100)), tf.greater_equal(k, i + j) |
| ) |
| |
| def b(i, j, k): |
| return [i + j, j + k, k + 1] |
| |
| r = tf.while_loop(c, b, loop_vars=[i, j, k]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| @pytest.mark.skip |
| def test_loop_bodies(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def body(x): |
| a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32) |
| b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32) |
| c = a + b |
| return tf.nn.relu(x + c) |
| |
| def condition(x): |
| return tf.reduce_sum(x) < 100 |
| |
| x = tf.constant(0, shape=[2, 2]) |
| r = tf.while_loop(condition, body, [x]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_nested_loop(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def body(x): |
| def nest_body(c): |
| return tf.multiply(c, 2) |
| |
| def cd(c): |
| return tf.less(c, 10) |
| |
| c = tf.constant(2) |
| res = tf.while_loop(cd, nest_body, loop_vars=[c]) |
| return tf.nn.relu(x + res) |
| |
| def condition(x): |
| return tf.greater(x, 100) |
| |
| x = tf.constant(3) |
| r = tf.while_loop(condition, body, loop_vars=[x]) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_vanilla_cond(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| i = tf.constant(1) |
| j = tf.constant(4) |
| |
| def f1(): |
| return tf.multiply(1, 17) |
| |
| def f2(): |
| return tf.add(4, 23) |
| |
| r = tf.cond(tf.less(i, j), f1, f2) |
| |
| with tf.Session(graph=graph) as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_multiple_cond_vars(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| x1 = tf.constant(7) |
| x2 = tf.constant(12) |
| z = tf.constant(20) |
| r = tf.cond(tf.less(tf.add(x1, x2), 10), lambda: tf.add(10, 2), lambda: tf.square(5)) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_cond_fn_parameters(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def fn1(x, y): |
| return tf.multiply(5, 6) |
| |
| def fn2(x, y): |
| return tf.add(3, 4) |
| |
| i = tf.constant(1) |
| j = tf.constant(2) |
| k = tf.constant(3) |
| r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k)) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3}) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_nested_cond(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def fn1(a, b): |
| def nest_fn1(): |
| return tf.add(1, 2) |
| |
| def nest_fn2(): |
| return tf.subtract(10, 5) |
| |
| res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2) |
| return tf.multiply(tf.add(87, res), 10) |
| |
| def fn2(a, b): |
| return tf.add(10, 10) |
| |
| x = tf.constant(5) |
| y = tf.constant(6) |
| z = tf.constant(7) |
| pred = tf.less(x, y) |
| r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_loop_in_cond(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def fn1(a, b): |
| i = tf.constant(0) |
| |
| def cd(i): |
| return tf.less(i, 10) |
| |
| def bd(i): |
| return tf.add(i, 1) |
| |
| res = tf.while_loop(cd, bd, [i]) |
| return tf.multiply(tf.add(20, res), 10) |
| |
| def fn2(a, b): |
| return tf.add(10, 20) |
| |
| x = tf.constant(7) |
| y = tf.constant(20) |
| z = tf.constant(10) |
| pred = tf.less(x, y) |
| r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) |
| |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_cond_in_loop(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| |
| def body(x): |
| x = tf.constant(7) |
| z = tf.constant(20) |
| res = tf.cond(tf.less(x, 10), lambda: tf.add(10, 20), lambda: tf.square(10)) |
| return tf.multiply(res, x) |
| |
| x = tf.constant(21) |
| |
| def condition(x): |
| return tf.less(x, 100) |
| |
| r = tf.while_loop(condition, body, loop_vars=[x]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r) |
| |
| check_equal(graph, tf_out) |
| |
| |
| def test_vanilla_loop_bound(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| dshape = (2, 10) |
| dtype = "float32" |
| dname = "data" |
| np_data = np.random.uniform(size=dshape).astype(dtype) |
| data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) |
| x = tf.slice(data, [1, 4], [1, 4]) |
| outer = x + 5.0 |
| |
| def body(x, y): |
| res = tf.cond(tf.less(y, 10), lambda: tf.add(10.0, 20.0), lambda: tf.square(10.0)) |
| z = tf.constant(7) |
| res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) |
| return tf.multiply(res, x * outer), y + 1 |
| |
| y = tf.constant(0) |
| |
| def condition(x, y): |
| return tf.less(y, 20) |
| |
| r = tf.while_loop(condition, body, loop_vars=[x, y]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) |
| |
| check_equal(graph, tf_out, {dname: np_data}) |
| |
| |
| def test_nested_loop_bound(): |
| graph = tf.Graph() |
| with graph.as_default(): |
| dshape = (2, 10) |
| dtype = "float32" |
| dname = "data" |
| np_data = np.random.uniform(size=dshape).astype(dtype) |
| data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) |
| x = tf.slice(data, [1, 4], [1, 4]) |
| outer = x + 5.0 |
| |
| def body(x, y): |
| res = tf.cond(tf.less(y, 10), lambda: tf.add(10.0, 20.0), lambda: tf.square(10.0)) |
| |
| def nested_body(nx, ny): |
| return nx + 1, res + 2.0 |
| |
| def nested_cond(nx, ny): |
| return tf.less(nx, 15) |
| |
| nx = tf.constant(0) |
| ny = tf.constant(0.0) |
| nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny]) |
| res = res + nested_res[1] |
| z = tf.constant(7) |
| res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) |
| return tf.multiply(res, x * outer), y + 1 |
| |
| y = tf.constant(0) |
| |
| def condition(x, y): |
| return tf.less(y, 20) |
| |
| r = tf.while_loop(condition, body, loop_vars=[x, y]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) |
| |
| check_equal(graph, tf_out, {dname: np_data}) |
| |
| |
| def test_switch(): |
| graph = tf.Graph() |
| |
| with graph.as_default(): |
| data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype("float32") |
| dname = "data" |
| flag_name = "flag" |
| data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) |
| split = tf.split(data, 2, axis=0) |
| flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name) |
| output_false, output_true = control_flow_ops.switch(split[1], flag) |
| with tf.Session() as sess: |
| tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False}) |
| |
| check_equal(graph, tf_out, {dname: data_np, flag_name: False}) |
| |
| |
| def test_loop_tuple_input(): |
| graph = tf.Graph() |
| |
| with graph.as_default(): |
| data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype("float32") |
| dname = "data" |
| data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) |
| split = tf.split(data, 2, axis=0) |
| |
| def body(x, y): |
| return x + 2, y + 1 |
| |
| start = tf.constant(0) |
| |
| def condition(x, y): |
| return tf.less(y, 20) |
| |
| r = tf.while_loop(condition, body, loop_vars=[split[1], start]) |
| with tf.Session() as sess: |
| tf_out = sess.run(r, feed_dict={data.name: data_np}) |
| |
| check_equal(graph, tf_out, {dname: data_np}) |
| |
| |
| if __name__ == "__main__": |
| # tf.while_loop |
| test_vanilla_loop() |
| test_loop_2_vars() |
| test_loop_3_vars() |
| test_loop_conditions() |
| # TODO(@jroesch): Need to fix memory alloc to support closure |
| # test_loop_bodies() |
| test_callnode_loop_vars() |
| |
| # tf.cond |
| test_vanilla_cond() |
| test_multiple_cond_vars() |
| test_cond_fn_parameters() |
| |
| # nested cases |
| test_nested_loop() |
| test_nested_cond() |
| test_loop_in_cond() |
| test_cond_in_loop() |
| test_vanilla_loop_bound() |
| test_nested_loop_bound() |
| |
| test_switch() |
| test_loop_tuple_input() |