| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #include "create_primfunc.h" |
| |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/name_supply.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/function.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <algorithm> |
| #include <set> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| #include "../../support/array.h" |
| #include "../../tir/ir/data_type_rewriter.h" |
| #include "../../tir/ir/functor_common.h" |
| #include "../../tir/transform/ir_utils.h" |
| #include "graph.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| /*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */ |
| class ProducerToBufferTransformer : public StmtExprMutator { |
| public: |
| explicit ProducerToBufferTransformer(const std::unordered_map<te::Tensor, Buffer>& tensor2buffers) |
| : tensor2buffers_(tensor2buffers) {} |
| |
| PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
| auto visited_op = Downcast<ProducerLoad>(StmtExprMutator::VisitExpr_(op)); |
| te::Tensor tensor = Downcast<te::Tensor>(visited_op->producer); |
| auto it = tensor2buffers_.find(tensor); |
| ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor; |
| const Buffer& buffer = it->second; |
| return BufferLoad(buffer, visited_op->indices); |
| } |
| |
| private: |
| /*! \brief The Map from Operations to buffers */ |
| const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_; |
| }; |
| |
| /*! \brief The helper mutator to rewrite buffer and buffer var accessed by block body */ |
| class BufferSubstituter : public StmtExprMutator { |
| public: |
| explicit BufferSubstituter(const std::unordered_map<const VarNode*, PrimExpr>& var_map, |
| const std::unordered_map<const BufferNode*, Buffer>& buffer_map) |
| : var_map_(var_map), buffer_map_(buffer_map) {} |
| |
| PrimExpr VisitExpr_(const VarNode* op) final { |
| auto it = var_map_.find(op); |
| if (it != var_map_.end()) { |
| return it->second; |
| } |
| return StmtExprMutator::VisitExpr_(op); |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| auto it = buffer_map_.find(load->buffer.get()); |
| if (it != buffer_map_.end()) { |
| return BufferLoad(it->second, load->indices, load->predicate, load->span); |
| } |
| return load; |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| auto it = buffer_map_.find(store->buffer.get()); |
| if (it != buffer_map_.end()) { |
| return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); |
| } |
| return store; |
| } |
| |
| private: |
| const std::unordered_map<const VarNode*, PrimExpr>& var_map_; |
| const std::unordered_map<const BufferNode*, Buffer>& buffer_map_; |
| }; |
| |
| /*! \brief Helper data structure to store information. */ |
| struct CreateFuncInfo { |
| /*! \brief The Tensor arg_list. */ |
| ffi::Array<te::Tensor> arg_list; |
| /*! \brief The map from each Tensor to its corresponding buffer. */ |
| std::unordered_map<te::Tensor, Buffer> tensor2buffers; |
| /*! \brief The transformer from ProducerLoad to BufferLoad. */ |
| ProducerToBufferTransformer transformer; |
| /*! \brief The buffers should be allocated at function root. */ |
| ffi::Array<Buffer> root_alloc; |
| /*! \brief The NameSupply to make block name unique. */ |
| NameSupply name_supply; |
| |
| ffi::String FreshName(ffi::String base_name) { return name_supply->FreshName(base_name); } |
| |
| explicit CreateFuncInfo(ffi::Array<te::Tensor> arg_list) |
| : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} |
| |
| bool IsArg(const te::Tensor& tensor) const { |
| return std::any_of(arg_list.begin(), arg_list.end(), |
| [&tensor](const te::Tensor& arg) { return tensor == arg; }); |
| } |
| }; |
| |
| class LayoutFreePlaceholdersNormalizer : public StmtMutator { |
| public: |
| PrimFunc Process(PrimFunc func) { |
| for (int i = 0, n = func->params.size(); i < n; ++i) { |
| if (auto v = func->params[i].as<Var>()) { |
| if (ffi::Optional<Buffer> buffer = func->buffer_map.Get(v.value())) { |
| buffer2index_[buffer.value()] = i; |
| } |
| } |
| } |
| PrimFuncNode* f = func.CopyOnWrite(); |
| f->body = VisitStmt(std::move(f->body)); |
| if (this->layout_free_buffer_indices_.empty()) { |
| return func; |
| } |
| ffi::Array<int64_t> indices; |
| indices.reserve(this->layout_free_buffer_indices_.size()); |
| for (int i : this->layout_free_buffer_indices_) { |
| indices.push_back(i); |
| } |
| return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices); |
| } |
| |
| Stmt VisitStmt_(const SBlockNode* _block) final { |
| SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(_block)); |
| SBlockNode* n = block.CopyOnWrite(); |
| if (auto opt_ann = n->annotations.Get(topi_attr)) { |
| ffi::Array<Buffer> new_buffers; |
| for (Buffer buffer : Downcast<ffi::Array<Buffer>>(opt_ann.value())) { |
| auto it = buffer2index_.find(buffer); |
| if (it != buffer2index_.end()) { |
| layout_free_buffer_indices_.insert(it->second); |
| } else { |
| new_buffers.push_back(buffer); |
| } |
| } |
| if (new_buffers.empty()) { |
| n->annotations.erase(topi_attr); |
| } else { |
| n->annotations.Set(topi_attr, new_buffers); |
| } |
| } |
| for (const ffi::String& attr : this->blocklist) { |
| auto it = n->annotations.find(attr); |
| if (it != n->annotations.end()) { |
| n->annotations.erase(attr); |
| } |
| } |
| return block; |
| } |
| |
| std::unordered_map<tir::Buffer, int, ObjectPtrHash, ObjectPtrEqual> buffer2index_; |
| std::set<int> layout_free_buffer_indices_; |
| ffi::String topi_attr = "layout_free_placeholders"; |
| std::vector<ffi::String> blocklist = {"const_matrix", |
| "auto_scheduler_simplify_const_tensor_indices", "workload"}; |
| }; |
| |
| /**! |
| * \brief The iter levels specify nested structure wrt iteration domain dependencies. |
| * (1) Each iter should reside in exactly one level. |
| * (2) The domain of low level iter should be either free or ony depend on iters in high level. |
| **/ |
| using NestedIterLevels = std::vector<std::vector<IterVar>>; |
| |
| NestedIterLevels GenerateNestedIterLevels(const ffi::Array<IterVar>& axes, |
| arith::Analyzer* analyzer) { |
| int global_max_depth = 0; |
| std::unordered_map<Var, int> depth; |
| std::unordered_map<Var, IterVar> var2iter; |
| for (const auto& axis : axes) { |
| var2iter[axis->var] = axis; |
| } |
| |
| std::function<int(const IterVar&)> traverse = [&](const IterVar& axis) -> int { |
| auto depth_it = depth.find(axis->var); |
| if (depth_it != depth.end()) { // cache |
| return depth_it->second; |
| } |
| std::vector<Var> dep_vars; |
| for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->min))) { |
| dep_vars.push_back(v); |
| } |
| for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->extent))) { |
| dep_vars.push_back(v); |
| } |
| int cur_depth = 0; |
| for (const Var& v : dep_vars) { |
| auto it = var2iter.find(v); |
| if (it == var2iter.end()) { |
| // not axis var dependency, maybe a symbolic shape var or others. |
| continue; |
| } |
| int depth = traverse(it->second); |
| cur_depth = std::max(cur_depth, depth + 1); |
| } |
| depth.emplace_hint(depth_it, axis->var, cur_depth); |
| global_max_depth = std::max(global_max_depth, cur_depth); |
| return cur_depth; |
| }; |
| |
| for (const auto& axis : axes) { |
| traverse(axis); |
| } |
| NestedIterLevels levels; |
| levels.resize(global_max_depth + 1); |
| for (const auto& axis : axes) { |
| const Var& var = axis->var; |
| levels[depth[var]].push_back(axis); |
| } |
| return levels; |
| } |
| |
| /*! |
| * \brief Generate output buffers from compute op's output tensors, and bind to context func info. |
| * \param compute_op The target compute op. |
| * \param info Generation context info. |
| * \returns The output buffer objects, ordered by compute op's outputs. |
| **/ |
| ffi::Array<Buffer> GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { |
| // Step 1. Collect output tensors in TE operation. |
| ffi::Array<te::Tensor> tensors; |
| if (compute_op->body[0]->IsInstance<ReduceNode>()) { |
| auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { |
| StructuralEqual eq; |
| return eq(a->combiner, b->combiner) && // |
| eq(a->source, b->source) && // |
| eq(a->axis, b->axis) && // |
| eq(a->condition, b->condition) && // |
| eq(a->init, b->init); |
| }; |
| PrimExpr expr_body = compute_op->body[0]; |
| tensors.push_back(compute_op.output(0)); |
| const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>(); |
| // specially handle reduction inline for multiplre reductions. |
| for (size_t k = 1; k < compute_op->body.size(); ++k) { |
| const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>(); |
| ICHECK(reduce_); |
| ICHECK(f_reducer_equal(reduce_, reduce)) |
| << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " |
| << "but the first argument has body " << ffi::GetRef<PrimExpr>(reduce_) << ", while the " |
| << k << "-th argument has body " << ffi::GetRef<PrimExpr>(reduce); |
| tensors.push_back(compute_op.output(k)); |
| } |
| } else { |
| for (size_t k = 0; k < compute_op->body.size(); ++k) { |
| tensors.push_back(compute_op.output(k)); |
| } |
| } |
| // Step 2. Prepare buffers for compute outputs |
| // - Declare buffers |
| // - Update `op2buffers` |
| // - Add the non-argument tensors to `alloc_buffer` of the root block |
| ffi::Array<Buffer> buffers; |
| for (const te::Tensor& tensor : tensors) { |
| Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); |
| info->tensor2buffers[tensor] = buffer; |
| buffers.push_back(buffer); |
| if (!info->IsArg(tensor)) { |
| info->root_alloc.push_back(info->tensor2buffers[tensor]); |
| } |
| } |
| return buffers; |
| } |
| |
| /*! |
| * \brief Generate block annotation dict from compute op attrs. |
| * \param compute_op The target compute op. |
| * \param info Generation context info. |
| * \returns The block annotation dict. |
| **/ |
| ffi::Map<ffi::String, ffi::Any> GenerateBlockAnnotations(const te::ComputeOp& compute_op, |
| CreateFuncInfo* info) { |
| ffi::Map<ffi::String, ffi::Any> annotations; |
| auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any { |
| if (auto tensor_value = value.try_cast<te::Tensor>()) { |
| return info->tensor2buffers.at(tensor_value.value()); |
| } else { |
| return value; |
| } |
| }; |
| for (const auto& pair : compute_op->attrs) { |
| const ffi::String& key = pair.first; |
| const Any& value = pair.second; |
| // TensorIR will not allow Tensor data structure |
| if (value.as<ffi::ArrayObj>()) { |
| const auto array_value = Downcast<ffi::Array<ffi::Any>>(value); |
| annotations.Set(key, array_value.Map(mutate_attr)); |
| } else { |
| annotations.Set(key, mutate_attr(value)); |
| } |
| } |
| // Set script_parsing_detect_access |
| annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); |
| return annotations; |
| } |
| |
| /*! |
| * \brief Generate init stmt for reduction. |
| * \param indices Target store indices for the block. |
| * \param buffers Target store buffers for the block. |
| * \param reduce Reduce description node. |
| * \param var_map Var re-mapping for TE compute axes. |
| * \param info Generation context info. |
| * \returns Init stmt. |
| **/ |
| Stmt GenerateInitStmt(const ffi::Array<PrimExpr>& indices, const ffi::Array<Buffer>& buffers, |
| const ReduceNode* reduce, const ffi::Map<Var, PrimExpr>& var_map, |
| CreateFuncInfo* info) { |
| // helper to transform the expr and remap iters to the block domain |
| auto f_transform_and_remap = [&](const PrimExpr& e) { |
| return Substitute(info->transformer(e), var_map); |
| }; |
| ffi::Optional<Stmt> init = std::nullopt; |
| Stmt body; |
| int n_buffers = buffers.size(); |
| ffi::Array<Stmt> init_stmts; |
| init_stmts.reserve(n_buffers); |
| for (int i = 0; i < n_buffers; ++i) { |
| const Buffer& buffer = buffers[i]; |
| PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]); |
| init_stmts.push_back(BufferStore(buffer, identity, indices)); |
| } |
| return SeqStmt::Flatten(init_stmts); |
| } |
| |
| /*! |
| * \brief Generate body execution stmt. |
| * \param indices Target store indices for the block. |
| * \param buffers Target store buffers for the block. |
| * \param var_map Var re-mapping for TE compute axes. |
| * \param expr_body Target computation expression. |
| * \param info Generation context info. |
| * \param analyzer Arithmetic analyzer in context. |
| * \returns Init stmt. |
| **/ |
| Stmt GenerateBodyStmt(const ffi::Array<PrimExpr>& indices, const ffi::Array<Buffer>& buffers, |
| const ffi::Map<Var, PrimExpr>& var_map, PrimExpr expr_body, |
| CreateFuncInfo* info, arith::Analyzer* analyzer) { |
| // helper to transform the expr and remap iters to the block domain |
| auto f_transform_and_remap = [&](const PrimExpr& e) { |
| return Substitute(info->transformer(e), var_map); |
| }; |
| Stmt body; |
| if (const auto* reduce = expr_body.as<ReduceNode>()) { |
| // Case 1. Reduce compute |
| int n_buffers = buffers.size(); |
| |
| ffi::Array<PrimExpr> lhs; |
| ffi::Array<PrimExpr> rhs; |
| lhs.reserve(n_buffers); |
| rhs.reserve(n_buffers); |
| |
| // Make the LHS operands and RHS operands: |
| // - A LHS operand is the buffer storing the reduction result, with corresponding indices. |
| // - A RHS operand is the value to be reduced. |
| for (int i = 0; i < n_buffers; ++i) { |
| const PrimExpr& left = BufferLoad(buffers[i], indices); |
| const PrimExpr& right = analyzer->Simplify(f_transform_and_remap(reduce->source[i])); |
| lhs.push_back(left); |
| rhs.push_back(right); |
| ICHECK_EQ(left->dtype, right->dtype); |
| } |
| |
| ffi::Array<Var> temp_vars; |
| ffi::Array<Stmt> body_stmts; |
| temp_vars.reserve(n_buffers); |
| body_stmts.reserve(n_buffers); |
| |
| // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs, |
| // rhs)" into the target buffer position. |
| // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate |
| // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we |
| // then store the value of the variables into the target buffer positions. |
| for (int i = 0; i < n_buffers; ++i) { |
| const Buffer& buffer = buffers[i]; |
| PrimExpr value{nullptr}; |
| if (n_buffers > 1) { |
| temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); |
| value = temp_vars.back(); |
| } else { |
| PrimExpr combined = reduce->combiner.get()->operator()(lhs, rhs)[i]; |
| value = f_transform_and_remap(combined); |
| } |
| body_stmts.push_back(BufferStore(buffer, value, indices)); |
| } |
| body = SeqStmt::Flatten(body_stmts); |
| if (n_buffers > 1) { |
| // When there are multiple buffers, we wrap the body with LetStmts. |
| for (int i = n_buffers - 1; i >= 0; --i) { |
| PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]); |
| body = LetStmt(temp_vars[i], std::move(value), std::move(body)); |
| } |
| } |
| } else { |
| // Case 2. Data parallel compute |
| ICHECK_EQ(buffers.size(), 1); |
| const PrimExpr& compute_body = f_transform_and_remap(expr_body); |
| body = BufferStore(buffers[0], analyzer->Simplify(compute_body), indices); |
| } |
| return body; |
| } |
| |
| /*! \brief Record loops, block vars and binding in the single level scope. */ |
| struct NestedScopeInfo { |
| // loop var and range in the scope. |
| std::vector<std::pair<Var, Range>> loop_vars; |
| // block iters for current level's block. |
| ffi::Array<IterVar> block_iters; |
| // block bindings for current level's block. |
| ffi::Array<PrimExpr> bindings; |
| // store indices for current level's block. |
| ffi::Array<PrimExpr> store_indices; |
| // mapping from original TE compute axes to new block vars. |
| ffi::Map<Var, PrimExpr> axes_remap; |
| |
| // helper to add new block var |
| void AddBlockIter(const ffi::Optional<IterVar>& origin_axis, const IterVar& iter, |
| const PrimExpr& value) { |
| block_iters.push_back(iter); |
| bindings.push_back(value); |
| if (origin_axis.defined()) { |
| if (iter->iter_type != IterVarType::kCommReduce) { |
| store_indices.push_back(iter->var); |
| } |
| axes_remap.Set(origin_axis.value()->var, iter->var); |
| } |
| } |
| |
| // helper to renew leaf block var defs to ensure SSA. |
| void Renew(const ffi::Array<IterVar>& origin_axes) { |
| block_iters.MutateByApply([](const IterVar& itervar) { |
| auto n = ffi::make_object<IterVarNode>(*itervar.get()); |
| n->var = n->var.copy_with_suffix(""); |
| return IterVar(n); |
| }); |
| for (size_t i = 0; i < origin_axes.size(); ++i) { |
| Var block_var = block_iters[i]->var; |
| if (origin_axes[i]->iter_type != IterVarType::kCommReduce) { |
| store_indices.Set(i, block_var); |
| } |
| axes_remap.Set(origin_axes[i]->var, block_var); |
| } |
| } |
| }; |
| |
| Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, |
| arith::Analyzer* analyzer) { |
| // Step 1. Collect all iter axes in original TE compute op |
| ffi::Array<IterVar> axes = compute_op->axis; |
| axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); |
| |
| // Step 2. Prepare nested iteration scopes. |
| // For each axis, we generate loop and the first block binding at the level it belongs to. |
| // In lower levels, we just create new block var and bind it to the previous level block var. |
| auto axes_levels = GenerateNestedIterLevels(axes, analyzer); |
| ICHECK(!axes_levels.empty()); |
| std::vector<NestedScopeInfo> scopes; |
| scopes.reserve(axes_levels.size()); |
| std::unordered_set<Var> defined_axes; |
| for (size_t i = 0; i < axes_levels.size(); ++i) { |
| NestedScopeInfo cur_scope; |
| for (size_t j = 0; j < axes.size(); ++j) { |
| const IterVar& axis = axes[j]; |
| DataType index_type = |
| DataType::Int(std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits())); |
| bool first_times_define = |
| std::find(axes_levels[i].begin(), axes_levels[i].end(), axis) != axes_levels[i].end(); |
| if (first_times_define) { |
| Var loop_var = Var(axis->var->name_hint, index_type); |
| Var block_var("v_" + axis->var->name_hint, index_type); |
| PrimExpr min = axis->dom->min; |
| PrimExpr extent = axis->dom->extent; |
| if (i > 0) { |
| const auto& scope_repl = scopes[i - 1].axes_remap; |
| min = Substitute(min, scope_repl); |
| extent = Substitute(extent, scope_repl); |
| } |
| Range dom = Range::FromMinExtent(analyzer->Simplify(min), analyzer->Simplify(extent)); |
| IterVar new_block_iter(dom, block_var, axis->iter_type, axis->thread_tag, axis->span); |
| cur_scope.loop_vars.emplace_back(loop_var, dom); |
| cur_scope.AddBlockIter(axis, new_block_iter, loop_var); |
| defined_axes.insert(axis->var); |
| } else if (defined_axes.count(axis->var)) { |
| ICHECK_GT(i, 0); |
| ICHECK(scopes[i - 1].axes_remap.count(axis->var)); |
| PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var); |
| Var block_var("v_" + axis->var->name_hint, index_type); |
| Range dom = Range::FromMinExtent(prev_binding, make_const(index_type, 1)); |
| IterVar new_block_iter(dom, block_var, axis->iter_type, axis->thread_tag, axis->span); |
| cur_scope.AddBlockIter(axis, new_block_iter, prev_binding); |
| } |
| } |
| if (i == axes_levels.size() - 1 && cur_scope.block_iters.empty()) { |
| // for the leaf scope, we ensure at least one block var exists |
| IterVar dummy(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), |
| IterVarType::kDataPar); |
| cur_scope.AddBlockIter(std::nullopt, dummy, 0); |
| } |
| scopes.push_back(cur_scope); |
| } |
| |
| // Step 3. Generate output buffers for each output tensor |
| ffi::Array<Buffer> buffers = GenerateOutputBuffers(compute_op, info); |
| |
| // Step 4. Generate leaf block stmts. |
| ffi::Array<Stmt> seq_stmt; |
| auto leaf = scopes.back(); |
| ffi::Map<ffi::String, ffi::Any> annotations = GenerateBlockAnnotations(compute_op, info); |
| const ReduceNode* reduce = compute_op->body[0].as<ReduceNode>(); |
| if (reduce) { |
| PrimExpr expr_body = compute_op->body[0]; |
| Stmt init = GenerateInitStmt(leaf.store_indices, buffers, reduce, leaf.axes_remap, info); |
| Stmt body = |
| GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, expr_body, info, analyzer); |
| seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, |
| /*predicate=*/Bool(true), |
| /*block=*/ |
| SBlock(/*iter_vars=*/leaf.block_iters, |
| /*reads=*/{}, |
| /*writes=*/{}, |
| /*name_hint=*/info->FreshName(compute_op->name), |
| /*body=*/body, |
| /*init=*/init, |
| /*alloc_buffers=*/{}, |
| /*match_buffers=*/{}, |
| /*annotations=*/annotations))); |
| |
| } else { |
| for (int i = 0; i < compute_op->num_outputs(); ++i) { |
| if (i > 0) { |
| // Renew block var defs to ensure SSA |
| leaf.Renew(axes); |
| } |
| PrimExpr expr_body = compute_op->body[i]; |
| Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]}, leaf.axes_remap, expr_body, |
| info, analyzer); |
| seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, |
| /*predicate=*/Bool(true), |
| /*block=*/ |
| SBlock(/*iter_vars=*/leaf.block_iters, |
| /*reads=*/{}, |
| /*writes=*/{}, |
| /*name_hint=*/info->FreshName(buffers[i]->name), |
| /*body=*/body, |
| /*init=*/std::nullopt, |
| /*alloc_buffers=*/{}, |
| /*match_buffers=*/{}, |
| /*annotations=*/annotations))); |
| } |
| } |
| Stmt body = SeqStmt::Flatten(seq_stmt); |
| |
| // Step 4. Generate nested parent scopes. |
| for (size_t i = scopes.size(); i > 0; --i) { |
| const auto& cur = scopes[i - 1]; |
| if (i < scopes.size()) { |
| auto block_name = info->FreshName(compute_op->name + "_l" + std::to_string(i)); |
| const auto& block_iters = cur.block_iters; |
| |
| ffi::Optional<Stmt> init{std::nullopt}; |
| if (reduce && std::any_of(block_iters.begin(), block_iters.end(), [](const IterVar& iter) { |
| return iter->iter_type == IterVarType::kCommReduce; |
| })) { |
| // if the reduce axis defined in non-leaf scopes, the nested block is also |
| // a reduction block, thus we should also insert init stmt in the parent level. |
| init = GenerateInitStmt(cur.store_indices, buffers, reduce, cur.axes_remap, info); |
| } |
| |
| // wrap nested block |
| body = SBlockRealize(/*iter_values=*/cur.bindings, |
| /*predicate=*/Bool(true), |
| /*block=*/ |
| SBlock(/*iter_vars=*/block_iters, |
| /*reads=*/{}, |
| /*writes=*/{}, |
| /*name_hint=*/block_name, |
| /*body=*/body, |
| /*init=*/init, |
| /*alloc_buffers=*/{}, |
| /*match_buffers=*/{}, |
| /*annotations=*/annotations)); |
| } |
| for (size_t j = cur.loop_vars.size(); j > 0; --j) { |
| const auto& [loop_var, dom] = cur.loop_vars[j - 1]; |
| body = For(loop_var, dom->min, dom->extent, ForKind::kSerial, body); |
| } |
| } |
| return body; |
| } |
| |
| Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) { |
| // Step 1. Check all inputs are visited before and update var_map. |
| std::unordered_map<const VarNode*, PrimExpr> var_map; |
| std::unordered_map<const BufferNode*, Buffer> input_buffer_map; |
| ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); |
| for (size_t i = 0; i < extern_op->inputs.size(); ++i) { |
| const Buffer& placeholder = extern_op->input_placeholders[i]; |
| const te::Tensor& input_tensor = extern_op->inputs[i]; |
| auto it = info->tensor2buffers.find(input_tensor); |
| ICHECK(it != info->tensor2buffers.end()); |
| var_map[placeholder->data.get()] = it->second->data; |
| input_buffer_map[placeholder.get()] = it->second; |
| } |
| |
| // Step 2. Update info with its output tensor and placeholder buffer. |
| ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size()); |
| for (int i = 0; i < extern_op->num_outputs(); ++i) { |
| const Buffer& placeholder = extern_op->output_placeholders[i]; |
| const te::Tensor& output_tensor = extern_op.output(i); |
| info->tensor2buffers[output_tensor] = placeholder; |
| if (!info->IsArg(output_tensor)) { |
| info->root_alloc.push_back(placeholder); |
| } |
| } |
| |
| // The access region does not need to be collected here, as it will |
| // be generated with the later application of "script.Complete" in |
| // GenerateAndCompletePrimFunc. Waiting until later also handles |
| // the case where there is only a single BlockNode, which then |
| // becomes the root SBlock of the function, and should not have |
| // reads/writes filled in. |
| |
| BufferSubstituter substituter(var_map, input_buffer_map); |
| Stmt substituted_body = substituter(extern_op->body); |
| |
| ProducerToBufferTransformer transformer(info->tensor2buffers); |
| Stmt body = transformer(substituted_body); |
| |
| // Step 4. Generate opaque block as body. |
| return SBlockRealize(/*iter_values=*/{}, |
| /*predicate=*/Bool(true), |
| /*block=*/ |
| SBlock(/*iter_vars=*/{}, |
| /*reads=*/{}, |
| /*writes=*/{}, |
| /*name_hint=*/info->FreshName(extern_op->name), |
| /*body=*/std::move(body), |
| /*init=*/std::nullopt, |
| /*alloc_buffers=*/{}, |
| /*match_buffers=*/{}, |
| /*annotations=*/extern_op->attrs)); |
| } |
| |
| ffi::Array<te::Operation> CollectOrderedOps(const ffi::Array<te::Tensor>& arg_list) { |
| ffi::Array<te::Operation> arg_ops; |
| for (const te::Tensor& arg : arg_list) { |
| arg_ops.push_back(arg->op); |
| } |
| te::ReadGraph g = te::CreateReadGraph(arg_ops); |
| ffi::Array<te::Operation> order = te::PostDFSOrder(arg_ops, g); |
| |
| for (const te::Operation& op : order) { |
| if (!(op->IsInstance<te::PlaceholderOpNode>() || op->IsInstance<te::ComputeOpNode>() || |
| op->IsInstance<te::ExternOpNode>())) |
| LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " |
| << "Only te.placeholder and te.compute are allowed for now."; |
| } |
| return order; |
| } |
| |
| void InitializeBufferBinds(const ffi::Array<te::Operation>& ordered_ops, CreateFuncInfo* info) { |
| // Process any TE operations which contain user defined buffers |
| for (const auto& op : ordered_ops) { |
| // Initialize the tensor2buffer binds map with buffers defined by the te.extern |
| if (const auto* extern_op = op.as<te::ExternOpNode>()) { |
| ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); |
| for (size_t i = 0; i < extern_op->inputs.size(); ++i) { |
| const te::Tensor& input = extern_op->inputs[i]; |
| const Buffer& buffer = extern_op->input_placeholders[i]; |
| info->tensor2buffers[input] = buffer; |
| } |
| } |
| } |
| } |
| |
| void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, |
| ffi::Array<Stmt>* root_stmts, arith::Analyzer* analyzer) { |
| if (const auto* placeholder = op.as<te::PlaceholderOpNode>()) { |
| // Case 1. PlaceholderOp (te.placeholder) |
| ICHECK_EQ(op->num_outputs(), 1); |
| const te::Tensor& tensor = op.output(0); |
| // Check op is in op list |
| ICHECK(info->IsArg(tensor)) << "The operation " << op << " produces tensor " << tensor |
| << ", but this tensor does not appear as a function argument. " |
| << "The function accepts arguments " << info->arg_list; |
| // Declare a buffer for any argument tensors without a pre-existing |
| // buffer declaration recorded in the tensor2buffer binds map |
| if (info->tensor2buffers.count(tensor) == 0) { |
| const Buffer& buffer = |
| decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); |
| info->tensor2buffers[tensor] = buffer; |
| } |
| } else if (auto compute_op = op.as<te::ComputeOp>()) { |
| // Case 2. ComputeOp (te.compute) |
| root_stmts->push_back(GenerateStmtFromCompute(compute_op.value(), info, analyzer)); |
| } else if (const auto extern_op = op.as<te::ExternOp>()) { |
| // Case 3. ExternOp (te.extern) |
| root_stmts->push_back(GenerateStmtFromExternOp(extern_op.value(), info)); |
| } else { |
| ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " |
| << "Only te.placeholder and te.compute are allowed for now."; |
| } |
| } |
| |
| PrimFunc GenerateAndCompletePrimFunc(const ffi::Array<te::Tensor>& arg_list, |
| const ffi::Array<Stmt>& root_stmts, CreateFuncInfo* info) { |
| ffi::Array<Var> parameters; |
| ffi::Map<Var, Buffer> buffer_map; |
| for (const te::Tensor& tensor : arg_list) { |
| Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); |
| parameters.push_back(arg); |
| auto it = info->tensor2buffers.find(tensor); |
| ICHECK(it != info->tensor2buffers.end()); |
| buffer_map.Set(arg, it->second); |
| } |
| PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), |
| /*body=*/SeqStmt::Flatten(root_stmts), |
| /*ret_type=*/VoidType(), |
| /*buffer_map=*/std::move(buffer_map)), |
| {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); |
| const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); |
| ICHECK(fcomplete.has_value()); |
| func = (*fcomplete)(std::move(func), info->root_alloc).cast<PrimFunc>(); |
| return func; |
| } |
| |
| PrimFunc CreatePrimFuncWithConstants(const ffi::Array<te::Tensor>& arg_list, |
| const ffi::Array<runtime::Tensor>& constants, |
| std::optional<DataType> index_dtype_override) { |
| // Information used in CreatePrimFunc and its sub-functions. |
| CreateFuncInfo info(arg_list); |
| // Root body stmts. |
| ffi::Array<Stmt> root_stmts; |
| // Analyzer |
| arith::Analyzer analyzer; |
| |
| // Step 1. Create ordered array of operations and validate they are supported. |
| ffi::Array<te::Operation> order = CollectOrderedOps(arg_list); |
| |
| // Step 2. Initialize buffer binds map |
| InitializeBufferBinds(order, &info); |
| |
| // Step 3. Rewrite compute stages into blocks. |
| for (const te::Operation& op : order) { |
| RewriteStageToBlock(op, &info, &root_stmts, &analyzer); |
| } |
| |
| // Step 4. Create func and complete prim func. |
| auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); |
| func = tir::BindParams(func, constants); |
| if (index_dtype_override.has_value()) { |
| func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); |
| } |
| auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); |
| return result; |
| } |
| |
| PrimFunc CreatePrimFunc(const ffi::Array<te::Tensor>& arg_list, |
| std::optional<DataType> index_dtype_override) { |
| return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("te.CreatePrimFunc", [](ffi::PackedArgs args, ffi::Any* ret) { |
| ffi::Array<ObjectRef> arg_list = args[0].cast<ffi::Array<ObjectRef>>(); |
| std::optional<DataType> index_dtype_override{std::nullopt}; |
| // Add conversion to make std::optional compatible with FFI. |
| if (args[1] != nullptr) { |
| index_dtype_override = args[1].cast<DataType>(); |
| } |
| *ret = CreatePrimFunc(arg_list, index_dtype_override); |
| }); |
| } |
| |
| // Relax version impl |
| PrimFunc GenerateAndCompletePrimFunc(const ffi::Array<ObjectRef>& arg_tir_var_list, |
| const ffi::Array<Stmt>& root_stmts, CreateFuncInfo* info) { |
| ffi::Array<Var> parameters; |
| ffi::Map<Var, Buffer> buffer_map; |
| for (const ObjectRef& arg : arg_tir_var_list) { |
| if (auto opt_tensor = arg.as<te::Tensor>()) { |
| te::Tensor tensor = opt_tensor.value(); |
| Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); |
| parameters.push_back(arg); |
| auto it = info->tensor2buffers.find(tensor); |
| ICHECK(it != info->tensor2buffers.end()); |
| buffer_map.Set(arg, it->second); |
| } else if (auto var = arg.as<tir::Var>()) { |
| parameters.push_back(var.value()); |
| } |
| } |
| PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), |
| /*body=*/SeqStmt::Flatten(root_stmts), |
| /*ret_type=*/VoidType(), |
| /*buffer_map=*/std::move(buffer_map)), |
| {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); |
| const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); |
| ICHECK(fcomplete.has_value()); |
| func = (*fcomplete)(std::move(func), info->root_alloc).cast<PrimFunc>(); |
| return func; |
| } |
| |
| PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, |
| const ffi::Array<runtime::Tensor>& constants, |
| std::optional<DataType> index_dtype_override) { |
| ffi::Array<te::Tensor> tensor_arg_list; |
| for (const ObjectRef& x : arg_list) { |
| if (auto tensor_node = x.as<te::TensorNode>()) { |
| te::Tensor tensor = ffi::GetRef<te::Tensor>(tensor_node); |
| tensor_arg_list.push_back(tensor); |
| } |
| } |
| // Infomations used in CreatePrimFunc and its sub-functions. |
| CreateFuncInfo info(tensor_arg_list); |
| // Root body stmts. |
| ffi::Array<Stmt> root_stmts; |
| // Analyzer |
| arith::Analyzer analyzer; |
| |
| // Step 1. Create ordered array of operations and validate they are supported. |
| ffi::Array<te::Operation> order = CollectOrderedOps(tensor_arg_list); |
| |
| // Step 2. Initialize buffer binds map |
| InitializeBufferBinds(order, &info); |
| |
| // Step 3. Rewrite compute stages into blocks. |
| for (const te::Operation& op : order) { |
| RewriteStageToBlock(op, &info, &root_stmts, &analyzer); |
| } |
| auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); |
| func = tir::BindParams(func, constants); |
| if (index_dtype_override.has_value()) { |
| func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); |
| } |
| auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); |
| return result; |
| } |
| |
| PrimFunc CreatePrimFunc(const ffi::Array<ObjectRef>& arg_list, |
| std::optional<DataType> index_dtype_override) { |
| return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); |
| } |
| |
| } // namespace tir |
| } // namespace tvm |