blob: 17c7e59c7f60501fa1afb631256f2177b03548df [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/transform/capture_index_in_spans.cc
* \brief A pass to set spans to capture the post-dfs index of every node.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "../ir/indexed_graph.h"
namespace tvm {
namespace relay {
namespace transform {
namespace {
/*! \brief Update all the spans to capture their post-dfs index. */
class SpansRewriter : public ExprRewriter {
public:
explicit SpansRewriter(const IndexedGraph<Expr>* indexed_graph)
: source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {}
private:
Expr Rewrite_(const VarNode* var_node, const Expr& post) final {
return WithFields(Downcast<Var>(post), {}, {}, {}, MakeSpan(GetRef<Var>(var_node)));
}
Expr Rewrite_(const GlobalVarNode* global_var_node, const Expr& post) final {
return WithFields(Downcast<GlobalVar>(post), {}, {}, {},
MakeSpan(GetRef<GlobalVar>(global_var_node)));
}
Expr Rewrite_(const ConstantNode* constant_node, const Expr& post) final {
return WithFields(Downcast<Constant>(post), {}, {}, MakeSpan(GetRef<Constant>(constant_node)));
}
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
return WithFields(Downcast<Tuple>(post), {}, {}, MakeSpan(GetRef<Tuple>(tuple_node)));
}
Expr Rewrite_(const FunctionNode* function_node, const Expr& post) final {
return WithFields(Downcast<Function>(post), {}, {}, {}, {}, {}, {},
MakeSpan(GetRef<Function>(function_node)));
}
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
return WithFields(Downcast<Call>(post), {}, {}, {}, {}, {}, MakeSpan(GetRef<Call>(call_node)));
}
Expr Rewrite_(const LetNode* let_node, const Expr& post) final {
return WithFields(Downcast<Let>(post), {}, {}, {}, {}, MakeSpan(GetRef<Let>(let_node)));
}
Expr Rewrite_(const IfNode* if_node, const Expr& post) final {
return WithFields(Downcast<If>(post), {}, {}, {}, {}, MakeSpan(GetRef<If>(if_node)));
}
// OpNodes are not rewritten.
Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post) final {
return WithFields(Downcast<TupleGetItem>(post), {}, {}, {},
MakeSpan(GetRef<TupleGetItem>(tuple_get_item_node)));
}
Expr Rewrite_(const RefCreateNode* ref_create_node, const Expr& post) final {
return WithFields(Downcast<RefCreate>(post), {}, {},
MakeSpan(GetRef<RefCreate>(ref_create_node)));
}
Expr Rewrite_(const RefReadNode* ref_read_node, const Expr& post) final {
return WithFields(Downcast<RefRead>(post), {}, {}, MakeSpan(GetRef<RefRead>(ref_read_node)));
}
Expr Rewrite_(const RefWriteNode* ref_write_node, const Expr& post) final {
return WithFields(Downcast<RefWrite>(post), {}, {}, {},
MakeSpan(GetRef<RefWrite>(ref_write_node)));
}
// ConstructorNodes are not rewritten.
Expr Rewrite_(const MatchNode* match_node, const Expr& post) final {
return WithFields(Downcast<Match>(post), {}, {}, {}, MakeSpan(GetRef<Match>(match_node)));
}
Span MakeSpan(const Expr& expr) {
auto node = indexed_graph_->item_to_node(expr);
int node_index = static_cast<int>(node->index_);
int dominator_index =
node->dominator_parent_ ? static_cast<int>(node->dominator_parent_->index_) : -1;
Span span(source_name_, /*line=*/node_index, /*end_line=*/node_index,
/*column=*/dominator_index, /*end_column=*/dominator_index);
return span;
}
SourceName source_name_;
const IndexedGraph<Expr>* indexed_graph_;
};
} // namespace
tvm::transform::Pass CapturePostDfsIndexInSpans() {
auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) {
std::unique_ptr<IndexedGraph<Expr>> indexed_graph = CreateIndexedGraph(f);
SpansRewriter rewriter(indexed_graph.get());
return Downcast<Function>(PostOrderRewrite(f, &rewriter));
};
return CreateFunctionPass(pass_func, 0, "CapturePostDfsIndexInSpans", {});
}
TVM_REGISTER_GLOBAL("relay._transform.CapturePostDfsIndexInSpans")
.set_body_typed(CapturePostDfsIndexInSpans);
} // namespace transform
} // namespace relay
} // namespace tvm