blob: fdb6689579be4a991fcf6814a909926da99d4962 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""export helper functions"""
# coding: utf-8
import os
import logging
import mxnet as mx
def load_module(sym_filepath, params_filepath):
"""Loads the MXNet model file and
returns MXNet symbol and params (weights).
Parameters
----------
json_path : str
Path to the json file
params_path : str
Path to the params file
Returns
-------
sym : MXNet symbol
Model symbol object
params : params object
Model weights including both arg and aux params.
"""
if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)):
raise ValueError("Symbol and params files provided are invalid")
try:
# reads symbol.json file from given path and
# retrieves model prefix and number of epochs
model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0]
params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1)
# Setting num_epochs to 0 if not present in filename
num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1])
except IndexError:
logging.info("Model and params name should be in format: "
"prefix-symbol.json, prefix-epoch.params")
raise
sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
# Merging arg and aux parameters
params = {}
params.update(arg_params)
params.update(aux_params)
return sym, params