| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: skip-file |
| import mxnet as mx |
| from mxnet.test_utils import get_mnist_iterator |
| import numpy as np |
| import logging |
| import time |
| |
| logging.basicConfig(level=logging.DEBUG) |
| |
| def build_network(): |
| data = mx.symbol.Variable('data') |
| fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) |
| act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") |
| fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) |
| act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") |
| fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) |
| sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1') |
| sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2') |
| |
| softmax = mx.symbol.Group([sm1, sm2]) |
| |
| return softmax |
| |
| class Multi_mnist_iterator(mx.io.DataIter): |
| '''multi label mnist iterator''' |
| |
| def __init__(self, data_iter): |
| super(Multi_mnist_iterator, self).__init__() |
| self.data_iter = data_iter |
| self.batch_size = self.data_iter.batch_size |
| |
| @property |
| def provide_data(self): |
| return self.data_iter.provide_data |
| |
| @property |
| def provide_label(self): |
| provide_label = self.data_iter.provide_label[0] |
| # Different labels should be used here for actual application |
| return [('softmax1_label', provide_label[1]), \ |
| ('softmax2_label', provide_label[1])] |
| |
| def hard_reset(self): |
| self.data_iter.hard_reset() |
| |
| def reset(self): |
| self.data_iter.reset() |
| |
| def next(self): |
| batch = self.data_iter.next() |
| label = batch.label[0] |
| |
| return mx.io.DataBatch(data=batch.data, label=[label, label], \ |
| pad=batch.pad, index=batch.index) |
| |
| class Multi_Accuracy(mx.metric.EvalMetric): |
| """Calculate accuracies of multi label""" |
| |
| def __init__(self, num=None): |
| self.num = num |
| super(Multi_Accuracy, self).__init__('multi-accuracy') |
| |
| def reset(self): |
| """Resets the internal evaluation result to initial state.""" |
| self.num_inst = 0 if self.num is None else [0] * self.num |
| self.sum_metric = 0.0 if self.num is None else [0.0] * self.num |
| |
| def update(self, labels, preds): |
| mx.metric.check_label_shapes(labels, preds) |
| |
| if self.num is not None: |
| assert len(labels) == self.num |
| |
| for i in range(len(labels)): |
| pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') |
| label = labels[i].asnumpy().astype('int32') |
| |
| mx.metric.check_label_shapes(label, pred_label) |
| |
| if self.num is None: |
| self.sum_metric += (pred_label.flat == label.flat).sum() |
| self.num_inst += len(pred_label.flat) |
| else: |
| self.sum_metric[i] += (pred_label.flat == label.flat).sum() |
| self.num_inst[i] += len(pred_label.flat) |
| |
| def get(self): |
| """Gets the current evaluation result. |
| |
| Returns |
| ------- |
| names : list of str |
| Name of the metrics. |
| values : list of float |
| Value of the evaluations. |
| """ |
| if self.num is None: |
| return super(Multi_Accuracy, self).get() |
| else: |
| return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 |
| else self.sum_metric[i] / self.num_inst[i]) |
| for i in range(self.num))) |
| |
| def get_name_value(self): |
| """Returns zipped name and value pairs. |
| |
| Returns |
| ------- |
| list of tuples |
| A (name, value) tuple list. |
| """ |
| if self.num is None: |
| return super(Multi_Accuracy, self).get_name_value() |
| name, value = self.get() |
| return list(zip(name, value)) |
| |
| |
| batch_size=100 |
| num_epochs=100 |
| device = mx.gpu(0) |
| lr = 0.01 |
| |
| network = build_network() |
| train, val = get_mnist_iterator(batch_size=batch_size, input_shape = (784,)) |
| train = Multi_mnist_iterator(train) |
| val = Multi_mnist_iterator(val) |
| |
| |
| model = mx.mod.Module( |
| context = device, |
| symbol = network, |
| label_names = ('softmax1_label', 'softmax2_label')) |
| |
| model.fit( |
| train_data = train, |
| eval_data = val, |
| eval_metric = Multi_Accuracy(num=2), |
| num_epoch = num_epochs, |
| optimizer_params = (('learning_rate', lr), ('momentum', 0.9), ('wd', 0.00001)), |
| initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), |
| batch_end_callback = mx.callback.Speedometer(batch_size, 50)) |
| |