blob: ebed6c57586d0e61ac442e95c92a36044239d916 [file] [log] [blame]
#!/usr/bin/env python
# pylint: skip-file
import sys
sys.path.insert(0, "../../python/")
import mxnet as mx
import numpy as np
import time
def check_diff_to_scalar(A, x):
""" assert A == x"""
assert(np.sum(np.abs((A - x).asnumpy())) == 0), A.asnumpy()
# setup
keys = [3, 5, 7]
rate = 2
shape = (2, 2)
big_shape = (1200, 1200) # big than BIGARRAY_BOUND
kv = mx.kv.create('dist_sync')
# init kv
kv.init(keys, [mx.nd.ones(shape)] * len(keys))
kv.init(99, mx.nd.ones(big_shape))
# init updater on servers
kv.set_optimizer(mx.optimizer.create('test', rate))
my_rank = kv.rank
nworker = kv.num_workers
def test_sync_push_pull():
nrepeat = 3
for i in range(nrepeat):
kv.push(3, mx.nd.ones(shape)*(my_rank+1))
kv.push(99, mx.nd.ones(big_shape)*(my_rank+1))
num = (nworker + 1 ) * nworker * rate / 2 * nrepeat + 1
val = mx.nd.zeros(shape)
kv.pull(3, out = val)
check_diff_to_scalar(val, num)
# print val.asnumpy()
val2 = mx.nd.zeros(big_shape)
kv.pull(99, out = val2)
check_diff_to_scalar(val2, num)
if __name__ == "__main__":
test_sync_push_pull()