/*!
 *  Copyright (c) 2016 by Contributors
 * \file symbol.hpp
 * \brief implementation of the symbol
 * \author Zhang Chen, Chuntao Hong
 */

#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "dmlc/logging.h"
#include "mxnet-cpp/symbol.h"

#include "mxnet-cpp/op_suppl.h"

namespace mxnet {
namespace cpp {
inline OpMap*& Symbol::op_map() {
  static OpMap* op_map_ = new OpMap();
  return op_map_;
}
inline Symbol::Symbol(SymbolHandle handle) {
  blob_ptr_ = std::make_shared<SymBlob>(handle);
}
inline Symbol::Symbol(const char *name) {
  SymbolHandle handle;
  CHECK_EQ(MXSymbolCreateVariable(name, &(handle)), 0);
  blob_ptr_ = std::make_shared<SymBlob>(handle);
}
inline Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {}
inline Symbol Symbol::Variable(const std::string &name) { return Symbol(name); }
inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); }
inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); }
inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); }
inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); }
inline Symbol Symbol::operator%(const Symbol &rhs) const { return _Mod(*this, rhs); }
inline Symbol Symbol::operator+(mx_float scalar) const {
  return _PlusScalar(*this, scalar);
}
inline Symbol Symbol::operator-(mx_float scalar) const {
  return _MinusScalar(*this, scalar);
}
inline Symbol Symbol::operator*(mx_float scalar) const {
  return _MulScalar(*this, scalar);
}
inline Symbol Symbol::operator/(mx_float scalar) const {
  return _DivScalar(*this, scalar);
}
inline Symbol Symbol::operator%(mx_float scalar) const {
  return _ModScalar(*this, scalar);
}
inline Symbol Symbol::operator[](int index) {
  SymbolHandle out;
  MXSymbolGetOutput(GetHandle(), index, &out);
  return Symbol(out);
}
inline Symbol Symbol::operator[](const std::string &index) {
  auto outputs = ListOutputs();
  for (mx_uint i = 0; i < outputs.size(); ++i) {
    if (outputs[i] == index) {
      return (*this)[i];
    }
  }
  LOG(FATAL) << "Cannot find output that matches name " << index;
  return (*this)[0];
}
inline Symbol Symbol::Group(const std::vector<Symbol> &symbols) {
  SymbolHandle out;
  std::vector<SymbolHandle> handle_list;
  for (const auto &t : symbols) {
    handle_list.push_back(t.GetHandle());
  }
  MXSymbolCreateGroup(handle_list.size(), handle_list.data(), &out);
  return Symbol(out);
}
inline Symbol Symbol::Load(const std::string &file_name) {
  op_map();
  SymbolHandle handle;
  CHECK_EQ(MXSymbolCreateFromFile(file_name.c_str(), &(handle)), 0);
  return Symbol(handle);
}
inline Symbol Symbol::LoadJSON(const std::string &json_str) {
  SymbolHandle handle;
  CHECK_EQ(MXSymbolCreateFromJSON(json_str.c_str(), &(handle)), 0);
  return Symbol(handle);
}
inline void Symbol::Save(const std::string &file_name) const {
  CHECK_EQ(MXSymbolSaveToFile(GetHandle(), file_name.c_str()), 0);
}
inline std::string Symbol::ToJSON() const {
  const char *out_json;
  CHECK_EQ(MXSymbolSaveToJSON(GetHandle(), &out_json), 0);
  return std::string(out_json);
}
inline Symbol Symbol::GetInternals() const {
  SymbolHandle handle;
  CHECK_EQ(MXSymbolGetInternals(GetHandle(), &handle), 0);
  return Symbol(handle);
}
inline Symbol::Symbol(const std::string &operator_name, const std::string &name,
               std::vector<const char *> input_keys,
               std::vector<SymbolHandle> input_values,
               std::vector<const char *> config_keys,
               std::vector<const char *> config_values) {
  SymbolHandle handle;
  AtomicSymbolCreator creator = op_map()->GetSymbolCreator(operator_name);
  MXSymbolCreateAtomicSymbol(creator, config_keys.size(), config_keys.data(),
                             config_values.data(), &handle);
  MXSymbolCompose(handle, operator_name.c_str(), input_keys.size(),
                  input_keys.data(), input_values.data());
  blob_ptr_ = std::make_shared<SymBlob>(handle);
}

