blob: 767e5924b294de4edf06ae62ad68d5a815ef2152 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import os
import logging
import argparse
from math import ceil
import sparse_sgd
# symbol net
def get_symbol():
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(fc3, name='sm')
return softmax
# download ubyte version of mnist and untar
def download_data():
if not os.path.isdir("data/"):
os.system("mkdir data/")
if (not os.path.exists('data/train-images-idx3-ubyte')) or \
(not os.path.exists('data/train-labels-idx1-ubyte')) or \
(not os.path.exists('data/t10k-images-idx3-ubyte')) or \
(not os.path.exists('data/t10k-labels-idx1-ubyte')):
os.system("wget -q http://data.mxnet.io/mxnet/data/mnist.zip -P data/")
os.chdir("./data")
os.system("unzip -u mnist.zip")
os.chdir("..")
# get data iterators
def get_iters(batch_size):
train = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
data_shape=(784,),
label_name='sm_label',
batch_size=batch_size,
shuffle=True,
flat=True,
silent=False,
seed=10)
val = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
data_shape=(784,),
label_name='sm_label',
batch_size=batch_size,
shuffle=True,
flat=True,
silent=False)
return (train, val)
def test_mlp(args):
# get parameters
prefix = './mlp'
batch_size = 100
pruning_switch_epoch = [int(i) for i in args.pruning_switch_epoch.split(',')]
num_epoch = pruning_switch_epoch[-1]
batches_per_epoch = ceil(60000.0/batch_size)
weight_sparsity = args.weight_sparsity
bias_sparsity = args.bias_sparsity
weight_threshold = args.weight_threshold
bias_threshold = args.bias_threshold
if args.weight_sparsity:
weight_sparsity = [float(i) for i in args.weight_sparsity.split(',')]
bias_sparsity = [float(i) for i in args.bias_sparsity.split(',')]
else:
weight_threshold = [float(i) for i in args.weight_threshold.split(',')]
bias_threshold = [float(i) for i in args.bias_threshold.split(',')]
# get symbols and iterators
sym = get_symbol()
download_data()
(train, val) = get_iters(batch_size)
# fit model
model = mx.mod.Module(
sym,
context=[mx.cpu(i) for i in range(2)],
data_names=['data'],
label_names=['sm_label'])
optimizer_params = {
'learning_rate' : 0.1,
'wd' : 0.004,
'momentum' : 0.9,
'pruning_switch_epoch' : pruning_switch_epoch,
'batches_per_epoch' : batches_per_epoch,
'weight_sparsity' : weight_sparsity,
'bias_sparsity' : bias_sparsity,
'weight_threshold' : weight_threshold,
'bias_threshold' : bias_threshold}
logging.info('Start training...')
model.fit(train,
eval_data=val,
eval_metric='acc',
epoch_end_callback=mx.callback.do_checkpoint(prefix),
num_epoch=num_epoch,
optimizer='sparsesgd',
optimizer_params=optimizer_params)
logging.info('Finish traning...')
# remove files
for i in range(num_epoch):
os.remove('%s-%04d.params' % (prefix, i + 1))
os.remove('%s-symbol.json' % prefix)
if __name__ == "__main__":
# print logging by default
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser(description="sparse training")
parser.add_argument('--pruning_switch_epoch', type=str)
parser.add_argument('--weight_sparsity', type=str, default=None)
parser.add_argument('--bias_sparsity', type=str, default=None)
parser.add_argument('--weight_threshold', type=str, default=None)
parser.add_argument('--bias_threshold', type=str, default=None)
args = parser.parse_args()
test_mlp(args)