blob: c3b2a0b44eb15ff3f624659632aad84ec000bd31 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
5B * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Compute Op.
* \file compute_op.cc
*/
#include "compute_op.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include <unordered_set>
#include <utility>
#include "../../arith/interval_set.h"
#include "../schedule/message_passing.h"
#include "op_util.h"
namespace tvm {
namespace te {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode* op);
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
}
int ComputeOpNode::num_outputs() const { return body.size(); }
Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
DataType ComputeOpNode::output_dtype(size_t idx) const {
CHECK_LT(idx, num_outputs());
return body[idx].dtype();
}
Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of a BaseComputeOp have the same shape
Array<PrimExpr> shape;
for (const auto& ivar : this->axis) {
const Range& r = ivar->dom;
shape.push_back(r->extent);
}
return shape;
}
Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag,
Map<String, ObjectRef> attrs) {
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0);
}
Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string name,
std::string tag, Map<String, ObjectRef> attrs) {
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args));
Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx));
}
return outputs;
}
ComputeOp::ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body) {
if (!attrs.defined()) {
attrs = Map<String, ObjectRef>();
}
auto n = make_object<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->axis = std::move(axis);
n->body = std::move(body);
if (n->body[0]->IsInstance<tir::ReduceNode>()) {
const tir::ReduceNode* reduce = n->body[0].as<tir::ReduceNode>();
n->reduce_axis = reduce->axis;
}
VerifyComputeOp(n.get());
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis,
Array<PrimExpr> body) { return ComputeOp(name, tag, attrs, axis, body); });
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
std::unordered_set<Tensor> visited;
for (auto& e : body) {
tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
if (auto* pload = n.as<tir::ProducerLoadNode>()) {
Tensor t = Downcast<Tensor>(pload->producer);
if (!visited.count(t)) {
ret.push_back(t);
visited.insert(t);
}
}
});
}
return ret;
}
Operation ComputeOpNode::ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
Array<PrimExpr> arr;
if (this->body[0]->IsInstance<tir::ReduceNode>()) {
// Specially handle reduce so the replaced op
// still share all the components
PrimExpr new_reduce = te::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const tir::ReduceNode* r = new_reduce.as<tir::ReduceNode>();
for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_object<tir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
arr.push_back(PrimExpr(n));
}
} else {
arr = this->body;
}
} else {
arr =
UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); });
}
if (!arr.same_as(this->body)) {
return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr);
} else {
return self;
}
}
void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
if (auto* pload = n.as<tir::ProducerLoadNode>()) {
Tensor t = Downcast<Tensor>(pload->producer);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
// We assume that the value of the argument cannot be out of bounds (otherwise it is
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
PrimExpr shape_i_max_value = t->shape[i] - 1;
PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
// We must update bound's ends in pairs. Here is an counter example: shape_i is
// [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
// [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
// awkward for further analysis.
if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
(analyzer->CanProve(shape_i_min_value >= min_value) &&
analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value;
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::Interval(min_value, max_value));
} else {
dom.data[i].push_back(arg_intset);
}
}
}
}
};
for (auto& e : body) tir::PostOrderVisit(e, fvisit);
}
void BaseComputeOpNode::GatherBound(const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).CoverRange(this->axis[i]->dom);
CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
CHECK(!out_dom_map->count(this->reduce_axis[i]));
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
}
}
Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i - 1);
realize = tir::ProducerRealize(t, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
attr->dim_align_offset};
realize =
tir::AttrStmt(t, tir::attr::buffer_dim_align,
Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize);
}
}
}
}
return realize;
}
size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); }
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt* init,
Stmt* provide) {
Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
std::vector<Stmt> inits, provides;
size_t size = op->body.size();
const ReduceNode* reduce = op->body[0].as<ReduceNode>();
CHECK(reduce);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
Array<PrimExpr> lhs;
for (size_t i = 0; i < size; ++i) {
lhs.push_back(tensors[i](args));
}
Array<PrimExpr> init_value = combiner->identity_element;
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
// If an init was passed to ReduceNode, use that for initialization
// instead of combiner->identity_element
Array<PrimExpr> reduce_init = reduce->init;
if (!reduce_init.empty()) {
init_value = reduce_init;
}
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(ProducerStore(t, init_value[i], args));
provides.emplace_back(ProducerStore(t, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElse(reduce->condition, *provide);
}
}
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return ProducerStore(t, op->body[t->value_index], args);
}
Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
// Normal loop structure
n.init_nest.emplace_back(MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(MakeIfNest(n.main_predicates));
if (self->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < self->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(self, source, &init, &provide);
init = MergeNest(n.init_nest, init);
init = Substitute(init, n.init_vmap);
// common nest
std::vector<std::vector<Stmt> > common(n.main_nest.begin(),
n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(n.main_nest.begin() + n.num_common_loop + 1,
n.main_nest.end());
provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) {
provide = MergeNest(common, provide);
} else {
provide = MergeNest(common, SeqStmt::Flatten(init, provide));
}
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return Substitute(provide, n.main_vmap);
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = SeqStmt::Flatten(provides);
provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return Substitute(provide, n.main_vmap);
}
}
enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize };
ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
for (IterVar iv : stage->leaf_iter_vars) {
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (attr.defined() && attr->iter_type == kTensorized) {
++tensorize;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis";
}
}
if (tensorize != 0) {
CHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
return ComputeType::kNormal;
}
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
} else {
return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
}
}
ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(),
&ret.main_vmap, debug_keep_trivial_loop);
ret.main_predicates =
MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
if (stage->store_predicate.defined()) {
ret.main_predicates.push_back(stage->store_predicate);
}
if (self->reduce_axis.size() != 0) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
update_state[self->axis[i]] = 1;
}
// find which iter var is related to reduction and which is related to axis.
te::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i;
break;
}
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if (flag == 2) skip_iter.insert(kv.first);
}
ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
debug_keep_trivial_loop);
ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
} else {
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
}
namespace {
/*!
* \brief Verify if ComputeOp is valid with respect to Reduce operations.
*
* The following two properties are verified:
* (1) All Reduce operations must exist at top level.
* (2) For a list of operations, if one is Reduce, then the others
* must be Reduce as well; and their inputs should have the
* same attribute except value_index.
*/
class ComputeVerifier final : protected tir::ExprVisitor {
public:
/// Special member functions
//@{
explicit ComputeVerifier(const ComputeOpNode* compute)
: compute_(compute), reduce_(compute->body[0].as<tir::ReduceNode>()) {}
virtual ~ComputeVerifier() = default;
ComputeVerifier(const ComputeVerifier&) = delete;
ComputeVerifier(ComputeVerifier&&) = delete;
ComputeVerifier& operator=(const ComputeVerifier&) = delete;
ComputeVerifier& operator=(ComputeVerifier&&) = delete;
//@}
/// Interface to perform compute verification
void Run() {
for (const PrimExpr e : compute_->body) {
// Check for consistency of top level reductions
const tir::ReduceNode* reduce = e.as<tir::ReduceNode>();
CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent "
<< "with being Reduce operation or not.";
if (reduce && reduce_) {
CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
level_ = 0;
ExprVisitor::VisitExpr(e);
}
}
protected:
/// Visitor implementation
//@{
void VisitExpr(const PrimExpr& n) final {
++level_;
ExprVisitor::VisitExpr(n);
--level_;
}
void VisitExpr_(const tir::ReduceNode* op) final {
// Check for non top level reductions
CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. "
<< "Please create another tensor for further composition.";
}
//@}
private:
const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
int level_{0}; ///< Level of op being processed
};
} // namespace
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode* op) {
ComputeVerifier v(op);
v.Run();
}
Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n, Stmt body, Stmt update) {
Array<PrimExpr> conds;
std::unordered_set<const VarNode*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
auto fbanned = [&](const VarNode* node) { return banned.count(node); };
for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
}
auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
return IfThenElse(cond, update, body);
}
} // namespace te
} // namespace tvm