/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file c_api.cc
 * \brief C API of mxnet
 */
#include <vector>
#include <sstream>
#include <string>
#include <mutex>
#include <memory>
#include <functional>
#include <unordered_map>
#include <utility>
#include "dmlc/base.h"
#include "dmlc/logging.h"
#include "dmlc/io.h"
#include "dmlc/memory_io.h"
#include "dmlc/recordio.h"
#include "dmlc/omp.h"
#include "mxnet/base.h"
#include "mxnet/ndarray.h"
#include "mxnet/operator.h"
#include "mxnet/io.h"
#include "mxnet/c_api.h"
#include "mxnet/kvstore.h"
#include "mxnet/rtc.h"
#include "mxnet/storage.h"
#include "mxnet/libinfo.h"
#include "mxnet/imperative.h"
#include "mxnet/lib_api.h"
#include "../initialize.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "../operator/tvmop/op_module.h"
#include "../operator/subgraph/partitioner/custom_subgraph_property.h"
#include "../operator/subgraph/subgraph_property.h"
#include "../common/alm.h"
#include "../common/utils.h"
#include "../profiler/profiler.h"
#include "../serialization/cnpy.h"
#include "miniz.h"
#include "nnvm/pass_functions.h"

// FTZ only applies to SSE and AVX instructions.
#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
#define SUPPORT_FTZ_DMZ 1
#else
#define SUPPORT_FTZ_DMZ 0
#endif

#if SUPPORT_FTZ_DMZ
#include <immintrin.h>
#include <xmmintrin.h>
#endif
#if SUPPORT_FTZ_DMZ && !defined(_MSC_VER)
#include <x86intrin.h>
#endif

#if MXNET_USE_CUDA
#include <cuda_profiler_api.h>
#endif
#include "../common/cuda/nvtx.h"

using namespace mxnet;

// Internal function to get the information
// from function registry
// Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo
template <typename FunRegType>
inline int MXAPIGetFunctionRegInfo(const FunRegType* e,
                                   const char** name,
                                   const char** description,
                                   uint32_t* num_args,
                                   const char*** arg_names,
                                   const char*** arg_type_infos,
                                   const char*** arg_descriptions,
                                   const char** return_type) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();

  API_BEGIN();
  *name        = e->name.c_str();
  *description = e->description.c_str();
  *num_args    = static_cast<uint32_t>(e->arguments.size());
  if (return_type)
    *return_type = e->return_type.c_str();
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < e->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(e->arguments[i].name.c_str());
  }
  for (size_t i = 0; i < e->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(e->arguments[i].type_info_str.c_str());
  }
  for (size_t i = 0; i < e->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(e->arguments[i].description.c_str());
  }
  *arg_names        = dmlc::BeginPtr(ret->ret_vec_charp);
  *arg_type_infos   = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size();
  *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2);
  API_END();
}

// NOTE: return value is added in API_END

std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize, mxnet::ext::msgGet_t msgGet) {
  std::string str;
  if (msgSize() > 0) {
    str = "\nExtension Traceback:\n";
    for (int i = 0; i < msgSize(); i++) {
      const char* tmp;
      msgGet(i, &tmp);
      // format: [i] message
      str += std::string("\t[") + std::to_string(i) + std::string("] ") + std::string(tmp) +
             std::string("\n");
    }
  }
  return str;
}

/*!
 * \brief Common compute function dispatcher for forward/backward and stateful forward/backward
 * state_ptr will be nullptr for regular ops; fcomp_fp is nullptr for stateful ops
 */
void CustomFComputeDispatcher(const std::string op_name,
                              const mxnet::ext::opCallFComp_t callFComp,
                              const mxnet::ext::fcomp_t fcomp_fp,
                              const nnvm::NodeAttrs* attrs,
                              const mxnet::ext::opCallFStatefulComp_t callFStatefulComp,
                              int stateful_forward_flag,
                              const OpStatePtr* state_ptr,
                              const OpContext& ctx,
                              const std::vector<NDArray>& inputs,
                              const std::vector<OpReqType>& req,
                              const std::vector<NDArray>& outputs,
                              mxnet::ext::msgSize_t msgSize,
                              mxnet::ext::msgGet_t msgGet) {
  using namespace mxnet::ext;

  std::vector<void*> in_data, out_data;
  std::vector<const int64_t*> in_shapes, out_shapes;
  std::vector<int> in_dims, out_dims;
  std::vector<int> in_types, out_types;
  std::vector<size_t> in_verIDs, out_verIDs;
  std::vector<const char*> in_dev_type, out_dev_type;
  std::vector<int> in_dev_id, out_dev_id;
  std::vector<NDArray> conv_dnnl;  // converted NDArrays from DNNL format

  // Extra data for sparse inputs and outputs.
  std::vector<int> in_stypes(inputs.size(), 0), out_stypes(outputs.size(), 0);
  std::vector<void*> in_indices(inputs.size(), nullptr), out_indices(outputs.size(), nullptr);
  std::vector<void*> in_indptr(inputs.size(), nullptr), out_indptr(outputs.size(), nullptr);
  std::vector<int64_t> in_indices_shapes(inputs.size(), 0), out_indices_shapes(outputs.size(), 0);
  std::vector<int64_t> in_indptr_shapes(inputs.size(), 0), out_indptr_shapes(outputs.size(), 0);

  // convert inputs/outpus NDArray to C types to be passed to lib_api.h
  for (size_t i = 0; i < inputs.size(); i++) {
    NDArray const* in_nd = &(inputs[i]);
#if MXNET_USE_ONEDNN == 1
    // reorder data if in DNNL format
    if (in_nd->IsDNNLData()) {
      // convert from DNNL
      conv_dnnl.push_back(in_nd->Reorder2Default());
      in_nd = &(conv_dnnl.back());
    }
#endif
    // pull out parts to pass over to library
    in_data.push_back(in_nd->data().dptr_);
    in_shapes.push_back(in_nd->shape().data());
    in_dims.push_back(in_nd->shape().ndim());
    in_types.push_back(in_nd->dtype());
    in_verIDs.push_back(in_nd->version());
    // string repr of supported context for custom library, currently only "cpu" and "gpu"
    const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
    in_dev_type.push_back(ctx_str);

    in_dev_id.push_back(in_nd->ctx().real_dev_id());
    if (inputs[i].storage_type() == mxnet::kRowSparseStorage) {
      in_stypes[i]         = 1;
      in_indices[i]        = inputs[i].aux_data(rowsparse::kIdx).dptr_;
      in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size();
    } else if (inputs[i].storage_type() == mxnet::kCSRStorage) {
      in_stypes[i]         = 2;
      in_indices[i]        = inputs[i].aux_data(csr::kIdx).dptr_;
      in_indptr[i]         = inputs[i].aux_data(csr::kIndPtr).dptr_;
      in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size();
      in_indptr_shapes[i]  = inputs[i].aux_shape(csr::kIndPtr).Size();
    }
  }

  for (size_t i = 0; i < outputs.size(); i++) {
    out_data.push_back(outputs[i].data().dptr_);
    out_shapes.push_back(outputs[i].shape().data());
    out_dims.push_back(outputs[i].shape().ndim());
    out_types.push_back(outputs[i].dtype());
    out_verIDs.push_back(outputs[i].version());
    const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
    out_dev_type.push_back(ctx_str);
    out_dev_id.push_back(outputs[i].ctx().real_dev_id());

    if (outputs[i].storage_type() == mxnet::kRowSparseStorage) {
      out_stypes[i]         = 1;
      out_indices[i]        = outputs[i].aux_data(rowsparse::kIdx).dptr_;
      out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size();
    } else if (outputs[i].storage_type() == mxnet::kCSRStorage) {
      out_stypes[i]         = 2;
      out_indices[i]        = outputs[i].aux_data(csr::kIdx).dptr_;
      out_indptr[i]         = outputs[i].aux_data(csr::kIndPtr).dptr_;
      out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size();
      out_indptr_shapes[i]  = outputs[i].aux_shape(csr::kIndPtr).Size();
    }
  }

  // get memory resource and mxnet backend streams
  CHECK(ctx.requested.size() >= 2)
      << "Custom operator should register at least memory resource and parallel random resource";
  const Resource& resource                = ctx.requested.at(0);
  mshadow::Stream<mxnet::cpu>* cpu_stream = ctx.get_stream<mxnet::cpu>();
  mshadow::Stream<mxnet::gpu>* gpu_stream = ctx.get_stream<mxnet::gpu>();

  // create lambda that captures stream & resource objects
  // this temp workspace holds memory allocated by custom library via OpResource
  auto cpu_alloc = [&](int size) {
    mshadow::Tensor<mxnet::cpu, 1, char> workspace =
        resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size), cpu_stream);
    return workspace.dptr_;
  };
  auto gpu_alloc = [&](int size) {
    mshadow::Tensor<mxnet::gpu, 1, char> workspace =
        resource.get_space_typed<mxnet::gpu, 1, char>(mshadow::Shape1(size), gpu_stream);
    return workspace.dptr_;
  };

  // create lambda that allocates memory for sparse and
  // returns allocated arrays for data, indices and indptr.
  auto sparse_alloc = [&](int index,
                          int indices_len,
                          int idxptr_len,
                          void** data,
                          int64_t** indices,
                          int64_t** indptr) {
    if (idxptr_len == 0) {
      // Row Sparse
      outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)});
      *data    = outputs[index].data().dptr_;
      *indices = reinterpret_cast<int64_t*>(outputs[index].aux_data(rowsparse::kIdx).dptr_);
    } else {
      // CSR
      outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len), mshadow::Shape1(indices_len)});
      *data    = outputs[index].data().dptr_;
      *indices = reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIdx).dptr_);
      *indptr  = reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIndPtr).dptr_);
    }
  };

  // create no-capture lambda so that we can cast it to function pointer
  // lambda with captures cannot be cast to function pointer and pass to lib_api.h
  // this needs to be a lambda function so that we can do the decltype cast
  typedef decltype(cpu_alloc) alloc_type_cpu;
  auto cpu_malloc = [](void* _cpu_alloc, int size) {
    // cast the void* argument to the type for the cpu_alloc lambda function
    alloc_type_cpu* cpualloc = static_cast<alloc_type_cpu*>(_cpu_alloc);
    // call cpu_alloc to actually allocate memory and return the pointer
    return static_cast<void*>((*cpualloc)(size));
  };

  using alloc_type_gpu = decltype(gpu_alloc);
  auto gpu_malloc      = [](void* _gpu_alloc, int size) {
    alloc_type_gpu* gpualloc = static_cast<alloc_type_gpu*>(_gpu_alloc);
    return static_cast<void*>((*gpualloc)(size));
  };

  using alloc_type_sparse = decltype(sparse_alloc);
  auto sparse_malloc      = [](void* _sparse_alloc,
                          int index,
                          int indices_len,
                          int idxptr_len,
                          void** data,
                          int64_t** indices,
                          int64_t** indptr) {
    alloc_type_sparse* sparsealloc = static_cast<alloc_type_sparse*>(_sparse_alloc);
    (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr);
  };

  // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h
  void* cuda_stream = nullptr;
#if MXNET_USE_CUDA
  if ((inputs.size() > 0 && inputs[0].ctx().dev_mask() == Context::kGPU) ||
      (outputs.size() > 0 && outputs[0].ctx().dev_mask() == Context::kGPU)) {
    cuda_stream = static_cast<void*>(gpu_stream->stream_);
  }
#endif

  // get mxnet initialized and seeded RNG states and pass to lib_api.h
  void *rng_cpu_states = nullptr, *rng_gpu_states = nullptr;
  using mxnet::common::random::RandGenerator;
  RandGenerator<cpu, float>* pgen_cpu = ctx.requested.at(1).get_parallel_random<cpu, float>();
  rng_cpu_states                      = pgen_cpu->GetStates();
#if MXNET_USE_CUDA
  RandGenerator<gpu, float>* pgen_gpu = ctx.requested.at(1).get_parallel_random<gpu, float>();
  rng_gpu_states                      = pgen_gpu->GetStates();
#endif

  CHECK((fcomp_fp != nullptr && state_ptr == nullptr) ||
        (fcomp_fp == nullptr && state_ptr != nullptr))
      << "Can only register either regular op or stateful op for '" << op_name << "'";

  if (fcomp_fp != nullptr) {
    // convert attributes to vector of char*
    std::vector<const char*> attr_keys, attr_vals;
    for (auto& kv : attrs->dict) {
      attr_keys.push_back(kv.first.c_str());
      attr_vals.push_back(kv.second.c_str());
    }

    // call fcompute function
    int retval       = callFComp(fcomp_fp,
                           attr_keys.data(),
                           attr_vals.data(),
                           attr_keys.size(),
                           in_shapes.data(),
                           in_dims.data(),
                           in_data.data(),
                           in_types.data(),
                           in_verIDs.data(),
                           in_dev_type.data(),
                           in_dev_id.data(),
                           in_data.size(),
                           out_shapes.data(),
                           out_dims.data(),
                           out_data.data(),
                           out_types.data(),
                           out_verIDs.data(),
                           out_dev_type.data(),
                           out_dev_id.data(),
                           out_data.size(),
                           cpu_malloc,
                           &cpu_alloc,
                           gpu_malloc,
                           &gpu_alloc,
                           cuda_stream,
                           sparse_malloc,
                           &sparse_alloc,
                           in_stypes.data(),
                           out_stypes.data(),
                           in_indices.data(),
                           out_indices.data(),
                           in_indptr.data(),
                           out_indptr.data(),
                           in_indices_shapes.data(),
                           out_indices_shapes.data(),
                           in_indptr_shapes.data(),
                           out_indptr_shapes.data(),
                           rng_cpu_states,
                           rng_gpu_states);
    std::string msgs = getExtensionMsgs(msgSize, msgGet);
    CHECK(retval) << "Error calling FCompute for custom operator '" << op_name << "'" << msgs;
  }

  if (state_ptr != nullptr) {
    // retrieve op state object created from CreateOpState
    CustomStatefulOpWrapper& op     = state_ptr->get_state<CustomStatefulOpWrapper>();
    CustomStatefulOp* state_op_inst = op.get_instance();
    std::string msgs                = getExtensionMsgs(msgSize, msgGet);
    CHECK(state_op_inst != nullptr)
        << "Error custom stateful operator is null for operator '" << op_name << "'" << msgs;

    // call fcompute function
    int retval = callFStatefulComp(stateful_forward_flag,
                                   state_op_inst,
                                   in_shapes.data(),
                                   in_dims.data(),
                                   in_data.data(),
                                   in_types.data(),
                                   in_verIDs.data(),
                                   in_dev_type.data(),
                                   in_dev_id.data(),
                                   in_data.size(),
                                   out_shapes.data(),
                                   out_dims.data(),
                                   out_data.data(),
                                   out_types.data(),
                                   out_verIDs.data(),
                                   out_dev_type.data(),
                                   out_dev_id.data(),
                                   out_data.size(),
                                   cpu_malloc,
                                   &cpu_alloc,
                                   gpu_malloc,
                                   &gpu_alloc,
                                   cuda_stream,
                                   sparse_malloc,
                                   &sparse_alloc,
                                   in_stypes.data(),
                                   out_stypes.data(),
                                   in_indices.data(),
                                   out_indices.data(),
                                   in_indptr.data(),
                                   out_indptr.data(),
                                   in_indices_shapes.data(),
                                   out_indices_shapes.data(),
                                   in_indptr_shapes.data(),
                                   out_indptr_shapes.data(),
                                   rng_cpu_states,
                                   rng_gpu_states);
    msgs       = getExtensionMsgs(msgSize, msgGet);
    CHECK(retval) << "Error calling FStatefulCompute for custom operator '" << op_name << "'"
                  << msgs;
  }
}

template <typename RescReq,
          typename AttrParser,
          typename NumInputs,
          typename NumOutputs,
          typename NumInOuts,
          typename InferType,
          typename InferShape,
          typename InferSType,
          typename MutateInputs,
          typename SubgraphNumInputs,
          typename SubgraphInferType,
          typename SubgraphInferShape,
          typename SubgraphInferSType,
          typename CreateOpState,
          typename GradReg>
