blob: ce82f2b1ad6c17bd1afa23f76405004519bda292 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import mxnet as mx
import mxnet.ndarray as nd
import numpy
import os
import pickle
from collections import OrderedDict
import logging
from utils import *
logger = logging.getLogger(__name__)
class Base(object):
"""Basic wrapper for the symbols
Parameters
----------
data_shapes : dict
The shapes of tensor variables
sym_gen : mx.sym.Symbol
Symbol of the network
params : None or dict, optional
params_grad : None or dict, optional
aux_states:
initializer:
ctx:
name:
"""
def __init__(self, data_shapes, sym_gen, params=None, aux_states=None,
default_bucket_kwargs=None, learn_init_keys=None,
initializer=mx.init.Xavier(factor_type="in", rnd_type="gaussian", magnitude=2),
ctx=mx.gpu(), name='Net'):
self.sym_gen = sym_gen
bucket_kwargs = default_bucket_kwargs.copy() if \
default_bucket_kwargs is not None else dict()
self.curr_bucket_key = None
self.ctx = ctx
self.name = name
self.initializer = initializer
if params is None:
self.params = None
self.params_grad = None
else:
self.params = OrderedDict([(k, v.copyto(ctx)) for k, v in params.items()])
self.params_grad = OrderedDict([(n, nd.empty(v.shape, ctx=ctx))
for n, v in self.params.items()])
if aux_states is not None:
self.aux_states = OrderedDict([(k, v.copyto(ctx)) for k, v in aux_states.items()])
else:
self.aux_states = None
self._buckets = dict()
self.learn_init_keys = learn_init_keys if learn_init_keys is not None else []
self.learn_init_key_shapes = {k: data_shapes[k] for k in self.learn_init_keys}
self.switch_bucket(bucket_kwargs=bucket_kwargs, data_shapes=data_shapes)
self.acc_grad = None
@property
def exe(self):
"""Get the current executor
Returns
-------
exe : mxnet.executor.Executor
"""
return self._buckets[self.curr_bucket_key]['exe'][tuple(self.data_shapes.items())]
@property
def data_shapes(self):
return self._buckets[self.curr_bucket_key]['data_shapes']
@property
def sym(self):
return self._buckets[self.curr_bucket_key]['sym']
def switch_bucket(self, bucket_kwargs=None, data_shapes=None):
if bucket_kwargs is not None:
self.curr_bucket_key = get_bucket_key(bucket_kwargs=bucket_kwargs)
# 1. Check if bucket key exists
if self.curr_bucket_key in self._buckets:
if data_shapes is not None:
if tuple(data_shapes.items()) not in self._buckets[self.curr_bucket_key]['exe']:
#TODO Optimize the reshaping functionality!
self._buckets[self.curr_bucket_key]['exe'][tuple(data_shapes.items())] = \
self.exe.reshape(partial_shaping=True, allow_up_sizing=True, **data_shapes)
self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
else:
self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
return
# 2. If the bucket key does not exist, create new symbol + executor
assert data_shapes is not None, "Must set data_shapes for new bucket!"
if isinstance(self.sym_gen, mx.symbol.Symbol):
sym = self.sym_gen
else:
sym = self.sym_gen(**dict(self.curr_bucket_key))
arg_names = sym.list_arguments()
aux_names = sym.list_auxiliary_states()
param_names = [n for n in arg_names
if n in self.learn_init_keys or (n not in data_shapes.keys())]
for k, v in data_shapes.items():
assert isinstance(v, tuple), "Data_shapes must be tuple! Find k=%s, v=%s, " \
"data_shapes=%s" % (k, str(v), str(data_shapes))
arg_shapes, _, aux_shapes = sym.infer_shape(**data_shapes)
arg_name_shape = OrderedDict([(k, s) for k, s in zip(arg_names, arg_shapes)])
if self.params is None:
self.params = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
for n in param_names])
self.params_grad = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
for n in param_names])
if len(self.params) > 0:
assert self.initializer is not None, \
'We must set the initializer if we donnot initialize' \
'manually the free parameters of the network!!'
for k, v in self.params.items():
self.initializer(k, v)
else:
assert set(arg_name_shape.items()) == \
set(data_shapes.items() + [(k, v.shape) for k, v in self.params.items()])
if self.aux_states is None:
self.aux_states = OrderedDict([(k, nd.empty(s, ctx=self.ctx))
for k, s in zip(aux_names, aux_shapes)])
data_inputs = {k: mx.nd.empty(data_shapes[k], ctx=self.ctx)
for k in set(data_shapes.keys()) - set(self.learn_init_keys)}
if len(self._buckets) > 0:
shared_exe = list(list(self._buckets.values())[0]['exe'].values())[0]
else:
shared_exe = None
self._buckets[self.curr_bucket_key] = {
'exe': {tuple(data_shapes.items()):
sym.bind(ctx=self.ctx,
args=dict(self.params, **data_inputs),
args_grad=dict(self.params_grad.items()),
aux_states=self.aux_states,
shared_exec=shared_exe)
},
'data_shapes': data_shapes,
'sym': sym
}
def save_params(self, dir_path="", epoch=None):
param_saving_path = save_params(dir_path=dir_path, name=self.name, epoch=epoch,
params=self.params,
aux_states=self.aux_states)
misc_saving_path = save_misc(dir_path=dir_path, epoch=epoch, name=self.name,
content={'data_shapes': {k: map(int, v) for k, v in self.data_shapes.items()}})
logging.info('Saving %s, params: \"%s\", misc: \"%s\"',
self.name, param_saving_path, misc_saving_path)
def load_params(self, name="", dir_path="", epoch=None):
params, aux_states, param_loading_path = load_params(dir_path=dir_path, epoch=epoch, name=name)
logging.info('Loading params from \"%s\" to %s' % (param_loading_path, self.name))
for k, v in params.items():
if k in self.params:
logging.debug(' Loading %s %s' %(k, str(v.shape)))
self.params[k][:] = v
else:
logging.warn("Found unused param in the saved model file: %s" % k)
for k, v in aux_states.items():
self.aux_states[k][:] = v
@property
def internal_sym_names(self):
return self.sym.get_internals().list_outputs()
@property
def output_keys(self):
return self.sym.list_outputs()
def compute_internal(self, sym_name, bucket_kwargs=None, **arg_dict):
"""
View the internal symbols using the forward function.
:param sym_name:
:param bucket_kwargs:
:param input_dict:
:return:
"""
data_shapes = {k: v.shape for k, v in arg_dict.items()}
self.switch_bucket(bucket_kwargs=bucket_kwargs,
data_shapes=data_shapes)
internal_sym = self.sym.get_internals()[sym_name]
data_inputs = {k: mx.nd.empty(v, ctx=self.ctx)
for k, v in self.data_shapes.items()
if k in internal_sym.list_arguments()}
params = {k: v for k, v in self.params.items() if
k in internal_sym.list_arguments()}
aux_states = {k: v for k, v in self.aux_states.items()
if k in internal_sym.list_auxiliary_states()}
exe = internal_sym.bind(ctx=self.ctx,
args=dict(params, **data_inputs),
args_grad=None,
grad_req='null',
aux_states=aux_states,
shared_exec=self.exe)
for k, v in arg_dict.items():
exe.arg_dict[k][:] = v
exe.forward(is_train=False)
assert 1 == len(exe.outputs)
for output in exe.outputs:
output.wait_to_read()
return exe.outputs[0]
def forward(self, is_train=False, bucket_kwargs=None, **arg_dict):
#import time
#start = time.time()
data_shapes = {k: v.shape for k, v in arg_dict.items()}
for name in self.learn_init_keys:
data_shapes[name] = self.learn_init_key_shapes[name]
self.switch_bucket(bucket_kwargs=bucket_kwargs,
data_shapes=data_shapes)
#end = time.time()
#print 'Swith Bucket:', end - start
#start = time.time()
for k, v in arg_dict.items():
assert self.exe.arg_dict[k].shape == v.shape,\
"Shape not match: key %s, need %s, received %s" \
%(k, str(self.exe.arg_dict[k].shape), str(v.shape))
self.exe.arg_dict[k][:] = v
self.exe.forward(is_train=is_train)
for output in self.exe.outputs:
output.wait_to_read()
#end = time.time()
#print 'Forward:', end - start
return self.exe.outputs
def backward(self, out_grads=None, **arg_dict):
for k, v in arg_dict.items():
assert self.exe.arg_dict[k].shape == v.shape, \
"Shape not match: key %s, need %s, received %s" \
% (k, str(self.exe.arg_dict[k].shape), str(v.shape))
self.exe.arg_dict[k][:] = v
self.exe.backward(out_grads=out_grads)
def forward_backward(self, bucket_kwargs=None, out_grads=None, **arg_dict):
data_shapes = {k: v.shape for k, v in arg_dict.items()}
for name in self.learn_init_keys:
data_shapes[name] = self.learn_init_key_shapes[name]
self.switch_bucket(bucket_kwargs=bucket_kwargs,
data_shapes=data_shapes)
for k, v in arg_dict.items():
self.exe.arg_dict[k][:] = v
self.exe.forward(is_train=True)
self.exe.backward(out_grads=out_grads)
for output in self.exe.outputs:
output.wait_to_read()
return self.exe.outputs
def update(self, updater, params_grad=None):
if params_grad is None:
params_grad = self.params_grad
assert type(params_grad) is OrderedDict
for ind, k in enumerate(self.params.keys()):
updater(index=ind, grad=params_grad[k], weight=self.params[k])
def update_acc_grad(self):
if self.acc_grad is None:
self.acc_grad = OrderedDict([(n, nd.zeros(v.shape, ctx=self.ctx))
for n, v in self.params_grad.items()])
for k, v in self.acc_grad.items():
v[:] = v + self.params_grad[k]
def reset_acc_grad(self):
for v in self.acc_grad.values():
v[:] = 0
def copy(self, name=None, ctx=None):
if ctx is None:
ctx = self.ctx
if name is None:
name = self.name + '-copy-' + str(ctx)
return Base(data_shapes=self.data_shapes,
sym_gen=self.sym_gen,
default_bucket_kwargs=dict(self.curr_bucket_key),
params=self.params,
aux_states=self.aux_states, ctx=ctx, name=name)
def copy_params_to(self, dst):
for k, v in self.params.items():
dst.params[k][:] = v
# TODO `wait_to_read()` here seems unnecessary, remove it in the future!
dst.params[k].wait_to_read()
@property
def total_param_num(self):
return sum(v.size for v in self.params.values())
def print_stat(self):
logging.info("Name: %s" % self.name)
assert self.params is not None, "Fatal Error!"
logging.info("Params: ")
for k, v in self.params.items():
logging.info(" %s: %s" % (k, v.shape))
if self.aux_states is None or 0 == len(self.aux_states):
logging.info("Aux States: None")
else:
logging.info("Aux States: " + ' '.join(
["%s:%s" % (str(k), str(v.shape)) for k, v in self.aux_states.items()]))
logging.info("Total Parameter Num: " + str(self.total_param_num))