blob: ccf536920fb0b6d5a1d7e20de58712885965ede1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file attach_op_execs_pass.cc
* \brief Operator executor to execute each operator.
*/
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "./exec_pass.h"
#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
namespace mxnet {
namespace op {
const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs);
} // namespace op
namespace exec {
// abstract OpExecutor which provides storage fallback procedure on
// non-default inputs and outputs
// FComputeExecutor and FStatefulComputeExecutor inherit from this class
class StorageFallbackOpExecutor : public OpExecutor {
public:
explicit StorageFallbackOpExecutor(const NodeAttrs& attrs,
const DispatchMode& dispatch_mode,
const std::vector<uint32_t> &mutate_idx)
: OpExecutor(attrs, dispatch_mode), mutate_idx_(mutate_idx) {}
void Setup() override {
init_ = false;
}
protected:
// initialize the data blobs
void InitBlobs() {
if (!init_) {
pre_temp_buf_.clear();
post_temp_buf_.clear();
for (const auto& nd : in_array) {
pre_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype());
}
for (const auto& nd : out_array) {
post_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype());
}
init_ = true;
}
}
// storage fallback before fcompute is launched
void PreFCompute(bool is_gpu) {
using namespace common;
InitBlobs();
in_data_.clear(); out_data_.clear();
pre_temp_src_.clear(); pre_temp_dst_.clear();
post_temp_src_.clear(); post_temp_dst_.clear();
in_temp_idx_map_.clear();
tmp_req = req;
SetupDefaultBlobsInOut(in_array, out_array, &pre_temp_buf_, &post_temp_buf_, &req,
&in_data_, &out_data_,
&pre_temp_src_, &pre_temp_dst_,
&post_temp_src_, &post_temp_dst_,
&in_temp_idx_map_, mutate_idx_);
common::CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx, is_gpu);
}
// storage fallback after fcompute is completed
void PostFCompute(bool is_gpu) {
common::CastNonDefaultStorage(post_temp_src_, post_temp_dst_, op_ctx, is_gpu);
req = tmp_req;
}
// output requirement on each output array.
// This temporarily saves the original output requirements.
std::vector<OpReqType> tmp_req;
// default storage tensor blobs for fcompute
std::vector<TBlob> in_data_, out_data_;
// These are NDArray buffers for cast storage.
std::vector<NDArray> pre_temp_buf_, post_temp_buf_;
// source NDArray for cast storage
std::vector<NDArray> pre_temp_src_, post_temp_src_;
// destination NDArray for cast storage
std::vector<NDArray> pre_temp_dst_, post_temp_dst_;
// mapping from index in input_blobs to index in pre_temp_dst
std::unordered_map<uint32_t, uint32_t> in_temp_idx_map_;
// indices of mutatable inputs
std::vector<uint32_t> mutate_idx_;
// whether blobs are initialized
bool init_;
};
// stateful compute executor
class StatefulComputeExecutor : public StorageFallbackOpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
#endif
PreFCompute(is_gpu);
fcompute_(state_, op_ctx, in_data_, req, out_data_);
PostFCompute(is_gpu);
}
ExecType exec_type() const override {
return exec_type_;
}
engine::VarHandle var() const override {
return state_.get_var();
}
OpStatePtr state() const override {
return state_;
}
explicit StatefulComputeExecutor(const NodeAttrs& attrs,
const DispatchMode dispatch_mode,
const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
const std::vector<uint32_t> &mutate_idx)
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
OpStatePtr state_;
FStatefulCompute fcompute_;
ExecType exec_type_;
};
// stateful compute_ex executor
class StatefulComputeExExecutor : public OpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
// TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
return;
}
#endif
fcompute_(state_, op_ctx, in_array, req, out_array);
}
void Setup() override {}
ExecType exec_type() const override {
return exec_type_;
}
engine::VarHandle var() const override {
return state_.get_var();
}
OpStatePtr state() const override {
return state_;
}
explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
const DispatchMode& dispatch_mode,
const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: OpExecutor(attrs, dispatch_mode), state_(state), fcompute_(fcompute),
exec_type_(exec_type) {}
private:
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
};
// fcompute executor
class FComputeExecutor : public StorageFallbackOpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
using namespace common;
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
#endif
PreFCompute(is_gpu);
fcompute_(attrs, op_ctx, in_data_, req, out_data_);
PostFCompute(is_gpu);
}
ExecType exec_type() const override {
return exec_type_;
}
explicit FComputeExecutor(const NodeAttrs& attrs, const DispatchMode dispatch_mode,
FCompute fcompute, ExecType exec_type,
const std::vector<uint32_t> &mutate_idx)
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
fcompute_(fcompute), exec_type_(exec_type) {
}
private:
FCompute fcompute_;
ExecType exec_type_;
};
// fcompute_ex executor
class FComputeExExecutor : public OpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
// TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(attrs, op_ctx, in_array_fallback, req, out_array);
return;
}
#endif
fcompute_(attrs, op_ctx, in_array, req, out_array);
}
void Setup() override {}
ExecType exec_type() const override {
return exec_type_;
}
explicit FComputeExExecutor(const NodeAttrs& attrs, const DispatchMode dispatch_mode,
FComputeEx fcompute, ExecType exec_type)
: OpExecutor(attrs, dispatch_mode), fcompute_(fcompute), exec_type_(exec_type) {
}
private:
FComputeEx fcompute_;
ExecType exec_type_;
};
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
using nnvm::DTypeVector;
using mxnet::ShapeVector;
using nnvm::FMutateInputs;
static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
// get the graph
const auto& idx = g.indexed_graph();
OpExecVector& ret = *p_ret;
// initialize the nodes
const auto& inode = idx[i];
if (inode.source->is_variable()) return;
const nnvm::Op *op = inode.source->op();
ExecType exec_type = ExecType::kSync;
std::vector<uint32_t> mutate_index;
if (fmutate_inputs.count(op)) {
mutate_index = fmutate_inputs[op](inode.source->attrs);
}
if (fexec_type.count(op)) {
exec_type = fexec_type[op](inode.source->attrs);
}
CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
if (fcreate_op_state.count(op)) {
mxnet::ShapeVector ishape;
std::vector<int> itype;
for (const auto& e : inode.inputs) {
ishape.emplace_back(vshape[idx.entry_id(e)]);
itype.emplace_back(vdtype[idx.entry_id(e)]);
}
OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
if (p_state) {
CHECK_GT(p_state->size(), i);
p_state->at(i) = state;
}
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs,
dispatch_modes[i], state,
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dispatch_modes[i],
state, fcompute,
exec_type, mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
uint32_t fwd_id = inode.control_deps[0];
CHECK(vctx[fwd_id] == vctx[i]);
CHECK(ret[fwd_id] != nullptr);
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
inode.source->attrs, dispatch_modes[i], ret[fwd_id].get()->state(),
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dispatch_modes[i], ret[fwd_id].get()->state(), fcompute, exec_type,
mutate_index);
}
} else {
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<FComputeExExecutor>(
inode.source->attrs, dispatch_modes[i], fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
inode.source->attrs, dispatch_modes[i], fcompute, exec_type, mutate_index);
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
}
}
}
// pass to attach operator executors
Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
CreateOpExecs(g, &ret, nullptr, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
}
} // namespace exec
} // namespace mxnet