blob: c954c1859d643eb3601e489e47844c01bd10da7a [file] [log] [blame]
#!/usr/bin/env python
import sys
sys.path.insert(0, "../../python/")
import mxnet as mx
import numpy as np
keys = [3, 5, 7]
# let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
shapes = [(4, 4), (100, 100), (2000, 2000)];
lr = .1
nworker = 4
nrepeat = 10
## generate data
data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
## individual key interface
def test_kvstore(kv_type):
print(kv_type)
kv = mx.kv.create(kv_type)
kv.set_optimizer(mx.optimizer.create('test', lr))
for k, s in zip(keys, shapes):
kv.init(k, mx.nd.zeros(s))
res = [np.zeros(s) for s in shapes]
for i in range(nrepeat):
for j in range(len(keys)):
kv.push(keys[j], [mx.nd.array(
data[i][j][g], mx.gpu(g)) for g in range(nworker)])
res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
for j in range(len(keys)):
out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
kv.pull(keys[j], out=out)
err = [np.sum(np.abs(o.asnumpy() - res[j])) for o in out]
err = sum(err) / np.sum(np.abs(res[j]))
assert(err < 1e-6), (err, shapes[j])
test_kvstore('local_update_cpu')
test_kvstore('local_allreduce_cpu')
test_kvstore('local_allreduce_device')
## group keys interface
def test_group_kvstore(kv_type):
print(kv_type)
kv = mx.kv.create(kv_type)
kv.set_optimizer(mx.optimizer.create('test', lr))
kv.init(keys, [mx.nd.zeros(s) for s in shapes])
res = [np.zeros(s) for s in shapes]
out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
for i in range(nrepeat):
kv.push(keys, [[
mx.nd.array(data[i][j][g], mx.gpu(g)) for g in range(nworker)]
for j in range(len(keys))])
kv.pull(keys, out=out)
res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
for a, b in zip(res, out):
err = [np.sum(np.abs(o.asnumpy() - a)) for o in b]
err = sum(err) / np.sum(np.abs(a))
assert(err < 1e-6), (err, a.shape)
test_group_kvstore('local_update_cpu')
test_group_kvstore('local_allreduce_cpu')
test_group_kvstore('local_allreduce_device')