blob: e70574d41cb50e5f9a130df4f7cec54af09280f1 [file] [log] [blame]
import mxnet as mx
import argparse
import os, sys
import time
import numpy as np
from mxnet import profiler
def parse_args():
parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.')
parser.add_argument('--profile_filename', type=str, default='profile_executor_5iter.json')
parser.add_argument('--iter_num', type=int, default=5)
parser.add_argument('--fc1', type=int, default=128)
parser.add_argument('--fc2', type=int, default=128)
parser.add_argument('--fc3', type=int, default=128)
parser.add_argument('--fc4', type=int, default=128)
return parser.parse_args()
def _download(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('train-images-idx3-ubyte')) or \
(not os.path.exists('train-labels-idx1-ubyte')) or \
(not os.path.exists('t10k-images-idx3-ubyte')) or \
(not os.path.exists('t10k-labels-idx1-ubyte')):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip")
os.system("unzip -u mnist.zip; rm mnist.zip")
os.chdir("..")
def get_data(data_shape):
data_dir = "mnist/"
batch_size = 128
if '://' not in data_dir:
_download(data_dir)
train = mx.io.MNISTIter(
image = data_dir + "train-images-idx3-ubyte",
label = data_dir + "train-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = batch_size,
shuffle = True,
)
val = mx.io.MNISTIter(
image = data_dir + "t10k-images-idx3-ubyte",
label = data_dir + "t10k-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = batch_size,
)
return (train, val)
def get_symbol():
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=args.fc1)
act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(data=act1 , name='fc2', num_hidden=args.fc2)
act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type='relu')
fc3 = mx.symbol.FullyConnected(data=act2 , name='fc3', num_hidden=args.fc3)
act3 = mx.symbol.Activation(data=fc3, name='relu3', act_type='relu')
fc4 = mx.symbol.FullyConnected(data=act3 , name='fc4', num_hidden=args.fc4)
act4 = mx.symbol.Activation(data=fc4, name='relu4', act_type='relu')
fc5 = mx.symbol.FullyConnected(data=act4 , name='fc5', num_hidden=10)
net = mx.symbol.SoftmaxOutput(data=fc5 , name='softmax')
return net, [('data', (128, 1, 28, 28))], [('softmax_label', (128, ))]
def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=True, use_memonger=False):
if use_memonger:
sym = search_plan(sym, data=data_shapes)
mod = mx.mod.Module(symbol=sym,
data_names=[name for name, _ in provide_data],
label_names=[name for name, _ in provide_label],
context=ctx)
if batch_size is not None:
provide_data = [(name, (batch_size,) + shape[1:]) for name, shape in provide_data]
provide_label = [(name, (batch_size,) + shape[1:]) for name, shape in provide_label]
if is_train:
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=True, inputs_need_grad=False)
else:
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False)
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(optimizer='ccsgd',
optimizer_params={
'learning_rate': 0.0001,
'momentum': 0.0,
'wd': 0.0
})
return mod
def benchmark(mod, dry_run=10, iterations=10):
if len(mod._context) == 1:
ctx = mod._context[0]
else:
ctx = mx.cpu()
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
label = [mx.nd.array(np.random.randint(1, 100, size=shape), ctx=ctx) for _, shape in mod.label_shapes]
batch = mx.io.DataBatch(data, label)
# dry run
for i in range(dry_run):
mod.forward(batch, is_train=True)
mod.backward()
for output in mod.get_outputs(merge_multi_context=False)[0]:
output.wait_to_read()
mod.update()
t0 = time.clock()
profiler.profiler_set_state('run')
# real run
for i in range(iterations):
mod.forward(batch, is_train=True)
mod.backward()
mod.update()
for output in mod.get_outputs(merge_multi_context=False)[0]:
output.wait_to_read()
profiler.profiler_set_state('stop')
t1 = time.clock()
return (t1 - t0)*1000.0 / iterations
def executor(num_iteration):
sym, provide_data, provide_label = get_symbol()
ctx = [mx.cpu(0)]
mod = get_module(ctx, sym, provide_data, provide_label, batch_size=128)
return benchmark(mod, iterations=args.iter_num)
args = parse_args()
if __name__ == '__main__':
mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename)
print('profile file save to {0}'.format(args.profile_filename))
print('executor num_iteration: {0}'.format(args.iter_num))
executor_time = executor(args.iter_num)
print("executor {0} ms / iteration".format(executor_time))