| # coding: utf-8 |
| """Autograd for NDArray.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| |
| import ctypes |
| from .base import _LIB, check_call, string_types |
| from .base import mx_uint, NDArrayHandle, c_array |
| from .ndarray import NDArray |
| from .symbol import _GRAD_REQ_MAP |
| |
| |
| def set_is_training(is_train): |
| """Set status to training/not training. When training, graph will be constructed |
| for gradient computation. Operators will also run with ctx.is_train=True. For example, |
| Dropout will drop inputs randomly when is_train=True while simply passing through |
| if is_train=False. |
| |
| Parameters |
| ---------- |
| is_train: bool |
| |
| Returns |
| ------- |
| previous state before this set. |
| """ |
| prev = ctypes.c_int() |
| check_call(_LIB.MXAutogradSetIsTraining( |
| ctypes.c_int(is_train), ctypes.byref(prev))) |
| return bool(prev.value) |
| |
| |
| class TrainingStateScope(object): |
| """Scope for managing training state. |
| |
| Example:: |
| with TrainingStateScope(True): |
| y = model(x) |
| compute_gradient([y]) |
| """ |
| def __init__(self, enter_state): |
| self._enter_state = enter_state |
| self._prev = None |
| |
| def __enter__(self): |
| self._prev = set_is_training(self._enter_state) |
| |
| def __exit__(self, ptype, value, trace): |
| if self._prev != self._enter_state: |
| set_is_training(self._prev) |
| |
| |
| def train_section(): |
| """Returns a training scope context to be used in 'with' statement |
| and captures training code. |
| |
| Example:: |
| with autograd.train_section(): |
| y = model(x) |
| compute_gradient([y]) |
| metric.update(...) |
| optim.step(...) |
| """ |
| return TrainingStateScope(True) |
| |
| |
| def test_section(): |
| """Returns a testing scope context to be used in 'with' statement |
| and captures testing code. |
| |
| Example:: |
| with autograd.train_section(): |
| y = model(x) |
| compute_gradient([y]) |
| with autograd.test_section(): |
| # testing, IO, gradient updates... |
| """ |
| return TrainingStateScope(False) |
| |
| |
| def mark_variables(variables, gradients, grad_reqs='write'): |
| """Mark NDArrays as variables to compute gradient for autograd. |
| |
| Parameters |
| ---------- |
| variables: NDArray or list of NDArray |
| gradients: NDArray or list of NDArray |
| grad_reqs: str or list of str |
| """ |
| if isinstance(variables, NDArray): |
| assert isinstance(gradients, NDArray) |
| variables = [variables] |
| gradients = [gradients] |
| |
| variable_handles = [] |
| gradient_handles = [] |
| for var, gradvar in zip(variables, gradients): |
| variable_handles.append(var.handle) |
| gradient_handles.append(gradvar.handle) |
| if isinstance(grad_reqs, string_types): |
| grad_reqs = [_GRAD_REQ_MAP[grad_reqs]]*len(variables) |
| else: |
| grad_reqs = [_GRAD_REQ_MAP[i] for i in grad_reqs] |
| |
| check_call(_LIB.MXAutogradMarkVariables( |
| len(variable_handles), |
| c_array(NDArrayHandle, variable_handles), |
| c_array(mx_uint, grad_reqs), |
| c_array(NDArrayHandle, gradient_handles))) |
| |
| |
| def backward(heads, head_grads=None, retain_graph=False): |
| """Compute the gradients of heads w.r.t previously marked variables. |
| |
| Parameters |
| ---------- |
| heads: NDArray or list of NDArray |
| Output NDArray(s) |
| head_grads: NDArray or list of NDArray or None |
| Gradients with respect to heads. |
| """ |
| if isinstance(heads, NDArray): |
| assert head_grads is None or isinstance(head_grads, NDArray) |
| heads = [heads] |
| head_grads = [head_grads] if head_grads is not None else None |
| |
| output_handles = [] |
| for arr in heads: |
| output_handles.append(arr.handle) |
| |
| if head_grads is None: |
| check_call(_LIB.MXAutogradBackward( |
| len(output_handles), |
| c_array(NDArrayHandle, output_handles), |
| ctypes.c_void_p(0), |
| ctypes.c_int(retain_graph))) |
| return |
| |
| ograd_handles = [] |
| for arr in head_grads: |
| if arr is not None: |
| ograd_handles.append(arr.handle) |
| else: |
| ograd_handles.append(NDArrayHandle(0)) |
| assert len(ograd_handles) == len(output_handles), \ |
| "heads and head_grads must have the same length" |
| |
| check_call(_LIB.MXAutogradBackward( |
| len(output_handles), |
| c_array(NDArrayHandle, output_handles), |
| c_array(NDArrayHandle, ograd_handles), |
| ctypes.c_int(retain_graph))) |