blob: 02eacc311d73de1e2e0699a7474ec22e789c4c48 [file] [log] [blame]
# pylint: skip-file
from data import mnist_iterator
import mxnet as mx
import numpy as np
import logging
# define mlp
use_torch_criterion = False
data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
if use_torch_criterion:
logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch's label starts from 1
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')
else:
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
# data
train, val = mnist_iterator(batch_size=100, input_shape = (784,))
# train
logging.basicConfig(level=logging.DEBUG)
model = mx.model.FeedForward(
ctx = mx.cpu(0), symbol = mlp, num_epoch = 20,
learning_rate = 0.1, momentum = 0.9, wd = 0.00001)
if use_torch_criterion:
model.fit(X=train, eval_data=val, eval_metric=mx.metric.Torch())
else:
model.fit(X=train, eval_data=val)