void registerOp(const char* name,
                const std::string& name_str,
                bool isSubgraphOp,
                RescReq resc_req,
                AttrParser attr_parser,
                NumInputs num_inputs,
                NumOutputs num_outputs,
                NumInOuts num_inouts,
                InferType infer_type,
                InferShape infer_shape,
                InferSType infer_storage_type,
                MutateInputs mutate_inputs,
                SubgraphNumInputs num_subgraph_inputs,
                SubgraphInferType infer_subgraph_type,
                SubgraphInferShape infer_subgraph_shape,
                SubgraphInferSType infer_subgraph_storage_type,
                CreateOpState create_opstate,
                GradReg grad_reg,
                mxnet::ext::mutateInputs_t mutate_fp,
                const std::unordered_map<std::string, mxnet::ext::createOpState_t>& createop_map,
                const std::unordered_map<std::string, mxnet::ext::fcomp_t>& forward_ctx_map,
                const std::unordered_map<std::string, mxnet::ext::fcomp_t>& backward_ctx_map,
                mxnet::ext::opCallFComp_t callFComp,
                mxnet::ext::opCallFStatefulComp_t callFStatefulComp,
                mxnet::ext::msgSize_t msgSize,
                mxnet::ext::msgGet_t msgGet) {
  using namespace mxnet::ext;

  // check if operator is already registered
  const nnvm::Op* regOpPtr = dmlc::Registry<nnvm::Op>::Get()->Find(name);
  nnvm::Op& regOp          = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(name);
  int plevel               = 10;
  if (regOpPtr != nullptr) {
    // overwrite registration of existing op with custom op
    regOp.arguments.clear();
    // set attribute with higher plevel (11) to allow re-registering once
    // TODO(samskalicky): enable constant overwriting of registertion multiple times
    plevel++;
  }
  // define supported resources for both subgraph ops and regular ops
  regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
  if (!isSubgraphOp) {
    regOp.set_attr_parser(attr_parser);
    regOp.set_num_inputs(num_inputs);
    regOp.set_num_outputs(num_outputs);
    regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
    regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
    regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
    // optionally add fmutate inputs if user specified a function
    if (mutate_fp != nullptr)
      regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
  } else {
    using namespace mxnet::op;
    regOp.set_num_inputs(num_subgraph_inputs);
    regOp.set_num_outputs(DefaultSubgraphOpNumOutputs);
    regOp.set_attr<nnvm::FInferType>("FInferType", infer_subgraph_type, plevel);
    regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_subgraph_shape, plevel);
    regOp.set_attr<FInferStorageType>("FInferStorageType", infer_subgraph_storage_type, plevel);
    regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs, plevel);
  }
  // optionally add stateful forward
  if (createop_map.size() != 0) {
    regOp.set_attr<FCreateOpState>("FCreateOpState", create_opstate, plevel);
    auto fstate_forward = [=](const OpStatePtr& state_ptr,
                              const OpContext& ctx,
                              const std::vector<NDArray>& inputs,
                              const std::vector<OpReqType>& req,
                              const std::vector<NDArray>& outputs) {
      CustomFComputeDispatcher(name_str,
                               nullptr,
                               nullptr,
                               nullptr,
                               callFStatefulComp,
                               1,
                               &state_ptr,
                               ctx,
                               inputs,
                               req,
                               outputs,
                               msgSize,
                               msgGet);
    };
    if (createop_map.count("cpu") > 0)
      regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_forward, plevel);
    if (createop_map.count("gpu") > 0)
      regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_forward, plevel);
  } else {
    auto forward_lambda = [=](const nnvm::NodeAttrs& attrs,
                              const OpContext& ctx,
                              const std::vector<NDArray>& inputs,
                              const std::vector<OpReqType>& req,
                              const std::vector<NDArray>& outputs) {
      if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) {
        CHECK_GT(forward_ctx_map.count("cpu"), 0);
        fcomp_t fcomp = forward_ctx_map.at("cpu");
        CustomFComputeDispatcher(name_str,
                                 callFComp,
                                 fcomp,
                                 &attrs,
                                 nullptr,
                                 0,
                                 nullptr,
                                 ctx,
                                 inputs,
                                 req,
                                 outputs,
                                 msgSize,
                                 msgGet);
      } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) {
        CHECK_GT(forward_ctx_map.count("gpu"), 0);
        fcomp_t fcomp = forward_ctx_map.at("gpu");
        CustomFComputeDispatcher(name_str,
                                 callFComp,
                                 fcomp,
                                 &attrs,
                                 nullptr,
                                 0,
                                 nullptr,
                                 ctx,
                                 inputs,
                                 req,
                                 outputs,
                                 msgSize,
                                 msgGet);
      }
    };
    if (forward_ctx_map.count("cpu") > 0)
      regOp.set_attr<FComputeEx>("FComputeEx<cpu>", forward_lambda, plevel);
    if (forward_ctx_map.count("gpu") > 0)
      regOp.set_attr<FComputeEx>("FComputeEx<gpu>", forward_lambda, plevel);
  }
  // optionally add fgradient if user specified a function, or for stateful ops
  if (backward_ctx_map.size() != 0 || createop_map.size() != 0) {
    std::string grad_name = "_backward_" + name_str;
    nnvm::Op& gradOp      = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(grad_name);
    regOp.set_attr<nnvm::FGradient>("FGradient", grad_reg, plevel);
    gradOp.set_attr<nnvm::TIsBackward>("TIsBackward", true, plevel);
    gradOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
    gradOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);

    if (!isSubgraphOp) {
      // register attr parser and standard functions for non-subgraph ops
      gradOp.set_attr_parser(attr_parser);
      gradOp.set_num_inputs(num_inouts);
      gradOp.set_num_outputs(num_inputs);
    } else {
      // for subgraph ops use special functions that do not invoke attr_parser
      using namespace mxnet::op;
      auto grad_inouts = [=](const nnvm::NodeAttrs& attrs) {
        // for backward passes, inputs + outputs + input gradients (one for each output)
        uint32_t cnt = num_subgraph_inputs(attrs);
        cnt += 2 * DefaultSubgraphOpNumOutputs(attrs);
        return cnt;
      };
      gradOp.set_num_inputs(grad_inouts);
      gradOp.set_num_outputs(num_subgraph_inputs);
    }

    if (createop_map.size() != 0) {
      // for stateful operators
      gradOp.set_attr<bool>("TIsLayerOpBackward", true, plevel);
      auto fstate_backward = [=](const OpStatePtr& state_ptr,
                                 const OpContext& ctx,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
        CustomFComputeDispatcher(name_str,
                                 nullptr,
                                 nullptr,
                                 nullptr,
                                 callFStatefulComp,
                                 0,
                                 &state_ptr,
                                 ctx,
                                 inputs,
                                 req,
                                 outputs,
                                 msgSize,
                                 msgGet);
      };
      gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_backward, plevel);
      gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_backward, plevel);
    } else {
      // for stateless operators
      if (backward_ctx_map.count("cpu") > 0) {
        fcomp_t fcomp_back_cpu   = backward_ctx_map.at("cpu");
        auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs,
                                       const OpContext& ctx,
                                       const std::vector<NDArray>& inputs,
                                       const std::vector<OpReqType>& req,
                                       const std::vector<NDArray>& outputs) {
          CustomFComputeDispatcher(name_str,
                                   callFComp,
                                   fcomp_back_cpu,
                                   &attrs,
                                   nullptr,
                                   0,
                                   nullptr,
                                   ctx,
                                   inputs,
                                   req,
                                   outputs,
                                   msgSize,
                                   msgGet);
        };
        gradOp.set_attr<FComputeEx>("FComputeEx<cpu>", backward_cpu_lambda, plevel);
      }
      if (backward_ctx_map.count("gpu") > 0) {
        fcomp_t fcomp_back_gpu   = backward_ctx_map.at("gpu");
        auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs,
                                       const OpContext& ctx,
                                       const std::vector<NDArray>& inputs,
                                       const std::vector<OpReqType>& req,
                                       const std::vector<NDArray>& outputs) {
          CustomFComputeDispatcher(name_str,
                                   callFComp,
                                   fcomp_back_gpu,
                                   &attrs,
                                   nullptr,
                                   0,
                                   nullptr,
                                   ctx,
                                   inputs,
                                   req,
                                   outputs,
                                   msgSize,
                                   msgGet);
        };
        gradOp.set_attr<FComputeEx>("FComputeEx<gpu>", backward_gpu_lambda, plevel);
      }
    }
  }
  regOp.add_argument("data", "NDArray[]", "Source inputs");
}

