blob: d2737dc12af7346d75dab7b10666c75310ec9ee0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import get_data
import mxnet as mx
import numpy as np
import logging
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
n_epoch = 2
batch_size = 100
basedir = os.path.dirname(__file__)
get_data.get_mnist(os.path.join(basedir, "data"))
train_dataiter = mx.io.MNISTIter(
image=os.path.join(basedir, "data", "train-images-idx3-ubyte"),
label=os.path.join(basedir, "data", "train-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image=os.path.join(basedir, "data", "t10k-images-idx3-ubyte"),
label=os.path.join(basedir, "data", "t10k-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)
################################################################################
# Intermediate-level API
################################################################################
mod = mx.mod.Module(softmax)
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
metric = mx.metric.create('acc')
for i_epoch in range(n_epoch):
for i_iter, batch in enumerate(train_dataiter):
mod.forward(batch)
mod.update_metric(metric, batch.label)
mod.backward()
mod.update()
for name, val in metric.get_name_value():
print('epoch %03d: %s=%f' % (i_epoch, name, val))
metric.reset()
train_dataiter.reset()
################################################################################
# High-level API
################################################################################
logging.basicConfig(level=logging.DEBUG)
train_dataiter.reset()
mod = mx.mod.Module(softmax)
mod.fit(train_dataiter, eval_data=val_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)
# prediction iterator API
for preds, i_batch, batch in mod.iter_predict(val_dataiter):
pred_label = preds[0].asnumpy().argmax(axis=1)
label = batch.label[0].asnumpy().astype('int32')
if i_batch % 20 == 0:
print('batch %03d acc: %.3f' % (i_batch, (label == pred_label).sum() / float(len(pred_label))))
# a dummy call just to test if the API works for merge_batches=True
preds = mod.predict(val_dataiter)
# perform prediction and calculate accuracy manually
preds = mod.predict(val_dataiter, merge_batches=False)
val_dataiter.reset()
acc_sum = 0.0; acc_cnt = 0
for i, batch in enumerate(val_dataiter):
pred_label = preds[i][0].asnumpy().argmax(axis=1)
label = batch.label[0].asnumpy().astype('int32')
acc_sum += (label == pred_label).sum()
acc_cnt += len(pred_label)
print('validation Accuracy: %.3f' % (acc_sum / acc_cnt))
# evaluate on validation set with a evaluation metric
mod.score(val_dataiter, metric)
for name, val in metric.get_name_value():
print('%s=%f' % (name, val))