blob: 9227819fcceaa735170920f703a48e9f1a132c55 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_OPERATOR_NPX_CONTROL_FLOW_H_
#define MXNET_OPERATOR_NPX_CONTROL_FLOW_H_
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <dmlc/optional.h>
#include <string>
#include <vector>
namespace mxnet {
namespace op {
struct NPXForeachParam : public dmlc::Parameter<NPXForeachParam> {
int num_args;
int num_outputs;
int num_out_data;
// The location of states in the subgraph inputs.
mxnet::Tuple<dim_t> in_state_locs;
// The location of data arrays in the subgraph inputs.
mxnet::Tuple<dim_t> in_data_locs;
// The location of remaining arrays in the subgraph inputs.
mxnet::Tuple<dim_t> remain_locs;
// The index mapping from out_states to in_states.
mxnet::Tuple<dim_t> in_state_index;
DMLC_DECLARE_PARAMETER(NPXForeachParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1).describe("Number of inputs.");
DMLC_DECLARE_FIELD(num_outputs).describe("The number of outputs of the subgraph.");
DMLC_DECLARE_FIELD(num_out_data).describe("The number of output data of the subgraph.");
DMLC_DECLARE_FIELD(in_state_locs).describe("The locations of loop states among the inputs.");
DMLC_DECLARE_FIELD(in_data_locs).describe("The locations of input data among the inputs.");
DMLC_DECLARE_FIELD(remain_locs).describe("The locations of remaining data among the inputs.");
DMLC_DECLARE_FIELD(in_state_index).describe("The index mapping from out_states to in_states.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s, num_outputs_s, num_out_data_s, in_state_locs_s, in_data_locs_s,
remain_locs_s, in_state_index_s;
num_args_s << num_args;
num_outputs_s << num_outputs;
num_out_data_s << num_out_data;
in_state_locs_s << in_state_locs;
in_data_locs_s << in_data_locs;
remain_locs_s << remain_locs;
in_state_index_s << in_state_index;
}
}; // struct NPXForeachParam
struct NPXWhileLoopParam : public dmlc::Parameter<NPXWhileLoopParam> {
int num_args;
int num_outputs;
int num_out_data;
int max_iterations;
// `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs
// `cond_input_locs' contains indices of inputs fed to `cond', and
// `func_input_locs' contains indices of inputs fed to `func'.
// `func_var_locs' are indices in which input "variables" are stored in func's inputs.
mxnet::Tuple<dim_t> cond_input_locs;
mxnet::Tuple<dim_t> func_input_locs;
mxnet::Tuple<dim_t> func_var_locs;
DMLC_DECLARE_PARAMETER(NPXWhileLoopParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(2).describe(
"Number of input arguments, including cond and func as two symbol inputs.");
DMLC_DECLARE_FIELD(num_outputs)
.set_lower_bound(1)
.describe("The number of outputs of the subgraph.");
DMLC_DECLARE_FIELD(num_out_data)
.set_lower_bound(0)
.describe("The number of outputs from the function body.");
DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1).describe("Maximum number of iterations.");
DMLC_DECLARE_FIELD(cond_input_locs)
.describe("The locations of cond's inputs in the given inputs.");
DMLC_DECLARE_FIELD(func_input_locs)
.describe("The locations of func's inputs in the given inputs.");
DMLC_DECLARE_FIELD(func_var_locs).describe("The locations of loop_vars among func's inputs.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s, num_outputs_s, num_out_data_s, max_iterations_s,
cond_input_locs_s, func_input_locs_s, func_var_locs_s;
num_args_s << num_args;
num_outputs_s << num_outputs;
num_out_data_s << num_out_data;
max_iterations_s << max_iterations;
cond_input_locs_s << cond_input_locs;
func_input_locs_s << func_input_locs;
func_var_locs_s << func_var_locs;
}
template <typename T>
bool sync_in_out(std::vector<T>* in,
std::vector<T>* out,
std::function<bool(const T&)> is_empty) const {
for (int i = this->num_out_data; i < this->num_outputs; ++i) {
// each out->at(i) is a params, loop_var
T& x = in->at(this->func_input_locs[this->func_var_locs[i - this->num_out_data]]);
T& y = out->at(i);
fill_value(&x, &y, is_empty(x), is_empty(y));
}
return true;
}
}; // struct NPXWhileLoopParam
struct NPXCondParam : public dmlc::Parameter<NPXCondParam> {
int num_args;
int num_outputs;
mxnet::Tuple<dim_t> cond_input_locs;
mxnet::Tuple<dim_t> then_input_locs;
mxnet::Tuple<dim_t> else_input_locs;
DMLC_DECLARE_PARAMETER(NPXCondParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(3).describe(
"Number of input arguments, including cond, then and else as three symbol inputs.");
DMLC_DECLARE_FIELD(num_outputs)
.set_lower_bound(1)
.describe("The number of outputs of the subgraph.");
DMLC_DECLARE_FIELD(cond_input_locs)
.describe("The locations of cond's inputs in the given inputs.");
DMLC_DECLARE_FIELD(then_input_locs)
.describe("The locations of then's inputs in the given inputs.");
DMLC_DECLARE_FIELD(else_input_locs)
.describe("The locations of else's inputs in the given inputs.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s, num_outputs_s, cond_input_locs_s, then_input_locs_s,
else_input_locs_s;
num_args_s << num_args;
num_outputs_s << num_outputs;
cond_input_locs_s << cond_input_locs;
then_input_locs_s << then_input_locs;
else_input_locs_s << else_input_locs;
}
}; // struct NPXCondParam
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NPX_CONTROL_FLOW_H_