void registerOperators(void* lib,
                       int verbose,
                       mxnet::ext::msgSize_t msgSize,
                       mxnet::ext::msgGet_t msgGet) {
  using namespace mxnet::ext;

  // get C type interface functions
  opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));

  opCallParseAttrs_t callParseAttrs =
      get_func<opCallParseAttrs_t>(lib, const_cast<char*>(MXLIB_OPCALLPARSEATTRS_STR));

  opCallInferShape_t callInferShape =
      get_func<opCallInferShape_t>(lib, const_cast<char*>(MXLIB_OPCALLINFERSHAPE_STR));

  opCallInferType_t callInferType =
      get_func<opCallInferType_t>(lib, const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));

  opCallInferSType_t callInferSType =
      get_func<opCallInferSType_t>(lib, const_cast<char*>(MXLIB_OPCALLINFERSTYPE_STR));

  opCallFComp_t callFComp = get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));

  opCallMutateInputs_t callMutateInputs =
      get_func<opCallMutateInputs_t>(lib, const_cast<char*>(MXLIB_OPCALLMUTATEINPUTS_STR));

  opCallCreateOpState_t callCreateOpState =
      get_func<opCallCreateOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLCREATEOPSTATE_STR));

  opCallDestroyOpState_t callDestroyOpState =
      get_func<opCallDestroyOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLDESTROYOPSTATE_STR));

  opCallFStatefulComp_t callFStatefulComp =
      get_func<opCallFStatefulComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFSTATEFULCOMP_STR));

  // get number of operators registered in the library
  opRegSize_t opRegSize = get_func<opRegSize_t>(lib, const_cast<char*>(MXLIB_OPREGSIZE_STR));
  int numOps            = opRegSize();
  if (verbose)
    LOG(INFO) << "Found " << numOps << " operators in library";

  /*
   * Get all custom operators implementation from custom library
   * loop and register each operator in the library to NNVM
   */
  opRegGet_t opRegGet = get_func<opRegGet_t>(lib, const_cast<char*>(MXLIB_OPREGGET_STR));
  for (int i = 0; i < numOps; i++) {
    const char* name;
    // function pointers holding implementation from custom library
    parseAttrs_t parse_fp = nullptr;
    inferType_t type_fp   = nullptr;
    inferSType_t stype_fp = nullptr;
    inferShape_t shape_fp = nullptr;
    // optional attributes
    mutateInputs_t mutate_fp = nullptr;
    bool isSubgraphOp        = false;
    int _isSubgraphOp        = 0;
    // lists of forward and backward function associated with each context
    const char **forward_ctx, **backward_ctx, **createop_ctx;
    fcomp_t *forward_fcomp, *backward_fcomp;
    createOpState_t* createop_fp;
    int forward_count, backward_count, createop_count;

    // main function to get custom operator implemenation from the custom library
    opRegGet(i,
             &name,
             &_isSubgraphOp,
             &forward_ctx,
             &forward_fcomp,
             &forward_count,
             &backward_ctx,
             &backward_fcomp,
             &backward_count,
             &createop_ctx,
             &createop_fp,
             &createop_count,
             &parse_fp,
             &type_fp,
             &stype_fp,
             &shape_fp,
             &mutate_fp);

    // construct maps of context to forward/backward custom library function
    std::unordered_map<std::string, fcomp_t> forward_ctx_map;
    std::unordered_map<std::string, fcomp_t> backward_ctx_map;
    std::unordered_map<std::string, createOpState_t> createop_map;
    for (int i = 0; i < forward_count; i++) {
      std::string ctx_str(forward_ctx[i]);
      forward_ctx_map[ctx_str] = forward_fcomp[i];
    }
    for (int i = 0; i < backward_count; i++) {
      std::string ctx_str(backward_ctx[i]);
      backward_ctx_map[ctx_str] = backward_fcomp[i];
    }
    for (int i = 0; i < createop_count; i++) {
      std::string ctx_str(createop_ctx[i]);
      createop_map[ctx_str] = createop_fp[i];
    }
    // set bool, dont pass bool across ABI boundary
    isSubgraphOp = _isSubgraphOp;

    // validate custom operator functions from the dynamic library
    if (!isSubgraphOp) {
      CHECK(parse_fp != nullptr) << "Error loading '" << name
                                 << "' custom op, ParseAttrs function was not set.";
      CHECK(forward_ctx_map.size() != 0 || createop_map.size() != 0)
          << "Error loading '" << name
          << "' custom op, Forward or CreateOpState function was not set.";
      CHECK(type_fp != nullptr) << "Error loading '" << name
                                << "' custom op, InferType function was not set.";
      CHECK(shape_fp != nullptr) << "Error loading '" << name
                                 << "' custom op, InferShape function was not set.";
    } else {
      CHECK(createop_map.size() != 0)
          << "Error loading '" << name
          << "' custom subgraph op, CreateOpState function was not set.";
    }
    if (verbose)
      LOG(INFO) << "\tOp[" << i << "] " << name;
    if (verbose && isSubgraphOp)
      LOG(INFO) << "\t\tisSubgraphOp";
    std::string name_str(name);

    /*
     * Below are a series of lambda functions that will be registered in the NNVM op registration
     * Each one has the standard MXNet signature and converts to types supported by externally
     * registered operators.
     */

    // lambda function to call parse attributes
    auto attr_parser = [=](const NodeAttrs* attrs) {
      // convert attributes to vector of char
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs->dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }
      // convert subgraph symbol from node attributes to char*
      std::string subgraph_json;
      if (!attrs->subgraphs.empty()) {
        nnvm::Graph g;
        g.outputs     = attrs->subgraphs[0].get()->outputs;
        subgraph_json = nnvm::pass::SaveJSON(g);
        attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON);
        attr_vals.push_back(subgraph_json.c_str());
      }

      int num_in  = -1;
      int num_out = -1;
      int retval  = callParseAttrs(
          parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling ParseAttrs for custom operator '" << name_str << "'" << msgs;

      // return type void
    };

    // lambda function to call parse attributes and return the number of inputs
    auto num_inputs = [=](const NodeAttrs& attrs) {
      // convert attributes to vector of char
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      int num_in  = -1;
      int num_out = -1;
      int retval  = callParseAttrs(
          parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str
                    << "'" << msgs;

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

      return num_in + extra_inputs;
    };

    // lambda function to call parse attributes and return the number of inputs for subgraph ops
    auto num_subgraph_inputs = [=](const NodeAttrs& attrs) {
      // get number of inputs for subgraph
      int num_in = mxnet::op::DefaultSubgraphOpNumInputs(attrs);

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

      return num_in + extra_inputs;
    };

    // lambda function to call parse attributes and return the number of outputs
    auto num_outputs = [=](const NodeAttrs& attrs) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      int num_in  = -1;
      int num_out = -1;
      int retval  = callParseAttrs(
          parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str
                    << "'" << msgs;

      return num_out;
    };

    // lambda function to call parse attributes and return the number of inputs and outputs
    // for backward computation
    auto num_inouts = [=](const NodeAttrs& attrs) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      int num_in  = -1;
      int num_out = -1;
      int retval  = callParseAttrs(
          parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str
                    << "'" << msgs;
      // for backward passes, inputs + outputs + input gradients (one for each output)

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

      return num_in + extra_inputs + 2 * num_out;
    };

    // lambda function to call infer shape
    auto infer_shape = [=](const nnvm::NodeAttrs& attrs,
                           mxnet::ShapeVector* in_shape,
                           mxnet::ShapeVector* out_shape) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
      int num_inputs = in_shape->size() - extra_inputs;

      std::vector<uint32_t*> inshapes(num_inputs);
      std::vector<int> indims(num_inputs);

      // determine amount of memory needed to store all the input shapes
      size_t buff_size = 0;
      for (size_t i = 0; i < num_inputs; ++i)
        buff_size += (*in_shape)[i].ndim();

      // copy input shapes from ShapeVector to raw memory layout
      std::vector<uint32_t> inbuff(buff_size);
      uint32_t* ptr = inbuff.data();
      for (size_t i = 0; i < num_inputs; ++i) {
        inshapes[i] = ptr;
        indims[i]   = (*in_shape)[i].ndim();
        for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
          *ptr = static_cast<uint32_t>((*in_shape)[i][j]);
        }
      }

      // modified input shapes will be allocated by infer shape function
      uint32_t** mod_inshapes = nullptr;
      int* mod_indims         = nullptr;
      // output shapes will be allocated by infer shape function
      uint32_t** outshapes = nullptr;
      int* outdims         = nullptr;

      int retval       = callInferShape(shape_fp,
                                  attr_keys.data(),
                                  attr_vals.data(),
                                  attr_keys.size(),
                                  inshapes.data(),
                                  indims.data(),
                                  num_inputs,
                                  &mod_inshapes,
                                  &mod_indims,
                                  &outshapes,
                                  &outdims,
                                  out_shape->size());
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling InferShape for custom operator '" << name_str << "'" << msgs;

      std::vector<uint32_t*> in_shapes(num_inputs);
      // determine amount of memory needed to store all the modified input shapes
      buff_size = 0;
      for (unsigned i = 0; i < num_inputs; i++) {
        buff_size += mod_indims[i];
      }

      // copy modified input shapes from custom op memory to MXNet memory
      std::vector<uint32_t> mod_inbuff(buff_size);
      ptr = mod_inbuff.data();
      for (unsigned i = 0; i < num_inputs; ++i) {
        in_shapes[i] = ptr;
        for (int j = 0; j < mod_indims[i]; ++j, ++ptr) {
          *ptr = static_cast<uint32_t>(mod_inshapes[i][j]);
        }
      }

      // assign modified input shapes to ShapeVector
      for (unsigned i = 0; i < num_inputs; ++i) {
        SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(in_shapes[i], in_shapes[i] + mod_indims[i]));
      }

      std::vector<uint32_t*> out_shapes(out_shape->size());
      // determine amount of memory needed to store all the output shapes
      buff_size = 0;
      for (unsigned i = 0; i < out_shape->size(); i++) {
        buff_size += outdims[i];
      }

      // copy output shapes from custom op memory to MXNet memory
      std::vector<uint32_t> outbuff(buff_size);
      ptr = outbuff.data();
      for (unsigned i = 0; i < out_shape->size(); ++i) {
        out_shapes[i] = ptr;
        for (int j = 0; j < outdims[i]; ++j, ++ptr) {
          *ptr = static_cast<uint32_t>(outshapes[i][j]);
        }
      }

      // assign output shapes to ShapeVector
      for (unsigned i = 0; i < out_shape->size(); ++i) {
        SHAPE_ASSIGN_CHECK(*out_shape, i, mxnet::TShape(out_shapes[i], out_shapes[i] + outdims[i]));
      }

      // free memory used by custom op to allocate shapes/dims
      callFree(mod_indims);
      for (unsigned i = 0; i < num_inputs; i++) {
        callFree(mod_inshapes[i]);
      }
      callFree(mod_inshapes);

      callFree(outdims);
      for (unsigned i = 0; i < out_shape->size(); i++) {
        callFree(outshapes[i]);
      }
      callFree(outshapes);

      return true;
    };

    // lambda function to call infer shape for subgraph ops
    auto infer_subgraph_shape = [=](const nnvm::NodeAttrs& attrs,
                                    mxnet::ShapeVector* in_shape,
                                    mxnet::ShapeVector* out_shape) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

      auto in_first                    = in_shape->begin();
      auto in_last                     = in_first + in_shape->size() - extra_inputs;
      mxnet::ShapeVector* sg_in_shapes = new mxnet::ShapeVector(in_first, in_last);
      bool res = mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);

      // assign modified input shapes to ShapeVector
      for (unsigned i = 0; i < sg_in_shapes->size(); ++i) {
        SHAPE_ASSIGN_CHECK(*in_shape, i, sg_in_shapes->at(i));
      }
      return res;
    };

    // lambda function to call infer type
    auto infer_type = [=](const nnvm::NodeAttrs& attrs,
                          std::vector<int>* in_type,
                          std::vector<int>* out_type) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
      int num_inputs = in_type->size() - extra_inputs;

      // copy input types from in_type
      std::vector<int> intypes(*in_type);

      // output types will be populated by inferType function
      std::vector<int> outtypes(out_type->size());

      int retval       = callInferType(type_fp,
                                 attr_keys.data(),
                                 attr_vals.data(),
                                 attr_keys.size(),
                                 intypes.data(),
                                 num_inputs,
                                 outtypes.data(),
                                 out_type->size());
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling InferType for custom operator '" << name_str << "'" << msgs;

      // copy and assign modified input types from custom op to MXNet memory
      for (size_t i = 0; i < num_inputs; i++) {
        TYPE_ASSIGN_CHECK(*in_type, i, intypes[i]);
      }
      // copy and assign output types from custom op to MXNet memory
      for (size_t i = 0; i < out_type->size(); i++) {
        TYPE_ASSIGN_CHECK(*out_type, i, outtypes[i]);
      }

      return true;
    };

    // lambda function to call infer type for subgraph ops
    auto infer_subgraph_type =
        [=](const nnvm::NodeAttrs& attrs, std::vector<int>* in_type, std::vector<int>* out_type) {
          // convert attributes to vector of char*
          std::vector<const char*> attr_keys, attr_vals;
          for (auto& kv : attrs.dict) {
            attr_keys.push_back(kv.first.c_str());
            attr_vals.push_back(kv.second.c_str());
          }

          // get extra inputs, if exists
          int extra_inputs = 0;
          if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
            extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

          auto in_first                 = in_type->begin();
          auto in_last                  = in_first + in_type->size() - extra_inputs;
          std::vector<int>* sg_in_types = new std::vector<int>(in_first, in_last);

          bool res = mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
          // copy and assign modified input types
          for (size_t i = 0; i < sg_in_types->size(); i++) {
            TYPE_ASSIGN_CHECK(*in_type, i, sg_in_types->at(i));
          }
          return res;
        };

    // lambda function to convert from external mutate_inputs to internal MXNet types
    auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      // C type placeholder for mutate input indices vector
      int* mutate_indices = nullptr;
      int indices_size    = 0;

      // call mutate inputs function
      int retval       = callMutateInputs(mutate_fp,
                                    attr_keys.data(),
                                    attr_vals.data(),
                                    attr_keys.size(),
                                    &mutate_indices,
                                    &indices_size);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling MutateInputs for custom operator '" << name_str << "'"
                    << msgs;

      std::vector<uint32_t> mutate_indices_list(indices_size);
      for (int i = 0; i < indices_size; i++) {
        mutate_indices_list[i] = static_cast<uint32_t>(mutate_indices[i]);
      }

      return mutate_indices_list;
    };

    // lambda function to set storage types
    auto infer_storage_type = [=](const nnvm::NodeAttrs& attrs,
                                  const int dev_mask,
                                  DispatchMode* dispatch_mode,
                                  std::vector<int>* in_stypes,
                                  std::vector<int>* out_stypes) {
      if (stype_fp == nullptr) {
        // InferSType is not defined in customized lib.
        CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage))
            << "Error input tensors are not dense for custom operator '" << name_str << "'";
        // set outputs as dense
        return op::storage_type_assign(
            out_stypes, mxnet::kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
      } else {
        // InferSType is defined in customized lib.
        // convert attributes to vector of char*
        std::vector<const char*> attr_keys, attr_vals;
        for (const auto& kv : attrs.dict) {
          attr_keys.push_back(kv.first.c_str());
          attr_vals.push_back(kv.second.c_str());
        }

        // get extra inputs, if exists
        int extra_inputs = 0;
        if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
          extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
        int num_inputs = in_stypes->size() - extra_inputs;

        // copy input types from in_stype
        std::vector<int> instypes(*in_stypes);

        // output types will be populated by inferType function
        std::vector<int> outstypes(out_stypes->size());
        int retval       = callInferSType(stype_fp,
                                    attr_keys.data(),
                                    attr_vals.data(),
                                    attr_keys.size(),
                                    instypes.data(),
                                    num_inputs,
                                    outstypes.data(),
                                    out_stypes->size());
        std::string msgs = getExtensionMsgs(msgSize, msgGet);
        CHECK(retval) << "Error calling InferSType for custom operator '" << name_str << "'"
                      << msgs;

        // copy and assign modified input storage types from custom op to MXNet memory.
        for (size_t i = 0; i < num_inputs; i++) {
          STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, instypes[i]);
        }
        // copy and assign output storage types from custom op to MXNet memory.
        for (size_t i = 0; i < out_stypes->size(); i++) {
          STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]);
        }
        // assign dispatch mode
        DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
        return true;
      }
    };

    // lambda function to set storage types for subgraph ops
    auto infer_subgraph_storage_type = [=](const nnvm::NodeAttrs& attrs,
                                           const int dev_mask,
                                           DispatchMode* dispatch_mode,
                                           std::vector<int>* in_stypes,
                                           std::vector<int>* out_stypes) {
      // get extra inputs, if exists
      int extra_inputs = 0;
      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));

      auto in_first                  = in_stypes->begin();
      auto in_last                   = in_first + in_stypes->size() - extra_inputs;
      std::vector<int>* sg_in_stypes = new std::vector<int>(in_first, in_last);

      bool res = mxnet::op::DefaultSubgraphOpStorageType(
          attrs, dev_mask, dispatch_mode, sg_in_stypes, out_stypes);
      // copy and assign modified input storage types
      for (size_t i = 0; i < sg_in_stypes->size(); i++) {
        STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, sg_in_stypes->at(i));
      }
      return res;
    };

    // FGradient register lambda
    auto grad_reg = [=](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
      // create node for gradient
      auto p                = nnvm::Node::Create();
      std::string grad_name = "_backward_" + name_str;
      p->attrs.op           = nnvm::Op::Get(grad_name.c_str());
      p->attrs.name         = n->attrs.name + "_backward";
      // copy attributes and subgraphs
      p->attrs.dict = n->attrs.dict;
      for (const auto& s : n->attrs.subgraphs)
        p->attrs.subgraphs.push_back(s);
      // set control dependency and attr parser
      p->control_deps.emplace_back(n);
      if (p->op()->attr_parser != nullptr) {
        p->op()->attr_parser(&(p->attrs));
      }
      // gradient inputs: copy gradients first
      std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
      // copy inputs second
      for (auto& h : n->inputs) {
        heads.push_back(h);
      }
      // gradient inputs: copy outputs last
      uint32_t n_out = n->num_outputs();
      for (uint32_t i = 0; i < n_out; ++i) {
        heads.emplace_back(n, i, 0);
      }
      // set inputs to gradient node
      p->inputs = heads;
      CHECK_EQ(p->num_inputs(), p->inputs.size())
          << "Number of inputs to operator " << grad_name << " (" << p->num_inputs()
          << ") does not match the actual number of inputs provided to operator " << p->attrs.name
          << " (" << p->inputs.size() << ").";
      // create output node entries
      return mxnet::op::CreateNodeEntries(p);
    };

    auto resc_req = [=](const NodeAttrs& attrs) {
      return std::vector<ResourceRequest>{ResourceRequest::kTempSpace,
                                          ResourceRequest::kParallelRandom};
    };

    // library author should implement and return a 'state' which points to an instance
    // in lambda we create OpStatePtr using the returned 'state'
    auto create_opstate = [=](const NodeAttrs& attrs,
                              Context ctx,
                              const std::vector<TShape>& in_shapes,
                              const std::vector<int>& in_types) {
      // convert attributes to vector of char*
      std::vector<const char*> attr_keys, attr_vals;
      for (auto& kv : attrs.dict) {
        attr_keys.push_back(kv.first.c_str());
        attr_vals.push_back(kv.second.c_str());
      }

      // string repr of supported context for custom library, currently only "cpu" and "gpu"
      const char* ctx_str = ctx.dev_mask() == Context::kCPU ? "cpu" : "gpu";

      std::vector<uint32_t*> inshapes(in_shapes.size());
      std::vector<int> indims(in_shapes.size());

      // determine amount of memory needed to store all the input shapes
      size_t buff_size = 0;
      for (const auto& in_shape : in_shapes)
        buff_size += in_shape.ndim();

      // copy input shapes to raw memory layout
      std::vector<uint32_t> inbuff(buff_size);
      uint32_t* ptr = inbuff.data();
      for (size_t i = 0; i < in_shapes.size(); ++i) {
        inshapes[i] = ptr;
        indims[i]   = in_shapes[i].ndim();
        for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) {
          *ptr = static_cast<uint32_t>(in_shapes[i][j]);
        }
      }

      // convert subgraph symbol from node attributes to char*
      std::string subgraph_json;
      if (!attrs.subgraphs.empty()) {
        nnvm::Graph g;
        g.outputs     = attrs.subgraphs[0].get()->outputs;
        subgraph_json = nnvm::pass::SaveJSON(g);
        attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON);
        attr_vals.push_back(subgraph_json.c_str());
      }

      // create a pointer to hold custom op state object
      // only create one stateful op depending on passing context
      // user can add new supported context and call to custom library
      void* state_op_inst = nullptr;
      if (ctx.dev_mask() == Context::kCPU) {
        CHECK(createop_map.count("cpu") > 0)
            << "CPU CreateOpState not implemented for '" << name_str << "'";
        int retval       = callCreateOpState(createop_map.at("cpu"),
                                       attr_keys.data(),
                                       attr_vals.data(),
                                       attr_keys.size(),
                                       ctx_str,
                                       ctx.real_dev_id(),
                                       inshapes.data(),
                                       indims.data(),
                                       in_shapes.size(),
                                       in_types.data(),
                                       &state_op_inst);
        std::string msgs = getExtensionMsgs(msgSize, msgGet);
        CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"
                      << msgs;
      } else if (ctx.dev_mask() == Context::kGPU) {
        CHECK(createop_map.count("gpu") > 0)
            << "GPU CreateOpState not implemented for '" << name_str << "'";
        int retval       = callCreateOpState(createop_map.at("gpu"),
                                       attr_keys.data(),
                                       attr_vals.data(),
                                       attr_keys.size(),
                                       ctx_str,
                                       ctx.real_dev_id(),
                                       inshapes.data(),
                                       indims.data(),
                                       in_shapes.size(),
                                       in_types.data(),
                                       &state_op_inst);
        std::string msgs = getExtensionMsgs(msgSize, msgGet);
        CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"
                      << msgs;
      }

      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(state_op_inst != nullptr)
          << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs;

      CustomStatefulOp* state_op = reinterpret_cast<CustomStatefulOp*>(state_op_inst);
      if (!state_op->wasCreated() && !state_op->ignore_warn)
        LOG(INFO) << "WARNING! Custom stateful op " << state_op_inst << " was created without "
                  << "calling CustomStatefulOp::create(). Please ensure this object was "
                  << "allocated with 'new' since it will be destructed with 'delete'. "
                  << "To suppress this message without calling CustomStatefulOp::create() "
                  << "set ignore_warn to 'true' on custom stateful op instance.";
      return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op, callDestroyOpState);
    };

    /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */

    registerOp(name,
               name_str,
               isSubgraphOp,
               resc_req,
               attr_parser,
               num_inputs,
               num_outputs,
               num_inouts,
               infer_type,
               infer_shape,
               infer_storage_type,
               mutate_inputs,
               num_subgraph_inputs,
               infer_subgraph_type,
               infer_subgraph_shape,
               infer_subgraph_storage_type,
               create_opstate,
               grad_reg,
               mutate_fp,
               createop_map,
               forward_ctx_map,
               backward_ctx_map,
               callFComp,
               callFStatefulComp,
               msgSize,
               msgGet);
  }
}  // NOLINT

