blob: f77ecf8fd39d656be12b6cbae81c5859c55b29ba [file] [view]
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
## SVRGModule Example
SVRGModule is an extension to the Module API that implements SVRG optimization, which stands for Stochastic
Variance Reduced Gradient. SVRG is an optimization technique that complements SGD and has several key
properties:
* Employs explicit variance reduction by using a different update rule compared to SGD.
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
* Guarantees for fast convergence for smooth and strongly convex functions.
#### API Usage Example
SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API.
example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API.
example_inference.py: provides example usage of SVRGModule inference.
#### Linear Regression
This example trains a linear regression model using SVRGModule on a real dataset, YearPredictionMSD.
Logs of the training results can be found in experiments.log which will automatically generated when running the
training script.
##### Dataset
YearPredictionMSD: contains predictions of the release year of a song from audio features. It has over
400,000 samples with 90 features. It will be automatically downloaded on first execution and cached.
YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd
#### Benchmarks:
An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model. A jupyter
notebook under `/benchmarks` demonstrates the training process and plots two graphs for benchmarking.
* benchmark1: A lr_scheduler returns a new learning rate based on the number of updates that have been performed.
* benchmark2: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero,
thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG,
therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with
learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges
much faster than SGD in both cases.