blob: 2e419971401438e1f32dd65281569761554cd62b [file] [log] [blame]
# coding: utf-8
""" Key value store interface of MXNet for parameter synchronization."""
from __future__ import absolute_import
import ctypes
import pickle
from .ndarray import NDArray
from .base import _LIB
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt
def _ctype_key_value(keys, vals):
"""
Returns ctype arrays for the key-value args. For internal use.
"""
if isinstance(keys, int):
if isinstance(vals, NDArray):
return (c_array(ctypes.c_int, [keys]),
c_array(NDArrayHandle, [vals.handle]))
else:
for value in vals:
assert(isinstance(value, NDArray))
return (c_array(ctypes.c_int, [keys] * len(vals)),
c_array(NDArrayHandle, [value.handle for value in vals]))
else:
assert(len(keys) == len(vals))
for k in keys:
assert(isinstance(k, int))
c_keys = []
c_vals = []
for key, val in zip(keys, vals):
c_key_i, c_val_i = _ctype_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals))
def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
def updater_handle(key, lhs_handle, rhs_handle, _):
""" ctypes function """
lhs = NDArray(NDArrayHandle(lhs_handle))
rhs = NDArray(NDArrayHandle(rhs_handle))
updater(key, lhs, rhs)
return updater_handle
class KVStore(object):
"""A key-value store for synchronization of values, over multiple devices."""
def __init__(self, handle):
"""Initializes a new KVStore.
Parameters
----------
handle : KVStoreHandle
`KVStore` handle of C API.
"""
assert isinstance(handle, KVStoreHandle)
self.handle = handle
self._updater = None
self._updater_func = None
def __del__(self):
check_call(_LIB.MXKVStoreFree(self.handle))
def init(self, key, value):
""" Initializes a single or a sequence of key-value pairs into the store.
For each key, one must `init` it before calling `push` or `pull`.
When multiple workers invoke `init` for the same key, only
the value supplied by worker with rank `0` is used. This function returns
after data has been initialized successfully.
Parameters
----------
key : int or sequence of int
The keys.
value : NDArray or sequence of NDArray
Values corresponding to the Keys
Examples
--------
>>> # init a single key-value pair
>>> shape = (2,3)
>>> kv = mx.kv.create('local')
>>> kv.init(3, mx.nd.ones(shape)*2)
>>> a = mx.nd.zeros(shape)
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
>>> # init a list of key-value pairs
>>> keys = [5, 7, 9]
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStoreInit(
self.handle, mx_uint(len(ckeys)), ckeys, cvals))
def push(self, key, value, priority=0):
""" Pushes a single or a sequence of key-value pairs into the store.
This function returns immediately after adding an operator to the engine.
The actual operation is executed asynchronously after all previous `push`
and `pull` calls for the same input key(s) are finished.
There is no synchronization between workers. One can use ``_barrier()``
to sync all workers.
Parameters
----------
key : int or list of int
Keys
value : NDArray or list of NDArray or list of list of NDArray
Values corresponding to the Keys
priority : int, optional
The priority of the push operation.
Higher priority push operations are likely to be executed before
other push actions
Examples
--------
>>> # push a single key-value pair
>>> kv.push(3, mx.nd.ones(shape)*8)
>>> kv.pull(3, out=a) # pull out the value
>>> print a.asnumpy()
[[ 8. 8. 8.]
[ 8. 8. 8.]]
>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(3, b)
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
>>> # push a list of keys.
>>> # single device
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
[[ 1. 1. 1.]
[ 1. 1. 1.]]
>>> # multiple devices:
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.push(keys, b)
>>> kv.pull(keys, out=b)
>>> print b[1][1].asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStorePush(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
def pull(self, key, out=None, priority=0):
""" Pulls a single value or a sequence of values from the store.
This function returns immediately after adding an operator to the engine.
Subsequent attempts to read from the `out` variable will be blocked until the
pull operation completes.
`pull` is executed asynchronously after all previous `push` and `pull` calls
for the same input key(s) are finished.
The returned values are gauranteed to the latest values in the store.
Parameters
----------
key : int or list of int
Keys.
out: NDArray or list of NDArray or list of list of NDArray
Values corresponding to the Keys.
priority : int, optional
The priority of the pull operation.
Higher priority pull operations are likely to be executed before
other pull actions
Examples
--------
>>> # pull a single key-value pair
>>> a = mx.nd.zeros(shape)
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
>>> # pull into multiple devices
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pull(3, out=b)
>>> print b[1].asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
>>> # pull a list of key-value pairs.
>>> # On single device
>>> keys = [5, 7, 9]
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
>>> # On multiple devices
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1][1].asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
"""
assert(out is not None)
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePull(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
def set_optimizer(self, optimizer):
""" Registers an optimizer with the store.
When there are multiple machines, this operation (invoked from a worker node)
will pack the optimizer and send it to all servers. It returns after
this action is done.
Parameters
----------
optimizer : Optimizer
the optimizer
"""
is_worker = ctypes.c_int()
check_call(_LIB.MXKVStoreIsWorkerNode(ctypes.byref(is_worker)))
# pylint: disable=invalid-name
if 'dist' in self.type and is_worker.value:
# send the optimizer to server
try:
# use ASCII protocol 0, might be slower, but not a big ideal
optim_str = pickle.dumps(optimizer, 0)
except:
raise
self._send_command_to_servers(0, optim_str)
else:
self._set_updater(opt.get_updater(optimizer))
@property
def type(self):
""" Returns the type of this kvstore.
Returns
-------
type : str
the string type
"""
kv_type = ctypes.c_char_p()
check_call(_LIB.MXKVStoreGetType(self.handle, ctypes.byref(kv_type)))
return py_str(kv_type.value)
@property
def rank(self):
""" Returns the rank of this worker node.
Returns
-------
rank : int
The rank of this node, which is in range [0, num_workers())
"""
rank = ctypes.c_int()
check_call(_LIB.MXKVStoreGetRank(self.handle, ctypes.byref(rank)))
return rank.value
@property
def num_workers(self):
"""Returns the number of worker nodes.
Returns
-------
size :int
The number of worker nodes.
"""
size = ctypes.c_int()
check_call(_LIB.MXKVStoreGetGroupSize(self.handle, ctypes.byref(size)))
return size.value
def save_optimizer_states(self, fname):
"""Saves optimizer (updater) state to file.
Parameters
----------
fname : str
Path to output states file.
"""
assert self._updater is not None, "Cannot save states for distributed training"
with open(fname, 'wb') as fout:
fout.write(self._updater.get_states())
def load_optimizer_states(self, fname):
"""Loads optimizer (updater) state from file.
Parameters
----------
fname : str
Path to input states file.
"""
assert self._updater is not None, "Cannot save states for distributed training"
self._updater.set_states(open(fname, 'rb').read())
def _set_updater(self, updater):
"""Sets a push updater into the store.
This function only changes the local store. When running on multiple machines one must
use `set_optimizer`.
Parameters
----------
updater : function
The updater function.
Examples
--------
>>> def update(key, input, stored):
... print "update on key: %d" % key
... stored += input * 2
>>> kv._set_updater(update)
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
>>> kv.push(3, mx.nd.ones(shape))
update on key: 3
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 6. 6. 6.]
[ 6. 6. 6.]]
"""
self._updater = updater
_updater_proto = ctypes.CFUNCTYPE(
None, ctypes.c_int, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
self._updater_func = _updater_proto(_updater_wrapper(updater))
check_call(_LIB.MXKVStoreSetUpdater(self.handle, self._updater_func, None))
def _barrier(self):
"""Invokes global barrier among all worker nodes.
For example, assume there are `n` machines. We would like machine `0` to first
`init` the values and then have all the workers `pull` the initialized value.
Before pulling, we can place invoke `_barrier()` to guarantee that the
initialization is finished.
"""
check_call(_LIB.MXKVStoreBarrier(self.handle))
def _send_command_to_servers(self, head, body):
"""Sends a command to all server nodes.
Sending command to a server node will cause that server node to invoke
``KVStoreServer.controller`` to execute the command.
This function returns after the command has been executed on all server
nodes.
Parameters
----------
head : int
the head of the command
body : str
the body of the command
"""
check_call(_LIB.MXKVStoreSendCommmandToServers(
self.handle, mx_uint(head), c_str(body)))
def create(name='local'):
"""Creates a new KVStore.
Parameters
----------
name : {'local'}
The type of KVStore
- local works for multiple devices on a single machine (single process).
- dist works for multiple machines (multiple processes).
Returns
-------
kv : KVStore
The created KVStore.
"""
if not isinstance(name, string_types):
raise TypeError('name must be a string')
handle = KVStoreHandle()
check_call(_LIB.MXKVStoreCreate(c_str(name),
ctypes.byref(handle)))
return KVStore(handle)