void registerPartitioners(void* lib,
                          int verbose,
                          mxnet::ext::msgSize_t msgSize,
                          mxnet::ext::msgGet_t msgGet) {
  using namespace mxnet::ext;

  // get C type interface functions
  opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));

  partCallSupportedOps_t callSupportedOps =
      get_func<partCallSupportedOps_t>(lib, const_cast<char*>(MXLIB_PARTCALLSUPPORTEDOPS_STR));

  partCallCreateSelector_t callCreateSelector =
      get_func<partCallCreateSelector_t>(lib, const_cast<char*>(MXLIB_PARTCALLCREATESELECTOR_STR));

  partCallSelect_t callSelect =
      get_func<partCallSelect_t>(lib, const_cast<char*>(MXLIB_PARTCALLSELECT_STR));

  partCallSelectInput_t callSelectInput =
      get_func<partCallSelectInput_t>(lib, const_cast<char*>(MXLIB_PARTCALLSELECTINPUT_STR));

  partCallSelectOutput_t callSelectOutput =
      get_func<partCallSelectOutput_t>(lib, const_cast<char*>(MXLIB_PARTCALLSELECTOUTPUT_STR));

  partCallFilter_t callFilter =
      get_func<partCallFilter_t>(lib, const_cast<char*>(MXLIB_PARTCALLFILTER_STR));

  partCallReset_t callReset =
      get_func<partCallReset_t>(lib, const_cast<char*>(MXLIB_PARTCALLRESET_STR));

  partCallReviewSubgraph_t callReviewSubgraph =
      get_func<partCallReviewSubgraph_t>(lib, const_cast<char*>(MXLIB_PARTCALLREVIEWSUBGRAPH_STR));

  // get number of partitioners registered in the library
  partRegSize_t partRegSize =
      get_func<partRegSize_t>(lib, const_cast<char*>(MXLIB_PARTREGSIZE_STR));
  int numParts = partRegSize();
  if (verbose)
    LOG(INFO) << "Found " << numParts << " partitioners in library";

  /*
   * Get all custom partitioners implementation from custom library
   * loop and register each partitioner in the library to NNVM
   */
  partRegGetCount_t partRegGetCount =
      get_func<partRegGetCount_t>(lib, const_cast<char*>(MXLIB_PARTREGGETCOUNT_STR));
  partRegGet_t partRegGet = get_func<partRegGet_t>(lib, const_cast<char*>(MXLIB_PARTREGGET_STR));
  for (int i = 0; i < numParts; i++) {
    const char* name;
    // get custom partitioner strategy count from the dynamic library
    int count = partRegGetCount(i, &name);
    CHECK(count > 0) << "Error loading '" << name << "' custom partitioner, no strategies defined";
    std::string name_str(name);
    if (verbose)
      LOG(INFO) << "\tPartitioner[" << i << "] " << name;

    mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_BACKEND__(name);

    for (int j = 0; j < count; j++) {
      const char* strategy;
      // function pointers holding implementation from custom library
      supportedOps_t supportedOps_fp     = nullptr;
      createSelector_t createSelector_fp = nullptr;
      reviewSubgraph_t reviewSubgraph_fp = nullptr;
      // name of subgraph op
      const char* op_name = nullptr;

      // get custom partitioner strategy from the dynamic library
      partRegGet(
          i, j, &strategy, &supportedOps_fp, &createSelector_fp, &reviewSubgraph_fp, &op_name);
      // validate custom partitioner functions from the dynamic library
      if (supportedOps_fp == nullptr && createSelector_fp == nullptr)
        LOG(ERROR) << "Error loading '" << name << "' custom partitioner strategy '" << strategy
                   << "', must implement supportedOps or createSelector";
      std::string strategy_str(strategy);
      std::string op_name_str(op_name);
      if (verbose)
        LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str << " subgraphOp: '" << op_name_str
                  << "'";
      mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__(
          name_str,
          std::make_shared<mxnet::op::CustomSubgraphProperty>(strategy_str,
                                                              callSupportedOps,
                                                              supportedOps_fp,
                                                              callCreateSelector,
                                                              createSelector_fp,
                                                              callSelect,
                                                              callSelectInput,
                                                              callSelectOutput,
                                                              callFilter,
                                                              callReset,
                                                              callReviewSubgraph,
                                                              reviewSubgraph_fp,
                                                              callFree,
                                                              op_name_str));
    }
  }
}

void registerPasses(void* lib,
                    int verbose,
                    mxnet::ext::msgSize_t msgSize,
                    mxnet::ext::msgGet_t msgGet) {
  using namespace mxnet::ext;

  // get C type interface functions
  opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));

  passCallGraphPass_t callGraphPass =
      get_func<passCallGraphPass_t>(lib, const_cast<char*>(MXLIB_PASSCALLGRAPHPASS_STR));

  // get number of passes registered in the library
  partRegSize_t passRegSize =
      get_func<passRegSize_t>(lib, const_cast<char*>(MXLIB_PASSREGSIZE_STR));
  int numPasses = passRegSize();
  if (verbose)
    LOG(INFO) << "Found " << numPasses << " graph passes in library";

  /*
   * Get all custom pass implementation from custom library
   * loop and register each pass in the library to NNVM
   */
  passRegGet_t passRegGet = get_func<passRegGet_t>(lib, const_cast<char*>(MXLIB_PASSREGGET_STR));
  for (int i = 0; i < numPasses; i++) {
    const char* name;
    // function pointers holding implementation from custom library
    graphPass_t pass_fp = nullptr;

    // main function to get custom pass implemenation from the custom library
    passRegGet(i, &pass_fp, &name);

    if (verbose)
      LOG(INFO) << "\tGraph Pass [" << i << "] " << name;

    auto pass_lambda = [=](nnvm::Graph&& g) {
      // get pass name
      const char* pass_name = g.GetAttr<const char*>("pass_name");
      // get options
      const std::unordered_map<std::string, std::string>& options_map =
          g.GetAttr<const std::unordered_map<std::string, std::string>>("options_map");
      // convert options_map_ to char* to pass to backend library
      std::vector<const char*> opt_keys, opt_vals;
      for (auto& kv : options_map) {
        opt_keys.push_back(kv.first.c_str());
        opt_vals.push_back(kv.second.c_str());
      }

      // get input args and arg names
      std::vector<std::string> in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
      std::vector<std::string> in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
      NDArray** in_args_ptr                 = g.GetAttr<NDArray**>("in_args");
      NDArray** in_aux_ptr                  = g.GetAttr<NDArray**>("in_aux");

      // get shapes/types
      mxnet::ShapeVector shapes;
      if (g.HasAttr("shape"))
        shapes = g.GetAttr<mxnet::ShapeVector>("shape");
      std::vector<int> dtypes;
      if (g.HasAttr("dtype"))
        dtypes = g.GetAttr<std::vector<int>>("dtype");
      g.attrs.clear();
      const nnvm::IndexedGraph& indexed_graph = g.indexed_graph();

      // set shape attrs for each node in the graph
      if (shapes.size() > 0) {
        for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) {
          nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[nid].source);
          std::stringstream ss;
          ss << "[";
          // set the output shapes for this node
          for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
            const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
            mxnet::TShape& shape        = shapes[out_entry_id];
            ss << shape;
            if (oid < node->num_outputs() - 1)
              ss << ",";
          }
          ss << "]";
          node->attrs.dict[MX_STR_SHAPE] = ss.str();
        }
      }
      // set dtype attrs for each node in the graph
      if (dtypes.size() > 0) {
        for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) {
          nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[nid].source);
          std::stringstream ss;
          ss << "[";
          // set the output dtypes for this node
          for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
            const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
            int dtype                   = dtypes[out_entry_id];
            ss << dtype;
            if (oid < node->num_outputs() - 1)
              ss << ",";
          }
          ss << "]";
          node->attrs.dict[MX_STR_DTYPE] = ss.str();
        }
      }

      std::vector<const char*> arg_names, aux_names;
      std::vector<void*> arg_data, aux_data;
      std::vector<const int64_t*> arg_shapes, aux_shapes;
      std::vector<int> arg_dims, aux_dims;
      std::vector<int> arg_types, aux_types;
      std::vector<size_t> arg_verIDs, aux_verIDs;
      std::vector<const char*> arg_dev_type, aux_dev_type;
      std::vector<int> arg_dev_id, aux_dev_id;

      // convert input args
      for (size_t i = 0; i < in_arg_names.size(); i++) {
        if (in_args_ptr[i] != nullptr) {
          arg_names.push_back(in_arg_names[i].c_str());
          const NDArray& in_arg = *(in_args_ptr[i]);

#if MXNET_USE_ONEDNN == 1
          // reorder data if in DNNL format
          if (in_arg.IsDNNLData()) {
            in_arg.Reorder2DefaultAsync();
            in_arg.WaitToRead();
          }
#endif

          // pull out parts of NDArray to send to backend
          arg_data.push_back(in_arg.data().dptr_);
          arg_shapes.push_back(in_arg.shape().data());
          arg_dims.push_back(in_arg.shape().ndim());
          arg_types.push_back(in_arg.dtype());
          arg_verIDs.push_back(in_arg.version());
          const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
          arg_dev_type.push_back(arg_ctx_str);
          arg_dev_id.push_back(in_arg.ctx().real_dev_id());
        }
      }

      // convert input aux
      for (size_t i = 0; i < in_aux_names.size(); i++) {
        if (in_aux_ptr[i] != nullptr) {
          aux_names.push_back(in_aux_names[i].c_str());
          const auto& in_aux = *(in_aux_ptr[i]);

#if MXNET_USE_ONEDNN == 1
          // reorder data if in DNNL format
          if (in_aux.IsDNNLData()) {
            in_aux.Reorder2DefaultAsync();
            in_aux.WaitToRead();
          }
#endif

          // pull out parts of NDArray to send to backend
          aux_data.push_back(in_aux.data().dptr_);
          aux_shapes.push_back(in_aux.shape().data());
          aux_dims.push_back(in_aux.shape().ndim());
          aux_types.push_back(in_aux.dtype());
          aux_verIDs.push_back(in_aux.version());
          const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
          aux_dev_type.push_back(aux_ctx_str);
          aux_dev_id.push_back(in_aux.ctx().real_dev_id());
        }
      }

      // convert graph to string
      std::string in_json = nnvm::pass::SaveJSON(g);

      std::vector<std::string> new_arg_names, new_aux_names;
      std::vector<NDArray*> new_args, new_aux;

      // create lambda that captures stream & resource objects
      // this temp workspace holds memory allocated by custom library via OpResource
      auto ndarray_alloc =
          [&](const mxnet::TShape& shape, Context ctx, int dtype, std::string name, bool isArg) {
            NDArray* arr = new NDArray(shape, ctx, false, dtype);
            if (isArg) {
              new_args.push_back(arr);
              new_arg_names.push_back(name);
            } else {
              new_aux.push_back(arr);
              new_aux_names.push_back(name);
            }
            return arr;
          };

      // create no-capture lambda so that we can cast it to function pointer
      // lambda with captures cannot be cast to function pointer and pass to lib_api.h
      // this needs to be a lambda function so that we can do the decltype cast
      using alloc_type_ndarray = decltype(ndarray_alloc);
      auto ndarray_malloc      = [](const void* _ndarray_alloc,
                               const int64_t* shapes,
                               int num_shapes,
                               const char* dev_str,
                               int dev_id,
                               int dtype,
                               const char* name,
                               int isArg,
                               void** data) {
        mxnet::TShape shape(num_shapes, 0);
        for (int i = 0; i < num_shapes; i++)
          shape[i] = shapes[i];
        int dev_type = -1;
        if (strcmp(dev_str, "cpu") == 0)
          dev_type = kCPU;
        else
          dev_type = kGPU;
        Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);

        // cast the void* argument to the type for the cpu_alloc lambda function
        const alloc_type_ndarray* ndalloc = static_cast<const alloc_type_ndarray*>(_ndarray_alloc);
        // call cpu_alloc to actually allocate memory and return the pointer
        NDArray* arr = (*ndalloc)(shape, ctx, dtype, name, isArg);
        *data        = arr->data().dptr_;
      };

      char* out_json;
      int retval       = callGraphPass(pass_fp,
                                 in_json.c_str(),
                                 &out_json,
                                 opt_keys.data(),
                                 opt_vals.data(),
                                 opt_keys.size(),
                                 pass_name,
                                 arg_names.data(),
                                 arg_names.size(),
                                 arg_data.data(),
                                 arg_shapes.data(),
                                 arg_dims.data(),
                                 arg_types.data(),
                                 arg_verIDs.data(),
                                 arg_dev_type.data(),
                                 arg_dev_id.data(),
                                 aux_names.data(),
                                 aux_names.size(),
                                 aux_data.data(),
                                 aux_shapes.data(),
                                 aux_dims.data(),
                                 aux_types.data(),
                                 aux_verIDs.data(),
                                 aux_dev_type.data(),
                                 aux_dev_id.data(),
                                 ndarray_malloc,
                                 &ndarray_alloc);
      std::string msgs = getExtensionMsgs(msgSize, msgGet);
      CHECK(retval) << "Error calling graph pass for '" << pass_name << "'" << msgs;

      std::string out_string(out_json);
      nnvm::Graph out_graph = nnvm::pass::LoadJSON(out_string);

      out_graph.attrs["new_args"]      = std::make_shared<nnvm::any>(new_args);
      out_graph.attrs["new_arg_names"] = std::make_shared<nnvm::any>(new_arg_names);
      out_graph.attrs["new_aux"]       = std::make_shared<nnvm::any>(new_aux);
      out_graph.attrs["new_aux_names"] = std::make_shared<nnvm::any>(new_aux_names);

      callFree(out_json);
      return out_graph;
    };

    nnvm::PassFunctionReg& pass = dmlc::Registry<nnvm::PassFunctionReg>::Get()->__REGISTER__(name);
    pass.set_body(pass_lambda);
    pass.set_change_graph(true);
  }
}

/*!
 * \brief Loads dynamic custom library and initializes it
 * \param path library path
 */
int MXLoadLib(const char* path, unsigned verbose, void** lib) {
  API_BEGIN();
  *lib = LibraryInitializer::Get()->lib_load(path);
  if (!*lib)
    LOG(FATAL) << "Unable to load library";

  // check that library and MXNet use same version of library API
  mxnet::ext::opVersion_t opVersion =
      get_func<mxnet::ext::opVersion_t>(*lib, const_cast<char*>(MXLIB_OPVERSION_STR));
  int libVersion = opVersion();
  if (MX_LIBRARY_VERSION != libVersion)
    LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version ("
               << MX_LIBRARY_VERSION << ")";

  // get error messaging APIs
  mxnet::ext::msgSize_t msgSize =
      get_func<mxnet::ext::msgSize_t>(*lib, const_cast<char*>(MXLIB_MSGSIZE_STR));
  mxnet::ext::msgGet_t msgGet =
      get_func<mxnet::ext::msgGet_t>(*lib, const_cast<char*>(MXLIB_MSGGET_STR));

  // initialize library by passing MXNet version
  mxnet::ext::initialize_t initialize =
      get_func<mxnet::ext::initialize_t>(*lib, const_cast<char*>(MXLIB_INITIALIZE_STR));
  if (!initialize(static_cast<int>(MXNET_VERSION))) {
    std::string msgs = getExtensionMsgs(msgSize, msgGet);
    LOG(FATAL) << "Library failed to initialize" << msgs;
  }

  // find ops, partitioners, and passes in library
  registerOperators(*lib, verbose, msgSize, msgGet);
  registerPartitioners(*lib, verbose, msgSize, msgGet);
  registerPasses(*lib, verbose, msgSize, msgGet);
  API_END();
}

int MXLibInfoFeatures(const struct LibFeature** lib_features, size_t* size) {
  using namespace features;
  API_BEGIN();
  LibInfo* lib_info = LibInfo::getInstance();
  *lib_features     = lib_info->getFeatures().data();
  *size             = lib_info->getFeatures().size();
  API_END();
}

int MXLibInfoCompiledWithCXX11ABI(int* result) {
  API_BEGIN();
#ifdef _GLIBCXX_USE_CXX11_ABI
  *result = _GLIBCXX_USE_CXX11_ABI;
#else
  *result = -1;
#endif
  API_END();
}

int MXRandomSeed(int seed) {
  API_BEGIN();
  mxnet::RandomSeed(seed);
  API_END();
}

int MXRandomSeedContext(int seed, int dev_type, int dev_id) {
  API_BEGIN();
  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
  mxnet::RandomSeed(ctx, seed);
  API_END();
}

int MXSetFlushDenorms(bool value, bool* prev_state) {
  API_BEGIN();
  *prev_state = false;

#if SUPPORT_FTZ_DMZ
  std::function<bool()> is_dmz_flag_available = []() {
    // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1
    // "Checking for the DAZ Flag in the MXCSR Register"
    constexpr unsigned int mxcsr_mask_offset = 28;
    constexpr unsigned int dmz_flag_offset   = 5;
    constexpr unsigned int fxsave_req_bytes  = 512;

    char* fxsave_area_ptr = reinterpret_cast<char*>(malloc(fxsave_req_bytes));
    memset(fxsave_area_ptr, 0, fxsave_req_bytes);  // fill memory with 0
    _fxsave(fxsave_area_ptr);

    char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset;
    uint32_t mxcsr_mask  = *(reinterpret_cast<uint32_t*>((mxcsr_mask_ptr)));
    // DMZ flag is supported if sixth bit of MXCSR_MASK is hot
    bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1;
    free(fxsave_area_ptr);
    return dmz_flag;
  };

  Engine::Get()->PushSync(
      [value, prev_state, is_dmz_flag_available](RunContext rctx) {
        const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF;
        const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF;
        *prev_state                  = _MM_GET_FLUSH_ZERO_MODE();
        _MM_SET_FLUSH_ZERO_MODE(FTZ_STATE);

        // If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1
        // to it will cause a general-protection exception (#GP)
        if (is_dmz_flag_available()) {
          _MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE);
        }
      },
      Context::CPU(),
      {},
      {},
      FnProperty::kNormal,
      0,
      "SetFlushDenorms");

  Engine::Get()->WaitForAll();

#endif

  API_END();
}

