blob: 0fb9eb3baade931bc228195d720b7ea9a72b0629 [file] [log] [blame]
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments, too-many-statements
"""Executor manager."""
from __future__ import absolute_import
import logging
import numpy as np
from .base import mx_real_t
from . import ndarray as nd
from .context import cpu
from .io import DataDesc
def _split_input_slice(batch_size, work_load_list):
"""Get input slice from the input shape.
Parameters
----------
batch_size : int
The number of samples in a mini-batch.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as `ctx`.
Returns
-------
slices : list of slice
The split slices to get a specific slice.
Raises
------
ValueError
In case of too many splits, leading to some empty slices.
"""
total_work_load = sum(work_load_list)
batch_num_list = [round(work_load * batch_size / total_work_load)
for work_load in work_load_list]
batch_num_sum = sum(batch_num_list)
if batch_num_sum < batch_size:
batch_num_list[-1] += batch_size - batch_num_sum
slices = []
end = 0
for batch_num in batch_num_list:
begin = int(min((end, batch_size)))
end = int(min((begin + batch_num, batch_size)))
if begin >= end:
raise ValueError('Too many slices. Some splits are empty.')
slices.append(slice(begin, end))
return slices
def _check_arguments(symbol):
"""Check the argument names of symbol.
This function checks the duplication of arguments in Symbol.
The check is done for feedforward net for now.
Parameters
----------
symbol : Symbol
The network configuration.
"""
arg_set = set()
arg_names = symbol.list_arguments()
for name in arg_names:
if name in arg_set:
raise ValueError(('Find duplicated argument name \"%s\", ' +
'please make the weight name non-duplicated(using name arguments), ' +
'arguments are %s') % (name, str(arg_names)))
arg_set.add(name)
aux_set = set()
aux_names = symbol.list_auxiliary_states()
for name in aux_names:
if name in aux_set:
raise ValueError(
('Find duplicated auxiliary param name \"%s\", ' +
'please make the weight name non-duplicated(using name arguments), ' +
'arguments are %s, auxiliary params are %s'
) % (name, str(arg_names), str(aux_names)))
aux_set.add(name)
def _load_general(data, targets):
"""Load a list of arrays into a list of arrays specified by slices."""
for d_src, d_targets in zip(data, targets):
if isinstance(d_targets, nd.NDArray):
d_src.copyto(d_targets)
else:
assert d_targets[-1][0].stop == d_src.shape[0], \
"Batch size miss match. Expected %d, got %d"%( \
d_targets[-1][0].stop, d_src.shape[0])
for slice_idx, d_dst in d_targets:
d_src[slice_idx].copyto(d_dst)
def _load_data(batch, targets):
"""Load data into sliced arrays."""
_load_general(batch.data, targets)
def _load_label(batch, targets):
"""Load label into sliced arrays."""
_load_general(batch.label, targets)
# pylint: disable=too-many-branches
def _bind_exec(sym, ctx, input_shapes, param_names, need_grad=False,
base_exec=None, shared_data_arrays=None, input_types=None, logger=logging):
"""bind executor for bucketing, potentially sharing data with an existing executor."""
arg_shape, _, aux_shape = sym.infer_shape(**input_shapes)
assert(arg_shape is not None)
if input_types is None:
input_types = {k: mx_real_t for k in input_shapes.keys()}
arg_types, _, aux_types = sym.infer_type(**input_types)
assert(arg_types is not None)
arg_arrays = []
grad_arrays = {} if need_grad != False else None
arg_names = sym.list_arguments()
if need_grad is False:
need_grad = set()
elif need_grad is True:
need_grad = set(arg_names) - set(input_shapes.keys())
elif isinstance(need_grad, set):
pass
else:
raise AssertionError("need_grad must be boolean or set.")
grad_req = {name:('write' if name in need_grad else 'null') for name in arg_names}
# create or borrow arguments and gradients
for i, name in enumerate(arg_names):
if not name in param_names:
# data or label
if shared_data_arrays is not None and \
name in shared_data_arrays:
arg_arr = shared_data_arrays[name]
if np.prod(arg_arr.shape) >= np.prod(arg_shape[i]):
# good, we can share this memory
assert(arg_types[i] == arg_arr.dtype)
arg_arr = arg_arr.reshape(arg_shape[i])
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape[i])) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to be the bucket taking the largest ') +
('input for better memory sharing.'))
arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
if shared_data_arrays is not None:
shared_data_arrays[name] = arg_arr
arg_arrays.append(arg_arr)
else:
# model parameter
if base_exec is None:
arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
if name in need_grad:
grad_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
grad_arrays[name] = grad_arr
else:
arg_arr = base_exec.arg_dict[name]
assert arg_arr.shape == arg_shape[i]
assert arg_arr.dtype == arg_types[i]
if name in need_grad:
grad_arrays[name] = base_exec.grad_dict[name]
arg_arrays.append(arg_arr)
# create or borrow aux variables
if base_exec is None:
aux_arrays = [nd.zeros(s, ctx, dtype=t) for s, t in zip(aux_shape, aux_types)]
else:
for i, a in enumerate(base_exec.aux_arrays):
assert aux_shape[i] == a.shape
assert aux_types[i] == a.dtype
aux_arrays = [a for a in base_exec.aux_arrays]
executor = sym.bind(ctx=ctx, args=arg_arrays, args_grad=grad_arrays,
aux_states=aux_arrays,
grad_req=grad_req, shared_exec=base_exec)
return executor
class DataParallelExecutorGroup(object):
"""A group of executors living on different devices, for data parallelization.
Parameters
----------
sym: Symbol
The network configuration.
arg_names: list of str
Equals `sym.list_arguments()`
param_names: list of str
List of names of all trainable parameters.
ctx: list of Context
List of devices for training (data parallelization).
slices: list of int
Describes how the data parallelization splits data into different devices.
train_data: DataIter (or DataBatch)
The dataset for training. It could be any object with `provide_data` and
`provide_label` properties. Loading of actual data is not necessarily needed
at this stage.
shared_grop: DataParallelExecutorGroup
An existing executor group, if to share parameters with it.
"""
def __init__(self, sym, arg_names, param_names, ctx, slices, train_data, shared_group=None):
# make sure the architecture is valid
_check_arguments(sym)
if shared_group is None:
self.shared_data_arrays = [{} for _ in ctx]
else:
self.shared_data_arrays = shared_group.shared_data_arrays
self.data_names = [x[0] for x in train_data.provide_data]
self.label_names = [x[0] for x in train_data.provide_label]
self.aux_names = sym.list_auxiliary_states()
self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names]
self.param_names = [arg_names[i] for i in self.param_idx]
self.train_execs = []
for i, ctxi in enumerate(ctx):
data_shapes = {}
data_types = {}
for x in train_data.provide_data + train_data.provide_label:
data_shapes[x[0]] = tuple([slices[i].stop - slices[i].start] + list(x[1][1:]))
if isinstance(x, DataDesc):
data_types[x.name] = x.dtype
else:
data_types[x[0]] = mx_real_t
shared_exec = None if shared_group is None else shared_group.train_execs[i]
train_exec = _bind_exec(sym, ctxi, data_shapes, self.param_names,
need_grad=True, base_exec=shared_exec,
shared_data_arrays=self.shared_data_arrays[i],
input_types=data_types)
self.train_execs.append(train_exec)
# data structure
self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.data_names]
self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.label_names]
self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs]
for i in self.param_idx]
self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs]
for i in self.param_idx]
self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs]
for i in range(len(self.aux_names))]
self.slices = slices
def load_data_batch(self, data_batch):
"""Load data and labels into arrays."""
_load_data(data_batch, self.data_arrays)
_load_label(data_batch, self.label_arrays)
def forward(self, is_train=False):
"""Perform a forward pass on each executor."""
for texec in self.train_execs:
texec.forward(is_train=is_train)
def backward(self):
"""Perform a backward pass on each executor."""
for texec in self.train_execs:
texec.backward()
def update_metric(self, metric, labels):
"""Update evaluation metric with label and current outputs."""
for texec, islice in zip(self.train_execs, self.slices):
labels_slice = [label[islice] for label in labels]
metric.update(labels_slice, texec.outputs)
class DataParallelExecutorManager(object):
""" Helper class to manage multiple executors for data parallelism.
Parameters
----------
symbol : Symbol
Output symbol.
ctx : list of Context
Devices to run on.
param_names: list of str
Name of all trainable parameters of the network.
arg_names: list of str
Name of all arguments of the network.
aux_names: list of str
Name of all auxiliary states of the network.
train_data : DataIter
Training data iterator.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx.
logger : logging logger
When not specified, default logger will be used.
sym_gen : A function that generate new Symbols depending on different
input shapes. Used only for bucketing.
"""
def __init__(self, symbol, ctx, train_data,
arg_names, param_names, aux_names,
work_load_list=None, logger=None, sym_gen=None):
if logger is None:
logger = logging
# preparation
num_device = len(ctx)
logger.info('Start training with %s', str(ctx))
if work_load_list is None:
work_load_list = [1] * num_device
assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \
"Invalid settings for work load. "
slices = _split_input_slice(train_data.batch_size, work_load_list)
self.slices = slices
self.arg_names = arg_names
self.param_names = param_names
self.aux_names = aux_names
self.ctx = ctx
self.execgrp = DataParallelExecutorGroup(symbol, self.arg_names, self.param_names, self.ctx,
self.slices, train_data)
self.symbol = symbol
self.sym_gen = sym_gen
self.curr_execgrp = None # this is set when data is loaded
if self.sym_gen is not None:
self.execgrp_bucket = {train_data.default_bucket_key: self.execgrp}
def install_monitor(self, monitor):
"""Install monitor on all executors."""
if self.sym_gen is not None:
raise NotImplementedError("Monitoring is not implemented for bucketing")
for train_exec in self.execgrp.train_execs:
monitor.install(train_exec)
def set_params(self, arg_params, aux_params):
"""Set parameter and aux values.
Parameters
----------
arg_params : list of NDArray
Source parameter arrays
aux_params : list of NDArray
Source aux arrays.
"""
for texec in self.execgrp.train_execs:
texec.copy_params_from(arg_params, aux_params)
def copy_to(self, arg_params, aux_params):
""" Copy data from each executor to ```arg_params`` and ``aux_params``.
Parameters
----------
arg_params : list of NDArray
Target parameter arrays.
aux_params : list of NDArray
Target aux arrays.
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_params.
"""
for name, block in zip(self.param_names, self.param_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.astype(arg_params[name].dtype).copyto(arg_params[name])
for name, block in zip(self.aux_names, self.aux_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.astype(aux_params[name].dtype).copyto(aux_params[name])
@property
def param_arrays(self):
"""Shared parameter arrays."""
# param arrays should be shared by all executor groups
return self.execgrp.param_arrays
@property
def grad_arrays(self):
"""Shared gradient arrays."""
# grad arrays should be shared by all executor groups
return self.execgrp.grad_arrays
@property
def aux_arrays(self):
"""Shared aux states."""
# aux arrays are also shared by all executor groups
return self.execgrp.aux_arrays
def load_data_batch(self, data_batch):
"""Load data and labels into arrays."""
if self.sym_gen is not None:
key = data_batch.bucket_key
if key not in self.execgrp_bucket:
# create new bucket entry
symbol = self.sym_gen(key)
execgrp = DataParallelExecutorGroup(symbol, self.arg_names,
self.param_names, self.ctx,
self.slices, data_batch,
shared_group=self.execgrp)
self.execgrp_bucket[key] = execgrp
self.curr_execgrp = self.execgrp_bucket[key]
else:
self.curr_execgrp = self.execgrp
self.curr_execgrp.load_data_batch(data_batch)
def forward(self, is_train=False):
"""Run forward on the current executor."""
self.curr_execgrp.forward(is_train=is_train)
def backward(self):
"""Run backward on the current executor."""
self.curr_execgrp.backward()
def update_metric(self, metric, labels):
"""Update metric with the current executor."""
self.curr_execgrp.update_metric(metric, labels)