inline Symbol Symbol::Copy() const {
  SymbolHandle handle;
  CHECK_EQ(MXSymbolCopy(GetHandle(), &handle), 0);
  return Symbol(handle);
}

inline std::vector<std::string> Symbol::ListArguments() const {
  std::vector<std::string> ret;
  mx_uint size;
  const char **sarr;
  MXSymbolListArguments(GetHandle(), &size, &sarr);
  for (mx_uint i = 0; i < size; ++i) {
    ret.push_back(std::string(sarr[i]));
  }
  return ret;
}
inline std::vector<std::string> Symbol::ListOutputs() const {
  std::vector<std::string> ret;
  mx_uint size;
  const char **sarr;
  MXSymbolListOutputs(GetHandle(), &size, &sarr);
  for (mx_uint i = 0; i < size; ++i) {
    ret.push_back(std::string(sarr[i]));
  }
  return ret;
}
inline std::vector<std::string> Symbol::ListAuxiliaryStates() const {
  std::vector<std::string> ret;
  mx_uint size;
  const char **sarr;
  MXSymbolListAuxiliaryStates(GetHandle(), &size, &sarr);
  for (mx_uint i = 0; i < size; ++i) {
    ret.push_back(std::string(sarr[i]));
  }
  return ret;
}

inline void Symbol::InferShape(
    const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
    std::vector<std::vector<mx_uint> > *in_shape,
    std::vector<std::vector<mx_uint> > *aux_shape,
    std::vector<std::vector<mx_uint> > *out_shape) const {

  std::vector<const char *> keys;
  std::vector<mx_uint> arg_ind_ptr;
  std::vector<mx_uint> arg_shape_data;

  for (const auto &arg : arg_shapes) {
    keys.push_back(arg.first.c_str());
    arg_ind_ptr.push_back(arg_shape_data.size());
    for (auto i : arg.second) {
      arg_shape_data.push_back(i);
    }
  }
  arg_ind_ptr.push_back(arg_shape_data.size());

  mx_uint in_shape_size;
  const mx_uint *in_shape_ndim;
  const mx_uint **in_shape_data;
  mx_uint out_shape_size;
  const mx_uint *out_shape_ndim;
  const mx_uint **out_shape_data;
  mx_uint aux_shape_size;
  const mx_uint *aux_shape_ndim;
  const mx_uint **aux_shape_data;
  int complete;

  CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(),
                              arg_ind_ptr.data(), arg_shape_data.data(),
                              &in_shape_size, &in_shape_ndim, &in_shape_data,
                              &out_shape_size, &out_shape_ndim, &out_shape_data,
                              &aux_shape_size, &aux_shape_ndim, &aux_shape_data,
                              &complete),
           0);

  if (complete) {
    for (mx_uint i = 0; i < in_shape_size; ++i) {
      in_shape->push_back(std::vector<mx_uint>());
      for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) {
        (*in_shape)[i].push_back(in_shape_data[i][j]);
      }
    }
    for (mx_uint i = 0; i < aux_shape_size; ++i) {
      aux_shape->push_back(std::vector<mx_uint>());
      for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) {
        (*aux_shape)[i].push_back(aux_shape_data[i][j]);
      }
    }
    for (mx_uint i = 0; i < out_shape_size; ++i) {
      out_shape->push_back(std::vector<mx_uint>());
      for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) {
        (*out_shape)[i].push_back(out_shape_data[i][j]);
      }
    }
  }
}