int MXNotifyShutdown() {
  API_BEGIN();
  mxnet::op::custom::CustomOperator::Get()->Stop();
  Engine::Get()->NotifyShutdown();
  Engine::Get()->WaitForAll();
  API_END();
}

int MXSetNumOMPThreads(int thread_num) {
  API_BEGIN();
  omp_set_num_threads(thread_num);
  API_END();
}

int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
  API_BEGIN();
  *prev_bulk_size = Engine::Get()->set_bulk_size(bulk_size);
  API_END();
}

int MXGetGPUCount(int* out) {
  API_BEGIN();
  *out = Context::GetGPUCount();
  API_END();
}

// Deprecated: use MXGetGPUMemoryInformation64() instead.
int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem) {
  API_BEGIN();
  uint64_t free_mem64  = 0UL;
  uint64_t total_mem64 = 0UL;
  Context::GetGPUMemoryInformation(dev, &free_mem64, &total_mem64);
  *free_mem  = static_cast<int>(free_mem64);
  *total_mem = static_cast<int>(total_mem64);
  API_END();
}

int MXGetGPUMemoryInformation64(int dev, uint64_t* free_mem, uint64_t* total_mem) {
  API_BEGIN();
  Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
  API_END();
}

int MXGetVersion(int* out) {
  API_BEGIN();
  *out = static_cast<int>(MXNET_VERSION);
  API_END();
}

int MXGetBranch(const char** out) {
  API_BEGIN();
  *out = MXNET_BRANCH;
  API_END();
}

int MXGetCommitHash(const char** out) {
  API_BEGIN();
  *out = MXNET_COMMIT_HASH;
  API_END();
}

#if MXNET_USE_TVM_OP
int MXLoadTVMOp(const char* libpath) {
  API_BEGIN();
  tvm::runtime::TVMOpModule::Get()->Load(libpath);
  tvm::runtime::TVMOpModule* global_module = tvm::runtime::TVMOpModule::Get();
  global_module->Load(libpath);
#if MXNET_USE_CUDA
  std::string libpathstr(libpath);
  std::string cubinpath = libpathstr.substr(0, libpathstr.size() - 11) + "libtvmop.cubin";
  tvm::runtime::TVMOpModule cubin_module;
  cubin_module.Load(cubinpath);
  global_module->Import(cubin_module);
#endif
  API_END();
}

int MXLoadTVMConfig(ConfigSpaces config) {
  API_BEGIN();
  for (int k = 0; k < config.spaces_size; ++k) {
    tvm::runtime::TVMOpConfig& entry =
        ::dmlc::Registry<tvm::runtime::TVMOpConfig>::Get()->__REGISTER_OR_GET__(
            std::string(config.spaces_key[k]));
    const ConfigSpace& c = config.spaces_val[k];
    for (int i = 0; i < c.entity_map_size; ++i) {
      entry.add_entity(std::string(c.entity_map_key[i]), c.entity_map_val[i].val);
    }
    for (int i = 0; i < c.space_map_size; ++i) {
      std::string name = std::string(c.space_map_key[i]);
      std::vector<int> entities;
      for (int j = 0; j < c.space_map_val[i].entities_size; ++j) {
        int val = c.space_map_val[i].entities[j].val;
        entities.push_back(val);
      }
      entry.add_space(name, entities);
    }
  }
  API_END();
}

#endif  // MXNET_USE_TVM_OP

int MXNDArrayCreateNone(NDArrayHandle* out) {
  API_BEGIN();
  *out = new NDArray();
  API_END();
}

template <typename DataType>
void CreateNDArray(const DataType* shape,
                   int ndim,
                   int dev_type,
                   int dev_id,
                   int delay_alloc,
                   int dtype,
                   NDArrayHandle* out) {
  mxnet::TShape requested_shape = mxnet::TShape(shape, shape + ndim);
  if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
    CHECK_LT(requested_shape.Size(), (int64_t{1} << 31) - 1)
        << "[CreateNDArray] Size of tensor you are trying to allocate is larger than "
           "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
  }
  NDArray* nd = new NDArray(requested_shape,
                            Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
                            delay_alloc != 0,
                            dtype);
  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
  *out = nd;
}

int MXNDArrayCreate64(const int64_t* shape,
                      int ndim,
                      int dev_type,
                      int dev_id,
                      int delay_alloc,
                      int dtype,
                      NDArrayHandle* out) {
  API_BEGIN();
  CreateNDArray<int64_t>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
  API_END();
}

int MXNDArrayCreate(const uint32_t* shape,
                    uint32_t ndim,
                    int dev_type,
                    int dev_id,
                    int delay_alloc,
                    int dtype,
                    NDArrayHandle* out) {
  API_BEGIN();
  CreateNDArray<uint32_t>(shape, static_cast<int>(ndim), dev_type, dev_id, delay_alloc, dtype, out);
  API_END();
}

template <typename DType>
void CreateSparseNDArray(int storage_type,
                         const DType* shape,
                         int ndim,
                         int dev_type,
                         int dev_id,
                         int delay_alloc,
                         int dtype,
                         uint32_t num_aux,
                         int* aux_type,
                         int* aux_ndims,
                         const DType* aux_shape,
                         NDArrayHandle* out) {
  std::vector<int> aux_types;
  mxnet::ShapeVector aux_shapes;
  auto shape_start = aux_shape;
  for (size_t i = 0; i < num_aux; i++) {
    // types
    aux_types.push_back(aux_type[i]);
    // shapes
    aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]);
    shape_start += aux_ndims[i];
  }
  NDArray* nd = new NDArray(NDArrayStorageType(storage_type),
                            mxnet::TShape(shape, shape + ndim),
                            Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
                            delay_alloc != 0,
                            dtype,
                            aux_types,
                            aux_shapes);
  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
  *out = nd;
}

int MXNDArrayCreateSparseEx(int storage_type,
                            const uint32_t* shape,
                            uint32_t ndim,
                            int dev_type,
                            int dev_id,
                            int delay_alloc,
                            int dtype,
                            uint32_t num_aux,
                            int* aux_type,
                            uint32_t* aux_ndims,
                            const uint32_t* aux_shape,
                            NDArrayHandle* out) {
  API_BEGIN();
  CreateSparseNDArray<uint32_t>(storage_type,
                                shape,
                                static_cast<int>(ndim),
                                dev_type,
                                dev_id,
                                delay_alloc,
                                dtype,
                                num_aux,
                                aux_type,
                                reinterpret_cast<int*>(aux_ndims),
                                aux_shape,
                                out);
  API_END();
}

int MXNDArrayCreateSparseEx64(int storage_type,
                              const int64_t* shape,
                              int ndim,
                              int dev_type,
                              int dev_id,
                              int delay_alloc,
                              int dtype,
                              uint32_t num_aux,
                              int* aux_type,
                              int* aux_ndims,
                              const int64_t* aux_shape,
                              NDArrayHandle* out) {
  API_BEGIN();
  CreateSparseNDArray<int64_t>(storage_type,
                               shape,
                               static_cast<int>(ndim),
                               dev_type,
                               dev_id,
                               delay_alloc,
                               dtype,
                               num_aux,
                               aux_type,
                               reinterpret_cast<int*>(aux_ndims),
                               aux_shape,
                               out);
  API_END();
}

int MXNDArrayLoadFromRawBytes(const void* buf, size_t size, NDArrayHandle* out) {
  NDArray* ptr = nullptr;
  API_BEGIN();
  dmlc::MemoryFixedSizeStream strm((void*)buf, size);  // NOLINT(*)
  ptr = new NDArray();
  if (!ptr->Load(&strm)) {
    throw dmlc::Error("Invalid NDArray serialization format");
  }
  *out = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t* out_size, const char** out_buf) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  API_BEGIN();
  ret->ret_str.resize(0);
  dmlc::MemoryStringStream strm(&ret->ret_str);
  static_cast<NDArray*>(handle)->Save(&strm);
  *out_size = ret->ret_str.length();
  *out_buf  = ret->ret_str.c_str();
  API_END();
}

int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void* data, size_t size) {
  API_BEGIN();
  static_cast<NDArray*>(handle)->SyncCopyFromCPU(data, size);
  API_END();
}

int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void* data, size_t size) {
  API_BEGIN();
  static_cast<NDArray*>(handle)->SyncCopyToCPU(data, size);
  API_END();
}

/*!
 * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
 * This function blocks. Do not use it in performance critical code.
 * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated
 * \param handle_src handle of a src ndarray which has default storage type
 * \param i dst data blob indicator
 */
int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
                                 const NDArrayHandle handle_src,
                                 const int i) {
  API_BEGIN();
  NDArray* dst = static_cast<NDArray*>(handle_dst);
  NDArray* src = static_cast<NDArray*>(handle_src);
  dst->SyncCopyFromNDArray(*src, -1, i);
  API_END();
}

int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  arr->SyncCheckFormat(full_check);
  API_END();
}

int MXNDArrayWaitToRead(NDArrayHandle handle) {
  API_BEGIN();
  static_cast<NDArray*>(handle)->WaitToRead();
  API_END();
}

int MXNDArrayWaitToWrite(NDArrayHandle handle) {
  API_BEGIN();
  static_cast<NDArray*>(handle)->WaitToWrite();
  API_END();
}

int MXNDArrayWaitAll() {
  API_BEGIN();
  Engine::Get()->WaitForAll();
  API_END();
}

int MXNDArrayLegacySave(const char* fname,
                        uint32_t num_args,
                        NDArrayHandle* args,
                        const char** keys) {
  API_BEGIN();
  std::vector<NDArray> data(num_args);
  std::vector<std::string> names;
  for (uint32_t i = 0; i < num_args; ++i) {
    data[i] = *static_cast<NDArray*>(args[i]);
  }
  if (keys != nullptr) {
    names.resize(num_args);
    for (uint32_t i = 0; i < num_args; ++i) {
      names[i] = keys[i];
    }
  }
  {
    std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
    mxnet::NDArray::Save(fo.get(), data, names);
  }
  API_END();
}

int MXNDArraySave(const char* fname, uint32_t num_args, NDArrayHandle* args, const char** keys) {
  API_BEGIN();

  CHECK_NOTNULL(fname);

  // We may use mz_zip_writer_init_v2 later instead of mz_zip_writer_init_file
  // and write an adapter for DMLC stream based on pZip->m_pWrite (and
  // pZip->m_pIO_opaque)
  if (num_args == 1 && keys == nullptr) {
    NDArray* array = static_cast<NDArray*>(args[0]);
    if (array->storage_type() == kDefaultStorage) {
      npy::save_array(fname, *array);
    } else {
      mz_zip_archive archive{};
      CHECK(mz_zip_writer_init_file(&archive, fname, 0))
          << "Failed to open archive " << fname << ": "
          << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
      npz::save_array(&archive, "", *array);
      CHECK(mz_zip_writer_finalize_archive(&archive))
          << "Failed to finalize archive " << fname
          << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
      CHECK(mz_zip_writer_end(&archive))
          << "Failed to end archive " << fname
          << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
    }
  } else {
    mz_zip_archive archive{};
    CHECK(mz_zip_writer_init_file(&archive, fname, 0))
        << "Failed to open archive " << fname << ": "
        << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
    for (uint32_t i = 0; i < num_args; ++i) {
      NDArray* array              = static_cast<NDArray*>(args[i]);
      const std::string array_key = keys == nullptr ? "arr_" + std::to_string(i) : keys[i];
      npz::save_array(&archive, array_key, *array);
    }
    CHECK(mz_zip_writer_finalize_archive(&archive))
        << "Failed to finalize archive " << fname
        << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
    CHECK(mz_zip_writer_end(&archive)) << "Failed to end archive " << fname
                                       << mz_zip_get_error_string(mz_zip_get_last_error(&archive));
  }
  API_END();
}

int MXNDArrayLoad(const char* fname,
                  uint32_t* out_size,
                  NDArrayHandle** out_arr,
                  uint32_t* out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();

  uint32_t magic;
  {
    std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(fname, "r"));
    CHECK_EQ(strm->Read(&magic, sizeof(uint32_t)), sizeof(uint32_t))
        << "Failed to read 32 bits from file.";
  }

  if (magic == 0x04034b50 || magic == 0x504b0304 || magic == 0x06054b50 ||
      magic == 0x504b0506) {                       // zip file format; assumed to be npz
    auto [data, names] = npz::load_arrays(fname);  // NOLINT
    ret->ret_handles.resize(data.size());
    for (size_t i = 0; i < data.size(); ++i) {
      NDArray* ptr        = new NDArray();
      *ptr                = data[i];
      ret->ret_handles[i] = ptr;
    }
    ret->ret_vec_str.resize(names.size());
    for (size_t i = 0; i < names.size(); ++i) {
      ret->ret_vec_str[i] = names[i];
    }
    ret->ret_vec_charp.resize(names.size());
    for (size_t i = 0; i < names.size(); ++i) {
      ret->ret_vec_charp[i] = ret->ret_vec_str[i].c_str();
    }
    *out_size      = static_cast<uint32_t>(data.size());
    *out_arr       = dmlc::BeginPtr(ret->ret_handles);
    *out_name_size = static_cast<uint32_t>(names.size());
    *out_names     = dmlc::BeginPtr(ret->ret_vec_charp);
  } else if (magic == 0x4d554e93 || magic == 0x934e554d) {  // first bytes of npy format
    *out_size = 1;
    ret->ret_handles.resize(1);
    NDArray* ptr = new NDArray();
    *ptr         = npy::load_array(fname);  // Only supports local filesystem at this point in time
    ret->ret_handles[0] = ptr;
    *out_arr            = dmlc::BeginPtr(ret->ret_handles);
  } else {
    std::vector<NDArray> data;
    std::vector<std::string>& names = ret->ret_vec_str;
    {
      std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
      mxnet::NDArray::Load(fi.get(), &data, &names);
    }
    ret->ret_handles.resize(data.size());
    for (size_t i = 0; i < data.size(); ++i) {
      NDArray* ptr        = new NDArray();
      *ptr                = data[i];
      ret->ret_handles[i] = ptr;
    }
    ret->ret_vec_charp.resize(names.size());
    for (size_t i = 0; i < names.size(); ++i) {
      ret->ret_vec_charp[i] = names[i].c_str();
    }
    *out_size      = static_cast<uint32_t>(data.size());
    *out_arr       = dmlc::BeginPtr(ret->ret_handles);
    *out_name_size = static_cast<uint32_t>(names.size());
    *out_names     = dmlc::BeginPtr(ret->ret_vec_charp);
  }
  API_END();
}

int MXNDArrayLoadFromBuffer(const void* ndarray_buffer,
                            size_t size,
                            uint32_t* out_size,
                            NDArrayHandle** out_arr,
                            uint32_t* out_name_size,
                            const char*** out_names) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();
  CHECK_NOTNULL(ndarray_buffer);
  std::vector<NDArray> data;
  std::vector<std::string>& names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(
        new dmlc::MemoryFixedSizeStream(const_cast<void*>(ndarray_buffer), size));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ret->ret_handles.resize(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    NDArray* ptr        = new NDArray();
    *ptr                = data[i];
    ret->ret_handles[i] = ptr;
  }
  ret->ret_vec_charp.resize(names.size());
  for (size_t i = 0; i < names.size(); ++i) {
    ret->ret_vec_charp[i] = names[i].c_str();
  }
  *out_size      = static_cast<uint32_t>(data.size());
  *out_arr       = dmlc::BeginPtr(ret->ret_handles);
  *out_name_size = static_cast<uint32_t>(names.size());
  *out_names     = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

int MXNDArrayFree(NDArrayHandle handle) {
  API_BEGIN();
  delete static_cast<NDArray*>(handle);
  API_END();
}

template <typename dtype>
void SliceArray(NDArrayHandle handle,
                dtype slice_begin,
                dtype slice_end,
                NDArray* ptr,
                NDArrayHandle* out) {
  *ptr = static_cast<NDArray*>(handle)->SliceWithRecord(slice_begin, slice_end);
  *out = ptr;
}

