blob: ec5ece985cd3b743aaff0b1c39dd49d40668ba10 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import sys
import os
sys.path.insert(0, "../../python/")
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
from get_data import MNISTIterator
import mxnet as mx
import numpy as np
import logging
import time
logging.basicConfig(level=logging.DEBUG)
def build_network():
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1')
sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2')
softmax = mx.symbol.Group([sm1, sm2])
return softmax
class Multi_mnist_iterator(mx.io.DataIter):
'''multi label mnist iterator'''
def __init__(self, data_iter):
super(Multi_mnist_iterator, self).__init__()
self.data_iter = data_iter
self.batch_size = self.data_iter.batch_size
@property
def provide_data(self):
return self.data_iter.provide_data
@property
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
# Different labels should be used here for actual application
return [('softmax1_label', provide_label[1]), \
('softmax2_label', provide_label[1])]
def hard_reset(self):
self.data_iter.hard_reset()
def reset(self):
self.data_iter.reset()
def next(self):
batch = self.data_iter.next()
label = batch.label[0]
return mx.io.DataBatch(data=batch.data, label=[label, label], \
pad=batch.pad, index=batch.index)
class Multi_Accuracy(mx.metric.EvalMetric):
"""Calculate accuracies of multi label"""
def __init__(self, num=None):
self.num = num
super(Multi_Accuracy, self).__init__('multi-accuracy')
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.num_inst = 0 if self.num is None else [0] * self.num
self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
if self.num is not None:
assert len(labels) == self.num
for i in range(len(labels)):
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
label = labels[i].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
if self.num is None:
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
else:
self.sum_metric[i] += (pred_label.flat == label.flat).sum()
self.num_inst[i] += len(pred_label.flat)
def get(self):
"""Gets the current evaluation result.
Returns
-------
names : list of str
Name of the metrics.
values : list of float
Value of the evaluations.
"""
if self.num is None:
return super(Multi_Accuracy, self).get()
else:
return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0
else self.sum_metric[i] / self.num_inst[i])
for i in range(self.num)))
def get_name_value(self):
"""Returns zipped name and value pairs.
Returns
-------
list of tuples
A (name, value) tuple list.
"""
if self.num is None:
return super(Multi_Accuracy, self).get_name_value()
name, value = self.get()
return list(zip(name, value))
batch_size=100
num_epochs=100
device = mx.gpu(0)
lr = 0.01
network = build_network()
train, val = MNISTIterator(batch_size=batch_size, input_shape = (784,))
train = Multi_mnist_iterator(train)
val = Multi_mnist_iterator(val)
model = mx.model.FeedForward(
ctx = device,
symbol = network,
num_epoch = num_epochs,
learning_rate = lr,
momentum = 0.9,
wd = 0.00001,
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34))
model.fit(
X = train,
eval_data = val,
eval_metric = Multi_Accuracy(num=2),
batch_end_callback = mx.callback.Speedometer(batch_size, 50))