blob: 3ff606dca38fb0751fa14ad8dbd76e770f17d915 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet
import java.io._
import ml.dmlc.mxnet.Base._
import org.slf4j.{Logger, LoggerFactory}
/**
* Key value store interface of MXNet for parameter synchronization.
* @author Yizhi Liu
*/
object KVStore {
// group id of scheduler/server/worker
val GROUP_NODE_SCHEDULER = 1
val GROUP_NODE_SERVER = 2
val GROUP_NODE_WORKER = 4
/**
* Create a new KVStore. <br />
* <b>
* WARNING: it is your responsibility to clear this object through dispose().
* </b>
*
* @param name : {'local', 'dist'}
* The type of KVStore
* - local works for multiple devices on a single machine (single process)
* - dist works for multi-machines (multiple processes)
* @return The created KVStore
*/
def create(name: String = "local"): KVStore = {
val handle = new KVStoreHandleRef
checkCall(_LIB.mxKVStoreCreate(name, handle))
new KVStore(handle.value)
}
}
class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
private var updaterFunc: MXKVStoreUpdater = null
private var disposed = false
protected def isDisposed = disposed
/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxKVStoreFree(handle)
disposed = true
}
}
/**
* Initialize a single or a sequence of key-value pairs into the store.
* For each key, one must init it before push and pull.
* Only worker 0's (rank == 0) data are used.
* This function returns after data have been initialized successfully
*
* @param keys The keys.
* @param values The values.
*/
def init(keys: Array[String], values: Array[NDArray]): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle)
checkCall(_LIB.mxKVStoreInitEx(handle, keys.length, keys, valuePtrs))
}
def init(key: String, value: NDArray): Unit = {
init(Array(key), Array(value))
}
/**
* Push a single or a sequence of key-value pairs into the store.
* Data consistency:
* 1. this function returns after adding an operator to the engine.
* 2. push is always called after all previous push and pull on the same key are finished
* 3. there is no synchronization between workers. One can use _barrier() to sync all workers
*
* @param keys Keys
* @param values According values
* @param priority
* The priority of the push operation.
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
def push(keys: Array[String], values: Array[NDArray], priority: Int): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle)
checkCall(_LIB.mxKVStorePushEx(handle, keys.length, keys, valuePtrs, priority))
}
def push(keys: Array[String], values: Array[NDArray]): Unit = push(keys, values, 0)
def push(key: String, value: NDArray, priority: Int = 0): Unit = {
push(Array(key), Array(value), priority)
}
def push(key: String, values: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(values.length)(key)
push(keys, values, priority)
}
def push(key: String, values: Array[NDArray]): Unit = {
push(key, values, 0)
}
/**
* Pull a single value or a sequence of values from the store.
*
* Data consistency:
* 1. this function returns after adding an operator to the engine. But any
* further read on out will be blocked until it is finished.
* 2. pull is always called after all previous push and pull on the same key are finished
* 3. It pulls the newest value from the store.
* @param keys Keys
* @param outs According values
* @param priority
* The priority of the push operation.
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
def pull(keys: Array[String], outs: Array[NDArray], priority: Int): Unit = {
require(keys.length == outs.length, "len(keys) != len(outs)")
val outPtrs = outs.map(_.handle)
checkCall(_LIB.mxKVStorePullEx(handle, keys.length, keys, outPtrs, priority))
}
def pull(keys: Array[String], outs: Array[NDArray]): Unit = pull(keys, outs, 0)
def pull(key: String, out: NDArray, priority: Int = 0): Unit = {
pull(Array(key), Array(out), priority)
}
def pull(key: String, outs: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(outs.length)(key)
pull(keys, outs, priority)
}
def pull(key: String, outs: Array[NDArray]): Unit = {
pull(key, outs, 0)
}
// Get the type of this kvstore
def `type`: String = {
val kvType = new RefString
checkCall(_LIB.mxKVStoreGetType(handle, kvType))
kvType.value
}
/**
* Get the number of worker nodes
* @return The number of worker nodes
*/
def numWorkers: Int = {
val size = new RefInt
checkCall(_LIB.mxKVStoreGetGroupSize(handle, size))
size.value
}
/**
* Get the rank of this worker node
* @return The rank of this node, which is in [0, get_num_workers())
*/
def rank: Int = {
val rank = new RefInt
checkCall(_LIB.mxKVStoreGetRank(handle, rank))
rank.value
}
/**
* Register an optimizer to the store
* If there are multiple machines, this process (should be a worker node)
* will pack this optimizer and send it to all servers. It returns after
* this action is done.
*
* @param optimizer the optimizer
*/
def setOptimizer(optimizer: Optimizer): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if (`type`.contains("dist") && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
val cmd = Serializer.encodeBase64String(optSerialized)
logger.debug("Send optimizer to server: {}", cmd)
sendCommandToServers(0, cmd)
} else {
setUpdater(Optimizer.getUpdater(optimizer))
}
}
/**
* Set a push updater into the store.
*
* This function only changes the local store. Use setOptimizer for
* multi-machines.
*
* @param updater the updater function
*/
def setUpdater(updater: MXKVStoreUpdater): Unit = {
this.updaterFunc = updater
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc))
}
/**
* Global barrier among all worker nodes
*
* For example, assume there are n machines, we want to let machine 0 first
* init the values, and then pull the inited value to all machines. Before
* pulling, we can place a barrier to guarantee that the initialization is
* finished.
*/
def barrier(): Unit = {
checkCall(_LIB.mxKVStoreBarrier(handle))
}
def numDeadNode(nodeId: Int): Int = {
val number = new RefInt
checkCall(_LIB.mxKVStoreGetNumDeadNode(handle, nodeId, number))
number.value
}
/**
* Whether to do barrier when the kvstore finalizes
* @param barrierBeforeExit
*/
def setBarrierBeforeExit(barrierBeforeExit: Boolean): Unit = {
val flag: Int = if (barrierBeforeExit) 1 else 0
checkCall(_LIB.mxKVStoreSetBarrierBeforeExit(handle, flag))
}
/**
* Send a command to all server nodes
*
* Send a command to all server nodes, which will make each server node run
* KVStoreServer.controller
*
* This function returns after the command has been executed in all server nodes
*
* @param head the head of the command
* @param body the body of the command
*/
private def sendCommandToServers(head: Int, body: String): Unit = {
checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body))
}
/**
* Save optimizer (updater) state to file
* @param fname Path to output states file.
*/
def saveOptimizerStates(fname: String): Unit = {
require(updaterFunc != null, "Cannot save states for distributed training")
updaterFunc match {
case cachedStates: MXKVStoreCachedStates =>
val target = new BufferedOutputStream(new FileOutputStream(fname))
try {
target.write(cachedStates.serializeState())
} finally {
target.close()
}
case _ =>
logger.warn("Updater does not have states, skip saving to {}", fname)
}
}
/**
* Load optimizer (updater) state from file
* @param fname Path to input states file.
*/
def loadOptimizerStates(fname: String): Unit = {
assert(updaterFunc != null, "Cannot load states for distributed training")
updaterFunc match {
case cachedStates: MXKVStoreCachedStates =>
val bis = new BufferedInputStream (new FileInputStream (fname) )
try {
val bArray = Stream.continually (bis.read).takeWhile (- 1 !=).map (_.toByte).toArray
cachedStates.deserializeState(bArray)
} finally {
bis.close ()
}
case _ =>
logger.warn("Updater does not have states, skip loading from {}", fname)
}
}
}