| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| BatchNorm without given mean and variance given testcases |
| ==================== |
| This is a test script to test fused_batch_norm operators |
| in TensorFlow frontend when mean and variance are not given. |
| """ |
| import tvm |
| import tvm.testing |
| import numpy as np |
| |
| try: |
| import tensorflow.compat.v1 as tf |
| |
| tf.disable_v2_behavior() |
| except ImportError: |
| import tensorflow as tf |
| from tvm import relay |
| from tensorflow.python.framework import graph_util |
| |
| |
| def verify_fused_batch_norm(shape): |
| g = tf.Graph() |
| with g.as_default(): |
| input_tensor = tf.placeholder(tf.float32, shape=shape, name="input") |
| alpha = tf.constant( |
| np.random.rand( |
| shape[-1], |
| ), |
| dtype=tf.float32, |
| name="alpha", |
| ) |
| beta = tf.constant( |
| np.random.rand( |
| shape[-1], |
| ), |
| dtype=tf.float32, |
| name="beta", |
| ) |
| bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name="bn") |
| out = tf.identity(bn[0], name="output") |
| data = np.random.rand(*shape) |
| with tf.Session(graph=out.graph) as sess: |
| sess.run([tf.global_variables_initializer()]) |
| tf_out = sess.run(out, feed_dict={input_tensor: data}) |
| constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) |
| |
| for device in ["llvm"]: |
| dev = tvm.device(device, 0) |
| if not tvm.testing.device_enabled(device): |
| print("Skip because %s is not enabled" % device) |
| continue |
| with tvm.testing.disable_span_filling(): |
| mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) |
| with tvm.testing.enable_span_filling(): |
| mod_with_span, _ = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) |
| assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) |
| with tvm.transform.PassContext(opt_level=3): |
| graph, lib, params = relay.build(mod, target=device, params=params) |
| from tvm.contrib import graph_executor |
| |
| m = graph_executor.create(graph, lib, dev) |
| m.set_input(**params) |
| m.set_input("input", data) |
| m.run() |
| tvm_out = m.get_output(0) |
| tvm.testing.assert_allclose( |
| tvm_out.numpy(), tf_out.astype(tvm_out.dtype), atol=1e-3, rtol=1e-3 |
| ) |
| |
| |
| def test_fused_batch_norm(): |
| verify_fused_batch_norm(shape=(1, 12, 12, 32)) |
| verify_fused_batch_norm(shape=(1, 24, 24, 64)) |
| verify_fused_batch_norm(shape=(1, 64, 64, 128)) |
| verify_fused_batch_norm(shape=(8, 12, 12, 32)) |
| verify_fused_batch_norm(shape=(16, 12, 12, 32)) |
| verify_fused_batch_norm(shape=(32, 12, 12, 32)) |
| |
| |
| if __name__ == "__main__": |
| test_fused_batch_norm() |