blob: 7b6d7cb74849dc82dd17d6b5ee380b7e8b9c7a79 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import mxnet as mx
from numpy.testing import assert_allclose
from mxnet.test_utils import *
from mxnet.base import _as_list
mx.npx.reset_np()
def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps):
def _create_vars(num, prefix):
return [mx.sym.var(prefix + str(i)) for i in range(num)]
def _create_arrays(shapes):
return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes]
def _create_dict(prefix, shapes):
return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)}
def _merge_dict(*dicts):
result = {}
for item in dicts:
result.update(item)
return result
def _to_numpy_list(arrays):
return [x.asnumpy() if x is not None else x for x in arrays]
def _get_imperative_result(n_steps):
free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)]
loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)]
loop_var_start = int(is_for)
if is_train:
for var in free_vars + loop_vars[loop_var_start: ]:
var.attach_grad()
with mx.autograd.record(train_mode=is_train):
outputs, final_loop_vars = mx.nd.contrib.while_loop(
cond=lambda *_loop_vars: cond(_loop_vars, free_vars),
func=lambda *_loop_vars: func(_loop_vars, free_vars),
loop_vars=loop_vars,
max_iterations=max_iterations,
)
outputs = _as_list(outputs)
final_loop_vars = _as_list(final_loop_vars)
outputs = [x[: n_steps] for x in outputs]
out_grads = _create_arrays(x.shape for x in outputs) \
+ _create_arrays(x.shape for x in final_loop_vars)
loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars]
grads = []
if is_train:
cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0)
cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0))
grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \
+ [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start]
return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads
def _get_symbolic_result(out_grads, n_steps):
def _copy_args_dict(name_list):
return {name: args[name].copy() for name in name_list}
def _zeros_like_dict(name_list):
return {name: mx.nd.zeros_like(args[name]) for name in name_list}
free_syms = _create_vars(len(free_var_shapes), "FreeVar")
loop_syms = _create_vars(len(loop_var_shapes), "LoopVar")
outputs, final_loop_syms = mx.sym.contrib.while_loop(
cond=lambda *_loop_vars: cond(_loop_vars, free_syms),
func=lambda *_loop_vars: func(_loop_vars, free_syms),
loop_vars=loop_syms,
max_iterations=max_iterations,
)
outputs = _as_list(outputs)
final_loop_syms = _as_list(final_loop_syms)
if n_steps == 0:
outputs = []
else:
outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs]
loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms]
loop_result_sym = mx.sym.Group(loop_result_sym)
loop_var_start = int(is_for)
args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \
+ ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start]
args_grad = None if not is_train else _zeros_like_dict(x for x in args_names)
executor = loop_result_sym._bind(
ctx=default_device(),
args=_copy_args_dict(loop_result_sym.list_inputs()),
args_grad=args_grad,
)
loop_result_nd = executor.forward(is_train=is_train)
grads = []
if is_train:
executor.backward(out_grads=out_grads)
grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \
+ [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start]
return _to_numpy_list(loop_result_nd), _to_numpy_list(grads)
args = _merge_dict(
_create_dict("FreeVar", free_var_shapes),
_create_dict("LoopVar", loop_var_shapes),
)
if is_for:
assert loop_var_shapes[0] == (1, )
args["LoopVar0"] = mx.nd.array([0])
imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps)
sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps)
for imp_out, sym_out in zip(imp_outs, sym_outs):
if imp_out is None or sym_out is None:
continue
assert_almost_equal(imp_out, sym_out, rtol=1e-3, atol=1e-3)
for imp_grad, sym_grad in zip(imp_grads, sym_grads):
if imp_grad is None or sym_grad is None:
continue
assert_almost_equal(imp_grad, sym_grad, rtol=1e-3, atol=1e-3)
@pytest.mark.skip(reason="Bug in while loop op, tracked at incubator-mxnet/issues/18575")
def test_while_loop_for_foreach():
def make_true_cond():
return lambda loop_vars, _: (loop_vars[0] < 1e35).prod()
def make_false_cond():
return lambda loop_vars, _: (loop_vars[0] > 1e35).prod()
def make_for_cond(length):
return lambda loop_vars, _: loop_vars[0] < length
def case_0():
# This is a simple testcase that all loop steps are independent'
# It basically scans the array and outputs itself
# There is 1 output
# There is 1 state: i
def _simple_func(loop, free):
(i, ), (scanned, ) = loop, free
in_ = scanned.take(i).squeeze(axis=0)
return (in_, i + 1)
_verify_while_loop(
cond=make_true_cond(),
func=_simple_func,
max_iterations=1,
is_train=True,
is_for=True,
loop_var_shapes=[
(1, ), # i
],
free_var_shapes=[
(1, 3), # scanned
],
n_steps=1,
)
def case_1(**params):
# This is a simple testcase that simulates a cumulative sum
# There is 1 output
# There is 1 state: s
step_funcs = [
lambda a, b, s: s,
lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5,
lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5,
lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5,
lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5,
lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5,
lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5,
lambda a, b, s: a * 2.5 * b + s * 0.3,
lambda a, b, s: b * 2.5 * a + s * 0.3,
lambda a, b, s: 2.5 * a * b + s * 0.3,
lambda a, b, s: b * a * 2.5 + s * 0.3,
lambda a, b, s: 2.5 * b * a + s * 0.3,
lambda a, b, s: b * a * 2.5 + s * 0.3,
lambda a, b, s: s * 0.3 + a * 2.5 * b,
lambda a, b, s: s * 0.3 + b * 2.5 * a,
lambda a, b, s: s * 0.3 + 2.5 * a * b,
lambda a, b, s: s * 0.3 + b * a * 2.5,
lambda a, b, s: s * 0.3 + 2.5 * b * a,
lambda a, b, s: s * 0.3 + b * a * 2.5,
]
def make_func(step_func):
def step(loop, free):
(s, ), (a, b) = loop, free
out = step_func(a, b, s)
return (out, out)
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
is_train=is_train,
is_for=False,
**params
)
def case_2(**params):
# This is a testcase that involves non-differentiable operators
# There is 1 output
# There is 2 states: i, s
step_funcs = [
lambda in_, s, f_1: (in_ * 2) * s * f_1,
lambda in_, s, f_1: (in_ * 2) * f_1 * s,
lambda in_, s, f_1: s * (in_ * 2) * f_1,
lambda in_, s, f_1: s * f_1 * (in_ * 2),
lambda in_, s, f_1: f_1 * (in_ * 2) * s,
lambda in_, s, f_1: f_1 * s * (in_ * 2),
lambda in_, s, f_1: (2 * in_) * s * f_1,
lambda in_, s, f_1: (2 * in_) * f_1 * s,
lambda in_, s, f_1: s * (2 * in_) * f_1,
lambda in_, s, f_1: s * f_1 * (2 * in_),
lambda in_, s, f_1: f_1 * (2 * in_) * s,
lambda in_, s, f_1: f_1 * s * (2 * in_),
]
def make_func(step_func):
"""This simulates:
def compute(s, inputs, f_1, length):
outputs = []
for i in range(length):
s += inputs[i] * 2 + f_1
outputs.append(s)
return outputs, s
"""
def step(loop, free):
(i, s), (scanned, f_1, _) = loop, free
in_ = scanned.take(i).squeeze(axis=0)
out = step_func(in_, s, f_1)
return (out, (i + 1, out))
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
max_iterations=1000,
is_train=is_train,
is_for=True,
**params
)
def case_3(length, **params):
# This is a testcase for multiple non-differentiable operators and different ways of slicing
# There are 2 outputs
# There are 3 states: i, s_0, s_1
step_funcs = [
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0,
lambda i_0, i_1, s_0, s_1, f_0: i_1,
lambda i_0, i_1, s_0, s_1, f_0: s_0,
lambda i_0, i_1, s_0, s_1, f_0: s_1,
lambda i_0, i_1, s_0, s_1, f_0: f_0,
]
def make_func(step_func):
"""This simulates:
def compute(input_0, input_1, s_0, s_1, f_0, length):
output_0 = []
output_1 = []
for i in range(length):
i_0 = input_0[i]
i_1 = input_1[length - 1 - i]
out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
s_0 = (s_0 + out) * 1.05
s_1 = (s_1 - out * 0.5) * 0.95
output_0.append(out)
output_1.append(out * 1.5)
return outputs, s_0, s_1
"""
def step(loop, free):
(i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free
i_0 = sc_0.take(i).squeeze(axis=0)
i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
out = step_func(i_0, i_1, s_0, s_1, f_0)
return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95])
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
max_iterations=1000,
is_train=is_train,
is_for=True,
**params
)
def case_4(length, single_shape, **params):
# It is for the case that inputs & outputs are the same
# There are 3 outputs
# There are 4 states: i, s_0, s_1, s_2
# i is used in both non-differentiable (take) and differentiable (+) occasions
step_funcs = [
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0,
lambda i_0, i_1, s_0, s_1, f_0: i_1,
lambda i_0, i_1, s_0, s_1, f_0: s_0,
lambda i_0, i_1, s_0, s_1, f_0: s_1,
lambda i_0, i_1, s_0, s_1, f_0: f_0,
]
def make_func(step_func):
"""This simulates:
def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
# here s_2 remains untouched
output_0 = []
output_1 = []
output_2 = []
for i in range(length):
i_0 = input_0[i]
i_1 = input_1[length - 1 - i]
out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
out = out * i * i_0 * i_1
s_0 = (s_0 + out) * 1.05
s_1 = (s_1 - out * 0.5) * 0.95
output_0.append(out)
output_1.append(f_0)
output_2.append(out * 1.5)
return output_0, output_1, output_2, s_0, s_1, s_2
"""
def step(loop, free):
(i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
i_0 = sc_0.take(i).squeeze(axis=0)
i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
out = step_func(i_0, i_1, s_0, s_1, f_0)
out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape)
out = out * i_0 * i_1
return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2])
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
max_iterations=1000,
is_train=is_train,
is_for=True,
**params
)
def case_5(length, single_shape, **params):
# It is for the case that inputs & outputs are the same
# There are 0 outputs
# There are 4 states: i, s_0, s_1, s_2
# i is used in both differentiable (take) and non-differentiable (+) occasions
step_funcs = [
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0,
lambda i_0, i_1, s_0, s_1, f_0: i_1,
lambda i_0, i_1, s_0, s_1, f_0: s_0,
lambda i_0, i_1, s_0, s_1, f_0: s_1,
lambda i_0, i_1, s_0, s_1, f_0: f_0,
]
def make_func(step_func):
"""This simulates:
def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
# here s_2 remains untouched
output_0 = []
output_1 = []
output_2 = []
for i in range(length):
i_0 = input_0[i]
i_1 = input_1[length - 1 - i]
out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
out = out * i * i_0 * i_1
s_0 = (s_0 + out) * 1.05
s_1 = (s_1 - out * 0.5) * 0.95
output_0.append(out)
output_1.append(f_0)
output_2.append(out * 1.5)
return output_0, output_1, output_2, s_0, s_1, s_2
"""
def step(loop, free):
(i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
i_0 = sc_0.take(i).squeeze(axis=0)
i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
out = step_func(i_0, i_1, s_0, s_1, f_0)
out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape)
out = out * i_0 * i_1
return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2])
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
max_iterations=1000,
is_train=is_train,
is_for=True,
**params
)
def case_6(length, single_shape, **params):
# It is for the case that inputs & outputs are the same
# There are 3 outputs
# There are 4 states: i, s_0, s_1, s_2
# i is used in both differentiable (take) and non-differentiable (+) occasions
step_funcs = [
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2),
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0,
lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0,
lambda i_0, i_1, s_0, s_1, f_0: i_0,
lambda i_0, i_1, s_0, s_1, f_0: i_1,
lambda i_0, i_1, s_0, s_1, f_0: s_0,
lambda i_0, i_1, s_0, s_1, f_0: s_1,
lambda i_0, i_1, s_0, s_1, f_0: f_0,
]
def make_func(step_func):
"""This simulates:
def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
# here s_2 remains untouched
output_0 = []
output_1 = []
output_2 = []
for i in range(length):
i_0 = input_0[i]
i_1 = input_1[length - 1 - i]
out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
out = out * i * i_0 * i_1
s_0 = (s_0 + out) * 1.05
s_1 = (s_1 - out * 0.5) * 0.95
output_0.append(out)
output_1.append(f_0)
output_2.append(out * 1.5)
return output_0, output_1, output_2, s_0, s_1, s_2
"""
def step(loop, free):
(i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
i_0 = sc_0.take(i).squeeze(axis=0)
i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
out_0 = step_func(i_0, i_1, s_0, s_1, f_0)
out_0 = out_0 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape)
out_1 = step_func(i_1, s_0, f_0, s_1, i_0)
out_1 = out_1 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape)
return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i + 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2])
return step
case_id = 0
for is_train in [True, False]:
for step_func in step_funcs:
case_id += 1
_verify_while_loop(
func=make_func(step_func),
max_iterations=1000,
is_train=is_train,
is_for=True,
**params
)
# Case 0: the simpest case
case_0()
# Case 1.1.*
case_1(
cond=make_true_cond(),
loop_var_shapes=[
(1, ), # s
],
free_var_shapes=[
(1, ), # a
(1, ), # b
],
max_iterations=5,
n_steps=5,
)
# Case 1.2.*
case_1(
cond=make_true_cond(),
loop_var_shapes=[
(2, 3, 4), # s
],
free_var_shapes=[
(2, 3, 4), # a
(2, 3, 4), # b
],
max_iterations=3,
n_steps=3,
)
# Case 1.3.*
case_1(
cond=make_false_cond(),
loop_var_shapes=[
(2, 3, 4), # s
],
free_var_shapes=[
(2, 3, 4), # a
(2, 3, 4), # b
],
max_iterations=20,
n_steps=0,
)
# Case 2.1.*
case_2(
cond=make_for_cond(length=5),
loop_var_shapes=[
(1, ), # i
(2, ), # s
],
free_var_shapes=[
(100, 2), # scanned
(2, ), # f_1
(3, 4, 5, 6), # f_2, unused
],
n_steps=5,
)
# Case 2.2.*
case_2(
cond=make_for_cond(length=3),
loop_var_shapes=[
(1, ), # i
(2, ), # s
],
free_var_shapes=[
(30, 2), # scanned
(2, ), # f_1
(3, 4, 5, 6), # f_2, unused
],
n_steps=3,
)
# Case 3.*
case_3(
length=5,
cond=make_for_cond(length=5),
loop_var_shapes=[
(1, ), # i
(2, ), # s_0
(2, ), # s_1
],
free_var_shapes=[
(30, 2), # sc_0
(30, 2), # sc_1
(2, ), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=5,
)
# Case 4.1.*
case_4(
length=4,
cond=make_for_cond(length=4),
single_shape=[5],
loop_var_shapes=[
(1, ), # i
(5, ), # s_0
(5, ), # s_1
(23, 6, 8), # s_2
],
free_var_shapes=[
(30, 5), # sc_0
(30, 5), # sc_1
(5, ), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=4,
)
# Case 4.2.*
case_4(
length=5,
cond=make_for_cond(length=5),
single_shape=[5, 12],
loop_var_shapes=[
(1, ), # i
(5, 12), # s_0
(5, 12), # s_1
(23, 6, 8), # s_2
],
free_var_shapes=[
(30, 5, 12), # sc_0
(30, 5, 12), # sc_1
(5, 12), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=5,
)
# Case 5.1.*
case_5(
length=4,
cond=make_for_cond(length=4),
single_shape=[5],
loop_var_shapes=[
(1, ), # i
(5, ), # s_0
(5, ), # s_1
(23, 6, 8), # s_2
],
free_var_shapes=[
(30, 5), # sc_0
(30, 5), # sc_1
(5, ), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=4,
)
# Case 5.2.*
case_5(
length=5,
cond=make_for_cond(length=5),
single_shape=[3, 4, 2],
loop_var_shapes=[
(1, ), # i
(3, 4, 2), # s_0
(3, 4, 2), # s_1
(23, 6, 8), # s_2
],
free_var_shapes=[
(30, 3, 4, 2), # sc_0
(30, 3, 4, 2), # sc_1
(3, 4, 2), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=5,
)
# Case 6.*
case_6(
length=5,
cond=make_for_cond(length=5),
single_shape=[5, 3],
loop_var_shapes=[
(1, ), # i
(5, 3), # s_0
(5, 3), # s_1
(3, 5), # s_2
],
free_var_shapes=[
(30, 5, 3), # sc_0
(30, 5, 3), # sc_1
(5, 3), # f_0
(3, 4, 5, 6), # f_1, unused
],
n_steps=5,
)
def test_while_loop_nested():
def _to_np_list(arrays):
return [x.asnumpy() if x is not None else x for x in arrays]
def _array(shape):
return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
def inner_cond(i, j, x_sum, sc):
return j < 2
def inner_body(i, j, x_sum, sc):
x_ij = sc.take(j).squeeze(axis=0)
return (x_ij, x_ij), (i, j + 1, x_sum, sc)
def outer_cond(i, j, x_sum, sc):
return i < 2
def outer_body(i, j, x_sum, sc):
F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
(x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop(
cond=inner_cond,
func=inner_body,
loop_vars=(i, j, x_sum, sc),
max_iterations=2,
)
return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p)
def make_loop(i, j, x_sum, sc):
F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
(x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop(
cond=outer_cond,
func=outer_body,
loop_vars=(i, j, x_sum, sc),
max_iterations=2,
)
return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji
args = {
"i": mx.nd.array([0]),
"j": mx.nd.array([0]),
"x_sum": _array([5, 3]),
"sc": _array([10, 10, 5, 3]),
}
args_grad = {
"x_sum": _array([5, 3]),
"sc": _array([10, 10, 5, 3]),
}
out_grad = [
_array([1]),
_array([1]),
_array([5, 3]),
_array([10, 10, 5, 3]),
_array([2, 2, 10, 5, 3]),
_array([2, 2, 10, 5, 3]),
]
def _get_imp_result(is_train, args, args_grad, out_grad):
args = {k: v.copy() for k, v in args.items()}
args_grad = {k: v.copy() for k, v in args_grad.items()}
i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]]
if is_train:
x_sum.attach_grad()
sc.attach_grad()
with mx.autograd.record(train_mode=is_train):
results = make_loop(i, j, x_sum, sc)
cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0)
if not is_train:
return _to_np_list(results), []
cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0)
assert cat_grad.shape == cat_res.shape
cat_res.backward(out_grad=cat_grad)
grads = [x_sum.grad, sc.grad]
return _to_np_list(results), _to_np_list(grads)
def _get_sym_result(is_train, args, args_grad, out_grad):
args = {k: v.copy() for k, v in args.items()}
args_grad = {k: v.copy() for k, v in args_grad.items()}
i, j, x_sum, sc = [
mx.sym.var("i"),
mx.sym.var("j"),
mx.sym.var("x_sum"),
mx.sym.var("sc"),
]
result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc))
executor = result_sym._bind(
ctx=default_device(),
args=args,
args_grad=args_grad,
)
results = executor.forward(is_train=is_train)
if not is_train:
return _to_np_list(results), []
executor.backward(out_grads=out_grad)
grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]]
return _to_np_list(results), _to_np_list(grads)
for is_train in [True, False]:
imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad)
sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad)
assert len(imp_out) == len(sym_out)
assert len(imp_grad) == len(sym_grad)
for x, y in zip(imp_out, sym_out):
assert_almost_equal(x, y, rtol=1e-3, atol=1e-3)
for x, y in zip(imp_grad, sym_grad):
assert_almost_equal(x, y, rtol=1e-3, atol=1e-3)
def _verify_cond(cond_func, then_func, else_func, input_var_shapes, free_var_shapes, is_train):
def _create_symbol(prefix, i):
return mx.sym.var(prefix + str(i))
def _create_array(shape):
return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
def _to_numpy_list(arrays):
return [x.asnumpy() if x is not None else x for x in arrays]
def _merge_dict(*dicts):
result = {}
for item in dicts:
result.update(item)
return result
_input_syms = [_create_symbol("InputVar", i) for i, _ in enumerate(input_var_shapes)]
_free_syms = [_create_symbol("FreeVar", i) for i, _ in enumerate(free_var_shapes)]
_input_vars = [_create_array(x) for x in input_var_shapes]
_free_vars = [_create_array(x) for x in free_var_shapes]
_args_dict = _merge_dict(
{"InputVar" + str(i): x for i, x in enumerate(_input_vars)},
{"FreeVar" + str(i): x for i, x in enumerate(_free_vars)},
)
def _get_imperative_result():
free_vars = [x.copy() for x in _free_vars]
input_vars = [x.copy() for x in _input_vars]
out_grads = []
if is_train:
for var in free_vars + input_vars:
var.attach_grad()
with mx.autograd.record(train_mode=is_train):
outputs = mx.nd.contrib.cond(
pred=cond_func(input_vars, free_vars),
then_func=lambda: then_func(input_vars, free_vars),
else_func=lambda: else_func(input_vars, free_vars),
)
outputs = _as_list(outputs)
outputs = [x * 2 for x in outputs]
grads = []
if is_train:
out_grads = [_create_array(x.shape) for x in outputs]
cat_out = mx.nd.concat(*[x.reshape(-1) for x in outputs], dim=0)
cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0))
grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \
+ [input_vars[i].grad for i, _ in enumerate(input_var_shapes)]
return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads
def _get_symbolic_result(out_grads):
outputs_sym = mx.sym.contrib.cond(
pred=cond_func(_input_syms, _free_syms),
then_func=lambda: then_func(_input_syms, _free_syms),
else_func=lambda: else_func(_input_syms, _free_syms),
)
outputs_sym = _as_list(outputs_sym)
outputs_sym = [x * 2 for x in outputs_sym]
outputs_sym = mx.sym.Group(outputs_sym)
executor = outputs_sym._bind(
ctx=default_device(),
args={name: _args_dict[name].copy() for name in outputs_sym.list_inputs()},
args_grad=None if not is_train else _merge_dict(
{"InputVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(input_var_shapes)},
{"FreeVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(free_var_shapes)},
),
)
outputs = executor.forward(is_train=is_train)
grads = []
if is_train:
executor.backward(out_grads=out_grads)
grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \
+ [executor.grad_dict.get("InputVar" + str(i), None) for i, _ in enumerate(input_var_shapes)]
return _to_numpy_list(outputs), _to_numpy_list(grads)
imp_outs, imp_grads, out_grads = _get_imperative_result()
sym_outs, sym_grads = _get_symbolic_result(out_grads)
for imp_out, sym_out in zip(imp_outs, sym_outs):
if imp_out is None or sym_out is None:
continue
assert_almost_equal(imp_out, sym_out, rtol=1e-3, atol=1e-3)
for imp_grad, sym_grad in zip(imp_grads, sym_grads):
if imp_grad is None or sym_grad is None:
continue
assert_almost_equal(imp_grad, sym_grad, rtol=1e-3, atol=1e-3)
def test_cond():
# whether there are free variables in three graphs
# whether these three graphs contain input_vars
# whether to use all input_vars
# which branch to choose
def run_case(cond_func, then_func, else_func, **params):
def make_cond(is_inverse):
def cond(inputs, free):
x = cond_func(inputs, free)
if is_inverse:
if isinstance(x, mx.sym.Symbol):
return mx.sym.logical_not(x)
else:
return mx.nd.logical_not(x)
return x
return cond
for is_train in [True, False]:
for is_inverse in [False, True]:
_verify_cond(
cond_func=make_cond(is_inverse),
then_func=then_func,
else_func=else_func,
is_train=is_train,
**params
)
# Each function can
# 1. use_free_vars or not: T/F
# 2. use_input_vars or not: T/F
# 3. use_all_input_vars or not: T/F
# (a, b, c) are inputs, (d, e, f) are free_vars
cond_funcs = [
lambda a, b, c, d, e, f: (a * b).sum() < 0.5, # F, T, F
lambda a, b, c, d, e, f: (a + b + c).sum() < 0.5, # F, T, T
lambda a, b, c, d, e, f: (d + e).sum() < 0.5, # T, F, F
lambda a, b, c, d, e, f: (d + e * a).sum() < 0.5, # T, T, F
lambda a, b, c, d, e, f: (d + e * a + b * c).sum() < 0.5, # T, T, T
]
body_funcs = [
lambda a, b, c, d, e, f: a * b, # F, T, F
lambda a, b, c, d, e, f: a * b * c, # F, T, T
lambda a, b, c, d, e, f: d * e, # T, F, F
lambda a, b, c, d, e, f: d * e * a, # T, T, F
lambda a, b, c, d, e, f: d * e * a * b * c, # T, T, T
# some extra tests
lambda a, b, c, d, e, f: b * c,
lambda a, b, c, d, e, f: a * c,
lambda a, b, c, d, e, f: (a + b) * c,
lambda a, b, c, d, e, f: c * (b - a),
]
# enumerate all kinds of possible combinations
for cond_func in cond_funcs:
for then_func in body_funcs:
for else_func in body_funcs:
run_case(
cond_func=lambda x, y: cond_func(x[0], x[1], x[2], y[0], y[1], y[2]),
then_func=lambda x, y: then_func(x[0], x[1], x[2], y[0], y[1], y[2]),
else_func=lambda x, y: else_func(x[0], x[1], x[2], y[0], y[1], y[2]),
input_var_shapes=[
(2, 3),
(2, 3),
(2, 3),
],
free_var_shapes=[
(2, 3),
(2, 3),
(2, 3),
]
)
@pytest.mark.garbage_expected
def test_foreach():
v3 = mx.sym.var("v0")
v4 = mx.sym.var("v1")
v5 = mx.sym.var("v2")
v6 = mx.sym.var("v3")
v7 = mx.sym.var("v4")
v8 = mx.sym.var("v5")
def verify_foreach(step, in_syms, state_syms, free_syms,
in_arrs, init_states, frees, out_grads, is_train=True,
free_vars_func=None, num_iters=1):
step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms)
res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms)
out = _as_list(res)
num_outputs = len(out)
for i in range(num_outputs):
out[i] = out[i] * 2
out.extend(states)
out = mx.sym.Group(out)
js_1 = out.tojson()
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
arr_grads = []
arg_dict = {}
arg_grad_dict = {}
i = 0
for arr in _as_list(in_arrs):
arr_grad = mx.nd.empty(arr.shape)
arr_grads.append(arr_grad)
arg_dict['v'+str(i)] = arr
arg_grad_dict['v'+str(i)] = arr_grad
i = i + 1
for arr in init_states:
arr_grad = mx.nd.empty(arr.shape)
arr_grads.append(arr_grad)
arg_dict['v'+str(i)] = arr
arg_grad_dict['v'+str(i)] = arr_grad
i = i + 1
for arr in frees:
arr_grad = mx.nd.empty(arr.shape)
arr_grads.append(arr_grad)
arg_dict['v'+str(i)] = arr
arg_grad_dict['v'+str(i)] = arr_grad
i = i + 1
if is_train:
e = out._bind(ctx=default_device(), args=arg_dict, args_grad=arg_grad_dict)
else:
e = out._bind(ctx=default_device(), args=arg_dict)
# the inputs to forward and backward are the same so forward and backward
# should always return the same outputs.
for _ in range(num_iters):
e.forward(is_train=is_train)
if (is_train):
# backward
tmp_grads = out_grads[0][:]
tmp_grads.extend(out_grads[1])
e.backward(tmp_grads)
# Below we use imperative to reimplement foreach and compute its gradients.
res = []
for _ in range(len(_as_list(out_grads[0]))):
res.append([])
for arr in _as_list(in_arrs):
arr.attach_grad()
for arr in init_states:
arr.attach_grad()
for arr in frees:
arr.attach_grad()
with mx.autograd.record():
frees_imp = frees if free_vars_func is None else free_vars_func(frees)
step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp)
states = [mx.nd.expand_dims(s, 0) for s in init_states]
res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states)
res2 = _as_list(res)
for i in range(len(res2)):
res2[i] = res2[i] * 2
outs = []
outs[:] = res2[:]
if isinstance(states, list):
outs.extend(states)
states = [mx.nd.expand_dims(s, 0) for s in states]
res2.extend(states)
else:
outs.append(states)
states = mx.nd.expand_dims(states, 0)
res2.append(states)
if is_train:
res = mx.nd.concat(*res2, dim=0)
tmp_grads = out_grads[0][:]
tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]]
tmp_grads.extend(tmp_grads1)
if is_train:
res.backward(mx.nd.concat(*tmp_grads, dim=0))
for i in range(len(outs)):
assert e.outputs[i].shape == outs[i].shape
assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(),
rtol=1e-3, atol=1e-3)
if (is_train):
all_ins = _as_list(in_arrs)[:]
all_ins.extend(init_states)
all_ins.extend(frees)
size = min(len(all_ins), len(e.grad_arrays))
for i in range(size):
assert_almost_equal(all_ins[i].grad.asnumpy(),
e.grad_arrays[i].asnumpy(),
rtol=1e-3, atol=1e-3)
# Test cases:
# * graph inputs are stored in different orders.
# This is to test if foreach finds the data arrays and weight arrays
# in the right location.
# * the number of iterations: odd or even.
# * multiple inputs and multiple outputs.
# * inference.
def step1(in1, states, free):
out = in1 * 2 + states[0] + free[0]
return (out, [out])
frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1]
arrs = mx.nd.arange(6).reshape(shape=(3, 2))
states = [mx.nd.arange(2)]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True,
lambda frees : [frees[0] + frees[1]])
verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False,
lambda frees : [frees[0] + frees[1]])
verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True,
lambda frees : [frees[0] + frees[1]], 5)
verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False,
lambda frees : [frees[0] + frees[1]], 5)
# Test the even number of iterations.
frees = [mx.nd.random.uniform(shape=(2))]
arrs = mx.nd.random.uniform(shape=(2, 2))
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Test the odd number of iterations
arrs = mx.nd.random.uniform(shape=(3, 2))
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Reorder the input and state in the subgraph inputs.
def step2(in1, states, free):
out = states[0] + in1 * 2 + free[0]
return (out, [out])
# Test the even number of iterations.
arrs = mx.nd.random.uniform(shape=(2, 2))
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Test the odd number of iterations.
arrs = mx.nd.random.uniform(shape=(3, 2))
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Test multiple inputs and outputs.
def step3(in1, states, free):
out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0]
return ([out, out], [out * 2, out * 3])
arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads)
verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False)
# Test multiple inputs and outputs.
# The order of subgraph inputs doesn't match the operator inputs
def step4(in1, states, free):
out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0]
return ([out, out * 2], [out * 2, out * 3])
arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads)
verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False)
# Test multiple inputs and outputs.
# The data inputs and states have different shapes.
def step5(in1, states, free):
if isinstance(in1[0], mx.nd.NDArray):
out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2)
out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2)
else:
out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2)
out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2)
return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))]
arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))]
states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
# Test multiple inputs and outputs.
# The data inputs and states have different shapes and data types.
def step6(in1, states, free):
if isinstance(in1[0], mx.nd.NDArray):
out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'),
mx.nd.cast(in1[1], 'float32') * 2)
out2 = mx.nd.broadcast_add(in1[0],
free[0] + mx.nd.cast(states[1], 'float32') * 2)
else:
out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'),
mx.sym.cast(in1[1], 'float32') * 2)
out2 = mx.sym.broadcast_add(in1[0],
free[0] + mx.sym.cast(states[1], 'float32') * 2)
return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
frees = [mx.nd.random.uniform(shape=(2)),
mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')]
arrs = [mx.nd.random.uniform(shape=(3, 2, 2)),
mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')]
states = [mx.nd.random.uniform(shape=(2, 2)),
mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
# Test multiple inputs and outputs.
# some of the inputs are used twice.
def step7(in1, states, free):
out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0]
out2 = in1[0] + free[0] + states[1] * 2 + in1[1]
return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
# Test the case that the output is the input.
arrs = mx.nd.random.uniform(shape=(3, 2))
states = [mx.nd.arange(2)]
frees = [mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
[mx.nd.random.uniform(-10, 10, states[0].shape)]]
def step8(in1, states, free):
return (in1, [states[0] * free[0]])
verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False)
def step9(in1, states, free):
return (in1 * free[0], states)
verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Test the case that not all inputs are used.
def step10(in1, states, free):
return (in1, states)
verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False)
def step11(in1, states, free):
return (in1, free)
try:
verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False)
except AssertionError:
print("the states have to be used")
def step12(in1, states, free):
return (in1, [states[0] + 1, states[0] + 2])
states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
frees = []
try:
verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads)
verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False)
except AssertionError:
print("the states have to be used")
# test without free variables.
def step13(in1, states, free):
return (in1, states)
states = [mx.nd.random.uniform(shape=(2))]
verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads)
verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False)
# test when there isn't output data or output states.
def step14(in1, states, free):
return (in1 + free[0], [])
frees = [mx.nd.random.uniform(shape=(2))]
out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], []]
verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads)
verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False)
def step15(in1, states, free):
return ([], [in1 * states[0] * free[0]])
out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]]
verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads)
verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False)
# Test the case of iterating on a 1D data array.
def step16(in1, states, free):
return ([in1[0] * states[0]], [states[0] * 2])
arrs = [mx.nd.arange(3)]
states = [mx.nd.random.uniform(shape=(1))]
out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))],
[mx.nd.random.uniform(-10, 10, (1))]]
verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads)
verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False)
def step17(in1, states, free):
return ([in1[1] * in1[0] * states[0]], [states[0] * 2])
arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)]
states = [mx.nd.random.uniform(shape=(1))]
out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))],
[mx.nd.random.uniform(-10, 10, (1))]]
verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads)
verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False)
def test_foreach_nested():
# Test nested foreach.
def step_in(in1, states):
out = in1 * 2 + states[0]
return (out, [out])
def step_sym(in1, states):
out1 = mx.sym.contrib.foreach(step_in, in1, states)
out = mx.sym.broadcast_add(out1[0], states[0])
return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))])
def step_nd(in1, states):
out1 = mx.nd.contrib.foreach(step_in, in1, states)
out = mx.nd.broadcast_add(out1[0], states[0])
return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))])
data_sym = mx.sym.var("v1")
state_sym = mx.sym.var("v2")
out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym])
assert isinstance(states, list)
assert len(states) == 1
out = mx.sym.broadcast_add(out, states[0])
js_1 = out.tojson()
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
data = mx.nd.arange(8).reshape((2, 2, 2))
state = mx.nd.arange(2)
data_grad = mx.nd.empty(data.shape)
state_grad = mx.nd.empty(state.shape)
e = out._bind(ctx=default_device(), args={'v1':data, 'v2':state},
args_grad={'v1':data_grad, 'v2':state_grad})
e.forward(is_train=True)
out_grads = []
for out in e.outputs:
out_grads.append(mx.nd.random.uniform(shape=out.shape))
e.backward(out_grads)
data.attach_grad()
state.attach_grad()
with mx.autograd.record():
out, states = mx.nd.contrib.foreach(step_nd, data, [state])
assert isinstance(states, list)
assert len(states) == 1
res = mx.nd.broadcast_add(out, states[0])
assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=1e-3, atol=1e-3)
res.backward(out_grads[0])
assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy(), rtol=1e-3, atol=1e-3)
assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy(), rtol=1e-3, atol=1e-3)
def test_foreach_with_unkown_dim():
# MXNet supports using 0 as placeholder for unknown dimensions in shape
step = lambda data, states: (data + states[0], [states[0] * 2])
# input shape with NCHW format and N is unknown
data = mx.sym.var('data', shape=(0, 3, 32, 32))
states = [mx.sym.var('state')]
outs, states = mx.sym.contrib.foreach(step, data, states)
_, output_shape, _ = outs.infer_shape_partial()
assert_allclose((0, 3, 32, 32), output_shape[0])