| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/relay/transforms/annotate_target.cc |
| * \brief Wraps an expr with compiler_begin and compiler_end to indicate that |
| * this expr should be handled by the external compiler. |
| */ |
| |
| #include <tvm/relay/attrs/annotation.h> |
| #include <tvm/relay/expr_functor.h> |
| #include <tvm/relay/op_attr_types.h> |
| #include <tvm/relay/transform.h> |
| #include <tvm/runtime/container.h> |
| |
| #include "pass_util.h" |
| |
| namespace tvm { |
| namespace relay { |
| namespace annotate_target { |
| |
| static const PackedFunc* make_begin_op = |
| runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); |
| static const PackedFunc* make_end_op = |
| runtime::Registry::Get("relay.op.annotation._make.compiler_end"); |
| |
| // A helper class to insert annotation boundaries for a program region that will |
| // be handled by a specific compiler. |
| class AnnotateTargetRewriter : public ExprRewriter { |
| public: |
| explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {} |
| |
| /*! |
| * \brief This function annotates a compiler end and a compiler begin to all arguments. |
| * |
| * The compiler end is based on the arg target while the compiler begin is based on the given |
| * target. If target is not given and all arguments are going to the same target, then we will |
| * use that target; otherwise we use default for this op. Note that all arg exprs must be |
| * available in op_expr_to_target before calling this function. |
| * |
| * \param args An array of arguments of the given node. |
| * \param target The target of the current node. |
| * \return A pair of target and annotated argument expressions. |
| */ |
| std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args, |
| const std::string& target = "") { |
| std::string ref_target = ""; |
| Array<Expr> compiler_ends; |
| for (auto arg : args) { |
| std::string arg_target = "default"; |
| const CallNode* call = arg.as<CallNode>(); |
| |
| if (call && call->op == CompilerBeginOp()) { |
| // Argument is already compiler begin node meaning that this is not the first time |
| // running this pass, so we simply remove it and will add a new one later. |
| CHECK_EQ(call->args.size(), 1U); |
| const CallNode* end = call->args[0].as<CallNode>(); |
| if (end->op == CompilerEndOp()) { |
| arg_target = end->attrs.as<CompilerAttrs>()->compiler; |
| } |
| compiler_ends.push_back(call->args[0]); |
| } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { |
| arg_target = op_expr_to_target_[arg]; |
| compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); |
| } else { |
| // Input vars. |
| compiler_ends.push_back(arg); |
| } |
| |
| // Maintain reference target in case the target of the current node is unassigned. |
| if (ref_target == "") { |
| ref_target = arg_target; |
| } else if (ref_target != arg_target) { |
| ref_target = "default"; |
| } |
| } |
| |
| // Determine compiler begin target. |
| std::string op_target = (target == "") ? ref_target : target; |
| |
| Array<Expr> compiler_begins; |
| for (const auto& end : compiler_ends) { |
| compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); |
| } |
| |
| return {op_target, compiler_begins}; |
| } |
| |
| Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) { |
| Expr new_op = (*ann_op)(expr, target); |
| new_op->checked_type_ = expr->checked_type_; |
| return new_op; |
| } |
| |
| Expr Rewrite_(const CallNode* pre, const Expr& post) final { |
| // Supported targets for this node. The order implies the priority. |
| std::vector<std::string> supported_targets; |
| |
| auto op_node = pre->op.as<OpNode>(); |
| |
| // This graph has annotations, meaning that this is not the first time running this pass. |
| if (op_node && pre->op == CompilerBeginOp()) { |
| // Bypass compiler begin due to lack of target information. It will be processed |
| // when the following op handling arguments. |
| CHECK_EQ(pre->args.size(), 1U); |
| return post.as<CallNode>()->args[0]; |
| } else if (op_node && pre->op == CompilerEndOp()) { |
| // Override compiler end with the new target. |
| CHECK_EQ(pre->args.size(), 1U); |
| auto input_expr = post.as<CallNode>()->args[0]; |
| CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); |
| return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); |
| } |
| |
| // Peek the first argument. If it is compiler begin then this node had annotated by |
| // another target before, so we also consider that target as a supported target. |
| const CallNode* first_arg_call = pre->args[0].as<CallNode>(); |
| if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { |
| std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler; |
| if (arg_target != "default") { |
| supported_targets.push_back(arg_target); |
| } |
| } |
| |
| // Check which targets this op can be offloaded. |
| if (op_node) { |
| // TVM operators: Check target specific op checking function and add to supported_targets |
| // if it is supported. |
| Op op = Downcast<Op>(pre->op); |
| CHECK(op.defined()); |
| for (const auto& target : this->targets_) { |
| if (!Op::HasAttrMap("target." + std::string(target))) { |
| continue; |
| } |
| auto fannotate = Op::GetAttrMap<FTVMAnnotateTarget>("target." + std::string(target)); |
| if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) { |
| supported_targets.push_back(target); |
| } |
| } |
| } else if (pre->op->IsInstance<FunctionNode>()) { |
| // Composite function: Add the target of a composite function to supported_targets |
| // if it is in the target list. |
| Function func = Downcast<Function>(pre->op); |
| CHECK(func.defined()); |
| |
| if (auto comp_name = func->GetAttr<String>(attr::kComposite)) { |
| std::string comp_name_str = comp_name.value(); |
| size_t i = comp_name_str.find('.'); |
| if (i != std::string::npos) { |
| std::string comp_target = comp_name_str.substr(0, i); |
| for (const auto& target : this->targets_) { |
| if (std::string(target) == comp_target) { |
| supported_targets.push_back(comp_target); |
| break; |
| } |
| } |
| } |
| } |
| } |
| supported_targets.push_back("default"); // Make default as the last option. |
| |
| // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with |
| // the highest priority, but we should preserve all supported targets so that |
| // we can make a better decision. |
| std::string target = supported_targets[0]; |
| |
| // Visit and mutate arguments after the target of this op has been determined. |
| Call post_call = Downcast<Call>(post); |
| |
| // Add annotations to each arg. |
| auto target_n_args = AnnotateArgs(post_call->args, target); |
| Array<Expr> compiler_begins = std::get<1>(target_n_args); |
| Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); |
| new_call->checked_type_ = pre->checked_type_; |
| |
| // Update the target map. |
| op_expr_to_target_[new_call] = target; |
| |
| return std::move(new_call); |
| } |
| |
| Expr Rewrite_(const TupleNode* op, const Expr& post) final { |
| auto expr = Downcast<Tuple>(post); |
| |
| auto target_n_args = AnnotateArgs(expr->fields); |
| auto new_expr = Tuple(std::get<1>(target_n_args)); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { |
| auto expr = Downcast<TupleGetItem>(post); |
| |
| auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple})); |
| auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const FunctionNode* fn, const Expr& post) final { |
| Function func; |
| Expr new_body; |
| // don't step into composite functions |
| if (fn->GetAttr<String>(attr::kComposite).defined()) { |
| func = GetRef<Function>(fn); |
| new_body = func->body; |
| } else { |
| func = Downcast<Function>(post); |
| new_body = func->body; |
| if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { |
| new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); |
| op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; |
| } |
| } |
| return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); |
| } |
| |
| Expr Rewrite_(const LetNode* op, const Expr& post) final { |
| auto let = Downcast<Let>(post); |
| |
| auto target_n_args = AnnotateArgs({let->value, let->body}); |
| auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const IfNode* op, const Expr& post) final { |
| auto expr = Downcast<If>(post); |
| |
| auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); |
| CHECK_EQ(std::get<1>(target_n_args).size(), 3U); |
| auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], |
| std::get<1>(target_n_args)[2]); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const RefCreateNode* op, const Expr& post) final { |
| auto expr = Downcast<RefCreate>(post); |
| |
| auto target_n_args = AnnotateArgs(Array<Expr>({expr->value})); |
| auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const RefReadNode* op, const Expr& post) final { |
| auto expr = Downcast<RefRead>(post); |
| |
| auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref})); |
| auto new_expr = RefRead(std::get<1>(target_n_args)[0]); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| Expr Rewrite_(const RefWriteNode* op, const Expr& post) final { |
| auto expr = Downcast<RefWrite>(post); |
| |
| auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value})); |
| auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); |
| op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
| return std::move(new_expr); |
| } |
| |
| private: |
| /*! \brief The target backends for annotation. */ |
| Array<runtime::String> targets_; |
| /*! \brief Maintain the decision of the target for each op expr. */ |
| std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_; |
| }; |
| |
| Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) { |
| auto rewriter = AnnotateTargetRewriter(targets); |
| return PostOrderRewrite(expr, &rewriter); |
| } |
| |
| } // namespace annotate_target |
| |
| namespace transform { |
| |
| Pass AnnotateTarget(const Array<runtime::String>& targets) { |
| runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
| [=](Function f, IRModule m, PassContext pc) { |
| return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets)); |
| }; |
| auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); |
| return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); |
| } |
| |
| TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget); |
| |
| } // namespace transform |
| |
| } // namespace relay |
| } // namespace tvm |