blob: 9680ac6cb09167f50270f311c5d1b1ae19bfe34b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import numpy as np
import mxnet as mx
import numba
import logging
# We use numba.jit to implement the loss gradient.
@numba.jit
def mc_hinge_grad(scores, labels):
scores = scores.asnumpy()
labels = labels.asnumpy()
n, _ = scores.shape
grad = np.zeros_like(scores)
for i in range(n):
score = 1 + scores[i] - scores[i, labels[i]]
score[labels[i]] = 0
ind_pred = score.argmax()
grad[i, labels[i]] -= 1
grad[i, ind_pred] += 1
return grad
if __name__ == '__main__':
n_epoch = 10
batch_size = 100
num_gpu = 2
contexts = mx.context.cpu() if num_gpu < 1 else [mx.context.gpu(i) for i in range(num_gpu)]
# build a MLP module
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
mlp = mx.mod.Module(fc3, context=contexts)
loss = mx.mod.PythonLossModule(grad_func=mc_hinge_grad)
mod = mx.mod.SequentialModule() \
.add(mlp) \
.add(loss, take_labels=True, auto_wiring=True)
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)
logging.basicConfig(level=logging.DEBUG)
mod.fit(train_dataiter, eval_data=val_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
num_epoch=n_epoch)