blob: 174ebf2dd9dddbe43e11be1bf6bbf48839fe6e33 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file torch_module-inl.h
* \brief torch module operator
* \author Min Lin
*/
#ifndef PLUGIN_TORCH_TORCH_CRITERION_INL_H_
#define PLUGIN_TORCH_TORCH_CRITERION_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <stdio.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "../../src/operator/operator_common.h"
#include "./torch_base.h"
namespace mxnet {
namespace op {
struct TorchCriterionParam : public dmlc::Parameter<TorchCriterionParam> {
std::string lua_string;
TShape label_shape;
float grad_scale;
DMLC_DECLARE_PARAMETER(TorchCriterionParam) {
DMLC_DECLARE_FIELD(lua_string)
.describe("lua string that is called to generate the torch criterion object");
DMLC_DECLARE_FIELD(label_shape)
.set_default(TShape())
.enforce_nonzero()
.describe("Shape of label (without batch size).");
DMLC_DECLARE_FIELD(grad_scale)
.set_default(1.0f)
.describe("Scale the gradient by a float factor (a.k.a weight of this loss).");
}
};
/**
* \brief This is the implementation of activation operator.
* \tparam xpu The device that the op will be executed on.
*/
template<typename xpu>
class TorchCriterionOp : public Operator {
private:
TorchCriterionParam param_;
TorchState* torchState_;
int lua_reference_;
public:
explicit TorchCriterionOp(TorchCriterionParam p) {
this->param_ = p;
this->torchState_ = new TorchState();
lua_State *L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
std::string exec = std::string("return ") + p.lua_string
+ TorchTensor::ModuleType(xpu::kDevMask);
CHECK_EQ(luaL_loadstring(L, exec.c_str()), 0);
int err = lua_pcall(L, 0, 1, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
// serialize
this->lua_reference_ = lua_ref(L, LUA_REGISTRYINDEX);
}
~TorchCriterionOp() {
delete this->torchState_;
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
lua_State *L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
CHECK_EQ(in_data.size(), 2);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
torchState_->SetStream(s);
lua_rawgeti(L, LUA_REGISTRYINDEX, lua_reference_);
// call forward
// | self
lua_getfield(L, -1, "forward");
// | self | forward
lua_pushvalue(L, -2);
// | self | forward | self
for (index_t i = 0; i < in_data.size(); ++i) {
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState_, in_data[i]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_data[i]));
}
// | self | forward | self | pred | label
int err = lua_pcall(L, 3, 1, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
CHECK(lua_isnumber(L, -1)) << "Criterion must return a number";
real_t loss = static_cast<real_t>(lua_tonumber(L, -1));
lua_pop(L, 1);
Tensor<xpu, 2> out = out_data[0].FlatTo2D<xpu, real_t>(s);
Assign(out, req[0], loss*param_.grad_scale);
lua_pop(L, 1);
CHECK_EQ(lua_gettop(L), 0);
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
lua_State *L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
CHECK_EQ(in_data.size(), 2);
CHECK_EQ(out_data.size(), 1);
CHECK_EQ(req[0], kWriteTo) << "Torch Criterion only supports write to in_grad";
CHECK_EQ(req[1], kNullOp) << "Torch Criterion cannot back prop to label";
Stream<xpu> *s = ctx.get_stream<xpu>();
torchState_->SetStream(s);
lua_rawgeti(L, LUA_REGISTRYINDEX, lua_reference_);
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState_, in_grad[0]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_grad[0]));
lua_setfield(L, -2, "gradInput");
lua_getfield(L, -1, "backward");
// | self | backward
lua_pushvalue(L, -2);
// | self | backward | self
for (index_t i = 0; i < in_data.size(); ++i) {
th = TorchTensor::TBlobToTHTensor(torchState_, in_data[i]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_data[i]));
}
// | self | forward | self | pred | label
int err = lua_pcall(L, 3, 0, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
Tensor<xpu, 2> grad = in_grad[0].FlatTo2D<xpu, real_t>(s);
grad *= param_.grad_scale * in_grad[0].shape_[0];
lua_pop(L, 1);
CHECK_EQ(lua_gettop(L), 0);
}
}; // class TorchCriterionOp
// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(TorchCriterionParam type);
#if DMLC_USE_CXX11
class TorchCriterionProp : public OperatorProperty {
public:
std::vector<std::string> ListArguments() const override {
return {"data", "label"};
}
virtual std::vector<std::string> ListOutputs() const {
return {"output"};
}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2);
const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
std::vector<index_t> lshape;
lshape.push_back(dshape[0]);
lshape.insert(lshape.end(), param_.label_shape.data(),
param_.label_shape.data() + param_.label_shape.ndim());
TShape shape(lshape.begin(), lshape.end());
SHAPE_ASSIGN_CHECK(*in_shape, 1, shape);
out_shape->clear();
out_shape->push_back(Shape1(dshape[0]));
return true;
}
OperatorProperty* Copy() const override {
auto ptr = new TorchCriterionProp();
ptr->param_ = param_;
return ptr;
}
std::string TypeString() const override {
return "TorchCriterion";
}
// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
std::vector<int> dep;
dep.insert(dep.end(), in_data.begin(), in_data.end());
// Ensure that the backward and forward cannot be called at the same time
dep.insert(dep.end(), out_data.begin(), out_data.end());
return dep;
}
Operator* CreateOperator(Context ctx) const override;
private:
TorchCriterionParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // PLUGIN_TORCH_TORCH_CRITERION_INL_H_