blob: dd8149d4822ed5c7a39035699125be8f4fce5706 [file] [log] [blame]
# pylint: skip-file
import mxnet as mx
import numpy as np
shape = (4, 4)
keys = [5, 7, 11]
def init_kv():
"""init kv """
kv = mx.kv.create()
# single
kv.init(3, mx.nd.zeros(shape))
# list
kv.init(keys, [mx.nd.zeros(shape)] * len(keys))
return kv
def check_diff_to_scalar(A, x):
""" assert A == x"""
assert(np.sum(np.abs((A - x).asnumpy())) == 0)
def test_single_kv_pair():
"""single key-value pair push & pull"""
kv = init_kv()
kv.push(3, mx.nd.ones(shape))
val = mx.nd.empty(shape)
kv.pull(3, out = val)
check_diff_to_scalar(val, 1)
def test_init():
"""test init"""
kv = mx.kv.create()
kv.init(3, mx.nd.ones(shape)*4)
a = mx.nd.zeros(shape)
kv.pull(3, out=a)
check_diff_to_scalar(a, 4)
def test_list_kv_pair():
"""list key-value pair push & pull"""
kv = init_kv()
kv.push(keys, [mx.nd.ones(shape)*4] * len(keys))
val = [mx.nd.empty(shape)] * len(keys)
kv.pull(keys, out = val)
for v in val:
check_diff_to_scalar(v, 4)
def test_aggregator():
"""aggregate value on muliple devices"""
kv = init_kv()
# devices
num_devs = 4
devs = [mx.Context('cpu', i) for i in range(num_devs)]
# single
vals = [mx.nd.ones(shape, d) for d in devs]
kv.push(3, vals)
kv.pull(3, out = vals)
for v in vals:
check_diff_to_scalar(v, num_devs)
# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys)
kv.push(keys, vals)
kv.pull(keys, out = vals)
for vv in vals:
for v in vv:
check_diff_to_scalar(v, num_devs * 2.0)
def updater(key, recv, local):
"""use updater: +="""
local += recv
def test_updater(dev = 'cpu'):
"""updater"""
kv = init_kv()
kv._set_updater(updater)
# devices
num_devs = 4
devs = [mx.Context(dev, i) for i in range(num_devs)]
# single
vals = [mx.nd.ones(shape, d) for d in devs]
kv.push(3, vals)
kv.pull(3, out = vals)
for v in vals:
check_diff_to_scalar(v, num_devs)
# list
vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys)
num_push = 4
for i in range(num_push):
kv.push(keys, vals)
kv.pull(keys, out = vals)
for vv in vals:
for v in vv:
check_diff_to_scalar(v, num_devs * num_push)
def test_get_type():
kvtype = 'local_allreduce_cpu'
kv = mx.kv.create(kvtype)
assert kv.type == kvtype
if __name__ == '__main__':
test_init()
test_get_type()
test_single_kv_pair()
test_list_kv_pair()
test_aggregator()
test_updater()