blob: bc49aeb9a17e07dffa026d88510997b72c237f90 [file] [log] [blame]
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .common import *
def train_lenet_gluon_save_params_api():
model_name = 'lenet_gluon_save_params_api'
create_model_folder(model_name)
logging.info(f'Saving files for model {model_name}')
net = Net()
weights = mx.initializer.Xavier(magnitude=2.57)
net.initialize(weights, ctx=[mx.cpu(0)])
# Prepare data
test_data = mx.np.random.uniform(-1, 1, size=(20, 1, 30, 30))
output = net(test_data)
# print (y)
mx.npx.savez(os.path.join(get_model_path(model_name), ''.join([model_name, '-data'])), **{'data': test_data})
save_inference_results(output, model_name)
net.save(os.path.join(get_model_path(model_name), ''.join([model_name, '-params'])))
@mx.util.use_np
def train_lenet_gluon_hybrid_export_api():
model_name = 'lenet_gluon_hybrid_export_api'
logging.info(f'Saving files for model {model_name}')
create_model_folder(model_name)
net = HybridNet()
weights = mx.initializer.Xavier(magnitude=2.57)
net.initialize(weights, ctx=[mx.cpu(0)])
net.hybridize()
# Prepare data
test_data = mx.np.array(np.random.uniform(-1, 1, size=(20, 1, 30, 30)))
output = net(test_data)
# print (y)
# Save the test data as well.
# Save the inference output ys
# Save the model params
mx.npx.savez(os.path.join(get_model_path(model_name), ''.join([model_name, '-data'])), **{'data': test_data})
save_inference_results(output, model_name)
if compare_versions(str(mxnet_version) , '1.1.0') < 0:
# v1.0.0 does not have the epoch param in the .exports API. Hence adding this safety net
net.export(os.path.join(get_model_path(model_name), model_name))
else:
# Saving with 0 since by default on 1.0.0 it was saved with 0, so simplifying things
net.export(os.path.join(get_model_path(model_name), model_name), epoch=0)
def train_lstm_gluon_save_parameters_api():
# If this code is being run on version >= 1.2.1 only then execute it,
# since it uses save_parameters and load_parameters API
if compare_versions(str(mxnet_version), '1.2.1') < 0:
logging.warn(f'Found MXNet version {str(mxnet_version)} and exiting because this version does not contain save_parameters'
' and load_parameters functions')
return
model_name = 'lstm_gluon_save_parameters_api'
logging.info(f'Saving files for model {model_name}')
create_model_folder(model_name)
net = SimpleLSTMModel()
weights = mx.initializer.Xavier(magnitude=2.57)
net.initialize(weights, ctx=[mx.cpu(0)])
test_data = mx.np.array(np.random.uniform(-1, 1, size=(10, 30)))
output = net(test_data)
# print output
mx.npx.savez(os.path.join(get_model_path(model_name), ''.join([model_name, '-data'])), **{'data': test_data})
save_inference_results(output, model_name)
net.save_parameters(os.path.join(get_model_path(model_name), ''.join([model_name, '-params'])))
def create_root_folder():
base_path = os.getcwd()
version_path = os.path.join(base_path, 'models')
if not os.path.exists(version_path):
os.mkdir(version_path)
if __name__ == '__main__':
create_root_folder()
train_lenet_gluon_save_params_api()
train_lenet_gluon_hybrid_export_api()
train_lstm_gluon_save_parameters_api()