blob: 62c531bb637416b6541e3c6454a6f2699940c556 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import sys
sys.path.insert(0, '../../python')
import mxnet as mx
import numpy as np
import os, pickle, gzip
import logging
from mxnet.test_utils import get_cifar10
batch_size = 128
# small mlp network
def get_net():
data = mx.symbol.Variable('data')
float_data = mx.symbol.Cast(data=data, dtype="float32")
fc1 = mx.symbol.FullyConnected(float_data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(fc3, name="softmax")
return softmax
# check data
get_cifar10()
def get_iterator(kv):
data_shape = (3, 28, 28)
train = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/train.rec",
mean_img = "data/cifar/mean.bin",
data_shape = data_shape,
batch_size = batch_size,
random_resized_crop = True,
min_aspect_ratio = 0.75,
max_aspect_ratio = 1.33,
min_random_area = 0.08,
max_random_area = 1,
brightness = 0.4,
contrast = 0.4,
saturation = 0.4,
pca_noise = 0.1,
rand_mirror = True,
num_parts = kv.num_workers,
part_index = kv.rank)
train = mx.io.PrefetchingIter(train)
val = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/test.rec",
mean_img = "data/cifar/mean.bin",
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
batch_size = batch_size,
num_parts = kv.num_workers,
part_index = kv.rank)
return (train, val)
num_epoch = 1
def run_cifar10(train, val, use_module):
train.reset()
val.reset()
devs = [mx.cpu(0)]
net = get_net()
mod = mx.mod.Module(net, context=devs)
optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9}
eval_metrics = ['accuracy']
if use_module:
executor = mx.mod.Module(net, context=devs)
executor.fit(
train,
eval_data=val,
optimizer_params=optim_args,
eval_metric=eval_metrics,
num_epoch=num_epoch,
arg_params=None,
aux_params=None,
begin_epoch=0,
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
epoch_end_callback=None)
else:
executor = mx.model.FeedForward.create(
net,
train,
ctx=devs,
eval_data=val,
eval_metric=eval_metrics,
num_epoch=num_epoch,
arg_params=None,
aux_params=None,
begin_epoch=0,
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
epoch_end_callback=None,
**optim_args)
ret = executor.score(val, eval_metrics)
if use_module:
ret = list(ret)
logging.info('final accuracy = %f', ret[0][1])
assert (ret[0][1] > 0.08)
else:
logging.info('final accuracy = %f', ret[0])
assert (ret[0] > 0.08)
class CustomDataIter(mx.io.DataIter):
def __init__(self, data):
super(CustomDataIter, self).__init__()
self.data = data
self.batch_size = data.provide_data[0][1][0]
# use legacy tuple
self.provide_data = [(n, s) for n, s in data.provide_data]
self.provide_label = [(n, s) for n, s in data.provide_label]
def reset(self):
self.data.reset()
def next(self):
return self.data.next()
def iter_next(self):
return self.data.iter_next()
def getdata(self):
return self.data.getdata()
def getlabel(self):
return self.data.getlable()
def getindex(self):
return self.data.getindex()
def getpad(self):
return self.data.getpad()
def test_cifar10():
# print logging by default
logging.basicConfig(level=logging.DEBUG)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)
kv = mx.kvstore.create("local")
# test float32 input
(train, val) = get_iterator(kv)
run_cifar10(train, val, use_module=False)
run_cifar10(train, val, use_module=True)
# test legecay tuple in provide_data and provide_label
run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=False)
run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=True)
if __name__ == "__main__":
test_cifar10()