inline void Symbol::InferExecutorArrays(
    const Context &context, std::vector<NDArray> *arg_arrays,
    std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
    std::vector<NDArray> *aux_arrays,
    const std::map<std::string, NDArray> &args_map,
    const std::map<std::string, NDArray> &arg_grad_store,
    const std::map<std::string, OpReqType> &grad_req_type,
    const std::map<std::string, NDArray> &aux_map) const {

  const auto arg_name_list = ListArguments();
  std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
  std::map<std::string, std::vector<mx_uint> > arg_shapes;

  for (const auto &arg_name : arg_name_list) {
    auto iter = args_map.find(arg_name);
    if (iter != args_map.end()) {
      arg_shapes[arg_name] = iter->second.GetShape();
    }
  }

  InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

  for (size_t i = 0; i < in_shapes.size(); ++i) {
    const auto &shape = in_shapes[i];
    const auto &arg_name = arg_name_list[i];
    auto iter_arg = args_map.find(arg_name);
    if (iter_arg != args_map.end()) {
      arg_arrays->push_back(iter_arg->second);
    } else {
      arg_arrays->push_back(NDArray(shape, context, false));
      NDArray::SampleGaussian(0, 1, &arg_arrays->back());
    }
    auto iter_grad = arg_grad_store.find(arg_name);
    if (iter_grad != arg_grad_store.end()) {
      grad_arrays->push_back(iter_grad->second);
    } else {
      grad_arrays->push_back(NDArray(shape, context, false));
    }
    auto iter_req = grad_req_type.find(arg_name);
    if (iter_req != grad_req_type.end()) {
      grad_reqs->push_back(iter_req->second);
    } else if (arg_name.rfind("data") == arg_name.length() - 4
            || arg_name.rfind("label") == arg_name.length() - 5) {
      grad_reqs->push_back(OpReqType::kNullOp);
    } else {
      grad_reqs->push_back(OpReqType::kWriteTo);
    }
  }

  const auto aux_name_list = ListAuxiliaryStates();
  for (size_t i = 0; i < aux_shapes.size(); ++i) {
    const auto &shape = aux_shapes[i];
    const auto &aux_name = aux_name_list[i];
    auto iter_aux = aux_map.find(aux_name);
    if (iter_aux != aux_map.end()) {
      aux_arrays->push_back(iter_aux->second);
    } else {
      aux_arrays->push_back(NDArray(shape, context, false));
      NDArray::SampleGaussian(0, 1, &aux_arrays->back());
    }
  }
}
inline void Symbol::InferArgsMap(
    const Context &context, std::map<std::string, NDArray> *args_map,
    const std::map<std::string, NDArray> &known_args) const {

  const auto arg_name_list = ListArguments();
  std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
  std::map<std::string, std::vector<mx_uint> > arg_shapes;

  for (const auto &arg_name : arg_name_list) {
    auto iter = known_args.find(arg_name);
    if (iter != known_args.end()) {
      arg_shapes[arg_name] = iter->second.GetShape();
    }
  }

  InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

  for (size_t i = 0; i < in_shapes.size(); ++i) {
    const auto &shape = in_shapes[i];
    const auto &arg_name = arg_name_list[i];
    auto iter_arg = known_args.find(arg_name);
    if (iter_arg != known_args.end()) {
      (*args_map)[arg_name] = iter_arg->second;
    } else {
      (*args_map)[arg_name] = NDArray(shape, context, false);
      NDArray::SampleGaussian(0, 1, &(*args_map)[arg_name]);
    }
  }
}

inline Executor *Symbol::SimpleBind(
    const Context &context, const std::map<std::string, NDArray> &args_map,
    const std::map<std::string, NDArray> &arg_grad_store,
    const std::map<std::string, OpReqType> &grad_req_type,
    const std::map<std::string, NDArray> &aux_map) {
  std::vector<NDArray> arg_arrays;
  std::vector<NDArray> grad_arrays;
  std::vector<OpReqType> grad_reqs;
  std::vector<NDArray> aux_arrays;

  InferExecutorArrays(context, &arg_arrays, &grad_arrays, &grad_reqs,
                      &aux_arrays, args_map, arg_grad_store, grad_req_type,
                      aux_map);

  return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs,
                      aux_arrays);
}

inline Executor *Symbol::Bind(const Context &context,
                       const std::vector<NDArray> &arg_arrays,
                       const std::vector<NDArray> &grad_arrays,
                       const std::vector<OpReqType> &grad_reqs,
                       const std::vector<NDArray> &aux_arrays,
                       const std::map<std::string, Context> &group_to_ctx,
                       Executor *shared_exec) {
  return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs,
                      aux_arrays, group_to_ctx, shared_exec);
}
inline Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; }
inline Symbol operator-(mx_float lhs, const Symbol &rhs) {
  return mxnet::cpp::_RMinusScalar(lhs, rhs);
}
inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; }
inline Symbol operator/(mx_float lhs, const Symbol &rhs) {
  return mxnet::cpp::_RDivScalar(lhs, rhs);
}
inline Symbol operator%(mx_float lhs, const Symbol &rhs) {
  return mxnet::cpp::_RModScalar(lhs, rhs);
}
}  // namespace cpp
}  // namespace mxnet

#endif  // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_
