blob: ab1e15c909cdd7890a932f6aea35268eacd2aa0a [file] [log] [blame]
import numpy as np
import mxnet as mx
import os
def reldiff(a, b):
diff = np.sum(np.abs(a - b))
norm = np.sum(np.abs(a))
reldiff = diff / norm
return reldiff
def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None):
"""check function consistency with uniform random numbers"""
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
lhs = mx.symbol.Variable('lhs')
rhs = mx.symbol.Variable('rhs')
if sf is not None:
ret = sf(lhs, rhs)
else:
ret = uf(lhs, rhs)
assert ret.list_arguments() == ['lhs', 'rhs']
lshape = shape if lshape is None else lshape
rshape = shape if rshape is None else rshape
lhs_arr = mx.nd.array(np.random.uniform(-1, 1, lshape))
rhs_arr = mx.nd.array(np.random.uniform(-1, 1, rshape))
lhs_grad = mx.nd.empty(lshape)
rhs_grad = mx.nd.empty(rshape)
executor = ret.bind(mx.Context('cpu'),
args=[lhs_arr, rhs_arr],
args_grad=[lhs_grad, rhs_grad])
exec3 = ret.bind(mx.Context('cpu'),
args=[lhs_arr, rhs_arr])
exec4 = ret.bind(mx.Context('cpu'),
args={'rhs': rhs_arr, 'lhs': lhs_arr},
args_grad={'lhs': lhs_grad, 'rhs': rhs_grad})
executor.forward()
exec3.forward()
exec4.forward()
out2 = executor.outputs[0].asnumpy()
out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy())
out3 = exec3.outputs[0].asnumpy()
out4 = exec4.outputs[0].asnumpy()
assert reldiff(out1, out2) < 1e-6
assert reldiff(out1, out3) < 1e-6
assert reldiff(out1, out4) < 1e-6
# test gradient
out_grad = mx.nd.array(np.ones(out2.shape))
lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(),
lhs_arr.asnumpy(),
rhs_arr.asnumpy())
executor.backward([out_grad])
assert reldiff(lhs_grad.asnumpy(), lhs_grad2) < 1e-6
assert reldiff(rhs_grad.asnumpy(), rhs_grad2) < 1e-6
def test_bind(disable_bulk_exec=False):
if disable_bulk_exec:
prev_fwd_var = os.environ.get("MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN", "1")
prev_bwd_var = os.environ.get("MXNET_EXEC_BULK_BWD_TRAIN", "1")
os.environ["MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN"] = "0"
os.environ["MXNET_EXEC_BULK_BWD_TRAIN"] = "0"
np.random.seed(0)
nrepeat = 10
maxdim = 4
for repeat in range(nrepeat):
for dim in range(1, maxdim):
check_bind_with_uniform(lambda x, y: x + y,
lambda g, x, y: (g, g),
dim)
check_bind_with_uniform(lambda x, y: x - y,
lambda g, x, y: (g, -g),
dim)
check_bind_with_uniform(lambda x, y: x * y,
lambda g, x, y: (y * g, x * g),
dim)
check_bind_with_uniform(lambda x, y: x / y,
lambda g, x, y: (g / y, -x * g/ (y**2)),
dim)
check_bind_with_uniform(lambda x, y: np.maximum(x, y),
lambda g, x, y: (g * (x>y), g * (y>x)),
dim,
sf=mx.symbol.maximum)
check_bind_with_uniform(lambda x, y: np.minimum(x, y),
lambda g, x, y: (g * (x<y), g * (y<x)),
dim,
sf=mx.symbol.minimum)
if disable_bulk_exec:
os.environ["MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN"] = prev_fwd_var
os.environ["MXNET_EXEC_BULK_BWD_TRAIN"] = prev_bwd_var
def test_dot():
np.random.seed(0)
nrepeat = 10
maxdim = 4
for repeat in range(nrepeat):
s =tuple(np.random.randint(1, 500, size=3))
check_bind_with_uniform(lambda x, y: np.dot(x, y),
lambda g, x, y: (np.dot(g, y.T), np.dot(x.T, g)),
2,
lshape=(s[0], s[1]),
rshape=(s[1], s[2]),
sf = mx.symbol.dot)
for repeat in range(nrepeat):
s =tuple(np.random.randint(1, 500, size=1))
check_bind_with_uniform(lambda x, y: np.dot(x, y),
lambda g, x, y: (g * y, g * x),
2,
lshape=(s[0],),
rshape=(s[0],),
sf = mx.symbol.dot)
def test_reshape():
x = mx.sym.Variable('x')
y = mx.sym.FullyConnected(x, num_hidden=4)
exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[])
exe.arg_arrays[0][:] = 1
exe.arg_arrays[1][:] = mx.nd.ones((4,4))
exe.arg_arrays[2][:] = 0
new_exe = exe.reshape(x=(3,4))
new_exe.forward(is_train=False)
# test sub exec forward
assert np.all(new_exe.outputs[0].asnumpy() == 4)
# test shared memory
assert np.all(exe.outputs[0].asnumpy()[:3] == 4)
# test base exec forward
exe.forward(is_train=False)
assert np.all(exe.outputs[0].asnumpy() == 4)
if __name__ == "__main__":
test_bind(disable_bulk_exec=False)
test_bind(disable_bulk_exec=True)
test_reshape()