blob: 6c81198a8bca0d7e161220a05d8740172015bb36 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-locals,wrong-import-position,import-error
from __future__ import absolute_import
import os
import unittest
import logging
import tempfile
from mxnet import nd, sym
from mxnet.gluon import nn
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
def _assert_sym_equal(lhs, rhs):
assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical
assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical
def _force_list(output):
if isinstance(output, nd.NDArray):
return [output]
return list(output)
def _optional_group(symbols, group=False):
if group:
return sym.Group(symbols)
else:
return symbols
def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params={}):
net.initialize()
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {name: param._reduce() for name, param in net.collect_params().items()}
net_params.update(extra_params)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
sym=net_sym,
params=net_params,
input_shape=[shape_type(data.shape)],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
_assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])
# Try importing the model to gluon
imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
_assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs))
# Confirm network outputs are the same
imported_net_output = _force_list(imported_net(data))
for out, imp_out in zip(output, imported_net_output):
mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy())
class TestExport(unittest.TestCase):
""" Tests ONNX export.
"""
def test_onnx_export_single_output(self):
net = nn.HybridSequential(prefix='single_output_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net)
def test_onnx_export_multi_output(self):
class MultiOutputBlock(nn.HybridBlock):
def __init__(self):
super(MultiOutputBlock, self).__init__()
with self.name_scope():
self.net = nn.HybridSequential()
for i in range(10):
self.net.add(nn.Dense(100 + i * 10, activation='relu'))
def hybrid_forward(self, F, x):
out = tuple(block(x) for block in self.net._children.values())
return out
net = MultiOutputBlock()
assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
_check_onnx_export(net, group_outputs=True)
def test_onnx_export_list_shape(self):
net = nn.HybridSequential(prefix='list_shape_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, shape_type=list)
def test_onnx_export_extra_params(self):
net = nn.HybridSequential(prefix='extra_params_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})
if __name__ == '__main__':
unittest.main()