| # pylint: skip-file | 
 | import mxnet as mx | 
 | import numpy as np | 
 |  | 
 | shape = (4, 4) | 
 | keys = [5, 7, 11] | 
 | def init_kv(): | 
 |     """init kv """ | 
 |     kv = mx.kv.create() | 
 |     # single | 
 |     kv.init(3, mx.nd.zeros(shape)) | 
 |     # list | 
 |     kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) | 
 |     return kv | 
 |  | 
 |  | 
 | def check_diff_to_scalar(A, x): | 
 |     """ assert A == x""" | 
 |     assert(np.sum(np.abs((A - x).asnumpy())) == 0) | 
 |  | 
 | def test_single_kv_pair(): | 
 |     """single key-value pair push & pull""" | 
 |  | 
 |     kv = init_kv() | 
 |     kv.push(3, mx.nd.ones(shape)) | 
 |     val = mx.nd.empty(shape) | 
 |     kv.pull(3, out = val) | 
 |     check_diff_to_scalar(val, 1) | 
 |  | 
 | def test_init(): | 
 |     """test init""" | 
 |     kv = mx.kv.create() | 
 |     kv.init(3, mx.nd.ones(shape)*4) | 
 |     a = mx.nd.zeros(shape) | 
 |     kv.pull(3, out=a) | 
 |     check_diff_to_scalar(a, 4) | 
 |  | 
 | def test_list_kv_pair(): | 
 |     """list key-value pair push & pull""" | 
 |  | 
 |     kv = init_kv() | 
 |  | 
 |     kv.push(keys, [mx.nd.ones(shape)*4] * len(keys)) | 
 |     val = [mx.nd.empty(shape)] * len(keys) | 
 |     kv.pull(keys, out = val) | 
 |     for v in val: | 
 |         check_diff_to_scalar(v, 4) | 
 |  | 
 |  | 
 | def test_aggregator(): | 
 |     """aggregate value on muliple devices""" | 
 |  | 
 |     kv = init_kv() | 
 |  | 
 |     # devices | 
 |     num_devs = 4 | 
 |     devs = [mx.Context('cpu', i) for i in range(num_devs)] | 
 |  | 
 |     # single | 
 |     vals = [mx.nd.ones(shape, d) for d in devs] | 
 |  | 
 |     kv.push(3, vals) | 
 |     kv.pull(3, out = vals) | 
 |  | 
 |     for v in vals: | 
 |         check_diff_to_scalar(v, num_devs) | 
 |  | 
 |     # list | 
 |     vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) | 
 |     kv.push(keys, vals) | 
 |     kv.pull(keys, out = vals) | 
 |  | 
 |     for vv in vals: | 
 |         for v in vv: | 
 |             check_diff_to_scalar(v, num_devs * 2.0) | 
 |  | 
 |  | 
 | def updater(key, recv, local): | 
 |     """use updater: +=""" | 
 |     local += recv | 
 |  | 
 | def test_updater(dev = 'cpu'): | 
 |     """updater""" | 
 |  | 
 |     kv = init_kv() | 
 |     kv._set_updater(updater) | 
 |  | 
 |     # devices | 
 |     num_devs = 4 | 
 |     devs = [mx.Context(dev, i) for i in range(num_devs)] | 
 |  | 
 |     # single | 
 |     vals = [mx.nd.ones(shape, d) for d in devs] | 
 |  | 
 |     kv.push(3, vals) | 
 |     kv.pull(3, out = vals) | 
 |  | 
 |     for v in vals: | 
 |         check_diff_to_scalar(v, num_devs) | 
 |  | 
 |     # list | 
 |     vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys) | 
 |  | 
 |     num_push = 4 | 
 |     for i in range(num_push): | 
 |         kv.push(keys, vals) | 
 |  | 
 |     kv.pull(keys, out = vals) | 
 |  | 
 |     for vv in vals: | 
 |         for v in vv: | 
 |             check_diff_to_scalar(v, num_devs * num_push) | 
 |  | 
 | def test_get_type(): | 
 |     kvtype = 'local_allreduce_cpu' | 
 |     kv = mx.kv.create(kvtype) | 
 |     assert kv.type == kvtype | 
 |  | 
 | if __name__ == '__main__': | 
 |     test_init() | 
 |     test_get_type() | 
 |     test_single_kv_pair() | 
 |     test_list_kv_pair() | 
 |     test_aggregator() | 
 |     test_updater() |