blob: 8eaa303a6c1b0b5a4ebf976c13ae14b34358dfb7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""ONNX test backend wrapper"""
try:
import onnx.backend.test
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
import test_cases
import unittest
import backend as mxnet_backend
import logging
operations = ['import', 'export']
backends = ['mxnet', 'gluon']
# This is a pytest magic variable to load extra plugins
pytest_plugins = "onnx.backend.test.report",
def test_suite(backend_tests): # type: () -> unittest.TestSuite
'''
TestSuite that can be run by TestRunner
This has been borrowed from onnx/onnx/backend/test/runner/__init__.py,
since Python3 cannot sort objects of type 'Type' as Runner.test_suite()
expects.
'''
suite = unittest.TestSuite()
for case in backend_tests.test_cases.values():
suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
return suite
def prepare_tests(backend, oper):
"""
Prepare the test list
:param backend: mxnet/gluon backend
:param oper: str. export or import
:return: backend test list
"""
BACKEND_TESTS = onnx.backend.test.BackendTest(backend, __name__)
implemented_ops = test_cases.IMPLEMENTED_OPERATORS_TEST.get('both', []) + \
test_cases.IMPLEMENTED_OPERATORS_TEST.get(oper, [])
for op_test in implemented_ops:
BACKEND_TESTS.include(op_test)
basic_models = test_cases.BASIC_MODEL_TESTS.get('both', []) + \
test_cases.BASIC_MODEL_TESTS.get(oper, [])
for basic_model_test in basic_models:
BACKEND_TESTS.include(basic_model_test)
std_models = test_cases.STANDARD_MODEL.get('both', []) + \
test_cases.STANDARD_MODEL.get(oper, [])
for std_model_test in std_models:
BACKEND_TESTS.include(std_model_test)
# Tests for scalar ops are in test_node.py
BACKEND_TESTS.exclude('.*scalar.*')
return BACKEND_TESTS
for bkend in backends:
for operation in operations:
log = logging.getLogger(bkend + operation)
if bkend == 'gluon' and operation == 'export':
log.warning('Gluon->ONNX export not implemented. Skipping tests...')
continue
log.info('Executing tests for ' + bkend + ' backend: ' + operation)
mxnet_backend.MXNetBackend.set_params(bkend, operation)
BACKEND_TESTS = prepare_tests(mxnet_backend, operation)
unittest.TextTestRunner().run(test_suite(BACKEND_TESTS.enable_report()))