int MXNDArraySlice(NDArrayHandle handle,
                   uint32_t slice_begin,
                   uint32_t slice_end,
                   NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  SliceArray<uint32_t>(handle, slice_begin, slice_end, ptr, out);
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArraySlice64(NDArrayHandle handle,
                     int64_t slice_begin,
                     int64_t slice_end,
                     NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  SliceArray<int64_t>(handle, slice_begin, slice_end, ptr, out);
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayAt(NDArrayHandle handle, uint32_t idx, NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  *ptr = static_cast<NDArray*>(handle)->AtWithRecord(idx);
  *out = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayAt64(NDArrayHandle handle, int64_t idx, NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  *ptr = static_cast<NDArray*>(handle)->AtWithRecord(idx);
  *out = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayReshape(NDArrayHandle handle, int ndim, int* dims, NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  mxnet::TShape new_shape(dims, dims + ndim);
  int size = 1;
  int pos  = -1;
  for (int i = 0; i < ndim; ++i) {
    int dim = dims[i];
    if (dim == -1) {
      CHECK_EQ(pos, -1) << "Invalid new shape " << new_shape << ": more than one dimensions are -1";
      pos = i;
    } else {
      if (dim == 0) {
        CHECK_LT(i, arr->shape().ndim()) << "Invalid new shape " << new_shape
                                         << ": 0 dimension exceeds original shape " << arr->shape();
        dim = arr->shape()[i];
      }
      size *= dim;
      new_shape[i] = dim;
    }
  }
  if (pos >= 0) {
    new_shape[pos] = arr->shape().Size() / size;
  }
  *ptr = arr->ReshapeWithRecord(new_shape);
  *out = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayReshape64(NDArrayHandle handle,
                       int ndim,
                       dim_t* dims,
                       bool reverse,
                       NDArrayHandle* out) {
  NDArray* ptr = new NDArray();
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  mxnet::Tuple<dim_t> shape(dims, dims + ndim);
  mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
  *ptr                    = arr->ReshapeWithRecord(new_shape);
  *out                    = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayGetStorageType(NDArrayHandle handle, int* out_storage_type) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  if (!arr->is_none()) {
    *out_storage_type = arr->storage_type();
  } else {
    *out_storage_type = kUndefinedStorage;
  }
  API_END();
}

template <typename dtype>
inline void GetShape(NDArrayHandle handle,
                     const dtype** out_pdata,
                     int* out_dim,
                     MXAPIThreadLocalEntry<dtype>* ret) {
  NDArray* arr = static_cast<NDArray*>(handle);
  if (!arr->is_none()) {
    mxnet::TShape s = arr->shape();
    // Handle dynamic shape in deferred compute mode
    if (!Imperative::DCInfo::IsNone(*arr)) {
      if (!shape_is_known(s) && !Imperative::DCInfo::IsComputed(*arr)) {
        Imperative::DCInfo::Compute(*arr);
        s = arr->shape();
      }
    }

    if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
      CHECK_LT(s.Size(), (int64_t{1} << 31) - 1)
          << "[Get Shape] Size of tensor you are trying to allocate is larger than "
             "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
    }

    if (!Imperative::Get()->is_np_shape()) {
      common::ConvertToLegacyShape(&s);
    }
    *out_dim = s.ndim();
    if (s.ndim() >= 0) {
      std::vector<dtype>& buffer = ret->arg_shape_buffer_ex;
      buffer.resize(s.ndim());
      mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
      *out_pdata = buffer.data();
    }
  } else {
    if (Imperative::Get()->is_np_shape()) {
      *out_dim = -1;
    } else {
      *out_dim = 0;
    }
  }
}

int MXNDArrayGetShape(NDArrayHandle handle, int* out_dim, const int** out_pdata) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  API_BEGIN();
  GetShape<int>(handle, out_pdata, out_dim, ret);
  API_END();
}

int MXNDArrayGetShape64(NDArrayHandle handle, int* out_dim, const int64_t** out_pdata) {
  MXAPIThreadLocalEntry<int64_t>* ret = MXAPIThreadLocalStore<int64_t>::Get();
  API_BEGIN();
  GetShape<int64_t>(handle, out_pdata, out_dim, ret);
  API_END();
}

int MXNDArrayGetData(NDArrayHandle handle, void** out_pdata) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
#if MXNET_USE_ONEDNN == 1
  if (arr->IsDNNLData()) {
    arr->Reorder2DefaultAsync();
    arr->WaitToRead();
  }
#endif
  if (!arr->is_none()) {
    *out_pdata = arr->data().dptr_;
  } else {
    *out_pdata = nullptr;
  }
  API_END();
}

int MXNDArrayToDLPack(NDArrayHandle handle, DLManagedTensorHandle* out_dlpack) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out_dlpack  = arr->ToDLPack();
  API_END();
}

int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
                        const bool transient_handle,
                        NDArrayHandle* out_handle) {
  API_BEGIN();
  *out_handle =
      new NDArray(NDArray::FromDLPack(static_cast<DLManagedTensor*>(dlpack), transient_handle));
  API_END();
}

int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
  API_BEGIN();
  if (dlpack != nullptr) {
    DLManagedTensor* p_dlpack = static_cast<DLManagedTensor*>(dlpack);
    p_dlpack->deleter(p_dlpack);
  }
  API_END();
}

int MXNDArrayGetDType(NDArrayHandle handle, int* out_dtype) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  if (!arr->is_none()) {
    *out_dtype = arr->dtype();
  } else {
    *out_dtype = -1;
  }
  API_END();
}

int MXNDArrayGetAuxType(NDArrayHandle handle, uint32_t i, int* out_type) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out_type    = arr->aux_type(i);
  API_END();
}

/*!
 * \brief Get a deep copy of the ith aux data blob
 * in the form of an NDArray of default storage type.
 * This function blocks. Do not use it in performance critical code.
 */
int MXNDArrayGetAuxNDArray(NDArrayHandle handle, uint32_t i, NDArrayHandle* out) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out         = new NDArray(arr->aux_ndarray(i));
  API_END();
}

/*!
 * \brief Get a deep copy of the data blob
 * in the form of an NDArray of default storage type.
 * This function blocks. Do not use it in performance critical code.
 */
int MXNDArrayGetDataNDArray(NDArrayHandle handle, NDArrayHandle* out) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out         = new NDArray(arr->data_ndarray());
  API_END();
}

int MXNDArrayGetContext(NDArrayHandle handle, int* out_dev_type, int* out_dev_id) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  if (!arr->is_none()) {
    const Context& ctx = arr->ctx();
    *out_dev_type      = ctx.dev_type;
    *out_dev_id        = ctx.dev_id;
  } else {
    *out_dev_type = 0;
    *out_dev_id   = 0;
  }
  API_END();
}

int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle* out) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  NDArray ret  = arr->grad();
  if (ret.is_none()) {
    *out = nullptr;
  } else {
    *out = new NDArray(ret);
  }
  API_END();
}

int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle* out) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out         = new NDArray(arr->Detach());
  API_END();
}

int MXNDArraySetGradState(NDArrayHandle handle, int state) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  arr->set_fresh_out_grad(static_cast<bool>(state));
  API_END();
}

int MXNDArrayGetGradState(NDArrayHandle handle, int* out) {
  API_BEGIN();
  NDArray* arr = static_cast<NDArray*>(handle);
  *out         = arr->fresh_out_grad();
  API_END();
}

int MXListFunctions(uint32_t* out_size, FunctionHandle** out_array) {
  API_BEGIN();
  auto& vec  = dmlc::Registry<NDArrayFunctionReg>::List();
  *out_size  = static_cast<uint32_t>(vec.size());
  *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
  API_END();
}

int MXGetFunction(const char* name, FunctionHandle* out) {
  API_BEGIN();
  *out = dmlc::Registry<NDArrayFunctionReg>::Find(name);
  API_END();
}

int MXFuncGetInfo(FunctionHandle fun,
                  const char** name,
                  const char** description,
                  uint32_t* num_args,
                  const char*** arg_names,
                  const char*** arg_type_infos,
                  const char*** arg_descriptions,
                  const char** return_type) {
  return MXAPIGetFunctionRegInfo(static_cast<const NDArrayFunctionReg*>(fun),
                                 name,
                                 description,
                                 num_args,
                                 arg_names,
                                 arg_type_infos,
                                 arg_descriptions,
                                 return_type);
}

int MXFuncDescribe(FunctionHandle fun,
                   uint32_t* num_use_vars,
                   uint32_t* num_scalars,
                   uint32_t* num_mutate_vars,
                   int* type_mask) {
  API_BEGIN();
  auto* f          = static_cast<const NDArrayFunctionReg*>(fun);
  *num_use_vars    = f->num_use_vars;
  *num_scalars     = f->num_scalars;
  *num_mutate_vars = f->num_mutate_vars;
  *type_mask       = f->type_mask;
  API_END();
}

int MXFuncInvoke(FunctionHandle fun,
                 NDArrayHandle* use_vars,
                 float* scalar_args,
                 NDArrayHandle* mutate_vars,
                 int num_params,
                 char** param_keys,
                 char** param_vals) {
  API_BEGIN();
  auto* f = static_cast<const NDArrayFunctionReg*>(fun);
  f->body((NDArray**)(use_vars),  //  NOLINT(*)
          scalar_args,
          (NDArray**)(mutate_vars),  //  NOLINT(*)
          num_params,
          param_keys,
          param_vals);
  API_END();
}

//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
int MXListDataIters(uint32_t* out_size, DataIterCreator** out_array) {
  API_BEGIN();
  auto& vec  = dmlc::Registry<DataIteratorReg>::List();
  *out_size  = static_cast<uint32_t>(vec.size());
  *out_array = (DataIterCreator*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
  API_END();
}

int MXDataIterGetIterInfo(DataIterCreator creator,
                          const char** name,
                          const char** description,
                          uint32_t* num_args,
                          const char*** arg_names,
                          const char*** arg_type_infos,
                          const char*** arg_descriptions) {
  DataIteratorReg* e = static_cast<DataIteratorReg*>(creator);
  return MXAPIGetFunctionRegInfo(
      e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr);
}

int MXDataIterCreateIter(DataIterCreator creator,
                         uint32_t num_param,
                         const char** keys,
                         const char** vals,
                         DataIterHandle* out) {
  IIterator<DataBatch>* iter = nullptr;
  API_BEGIN();
  DataIteratorReg* e = static_cast<DataIteratorReg*>(creator);
  iter               = e->body();
  std::vector<std::pair<std::string, std::string>> kwargs;
  for (uint32_t i = 0; i < num_param; ++i) {
    kwargs.emplace_back(std::string(keys[i]), std::string(vals[i]));
  }
  iter->Init(kwargs);
  *out = iter;
  API_END_HANDLE_ERROR(delete iter);
}

int MXDataIterFree(DataIterHandle handle) {
  API_BEGIN();
  delete static_cast<IIterator<DataBatch>*>(handle);
  API_END();
}

int MXDataIterBeforeFirst(DataIterHandle handle) {
  API_BEGIN();
  static_cast<IIterator<DataBatch>*>(handle)->BeforeFirst();
  API_END();
}

int MXDataIterGetLenHint(DataIterHandle handle, int64_t* len) {
  API_BEGIN();
  *len = static_cast<IIterator<DataBatch>*>(handle)->GetLenHint();
  API_END();
}

int MXDataIterNext(DataIterHandle handle, int* out) {
  API_BEGIN();
  *out = static_cast<IIterator<DataBatch>*>(handle)->Next();
  API_END();
}

int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle* out) {
  API_BEGIN();
  const DataBatch& db = static_cast<IIterator<DataBatch>*>(handle)->Value();
  bool no_label       = db.data.size() < 2U;
  NDArray* pndarray   = new NDArray();
  // temp hack to make label 1D
  // TODO(tianjun) make label 1D when label_width=0
  mxnet::TShape shape = no_label ? TShape({
                                       1,
                                   }) :
                                   db.data[1].shape();
  if (no_label || shape.Size() < 1) {
    // it's possible that label is not available and not required
    // but we need to bypass the invalid copy
    *pndarray = NDArray(TShape({1}), mxnet::Context::CPU(0));
  } else if (shape.ndim() > 1 && shape[1] == 1) {
    *pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0]));
  } else {
    *pndarray = db.data[1];
  }
  *out = pndarray;
  API_END();
}

int MXDataIterGetItems(DataIterHandle handle, int* num_outputs, NDArrayHandle** outputs) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  API_BEGIN();
  const DataBatch& db = static_cast<IIterator<DataBatch>*>(handle)->Value();
  std::vector<NDArray*> ndoutputs;
  ndoutputs.reserve(db.data.size());
  if (*outputs == nullptr) {
    *num_outputs = db.data.size();
    for (int i = 0; i < *num_outputs; ++i)
      ndoutputs.push_back(new NDArray());
  } else {
    CHECK_EQ(*num_outputs, db.data.size()) << "MXDataIterGetItems expects " << db.data.size()
                                           << " outputs, but " << *num_outputs << " was given.";
    for (int i = 0; i < *num_outputs; ++i) {
      ndoutputs.push_back(reinterpret_cast<NDArray*>((*outputs)[i]));
    }
  }

  // copy outputs
  for (int i = 0; i < *num_outputs; ++i)
    *ndoutputs[i] = db.data[i];

  if (*outputs == nullptr) {
    ret->ret_handles.clear();
    ret->ret_handles.reserve(*num_outputs);
    for (int i = 0; i < *num_outputs; ++i) {
      ret->ret_handles.push_back(ndoutputs[i]);
    }
    *outputs = dmlc::BeginPtr(ret->ret_handles);
  }
  API_END();
}

int MXDataIterGetIndex(DataIterHandle handle, uint64_t** out_index, uint64_t* out_size) {
  API_BEGIN();
  const DataBatch& db = static_cast<IIterator<DataBatch>*>(handle)->Value();
  *out_size           = db.index.size();
  *out_index          = const_cast<uint64_t*>(db.index.data());
  API_END();
}

int MXDataIterGetData(DataIterHandle handle, NDArrayHandle* out) {
  API_BEGIN();
  const DataBatch& db = static_cast<IIterator<DataBatch>*>(handle)->Value();
  NDArray* pndarray   = new NDArray();
  *pndarray           = db.data[0];
  *out                = pndarray;
  API_END();
}

int MXDataIterGetPadNum(DataIterHandle handle, int* pad) {
  API_BEGIN();
  const DataBatch& db = static_cast<IIterator<DataBatch>*>(handle)->Value();
  *pad                = db.num_batch_padd;
  API_END();
}

int MXListDatasets(uint32_t* out_size, DatasetCreator** out_array) {
  API_BEGIN();
  auto& vec  = dmlc::Registry<DatasetReg>::List();
  *out_size  = static_cast<uint32_t>(vec.size());
  *out_array = (DatasetCreator*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
  API_END();
}

int MXDatasetCreateDataset(DatasetCreator handle,
                           uint32_t num_param,
                           const char** keys,
                           const char** vals,
                           DatasetHandle* out) {
  Dataset* dataset = nullptr;
  API_BEGIN();
  DatasetReg* e = static_cast<DatasetReg*>(handle);
  std::vector<std::pair<std::string, std::string>> kwargs;
  for (uint32_t i = 0; i < num_param; ++i) {
    kwargs.emplace_back(std::string(keys[i]), std::string(vals[i]));
  }
  dataset = e->body(kwargs);
  *out    = new std::shared_ptr<Dataset>(dataset);
  API_END_HANDLE_ERROR(delete dataset);
}

int MXDatasetGetDatasetInfo(DatasetCreator creator,
                            const char** name,
                            const char** description,
                            uint32_t* num_args,
                            const char*** arg_names,
                            const char*** arg_type_infos,
                            const char*** arg_descriptions) {
  DatasetReg* e = static_cast<DatasetReg*>(creator);
  return MXAPIGetFunctionRegInfo(
      e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr);
}

int MXDatasetFree(DatasetHandle handle) {
  API_BEGIN();
  delete static_cast<std::shared_ptr<Dataset>*>(handle);
  API_END();
}

int MXDatasetGetLen(DatasetHandle handle, uint64_t* out) {
  API_BEGIN();
  uint64_t len = (*static_cast<std::shared_ptr<Dataset>*>(handle))->GetLen();
  *out         = len;
  API_END();
}

int MXDatasetGetItems(DatasetHandle handle,
                      uint64_t index,
                      int* num_outputs,
                      NDArrayHandle** outputs) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  API_BEGIN();
  std::vector<NDArray> res;
  CHECK((*static_cast<std::shared_ptr<Dataset>*>(handle))->GetItem(index, &res))
      << "Error getting item at index: " << index;
  std::vector<NDArray*> ndoutputs;
  ndoutputs.reserve(res.size());
  if (*outputs == nullptr) {
    *num_outputs = res.size();
    for (int i = 0; i < *num_outputs; ++i)
      ndoutputs.push_back(new NDArray());
  } else {
    CHECK_EQ(*num_outputs, res.size()) << "MXDatasetGetItems expects " << res.size()
                                       << " outputs, but " << *num_outputs << " was given.";
    for (int i = 0; i < *num_outputs; ++i) {
      ndoutputs.push_back(reinterpret_cast<NDArray*>((*outputs)[i]));
    }
  }
  // copy ndarrays
  for (int i = 0; i < *num_outputs; ++i)
    *(ndoutputs[i]) = res[i];

  if (*outputs == nullptr) {
    ret->ret_handles.clear();
    ret->ret_handles.reserve(*num_outputs);
    for (int i = 0; i < *num_outputs; ++i) {
      ret->ret_handles.push_back(ndoutputs[i]);
    }
    *outputs = dmlc::BeginPtr(ret->ret_handles);
  }
  API_END();
}

