blob: fb128fa91053bc71ca49fddabe5058ff7b27d0a8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file symbol.h
* \brief Rcpp Symbolic construction interface of MXNet
*/
#ifndef MXNET_RCPP_SYMBOL_H_
#define MXNET_RCPP_SYMBOL_H_
#include <Rcpp.h>
#include <mxnet/c_api.h>
#include <nnvm/c_api.h>
#include <string>
#include <algorithm>
#include <vector>
namespace mxnet {
namespace R {
// forward declare symbol functiono
class SymbolFunction;
/*! \brief The Rcpp Symbol class of MXNet */
class Symbol {
public:
// typedef RObjectType
typedef Rcpp::RObject RObjectType;
/*! \return typename from R side. */
inline static const char* TypeName() {
return "MXSymbol";
}
/*!
* \brief Apply the symbol as function on kwargs
* \param kwargs keyword arguments to the data
* \return A resulting symbol.
*/
RObjectType Apply(const Rcpp::List& kwargs) const;
/*!
* \brief Print the debug string of symbol
* \return the debug string.
*/
std::string DebugStr() const;
/*! \return the arguments in the symbol */
std::vector<std::string> ListArguments() const;
/*! \return the auxiliary states symbol */
std::vector<std::string> ListAuxiliaryStates() const;
/*! \return the outputs in the symbol */
std::vector<std::string> ListOuputs() const;
/*! \return the attributes of the symbol */
Rcpp::List getAttrs() const;
/*!
* \brief sets the attributes of the symbol
* \param attr list of keyword arguments
*/
void setAttrs(Rcpp::List attr);
/*!
* \brief Save the symbol to file
* \param fname the file name we need to save to
*/
void Save(const std::string& fname) const;
/*!
* \brief save the symbol to json string
* \return a JSON string representation of symbol.
*/
std::string AsJSON() const;
/*!
* \brief Get a new grouped symbol whose output contains all the
* internal outputs of this symbol.
* \return The internal of the symbol.
*/
RObjectType GetInternals() const;
/*!
* \brief Gets a new grouped symbol whose output contains
* inputs to output nodes of the original symbol.
* \return The children of the symbol.
*/
RObjectType GetChildren() const;
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
* \param index the Index of the output.
* \param out The output symbol whose outputs are the index-th symbol.
*/
RObjectType GetOutput(mx_uint index) const;
/*! \brief Infer the shapes of arguments, outputs, and auxiliary states */
SEXP InferShape(const Rcpp::List& kwargs) const;
/*!
* \brief Create a symbolic variable with specified name.
*
* \param name string, Name of the variable.
* \return The created variable symbol.
*/
static RObjectType Variable(const std::string& name);
/*!
* \brief Load a symbol variable from filename.
*
* \param filename string, the path to the symbol file.
* \return The loaded corresponding symbol.
*/
static RObjectType Load(const std::string& filename);
/*!
* \brief Load a symbol variable from json string
*
* \param json string, json string of symbol.
* \return The loaded corresponding symbol.
*/
static RObjectType LoadJSON(const std::string& json);
/*!
* \brief Create a symbol that groups symbols together.
*
* \param ... List of symbols to be grouped.
* \return The created grouped symbol.
*/
static RObjectType Group(const Rcpp::List& symbols);
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();
// destructor
~Symbol() {
MX_CALL(MXSymbolFree(handle_));
}
// get external pointer of Symbol
inline static Symbol* XPtr(const Rcpp::RObject& obj) {
return Rcpp::as<Symbol*>(obj);
}
private:
// friend with SymbolFunction
friend class SymbolFunction;
friend class Executor;
// enable trivial copy constructors etc.
Symbol() {}
// constructor
explicit Symbol(SymbolHandle handle)
: handle_(handle) {}
/*!
* \brief create a R object that correspond to the Class
* \param handle the Handle needed for output.
*/
inline static Rcpp::RObject RObject(SymbolHandle handle) {
return Rcpp::internal::make_new_object(new Symbol(handle));
}
/*!
* \brief Return a clone of Symbol
* Do not expose to R side
* \param obj The source to be cloned from
* \return a Cloned Symbol
*/
inline RObjectType Clone() const;
/*!
* \brief Compose the symbol with kwargs
* \param kwargs keyword arguments to the data
* \param name name of the symbol.
*/
void Compose(const Rcpp::List& kwargs, const std::string &name);
/*! \brief internal executor handle */
SymbolHandle handle_;
};
/*! \brief The Symbol functions to be invoked */
class SymbolFunction : public ::Rcpp::CppFunction {
public:
virtual SEXP operator() (SEXP* args);
virtual int nargs() {
return 1;
}
virtual bool is_void() {
return false;
}
virtual void signature(std::string& s, const char* name) { // NOLINT(*)
::Rcpp::signature< SEXP, ::Rcpp::List >(s, name);
}
virtual const char* get_name() {
return name_.c_str();
}
virtual SEXP get_formals() {
return Rcpp::List::create(Rcpp::_["alist"]);
}
virtual DL_FUNC get_function_ptr() {
return (DL_FUNC)NULL; // NOLINT(*)
}
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();
private:
// make constructor private
explicit SymbolFunction(OpHandle handle, std::string name);
/*! \brief internal creator handle. */
OpHandle handle_;
// name of the function
std::string name_;
// hint used to generate the names
std::string name_hint_;
// key to variable size arguments, if any.
std::string key_var_num_args_;
};
} // namespace R
} // namespace mxnet
RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::Symbol);
namespace Rcpp {
template<>
inline bool is<mxnet::R::Symbol>(SEXP x) {
return internal::is__module__object_fix<mxnet::R::Symbol>(x);
}
}
#endif // MXNET_RCPP_SYMBOL_H_