| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import absolute_import, division, print_function |
| |
| import os |
| import numpy |
| import json |
| import sys |
| import re |
| import scipy.signal |
| import logging |
| import ast |
| import inspect |
| import collections |
| import numbers |
| try: |
| import cPickle as pickle |
| except: |
| import pickle |
| from collections import namedtuple, OrderedDict |
| import time |
| import mxnet as mx |
| import mxnet.ndarray as nd |
| |
| |
| _ctx = mx.cpu() |
| _numpy_rng = numpy.random.RandomState(123456) |
| |
| |
| def get_default_ctx(): |
| return _ctx |
| |
| |
| def get_numpy_rng(): |
| return _numpy_rng |
| |
| |
| def get_saving_path(prefix="", epoch=None): |
| sym_saving_path = os.path.join('%s-symbol.json' % prefix) |
| if epoch is not None: |
| param_saving_path = os.path.join('%s-%05d.params' % (prefix, epoch)) |
| else: |
| param_saving_path = os.path.join('%s.params' % prefix) |
| misc_saving_path = os.path.join('%s-misc.json' % prefix) |
| return sym_saving_path, param_saving_path, misc_saving_path |
| |
| |
| def logging_config(name=None, level=logging.DEBUG, console_level=logging.DEBUG): |
| if name is None: |
| name = inspect.stack()[1][1].split('.')[0] |
| folder = os.path.join(os.getcwd(), name) |
| if not os.path.exists(folder): |
| os.makedirs(folder) |
| logpath = os.path.join(folder, name + ".log") |
| print("All Logs will be saved to %s" %logpath) |
| logging.root.setLevel(level) |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logfile = logging.FileHandler(logpath) |
| logfile.setLevel(level) |
| logfile.setFormatter(formatter) |
| logging.root.addHandler(logfile) |
| #TODO Update logging patterns in other files |
| logconsole = logging.StreamHandler() |
| logconsole.setLevel(console_level) |
| logconsole.setFormatter(formatter) |
| logging.root.addHandler(logconsole) |
| return folder |
| |
| |
| def save_params(dir_path=os.curdir, epoch=None, name="", params=None, aux_states=None, |
| ctx=mx.cpu()): |
| prefix = os.path.join(dir_path, name) |
| _, param_saving_path, _ = get_saving_path(prefix, epoch) |
| if not os.path.isdir(dir_path) and not (dir_path == ""): |
| os.makedirs(dir_path) |
| save_dict = {('arg:%s' % k): v.copyto(ctx) for k, v in params.items()} |
| save_dict.update({('aux:%s' % k): v.copyto(ctx) for k, v in aux_states.items()}) |
| nd.save(param_saving_path, save_dict) |
| return param_saving_path |
| |
| |
| def save_misc(dir_path=os.curdir, epoch=None, name="", content=None): |
| prefix = os.path.join(dir_path, name) |
| _, _, misc_saving_path = get_saving_path(prefix, epoch) |
| with open(misc_saving_path, 'w') as fp: |
| json.dump(content, fp) |
| return misc_saving_path |
| |
| |
| def quick_save_json(dir_path=os.curdir, file_name="", content=None): |
| file_path = os.path.join(dir_path, file_name) |
| if not os.path.isdir(dir_path): |
| os.makedirs(dir_path) |
| with open(file_path, 'w') as fp: |
| json.dump(content, fp) |
| logging.info('Save json into %s' % file_path) |
| |
| |
| def safe_eval(expr): |
| if type(expr) is str: |
| return ast.literal_eval(expr) |
| else: |
| return expr |
| |
| |
| def norm_clipping(params_grad, threshold): |
| assert isinstance(params_grad, dict) |
| norm_val = numpy.sqrt(sum([nd.norm(grad).asnumpy()[0]**2 for grad in params_grad.values()])) |
| # print('grad norm: %g' % norm_val) |
| ratio = 1.0 |
| if norm_val > threshold: |
| ratio = threshold / norm_val |
| for grad in params_grad.values(): |
| grad *= ratio |
| return norm_val |
| |
| |
| def sample_categorical(prob, rng): |
| """Sample from independent categorical distributions |
| |
| Each batch is an independent categorical distribution. |
| |
| Parameters |
| ---------- |
| prob : numpy.ndarray |
| Probability of the categorical distribution. Shape --> (batch_num, category_num) |
| rng : numpy.random.RandomState |
| |
| Returns |
| ------- |
| ret : numpy.ndarray |
| Sampling result. Shape --> (batch_num,) |
| """ |
| ret = numpy.empty(prob.shape[0], dtype=numpy.float32) |
| for ind in range(prob.shape[0]): |
| ret[ind] = numpy.searchsorted(numpy.cumsum(prob[ind]), rng.rand()).clip(min=0.0, |
| max=prob.shape[ |
| 1] - 0.5) |
| return ret |
| |
| |
| def sample_normal(mean, var, rng): |
| """Sample from independent normal distributions |
| |
| Each element is an independent normal distribution. |
| |
| Parameters |
| ---------- |
| mean : numpy.ndarray |
| Means of the normal distribution. Shape --> (batch_num, sample_dim) |
| var : numpy.ndarray |
| Variance of the normal distribution. Shape --> (batch_num, sample_dim) |
| rng : numpy.random.RandomState |
| |
| Returns |
| ------- |
| ret : numpy.ndarray |
| The sampling result. Shape --> (batch_num, sample_dim) |
| """ |
| ret = numpy.sqrt(var) * rng.randn(*mean.shape) + mean |
| return ret |
| |
| |
| def sample_mog(prob, mean, var, rng): |
| """Sample from independent mixture of gaussian (MoG) distributions |
| |
| Each batch is an independent MoG distribution. |
| |
| Parameters |
| ---------- |
| prob : numpy.ndarray |
| mixture probability of each gaussian. Shape --> (batch_num, center_num) |
| mean : numpy.ndarray |
| mean of each gaussian. Shape --> (batch_num, center_num, sample_dim) |
| var : numpy.ndarray |
| variance of each gaussian. Shape --> (batch_num, center_num, sample_dim) |
| rng : numpy.random.RandomState |
| |
| Returns |
| ------- |
| ret : numpy.ndarray |
| sampling result. Shape --> (batch_num, sample_dim) |
| """ |
| gaussian_inds = sample_categorical(prob, rng).astype(numpy.int32) |
| mean = mean[numpy.arange(mean.shape[0]), gaussian_inds, :] |
| var = var[numpy.arange(mean.shape[0]), gaussian_inds, :] |
| ret = sample_normal(mean=mean, var=var, rng=rng) |
| return ret |
| |
| |
| def npy_softmax(x, axis=1): |
| e_x = numpy.exp(x - numpy.max(x, axis=axis, keepdims=True)) |
| out = e_x / e_x.sum(axis=axis, keepdims=True) |
| return out |
| |
| |
| def npy_sigmoid(x): |
| return 1/(1 + numpy.exp(-x)) |
| |
| |
| def npy_onehot(x, num): |
| ret = numpy.zeros(shape=(x.size, num)) |
| ret[numpy.arange(x.size), x.ravel()] = 1 |
| ret = ret.reshape(x.shape + (num,)) |
| return ret |
| |
| def npy_binary_entropy(prediction, target): |
| assert prediction.shape == target.shape |
| return - (numpy.log(prediction + 1E-9) * target + |
| numpy.log(1 - prediction + 1E-9) * (1 - target)).sum() |
| |
| |
| def block_all(sym_list): |
| return [mx.symbol.BlockGrad(sym) for sym in sym_list] |
| |
| |
| def load_params(dir_path="", epoch=None, name=""): |
| prefix = os.path.join(dir_path, name) |
| _, param_loading_path, _ = get_saving_path(prefix, epoch) |
| while not os.path.isfile(param_loading_path): |
| logging.info("in load_param, %s Not Found!" % param_loading_path) |
| time.sleep(60) |
| save_dict = nd.load(param_loading_path) |
| arg_params = {} |
| aux_params = {} |
| for k, v in save_dict.items(): |
| tp, name = k.split(':', 1) |
| if tp == 'arg': |
| arg_params[name] = v |
| if tp == 'aux': |
| aux_params[name] = v |
| return arg_params, aux_params, param_loading_path |
| |
| |
| def load_misc(dir_path="", epoch=None, name=""): |
| prefix = os.path.join(dir_path, name) |
| _, _, misc_saving_path = get_saving_path(prefix, epoch) |
| with open(misc_saving_path, 'r') as fp: |
| misc = json.load(fp) |
| return misc |
| |
| |
| def load_npz(path): |
| with numpy.load(path) as data: |
| ret = {k: data[k] for k in data.keys()} |
| return ret |
| |
| |
| def discount_cumsum(x, discount): |
| # See https://docs.scipy.org/doc/scipy/reference/tutorial/signal.html#difference-equation-filtering |
| # Here, we have y[t] - discount*y[t+1] = x[t] |
| # or rev(y)[t] - discount*rev(y)[t-1] = rev(x)[t] |
| return scipy.signal.lfilter([1], [1, -discount], x[::-1], axis=0)[::-1] |
| |
| |
| def discount_return(x, discount): |
| return numpy.sum(x * (discount ** numpy.arange(len(x)))) |
| |
| |
| def update_on_kvstore(kv, params, params_grad): |
| for ind, k in enumerate(params.keys()): |
| kv.push(ind, params_grad[k], priority=-ind) |
| kv.pull(ind, params[k], priority=-ind) |
| |
| |
| def parse_ctx(ctx_args): |
| ctx = re.findall('([a-z]+)(\d*)', ctx_args) |
| ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx] |
| return ctx |
| |
| |
| def get_npy_list(ndarray_list): |
| """Get a numpy-array list from a ndarray list |
| Parameters |
| ---------- |
| ndarray_list : list of NDArray |
| |
| Returns |
| ------- |
| ret : list of numpy.ndarray |
| """ |
| ret = [v.asnumpy() for v in ndarray_list] |
| return ret |
| |
| |
| def get_sym_list(syms, default_names=None, default_shapes=None): |
| if syms is None and default_names is not None: |
| if default_shapes is not None: |
| return [mx.sym.Variable(name=name, shape=shape) for (name, shape) |
| in zip(default_names, default_shapes)] |
| else: |
| return [mx.sym.Variable(name=name) for name in default_names] |
| assert isinstance(syms, (list, tuple, mx.symbol.Symbol)) |
| if isinstance(syms, (list, tuple)): |
| if default_names is not None and len(syms) != len(default_names): |
| raise ValueError("Size of symbols do not match expectation. Received %d, Expected %d. " |
| "syms=%s, names=%s" %(len(syms), len(default_names), |
| str(list(sym.name for sym in syms)), |
| str(default_names))) |
| return list(syms) |
| else: |
| if default_names is not None and len(default_names) != 1: |
| raise ValueError("Size of symbols do not match expectation. Received 1, Expected %d. " |
| "syms=%s, names=%s" |
| % (len(default_names), str([syms.name]), str(default_names))) |
| return [syms] |
| |
| |
| def get_numeric_list(values, typ, expected_len=None): |
| if isinstance(values, numbers.Number): |
| if expected_len is not None: |
| return [typ(values)] * expected_len |
| else: |
| return [typ(values)] |
| elif isinstance(values, (list, tuple)): |
| if expected_len is not None: |
| assert len(values) == expected_len |
| try: |
| ret = [typ(value) for value in values] |
| return ret |
| except(ValueError): |
| print("Need iterable with numeric elements, received: %s" %str(values)) |
| sys.exit(1) |
| else: |
| raise ValueError("Unaccepted value type, values=%s" %str(values)) |
| |
| |
| def get_int_list(values, expected_len=None): |
| return get_numeric_list(values, numpy.int32, expected_len) |
| |
| |
| def get_float_list(values, expected_len=None): |
| return get_numeric_list(values, numpy.float32, expected_len) |
| |
| |
| def get_bucket_key(bucket_kwargs): |
| assert isinstance(bucket_kwargs, dict) |
| return tuple(bucket_kwargs.items()) |