blob: bae11e18021dd170121146f1cd758155d1cfa6c7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import os
import numpy
import json
import sys
import re
import scipy.signal
import logging
import ast
import inspect
import collections
import numbers
import cPickle as pickle
import pickle
from collections import namedtuple, OrderedDict
import time
import mxnet as mx
import mxnet.ndarray as nd
_ctx = mx.cpu()
_numpy_rng = numpy.random.RandomState(123456)
def get_default_ctx():
return _ctx
def get_numpy_rng():
return _numpy_rng
def get_saving_path(prefix="", epoch=None):
sym_saving_path = os.path.join('%s-symbol.json' % prefix)
if epoch is not None:
param_saving_path = os.path.join('%s-%05d.params' % (prefix, epoch))
param_saving_path = os.path.join('%s.params' % prefix)
misc_saving_path = os.path.join('%s-misc.json' % prefix)
return sym_saving_path, param_saving_path, misc_saving_path
def logging_config(name=None, level=logging.DEBUG, console_level=logging.DEBUG):
if name is None:
name = inspect.stack()[1][1].split('.')[0]
folder = os.path.join(os.getcwd(), name)
if not os.path.exists(folder):
logpath = os.path.join(folder, name + ".log")
print("All Logs will be saved to %s" %logpath)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logfile = logging.FileHandler(logpath)
#TODO Update logging patterns in other files
logconsole = logging.StreamHandler()
return folder
def save_params(dir_path=os.curdir, epoch=None, name="", params=None, aux_states=None,
prefix = os.path.join(dir_path, name)
_, param_saving_path, _ = get_saving_path(prefix, epoch)
if not os.path.isdir(dir_path) and not (dir_path == ""):
save_dict = {('arg:%s' % k): v.copyto(ctx) for k, v in params.items()}
save_dict.update({('aux:%s' % k): v.copyto(ctx) for k, v in aux_states.items()}), save_dict)
return param_saving_path
def save_misc(dir_path=os.curdir, epoch=None, name="", content=None):
prefix = os.path.join(dir_path, name)
_, _, misc_saving_path = get_saving_path(prefix, epoch)
with open(misc_saving_path, 'w') as fp:
json.dump(content, fp)
return misc_saving_path
def quick_save_json(dir_path=os.curdir, file_name="", content=None):
file_path = os.path.join(dir_path, file_name)
if not os.path.isdir(dir_path):
with open(file_path, 'w') as fp:
json.dump(content, fp)'Save json into %s' % file_path)
def safe_eval(expr):
if type(expr) is str:
return ast.literal_eval(expr)
return expr
def norm_clipping(params_grad, threshold):
assert isinstance(params_grad, dict)
norm_val = numpy.sqrt(sum([nd.norm(grad).asnumpy()[0]**2 for grad in params_grad.values()]))
# print('grad norm: %g' % norm_val)
ratio = 1.0
if norm_val > threshold:
ratio = threshold / norm_val
for grad in params_grad.values():
grad *= ratio
return norm_val
def sample_categorical(prob, rng):
"""Sample from independent categorical distributions
Each batch is an independent categorical distribution.
prob : numpy.ndarray
Probability of the categorical distribution. Shape --> (batch_num, category_num)
rng : numpy.random.RandomState
ret : numpy.ndarray
Sampling result. Shape --> (batch_num,)
ret = numpy.empty(prob.shape[0], dtype=numpy.float32)
for ind in range(prob.shape[0]):
ret[ind] = numpy.searchsorted(numpy.cumsum(prob[ind]), rng.rand()).clip(min=0.0,
1] - 0.5)
return ret
def sample_normal(mean, var, rng):
"""Sample from independent normal distributions
Each element is an independent normal distribution.
mean : numpy.ndarray
Means of the normal distribution. Shape --> (batch_num, sample_dim)
var : numpy.ndarray
Variance of the normal distribution. Shape --> (batch_num, sample_dim)
rng : numpy.random.RandomState
ret : numpy.ndarray
The sampling result. Shape --> (batch_num, sample_dim)
ret = numpy.sqrt(var) * rng.randn(*mean.shape) + mean
return ret
def sample_mog(prob, mean, var, rng):
"""Sample from independent mixture of gaussian (MoG) distributions
Each batch is an independent MoG distribution.
prob : numpy.ndarray
mixture probability of each gaussian. Shape --> (batch_num, center_num)
mean : numpy.ndarray
mean of each gaussian. Shape --> (batch_num, center_num, sample_dim)
var : numpy.ndarray
variance of each gaussian. Shape --> (batch_num, center_num, sample_dim)
rng : numpy.random.RandomState
ret : numpy.ndarray
sampling result. Shape --> (batch_num, sample_dim)
gaussian_inds = sample_categorical(prob, rng).astype(numpy.int32)
mean = mean[numpy.arange(mean.shape[0]), gaussian_inds, :]
var = var[numpy.arange(mean.shape[0]), gaussian_inds, :]
ret = sample_normal(mean=mean, var=var, rng=rng)
return ret
def npy_softmax(x, axis=1):
e_x = numpy.exp(x - numpy.max(x, axis=axis, keepdims=True))
out = e_x / e_x.sum(axis=axis, keepdims=True)
return out
def npy_sigmoid(x):
return 1/(1 + numpy.exp(-x))
def npy_onehot(x, num):
ret = numpy.zeros(shape=(x.size, num))
ret[numpy.arange(x.size), x.ravel()] = 1
ret = ret.reshape(x.shape + (num,))
return ret
def npy_binary_entropy(prediction, target):
assert prediction.shape == target.shape
return - (numpy.log(prediction + 1E-9) * target +
numpy.log(1 - prediction + 1E-9) * (1 - target)).sum()
def block_all(sym_list):
return [mx.symbol.BlockGrad(sym) for sym in sym_list]
def load_params(dir_path="", epoch=None, name=""):
prefix = os.path.join(dir_path, name)
_, param_loading_path, _ = get_saving_path(prefix, epoch)
while not os.path.isfile(param_loading_path):"in load_param, %s Not Found!" % param_loading_path)
save_dict = nd.load(param_loading_path)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return arg_params, aux_params, param_loading_path
def load_misc(dir_path="", epoch=None, name=""):
prefix = os.path.join(dir_path, name)
_, _, misc_saving_path = get_saving_path(prefix, epoch)
with open(misc_saving_path, 'r') as fp:
misc = json.load(fp)
return misc
def load_npz(path):
with numpy.load(path) as data:
ret = {k: data[k] for k in data.keys()}
return ret
def discount_cumsum(x, discount):
# See
# Here, we have y[t] - discount*y[t+1] = x[t]
# or rev(y)[t] - discount*rev(y)[t-1] = rev(x)[t]
return scipy.signal.lfilter([1], [1, -discount], x[::-1], axis=0)[::-1]
def discount_return(x, discount):
return numpy.sum(x * (discount ** numpy.arange(len(x))))
def update_on_kvstore(kv, params, params_grad):
for ind, k in enumerate(params.keys()):
kv.push(ind, params_grad[k], priority=-ind)
kv.pull(ind, params[k], priority=-ind)
def parse_ctx(ctx_args):
ctx = re.findall('([a-z]+)(\d*)', ctx_args)
ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx]
return ctx
def get_npy_list(ndarray_list):
"""Get a numpy-array list from a ndarray list
ndarray_list : list of NDArray
ret : list of numpy.ndarray
ret = [v.asnumpy() for v in ndarray_list]
return ret
def get_sym_list(syms, default_names=None, default_shapes=None):
if syms is None and default_names is not None:
if default_shapes is not None:
return [mx.sym.Variable(name=name, shape=shape) for (name, shape)
in zip(default_names, default_shapes)]
return [mx.sym.Variable(name=name) for name in default_names]
assert isinstance(syms, (list, tuple, mx.symbol.Symbol))
if isinstance(syms, (list, tuple)):
if default_names is not None and len(syms) != len(default_names):
raise ValueError("Size of symbols do not match expectation. Received %d, Expected %d. "
"syms=%s, names=%s" %(len(syms), len(default_names),
str(list( for sym in syms)),
return list(syms)
if default_names is not None and len(default_names) != 1:
raise ValueError("Size of symbols do not match expectation. Received 1, Expected %d. "
"syms=%s, names=%s"
% (len(default_names), str([]), str(default_names)))
return [syms]
def get_numeric_list(values, typ, expected_len=None):
if isinstance(values, numbers.Number):
if expected_len is not None:
return [typ(values)] * expected_len
return [typ(values)]
elif isinstance(values, (list, tuple)):
if expected_len is not None:
assert len(values) == expected_len
ret = [typ(value) for value in values]
return ret
print("Need iterable with numeric elements, received: %s" %str(values))
raise ValueError("Unaccepted value type, values=%s" %str(values))
def get_int_list(values, expected_len=None):
return get_numeric_list(values, numpy.int32, expected_len)
def get_float_list(values, expected_len=None):
return get_numeric_list(values, numpy.float32, expected_len)
def get_bucket_key(bucket_kwargs):
assert isinstance(bucket_kwargs, dict)
return tuple(bucket_kwargs.items())