blob: e679c05894a36690abb72f5d1da4ee2c7cf00e0d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import mxnet as mx
from lenet5_common import get_iters
def lenet5():
"""LeNet-5 Symbol"""
#pylint: disable=no-member
data = mx.sym.Variable('data')
data = mx.sym.Cast(data, 'float16')
conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
kernel=(2, 2), stride=(2, 2))
# second conv
conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
kernel=(2, 2), stride=(2, 2))
# first fullc
flatten = mx.sym.Flatten(data=pool2)
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
fc2 = mx.sym.Cast(fc2, 'float32')
# loss
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
#pylint: enable=no-member
return lenet
def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
"""train LeNet-5 model on MNIST data"""
ctx = mx.gpu(0)
lenet_model = mx.mod.Module(lenet5(), context=ctx)
lenet_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
eval_metric='acc',
batch_end_callback=mx.callback.Speedometer(batch_size, 1),
num_epoch=num_epochs)
# predict accuracy for lenet
acc = mx.metric.Accuracy()
lenet_model.score(test_iter, acc)
accuracy = acc.get()[1]
assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
return lenet_model
if __name__ == '__main__':
num_epochs = 10
batch_size = 128
model_name = 'lenet5'
model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
model_file = '%s/%s-symbol.json' % (model_dir, model_name)
params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
if not (os.path.exists(model_file) and os.path.exists(params_file)):
mnist = mx.test_utils.get_mnist()
_, _, _, all_test_labels = get_iters(mnist, batch_size)
trained_lenet = train_lenet5(num_epochs, batch_size,
*get_iters(mnist, batch_size)[:-1])
trained_lenet.save_checkpoint(model_name, num_epochs)