blob: b47ab56f68d23bb3432f65db9d7993e8cfe1cda5 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file torch_base.cc
* \brief torch_state
* \author Junyuan Xie
*/
#include "./torch_function.h"
namespace mxnet {
// Construction or extraction functions
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_eye, eye);
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_ones, ones);
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_rand, rand);
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_randn, randn);
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_randperm, randperm);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_tril, tril);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_triu, triu);
MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_zeros, zeros);
// Element-wise Mathematical Operations
MXNET_REGISTER_TORCH_UNARY_FUN(_th_abs, abs);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_sign, sign);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_acos, acos);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_asin, asin);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_atan, atan);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_ceil, ceil);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_cos, cos);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_cosh, cosh);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_exp, exp);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_floor, floor);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_log, log);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_log1p, log1p);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_pow, pow)
.add_argument("n", "float", "pow(x, n) returns x^n, element-wise. "
"pow(n, x) returns n^x, element-wise.");
MXNET_REGISTER_TORCH_UNARY_FUN(_th_round, round);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_sin, sin);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_sinh, sinh);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_sqrt, sqrt);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_tan, tan);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_tanh, tanh);
// Basic operations
MXNET_REGISTER_TORCH_UNARY_FUN(_th_add_scalar, add)
.add_argument("value", "float", "Add value to all elements in x");
MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_add, add);
MXNET_REGISTER_TORCH_BINARY_FUN(_th_add_axpy, add);
// MXNET_REGISTER_TORCH_UNARY_FUN(_th_csub_scalar, csub);
// MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_csub, csub);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_mul_scalar, mul)
.add_argument("value", "float", "Multiply value to all elements in x");
MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cmul, cmul);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_clamp, clamp);
MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cpow, cpow);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addcmul, addcmul);
MXNET_REGISTER_TORCH_UNARY_FUN(_th_div_scalar, div)
.add_argument("value", "float", "Divide all elements in x by value");
MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cdiv, cdiv);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addcdiv, addcdiv);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addmv, addmv);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addr, addr);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addmm, addmm);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_addbmm, addbmm);
MXNET_REGISTER_TORCH_TENARY_FUN(_th_baddbmm, baddbmm);
struct TorchMMShape {
static std::vector<mshadow::TShape> GetShape(NDArray **u,
const std::map<std::string, std::string>& param) {
CHECK_EQ(u[0]->shape().ndim(), 2);
CHECK_EQ(u[1]->shape().ndim(), 2);
CHECK_EQ(u[0]->shape()[1], u[1]->shape()[0]);
index_t shape[] = {u[0]->shape()[0], u[1]->shape()[1]};
mshadow::TShape tshape(shape, shape+2);
return {tshape};
}
static constexpr const char* fname = "mm";
static const int num_inputs = 2;
static const int num_outputs = 1;
};
MXNET_REGISTER_TORCH_FUN(_th_mm, TorchMMShape);
struct TorchMVShape {
static std::vector<mshadow::TShape> GetShape(NDArray **u,
const std::map<std::string, std::string>& param) {
CHECK_EQ(u[0]->shape().ndim(), 2);
CHECK_EQ(u[1]->shape().ndim(), 1);
CHECK_EQ(u[0]->shape()[1], u[1]->shape()[0]);
index_t shape[] = {u[0]->shape()[0]};
mshadow::TShape tshape(shape, shape+1);
return {tshape};
}
static constexpr const char* fname = "mv";
static const int num_inputs = 2;
static const int num_outputs = 1;
};
MXNET_REGISTER_TORCH_FUN(_th_mv, TorchMVShape);
struct TorchBMMShape {
static std::vector<mshadow::TShape> GetShape(NDArray **u,
const std::map<std::string, std::string>& param) {
CHECK_EQ(u[0]->shape().ndim(), 3);
CHECK_EQ(u[1]->shape().ndim(), 3);
CHECK_EQ(u[0]->shape()[0], u[1]->shape()[0]);
CHECK_EQ(u[0]->shape()[2], u[1]->shape()[1]);
index_t shape[] = {u[0]->shape()[1], u[1]->shape()[2]};
mshadow::TShape tshape(shape, shape+2);
return {tshape};
}
static constexpr const char* fname = "bmm";
static const int num_inputs = 2;
static const int num_outputs = 1;
};
MXNET_REGISTER_TORCH_FUN(_th_bmm, TorchBMMShape);
struct TorchGERShape {
static std::vector<mshadow::TShape> GetShape(NDArray **u,
const std::map<std::string, std::string>& param) {
CHECK_EQ(u[0]->shape().ndim(), 1);
CHECK_EQ(u[1]->shape().ndim(), 1);
index_t shape[] = {u[0]->shape()[0], u[1]->shape()[0]};
mshadow::TShape tshape(shape, shape+2);
return {tshape};
}
static constexpr const char* fname = "ger";
static const int num_inputs = 2;
static const int num_outputs = 1;
};
MXNET_REGISTER_TORCH_FUN(_th_ger, TorchGERShape);
} // namespace mxnet