int MXListBatchifyFunctions(uint32_t* out_size, BatchifyFunctionCreator** out_array) {
  API_BEGIN();
  auto& vec  = dmlc::Registry<BatchifyFunctionReg>::List();
  *out_size  = static_cast<uint32_t>(vec.size());
  *out_array = (BatchifyFunctionCreator*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
  API_END();
}

int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle,
                                     uint32_t num_param,
                                     const char** keys,
                                     const char** vals,
                                     BatchifyFunctionHandle* out) {
  BatchifyFunction* bf = nullptr;
  API_BEGIN();
  BatchifyFunctionReg* e = static_cast<BatchifyFunctionReg*>(handle);
  std::vector<std::pair<std::string, std::string>> kwargs;
  for (uint32_t i = 0; i < num_param; ++i) {
    kwargs.emplace_back(std::string(keys[i]), std::string(vals[i]));
  }
  bf   = e->body(kwargs);
  *out = new BatchifyFunctionPtr(bf);
  API_END_HANDLE_ERROR(delete bf);
}

int MXBatchifyFunctionGetFunctionInfo(BatchifyFunctionCreator creator,
                                      const char** name,
                                      const char** description,
                                      uint32_t* num_args,
                                      const char*** arg_names,
                                      const char*** arg_type_infos,
                                      const char*** arg_descriptions) {
  BatchifyFunctionReg* e = static_cast<BatchifyFunctionReg*>(creator);
  return MXAPIGetFunctionRegInfo(
      e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr);
}
int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle,
                             int batch_size,
                             int num_output,
                             NDArrayHandle* inputs,
                             NDArrayHandle** outputs) {
  MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get();
  API_BEGIN();
  CHECK_GT(batch_size, 0);
  CHECK_GT(num_output, 0);
  std::vector<std::vector<NDArray>> ndinputs;
  ndinputs.reserve(batch_size);
  int pos = 0;
  for (int i = 0; i < batch_size; ++i) {
    std::vector<NDArray> tmp;
    tmp.reserve(num_output);
    for (int j = 0; j < num_output; ++j) {
      tmp.emplace_back(*reinterpret_cast<NDArray*>(inputs[pos++]));
      tmp.back().WaitToRead();
    }
    ndinputs.emplace_back(tmp);
  }
  std::vector<NDArray> res;
  CHECK((*static_cast<BatchifyFunctionPtr*>(handle))->Batchify(ndinputs, &res))
      << "Error call batchify with " << ndinputs.size() << " inputs";
  std::vector<NDArray*> ndoutputs;
  ndoutputs.reserve(res.size());
  if (*outputs == nullptr) {
    for (int i = 0; i < num_output; ++i)
      ndoutputs.push_back(new NDArray());
  } else {
    CHECK_EQ(num_output, res.size()) << "MXBatchifyFunctionInvoke expects " << res.size()
                                     << " outputs, but " << num_output << " was given.";
    for (int i = 0; i < num_output; ++i) {
      ndoutputs.push_back(reinterpret_cast<NDArray*>((*outputs)[i]));
    }
  }

  // copy ndarrays
  for (int i = 0; i < num_output; ++i)
    *(ndoutputs[i]) = res[i];

  if (*outputs == nullptr) {
    ret->ret_handles.clear();
    ret->ret_handles.reserve(num_output);
    for (int i = 0; i < num_output; ++i) {
      ret->ret_handles.push_back(ndoutputs[i]);
    }
    *outputs = dmlc::BeginPtr(ret->ret_handles);
  }
  API_END();
}

int MXBatchifyFunctionFree(BatchifyFunctionHandle handle) {
  API_BEGIN();
  delete static_cast<BatchifyFunctionPtr*>(handle);
  API_END();
}
//--------------------------------------------
// Part 6: basic KVStore interface
//--------------------------------------------

int MXKVStoreCreate(const char* type, KVStoreHandle* out) {
  API_BEGIN();
  *out = KVStore::Create(type);
  API_END();
}

int MXKVStoreSetGradientCompression(KVStoreHandle handle,
                                    uint32_t num_params,
                                    const char** keys,
                                    const char** vals) {
  API_BEGIN();
  std::vector<std::pair<std::string, std::string>> params;
  for (uint32_t i = 0; i < num_params; ++i) {
    std::pair<std::string, std::string> p;
    p.first  = keys[i];
    p.second = vals[i];
    params.push_back(p);
  }
  static_cast<KVStore*>(handle)->SetGradientCompression(params);
  API_END();
}

int MXKVStoreFree(KVStoreHandle handle) {
  API_BEGIN();
  delete static_cast<KVStore*>(handle);
  API_END();
}

int MXKVStoreInit(KVStoreHandle handle, uint32_t num, const int* keys, NDArrayHandle* vals) {
  API_BEGIN();
  std::vector<int> v_keys(num);
  std::vector<NDArray> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = *static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
  API_END();
}

int MXKVStoreInitEx(KVStoreHandle handle, uint32_t num, const char** keys, NDArrayHandle* vals) {
  API_BEGIN();
  std::vector<std::string> v_keys(num);
  std::vector<NDArray> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = *static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
  API_END();
}

int MXKVStorePush(KVStoreHandle handle,
                  uint32_t num,
                  const int* keys,
                  NDArrayHandle* vals,
                  int priority) {
  API_BEGIN();
  std::vector<int> v_keys(num);
  std::vector<NDArray> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = *static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
  API_END();
}

int MXKVStorePushEx(KVStoreHandle handle,
                    uint32_t num,
                    const char** keys,
                    NDArrayHandle* vals,
                    int priority) {
  API_BEGIN();
  std::vector<std::string> v_keys(num);
  std::vector<NDArray> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = *static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
  API_END();
}

int MXKVStorePull(KVStoreHandle handle,
                  uint32_t num,
                  const int* keys,
                  NDArrayHandle* vals,
                  int priority) {
  API_BEGIN();
  std::vector<int> v_keys(num);
  std::vector<NDArray*> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
  API_END();
}

int MXKVStorePullEx(KVStoreHandle handle,
                    uint32_t num,
                    const char** keys,
                    NDArrayHandle* vals,
                    int priority) {
  API_BEGIN();
  std::vector<std::string> v_keys(num);
  std::vector<NDArray*> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
  API_END();
}

int MXKVStoreBroadcast(KVStoreHandle handle,
                       mx_uint vnum,
                       const int* vkeys,
                       mx_uint onum,
                       const int* okeys,
                       NDArrayHandle* vals,
                       NDArrayHandle* outs,
                       int priority) {
  API_BEGIN();
  std::vector<int> v_vkeys(vnum);
  std::vector<int> v_okeys(onum);
  std::vector<NDArray> v_vals(vnum);
  std::vector<NDArray*> v_outs(onum);
  for (mx_uint i = 0; i < vnum; ++i) {
    v_vkeys[i] = vkeys[i];
    v_vals[i]  = *static_cast<NDArray*>(vals[i]);
  }
  for (mx_uint i = 0; i < onum; ++i) {
    v_okeys[i] = okeys[i];
    v_outs[i]  = static_cast<NDArray*>(outs[i]);
  }
  static_cast<KVStore*>(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, priority);
  API_END();
}

int MXKVStoreBroadcastEx(KVStoreHandle handle,
                         mx_uint vnum,
                         const char** vkeys,
                         mx_uint onum,
                         const char** okeys,
                         NDArrayHandle* vals,
                         NDArrayHandle* outs,
                         int priority) {
  API_BEGIN();
  std::vector<std::string> v_vkeys(vnum);
  std::vector<std::string> v_okeys(onum);
  std::vector<NDArray> v_vals(vnum);
  std::vector<NDArray*> v_outs(onum);
  for (mx_uint i = 0; i < vnum; ++i) {
    v_vkeys[i] = vkeys[i];
    v_vals[i]  = *static_cast<NDArray*>(vals[i]);
  }
  for (mx_uint i = 0; i < onum; ++i) {
    v_okeys[i] = okeys[i];
    v_outs[i]  = static_cast<NDArray*>(outs[i]);
  }
  static_cast<KVStore*>(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, priority);
  API_END();
}

int MXKVStorePushPull(KVStoreHandle handle,
                      mx_uint vnum,
                      const int* vkeys,
                      mx_uint onum,
                      const int* okeys,
                      NDArrayHandle* vals,
                      NDArrayHandle* outs,
                      int priority) {
  API_BEGIN();
  std::vector<int> v_vkeys(vnum);
  std::vector<int> v_okeys(onum);
  std::vector<NDArray> v_vals(vnum);
  std::vector<NDArray*> v_outs(onum);
  for (mx_uint i = 0; i < vnum; ++i) {
    v_vkeys[i] = vkeys[i];
    v_vals[i]  = *static_cast<NDArray*>(vals[i]);
  }
  for (mx_uint i = 0; i < onum; ++i) {
    v_okeys[i] = okeys[i];
    v_outs[i]  = static_cast<NDArray*>(outs[i]);
  }
  static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, priority);
  API_END();
}

int MXKVStorePushPullEx(KVStoreHandle handle,
                        mx_uint vnum,
                        const char** vkeys,
                        mx_uint onum,
                        const char** okeys,
                        NDArrayHandle* vals,
                        NDArrayHandle* outs,
                        int priority) {
  API_BEGIN();
  std::vector<std::string> v_vkeys(vnum);
  std::vector<std::string> v_okeys(onum);
  std::vector<NDArray> v_vals(vnum);
  std::vector<NDArray*> v_outs(onum);
  for (mx_uint i = 0; i < vnum; ++i) {
    v_vkeys[i] = vkeys[i];
    v_vals[i]  = *static_cast<NDArray*>(vals[i]);
  }
  for (mx_uint i = 0; i < onum; ++i) {
    v_okeys[i] = okeys[i];
    v_outs[i]  = static_cast<NDArray*>(outs[i]);
  }
  static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, priority);
  API_END();
}

int MXKVStorePullWithSparse(KVStoreHandle handle,
                            uint32_t num,
                            const int* keys,
                            NDArrayHandle* vals,
                            int priority,
                            bool ignore_sparse) {
  API_BEGIN();
  std::vector<int> v_keys(num);
  std::vector<NDArray*> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
  API_END();
}

int MXKVStorePullWithSparseEx(KVStoreHandle handle,
                              uint32_t num,
                              const char** keys,
                              NDArrayHandle* vals,
                              int priority,
                              bool ignore_sparse) {
  API_BEGIN();
  std::vector<std::string> v_keys(num);
  std::vector<NDArray*> v_vals(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_vals[i] = static_cast<NDArray*>(vals[i]);
  }
  static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
  API_END();
}

int MXKVStorePullRowSparse(KVStoreHandle handle,
                           uint32_t num,
                           const int* keys,
                           NDArrayHandle* vals,
                           const NDArrayHandle* row_ids,
                           int priority) {
  API_BEGIN();
  std::vector<int> v_keys(num);
  std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_val_rowids[i] =
        std::make_pair(static_cast<NDArray*>(vals[i]), *static_cast<NDArray*>(row_ids[i]));
  }
  static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
  API_END();
}

int MXKVStorePullRowSparseEx(KVStoreHandle handle,
                             uint32_t num,
                             const char** keys,
                             NDArrayHandle* vals,
                             const NDArrayHandle* row_ids,
                             int priority) {
  API_BEGIN();
  std::vector<std::string> v_keys(num);
  std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
  for (uint32_t i = 0; i < num; ++i) {
    v_keys[i] = keys[i];
    v_val_rowids[i] =
        std::make_pair(static_cast<NDArray*>(vals[i]), *static_cast<NDArray*>(row_ids[i]));
  }
  static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
  API_END();
}

void MXKVStoreSetUpdaterImpl(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) {
  MXKVStoreUpdater* updater_temp = updater;
  void* updater_handle_temp      = updater_handle;
  std::function<void(int, const NDArray&, NDArray*)> updt =
      [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) {
        NDArray* recv_copy  = new NDArray();
        *recv_copy          = recv;
        NDArray* local_copy = new NDArray();
        *local_copy         = *local;
        updater_temp(key, recv_copy, local_copy, updater_handle_temp);
      };
  static_cast<KVStore*>(handle)->set_updater(updt);
}

int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) {
  API_BEGIN();
  MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
  API_END();
}

int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
                          MXKVStoreUpdater updater,
                          MXKVStoreStrUpdater str_updater,
                          void* updater_handle) {
  API_BEGIN();
  // set updater with int keys
  MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
  // set updater with string keys
  MXKVStoreStrUpdater* updater_temp = str_updater;
  void* updater_handle_temp         = updater_handle;
  std::function<void(const std::string&, const NDArray&, NDArray*)> updt =
      [updater_temp, updater_handle_temp](
          const std::string& key, const NDArray& recv, NDArray* local) {
        NDArray* recv_copy  = new NDArray();
        *recv_copy          = recv;
        NDArray* local_copy = new NDArray();
        *local_copy         = *local;
        updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp);
      };
  static_cast<KVStore*>(handle)->set_updater(updt);
  API_END();
}

int MXKVStoreGetRank(KVStoreHandle handle, int* rank) {
  API_BEGIN();
  *rank = static_cast<KVStore*>(handle)->get_rank();
  API_END();
}

int MXKVStoreGetGroupSize(KVStoreHandle handle, int* size) {
  API_BEGIN();
  *size = static_cast<KVStore*>(handle)->get_group_size();
  API_END();
}

int MXKVStoreBarrier(KVStoreHandle handle) {
  API_BEGIN();
  static_cast<KVStore*>(handle)->Barrier();
  API_END();
}

int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, const int barrier_before_exit) {
  API_BEGIN();
  static_cast<KVStore*>(handle)->set_barrier_before_exit(barrier_before_exit);
  API_END();
}

int MXInitPSEnv(uint32_t num_vars, const char** keys, const char** vals) {
  API_BEGIN();
  std::unordered_map<std::string, std::string> kwargs;
  for (uint32_t i = 0; i < num_vars; ++i) {
    kwargs[std::string(keys[i])] = std::string(vals[i]);
  }
  KVStore::InitPSEnv(kwargs);
  API_END();
}

int MXKVStoreIsWorkerNode(int* ret) {
  API_BEGIN();
  *ret = KVStore::IsWorkerNode();
  API_END();
}

int MXKVStoreIsServerNode(int* ret) {
  API_BEGIN();
  *ret = KVStore::IsServerNode();
  API_END();
}

int MXKVStoreIsSchedulerNode(int* ret) {
  API_BEGIN();
  *ret = KVStore::IsSchedulerNode();
  API_END();
}

int MXKVStoreRunServer(KVStoreHandle handle,
                       MXKVStoreServerController controller,
                       void* controller_handle) {
  API_BEGIN();
  MXKVStoreServerController* controller_temp = controller;
  void* controller_handle_temp               = controller_handle;
  auto ctrl = [controller_temp, controller_handle_temp](int head, const std::string& body) {
    controller_temp(head, body.c_str(), controller_handle_temp);
  };
  static_cast<KVStore*>(handle)->RunServer(ctrl);
  API_END();
}

int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body) {
  API_BEGIN();
  static_cast<KVStore*>(handle)->SendCommandToServers(cmd_id, std::string(cmd_body));
  API_END();
}

