blob: 494143fc0bf281c6cf72645017db3b16616e2572 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Test roundtrip of runtime modules """
# pylint: disable=missing-docstring
import pytest
import tvm
import tvm.testing
from tvm import TVMError
from tvm import relay
def test_csource_module():
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], [])
assert mod.type_key == "c"
assert mod.is_binary_serializable
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.type_key == "c"
assert new_mod.is_binary_serializable
def test_aot_module():
mod = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
# aot module that is not binary serializable.
# Thus, it would raise an error.
assert not mod.is_binary_serializable
with pytest.raises(TVMError):
tvm.ir.load_json(tvm.ir.save_json(mod))
def get_test_mod():
x = relay.var("x", shape=(1, 10), dtype="float32")
y = relay.var("y", shape=(1, 10), dtype="float32")
z = relay.add(x, y)
func = relay.Function([x, y], z)
return relay.build_module._build_module_no_factory(func, target="cuda")
def get_cuda_mod():
# Get Cuda module which is binary serializable
return get_test_mod().imported_modules[0].imported_modules[0]
@tvm.testing.requires_cuda
def test_cuda_module():
mod = get_cuda_mod()
assert mod.type_key == "cuda"
assert mod.is_binary_serializable
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.type_key == "cuda"
assert new_mod.is_binary_serializable
@tvm.testing.requires_cuda
def test_valid_submodules():
mod, mod2, mod3, mod4 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
# Create the nested cuda module
mod.import_module(mod2)
mod2.import_module(mod3)
mod2.import_module(mod4)
# Root module and all submodules should be binary serializable since they are cuda module
assert mod.type_key == "cuda"
assert mod.is_binary_serializable
assert mod.imported_modules[0].type_key == "cuda"
assert mod.imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[0].type_key == "cuda"
assert mod.imported_modules[0].imported_modules[1].type_key == "cuda"
assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[1].is_binary_serializable
# The roundtripped mod should have the same structure
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.type_key == "cuda"
assert new_mod.is_binary_serializable
assert new_mod.imported_modules[0].type_key == "cuda"
assert new_mod.imported_modules[0].is_binary_serializable
assert new_mod.imported_modules[0].imported_modules[0].type_key == "cuda"
assert new_mod.imported_modules[0].imported_modules[1].type_key == "cuda"
assert new_mod.imported_modules[0].imported_modules[0].is_binary_serializable
assert new_mod.imported_modules[0].imported_modules[1].is_binary_serializable
@tvm.testing.requires_cuda
def test_invalid_submodules():
mod, mod2, mod3 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
mod4 = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
# Create the nested cuda module
mod.import_module(mod2)
mod2.import_module(mod3)
mod2.import_module(mod4)
# One of submodules is not binary serializable.
assert mod.is_binary_serializable
assert mod.imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
assert not mod.imported_modules[0].imported_modules[1].is_binary_serializable
# Therefore, we cannot roundtrip.
with pytest.raises(TVMError):
tvm.ir.load_json(tvm.ir.save_json(mod))
if __name__ == "__main__":
tvm.testing.main()