blob: bae11e18021dd170121146f1cd758155d1cfa6c7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import os
import numpy
import json
import sys
import re
import scipy.signal
import logging
import ast
import inspect
import collections
import numbers
try:
import cPickle as pickle
except:
import pickle
from collections import namedtuple, OrderedDict
import time
import mxnet as mx
import mxnet.ndarray as nd
_ctx = mx.cpu()
_numpy_rng = numpy.random.RandomState(123456)
def get_default_ctx():
return _ctx
def get_numpy_rng():
return _numpy_rng
def get_saving_path(prefix="", epoch=None):
sym_saving_path = os.path.join('%s-symbol.json' % prefix)
if epoch is not None:
param_saving_path = os.path.join('%s-%05d.params' % (prefix, epoch))
else:
param_saving_path = os.path.join('%s.params' % prefix)
misc_saving_path = os.path.join('%s-misc.json' % prefix)
return sym_saving_path, param_saving_path, misc_saving_path
def logging_config(name=None, level=logging.DEBUG, console_level=logging.DEBUG):
if name is None:
name = inspect.stack()[1][1].split('.')[0]
folder = os.path.join(os.getcwd(), name)
if not os.path.exists(folder):
os.makedirs(folder)
logpath = os.path.join(folder, name + ".log")
print("All Logs will be saved to %s" %logpath)
logging.root.setLevel(level)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logfile = logging.FileHandler(logpath)
logfile.setLevel(level)
logfile.setFormatter(formatter)
logging.root.addHandler(logfile)
#TODO Update logging patterns in other files
logconsole = logging.StreamHandler()
logconsole.setLevel(console_level)
logconsole.setFormatter(formatter)
logging.root.addHandler(logconsole)
return folder
def save_params(dir_path=os.curdir, epoch=None, name="", params=None, aux_states=None,
ctx=mx.cpu()):
prefix = os.path.join(dir_path, name)
_, param_saving_path, _ = get_saving_path(prefix, epoch)
if not os.path.isdir(dir_path) and not (dir_path == ""):
os.makedirs(dir_path)
save_dict = {('arg:%s' % k): v.copyto(ctx) for k, v in params.items()}
save_dict.update({('aux:%s' % k): v.copyto(ctx) for k, v in aux_states.items()})
nd.save(param_saving_path, save_dict)
return param_saving_path
def save_misc(dir_path=os.curdir, epoch=None, name="", content=None):
prefix = os.path.join(dir_path, name)
_, _, misc_saving_path = get_saving_path(prefix, epoch)
with open(misc_saving_path, 'w') as fp:
json.dump(content, fp)
return misc_saving_path
def quick_save_json(dir_path=os.curdir, file_name="", content=None):
file_path = os.path.join(dir_path, file_name)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
with open(file_path, 'w') as fp:
json.dump(content, fp)
logging.info('Save json into %s' % file_path)
def safe_eval(expr):
if type(expr) is str:
return ast.literal_eval(expr)
else:
return expr
def norm_clipping(params_grad, threshold):
assert isinstance(params_grad, dict)
norm_val = numpy.sqrt(sum([nd.norm(grad).asnumpy()[0]**2 for grad in params_grad.values()]))
# print('grad norm: %g' % norm_val)
ratio = 1.0
if norm_val > threshold:
ratio = threshold / norm_val
for grad in params_grad.values():
grad *= ratio
return norm_val
def sample_categorical(prob, rng):
"""Sample from independent categorical distributions
Each batch is an independent categorical distribution.
Parameters
----------
prob : numpy.ndarray
Probability of the categorical distribution. Shape --> (batch_num, category_num)
rng : numpy.random.RandomState
Returns
-------
ret : numpy.ndarray
Sampling result. Shape --> (batch_num,)
"""
ret = numpy.empty(prob.shape[0], dtype=numpy.float32)
for ind in range(prob.shape[0]):
ret[ind] = numpy.searchsorted(numpy.cumsum(prob[ind]), rng.rand()).clip(min=0.0,
max=prob.shape[
1] - 0.5)
return ret
def sample_normal(mean, var, rng):
"""Sample from independent normal distributions
Each element is an independent normal distribution.
Parameters
----------
mean : numpy.ndarray
Means of the normal distribution. Shape --> (batch_num, sample_dim)
var : numpy.ndarray
Variance of the normal distribution. Shape --> (batch_num, sample_dim)
rng : numpy.random.RandomState
Returns
-------
ret : numpy.ndarray
The sampling result. Shape --> (batch_num, sample_dim)
"""
ret = numpy.sqrt(var) * rng.randn(*mean.shape) + mean
return ret
def sample_mog(prob, mean, var, rng):
"""Sample from independent mixture of gaussian (MoG) distributions
Each batch is an independent MoG distribution.
Parameters
----------
prob : numpy.ndarray
mixture probability of each gaussian. Shape --> (batch_num, center_num)
mean : numpy.ndarray
mean of each gaussian. Shape --> (batch_num, center_num, sample_dim)
var : numpy.ndarray
variance of each gaussian. Shape --> (batch_num, center_num, sample_dim)
rng : numpy.random.RandomState
Returns
-------
ret : numpy.ndarray
sampling result. Shape --> (batch_num, sample_dim)
"""
gaussian_inds = sample_categorical(prob, rng).astype(numpy.int32)
mean = mean[numpy.arange(mean.shape[0]), gaussian_inds, :]
var = var[numpy.arange(mean.shape[0]), gaussian_inds, :]
ret = sample_normal(mean=mean, var=var, rng=rng)
return ret
def npy_softmax(x, axis=1):
e_x = numpy.exp(x - numpy.max(x, axis=axis, keepdims=True))
out = e_x / e_x.sum(axis=axis, keepdims=True)
return out
def npy_sigmoid(x):
return 1/(1 + numpy.exp(-x))
def npy_onehot(x, num):
ret = numpy.zeros(shape=(x.size, num))
ret[numpy.arange(x.size), x.ravel()] = 1
ret = ret.reshape(x.shape + (num,))
return ret
def npy_binary_entropy(prediction, target):
assert prediction.shape == target.shape
return - (numpy.log(prediction + 1E-9) * target +
numpy.log(1 - prediction + 1E-9) * (1 - target)).sum()
def block_all(sym_list):
return [mx.symbol.BlockGrad(sym) for sym in sym_list]
def load_params(dir_path="", epoch=None, name=""):
prefix = os.path.join(dir_path, name)
_, param_loading_path, _ = get_saving_path(prefix, epoch)
while not os.path.isfile(param_loading_path):
logging.info("in load_param, %s Not Found!" % param_loading_path)
time.sleep(60)
save_dict = nd.load(param_loading_path)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return arg_params, aux_params, param_loading_path
def load_misc(dir_path="", epoch=None, name=""):
prefix = os.path.join(dir_path, name)
_, _, misc_saving_path = get_saving_path(prefix, epoch)
with open(misc_saving_path, 'r') as fp:
misc = json.load(fp)
return misc
def load_npz(path):
with numpy.load(path) as data:
ret = {k: data[k] for k in data.keys()}
return ret
def discount_cumsum(x, discount):
# See https://docs.scipy.org/doc/scipy/reference/tutorial/signal.html#difference-equation-filtering
# Here, we have y[t] - discount*y[t+1] = x[t]
# or rev(y)[t] - discount*rev(y)[t-1] = rev(x)[t]
return scipy.signal.lfilter([1], [1, -discount], x[::-1], axis=0)[::-1]
def discount_return(x, discount):
return numpy.sum(x * (discount ** numpy.arange(len(x))))
def update_on_kvstore(kv, params, params_grad):
for ind, k in enumerate(params.keys()):
kv.push(ind, params_grad[k], priority=-ind)
kv.pull(ind, params[k], priority=-ind)
def parse_ctx(ctx_args):
ctx = re.findall('([a-z]+)(\d*)', ctx_args)
ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx]
return ctx
def get_npy_list(ndarray_list):
"""Get a numpy-array list from a ndarray list
Parameters
----------
ndarray_list : list of NDArray
Returns
-------
ret : list of numpy.ndarray
"""
ret = [v.asnumpy() for v in ndarray_list]
return ret
def get_sym_list(syms, default_names=None, default_shapes=None):
if syms is None and default_names is not None:
if default_shapes is not None:
return [mx.sym.Variable(name=name, shape=shape) for (name, shape)
in zip(default_names, default_shapes)]
else:
return [mx.sym.Variable(name=name) for name in default_names]
assert isinstance(syms, (list, tuple, mx.symbol.Symbol))
if isinstance(syms, (list, tuple)):
if default_names is not None and len(syms) != len(default_names):
raise ValueError("Size of symbols do not match expectation. Received %d, Expected %d. "
"syms=%s, names=%s" %(len(syms), len(default_names),
str(list(sym.name for sym in syms)),
str(default_names)))
return list(syms)
else:
if default_names is not None and len(default_names) != 1:
raise ValueError("Size of symbols do not match expectation. Received 1, Expected %d. "
"syms=%s, names=%s"
% (len(default_names), str([syms.name]), str(default_names)))
return [syms]
def get_numeric_list(values, typ, expected_len=None):
if isinstance(values, numbers.Number):
if expected_len is not None:
return [typ(values)] * expected_len
else:
return [typ(values)]
elif isinstance(values, (list, tuple)):
if expected_len is not None:
assert len(values) == expected_len
try:
ret = [typ(value) for value in values]
return ret
except(ValueError):
print("Need iterable with numeric elements, received: %s" %str(values))
sys.exit(1)
else:
raise ValueError("Unaccepted value type, values=%s" %str(values))
def get_int_list(values, expected_len=None):
return get_numeric_list(values, numpy.int32, expected_len)
def get_float_list(values, expected_len=None):
return get_numeric_list(values, numpy.float32, expected_len)
def get_bucket_key(bucket_kwargs):
assert isinstance(bucket_kwargs, dict)
return tuple(bucket_kwargs.items())