int MXKVStoreGetType(KVStoreHandle handle, const char** type) {
  API_BEGIN();
  *CHECK_NOTNULL(type) = static_cast<KVStore*>(handle)->type().c_str();
  API_END();
}

int MXKVStoreGetNumDeadNode(KVStoreHandle handle,
                            const int node_id,
                            int* number,
                            const int timeout_sec) {
  API_BEGIN();
  *number = static_cast<KVStore*>(handle)->get_num_dead_node(node_id, timeout_sec);
  API_END();
}

struct MXRecordIOContext {
  dmlc::RecordIOWriter* writer;
  dmlc::RecordIOReader* reader;
  dmlc::Stream* stream;
  std::string* read_buff;
};

int MXRecordIOWriterCreate(const char* uri, RecordIOHandle* out) {
  API_BEGIN();
  dmlc::Stream* stream       = dmlc::Stream::Create(uri, "w");
  MXRecordIOContext* context = new MXRecordIOContext;
  context->writer            = new dmlc::RecordIOWriter(stream);
  context->reader            = nullptr;
  context->stream            = stream;
  context->read_buff         = nullptr;
  *out                       = reinterpret_cast<RecordIOHandle>(context);
  API_END();
}

int MXRecordIOWriterFree(RecordIOHandle handle) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  delete context->writer;
  delete context->stream;
  delete context;
  API_END();
}

int MXRecordIOWriterWriteRecord(RecordIOHandle handle, const char* buf, size_t size) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  context->writer->WriteRecord(reinterpret_cast<const void*>(buf), size);
  API_END();
}

int MXRecordIOWriterTell(RecordIOHandle handle, size_t* pos) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  *pos                       = context->writer->Tell();
  API_END();
}

int MXRecordIOReaderCreate(const char* uri, RecordIOHandle* out) {
  API_BEGIN();
  dmlc::Stream* stream       = dmlc::Stream::Create(uri, "r");
  MXRecordIOContext* context = new MXRecordIOContext;
  context->reader            = new dmlc::RecordIOReader(stream);
  context->writer            = nullptr;
  context->stream            = stream;
  context->read_buff         = new std::string();
  *out                       = reinterpret_cast<RecordIOHandle>(context);
  API_END();
}

int MXRecordIOReaderFree(RecordIOHandle handle) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  delete context->reader;
  delete context->stream;
  delete context->read_buff;
  delete context;
  API_END();
}

int MXRecordIOReaderReadRecord(RecordIOHandle handle, char const** buf, size_t* size) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  if (context->reader->NextRecord(context->read_buff)) {
    *buf  = context->read_buff->c_str();
    *size = context->read_buff->size();
  } else {
    *buf  = nullptr;
    *size = 0;
  }
  API_END();
}

int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  context->reader->Seek(pos);
  API_END();
}

int MXRecordIOReaderTell(RecordIOHandle handle, size_t* pos) {
  API_BEGIN();
  MXRecordIOContext* context = reinterpret_cast<MXRecordIOContext*>(handle);
  *pos                       = context->reader->Tell();
  API_END();
}

int MXRtcCreate(char* name,
                uint32_t num_input,
                uint32_t num_output,
                char** input_names,
                char** output_names,
                NDArrayHandle* inputs,
                NDArrayHandle* outputs,
                char* kernel,
                RtcHandle* out) {
  API_BEGIN();
  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
  API_END();
}

int MXRtcPush(RtcHandle handle,
              uint32_t num_input,
              uint32_t num_output,
              NDArrayHandle* inputs,
              NDArrayHandle* outputs,
              uint32_t gridDimX,
              uint32_t gridDimY,
              uint32_t gridDimZ,
              uint32_t blockDimX,
              uint32_t blockDimY,
              uint32_t blockDimZ) {
  API_BEGIN();
  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
  API_END();
}

int MXRtcFree(RtcHandle handle) {
  API_BEGIN();
  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
  API_END();
}

int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
  API_BEGIN();
  mxnet::op::custom::CustomOperator::Get()->Register(op_type, creator);
  API_END();
}

int MXRtcCudaModuleCreate(const char* source,
                          int num_options,
                          const char** options,
                          int num_exports,
                          const char** exports,
                          CudaModuleHandle* out) {
  API_BEGIN();
#if MXNET_USE_CUDA
  std::vector<std::string> str_opts;
  for (int i = 0; i < num_options; ++i)
    str_opts.emplace_back(options[i]);
  std::vector<std::string> str_exports;
  for (int i = 0; i < num_exports; ++i)
    str_exports.emplace_back(exports[i]);
  *out = new rtc::CudaModule(source, str_opts, str_exports);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
  API_END();
}

int MXRtcCudaModuleFree(CudaModuleHandle handle) {
  API_BEGIN();
#if MXNET_USE_CUDA
  delete reinterpret_cast<rtc::CudaModule*>(handle);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
  API_END();
}

int MXRtcCudaKernelCreate(CudaModuleHandle handle,
                          const char* name,
                          int num_args,
                          int* is_ndarray,
                          int* is_const,
                          int* arg_types,
                          CudaKernelHandle* out) {
  API_BEGIN();
#if MXNET_USE_CUDA
  auto module = reinterpret_cast<rtc::CudaModule*>(handle);
  std::vector<rtc::CudaModule::ArgType> signature;
  for (int i = 0; i < num_args; ++i) {
    signature.push_back(rtc::CudaModule::ArgType{static_cast<bool>(is_ndarray[i]),
                                                 static_cast<bool>(is_const[i]),
                                                 static_cast<mshadow::TypeFlag>(arg_types[i])});
  }
  auto kernel = module->GetKernel(name, signature);
  *out        = new std::shared_ptr<rtc::CudaModule::Kernel>(kernel);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
  API_END();
}

int MXRtcCudaKernelFree(CudaKernelHandle handle) {
  API_BEGIN();
#if MXNET_USE_CUDA
  delete reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
  API_END();
}

int MXRtcCudaKernelCall(CudaKernelHandle handle,
                        int dev_id,
                        void** args,
                        uint32_t grid_dim_x,
                        uint32_t grid_dim_y,
                        uint32_t grid_dim_z,
                        uint32_t block_dim_x,
                        uint32_t block_dim_y,
                        uint32_t block_dim_z,
                        uint32_t shared_mem) {
  API_BEGIN();
#if MXNET_USE_CUDA
  auto kernel           = reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
  const auto& signature = (*kernel)->signature();
  std::vector<dmlc::any> any_args;
  for (size_t i = 0; i < signature.size(); ++i) {
    if (signature[i].is_ndarray) {
      any_args.emplace_back(*static_cast<NDArray*>(args[i]));
    } else {
      MSHADOW_TYPE_SWITCH(
          signature[i].dtype, DType, { any_args.emplace_back(*static_cast<DType*>(args[i])); });
    }
  }
  (*kernel)->Launch(Context::GPU(dev_id),
                    any_args,
                    grid_dim_x,
                    grid_dim_y,
                    grid_dim_z,
                    block_dim_x,
                    block_dim_y,
                    block_dim_z,
                    shared_mem);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
  API_END();
}

int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shared_id) {
  API_BEGIN();
  NDArray* arr = reinterpret_cast<NDArray*>(handle);
  Storage::Handle shandle;
  if (arr->ctx().dev_type == Context::kCPUShared) {
    arr->WaitToRead();
    shandle = arr->storage_handle();
    Storage::Get()->SharedIncrementRefCount(shandle);
  } else {
    NDArray new_arr(arr->shape(), Context::CPUShared(0), false, arr->dtype());
    CopyFromTo(*arr, new_arr);
    new_arr.WaitToRead();
    shandle = new_arr.storage_handle();
    Storage::Get()->SharedIncrementRefCount(shandle);
  }
  *shared_pid = shandle.shared_pid;
  *shared_id  = shandle.shared_id;
  API_END();
}

int MXNDArrayCreateFromSharedMem(int shared_pid,
                                 int shared_id,
                                 const int* shape,
                                 int ndim,
                                 int dtype,
                                 NDArrayHandle* out) {
  API_BEGIN();
  NDArray* nd = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
  *out = nd;
  API_END();
}

using VarHandle          = Engine::VarHandle;
using CallbackOnStart    = Engine::CallbackOnStart;
using CallbackOnComplete = Engine::CallbackOnComplete;

void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
  CHECK_GE(num_const_vars, 0) << "Non-negative number of const vars expected.";
  CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars expected.";
}

int MXEnginePushAsync(EngineAsyncFunc async_func,
                      void* func_param,
                      EngineFuncParamDeleter deleter,
                      ContextHandle ctx_handle,
                      EngineVarHandle const_vars_handle,
                      int num_const_vars,
                      EngineVarHandle mutable_vars_handle,
                      int num_mutable_vars,
                      EngineFnPropertyHandle prop_handle,
                      int priority,
                      const char* opr_name,
                      bool wait) {
  API_BEGIN();

  auto exec_ctx     = *static_cast<const Context*>(ctx_handle);
  auto const_vars   = static_cast<VarHandle*>(const_vars_handle);
  auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
  auto prop         = FnProperty::kNormal;
  if (prop_handle) {
    prop = *static_cast<const FnProperty*>(prop_handle);
  }

  Engine::AsyncFn exec_fn;
  if (deleter == nullptr) {
    exec_fn = [async_func, func_param](
                  RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
      async_func(&rctx, &on_start, &on_complete, func_param);
    };
  } else {
    // Wrap func_param in a shared_ptr with deleter such that deleter
    // will be called when the lambda goes out of scope.
    std::shared_ptr<void> shared_func_param(func_param, deleter);
    exec_fn = [async_func, shared_func_param](
                  RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
      async_func(&rctx, &on_start, &on_complete, shared_func_param.get());
    };
  }

  AssertValidNumberVars(num_const_vars, num_mutable_vars);
  std::vector<VarHandle> const_var_vec(const_vars, const_vars + num_const_vars);
  std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars);
  Engine::Get()->PushAsync(
      exec_fn, exec_ctx, const_var_vec, mutable_var_vec, prop, priority, opr_name, wait);

  API_END();
}

int MXEnginePushSync(EngineSyncFunc sync_func,
                     void* func_param,
                     EngineFuncParamDeleter deleter,
                     ContextHandle ctx_handle,
                     EngineVarHandle const_vars_handle,
                     int num_const_vars,
                     EngineVarHandle mutable_vars_handle,
                     int num_mutable_vars,
                     EngineFnPropertyHandle prop_handle,
                     int priority,
                     const char* opr_name) {
  API_BEGIN();

  auto exec_ctx     = *static_cast<const Context*>(ctx_handle);
  auto const_vars   = static_cast<VarHandle*>(const_vars_handle);
  auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
  auto prop         = FnProperty::kNormal;
  if (prop_handle) {
    prop = *static_cast<const FnProperty*>(prop_handle);
  }

  Engine::SyncFn exec_fn;
  if (deleter == nullptr) {
    exec_fn = [sync_func, func_param](RunContext rctx) { sync_func(&rctx, func_param); };
  } else {
    // Wrap func_param in a shared_ptr with deleter such that deleter
    // will be called when the lambda goes out of scope.
    std::shared_ptr<void> shared_func_param(func_param, deleter);
    exec_fn = [sync_func, shared_func_param](RunContext rctx) {
      sync_func(&rctx, shared_func_param.get());
    };
  }

  AssertValidNumberVars(num_const_vars, num_mutable_vars);
  std::vector<VarHandle> const_var_vec(const_vars, const_vars + num_const_vars);
  std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars);
  Engine::Get()->PushSync(
      exec_fn, exec_ctx, const_var_vec, mutable_var_vec, prop, priority, opr_name);

  API_END();
}

int MXEnginePushAsyncND(EngineAsyncFunc async_func,
                        void* func_param,
                        EngineFuncParamDeleter deleter,
                        ContextHandle ctx_handle,
                        NDArrayHandle* const_nds_handle,
                        int num_const_nds,
                        NDArrayHandle* mutable_nds_handle,
                        int num_mutable_nds,
                        EngineFnPropertyHandle prop_handle,
                        int priority,
                        const char* opr_name,
                        bool wait) {
  API_BEGIN();
  NDArray** const_nds   = reinterpret_cast<NDArray**>(const_nds_handle);
  NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
  std::vector<VarHandle> const_var_vec(num_const_nds);
  for (int i = 0; i < num_const_nds; ++i)
    const_var_vec[i] = const_nds[i]->var();
  std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
  for (int i = 0; i < num_mutable_nds; ++i)
    mutable_var_vec[i] = mutable_nds[i]->var();
  return MXEnginePushAsync(async_func,
                           func_param,
                           deleter,
                           ctx_handle,
                           const_var_vec.data(),
                           num_const_nds,
                           mutable_var_vec.data(),
                           num_mutable_nds,
                           prop_handle,
                           priority,
                           opr_name,
                           wait);
  API_END();
}

int MXEnginePushSyncND(EngineSyncFunc sync_func,
                       void* func_param,
                       EngineFuncParamDeleter deleter,
                       ContextHandle ctx_handle,
                       NDArrayHandle* const_nds_handle,
                       int num_const_nds,
                       NDArrayHandle* mutable_nds_handle,
                       int num_mutable_nds,
                       EngineFnPropertyHandle prop_handle,
                       int priority,
                       const char* opr_name) {
  API_BEGIN();
  NDArray** const_nds   = reinterpret_cast<NDArray**>(const_nds_handle);
  NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
  std::vector<VarHandle> const_var_vec(num_const_nds);
  for (int i = 0; i < num_const_nds; ++i)
    const_var_vec[i] = const_nds[i]->var();
  std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
  for (int i = 0; i < num_mutable_nds; ++i)
    mutable_var_vec[i] = mutable_nds[i]->var();
  return MXEnginePushSync(sync_func,
                          func_param,
                          deleter,
                          ctx_handle,
                          const_var_vec.data(),
                          num_const_nds,
                          mutable_var_vec.data(),
                          num_mutable_nds,
                          prop_handle,
                          priority,
                          opr_name);
  API_END();
}

int MXStorageEmptyCache(int dev_type, int dev_id) {
  API_BEGIN();
  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
  Storage::Get()->ReleaseAll(ctx);
  API_END();
}

int MXShallowCopyNDArray(NDArrayHandle src_handle, NDArrayHandle* out) {
  NDArray* ret = nullptr;
  API_BEGIN();
  NDArray* src_array = static_cast<NDArray*>(src_handle);
  ret                = new NDArray(*src_array);
  *out               = ret;
  API_END_HANDLE_ERROR(delete ret);
}

int MXPushStreamDep(NDArrayHandle handle, int stream) {
  API_BEGIN();
  static_cast<NDArray*>(handle)->StreamSync(stream);
  API_END();
}

int MXGetCurrentStream(int device_id, int* stream) {
  API_BEGIN();
#if MXNET_USE_CUDA
  RunContext rctx{Context::GPU(device_id), new mshadow::Stream<gpu>(), nullptr};
  mshadow::Stream<gpu>* cur_stream = rctx.get_stream<gpu>();
  *stream = reinterpret_cast<int64_t>(mshadow::Stream<gpu>::GetStream(cur_stream));
#else
  LOG(FATAL) << "GPU is not enabled.";
#endif
  API_END();
}

int MXNVTXRangePush(const char* name, mx_uint color) {
  API_BEGIN();
#if MXNET_USE_CUDA && MXNET_USE_NVTX
  mxnet::common::cuda::nvtx::gpuRangeStart(color, name);
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NVTX=1 to have NVTX support.";
#endif
  API_END();
}

int MXNVTXRangePop() {
  API_BEGIN();
#if MXNET_USE_CUDA && MXNET_USE_NVTX
  mxnet::common::cuda::nvtx::gpuRangeStop();
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NVTX=1 to have NVTX support.";
#endif
  API_END();
}

int MXCUDAProfilerStart() {
  API_BEGIN();
#if MXNET_USE_CUDA && MXNET_USE_NVTX
  cudaProfilerStart();
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NVTX=1 to have CUDA profiler support.";
#endif
  API_END();
}

int MXCUDAProfilerStop() {
  API_BEGIN();
#if MXNET_USE_CUDA && MXNET_USE_NVTX
  cudaProfilerStop();
#else
  LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NVTX=1 to have CUDA Profiler support.";
#endif
  API_END();
}

int MXSetOptimizeLayout(bool val) {
  API_BEGIN();
  mxnet::alm::ALMParams::get().optimize = val;
  API_END();
}

int MXGetOptimizeLayout(bool* val) {
  API_BEGIN();
  *val = mxnet::alm::ALMParams::get().optimize;
  API_END();
}
