blob: 2f79b58a52cdee7707f6082ccae1c6ed26d430ea [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet
import org.apache.mxnet.Base._
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.ArrayBuffer
object Executor {
// Get the dictionary given name and ndarray pairs.
private[mxnet] def getDict(names: Seq[String],
ndarrays: Seq[NDArray]): Map[String, NDArray] = {
require(names.toSet.size == names.length, "Duplicate names detected")
(names zip ndarrays).toMap
}
}
/**
* Symbolic Executor component of MXNet <br />
* <b>
* WARNING: it is your responsibility to clear this object through dispose().
* </b>
*
* @author Yizhi Liu
*
* Constructor: please use Symbol.bind and Symbol.simpleBind instead.
* @param handle ExecutorHandle generated by calling Bind
* @param symbol
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _gradDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
private[mxnet] var _ctx: Context = null
private[mxnet] var _gradsReq: Iterable[_] = null
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
private var disposed = false
protected def isDisposed = disposed
def dispose(): Unit = {
if (!disposed) {
outputs.foreach(_.dispose())
_LIB.mxExecutorFree(handle)
disposed = true
}
}
/**
* Return a new executor with the same symbol and shared memory,
* but different input/output shapes.
* For runtime reshaping, variable length sequences, etc.
* The returned executor shares state with the current one,
* and cannot be used in parallel with it.
* @param partialShaping Whether to allow changing the shape of unspecified arguments.
* @param allowUpSizing Whether to allow allocating new ndarrays that's larger than the original.
* @param kwargs Map of string to Shape.
* - new shape for arguments.
* @return
* executor A new executor that shares memory with this.
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
require(argShapes != null, "Insufficient argument shapes provided.")
var newArgDict = Map[String, NDArray]()
var newGradDict = Map[String, NDArray]()
this.symbol.listArguments().zipWithIndex.foreach { case (name, i) =>
val newShape = argShapes(i)
val arr = this.argArrays(i)
val dArr = if (this.gradArrays == null) null else this.gradArrays(i)
if (partialShaping || kwargs.contains(name) || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of arg:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context))
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context))
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
if (dArr != null) {
newGradDict = newGradDict + (name -> dArr.reshape(newShape.toArray))
}
}
} else {
throw new AssertionError(s"Shape of unspecified array arg:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
var newAuxDict = Map[String, NDArray]()
val zip3 = (this.symbol.listAuxiliaryStates(), auxShapes, this.auxArrays).zipped
zip3.foreach { case (name, newShape, arr) =>
if (partialShaping || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of aux:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
} else {
throw new AssertionError(s"Shape of unspecified array aux:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Seq[String]],
newAuxDict,
this._group2ctx,
this)
} else {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Map[String, String]],
newAuxDict,
this._group2ctx,
this)
}
}
/**
* list all the output ndarray
* @return A list of ndarray binded to the heads of executor.
*/
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
ndHandles.toArray.map(new NDArray(_))
}
/**
* Calculate the outputs specified by the binded symbol.
* @param isTrain whether this forward is for evaluation purpose.
* @param kwargs Additional specification of input arguments.
*/
def forward(isTrain: Boolean, kwargs: (String, NDArray)*): Unit = {
kwargs.foreach { case (name, array) =>
require(argDict.contains(name), s"Unknown argument $name")
array.copyTo(argDict(name))
}
checkCall(_LIB.mxExecutorForward(handle, if (isTrain) 1 else 0))
}
def forward(): Unit = {
forward(isTrain = false)
}
/**
* Do backward pass to get the gradient of arguments.
* @param outGrads Gradient on the outputs to be propagated back.
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
val ndArrayPtrs = outGrads.map(_.handle)
checkCall(_LIB.mxExecutorBackward(handle, ndArrayPtrs))
}
def backward(outGrad: NDArray): Unit = {
require(outGrad != null)
backward(Array(outGrad))
}
def backward(): Unit = {
backward(Array.empty[NDArray])
}
/**
* Install callback.
* @param callback Takes a string and an NDArrayHandle.
*/
def setMonitorCallback(callback: MXMonitorCallback): Unit = {
monitorCallback = callback
checkCall(_LIB.mxExecutorSetMonitorCallback(handle, monitorCallback))
}
/**
* Get dictionary representation of argument arrrays.
* @return The dictionary that maps name of arguments to NDArrays.
* @throws IllegalArgumentException if there are duplicated names in the arguments.
*/
def argDict: Map[String, NDArray] = {
if (_argDict == null) {
_argDict = Executor.getDict(symbol.listArguments(), argArrays)
}
_argDict
}
/**
* Get dictionary representation of gradient arrays.
* @return The dictionary that maps name of arguments to gradient arrays.
* @throws IllegalArgumentException if there are duplicated names in the grads.
*/
def gradDict: Map[String, NDArray] = {
if (_gradDict == null) {
_gradDict = Executor.getDict(symbol.listArguments(), gradArrays)
}
_gradDict
}
/**
* Get dictionary representation of auxiliary states arrays.
* @return The dictionary that maps name of auxiliary states to NDArrays.
* @throws IllegalArgumentException if there are duplicated names in the auxiliary states.
*/
def auxDict: Map[String, NDArray] = {
if (_auxDict == null) {
_auxDict = Executor.getDict(symbol.listAuxiliaryStates(), auxArrays)
}
_auxDict
}
/**
* Copy parameters from arg_params, aux_params into executor's internal array.
* @param argParams : dict of name to NDArray of arguments
* @param auxParams : dict of name to NDArray of auxiliary states.
* @param allowExtraParams
* Whether allow extra parameters that are not needed by symbol
* If this is True, no error will be thrown when arg_params or aux_params
* contain extra parameters that is not needed by the executor.
* @throws IllegalArgumentException
* If there is additional parameters in the dict but allow_extra_params=False
*/
def copyParamsFrom(argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
allowExtraParams: Boolean = false): Unit = {
argParams.foreach { case (name, array) =>
if (argDict.contains(name)) {
array.copyTo(argDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the arguments")
}
}
if (auxParams != null) {
auxParams.foreach { case (name, array) =>
if (auxDict.contains(name)) {
array.copyTo(auxDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the auxiliary states")
}
}
}
}
def copyParamsFrom(argParams: Map[String, NDArray], allowExtraParams: Boolean): Unit = {
copyParamsFrom(argParams, null, allowExtraParams)
}
def copyParamsFrom(argParams: Map[String, NDArray]): Unit = {
copyParamsFrom(argParams, allowExtraParams = false)
}
/**
* Get a debug string about internal execution plan.
* @return Debug string of the executor.
*/
def debugStr: String = {
val str = new RefString
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}
}