blob: 6b6574c96189466ec957f285f654382a161d762e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import mxnet as mx
from common import create_lin_reg_network, create_logger
from data_reader import get_year_prediction_data
parser = argparse.ArgumentParser()
parser.add_argument('-e', dest='epochs', help='number of epochs for training phase', type=int, default=100)
parser.add_argument('-f', dest="updateFreq", help="update frequency for SVRGModule", type=int, default=2)
parser.add_argument('-b', dest="batch_size", help="define the batch size for training", type=int,
default=100, required=False)
parser.add_argument('-m', dest='metrics', help="create eval metric", type=str, default='mse')
parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
parser.add_argument('--kv-store', type=str, default='local', help='key-value store type')
args = parser.parse_args()
# devices for training
ctx = mx.cpu() if args.gpus is None or args.gpus == "" else [mx.gpu(int(i)) for i in args.gpus.split(',')]
logger = create_logger()
kv = mx.kvstore.create(args.kv_store)
feature_dim, train_features, train_labels = get_year_prediction_data()
train_iter, mod = create_lin_reg_network(train_features, train_labels, feature_dim, args.batch_size, args.updateFreq,
ctx, logger)
mod.fit(train_iter, eval_metric='mse', optimizer='sgd',
optimizer_params=(('learning_rate', 0.025), ), num_epoch=args.epochs, kvstore=kv)