| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import copy |
| import numpy as np |
| import mxnet as mx |
| from mxnet import gluon |
| from numpy.testing import assert_allclose, assert_array_equal |
| from collections import defaultdict |
| from mxnet.test_utils import * |
| from mxnet.base import _as_list |
| from mxnet.attribute import AttrScope |
| from common import with_seed |
| |
| |
| @with_seed() |
| def test_while_loop_simple_forward(): |
| |
| class _TestBlock(gluon.HybridBlock): |
| |
| def __init__(self, cond, func, max_iterations): |
| super(_TestBlock, self).__init__() |
| self.cond = cond |
| self.func = func |
| self.max_iterations = max_iterations |
| |
| def hybrid_forward(self, F, *loop_vars): |
| return F.contrib.while_loop( |
| cond=self.cond, |
| func=self.func, |
| loop_vars=loop_vars, |
| max_iterations=self.max_iterations |
| ) |
| |
| for hybridize in [False, True]: |
| # Case 1.1: result should be sum([1, 2, 3 ... 100]) |
| model = _TestBlock( |
| cond=lambda i, s: i <= 5, |
| func=lambda i, s: (None, (i + 1, s + i)), |
| max_iterations=10, |
| ) |
| if hybridize: |
| model.hybridize() |
| _, result = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| ) |
| assert result[0].asscalar() == 6 |
| assert result[1].asscalar() == 15 |
| # Case 1.2: result should be sum([1, 2, 3 ... 1000]) |
| model = _TestBlock( |
| cond=lambda i, s, true: true, |
| func=lambda i, s, true: (None, (i + 1, s + i, true)), |
| max_iterations=1000, |
| ) |
| if hybridize: |
| model.hybridize() |
| _, result = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| mx.nd.array([1], dtype="int64"), # true |
| ) |
| assert result[0].asscalar() == 1001 |
| assert result[1].asscalar() == 500500 |
| assert result[2].asscalar() == 1 |
| # Case 1.3: result should be sum([]) |
| model = _TestBlock( |
| cond=lambda i, s, false: false, |
| func=lambda i, s, false: (None, (i + 1, s + i, false)), |
| max_iterations=1000, |
| ) |
| if hybridize: |
| model.hybridize() |
| _, result = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| mx.nd.array([0], dtype="int64"), # false |
| ) |
| assert result[0].asscalar() == 1 |
| assert result[1].asscalar() == 0 |
| assert result[2].asscalar() == 0 |
| # Case 2.1: result should be sum([1, 2, 3 ... 100]) |
| model = _TestBlock( |
| cond=lambda i, s: i <= 100, |
| func=lambda i, s: (i, (i + 1, s + i)), |
| max_iterations=1000, |
| ) |
| if hybridize: |
| model.hybridize() |
| outputs, (result_i, result_s) = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| ) |
| assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 1)) |
| assert result_i.asscalar() == 101 |
| assert result_s.asscalar() == 5050 |
| # Case 2.2: result should be sum([1, 2, 3 ... 1000]) |
| model = _TestBlock( |
| cond=lambda i, s, true: true, |
| func=lambda i, s, true: (i, (i + 1, s + i, true)), |
| max_iterations=1000, |
| ) |
| if hybridize: |
| model.hybridize() |
| outputs, (result_i, result_s, _) = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| mx.nd.array([1], dtype="int64"), # true |
| ) |
| assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) |
| assert result_i.asscalar() == 1001 |
| assert result_s.asscalar() == 500500 |
| # Case 2.3: a corner case, in which loop body is never executed |
| model = _TestBlock( |
| cond=lambda i, s, false: false, |
| func=lambda i, s, false: (i, (i + 1, s + i, false)), |
| max_iterations=1000, |
| ) |
| if hybridize: |
| model.hybridize() |
| _, (result_i, result_s, _) = model( |
| mx.nd.array([1], dtype="int64"), # i |
| mx.nd.array([0], dtype="int64"), # s |
| mx.nd.array([0], dtype="int64"), # false |
| ) |
| assert result_i.asscalar() == 1 |
| assert result_s.asscalar() == 0 |
| |
| |
| def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps): |
| |
| def _create_vars(num, prefix): |
| return [mx.sym.var(prefix + str(i)) for i in range(num)] |
| |
| def _create_arrays(shapes): |
| return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes] |
| |
| def _create_dict(prefix, shapes): |
| return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)} |
| |
| def _merge_dict(*dicts): |
| result = {} |
| for item in dicts: |
| result.update(item) |
| return result |
| |
| def _to_numpy_list(arrays): |
| return [x.asnumpy() if x is not None else x for x in arrays] |
| |
| def _get_imperative_result(n_steps): |
| free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] |
| loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] |
| loop_var_start = int(is_for) |
| if is_train: |
| for var in free_vars + loop_vars[loop_var_start: ]: |
| var.attach_grad() |
| with mx.autograd.record(train_mode=is_train): |
| outputs, final_loop_vars = mx.nd.contrib.while_loop( |
| cond=lambda *_loop_vars: cond(_loop_vars, free_vars), |
| func=lambda *_loop_vars: func(_loop_vars, free_vars), |
| loop_vars=loop_vars, |
| max_iterations=max_iterations, |
| ) |
| outputs = _as_list(outputs) |
| final_loop_vars = _as_list(final_loop_vars) |
| outputs = [x[: n_steps] for x in outputs] |
| out_grads = _create_arrays(x.shape for x in outputs) \ |
| + _create_arrays(x.shape for x in final_loop_vars) |
| loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] |
| grads = [] |
| if is_train: |
| cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0) |
| cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) |
| grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ |
| + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] |
| return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads |
| |
| def _get_symbolic_result(out_grads, n_steps): |
| |
| def _copy_args_dict(name_list): |
| return {name: args[name].copy() for name in name_list} |
| |
| def _zeros_like_dict(name_list): |
| return {name: mx.nd.zeros_like(args[name]) for name in name_list} |
| |
| free_syms = _create_vars(len(free_var_shapes), "FreeVar") |
| loop_syms = _create_vars(len(loop_var_shapes), "LoopVar") |
| outputs, final_loop_syms = mx.sym.contrib.while_loop( |
| cond=lambda *_loop_vars: cond(_loop_vars, free_syms), |
| func=lambda *_loop_vars: func(_loop_vars, free_syms), |
| loop_vars=loop_syms, |
| max_iterations=max_iterations, |
| ) |
| outputs = _as_list(outputs) |
| final_loop_syms = _as_list(final_loop_syms) |
| if n_steps == 0: |
| outputs = [] |
| else: |
| outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs] |
| loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms] |
| loop_result_sym = mx.sym.Group(loop_result_sym) |
| |
| loop_var_start = int(is_for) |
| args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \ |
| + ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] |
| args_grad = None if not is_train else _zeros_like_dict(x for x in args_names) |
| executor = loop_result_sym.bind( |
| ctx=default_context(), |
| args=_copy_args_dict(loop_result_sym.list_inputs()), |
| args_grad=args_grad, |
| ) |
| loop_result_nd = executor.forward(is_train=is_train) |
| grads = [] |
| if is_train: |
| executor.backward(out_grads=out_grads) |
| grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ |
| + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] |
| return _to_numpy_list(loop_result_nd), _to_numpy_list(grads) |
| |
| args = _merge_dict( |
| _create_dict("FreeVar", free_var_shapes), |
| _create_dict("LoopVar", loop_var_shapes), |
| ) |
| if is_for: |
| assert loop_var_shapes[0] == (1, ) |
| args["LoopVar0"] = mx.nd.array([0]) |
| imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps) |
| sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) |
| for imp_out, sym_out in zip(imp_outs, sym_outs): |
| if imp_out is None or sym_out is None: |
| continue |
| assert_almost_equal(imp_out, sym_out, rtol=1e-3, atol=1e-3) |
| for imp_grad, sym_grad in zip(imp_grads, sym_grads): |
| if imp_grad is None or sym_grad is None: |
| continue |
| assert_almost_equal(imp_grad, sym_grad, rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_while_loop_for_foreach(): |
| |
| def make_true_cond(): |
| return lambda loop_vars, _: (loop_vars[0] < 1e35).prod() |
| |
| def make_false_cond(): |
| return lambda loop_vars, _: (loop_vars[0] > 1e35).prod() |
| |
| def make_for_cond(length): |
| return lambda loop_vars, _: loop_vars[0] < length |
| |
| def case_0(): |
| # This is a simple testcase that all loop steps are independent' |
| # It basically scans the array and outputs itself |
| # There is 1 output |
| # There is 1 state: i |
| def _simple_func(loop, free): |
| (i, ), (scanned, ) = loop, free |
| in_ = scanned.take(i).squeeze(axis=0) |
| return (in_, i + 1) |
| _verify_while_loop( |
| cond=make_true_cond(), |
| func=_simple_func, |
| max_iterations=1, |
| is_train=True, |
| is_for=True, |
| loop_var_shapes=[ |
| (1, ), # i |
| ], |
| free_var_shapes=[ |
| (1, 3), # scanned |
| ], |
| n_steps=1, |
| ) |
| |
| def case_1(**params): |
| # This is a simple testcase that simulates a cumulative sum |
| # There is 1 output |
| # There is 1 state: s |
| step_funcs = [ |
| lambda a, b, s: s, |
| lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, |
| lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, |
| lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, |
| lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5, |
| lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5, |
| lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5, |
| lambda a, b, s: a * 2.5 * b + s * 0.3, |
| lambda a, b, s: b * 2.5 * a + s * 0.3, |
| lambda a, b, s: 2.5 * a * b + s * 0.3, |
| lambda a, b, s: b * a * 2.5 + s * 0.3, |
| lambda a, b, s: 2.5 * b * a + s * 0.3, |
| lambda a, b, s: b * a * 2.5 + s * 0.3, |
| lambda a, b, s: s * 0.3 + a * 2.5 * b, |
| lambda a, b, s: s * 0.3 + b * 2.5 * a, |
| lambda a, b, s: s * 0.3 + 2.5 * a * b, |
| lambda a, b, s: s * 0.3 + b * a * 2.5, |
| lambda a, b, s: s * 0.3 + 2.5 * b * a, |
| lambda a, b, s: s * 0.3 + b * a * 2.5, |
| ] |
| def make_func(step_func): |
| def step(loop, free): |
| (s, ), (a, b) = loop, free |
| out = step_func(a, b, s) |
| return (out, out) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| is_train=is_train, |
| is_for=False, |
| **params |
| ) |
| |
| def case_2(**params): |
| # This is a testcase that involves non-differentiable operators |
| # There is 1 output |
| # There is 2 states: i, s |
| step_funcs = [ |
| lambda in_, s, f_1: (in_ * 2) * s * f_1, |
| lambda in_, s, f_1: (in_ * 2) * f_1 * s, |
| lambda in_, s, f_1: s * (in_ * 2) * f_1, |
| lambda in_, s, f_1: s * f_1 * (in_ * 2), |
| lambda in_, s, f_1: f_1 * (in_ * 2) * s, |
| lambda in_, s, f_1: f_1 * s * (in_ * 2), |
| lambda in_, s, f_1: (2 * in_) * s * f_1, |
| lambda in_, s, f_1: (2 * in_) * f_1 * s, |
| lambda in_, s, f_1: s * (2 * in_) * f_1, |
| lambda in_, s, f_1: s * f_1 * (2 * in_), |
| lambda in_, s, f_1: f_1 * (2 * in_) * s, |
| lambda in_, s, f_1: f_1 * s * (2 * in_), |
| ] |
| def make_func(step_func): |
| """This simulates: |
| def compute(s, inputs, f_1, length): |
| outputs = [] |
| for i in range(length): |
| s += inputs[i] * 2 + f_1 |
| outputs.append(s) |
| return outputs, s |
| """ |
| def step(loop, free): |
| (i, s), (scanned, f_1, _) = loop, free |
| in_ = scanned.take(i).squeeze(axis=0) |
| out = step_func(in_, s, f_1) |
| return (out, (i + 1, out)) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| max_iterations=1000, |
| is_train=is_train, |
| is_for=True, |
| **params |
| ) |
| |
| def case_3(length, **params): |
| # This is a testcase for multiple non-differentiable operators and different ways of slicing |
| # There are 2 outputs |
| # There are 3 states: i, s_0, s_1 |
| step_funcs = [ |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_1, |
| lambda i_0, i_1, s_0, s_1, f_0: s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: s_1, |
| lambda i_0, i_1, s_0, s_1, f_0: f_0, |
| ] |
| def make_func(step_func): |
| """This simulates: |
| def compute(input_0, input_1, s_0, s_1, f_0, length): |
| output_0 = [] |
| output_1 = [] |
| for i in range(length): |
| i_0 = input_0[i] |
| i_1 = input_1[length - 1 - i] |
| out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 |
| s_0 = (s_0 + out) * 1.05 |
| s_1 = (s_1 - out * 0.5) * 0.95 |
| output_0.append(out) |
| output_1.append(out * 1.5) |
| return outputs, s_0, s_1 |
| """ |
| def step(loop, free): |
| (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free |
| i_0 = sc_0.take(i).squeeze(axis=0) |
| i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) |
| out = step_func(i_0, i_1, s_0, s_1, f_0) |
| return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95]) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| max_iterations=1000, |
| is_train=is_train, |
| is_for=True, |
| **params |
| ) |
| |
| def case_4(length, single_shape, **params): |
| # It is for the case that inputs & outputs are the same |
| # There are 3 outputs |
| # There are 4 states: i, s_0, s_1, s_2 |
| # i is used in both non-differentiable (take) and differentiable (+) occasions |
| step_funcs = [ |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_1, |
| lambda i_0, i_1, s_0, s_1, f_0: s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: s_1, |
| lambda i_0, i_1, s_0, s_1, f_0: f_0, |
| ] |
| def make_func(step_func): |
| """This simulates: |
| def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): |
| # here s_2 remains untouched |
| output_0 = [] |
| output_1 = [] |
| output_2 = [] |
| for i in range(length): |
| i_0 = input_0[i] |
| i_1 = input_1[length - 1 - i] |
| out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 |
| out = out * i * i_0 * i_1 |
| s_0 = (s_0 + out) * 1.05 |
| s_1 = (s_1 - out * 0.5) * 0.95 |
| output_0.append(out) |
| output_1.append(f_0) |
| output_2.append(out * 1.5) |
| return output_0, output_1, output_2, s_0, s_1, s_2 |
| """ |
| def step(loop, free): |
| (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free |
| i_0 = sc_0.take(i).squeeze(axis=0) |
| i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) |
| out = step_func(i_0, i_1, s_0, s_1, f_0) |
| out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) |
| out = out * i_0 * i_1 |
| return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| max_iterations=1000, |
| is_train=is_train, |
| is_for=True, |
| **params |
| ) |
| |
| def case_5(length, single_shape, **params): |
| # It is for the case that inputs & outputs are the same |
| # There are 0 outputs |
| # There are 4 states: i, s_0, s_1, s_2 |
| # i is used in both differentiable (take) and non-differentiable (+) occasions |
| step_funcs = [ |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_1, |
| lambda i_0, i_1, s_0, s_1, f_0: s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: s_1, |
| lambda i_0, i_1, s_0, s_1, f_0: f_0, |
| ] |
| def make_func(step_func): |
| """This simulates: |
| def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): |
| # here s_2 remains untouched |
| output_0 = [] |
| output_1 = [] |
| output_2 = [] |
| for i in range(length): |
| i_0 = input_0[i] |
| i_1 = input_1[length - 1 - i] |
| out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 |
| out = out * i * i_0 * i_1 |
| s_0 = (s_0 + out) * 1.05 |
| s_1 = (s_1 - out * 0.5) * 0.95 |
| output_0.append(out) |
| output_1.append(f_0) |
| output_2.append(out * 1.5) |
| return output_0, output_1, output_2, s_0, s_1, s_2 |
| """ |
| def step(loop, free): |
| (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free |
| i_0 = sc_0.take(i).squeeze(axis=0) |
| i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) |
| out = step_func(i_0, i_1, s_0, s_1, f_0) |
| out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) |
| out = out * i_0 * i_1 |
| return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| max_iterations=1000, |
| is_train=is_train, |
| is_for=True, |
| **params |
| ) |
| |
| def case_6(length, single_shape, **params): |
| # It is for the case that inputs & outputs are the same |
| # There are 3 outputs |
| # There are 4 states: i, s_0, s_1, s_2 |
| # i is used in both differentiable (take) and non-differentiable (+) occasions |
| step_funcs = [ |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, |
| lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_0, |
| lambda i_0, i_1, s_0, s_1, f_0: i_1, |
| lambda i_0, i_1, s_0, s_1, f_0: s_0, |
| lambda i_0, i_1, s_0, s_1, f_0: s_1, |
| lambda i_0, i_1, s_0, s_1, f_0: f_0, |
| ] |
| def make_func(step_func): |
| """This simulates: |
| def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): |
| # here s_2 remains untouched |
| output_0 = [] |
| output_1 = [] |
| output_2 = [] |
| for i in range(length): |
| i_0 = input_0[i] |
| i_1 = input_1[length - 1 - i] |
| out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 |
| out = out * i * i_0 * i_1 |
| s_0 = (s_0 + out) * 1.05 |
| s_1 = (s_1 - out * 0.5) * 0.95 |
| output_0.append(out) |
| output_1.append(f_0) |
| output_2.append(out * 1.5) |
| return output_0, output_1, output_2, s_0, s_1, s_2 |
| """ |
| def step(loop, free): |
| (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free |
| F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd |
| i_0 = sc_0.take(i).squeeze(axis=0) |
| i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) |
| out_0 = step_func(i_0, i_1, s_0, s_1, f_0) |
| out_0 = out_0 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) |
| out_1 = step_func(i_1, s_0, f_0, s_1, i_0) |
| out_1 = out_1 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) |
| return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i + 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2]) |
| return step |
| case_id = 0 |
| for is_train in [True, False]: |
| for step_func in step_funcs: |
| case_id += 1 |
| _verify_while_loop( |
| func=make_func(step_func), |
| max_iterations=1000, |
| is_train=is_train, |
| is_for=True, |
| **params |
| ) |
| |
| # Case 0: the simpest case |
| case_0() |
| # Case 1.1.* |
| case_1( |
| cond=make_true_cond(), |
| loop_var_shapes=[ |
| (1, ), # s |
| ], |
| free_var_shapes=[ |
| (1, ), # a |
| (1, ), # b |
| ], |
| max_iterations=5, |
| n_steps=5, |
| ) |
| # Case 1.2.* |
| case_1( |
| cond=make_true_cond(), |
| loop_var_shapes=[ |
| (2, 3, 4), # s |
| ], |
| free_var_shapes=[ |
| (2, 3, 4), # a |
| (2, 3, 4), # b |
| ], |
| max_iterations=3, |
| n_steps=3, |
| ) |
| # Case 1.3.* |
| case_1( |
| cond=make_false_cond(), |
| loop_var_shapes=[ |
| (2, 3, 4), # s |
| ], |
| free_var_shapes=[ |
| (2, 3, 4), # a |
| (2, 3, 4), # b |
| ], |
| max_iterations=20, |
| n_steps=0, |
| ) |
| # Case 2.1.* |
| case_2( |
| cond=make_for_cond(length=5), |
| loop_var_shapes=[ |
| (1, ), # i |
| (2, ), # s |
| ], |
| free_var_shapes=[ |
| (100, 2), # scanned |
| (2, ), # f_1 |
| (3, 4, 5, 6), # f_2, unused |
| ], |
| n_steps=5, |
| ) |
| # Case 2.2.* |
| case_2( |
| cond=make_for_cond(length=3), |
| loop_var_shapes=[ |
| (1, ), # i |
| (2, ), # s |
| ], |
| free_var_shapes=[ |
| (30, 2), # scanned |
| (2, ), # f_1 |
| (3, 4, 5, 6), # f_2, unused |
| ], |
| n_steps=3, |
| ) |
| # Case 3.* |
| case_3( |
| length=5, |
| cond=make_for_cond(length=5), |
| loop_var_shapes=[ |
| (1, ), # i |
| (2, ), # s_0 |
| (2, ), # s_1 |
| ], |
| free_var_shapes=[ |
| (30, 2), # sc_0 |
| (30, 2), # sc_1 |
| (2, ), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=5, |
| ) |
| # Case 4.1.* |
| case_4( |
| length=4, |
| cond=make_for_cond(length=4), |
| single_shape=[5], |
| loop_var_shapes=[ |
| (1, ), # i |
| (5, ), # s_0 |
| (5, ), # s_1 |
| (23, 6, 8), # s_2 |
| ], |
| free_var_shapes=[ |
| (30, 5), # sc_0 |
| (30, 5), # sc_1 |
| (5, ), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=4, |
| ) |
| # Case 4.2.* |
| case_4( |
| length=5, |
| cond=make_for_cond(length=5), |
| single_shape=[5, 12], |
| loop_var_shapes=[ |
| (1, ), # i |
| (5, 12), # s_0 |
| (5, 12), # s_1 |
| (23, 6, 8), # s_2 |
| ], |
| free_var_shapes=[ |
| (30, 5, 12), # sc_0 |
| (30, 5, 12), # sc_1 |
| (5, 12), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=5, |
| ) |
| # Case 5.1.* |
| case_5( |
| length=4, |
| cond=make_for_cond(length=4), |
| single_shape=[5], |
| loop_var_shapes=[ |
| (1, ), # i |
| (5, ), # s_0 |
| (5, ), # s_1 |
| (23, 6, 8), # s_2 |
| ], |
| free_var_shapes=[ |
| (30, 5), # sc_0 |
| (30, 5), # sc_1 |
| (5, ), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=4, |
| ) |
| # Case 5.2.* |
| case_5( |
| length=5, |
| cond=make_for_cond(length=5), |
| single_shape=[3, 4, 2], |
| loop_var_shapes=[ |
| (1, ), # i |
| (3, 4, 2), # s_0 |
| (3, 4, 2), # s_1 |
| (23, 6, 8), # s_2 |
| ], |
| free_var_shapes=[ |
| (30, 3, 4, 2), # sc_0 |
| (30, 3, 4, 2), # sc_1 |
| (3, 4, 2), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=5, |
| ) |
| # Case 6.* |
| case_6( |
| length=5, |
| cond=make_for_cond(length=5), |
| single_shape=[5, 3], |
| loop_var_shapes=[ |
| (1, ), # i |
| (5, 3), # s_0 |
| (5, 3), # s_1 |
| (3, 5), # s_2 |
| ], |
| free_var_shapes=[ |
| (30, 5, 3), # sc_0 |
| (30, 5, 3), # sc_1 |
| (5, 3), # f_0 |
| (3, 4, 5, 6), # f_1, unused |
| ], |
| n_steps=5, |
| ) |
| |
| |
| @with_seed() |
| def test_while_loop_nested(): |
| |
| def _to_np_list(arrays): |
| return [x.asnumpy() if x is not None else x for x in arrays] |
| |
| def _array(shape): |
| return mx.nd.random.uniform(-1.0, 1.0, shape=shape) |
| |
| def inner_cond(i, j, x_sum, sc): |
| return j < 2 |
| |
| def inner_body(i, j, x_sum, sc): |
| x_ij = sc.take(j).squeeze(axis=0) |
| return (x_ij, x_ij), (i, j + 1, x_sum, sc) |
| |
| def outer_cond(i, j, x_sum, sc): |
| return i < 2 |
| |
| def outer_body(i, j, x_sum, sc): |
| F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd |
| (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop( |
| cond=inner_cond, |
| func=inner_body, |
| loop_vars=(i, j, x_sum, sc), |
| max_iterations=2, |
| ) |
| return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p) |
| |
| def make_loop(i, j, x_sum, sc): |
| F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd |
| (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop( |
| cond=outer_cond, |
| func=outer_body, |
| loop_vars=(i, j, x_sum, sc), |
| max_iterations=2, |
| ) |
| return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji |
| |
| args = { |
| "i": mx.nd.array([0]), |
| "j": mx.nd.array([0]), |
| "x_sum": _array([5, 3]), |
| "sc": _array([10, 10, 5, 3]), |
| } |
| args_grad = { |
| "x_sum": _array([5, 3]), |
| "sc": _array([10, 10, 5, 3]), |
| } |
| out_grad = [ |
| _array([1]), |
| _array([1]), |
| _array([5, 3]), |
| _array([10, 10, 5, 3]), |
| _array([2, 2, 10, 5, 3]), |
| _array([2, 2, 10, 5, 3]), |
| ] |
| def _get_imp_result(is_train, args, args_grad, out_grad): |
| args = {k: v.copy() for k, v in args.items()} |
| args_grad = {k: v.copy() for k, v in args_grad.items()} |
| i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]] |
| if is_train: |
| x_sum.attach_grad() |
| sc.attach_grad() |
| with mx.autograd.record(train_mode=is_train): |
| results = make_loop(i, j, x_sum, sc) |
| cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0) |
| if not is_train: |
| return _to_np_list(results), [] |
| cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0) |
| assert cat_grad.shape == cat_res.shape |
| cat_res.backward(out_grad=cat_grad) |
| grads = [x_sum.grad, sc.grad] |
| return _to_np_list(results), _to_np_list(grads) |
| |
| def _get_sym_result(is_train, args, args_grad, out_grad): |
| args = {k: v.copy() for k, v in args.items()} |
| args_grad = {k: v.copy() for k, v in args_grad.items()} |
| i, j, x_sum, sc = [ |
| mx.sym.var("i"), |
| mx.sym.var("j"), |
| mx.sym.var("x_sum"), |
| mx.sym.var("sc"), |
| ] |
| result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc)) |
| executor = result_sym.bind( |
| ctx=default_context(), |
| args=args, |
| args_grad=args_grad, |
| ) |
| results = executor.forward(is_train=is_train) |
| if not is_train: |
| return _to_np_list(results), [] |
| executor.backward(out_grads=out_grad) |
| grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]] |
| return _to_np_list(results), _to_np_list(grads) |
| |
| for is_train in [True, False]: |
| imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) |
| sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) |
| assert len(imp_out) == len(sym_out) |
| assert len(imp_grad) == len(sym_grad) |
| for x, y in zip(imp_out, sym_out): |
| assert_almost_equal(x, y, rtol=1e-3, atol=1e-3) |
| for x, y in zip(imp_grad, sym_grad): |
| assert_almost_equal(x, y, rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_while_loop_rnn(): |
| def _array(shape): |
| return mx.nd.random.uniform(-1.0, 1.0, shape=shape) |
| |
| cell_types = [mx.rnn.LSTMCell] |
| num_params = [2] |
| |
| batch_size = 2 |
| hidden_dim = 3 |
| input_dim = 4 |
| seq_len = 3 |
| |
| for cell, n_param in zip(cell_types, num_params): |
| # using while_loop |
| params = mx.rnn.RNNParams() |
| data = mx.sym.var("data") |
| iter_i = mx.sym.var("i") |
| def _cond(*states): |
| i = states[0] |
| return i < seq_len |
| def _func(*states): |
| i = states[0] |
| states = states[1:] |
| in_ = data.take(i).squeeze(axis=0) |
| rnn = cell(hidden_dim, prefix='', params=params) |
| next_hidden, next_states = rnn(in_, states) |
| return [next_hidden], [i + 1] + list(next_states) |
| states = [mx.sym.var("s_" + str(i)) for i in range(n_param)] |
| result = mx.sym.contrib.while_loop( |
| cond=_cond, |
| func=_func, |
| loop_vars=[iter_i] + states, |
| max_iterations=seq_len |
| ) |
| result = mx.sym.Group(result[0] + result[1][1: ]) |
| arg_shapes, _, _ = result.infer_shape( |
| data=(seq_len, batch_size, input_dim), |
| s_0=(batch_size, hidden_dim), |
| ) |
| rnn_inputs = result.list_inputs() |
| args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs) if name != "i"} |
| args["i"] = mx.nd.zeros([1]) |
| args_grad = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} |
| e_1 = result.bind(ctx=default_context(), |
| args={name: array.copy() for name, array in args.items()}, |
| args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, |
| ) |
| # using unrolled rnn |
| rnn = cell(hidden_dim, prefix='') |
| unroll_outs = [] |
| for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): |
| h, states = rnn(inputs, states) |
| unroll_outs.append(mx.sym.expand_dims(h, axis=0)) |
| unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) |
| unroll_outs.extend(states) |
| result = mx.sym.Group(unroll_outs) |
| e_2 = result.bind(ctx=default_context(), |
| args={name: array.copy() for name, array in args.items() if name != "i"}, |
| args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, |
| ) |
| for case_id in range(100): |
| out_grads = [_array(arr.shape) for arr in e_1.outputs] |
| args = {name: array.copy() for name, array in args.items()} |
| e_1.forward(is_train=True, **args) |
| e_1.backward(out_grads) |
| args = {name: array.copy() for name, array in args.items() if name != "i"} |
| e_2.forward(is_train=True, **args) |
| e_2.backward(out_grads) |
| assert len(e_1.outputs) == len(e_2.outputs) |
| for x, y in zip(e_1.outputs, e_2.outputs): |
| x = x.asnumpy() |
| y = y.asnumpy() |
| assert_almost_equal(x, y, rtol=1e-3, atol=1e-3) |
| grad_keys = list(e_2.grad_dict.keys()) |
| e_1_grad = [e_1.grad_dict[x] for x in grad_keys] |
| e_2_grad = [e_2.grad_dict[x] for x in grad_keys] |
| for x, y in zip(e_1_grad, e_2_grad): |
| x = x.asnumpy() |
| y = y.asnumpy() |
| assert_almost_equal(x, y, rtol=1e-3, atol=1e-3) |
| |
| def _verify_cond(cond_func, then_func, else_func, input_var_shapes, free_var_shapes, is_train): |
| |
| def _create_symbol(prefix, i): |
| return mx.sym.var(prefix + str(i)) |
| |
| def _create_array(shape): |
| return mx.nd.random.uniform(-1.0, 1.0, shape=shape) |
| |
| def _to_numpy_list(arrays): |
| return [x.asnumpy() if x is not None else x for x in arrays] |
| |
| def _merge_dict(*dicts): |
| result = {} |
| for item in dicts: |
| result.update(item) |
| return result |
| |
| _input_syms = [_create_symbol("InputVar", i) for i, _ in enumerate(input_var_shapes)] |
| _free_syms = [_create_symbol("FreeVar", i) for i, _ in enumerate(free_var_shapes)] |
| _input_vars = [_create_array(x) for x in input_var_shapes] |
| _free_vars = [_create_array(x) for x in free_var_shapes] |
| _args_dict = _merge_dict( |
| {"InputVar" + str(i): x for i, x in enumerate(_input_vars)}, |
| {"FreeVar" + str(i): x for i, x in enumerate(_free_vars)}, |
| ) |
| |
| def _get_imperative_result(): |
| free_vars = [x.copy() for x in _free_vars] |
| input_vars = [x.copy() for x in _input_vars] |
| out_grads = [] |
| if is_train: |
| for var in free_vars + input_vars: |
| var.attach_grad() |
| with mx.autograd.record(train_mode=is_train): |
| outputs = mx.nd.contrib.cond( |
| pred=cond_func(input_vars, free_vars), |
| then_func=lambda: then_func(input_vars, free_vars), |
| else_func=lambda: else_func(input_vars, free_vars), |
| ) |
| outputs = _as_list(outputs) |
| outputs = [x * 2 for x in outputs] |
| grads = [] |
| if is_train: |
| out_grads = [_create_array(x.shape) for x in outputs] |
| cat_out = mx.nd.concat(*[x.reshape(-1) for x in outputs], dim=0) |
| cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) |
| grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ |
| + [input_vars[i].grad for i, _ in enumerate(input_var_shapes)] |
| return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads |
| |
| def _get_symbolic_result(out_grads): |
| outputs_sym = mx.sym.contrib.cond( |
| pred=cond_func(_input_syms, _free_syms), |
| then_func=lambda: then_func(_input_syms, _free_syms), |
| else_func=lambda: else_func(_input_syms, _free_syms), |
| ) |
| outputs_sym = _as_list(outputs_sym) |
| outputs_sym = [x * 2 for x in outputs_sym] |
| outputs_sym = mx.sym.Group(outputs_sym) |
| executor = outputs_sym.bind( |
| ctx=default_context(), |
| args={name: _args_dict[name].copy() for name in outputs_sym.list_inputs()}, |
| args_grad=None if not is_train else _merge_dict( |
| {"InputVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(input_var_shapes)}, |
| {"FreeVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(free_var_shapes)}, |
| ), |
| ) |
| outputs = executor.forward(is_train=is_train) |
| grads = [] |
| if is_train: |
| executor.backward(out_grads=out_grads) |
| grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ |
| + [executor.grad_dict.get("InputVar" + str(i), None) for i, _ in enumerate(input_var_shapes)] |
| return _to_numpy_list(outputs), _to_numpy_list(grads) |
| |
| imp_outs, imp_grads, out_grads = _get_imperative_result() |
| sym_outs, sym_grads = _get_symbolic_result(out_grads) |
| for imp_out, sym_out in zip(imp_outs, sym_outs): |
| if imp_out is None or sym_out is None: |
| continue |
| assert_almost_equal(imp_out, sym_out, rtol=1e-3, atol=1e-3) |
| for imp_grad, sym_grad in zip(imp_grads, sym_grads): |
| if imp_grad is None or sym_grad is None: |
| continue |
| assert_almost_equal(imp_grad, sym_grad, rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_cond(): |
| # whether there are free variables in three graphs |
| # whether these three graphs contain input_vars |
| # whether to use all input_vars |
| # which branch to choose |
| def run_case(cond_func, then_func, else_func, **params): |
| def make_cond(is_inverse): |
| def cond(inputs, free): |
| x = cond_func(inputs, free) |
| if is_inverse: |
| if isinstance(x, mx.sym.Symbol): |
| return mx.sym.logical_not(x) |
| else: |
| return mx.nd.logical_not(x) |
| return x |
| return cond |
| for is_train in [True, False]: |
| for is_inverse in [False, True]: |
| _verify_cond( |
| cond_func=make_cond(is_inverse), |
| then_func=then_func, |
| else_func=else_func, |
| is_train=is_train, |
| **params |
| ) |
| # Each function can |
| # 1. use_free_vars or not: T/F |
| # 2. use_input_vars or not: T/F |
| # 3. use_all_input_vars or not: T/F |
| # (a, b, c) are inputs, (d, e, f) are free_vars |
| cond_funcs = [ |
| lambda a, b, c, d, e, f: (a * b).sum() < 0.5, # F, T, F |
| lambda a, b, c, d, e, f: (a + b + c).sum() < 0.5, # F, T, T |
| lambda a, b, c, d, e, f: (d + e).sum() < 0.5, # T, F, F |
| lambda a, b, c, d, e, f: (d + e * a).sum() < 0.5, # T, T, F |
| lambda a, b, c, d, e, f: (d + e * a + b * c).sum() < 0.5, # T, T, T |
| ] |
| body_funcs = [ |
| lambda a, b, c, d, e, f: a * b, # F, T, F |
| lambda a, b, c, d, e, f: a * b * c, # F, T, T |
| lambda a, b, c, d, e, f: d * e, # T, F, F |
| lambda a, b, c, d, e, f: d * e * a, # T, T, F |
| lambda a, b, c, d, e, f: d * e * a * b * c, # T, T, T |
| # some extra tests |
| lambda a, b, c, d, e, f: b * c, |
| lambda a, b, c, d, e, f: a * c, |
| lambda a, b, c, d, e, f: (a + b) * c, |
| lambda a, b, c, d, e, f: c * (b - a), |
| ] |
| # enumerate all kinds of possible combinations |
| for cond_func in cond_funcs: |
| for then_func in body_funcs: |
| for else_func in body_funcs: |
| run_case( |
| cond_func=lambda x, y: cond_func(x[0], x[1], x[2], y[0], y[1], y[2]), |
| then_func=lambda x, y: then_func(x[0], x[1], x[2], y[0], y[1], y[2]), |
| else_func=lambda x, y: else_func(x[0], x[1], x[2], y[0], y[1], y[2]), |
| input_var_shapes=[ |
| (2, 3), |
| (2, 3), |
| (2, 3), |
| ], |
| free_var_shapes=[ |
| (2, 3), |
| (2, 3), |
| (2, 3), |
| ] |
| ) |
| |
| class TestRNNLayer(gluon.HybridBlock): |
| def __init__(self, cell_type, hidden_size, prefix=None, params=None): |
| super(TestRNNLayer, self).__init__(prefix=prefix, params=params) |
| self.cell = cell_type(hidden_size, prefix='rnn_') |
| |
| def hybrid_forward(self, F, inputs, states): |
| out, states = F.contrib.foreach(self.cell, inputs, states) |
| return out |
| |
| def check_contrib_rnn(cell_type, num_states): |
| batch_size = 10 |
| hidden_size = 100 |
| rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50)) |
| state_shape = (batch_size, hidden_size) |
| states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] |
| layer = TestRNNLayer(cell_type, hidden_size) |
| layer.initialize(ctx=default_context()) |
| res1 = layer(rnn_data, states) |
| params1 = layer.collect_params() |
| orig_params1 = copy.deepcopy(params1) |
| |
| trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) |
| with mx.autograd.record(): |
| res1 = layer(rnn_data, states) |
| res1.backward() |
| trainer.step(batch_size) |
| |
| configs = [ |
| {}, |
| {'inline_limit': 0}, |
| {'static_alloc': True}, |
| {'static_alloc': True, 'static_shape': True} ] |
| for config in configs: |
| layer = TestRNNLayer(cell_type, hidden_size) |
| layer.initialize(ctx=default_context()) |
| layer.hybridize(**config) |
| res2 = layer(rnn_data, states) |
| params2 = layer.collect_params() |
| for key, val in orig_params1.items(): |
| params2[key].set_data(copy.deepcopy(val.data())) |
| |
| trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) |
| with mx.autograd.record(): |
| res2 = layer(rnn_data, states) |
| assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3) |
| res2.backward() |
| trainer.step(batch_size) |
| |
| for key, val in params1.items(): |
| weight1 = val.data() |
| weight2 = params2[key].data() |
| assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), |
| rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_contrib_rnn(): |
| cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), |
| (gluon.rnn.GRUCell, 1)] |
| for cell_type, num_states in cell_types: |
| check_contrib_rnn(cell_type, num_states) |
| |
| |
| @with_seed() |
| def test_foreach(): |
| v3 = mx.sym.var("v0") |
| v4 = mx.sym.var("v1") |
| v5 = mx.sym.var("v2") |
| v6 = mx.sym.var("v3") |
| v7 = mx.sym.var("v4") |
| v8 = mx.sym.var("v5") |
| |
| def verify_foreach(step, in_syms, state_syms, free_syms, |
| in_arrs, init_states, frees, out_grads, is_train=True, |
| free_vars_func=None, num_iters=1): |
| step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms) |
| res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms) |
| out = _as_list(res) |
| num_outputs = len(out) |
| for i in range(num_outputs): |
| out[i] = out[i] * 2 |
| out.extend(states) |
| out = mx.sym.Group(out) |
| js_1 = out.tojson() |
| out = mx.sym.load_json(js_1) |
| js_2 = out.tojson() |
| assert js_1 == js_2 |
| arr_grads = [] |
| arg_dict = {} |
| arg_grad_dict = {} |
| i = 0 |
| for arr in _as_list(in_arrs): |
| arr_grad = mx.nd.empty(arr.shape) |
| arr_grads.append(arr_grad) |
| arg_dict['v'+str(i)] = arr |
| arg_grad_dict['v'+str(i)] = arr_grad |
| i = i + 1 |
| for arr in init_states: |
| arr_grad = mx.nd.empty(arr.shape) |
| arr_grads.append(arr_grad) |
| arg_dict['v'+str(i)] = arr |
| arg_grad_dict['v'+str(i)] = arr_grad |
| i = i + 1 |
| for arr in frees: |
| arr_grad = mx.nd.empty(arr.shape) |
| arr_grads.append(arr_grad) |
| arg_dict['v'+str(i)] = arr |
| arg_grad_dict['v'+str(i)] = arr_grad |
| i = i + 1 |
| |
| if is_train: |
| e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict) |
| else: |
| e = out.bind(ctx=default_context(), args=arg_dict) |
| # the inputs to forward and backward are the same so forward and backward |
| # should always return the same outputs. |
| for i in range(num_iters): |
| e.forward(is_train=is_train) |
| if (is_train): |
| # backward |
| tmp_grads = out_grads[0][:] |
| tmp_grads.extend(out_grads[1]) |
| e.backward(tmp_grads) |
| |
| # Below we use imperative to reimplement foreach and compute its gradients. |
| res = [] |
| for i in range(len(_as_list(out_grads[0]))): |
| res.append([]) |
| for arr in _as_list(in_arrs): |
| arr.attach_grad() |
| for arr in init_states: |
| arr.attach_grad() |
| for arr in frees: |
| arr.attach_grad() |
| with mx.autograd.record(): |
| frees_imp = frees if free_vars_func is None else free_vars_func(frees) |
| step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp) |
| states = [mx.nd.expand_dims(s, 0) for s in init_states] |
| res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states) |
| |
| res2 = _as_list(res) |
| for i in range(len(res2)): |
| res2[i] = res2[i] * 2 |
| outs = [] |
| outs[:] = res2[:] |
| if isinstance(states, list): |
| outs.extend(states) |
| states = [mx.nd.expand_dims(s, 0) for s in states] |
| res2.extend(states) |
| else: |
| outs.append(states) |
| states = mx.nd.expand_dims(states, 0) |
| res2.append(states) |
| if is_train: |
| res = mx.nd.concat(*res2, dim=0) |
| |
| tmp_grads = out_grads[0][:] |
| tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]] |
| tmp_grads.extend(tmp_grads1) |
| if is_train: |
| res.backward(mx.nd.concat(*tmp_grads, dim=0)) |
| for i in range(len(outs)): |
| assert e.outputs[i].shape == outs[i].shape |
| assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(), |
| rtol=1e-3, atol=1e-3) |
| if (is_train): |
| all_ins = _as_list(in_arrs)[:] |
| all_ins.extend(init_states) |
| all_ins.extend(frees) |
| size = min(len(all_ins), len(e.grad_arrays)) |
| for i in range(size): |
| assert_almost_equal(all_ins[i].grad.asnumpy(), |
| e.grad_arrays[i].asnumpy(), |
| rtol=1e-3, atol=1e-3) |
| |
| # Test cases: |
| # * graph inputs are stored in different orders. |
| # This is to test if foreach finds the data arrays and weight arrays |
| # in the right location. |
| # * the number of iterations: odd or even. |
| # * multiple inputs and multiple outputs. |
| # * inference. |
| def step1(in1, states, free): |
| out = in1 * 2 + states[0] + free[0] |
| return (out, [out]) |
| frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1] |
| arrs = mx.nd.arange(6).reshape(shape=(3, 2)) |
| states = [mx.nd.arange(2)] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, |
| lambda frees : [frees[0] + frees[1]]) |
| verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, |
| lambda frees : [frees[0] + frees[1]]) |
| verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, |
| lambda frees : [frees[0] + frees[1]], 5) |
| verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, |
| lambda frees : [frees[0] + frees[1]], 5) |
| |
| # Test the even number of iterations. |
| frees = [mx.nd.random.uniform(shape=(2))] |
| arrs = mx.nd.random.uniform(shape=(2, 2)) |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| # Test the odd number of iterations |
| arrs = mx.nd.random.uniform(shape=(3, 2)) |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| |
| # Reorder the input and state in the subgraph inputs. |
| def step2(in1, states, free): |
| out = states[0] + in1 * 2 + free[0] |
| return (out, [out]) |
| # Test the even number of iterations. |
| arrs = mx.nd.random.uniform(shape=(2, 2)) |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| # Test the odd number of iterations. |
| arrs = mx.nd.random.uniform(shape=(3, 2)) |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| |
| # Test multiple inputs and outputs. |
| def step3(in1, states, free): |
| out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0] |
| return ([out, out], [out * 2, out * 3]) |
| arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] |
| states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] |
| verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) |
| verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) |
| |
| # Test multiple inputs and outputs. |
| # The order of subgraph inputs doesn't match the operator inputs |
| def step4(in1, states, free): |
| out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0] |
| return ([out, out * 2], [out * 2, out * 3]) |
| arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] |
| states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] |
| verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) |
| verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) |
| |
| # Test multiple inputs and outputs. |
| # The data inputs and states have different shapes. |
| def step5(in1, states, free): |
| if isinstance(in1[0], mx.nd.NDArray): |
| out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2) |
| out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2) |
| else: |
| out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2) |
| out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2) |
| return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) |
| frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))] |
| arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))] |
| states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] |
| verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) |
| |
| # Test multiple inputs and outputs. |
| # The data inputs and states have different shapes and data types. |
| def step6(in1, states, free): |
| if isinstance(in1[0], mx.nd.NDArray): |
| out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'), |
| mx.nd.cast(in1[1], 'float32') * 2) |
| out2 = mx.nd.broadcast_add(in1[0], |
| free[0] + mx.nd.cast(states[1], 'float32') * 2) |
| else: |
| out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'), |
| mx.sym.cast(in1[1], 'float32') * 2) |
| out2 = mx.sym.broadcast_add(in1[0], |
| free[0] + mx.sym.cast(states[1], 'float32') * 2) |
| return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) |
| frees = [mx.nd.random.uniform(shape=(2)), |
| mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')] |
| arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), |
| mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')] |
| states = [mx.nd.random.uniform(shape=(2, 2)), |
| mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] |
| verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) |
| |
| # Test multiple inputs and outputs. |
| # some of the inputs are used twice. |
| def step7(in1, states, free): |
| out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0] |
| out2 = in1[0] + free[0] + states[1] * 2 + in1[1] |
| return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) |
| frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] |
| arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] |
| states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] |
| verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) |
| |
| # Test the case that the output is the input. |
| arrs = mx.nd.random.uniform(shape=(3, 2)) |
| states = [mx.nd.arange(2)] |
| frees = [mx.nd.random.uniform(shape=(2))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], |
| [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| def step8(in1, states, free): |
| return (in1, [states[0] * free[0]]) |
| verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| def step9(in1, states, free): |
| return (in1 * free[0], states) |
| verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| |
| # Test the case that not all inputs are used. |
| def step10(in1, states, free): |
| return (in1, states) |
| verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| def step11(in1, states, free): |
| return (in1, free) |
| try: |
| verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| except AssertionError: |
| print("the states have to be used") |
| def step12(in1, states, free): |
| return (in1, [states[0] + 1, states[0] + 2]) |
| states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] |
| frees = [] |
| try: |
| verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads) |
| verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False) |
| except AssertionError: |
| print("the states have to be used") |
| |
| # test without free variables. |
| def step13(in1, states, free): |
| return (in1, states) |
| states = [mx.nd.random.uniform(shape=(2))] |
| verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads) |
| verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False) |
| |
| # test when there isn't output data or output states. |
| def step14(in1, states, free): |
| return (in1 + free[0], []) |
| frees = [mx.nd.random.uniform(shape=(2))] |
| verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads) |
| verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False) |
| def step15(in1, states, free): |
| return ([], [in1 * states[0] * free[0]]) |
| out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]] |
| verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads) |
| verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False) |
| |
| # Test the case of iterating on a 1D data array. |
| def step16(in1, states, free): |
| return ([in1[0] * states[0]], [states[0] * 2]) |
| arrs = [mx.nd.arange(3)] |
| states = [mx.nd.random.uniform(shape=(1))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], |
| [mx.nd.random.uniform(-10, 10, (1))]] |
| verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads) |
| verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False) |
| def step17(in1, states, free): |
| return ([in1[1] * in1[0] * states[0]], [states[0] * 2]) |
| arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)] |
| states = [mx.nd.random.uniform(shape=(1))] |
| out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], |
| [mx.nd.random.uniform(-10, 10, (1))]] |
| verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads) |
| verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False) |
| |
| |
| @with_seed() |
| def test_foreach_nested(): |
| # Test nested foreach. |
| def step_in(in1, states): |
| out = in1 * 2 + states[0] |
| return (out, [out]) |
| |
| def step_sym(in1, states): |
| out1 = mx.sym.contrib.foreach(step_in, in1, states) |
| out = mx.sym.broadcast_add(out1[0], states[0]) |
| return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))]) |
| def step_nd(in1, states): |
| out1 = mx.nd.contrib.foreach(step_in, in1, states) |
| out = mx.nd.broadcast_add(out1[0], states[0]) |
| return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))]) |
| |
| data_sym = mx.sym.var("v1") |
| state_sym = mx.sym.var("v2") |
| out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym]) |
| assert isinstance(states, list) |
| assert len(states) == 1 |
| out = mx.sym.broadcast_add(out, states[0]) |
| |
| js_1 = out.tojson() |
| out = mx.sym.load_json(js_1) |
| js_2 = out.tojson() |
| assert js_1 == js_2 |
| |
| data = mx.nd.arange(8).reshape((2, 2, 2)) |
| state = mx.nd.arange(2) |
| data_grad = mx.nd.empty(data.shape) |
| state_grad = mx.nd.empty(state.shape) |
| e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state}, |
| args_grad={'v1':data_grad, 'v2':state_grad}) |
| e.forward(is_train=True) |
| out_grads = [] |
| for out in e.outputs: |
| out_grads.append(mx.nd.random.uniform(shape=out.shape)) |
| e.backward(out_grads) |
| |
| data.attach_grad() |
| state.attach_grad() |
| with mx.autograd.record(): |
| out, states = mx.nd.contrib.foreach(step_nd, data, [state]) |
| assert isinstance(states, list) |
| assert len(states) == 1 |
| res = mx.nd.broadcast_add(out, states[0]) |
| assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| res.backward(out_grads[0]) |
| assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy(), rtol=1e-3, atol=1e-3) |
| assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| def check_foreach_rnn(cell_type, num_states): |
| data = mx.sym.var("data") |
| params = mx.rnn.RNNParams() |
| hidden_dim = 4 |
| input_dim = 5 |
| seq_len = 2 |
| batch_size = 2 |
| |
| # This tests foreach with accumulation sum. |
| def step(in1, states): |
| rnn = cell_type(hidden_dim, prefix='', params=params) |
| next_h, states = rnn(in1, states) |
| return (next_h, states) |
| |
| def sym_group(out): |
| if (isinstance(out[0], mx.sym.Symbol)): |
| ret = [out[0]] |
| else: |
| ret = out[0] |
| ret.extend(out[1]) |
| return mx.sym.Group(ret) |
| |
| rnn = cell_type(hidden_dim, prefix='', params=params) |
| if num_states == 2: |
| init_states = [mx.sym.var("h"), mx.sym.var("c")] |
| else: |
| init_states = [mx.sym.var("h")] |
| out = mx.sym.contrib.foreach(step, data, init_states) |
| out = sym_group(out) |
| arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=(seq_len, batch_size, input_dim), |
| h=(batch_size, hidden_dim)) |
| rnn_inputs = out.list_inputs() |
| |
| # Inputs |
| args1 = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} |
| args2 = copy.deepcopy(args1) |
| # gradients for the backward of the foreach symbol |
| args_grad1 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} |
| # gradients for the backward of the unrolled symbol. |
| args_grad2 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} |
| |
| # Symbol of running LSTM with foreach. |
| out = mx.sym.contrib.foreach(step, data, init_states) |
| out = sym_group(out) |
| js_1 = out.tojson() |
| out = mx.sym.load_json(js_1) |
| js_2 = out.tojson() |
| assert js_1 == js_2 |
| e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) |
| |
| # Symbol of running unrolled LSTM. |
| lstm = cell_type(hidden_dim, prefix='') |
| unroll_outs = [] |
| states = init_states |
| for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): |
| h, states = lstm(inputs, states) |
| unroll_outs.append(mx.sym.expand_dims(h, axis=0)) |
| unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) |
| unroll_outs.extend(states) |
| out = mx.sym.Group(unroll_outs) |
| js_1 = out.tojson() |
| out = mx.sym.load_json(js_1) |
| js_2 = out.tojson() |
| assert js_1 == js_2 |
| e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) |
| |
| for i in range(5): |
| out_grads = [] |
| for arr in e1.outputs: |
| out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) |
| |
| args = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} |
| |
| e1.forward(is_train=True, **args) |
| outputs1 = e1.outputs |
| e1.backward(out_grads) |
| |
| e2.forward(is_train=True, **args) |
| outputs2 = e2.outputs |
| e2.backward(out_grads) |
| |
| for i in range(len(outputs2)): |
| assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), |
| rtol=1e-3, atol=1e-3) |
| input_names = out.list_inputs() |
| for i in range(len(e1.grad_arrays)): |
| name = input_names[i] |
| assert_almost_equal(args_grad1[name].asnumpy(), args_grad2[name].asnumpy(), |
| rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_foreach_rnn(): |
| cell_types = [(mx.rnn.LSTMCell, 2), (mx.rnn.RNNCell, 1), (mx.rnn.GRUCell, 1)] |
| for cell_type, num_states in cell_types: |
| check_foreach_rnn(cell_type, num_states) |
| |
| |
| @with_seed() |
| def test_cut_subgraph_foreach(): |
| class TestLayer(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestLayer, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, inputs, states): |
| def step1(data, states): |
| return data + 1, states |
| out1, states1 = F.contrib.foreach(step1, inputs, states) |
| out2, states2 = F.contrib.foreach(step1, out1, states) |
| def step2(data, states): |
| return data + states[0], states1 |
| out, states = F.contrib.foreach(step2, out2, states) |
| return out |
| |
| data = mx.nd.normal(loc=0, scale=1, shape=(5, 10)) |
| states = mx.nd.normal(loc=0, scale=1, shape=(10)) |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| res1 = layer(data, [states]) |
| |
| with mx.autograd.record(): |
| res1 = layer(data, [states]) |
| |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| layer.hybridize() |
| res2 = layer(data, [states]) |
| |
| with mx.autograd.record(): |
| res2 = layer(data, [states]) |
| assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_uniq_name(): |
| class ForeachLayer1(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(ForeachLayer1, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, inputs, states): |
| def step1(data, states): |
| return data + 1, states |
| out1, states1 = F.contrib.foreach(step1, inputs, states) |
| # The input variables have the same symbol name. |
| out, states = F.contrib.foreach(step1, out1, states1) |
| return out |
| |
| class ForeachLayer2(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(ForeachLayer2, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, inputs, states): |
| def step1(data, states): |
| return data + 1, states |
| out1, states1 = F.contrib.foreach(step1, inputs, states) |
| def step2(data, states): |
| return data, [states[0] + states1[0] + F.squeeze(out1.slice_axis(axis=0, begin=0, end=1))] |
| # The input variables have the same symbol names. |
| # The free variables have the same symbol names as the input variables. |
| out, states = F.contrib.foreach(step2, out1, states1) |
| return out |
| |
| class WhileLayer1(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(WhileLayer1, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, inputs, states): |
| def cond(state1, state2): |
| s = F.squeeze(state1.slice_axis(axis=0, begin=0, end=1)) |
| return s == s |
| def step(state1, state2): |
| return state1 + 1, [state1, state2] |
| states = [states[0], states[0] + 1] |
| out1, states1 = F.contrib.while_loop(cond, step, states, max_iterations=5) |
| # The input variables have the same symbol name. |
| out, states = F.contrib.while_loop(cond, step, states1, max_iterations=5) |
| return out |
| |
| class WhileLayer2(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(WhileLayer2, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, inputs, states): |
| def cond(state1, state2): |
| s = F.squeeze(state1.slice_axis(axis=0, begin=0, end=1)) |
| return s == s |
| def step1(state1, state2): |
| return state1 + 1, [state1, state2] |
| states = [states[0], states[0] + 1] |
| out1, states1 = F.contrib.while_loop(cond, step1, states, max_iterations=5) |
| def step2(state1, state2): |
| return state1 + 1, [state1 + states1[0], state2 + states1[1]] |
| # The input variables have the same symbol name. |
| out, states = F.contrib.while_loop(cond, step2, states1, max_iterations=5) |
| return out |
| |
| TestLayers = [ForeachLayer1, ForeachLayer2, |
| WhileLayer1, WhileLayer2] |
| |
| data = mx.nd.normal(loc=0, scale=1, shape=(2, 5)) |
| states = mx.nd.normal(loc=0, scale=1, shape=(5)) |
| for TestLayer in TestLayers: |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| res1 = layer(data, [states]) |
| |
| with mx.autograd.record(): |
| res1 = layer(data, [states]) |
| |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| layer.hybridize() |
| res2 = layer(data, [states]) |
| |
| with mx.autograd.record(): |
| res2 = layer(data, [states]) |
| assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) |
| |
| |
| @with_seed() |
| def test_cut_subgraph_while_loop(): |
| class TestLayer(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestLayer, self).__init__(prefix=prefix, params=params) |
| def hybrid_forward(self, F, data): |
| out1, data1 = F.contrib.while_loop( |
| cond=lambda i: i <= 5, |
| func=lambda i: (None, (i + 1, )), |
| loop_vars=(data, ), |
| max_iterations=10, |
| ) |
| out2, data2 = F.contrib.while_loop( |
| cond=lambda i: data1[0], |
| func=lambda i: (None, (i + 1, )), |
| loop_vars=data1[0], |
| max_iterations=10, |
| ) |
| return data2[0] |
| data = mx.nd.normal(loc=0, scale=1, shape=(1, )) |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| res1 = layer(data) |
| with mx.autograd.record(): |
| res1 = layer(data) |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| layer.hybridize() |
| res2 = layer(data) |
| with mx.autograd.record(): |
| res2 = layer(data) |
| assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_cut_subgraph_cond(): |
| class TestLayer(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestLayer, self).__init__(prefix=prefix, params=params) |
| def hybrid_forward(self, F, data): |
| data1 = F.contrib.cond( |
| data > 0.5, |
| then_func=lambda: data * 2, |
| else_func=lambda: data * 3, |
| ) |
| data2 = F.contrib.cond( |
| data1 > 0.5, |
| then_func=lambda: data1 * 2, |
| else_func=lambda: data1 * 3, |
| ) |
| return data2 |
| data = mx.nd.normal(loc=0, scale=1, shape=(1, )) |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| res1 = layer(data) |
| with mx.autograd.record(): |
| res1 = layer(data) |
| layer = TestLayer() |
| layer.initialize(ctx=default_context()) |
| layer.hybridize() |
| res2 = layer(data) |
| with mx.autograd.record(): |
| res2 = layer(data) |
| assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| def test_scope(): |
| class TestBlock1(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestBlock1, self).__init__(prefix=prefix, params=params) |
| def hybrid_forward(self, F, data): |
| (new_data, ) = F.contrib.cond( |
| data > 0.5, |
| then_func=lambda: data * 2, |
| else_func=lambda: data * 3, |
| name="my_cond", |
| ) |
| return new_data |
| class TestBlock2(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestBlock2, self).__init__(prefix=prefix, params=params) |
| def hybrid_forward(self, F, data): |
| (new_data, ) = F.contrib.cond( |
| data > 0.5, |
| then_func=lambda: data * 2, |
| else_func=lambda: data * 3, |
| name="my_cond", |
| ) |
| return new_data |
| AttrScope._subgraph_names = defaultdict(int) |
| data = mx.nd.normal(loc=0, scale=1, shape=(1, )) |
| block1 = TestBlock1() |
| block1.initialize(ctx=default_context()) |
| block1.hybridize() |
| _ = block1(data) |
| block2 = TestBlock2() |
| block2.initialize(ctx=default_context()) |
| block2.hybridize() |
| _ = block2(data) |
| assert len(AttrScope._subgraph_names) == 3 |
| assert AttrScope._subgraph_names['my_cond_else'] == 2 |
| assert AttrScope._subgraph_names['my_cond_pred'] == 2 |
| assert AttrScope._subgraph_names['my_cond_then'] == 2 |
| |
| |
| def test_output_format_foreach(): |
| class TestLayer1(gluon.HybridBlock): |
| def __init__(self, step, prefix=None, params=None): |
| super(TestLayer1, self).__init__(prefix=prefix, params=params) |
| self.step = step |
| def hybrid_forward(self, F, ins, states): |
| out, states = F.contrib.foreach(self.step, ins, states) |
| return out, states |
| |
| def step1(data, state): |
| return data, state |
| def step2(data, state): |
| return [data], state |
| def step3(data, state): |
| if isinstance(state, list): |
| return [], [state[0] + data] |
| else: |
| return [], state + data |
| def step4(data, state): |
| if isinstance(state, list): |
| return [data, state[0]], state |
| else: |
| return [data, state], state |
| |
| steps = [step1, step2, step3, step4] |
| data = mx.nd.normal(loc=0, scale=1, shape=(10, 2)) |
| state = mx.nd.normal(loc=0, scale=1, shape=(2)) |
| for step in steps: |
| layer1 = TestLayer1(step) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(data, [state]) |
| out2, state2 = layer2(data, [state]) |
| step_out, step_state = step(data, [state]) |
| assert type(out1) == type(step_out) |
| assert type(out2) == type(step_out) |
| assert type(state1) == type(step_state) |
| assert type(state2) == type(step_state) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| |
| layer1 = TestLayer1(step) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(data, state) |
| out2, state2 = layer2(data, state) |
| step_out, step_state = step(data, state) |
| assert type(out1) == type(step_out) |
| assert type(out2) == type(step_out) |
| assert type(state1) == type(step_state) |
| assert type(state2) == type(step_state) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| |
| if step == step3: |
| continue |
| layer1 = TestLayer1(step) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(data, [state, [state + 1]]) |
| out2, state2 = layer2(data, [state, [state + 1]]) |
| step_out, step_state = step(data, [state, [state + 1]]) |
| assert type(out1) == type(step_out) |
| assert type(out2) == type(step_out) |
| assert type(state1) == type(step_state) |
| assert type(state2) == type(step_state) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| if isinstance(state1[i], list): |
| assert_almost_equal(state1[i][0].asnumpy(), state2[i][0].asnumpy(), |
| rtol=0.001, atol=0.0001) |
| else: |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), |
| rtol=0.001, atol=0.0001) |
| |
| |
| def test_output_format_while(): |
| class TestLayer1(gluon.HybridBlock): |
| def __init__(self, step, use_list, nested_list=False, prefix=None, params=None): |
| super(TestLayer1, self).__init__(prefix=prefix, params=params) |
| self.step = step |
| self.use_list = use_list |
| self.nested_list = nested_list |
| def hybrid_forward(self, F, states): |
| def cond(state1): |
| scalar = state1.slice_axis(axis=0, begin=0, end=1) |
| return scalar == scalar |
| cond_func = cond |
| if self.use_list: |
| states = [states] |
| elif self.nested_list: |
| def cond2(state1, state2): |
| scalar = state1.slice_axis(axis=0, begin=0, end=1) |
| return scalar == scalar |
| cond_func = cond2 |
| states = [states, [states + 1]] |
| out, states = F.contrib.while_loop(cond_func, self.step, states, max_iterations=5) |
| return out, states |
| |
| def step1(state): |
| return state, state |
| def step2(state): |
| if isinstance(state, list): |
| return state, state |
| else: |
| return [state], state |
| def step3(state): |
| return [], state |
| |
| steps = [step1, step2, step3] |
| state = mx.nd.normal(loc=0, scale=1, shape=(2)) |
| for step in steps: |
| layer1 = TestLayer1(step, False) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step, False) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(state) |
| out2, state2 = layer2(state) |
| assert type(out1) == type(out2) |
| assert type(state1) == type(state1) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| |
| layer1 = TestLayer1(step, True) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step, True) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(state) |
| out2, state2 = layer2(state) |
| assert type(out1) == type(out2) |
| assert type(state1) == type(state2) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| |
| def step4(state, state2): |
| states = _as_list(state) |
| states.append(state2) |
| return state, states |
| def step5(state, state2): |
| states = _as_list(state) |
| states.append(state2) |
| if isinstance(state, list): |
| return state, states |
| else: |
| return [state], states |
| def step6(state, state2): |
| states = _as_list(state) |
| states.append(state2) |
| return [], states |
| |
| steps = [step4, step5, step6] |
| for step in steps: |
| layer1 = TestLayer1(step, False, True) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(step, False, True) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1, state1 = layer1(state) |
| out2, state2 = layer2(state) |
| assert type(out1) == type(out2) |
| assert type(state1) == type(state2) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| state1 = _as_list(state1) |
| state2 = _as_list(state2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| for i in range(len(state1)): |
| if not isinstance(state1[i], list): |
| assert_almost_equal(state1[i].asnumpy(), state2[i].asnumpy(), |
| rtol=0.001, atol=0.0001) |
| |
| |
| def test_output_format_cond(): |
| class TestLayer1(gluon.HybridBlock): |
| def __init__(self, func, prefix=None, params=None): |
| super(TestLayer1, self).__init__(prefix=prefix, params=params) |
| self.func = func |
| def hybrid_forward(self, F, data): |
| def then_func(): |
| return self.func(data) |
| def else_func(): |
| return self.func(data) |
| return F.contrib.cond(data.slice_axis(axis=0, begin=0, end=1), |
| then_func, else_func) |
| |
| def func1(data): |
| return data |
| def func2(data): |
| return [data] |
| def func3(data): |
| return [data, data] |
| |
| funcs = [func1, func2, func3] |
| data = mx.nd.normal(loc=0, scale=1, shape=(2)) |
| for func in funcs: |
| layer1 = TestLayer1(func) |
| layer1.initialize(ctx=default_context()) |
| layer2 = TestLayer1(func) |
| layer2.initialize(ctx=default_context()) |
| layer2.hybridize() |
| out1 = layer1(data) |
| out2 = layer2(data) |
| func_out = func(data) |
| assert type(out1) == type(func_out) |
| assert type(out2) == type(func_out) |
| out1 = _as_list(out1) |
| out2 = _as_list(out2) |
| for i in range(len(out1)): |
| assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) |
| |
| def test_foreach_with_unkown_dim(): |
| # MXNet supports using 0 as placeholder for unknown dimensions in shape |
| step = lambda data, states: (data + states[0], [states[0] * 2]) |
| # input shape with NCHW format and N is unknown |
| data = mx.sym.var('data', shape=(0, 3, 32, 32)) |
| states = [mx.sym.var('state')] |
| outs, states = mx.sym.contrib.foreach(step, data, states) |
| _, output_shape, _ = outs.infer_shape_partial() |
| assert_allclose((0, 3, 32, 32), output_shape[0]) |
| |
| if __name__ == '__main__': |
| import nose |
| nose.runmodule() |