blob: c3375436f1374f03a7e3eff687efea308e1d03ba [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/si_builder.cc
* \brief Implementation for building a source info during rewriting expressions.
*/
#include <tvm/ir/si_builder.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <vector>
namespace tvm {
using RelayExprSet = std::unordered_set<relay::Expr, ObjectPtrHash, ObjectPtrEqual>;
using PrimExprSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
using StmtSet = std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>;
class RelayCollectSpans : public relay::ExprVisitor {
public:
explicit RelayCollectSpans(const RelayExprSet& inputs = {}) : inputs_(inputs) {}
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const relay::Expr& entry);
void VisitExpr(const relay::Expr& expr) final;
private:
Array<Span> spans_;
const RelayExprSet& inputs_;
};
void RelayCollectSpans::VisitExpr(const relay::Expr& expr) {
if (visit_counter_.count(expr.get())) {
return;
}
if (expr->span.defined()) {
spans_.push_back(expr->span);
}
if (inputs_.find(expr) != inputs_.end()) {
// becuase it returns directly, it should be recorded as visted manually.
visit_counter_.insert({expr.get(), 1});
return;
}
relay::ExprVisitor::VisitExpr(expr);
}
Span RelayCollectSpans::CollectSpans(const relay::Expr& entry) {
VisitExpr(entry);
return SequentialSpan(spans_);
}
class RelayRecursivelyFill : public relay::ExprMutator {
public:
explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {})
: span_(span), inputs_(inputs) {}
// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
void Fill(const relay::Expr& entry);
relay::Expr VisitExpr(const relay::Expr& expr) final;
private:
const Span& span_;
const RelayExprSet& inputs_;
};
relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) {
if (inputs_.find(expr) != inputs_.end()) {
return expr;
}
// Skip op node. Align with python frontend
if (!expr.as<OpNode>()) {
expr->span = span_;
}
return relay::ExprMutator::VisitExpr(expr);
}
void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); }
class TirCollectSpans : public tir::StmtExprVisitor {
public:
explicit TirCollectSpans(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {})
: expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}
void VisitExpr(const PrimExpr& expr) final;
void VisitStmt(const tir::Stmt& stmt) final;
bool IsInput(const PrimExpr& expr);
bool IsInput(const tir::Stmt& stmt);
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const PrimExpr& expr);
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const tir::Stmt& stmt);
private:
Array<Span> spans_;
std::unordered_map<const Object*, size_t> visit_counter_;
const PrimExprSet& expr_inputs_;
const StmtSet& stmt_inputs_;
};
Span TirCollectSpans::CollectSpans(const PrimExpr& expr) {
operator()(expr);
return SequentialSpan(spans_);
}
Span TirCollectSpans::CollectSpans(const tir::Stmt& stmt) {
operator()(stmt);
return SequentialSpan(spans_);
}
bool TirCollectSpans::IsInput(const PrimExpr& expr) {
return expr_inputs_.find(expr) != expr_inputs_.end();
}
bool TirCollectSpans::IsInput(const tir::Stmt& stmt) {
return stmt_inputs_.find(stmt) != stmt_inputs_.end();
}
void TirCollectSpans::VisitExpr(const PrimExpr& expr) {
if (visit_counter_.count(expr.get())) {
return;
}
if (expr->span.defined()) {
spans_.push_back(expr->span);
}
if (IsInput(expr)) {
// becuase it returns directly, it should be recorded as visted manually.
visit_counter_.insert({expr.get(), 1});
return;
}
StmtExprVisitor::VisitExpr(expr);
}
void TirCollectSpans::VisitStmt(const tir::Stmt& stmt) {
if (visit_counter_.count(stmt.get())) {
return;
}
if (stmt->span.defined()) {
spans_.push_back(stmt->span);
}
if (IsInput(stmt)) {
// becuase it returns directly, it should be recorded as visted manually.
visit_counter_.insert({stmt.get(), 1});
return;
}
StmtExprVisitor::VisitStmt(stmt);
}
class TirRecursivelyFill : public tir::StmtExprMutator {
public:
TirRecursivelyFill(const Span& span, const PrimExprSet& expr_inputs = {},
const StmtSet& stmt_inputs = {})
: span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}
// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); }
// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
PrimExpr Fill(const PrimExpr& e) { return operator()(e); }
bool IsInput(const PrimExpr& expr);
bool IsInput(const tir::Stmt& stmt);
PrimExpr VisitExpr(const PrimExpr& expr) final;
tir::Stmt VisitStmt(const tir::Stmt& stmt) final;
private:
const Span& span_;
const PrimExprSet& expr_inputs_;
const StmtSet& stmt_inputs_;
};
tir::Stmt TirRecursivelyFill::VisitStmt(const tir::Stmt& stmt) {
if (IsInput(stmt)) {
return stmt;
}
stmt->span = span_;
return StmtExprMutator::VisitStmt(stmt);
}
bool TirRecursivelyFill::IsInput(const PrimExpr& expr) {
return expr_inputs_.find(expr) != expr_inputs_.end();
}
bool TirRecursivelyFill::IsInput(const tir::Stmt& stmt) {
return stmt_inputs_.find(stmt) != stmt_inputs_.end();
}
PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) {
if (IsInput(expr)) {
return expr;
}
expr->span = span_;
return StmtExprMutator::VisitExpr(expr);
}
struct SIBuilder::Impl {
virtual Span Build() const { return Span(); }
virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {}
virtual void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) {}
virtual void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) {}
virtual void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) {}
virtual void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) {}
};
SIBuilder::~SIBuilder() = default;
Span SIBuilder::Build() const { return impl_->Build(); }
template <>
void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {
impl_->RecursivelyFillSpan(entry, inputs);
}
template <>
void SIBuilder::RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {
impl_->RecursivelyFillSpan(entry, inputs);
}
void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {
impl_->RecursivelyFillSpan(entry, inputs);
}
void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {
impl_->RecursivelyFillSpan(entry, inputs);
}
std::unique_ptr<SIBuilder::Impl> SIBuilder::CreateImpl(const Span& span) {
struct Impl : public SIBuilder::Impl {
explicit Impl(const Span& span) : span_(span) {}
Span Build() const final { return span_; }
void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final {
RelayRecursivelyFill(Build(), inputs).Fill(entry);
}
void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final {
TirRecursivelyFill(Build(), inputs).Fill(entry);
}
void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final {
TirRecursivelyFill(Build(), inputs).Fill(entry);
}
void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final {
TirRecursivelyFill(Build(), {}, inputs).Fill(entry);
}
void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) final {
span_ = RelayCollectSpans(inputs).CollectSpans(entry);
}
void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) final {
span_ = TirCollectSpans(inputs).CollectSpans(entry);
}
void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final {
span_ = TirCollectSpans(inputs).CollectSpans(entry);
}
void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) final {
span_ = TirCollectSpans({}, inputs).CollectSpans(entry);
}
private:
Span span_;
};
const bool enable_si_builder = transform::PassContext::Current()
->GetConfig<Bool>("ir.enable_si_builder", Bool(false))
.value();
if (enable_si_builder) {
return std::make_unique<Impl>(span);
}
return std::make_unique<SIBuilder::Impl>();
}
SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {}
SIBuilder::SIBuilder(const Array<Span>& spans) : impl_(CreateImpl(SequentialSpan(spans))) {}
SIBuilder::SIBuilder(const std::initializer_list<Span>& init)
: impl_(CreateImpl(SequentialSpan(Array<Span>(init)))) {}
template <>
SIBuilder::SIBuilder(const relay::Expr& expr, const Array<relay::Expr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollectSpansSpan(expr, RelayExprSet(inputs.begin(), inputs.end()));
}
template <>
SIBuilder::SIBuilder(const PrimExpr& expr, const Array<PrimExpr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollectSpansSpan(expr, PrimExprSet(inputs.begin(), inputs.end()));
}
SIBuilder::SIBuilder(const tir::Stmt& s, const Array<PrimExpr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollectSpansSpan(s, PrimExprSet(inputs.begin(), inputs.end()));
}
SIBuilder::SIBuilder(const tir::Stmt& s, const Array<tir::Stmt>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollectSpansSpan(s, StmtSet(inputs.begin(), inputs.end()));
}
// Register build pipeline related options
TVM_REGISTER_PASS_CONFIG_OPTION("ir.enable_si_builder", Bool);
} // namespace tvm