blob: 27db5656d7d7c467ffe54311d2884b549e9f6c29 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet
import java.io._
import scala.collection.mutable
import scala.util.Either
object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater with MXKVStoreCachedStates {
override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
}
override def dispose(): Unit = {
states.values.foreach(optimizer.disposeState)
states.clear()
}
override def serializeState(): Array[Byte] = {
val bos = new ByteArrayOutputStream()
try {
val out = new ObjectOutputStream(bos)
out.writeInt(states.size)
states.foreach { case (k, v) =>
if (v != null) {
out.writeInt(k)
val stateBytes = optimizer.serializeState(v)
if (stateBytes == null) {
out.writeInt(0)
} else {
out.writeInt(stateBytes.length)
out.write(stateBytes)
}
}
}
out.flush()
bos.toByteArray
} finally {
try {
bos.close()
} catch {
case _: Throwable =>
}
}
}
override def deserializeState(bytes: Array[Byte]): Unit = {
val bis = new ByteArrayInputStream(bytes)
var in: ObjectInputStream = null
try {
in = new ObjectInputStream(bis)
val size = in.readInt()
(0 until size).foreach(_ => {
val key = in.readInt()
val bytesLength = in.readInt()
val value =
if (bytesLength > 0) {
val bytes = Array.fill[Byte](bytesLength)(0)
in.readFully(bytes)
optimizer.deserializeState(bytes)
} else {
null
}
states.update(key, value)
})
} finally {
try {
if (in != null) {
in.close()
}
} catch {
case _: Throwable =>
}
}
}
}
}
}
abstract class Optimizer extends Serializable {
protected val lrMult: mutable.Map[Either[Int, String], Float] =
mutable.HashMap.empty[Either[Int, String], Float]
protected val wdMult: mutable.Map[Either[Int, String], Float] =
mutable.HashMap.empty[Either[Int, String], Float]
protected var numUpdate: Int = 0
protected val indexUpdateCount: mutable.Map[Int, Int] = mutable.HashMap.empty[Int, Int]
protected var specialized: Boolean = false
protected val weightSet: mutable.Set[Int] = mutable.HashSet.empty[Int]
protected var rescaleGrad: Float = 1
@transient protected var symbol: Symbol = null
protected var idx2name: Map[Int, String] = null
/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
* @param weight weight ndarray
* @param grad grad ndarray
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
// TODO: make state a ClassTag
def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit
// Create additional optimizer state such as momentum.
// TODO: make returned state a ClassTag
def createState(index: Int, weight: NDArray): AnyRef
// Dispose the state it created
def disposeState(state: AnyRef): Unit
def serializeState(state: AnyRef): Array[Byte]
def deserializeState(bytes: Array[Byte]): AnyRef
// Set individual learning rate scale for parameters
@deprecated("Use setLrMult instead.")
def setLrScale(lrScale: Map[Int, Float]): Unit = {
val argsLrScale: Map[Either[Int, String], Float] = lrScale.map { case (k, v) => Left(k) -> v }
setLrMult(argsLrScale)
}
/**
* Sets an individual learning rate multiplier for each parameter.
* If you specify a learning rate multiplier for a parameter, then
* the learning rate for the parameter will be set as the product of
* the global learning rate and its multiplier.
* note:: The default learning rate multiplier of a `Variable`
* can be set with `lr_mult` argument in the constructor.
* @param argsLrMult: Map[Either[Int, String], Float]
* For each of its key-value entries, the learning rate multipler for the
* parameter specified in the key will be set as the given value.
*
* You can specify the parameter with either its name or its index.
* If you use the name, you should also call the `setSymbol` method first,
* and the name you specified in the key of `argsLrMult` should match
* the name of the parameter in the `sym` you pass to `setSymbol` method.
* If you use the index, it should correspond to the index of the parameter
* used in the `update` method.
*
* Specifying a parameter by its index is only supported for backward
* compatibility, and we recommend to use the name instead.
*/
def setLrMult(argsLrMult: Map[Either[Int, String], Float]): Unit = {
argsLrMult.foreach { case (k, v) => this.lrMult(k) = v }
}
/**
* Sets an individual weight decay multiplier for each parameter.
*
* By default, the weight decay multipler is set as 0 for all
* parameters whose name don't end with ``_weight`` or ``_gamma``, if
* you call the `setIdx2Name` method to set idx2name.
*
* note:: The default weight decay multiplier for a `Variable`
* can be set with its `wd_mult` argument in the constructor.
* @param argsWdMult: Map[Either[Int, String], Float]
* For each of its key-value entries, the learning rate multipler for the
* parameter specified in the key will be set as the given value.
*
* You can specify the parameter with either its name or its index.
* If you use the name, you should also call the `setSymbol` method first,
* and the name you specified in the key of `argsWdMult` should match
* the name of the parameter in the `sym` you pass to `setSymbol` method.
* If you use the index, it should correspond to the index of the parameter
* used in the `update` method.
*
* Specifying a parameter by its index is only supported for backward
* compatibility, and we recommend to use the name instead.
*/
def setWdMult(argsWdMult: Map[Either[Int, String], Float]): Unit = {
argsWdMult.foreach { case (k, v) => this.wdMult(k) = v }
}
def setArgNames(argNames: Seq[String]): Unit = {
if (argNames != null) {
specialized = true
var index = 0
argNames foreach { name =>
if (!name.endsWith("data") && !name.endsWith("label")) {
if (name.endsWith("weight")) {
weightSet.add(index)
}
index += 1
}
}
}
}
// Set rescaling factor of gradient.
def setRescaleGrad(rescaleGrad: Float): Unit = {
this.rescaleGrad = rescaleGrad
}
def setSymbol(sym: Symbol): Unit = {
this.symbol = sym
if (this.symbol != null) {
val attr = this.symbol.attrMap
for (name <- this.symbol.listArguments()) {
if (attr.contains(name) && attr(name).contains("__lr_mult__")) {
this.lrMult(Right(name)) = attr(name)("__lr_mult__").toFloat
}
if (attr.contains(name) && attr(name).contains("__wd_mult__")) {
this.wdMult(Right(name)) = attr(name)("__wd_mult__").toFloat
}
}
}
}
def setIdx2Name(paramIdx2Name: Map[Int, String]): Unit = {
this.idx2name = paramIdx2Name
if (this.idx2name != null) {
for (n <- this.idx2name.values) {
if (!(n.endsWith("_weight") || n.endsWith("_gamma"))) {
this.wdMult(Right(n)) = 0f
}
}
}
}
/**
* update num_update
* @param index The index will be updated
*/
protected def updateCount(index: Int): Unit = {
val count = indexUpdateCount.getOrElseUpdate(index, 0) + 1
indexUpdateCount.update(index, count)
numUpdate = Math.max(count, numUpdate)
}
// Gets the learning rate given the index of the weight.
protected def getLr(index: Int, lr: Float): Float = {
var llr = lr
if (this.lrMult.contains(Left(index))) {
llr *= this.lrMult(Left(index))
} else if (this.idx2name != null && this.idx2name.contains(index)) {
llr *= this.lrMult.getOrElse(Right(this.idx2name(index)), 1.0f)
}
llr
}
// Gets weight decay for index.
protected def getWd(index: Int, wd: Float): Float = {
var lwd = if (specialized) {
if (this.weightSet.contains(index)) {
wd
} else {
0f
}
} else {
wd
}
if (this.wdMult.contains(Left(index))) {
lwd *= this.wdMult(Left(index))
} else if (this.idx2name != null && this.idx2name.contains(index)) {
lwd *= this.wdMult.getOrElse(Right(this.idx2name(index)), 1.0f)
}
lwd
}
}
trait MXKVStoreUpdater {
/**
* user-defined updater for the kvstore
* It's this updater's responsibility to delete recv and local
* @param key the key
* @param recv the pushed value on this key
* @param local the value stored on local on this key
*/
def update(key: Int, recv: NDArray, local: NDArray): Unit
def dispose(): Unit
// def serializeState(): Array[Byte]
// def deserializeState(bytes: Array[Byte]): Unit
}
trait MXKVStoreCachedStates {
protected val states = new scala.collection.mutable.HashMap[Int, AnyRef]
/**
* Serialize states to byte array
* @return serialized states
*/
def serializeState(): Array[Byte]
/**
* Update states with serialized results
* @param bytes Generated by serializeState()
*/
def deserializeState(bytes: Array[Byte]): Unit
}