blob: 312f9796074d1259f912abcda88f3ad8372dae04 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import numpy as np
import logging
from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
def test_svrg_inference(args):
epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq
train_iter, val_iter, mod = create_network(batch_size, update_freq)
mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd',
optimizer_params=(('learning_rate', 0.025),),
num_epoch=epoch)
def get_validation_score(args):
epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq
train_iter, val_iter, mod = create_network(batch_size, update_freq)
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
metrics = mx.metric.create("mse")
for e in range(epoch):
metrics.reset()
if e % mod.update_freq == 0:
mod.update_full_grads(train_iter)
train_iter.reset()
for batch in train_iter:
mod.forward_backward(data_batch=batch)
mod.update()
mod.update_metric(metrics, batch.label)
y = mod.predict(val_iter)
# test-train data split, 20% test data out of 1000 data samples
assert y.shape == (200, 1)
score = mod.score(val_iter, ['mse'])
print("Training Loss on Validation Set is {}".format(score[0][1]))
def create_network(batch_size, update_freq):
"""Create a linear regression network for performing SVRG optimization.
:return: an instance of mx.io.NDArrayIter
:return: an instance of mx.mod.svrgmodule for performing SVRG optimization
"""
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)
data = np.random.randint(1, 5, [1000, 2])
#Test_Train data split
n_train = int(data.shape[0] * 0.8)
weights = np.array([1.0, 2.0])
label = data.dot(weights)
di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=batch_size)
X = mx.sym.Variable('data')
Y = mx.symbol.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
mod = SVRGModule(
symbol=lro,
data_names=['data'],
label_names=['lin_reg_label'], update_freq=update_freq, logger=logging)
return di, val_iter, mod
# run as a script
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-e', dest='epochs', default=100, type=int)
parser.add_argument('-bs', dest='batch_size', default=32, type=int)
parser.add_argument('-f', dest="update_freq", default=2, type=int)
args = parser.parse_args()
print("========================== SVRG Module Inference ==========================")
test_svrg_inference(args)
print("========================SVRG Module Score ============================")
get_validation_score(args)