blob: f9bf00e4239ef3cd42dc3def665b165661d39795 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel, redefined-builtin
"""TF2 to relay converter test utilities"""
import tvm
from tvm import relay
from tvm.runtime.vm import VirtualMachine
import tvm.contrib.graph_executor as runtime
from tvm.relay.frontend.tensorflow2 import from_tensorflow
import tvm.testing
from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list
import tensorflow as tf
from tensorflow.python.eager.def_function import Function
def run_tf_code(func, input_):
if type(func) is Function:
f_out = func(input_)
if isinstance(f_out, (list, tuple)):
np_out = [x.numpy() for x in f_out]
else:
np_out = [f_out.numpy()]
else:
f_out = func(tf.constant(input_))
if type(f_out) is dict:
np_out = [f_out[k].numpy() for k in sorted(f_out.keys())]
elif type(f_out) is list:
np_out = [x.numpy() for x in f_out]
else:
np_out = f_out.numpy()
return np_out
def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3):
with tvm.transform.PassContext(opt_level):
lib = relay.build(mod, target=tvm.target.Target(target, host=target_host), params=params)
return lib
def compile_vm(mod, params, target="llvm", target_host="llvm", opt_level=3, disabled_pass=None):
with tvm.transform.PassContext(opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(
mod, target=tvm.target.Target(target, host=target_host), params=params
)
return vm_exec
def run_vm(vm_exec, input_, ctx=tvm.cpu(0)):
vm = VirtualMachine(vm_exec, ctx)
_out = vm.invoke("main", input_)
return vmobj_to_list(_out)
def run_graph_executor(lib, input_, ctx=tvm.cpu(0)):
mod = runtime.GraphModule(lib["default"](ctx))
mod.set_input(0, input_)
mod.run()
return [mod.get_output(i).numpy() for i in range(mod.get_num_outputs())]
def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None):
"""compare tf and tvm execution for the same input.
Parameters
----------
gdef: TF2 graph def extracted to be fed into from_tensorflow parser.
(https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
input_: a single numpy array object
output_: the expected output from TF to match TVM output with
runtime: choose TVM runtime; either "vm" for VirtualMachine or "graph" for GraphExecutor
output_tensors : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
"""
mod, params = from_tensorflow(gdef, outputs=output_tensors)
if runtime == "vm":
exec_ = compile_vm(mod, params)
tvm_out = run_vm(exec_, input_)
elif runtime == "graph":
lib = compile_graph_executor(mod, params)
tvm_out = run_graph_executor(lib, input_)
else:
raise RuntimeError("Runtime input not supported: %s" % runtime)
tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5)