blob: 4ad786a2ce93d2b6e3bd8d21ba6e79a060d9fdc8 [file] [log] [blame]
/*******************************************************************************
* Copyright 2016 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file mkl_util-inl.h
* \brief
* \author lingyan.guo@intel.com
* zhenlin.luo@intel.com
*
*******************************************************************************/
#ifndef MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_
#define MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_
#include <vector>
#define MKLDNN_CALL(func) \
{ \
dnnError_t status = (func); \
CHECK_EQ(status, E_SUCCESS) << "MKL DNN call failed (status: " << status << ")."; \
}
namespace mxnet {
namespace op {
#if MKL_EXPERIMENTAL == 1
template<typename DType>
inline DType * mkl_prv_data(const TBlob &b) {
std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
if (mem_valid) {
return reinterpret_cast<DType*>(bottom_data_mem->prv_data());
}
return NULL;
}
template<typename DType>
inline int mkl_prv_count(const TBlob &b) {
std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
if (mem_valid) {
return bottom_data_mem->prv_count();
}
return 0;
}
#endif
inline void mkl_set_priv_flag(const TBlob &b) {
#if MKL_EXPERIMENTAL == 1
std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
if (mem_valid) {
bottom_data_mem->disable_prv_2_cpu(true);
}
#endif
}
#if MKL_EXPERIMENTAL == 1
template<typename DType>
inline std::shared_ptr<MKLData<DType> > mkl_get_mem_desc(
const std::shared_ptr<MKLMemHolder> data_mem) {
std::shared_ptr<PrvMemDescr> prv_descriptor =
data_mem->get_prv_descriptor();
CHECK_EQ(prv_descriptor->get_descr_type(),
PrvMemDescr::PRV_DESCR_MKL2017);
std::shared_ptr<MKLData<DType> > mem_descr
= std::static_pointer_cast<MKLData<DType>>
(prv_descriptor);
CHECK(mem_descr != NULL);
return mem_descr;
}
#endif
template<typename xpu, int dim, typename DType>
inline mshadow::Tensor<xpu, dim, DType> mkl_experimental_direct_get(
const TBlob &b, mshadow::Stream<xpu> *s) {
mkl_set_priv_flag(b);
return b.get<xpu, dim, DType>(s);
}
template<typename xpu, int dim, typename DType>
inline mshadow::Tensor<xpu, dim, DType> mkl_experimental_direct_get_with_shape(
const TBlob &b, const mshadow::Shape<dim> &shape, mshadow::Stream<xpu> *s) {
mkl_set_priv_flag(b);
return b.get_with_shape<xpu, dim, DType>(shape, s);
}
} // namespace op
#if MKL_EXPERIMENTAL == 1
inline void mkl_tblobs_prv_to_cpu(const std::vector<TBlob> &data) {
for (size_t i = 0; i < data.size(); i++) {
std::shared_ptr<MKLMemHolder> mem_holder = data[i].Mkl_mem_;
if (mem_holder != nullptr && mem_holder->b_eager_mode) {
mem_holder->check_and_prv_to_cpu(data[i].dptr_);
}
}
}
inline void mkl_set_tblob_eager_mode(const TBlob &data) {
std::shared_ptr<MKLMemHolder> mem_holder = data.Mkl_mem_;
if (mem_holder != nullptr) {
mem_holder->set_eager_mode(true);
}
}
#endif
} // namespace mxnet
#endif // MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_