| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \brief Tensor Compute Op. |
| * \file tensor_compute_op.cc |
| */ |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <unordered_set> |
| |
| #include "./compute_op.h" |
| #include "./op_util.h" |
| |
| namespace tvm { |
| namespace te { |
| using namespace tir; |
| // TensorComputeOpNode |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const TensorComputeOpNode*>(node.get()); |
| p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); |
| |
| int TensorComputeOpNode::num_outputs() const { |
| return static_cast<int>(this->intrin->buffers.size() - this->inputs.size()); |
| } |
| |
| DataType TensorComputeOpNode::output_dtype(size_t i) const { |
| return this->intrin->buffers[this->inputs.size() + i]->dtype; |
| } |
| |
| TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis, |
| Array<IterVar> reduce_axis, int schedulable_ndim, |
| TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, |
| Array<PrimExpr> scalar_inputs) { |
| auto n = make_object<TensorComputeOpNode>(); |
| n->name = std::move(name); |
| n->tag = std::move(tag); |
| n->axis = std::move(axis); |
| n->reduce_axis = std::move(reduce_axis); |
| n->schedulable_ndim = std::move(schedulable_ndim); |
| n->intrin = std::move(intrin); |
| n->inputs = std::move(tensors); |
| n->input_regions = std::move(regions); |
| n->scalar_inputs = std::move(scalar_inputs); |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("te.TensorComputeOp") |
| .set_body_typed([](std::string name, std::string tag, Array<IterVar> axis, |
| Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin, |
| Array<Tensor> tensors, Array<Region> regions, |
| Array<PrimExpr> scalar_inputs) { |
| return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors, |
| regions, scalar_inputs); |
| }); |
| |
| Array<Tensor> TensorComputeOpNode::InputTensors() const { return inputs; } |
| |
| Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, |
| const std::unordered_map<Tensor, Tensor>& rmap) const { |
| CHECK_EQ(self.operator->(), this); |
| auto n = make_object<TensorComputeOpNode>(*this); |
| auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->())); |
| intrin->body = ReplaceTensor(this->intrin->body, rmap); |
| if (intrin->reduce_init.defined()) { |
| intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap); |
| } |
| if (intrin->reduce_update.defined()) { |
| intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap); |
| } |
| for (size_t i = 0; i < n->inputs.size(); ++i) { |
| Tensor t = n->inputs[i]; |
| if (rmap.count(t)) { |
| n->inputs.Set(i, rmap.at(t)); |
| } |
| } |
| |
| if (intrin->body.same_as(n->intrin->body) && |
| intrin->reduce_init.same_as(n->intrin->reduce_init) && |
| intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) { |
| return self; |
| } else { |
| n->intrin = TensorIntrin(intrin); |
| return Operation(n); |
| } |
| } |
| |
| void TensorComputeOpNode::PropBoundToInputs( |
| const Operation& self, arith::Analyzer* analyzer, |
| const std::unordered_map<const VarNode*, IntSet>& dom_map, |
| std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
| for (size_t i = 0; i < this->inputs.size(); ++i) { |
| Tensor t = this->inputs[i]; |
| Region region = input_regions[i]; |
| |
| auto it = out_dom_map->find(t); |
| if (it == out_dom_map->end()) continue; |
| TensorDom& dom = it->second; |
| for (size_t j = 0; j < t.ndim(); ++j) { |
| dom.data[j].emplace_back(EvalSet(region[j], dom_map)); |
| } |
| } |
| } |
| |
| size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; } |
| |
| Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, |
| const std::unordered_map<IterVar, Range>& dom_map, |
| bool debug_keep_trivial_loop) const { |
| CHECK_EQ(stage->op.operator->(), this); |
| |
| // Start bind data. |
| Stmt nop = Evaluate(0); |
| std::vector<Stmt> input_bind_nest, output_bind_nest; |
| Array<Tensor> inputs = this->InputTensors(); |
| |
| // input binding |
| size_t num_inputs = inputs.size(); |
| for (size_t i = 0; i < num_inputs; ++i) { |
| Tensor tensor = inputs[i]; |
| Region region = this->input_regions[i]; |
| Buffer buffer = this->intrin->buffers[i]; |
| Array<ObjectRef> bind_spec{buffer, tensor}; |
| |
| Array<PrimExpr> tuple; |
| for (size_t i = 0; i < region.size(); ++i) { |
| tuple.push_back(region[i]->min); |
| tuple.push_back(region[i]->extent); |
| } |
| input_bind_nest.emplace_back( |
| AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
| Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
| } |
| |
| // output binding |
| for (int i = 0; i < this->num_outputs(); ++i) { |
| Tensor tensor = stage->op.output(i); |
| Buffer buffer = this->intrin->buffers[num_inputs + i]; |
| Array<ObjectRef> bind_spec{buffer, tensor}; |
| |
| Array<PrimExpr> tuple; |
| for (size_t i = 0; i < this->axis.size(); ++i) { |
| auto ivar = this->axis[i]; |
| if (i < static_cast<size_t>(this->schedulable_ndim)) { |
| tuple.push_back(ivar->var); |
| tuple.push_back(1); |
| } else { |
| Range dom = ivar->dom; |
| tuple.push_back(dom->min); |
| tuple.push_back(dom->extent); |
| } |
| } |
| |
| output_bind_nest.emplace_back( |
| AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
| Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
| } |
| |
| // Check variable remap |
| std::unordered_map<const VarNode*, PrimExpr> vmap; |
| tir::ArgBinder binder(&vmap); |
| |
| // Map the expressions passed in the call to the TensorIntrin, to the placeholder |
| // variables |
| Array<PrimExpr> user_expr = this->scalar_inputs; |
| Array<Var> scalar_params = this->intrin->scalar_params; |
| Array<PrimExpr> sp_expr; |
| for (auto sp : scalar_params) { |
| PrimExpr esp = sp; |
| sp_expr.push_back(esp); |
| } |
| CHECK_EQ(sp_expr.size(), user_expr.size()); |
| // TODO(jdavies-huawei): what name should be used here? |
| binder.BindArray(sp_expr, user_expr, this->name); |
| |
| size_t tloc = stage->leaf_iter_vars.size(); |
| ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop); |
| |
| if (this->reduce_axis.size() == 0) { |
| std::vector<std::vector<Stmt> > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); |
| nest.emplace_back(MakeIfNest(n.main_predicates)); |
| CHECK_EQ(n.init_predicates.size(), 0U); |
| CHECK(this->intrin->body.defined()) |
| << "Normal store op for intrin " << this << " is not defined"; |
| Stmt body = MergeNest(output_bind_nest, this->intrin->body); |
| body = MergeNest(input_bind_nest, body); |
| body = tir::Substitute(body, vmap); |
| body = MergeNest(binder.asserts(), body); |
| body = te::Substitute(body, n.main_vmap); |
| Stmt ret = MergeNest(nest, body); |
| return ret; |
| } else { |
| // Need to split reduction |
| CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; |
| // Need init and update steps |
| CHECK_NE(this->reduce_axis.size(), 0U); |
| std::vector<std::vector<Stmt> > common(n.main_nest.begin(), |
| n.main_nest.begin() + n.num_common_loop + 1); |
| std::vector<std::vector<Stmt> > update_nest(n.main_nest.begin() + n.num_common_loop + 1, |
| n.main_nest.begin() + tloc + 1); |
| update_nest.emplace_back(MakeIfNest(n.main_predicates)); |
| |
| if (this->intrin->reduce_init.defined()) { |
| // init nest |
| std::vector<std::vector<Stmt> > init_nest(n.init_nest.begin(), |
| n.init_nest.begin() + tloc + 1); |
| init_nest.emplace_back(MakeIfNest(n.init_predicates)); |
| Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); |
| init = te::Substitute(init, n.init_vmap); |
| init = MergeNest(init_nest, init); |
| // The update |
| Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); |
| update = MergeNest(input_bind_nest, update); |
| update = tir::Substitute(update, vmap); |
| update = MergeNest(binder.asserts(), update); |
| update = te::Substitute(update, n.main_vmap); |
| update = MergeNest(update_nest, update); |
| return MergeNest(common, SeqStmt::Flatten(init, update)); |
| } else { |
| // When init op is not available, use body op for reset in the first iter. |
| CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; |
| Stmt update = |
| TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update); |
| update = MergeNest(output_bind_nest, update); |
| update = MergeNest(input_bind_nest, update); |
| update = tir::Substitute(update, vmap); |
| update = MergeNest(binder.asserts(), update); |
| update = te::Substitute(update, n.main_vmap); |
| update = MergeNest(update_nest, update); |
| return MergeNest(common, update); |
| } |
| } |
| } |
| } // namespace te |
| } // namespace tvm |