blob: 3037031ad991389e57c33095867c0329480b37eb [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file caffe_blob.h
* \brief conversion between tensor and caffeBlob
* \author Haoran Wang
*/
#ifndef PLUGIN_CAFFE_CAFFE_BLOB_H_
#define PLUGIN_CAFFE_CAFFE_BLOB_H_
#include <mxnet/tensor_blob.h>
#include <vector>
#include <caffe/blob.hpp>
#include <caffe/layer.hpp>
namespace mxnet {
namespace op {
namespace caffe {
// Declare Memory Type for Caffe blob
enum caffeMemoryTypes {Data, Grad, Non};
TShape Vector2TShape(const std::vector<int> &vec_int);
std::vector<int> TShape2Vector(const TShape &tshape);
// implementation of tensor to blob, called by TensorToBlob
template<typename Device, typename Dtype>
void SetDataGradToBlob(caffeMemoryTypes memType,
typename std::vector< ::caffe::Blob<Dtype>*>::iterator blob,
typename std::vector<TBlob>::const_iterator itr);
/**
* \brief The interface to convert mxnet's tensor to caffe's blob
* \brief called in caffe_operator_inl.h
*/
template<typename Device, typename Dtype>
void TBlob2CaffeBlob(caffeMemoryTypes memType,
typename std::vector< ::caffe::Blob<Dtype>*>::iterator blob,
typename std::vector<TBlob>::const_iterator tblob,
int n = 1) {
for (int i = 0; i < n; ++i, ++blob, ++tblob) {
(*blob)->Reshape(TShape2Vector((*tblob).shape_));
SetDataGradToBlob<Device, Dtype>(memType, blob, tblob);
}
}
template<typename Dtype>
void SetOpBlobs(::caffe::Layer<Dtype> *caffeOp,
const std::vector< ::caffe::Blob<Dtype>*>& weights) {
CHECK_EQ(caffeOp->blobs().size(), weights.size());
for (int i = 0; i < weights.size(); ++i)
caffeOp->blobs()[i].reset(weights[i]);
}
/**!
* \brief Workaround for missing functions in ::caffe::Blob
* \warning Do not add or override any virtual functions in this class
* @tparam Dtype
*/
template<class Dtype>
class CaffeBlobFriend : public ::caffe::Blob<Dtype> {
public:
inline void set_cpu_diff(Dtype* diff) {
CHECK(diff);
this->diff_->set_cpu_data(diff);
}
inline void set_gpu_diff(Dtype* diff) {
CHECK(diff);
this->diff_->set_gpu_data(diff);
}
};
#define MXCAFFEBLOB(__object$, __type$) \
(static_cast<mxnet::op::caffe::CaffeBlobFriend<__type$> *>(__object$))
/**!
* \brief Workaround for missing functions in ::caffe::Layer
* \warning Do not add or override any virtual functions in this class
* @tparam Dtype
*/
template <typename Dtype>
class CaffeLayerFriend : public ::caffe::Layer<Dtype> {
explicit CaffeLayerFriend(const ::caffe::LayerParameter& param) = delete;
public:
inline void SetPhase(::caffe::Phase p) {
this->phase_ = p;
}
};
#define MXCAFFELAYER(__object$, __type$) \
(static_cast<mxnet::op::caffe::CaffeLayerFriend<__type$> *>(__object$))
} // namespace caffe
} // namespace op
} // namespace mxnet
#endif // PLUGIN_CAFFE_CAFFE_BLOB_H_