| /*! |
| * Copyright (c) 2017 by Contributors |
| * \brief Logics related to tensorize, used by ComputeOpNode. |
| * \file tensorize.cc |
| */ |
| #include <tvm/ir.h> |
| #include <tvm/ir_mutator.h> |
| #include <tvm/ir_pass.h> |
| #include <tvm/api_registry.h> |
| #include "op_util.h" |
| #include "compute_op.h" |
| #include "../schedule/message_passing.h" |
| |
| namespace tvm { |
| |
| using namespace ir; |
| using namespace op; |
| |
| // Detect the region of input and output to be tensrized. |
| // out_dom: the domain of root iter vars in output op |
| // in_region: region of each input tensor. |
| // return The location of the tensorized scope start. |
| size_t InferTensorizeRegion( |
| const ComputeOpNode* self, |
| const Stage& stage, |
| const std::unordered_map<IterVar, Range>& dom_map, |
| std::unordered_map<IterVar, Range>* out_dom, |
| std::unordered_map<Tensor, Array<Range> >* in_region) { |
| // Get the bound of the tensorized scope. |
| bool found_point = false; |
| size_t loc_scope = 0; |
| std::unordered_map<IterVar, IntSet> up_state; |
| // Loop over the leafs |
| for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) { |
| IterVar iv = stage->leaf_iter_vars[i - 1]; |
| CHECK(iv->iter_type == kDataPar || |
| iv->iter_type == kCommReduce); |
| auto vit = dom_map.find(iv); |
| CHECK(vit != dom_map.end()); |
| const Range& vrange = vit->second; |
| if (is_one(vrange->extent)) { |
| up_state[iv] = IntSet::single_point(vrange->min); |
| } else if (found_point) { |
| CHECK(is_zero(vrange->min)); |
| up_state[iv] = IntSet::single_point(iv->var); |
| } else { |
| up_state[iv] = IntSet::range(vrange); |
| } |
| auto iit = stage->iter_var_attrs.find(iv); |
| if (iit != stage->iter_var_attrs.end()) { |
| const IterVarAttr& attr = (*iit).second; |
| if (!found_point) { |
| CHECK(!attr->bind_thread.defined()) |
| << "Do not allow thread in tensorize scope"; |
| } |
| if (attr->iter_type == kTensorized) { |
| CHECK(!found_point) << "Do not allow two tensorized point"; |
| found_point = true; |
| loc_scope = i - 1; |
| } |
| } |
| } |
| CHECK(found_point); |
| // Get domain of the tensorized scope. |
| schedule::PassUpDomain(stage, dom_map, &up_state); |
| // Get domains if inputs |
| std::unordered_map<Tensor, TensorDom> in_dom; |
| std::unordered_map<const Variable*, IntSet> temp_dmap; |
| Array<Tensor> inputs = self->InputTensors(); |
| for (Tensor t : inputs) { |
| in_dom.emplace(t, TensorDom(t.ndim())); |
| } |
| for (IterVar iv : self->root_iter_vars()) { |
| IntSet iset = up_state.at(iv); |
| (*out_dom)[iv] = iset.cover_range(dom_map.at(iv)); |
| temp_dmap[iv->var.get()] = iset; |
| } |
| // Input domains |
| self->PropBoundToInputs(stage->op, temp_dmap, &in_dom); |
| Range none; |
| for (const auto& kv : in_dom) { |
| Array<Range> vec; |
| const Tensor& t = kv.first; |
| for (size_t i = 0; i < t.ndim(); ++i) { |
| Range r = arith::Union(kv.second.data.at(i)).cover_range(none); |
| CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t; |
| vec.push_back(std::move(r)); |
| } |
| (*in_region)[t] = std::move(vec); |
| } |
| return loc_scope; |
| } |
| |
| void VerifyTensorizeLoopNest(const ComputeOpNode* self, |
| const Stage& stage, |
| const ComputeLoopNest& n, |
| size_t tloc) { |
| // Veirfication step. |
| std::unordered_set<const Variable*> banned; |
| CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); |
| CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || |
| n.init_nest.size() == 0); |
| auto f_push_banned = [&banned](const Stmt& s) { |
| if (const For* op = s.as<For>()) { |
| banned.insert(op->loop_var.get()); |
| } else if (const AttrStmt* op = s.as<AttrStmt>()) { |
| if (const IterVarNode* iv = op->node.as<IterVarNode>()) { |
| banned.insert(iv->var.get()); |
| } |
| } else if (const LetStmt* op = s.as<LetStmt>()) { |
| banned.insert(op->var.get()); |
| } |
| }; |
| for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) { |
| for (const Stmt& s : n.main_nest[i + 1]) { |
| f_push_banned(s); |
| } |
| if (n.init_nest.size() != 0) { |
| for (const Stmt& s : n.init_nest[i + 1]) { |
| f_push_banned(s); |
| } |
| } |
| } |
| for (const Expr& pred : n.main_predicates) { |
| if (ir::ExprUseVar(pred, banned)) { |
| LOG(FATAL) << "Tensorize failed, split condition " |
| << pred << " relies on var defined inside tensorize scope"; |
| } |
| } |
| for (const Expr& pred : n.init_predicates) { |
| if (ir::ExprUseVar(pred, banned)) { |
| LOG(FATAL) << "Tensorize failed, split condition " |
| << pred << " relies on var defined inside tensorize scope"; |
| } |
| } |
| } |
| |
| |
| // Remap the tensor placeholder, index and inline things. |
| class TensorIntrinMatcher final : public IRMutator { |
| public: |
| Expr Mutate_(const Call* op, const Expr& e) final { |
| Expr expr = IRMutator::Mutate_(op, e); |
| op = expr.as<Call>(); |
| if (op->call_type == Call::Halide) { |
| Tensor t = Operation(op->func.node_).output(op->value_index); |
| auto it = in_remap_.find(t); |
| if (it != in_remap_.end()) { |
| const InputEntry& e = it->second; |
| CHECK_EQ(op->args.size(), e.region.size()); |
| Array<Expr> args; |
| for (size_t i = e.start; i < e.region.size(); ++i) { |
| args.push_back(op->args[i] - e.region[i]->min); |
| } |
| return Call::make( |
| op->type, e.tensor->op->name, args, |
| op->call_type, e.tensor->op, e.tensor->value_index); |
| } |
| } |
| return expr; |
| } |
| |
| Expr Mutate_(const Variable* op, const Expr& e) final { |
| auto it = var_remap_.find(op); |
| if (it != var_remap_.end()) { |
| return it->second; |
| } else { |
| return e; |
| } |
| } |
| |
| Expr Mutate_(const Reduce* op, const Expr& e) final { |
| Expr expr = IRMutator::Mutate_(op, e); |
| op = expr.as<Reduce>(); |
| Array<IterVar> axis; |
| for (size_t i = 0; i < op->axis.size(); ++i) { |
| auto it = axis_remap_.find(op->axis[i]); |
| if (it != axis_remap_.end()) { |
| axis.push_back(it->second); |
| } |
| } |
| return Reduce::make( |
| op->combiner, op->source, axis, op->condition, op->value_index); |
| } |
| |
| void Init(const ComputeOpNode* self, |
| const Stage& stage, |
| const std::unordered_map<IterVar, Range>& out_dom, |
| const std::unordered_map<Tensor, Array<Range> >& in_region, |
| const TensorIntrin& intrin, |
| Map<Var, Range>* compute_intrin_iter_space) { |
| CHECK(self == stage->op.get()); |
| // input remap. |
| Array<Tensor> inputs = self->InputTensors(); |
| CHECK_EQ(inputs.size(), intrin->inputs.size()); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| InputEntry e; |
| e.tensor = intrin->inputs[i]; |
| e.region = Array<Range>(in_region.at(inputs[i])); |
| CHECK_GE(e.region.size(), e.tensor.ndim()); |
| // Enable fuzzy matching, to match [1, n, m] to [n, m] |
| e.start = e.region.size() - e.tensor.ndim(); |
| for (size_t i = 0; i < e.start; ++i) { |
| CHECK(is_one(e.region[i]->extent)) |
| << "Tensorize " << intrin->name << ":" |
| << " Input dimension mismatch with tensor intrin " |
| << " expected shape=" << e.tensor->shape |
| << ", given region=" << e.region; |
| } |
| in_remap_[inputs[i]] = e; |
| } |
| // output remap |
| const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
| CHECK(intrin_compute) << "Only support compute intrinsic for now"; |
| CHECK_GE(self->axis.size(), intrin_compute->axis.size()) |
| << "Tensorize: Output mismatch with tensor intrin "; |
| // Enable fuzzy matching, to match [1, n, m] to [n, m] |
| size_t axis_start = self->axis.size() - intrin_compute->axis.size(); |
| for (size_t i = 0; i < axis_start; ++i) { |
| Range r = out_dom.at(self->axis[i]); |
| CHECK(is_one(r->extent)) |
| << "Tensorize: Output mismatch with tensor intrin " |
| << " intrin-dim=" << intrin_compute->axis.size() |
| << ", tensorize-dim=" << self->axis.size(); |
| var_remap_[self->axis[i]->var.get()] = r->min; |
| } |
| // Assume we tensorize at regin axis i [min, min + extent) |
| // The corresponding intrinsic axis is j [0, extent) |
| // Remap index i to j + min |
| for (size_t i = axis_start; i < self->axis.size(); ++i) { |
| IterVar iv = self->axis[i]; |
| IterVar target_iv = intrin_compute->axis[i - axis_start]; |
| Range r = out_dom.at(iv); |
| var_remap_[iv->var.get()] = target_iv->var + r->min; |
| axis_remap_[iv] = target_iv; |
| compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); |
| } |
| // Remap reduction axis |
| CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) |
| << "Tensorize: Reduction dimension mismatch with tensor intrin"; |
| axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); |
| for (size_t i = 0; i < axis_start; ++i) { |
| Range r = out_dom.at(self->reduce_axis[i]); |
| CHECK(is_one(r->extent)) |
| << "Tensorize: Reduction mismatch with tensor intrin " |
| << " intrin-dim=" << intrin_compute->reduce_axis.size() |
| << ", tensorize-dim=" << self->reduce_axis.size(); |
| var_remap_[self->reduce_axis[i]->var.get()] = r->min; |
| } |
| for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { |
| IterVar iv = self->reduce_axis[i]; |
| IterVar target_iv = intrin_compute->reduce_axis[i - axis_start]; |
| Range r = out_dom.at(iv); |
| var_remap_[iv->var.get()] = target_iv->var + r->min; |
| axis_remap_[iv] = target_iv; |
| compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); |
| } |
| } |
| |
| private: |
| // Input entry |
| struct InputEntry { |
| Tensor tensor; |
| size_t start; |
| Array<Range> region; |
| }; |
| // input data remap |
| std::unordered_map<Tensor, InputEntry> in_remap_; |
| // variable remap. |
| std::unordered_map<const Variable*, Expr> var_remap_; |
| // IterVar remap. |
| std::unordered_map<IterVar, IterVar> axis_remap_; |
| }; |
| |
| // Try to match tensor dataflow of the stage with the intrinsic |
| Array<Expr> MatchTensorizeBody( |
| const ComputeOpNode* self, |
| const Stage& stage, |
| const std::unordered_map<IterVar, Range>& out_dom, |
| const std::unordered_map<Tensor, Array<Range> >& in_region, |
| const TensorIntrin& intrin, |
| Map<Var, Range>* compute_intrin_iter_space) { |
| TensorIntrinMatcher matcher; |
| matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space); |
| Array<Expr> ret; |
| for (Expr expr : self->body) { |
| ret.push_back(matcher.Mutate(expr)); |
| } |
| return ret; |
| } |
| |
| void VerifyTensorizeBody( |
| const ComputeOpNode* self, |
| const Stage& stage, |
| const std::unordered_map<IterVar, Range>& out_dom, |
| const std::unordered_map<Tensor, Array<Range> >& in_region, |
| const TensorIntrin& intrin) { |
| Map<Var, Range> compute_intrin_iter_space; |
| Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin, |
| &compute_intrin_iter_space); |
| const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
| CHECK(intrin_compute) << "Only support compute intrinsic for now"; |
| CHECK_EQ(body.size(), intrin_compute->body.size()) |
| << "Tensorize failed: body size mismatch"; |
| for (size_t i = 0; i < body.size(); ++i) { |
| Expr lhs = Simplify(body[i], compute_intrin_iter_space); |
| lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); |
| Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); |
| rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); |
| if (lhs.type() != rhs.type()) { |
| LOG(FATAL) |
| << "Failed to match the data type with TensorIntrin " |
| << intrin->name << "'s declaration " |
| << " provided=" << lhs.type() |
| << ", intrin=" << rhs.type(); |
| } |
| CHECK(Equal(lhs, rhs)) |
| << "Failed to match the compute with TensorIntrin " |
| << intrin->name << "'s declaration " |
| << " provided= " << lhs |
| << ", intrin= " << rhs; |
| } |
| } |
| |
| Stmt MakeTensorize(const ComputeOpNode* self, |
| const Stage& stage, |
| const std::unordered_map<IterVar, Range>& dom_map, |
| bool debug_keep_trivial_loop) { |
| std::unordered_map<IterVar, Range> out_dom; |
| std::unordered_map<Tensor, Array<Range> > in_region; |
| size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); |
| TensorIntrin intrin = stage->iter_var_attrs.at( |
| stage->leaf_iter_vars[tloc])->tensor_intrin; |
| CHECK(intrin.defined()); |
| ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); |
| VerifyTensorizeLoopNest(self, stage, n, tloc); |
| VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); |
| // Start bind data. |
| Stmt nop = Evaluate::make(0); |
| std::vector<Stmt> input_bind_nest, output_bind_nest; |
| Array<Tensor> inputs = self->InputTensors(); |
| CHECK_EQ(inputs.size(), intrin->inputs.size()) |
| << "Tensorize failed: input size mismatch "; |
| // input binding |
| for (size_t i = 0; i < intrin->inputs.size(); ++i) { |
| Tensor tensor = inputs[i]; |
| Buffer buffer = intrin->buffers[i]; |
| Array<NodeRef> bind_spec{buffer, tensor}; |
| auto it = in_region.find(tensor); |
| CHECK(it != in_region.end()); |
| const Array<Range>& region = it->second; |
| Array<Expr> tuple; |
| for (const Range r : region) { |
| tuple.push_back(r->min); |
| tuple.push_back(r->extent); |
| } |
| input_bind_nest.emplace_back(AttrStmt::make( |
| bind_spec, ir::attr::buffer_bind_scope, |
| Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); |
| } |
| // output binding |
| const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
| CHECK(intrin_compute) << "Only support compute intrinsic for now"; |
| CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size()); |
| CHECK_EQ(intrin_compute->body.size(), self->body.size()); |
| Array<Expr> tuple; |
| for (IterVar iv : self->axis) { |
| auto it = out_dom.find(iv); |
| CHECK(it != out_dom.end()); |
| tuple.push_back(it->second->min); |
| tuple.push_back(it->second->extent); |
| } |
| for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) { |
| Tensor tensor = stage->op.output(i - intrin->inputs.size()); |
| Buffer buffer = intrin->buffers[i]; |
| Array<NodeRef> bind_spec{buffer, tensor}; |
| output_bind_nest.emplace_back(AttrStmt::make( |
| bind_spec, ir::attr::buffer_bind_scope, |
| Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); |
| } |
| // Check variable remap |
| std::unordered_map<const Variable*, Expr> vmap; |
| ir::ArgBinder binder(&vmap); |
| CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) |
| << "Tensorization fail: reduction axis size do not match"; |
| size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); |
| for (size_t i = 0; i < start; ++i) { |
| IterVar iv = self->reduce_axis[i]; |
| auto it = out_dom.find(iv); |
| CHECK(it != out_dom.end()); |
| CHECK(is_one(it->second->extent)) |
| << "Tensorization fail: reduction axis size do not match"; |
| } |
| for (size_t i = start; i < self->reduce_axis.size(); ++i) { |
| IterVar iv = self->reduce_axis[i]; |
| IterVar target = intrin_compute->reduce_axis[i - start]; |
| auto it = out_dom.find(iv); |
| CHECK(it != out_dom.end()); |
| binder.Bind(target->dom->min, make_const(iv->dom->min.type(), 0), |
| "tensir_intrin.reduction.min"); |
| binder.Bind(target->dom->extent, it->second->extent, |
| "tensir_intrin.reduction.extent"); |
| } |
| if (tloc <= n.num_common_loop) { |
| // Do no need to split reduction |
| std::vector<std::vector<Stmt> > nest( |
| n.main_nest.begin(), n.main_nest.begin() + tloc + 1); |
| nest.emplace_back(op::MakeIfNest(n.main_predicates)); |
| CHECK_EQ(n.init_predicates.size(), 0U); |
| CHECK(intrin->body.defined()) |
| << "Normal store op for intrin " << intrin << " is not defined"; |
| Stmt body = MergeNest(output_bind_nest, intrin->body); |
| body = MergeNest(input_bind_nest, body); |
| body = Substitute(body, vmap); |
| body = MergeNest(binder.asserts(), body); |
| body = Substitute(body, n.main_vmap); |
| return MergeNest(nest, body); |
| } else { |
| // Need to split reduction |
| CHECK(intrin->reduce_update.defined()) |
| << "Reduction update op for intrin " << intrin << " is not defined"; |
| // Need init and update steps |
| CHECK_NE(self->reduce_axis.size(), 0U); |
| std::vector<std::vector<Stmt> > common( |
| n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); |
| std::vector<std::vector<Stmt> > update_nest( |
| n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); |
| update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); |
| |
| if (intrin->reduce_init.defined()) { |
| // init nest |
| std::vector<std::vector<Stmt> > init_nest( |
| n.init_nest.begin(), n.init_nest.begin() + tloc + 1); |
| init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); |
| Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); |
| init = Substitute(init, n.init_vmap); |
| init = MergeNest(init_nest, init); |
| // The update |
| Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); |
| update = MergeNest(input_bind_nest, update); |
| update = Substitute(update, vmap); |
| update = MergeNest(binder.asserts(), update); |
| update = Substitute(update, n.main_vmap); |
| update = MergeNest(update_nest, update); |
| return MergeNest(common, Block::make(init, update)); |
| } else { |
| // When init op is not available, use body op for reset in the first iter. |
| CHECK(intrin->body.defined()) |
| << "Normal body op for intrin " << intrin << " is not defined"; |
| Stmt update = TransformUpdate(stage, dom_map, n, |
| intrin->body, |
| intrin->reduce_update); |
| update = MergeNest(output_bind_nest, update); |
| update = MergeNest(input_bind_nest, update); |
| update = Substitute(update, vmap); |
| update = MergeNest(binder.asserts(), update); |
| update = Substitute(update, n.main_vmap); |
| update = MergeNest(update_nest, update); |
| return MergeNest(common, update); |
| } |
| } |
| } |
| |
| // Register functions for unittests |
| TVM_REGISTER_API("test.op.InferTensorizeRegion") |
| .set_body([](TVMArgs args, TVMRetValue* ret) { |
| Stage stage = args[0]; |
| Map<IterVar, Range> dmap = args[1]; |
| std::unordered_map<IterVar, Range> out_dom; |
| std::unordered_map<Tensor, Array<Range> > in_region; |
| CHECK(stage->op.as<ComputeOpNode>()); |
| InferTensorizeRegion(stage->op.as<ComputeOpNode>(), |
| stage, |
| as_unordered_map(dmap), |
| &out_dom, &in_region); |
| *ret = Array<NodeRef>{Map<IterVar, Range>(out_dom), |
| Map<Tensor, Array<Range> >(in_region)}; |
| }); |
| |
| TVM_REGISTER_API("test.op.MatchTensorizeBody") |
| .set_body([](TVMArgs args, TVMRetValue* ret) { |
| Stage stage = args[0]; |
| Map<IterVar, Range> out_dom = args[1]; |
| Map<Tensor, Array<Range> > in_region = args[2]; |
| TensorIntrin intrin = args[3]; |
| Map<Var, Range> vrange; |
| CHECK(stage->op.as<ComputeOpNode>()); |
| *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), |
| stage, |
| as_unordered_map(out_dom), |
| as_unordered_map(in_region), |
| intrin, |
| &vrange); |
| }); |
| } // namespace tvm |