blob: 68817a8849f760c1dd484031475993a1be2423e0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file torch_module-inl.h
* \brief torch module operator
* \author Min Lin
*/
#ifndef PLUGIN_TORCH_TORCH_CRITERION_INL_H_
#define PLUGIN_TORCH_TORCH_CRITERION_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <stdio.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "../../src/operator/operator_common.h"
#include "./torch_base.h"
namespace mxnet {
namespace op {
struct TorchCriterionParam : public dmlc::Parameter<TorchCriterionParam> {
std::string lua_string;
mxnet::TShape label_shape;
float grad_scale;
DMLC_DECLARE_PARAMETER(TorchCriterionParam) {
DMLC_DECLARE_FIELD(lua_string)
.describe("lua string that is called to generate the torch criterion object");
DMLC_DECLARE_FIELD(label_shape)
.set_default(mxnet::TShape())
.enforce_nonzero()
.describe("Shape of label (without batch size).");
DMLC_DECLARE_FIELD(grad_scale)
.set_default(1.0f)
.describe("Scale the gradient by a float factor (a.k.a weight of this loss).");
}
};
/**
* \brief This is the implementation of activation operator.
* \tparam xpu The device that the op will be executed on.
*/
template <typename xpu>
class TorchCriterionOp : public Operator {
private:
TorchCriterionParam param_;
TorchState* torchState_;
int lua_reference_;
public:
explicit TorchCriterionOp(TorchCriterionParam p) {
this->param_ = p;
this->torchState_ = new TorchState();
lua_State* L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
std::string exec =
std::string("return ") + p.lua_string + TorchTensor::ModuleType(xpu::kDevMask);
CHECK_EQ(luaL_loadstring(L, exec.c_str()), 0);
int err = lua_pcall(L, 0, 1, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
// serialize
this->lua_reference_ = lua_ref(L, LUA_REGISTRYINDEX);
}
~TorchCriterionOp() {
delete this->torchState_;
}
virtual void Forward(const OpContext& ctx,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& out_data,
const std::vector<TBlob>& aux_args) {
using namespace mshadow;
lua_State* L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
CHECK_EQ(in_data.size(), 2);
CHECK_EQ(out_data.size(), 1);
Stream<xpu>* s = ctx.get_stream<xpu>();
torchState_->SetStream(s);
lua_rawgeti(L, LUA_REGISTRYINDEX, lua_reference_);
// call forward
// | self
lua_getfield(L, -1, "forward");
// | self | forward
lua_pushvalue(L, -2);
// | self | forward | self
for (index_t i = 0; i < in_data.size(); ++i) {
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState_, in_data[i]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_data[i]));
}
// | self | forward | self | pred | label
int err = lua_pcall(L, 3, 1, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
CHECK(lua_isnumber(L, -1)) << "Criterion must return a number";
real_t loss = static_cast<real_t>(lua_tonumber(L, -1));
lua_pop(L, 1);
Tensor<xpu, 2> out = out_data[0].FlatTo2D<xpu, real_t>(s);
Assign(out, req[0], loss * param_.grad_scale);
lua_pop(L, 1);
CHECK_EQ(lua_gettop(L), 0);
}
virtual void Backward(const OpContext& ctx,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& in_grad,
const std::vector<TBlob>& aux_args) {
using namespace mshadow;
lua_State* L = torchState_->L;
CHECK_EQ(lua_gettop(L), 0);
CHECK_EQ(in_data.size(), 2);
CHECK_EQ(out_data.size(), 1);
CHECK_EQ(req[0], kWriteTo) << "Torch Criterion only supports write to in_grad";
CHECK_EQ(req[1], kNullOp) << "Torch Criterion cannot back prop to label";
Stream<xpu>* s = ctx.get_stream<xpu>();
torchState_->SetStream(s);
lua_rawgeti(L, LUA_REGISTRYINDEX, lua_reference_);
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState_, in_grad[0]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_grad[0]));
lua_setfield(L, -2, "gradInput");
lua_getfield(L, -1, "backward");
// | self | backward
lua_pushvalue(L, -2);
// | self | backward | self
for (index_t i = 0; i < in_data.size(); ++i) {
th = TorchTensor::TBlobToTHTensor(torchState_, in_data[i]);
luaT_pushudata(L, th, TorchTensor::TensorType(in_data[i]));
}
// | self | forward | self | pred | label
int err = lua_pcall(L, 3, 0, 0);
CHECK_EQ(err, 0) << lua_tostring(L, -1);
Tensor<xpu, 2> grad = in_grad[0].FlatTo2D<xpu, real_t>(s);
grad *= param_.grad_scale * in_grad[0].shape_[0];
lua_pop(L, 1);
CHECK_EQ(lua_gettop(L), 0);
}
}; // class TorchCriterionOp
// Decalre Factory function, used for dispatch specialization
template <typename xpu>
Operator* CreateOp(TorchCriterionParam type);
#if DMLC_USE_CXX11
class TorchCriterionProp : public OperatorProperty {
public:
std::vector<std::string> ListArguments() const override {
return {"data", "label"};
}
virtual std::vector<std::string> ListOutputs() const {
return {"output"};
}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}
bool InferShape(mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape,
mxnet::ShapeVector* aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2);
const mxnet::TShape& dshape = in_shape->at(0);
if (dshape.ndim() == 0)
return false;
std::vector<index_t> lshape;
lshape.push_back(dshape[0]);
lshape.insert(lshape.end(),
param_.label_shape.data(),
param_.label_shape.data() + param_.label_shape.ndim());
mxnet::TShape shape(lshape.begin(), lshape.end());
SHAPE_ASSIGN_CHECK(*in_shape, 1, shape);
out_shape->clear();
out_shape->push_back(Shape1(dshape[0]));
return true;
}
OperatorProperty* Copy() const override {
auto ptr = new TorchCriterionProp();
ptr->param_ = param_;
return ptr;
}
std::string TypeString() const override {
return "TorchCriterion";
}
// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(const std::vector<int>& out_grad,
const std::vector<int>& in_data,
const std::vector<int>& out_data) const override {
std::vector<int> dep;
dep.insert(dep.end(), in_data.begin(), in_data.end());
// Ensure that the backward and forward cannot be called at the same time
dep.insert(dep.end(), out_data.begin(), out_data.end());
return dep;
}
Operator* CreateOperator(Context ctx) const override;
private:
TorchCriterionParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // PLUGIN_TORCH_TORCH_CRITERION_INL_H_