| # coding: utf-8 |
| """ Key value store interface of MXNet for parameter synchronization.""" |
| from __future__ import absolute_import |
| |
| import ctypes |
| import pickle |
| from .ndarray import NDArray |
| from .base import _LIB |
| from .base import check_call, c_array, c_str, string_types, mx_uint, py_str |
| from .base import NDArrayHandle, KVStoreHandle |
| from . import optimizer as opt |
| |
| def _ctype_key_value(keys, vals): |
| """ |
| Return ctype arrays for the key-value args, for internal use. |
| """ |
| if isinstance(keys, int): |
| if isinstance(vals, NDArray): |
| return (c_array(ctypes.c_int, [keys]), |
| c_array(NDArrayHandle, [vals.handle])) |
| else: |
| for value in vals: |
| assert(isinstance(value, NDArray)) |
| return (c_array(ctypes.c_int, [keys] * len(vals)), |
| c_array(NDArrayHandle, [value.handle for value in vals])) |
| else: |
| assert(len(keys) == len(vals)) |
| for k in keys: |
| assert(isinstance(k, int)) |
| c_keys = [] |
| c_vals = [] |
| for key, val in zip(keys, vals): |
| c_key_i, c_val_i = _ctype_key_value(key, val) |
| c_keys += c_key_i |
| c_vals += c_val_i |
| return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) |
| |
| |
| def _updater_wrapper(updater): |
| """A wrapper for the user-defined handle.""" |
| def updater_handle(key, lhs_handle, rhs_handle, _): |
| """ ctypes function """ |
| lhs = NDArray(NDArrayHandle(lhs_handle)) |
| rhs = NDArray(NDArrayHandle(rhs_handle)) |
| updater(key, lhs, rhs) |
| return updater_handle |
| |
| |
| class KVStore(object): |
| """A key-value store for synchronization of values, over multiple devices.""" |
| def __init__(self, handle): |
| """Initialize a new KVStore. |
| |
| Parameters |
| ---------- |
| handle : KVStoreHandle |
| KVStore handle of C API. |
| """ |
| assert isinstance(handle, KVStoreHandle) |
| self.handle = handle |
| self._updater = None |
| self._updater_func = None |
| |
| def __del__(self): |
| check_call(_LIB.MXKVStoreFree(self.handle)) |
| |
| def init(self, key, value): |
| """ Initialize a single or a sequence of key-value pairs into the store. |
| |
| For each key, one must init it before push and pull. |
| |
| Only worker 0's (rank == 0) data are used. |
| |
| This function returns after data have been initialized successfully. |
| |
| Parameters |
| ---------- |
| key : int or sequence of int |
| The keys. |
| value : NDArray or sequence of NDArray |
| The values. |
| |
| Examples |
| -------- |
| >>> # init a single key-value pair |
| >>> shape = (2,3) |
| >>> kv = mx.kv.create('local') |
| >>> kv.init(3, mx.nd.ones(shape)*2) |
| >>> a = mx.nd.zeros(shape) |
| >>> kv.pull(3, out=a) |
| >>> print a.asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| |
| >>> # init a list of key-value pairs |
| >>> keys = [5, 7, 9] |
| >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) |
| """ |
| ckeys, cvals = _ctype_key_value(key, value) |
| check_call(_LIB.MXKVStoreInit( |
| self.handle, mx_uint(len(ckeys)), ckeys, cvals)) |
| |
| def push(self, key, value, priority=0): |
| """ Push a single or a sequence of key-value pairs into the store. |
| |
| Data consistency: |
| |
| 1. This function returns after adding an operator to the engine. |
| |
| 2. ``push`` is always called after all previous push and pull on the same |
| key are finished |
| |
| 3. There is no synchronization between workers. One can use _barrier() |
| to sync all workers. |
| |
| Parameters |
| ---------- |
| key : int or list of int |
| Keys |
| |
| value : NDArray or list of NDArray or list of list of NDArray |
| According values |
| |
| priority : int, optional |
| The priority of the push operation. |
| The higher the priority, the faster this action is likely |
| to be executed before other push actions. |
| |
| Examples |
| -------- |
| >>> # push a single key-value pair |
| >>> kv.push(3, mx.nd.ones(shape)*8) |
| >>> kv.pull(3, out=a) # pull out the value |
| >>> print a.asnumpy() |
| [[ 8. 8. 8.] |
| [ 8. 8. 8.]] |
| |
| >>> # aggregate the value and the push |
| >>> gpus = [mx.gpu(i) for i in range(4)] |
| >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] |
| >>> kv.push(3, b) |
| >>> kv.pull(3, out=a) |
| >>> print a.asnumpy() |
| [[ 4. 4. 4.] |
| [ 4. 4. 4.]] |
| |
| >>> # push a list of keys. |
| >>> # single device |
| >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys)) |
| >>> b = [mx.nd.zeros(shape)]*len(keys) |
| >>> kv.pull(keys, out=b) |
| >>> print b[1].asnumpy() |
| [[ 1. 1. 1.] |
| [ 1. 1. 1.]] |
| |
| >>> # multiple devices: |
| >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) |
| >>> kv.push(keys, b) |
| >>> kv.pull(keys, out=b) |
| >>> print b[1][1].asnumpy() |
| [[ 4. 4. 4.] |
| [ 4. 4. 4.]] |
| """ |
| ckeys, cvals = _ctype_key_value(key, value) |
| check_call(_LIB.MXKVStorePush( |
| self.handle, mx_uint(len(ckeys)), ckeys, cvals, |
| ctypes.c_int(priority))) |
| |
| def pull(self, key, out=None, priority=0): |
| """ Pull a single value or a sequence of values from the store. |
| |
| Data consistency: |
| |
| 1. This function returns after adding an operator to the engine. But any |
| further read on out will be blocked until it is finished. |
| |
| 2. ``pull`` is always called after all previous push and pull on the same |
| key are finished. |
| |
| 3. It pulls the newest value from the store. |
| |
| Parameters |
| ---------- |
| key : int or list of int |
| Keys. |
| |
| out: NDArray or list of NDArray or list of lists of NDArrays |
| According values. |
| |
| priority : int, optional |
| The priority of the push operation. |
| The higher the priority, the faster this action is likely |
| to be executed before other push actions. |
| |
| Examples |
| -------- |
| >>> # pull a single key-value pair |
| >>> a = mx.nd.zeros(shape) |
| >>> kv.pull(3, out=a) |
| >>> print a.asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| |
| >>> # pull into multiple devices |
| >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] |
| >>> kv.pull(3, out=b) |
| >>> print b[1].asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| |
| >>> # pull a list of key-value pairs. |
| >>> # On single device |
| >>> keys = [5, 7, 9] |
| >>> b = [mx.nd.zeros(shape)]*len(keys) |
| >>> kv.pull(keys, out=b) |
| >>> print b[1].asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| >>> # On multiple devices |
| >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) |
| >>> kv.pull(keys, out=b) |
| >>> print b[1][1].asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| """ |
| assert(out is not None) |
| ckeys, cvals = _ctype_key_value(key, out) |
| check_call(_LIB.MXKVStorePull( |
| self.handle, mx_uint(len(ckeys)), ckeys, cvals, |
| ctypes.c_int(priority))) |
| |
| def set_optimizer(self, optimizer): |
| """Register an optimizer to the store |
| |
| If there are multiple machines, this process (should be a worker node) |
| will pack this optimizer and send it to all servers. It returns after |
| this action is done. |
| |
| Parameters |
| ---------- |
| optimizer : Optimizer |
| the optimizer |
| """ |
| is_worker = ctypes.c_int() |
| check_call(_LIB.MXKVStoreIsWorkerNode(ctypes.byref(is_worker))) |
| |
| # pylint: disable=invalid-name |
| if 'dist' in self.type and is_worker.value: |
| # send the optimizer to server |
| try: |
| # use ASCII protocol 0, might be slower, but not a big ideal |
| optim_str = pickle.dumps(optimizer, 0) |
| except: |
| raise |
| self._send_command_to_servers(0, optim_str) |
| else: |
| self._set_updater(opt.get_updater(optimizer)) |
| |
| @property |
| def type(self): |
| """Get the type of this kvstore. |
| |
| Returns |
| ------- |
| type : str |
| the string type |
| """ |
| kv_type = ctypes.c_char_p() |
| check_call(_LIB.MXKVStoreGetType(self.handle, ctypes.byref(kv_type))) |
| return py_str(kv_type.value) |
| |
| @property |
| def rank(self): |
| """Get the rank of this worker node. |
| |
| Returns |
| ------- |
| rank : int |
| The rank of this node, which is in [0, get_num_workers()) |
| """ |
| rank = ctypes.c_int() |
| check_call(_LIB.MXKVStoreGetRank(self.handle, ctypes.byref(rank))) |
| return rank.value |
| |
| @property |
| def num_workers(self): |
| """Get the number of worker nodes. |
| |
| Returns |
| ------- |
| size :int |
| The number of worker nodes. |
| """ |
| size = ctypes.c_int() |
| check_call(_LIB.MXKVStoreGetGroupSize(self.handle, ctypes.byref(size))) |
| return size.value |
| |
| def save_optimizer_states(self, fname): |
| """Save optimizer (updater) state to file. |
| |
| Parameters |
| ---------- |
| fname : str |
| Path to output states file. |
| """ |
| assert self._updater is not None, "Cannot save states for distributed training" |
| with open(fname, 'wb') as fout: |
| fout.write(self._updater.get_states()) |
| |
| def load_optimizer_states(self, fname): |
| """Load optimizer (updater) state from file. |
| |
| Parameters |
| ---------- |
| fname : str |
| Path to input states file. |
| """ |
| assert self._updater is not None, "Cannot save states for distributed training" |
| self._updater.set_states(open(fname, 'rb').read()) |
| |
| def _set_updater(self, updater): |
| """Set a push updater into the store. |
| |
| This function only changes the local store. Use set_optimizer for |
| multiple machines. |
| |
| Parameters |
| ---------- |
| updater : function |
| The updater function. |
| |
| Examples |
| -------- |
| >>> def update(key, input, stored): |
| ... print "update on key: %d" % key |
| ... stored += input * 2 |
| >>> kv._set_updater(update) |
| >>> kv.pull(3, out=a) |
| >>> print a.asnumpy() |
| [[ 4. 4. 4.] |
| [ 4. 4. 4.]] |
| >>> kv.push(3, mx.nd.ones(shape)) |
| update on key: 3 |
| >>> kv.pull(3, out=a) |
| >>> print a.asnumpy() |
| [[ 6. 6. 6.] |
| [ 6. 6. 6.]] |
| """ |
| self._updater = updater |
| _updater_proto = ctypes.CFUNCTYPE( |
| None, ctypes.c_int, NDArrayHandle, NDArrayHandle, ctypes.c_void_p) |
| self._updater_func = _updater_proto(_updater_wrapper(updater)) |
| check_call(_LIB.MXKVStoreSetUpdater(self.handle, self._updater_func, None)) |
| |
| |
| def _barrier(self): |
| """Global barrier among all worker nodes. |
| |
| For example, assume there are n machines. We want to let machine 0 first |
| init the values, and then pull the initialized value to all machines. Before |
| pulling, we can place a barrier to guarantee that the initialization is |
| finished. |
| """ |
| check_call(_LIB.MXKVStoreBarrier(self.handle)) |
| |
| def _send_command_to_servers(self, head, body): |
| """Send a command to all server nodes. |
| |
| Send a command to all server nodes, which will make each server node run |
| KVStoreServer.controller |
| |
| This function returns after the command has been executed in all server |
| nodes. |
| |
| Parameters |
| ---------- |
| head : int |
| the head of the command |
| body : str |
| the body of the command |
| """ |
| check_call(_LIB.MXKVStoreSendCommmandToServers( |
| self.handle, mx_uint(head), c_str(body))) |
| |
| def create(name='local'): |
| """Create a new KVStore. |
| |
| Parameters |
| ---------- |
| name : {'local'} |
| The type of KVStore |
| - local works for multiple devices on a single machine (single process). |
| - dist works for multiple machines (multiple processes). |
| Returns |
| ------- |
| kv : KVStore |
| The created KVStore. |
| """ |
| if not isinstance(name, string_types): |
| raise TypeError('name must be a string') |
| handle = KVStoreHandle() |
| check_call(_LIB.MXKVStoreCreate(c_str(name), |
| ctypes.byref(handle))) |
| return KVStore(handle) |