[IR][SIBuilder] (#14574)
* [IR][SIBuilder]
- Add SIBuilder to handle the span propagation between passes
- Add SequentialSpan for multiple source expressions conversion between
passes
- Add test cases for SIBuilder and SequentialSpan
* [IR][SIBuilder]
- Make null implementation as base class
- Add comments and change naming based on reviewing
---------
Co-authored-by: Joey Tsai <chunit@qti.qualcomm.com>
diff --git a/include/tvm/ir/si_builder.h b/include/tvm/ir/si_builder.h
new file mode 100644
index 0000000..ab5f2d4
--- /dev/null
+++ b/include/tvm/ir/si_builder.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/si_builder.h
+ * \brief build a source info during rewriting expressions.
+ */
+#ifndef TVM_IR_SI_BUILDER_H_
+#define TVM_IR_SI_BUILDER_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/tir/stmt.h>
+
+#include <memory>
+#include <unordered_set>
+
+namespace tvm {
+
+/*!
+ * \brief Source Information Builder, SIBuilder provides helper APIs for filling spans,
+ * particularly useful for one-to-many, many-to-one and many-to-many IR transformations.
+ */
+class SIBuilder {
+ public:
+ /*!
+ * \brief Create SIBuilder from a given span
+ */
+ explicit SIBuilder(const Span& span = Span());
+
+ /*!
+ * \brief Create SIBuilder from a given span sequence
+ */
+ explicit SIBuilder(const Array<Span>& spans = Array<Span>());
+ explicit SIBuilder(const std::initializer_list<Span>& init);
+
+ /*!
+ * \brief Create SIBuilder via a subgraph,
+ * Will construct span based on the exprs in the subgraph. Including the inputs exprs.
+ *
+ * \param entry Entry expr for subgraph
+ * \param inputs End exprs for subgraph
+ */
+ template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
+ explicit SIBuilder(const T& entry, const tvm::Array<T>& inputs = {});
+ explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<PrimExpr>& inputs = {});
+ explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<tir::Stmt>& inputs = {});
+
+ ~SIBuilder();
+
+ SIBuilder(const SIBuilder&) = delete;
+ SIBuilder& operator=(const SIBuilder&) = delete;
+
+ /*!
+ * \brief build a span of source information, which is based on the given span or subgraph.
+ *
+ * \return the built span
+ */
+ Span Build() const;
+
+ /*!
+ * \brief Recursively fill all span of exprs in subgraph from entry until inputs.
+ *
+ * \param entry Entry expr for subgraph.
+ * \param inputs End exprs for subgraph, will not be filled with new span.
+ */
+ template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
+ void RecursivelyFillSpan(
+ const T& entry, const std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
+
+ void RecursivelyFillSpan(
+ const tir::Stmt& entry,
+ const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
+ void RecursivelyFillSpan(
+ const tir::Stmt& entry,
+ const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+
+ std::unique_ptr<Impl> CreateImpl(const Span& span);
+};
+
+} // namespace tvm
+
+#endif // TVM_IR_SI_BUILDER_H_
diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h
index 536099f..9b3041f 100644
--- a/include/tvm/ir/source_map.h
+++ b/include/tvm/ir/source_map.h
@@ -114,7 +114,7 @@
}
static constexpr const char* _type_key = "Span";
- TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
+ TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object);
};
class Span : public ObjectRef {
@@ -127,6 +127,50 @@
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
+/*!
+ * \brief Store a list of spans for an expr generated from mulitple source exprs
+ */
+class SequentialSpanNode : public SpanNode {
+ public:
+ /*! \brief The original source list of spans to construct a sequential span. */
+ Array<Span> spans;
+
+ // override attr visitor
+ void VisitAttrs(AttrVisitor* v) {
+ SpanNode::VisitAttrs(v);
+ v->Visit("spans", &spans);
+ }
+
+ static constexpr const char* _type_key = "SequentialSpan";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode);
+
+ bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const {
+ if (spans.size() != other->spans.size()) {
+ return false;
+ }
+
+ for (size_t i = 0, e = spans.size(); i != e; ++i) {
+ if (!StructuralEqual()(spans[i], other->spans[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+/*!
+ * \brief Reference class of SequentialSpanNode.
+ * \sa SequentialSpanNode
+ */
+class SequentialSpan : public Span {
+ public:
+ TVM_DLL SequentialSpan(Array<Span> spans);
+
+ TVM_DLL SequentialSpan(std::initializer_list<Span> init);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(SequentialSpan, Span, SequentialSpanNode);
+};
+
/*! \brief A program source in any language.
*
* Could represent the source from an ML framework or a source
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4f63cbe..5875f4b 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -25,6 +25,7 @@
Node,
SourceName,
Span,
+ SequentialSpan,
assert_structural_equal,
load_json,
save_json,
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 5f3a679..21a5ed6 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -69,6 +69,23 @@
)
+@register_object("SequentialSpan")
+class SequentialSpan(Object):
+ """A sequence of source spans
+
+ This span is specific for an expression, which is from multiple expressions
+ after an IR transform.
+
+ Parameters
+ ----------
+ spans : Array
+ The array of spans.
+ """
+
+ def __init__(self, spans):
+ self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans)
+
+
@register_object
class EnvFunc(Object):
"""Environment function.
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 02eec18..ef2b515 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -73,6 +73,7 @@
# Span
Span = base.Span
+SequentialSpan = base.SequentialSpan
SourceName = base.SourceName
# Type
diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py
index 8667bfb..460746f 100644
--- a/python/tvm/relay/base.py
+++ b/python/tvm/relay/base.py
@@ -20,7 +20,7 @@
import tvm._ffi
from tvm.ir import Node as RelayNode
-from tvm.ir import SourceName, Span
+from tvm.ir import SourceName, Span, SequentialSpan
from tvm.runtime import Object
from . import _ffi_api
diff --git a/src/ir/si_builder.cc b/src/ir/si_builder.cc
new file mode 100644
index 0000000..c337543
--- /dev/null
+++ b/src/ir/si_builder.cc
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/si_builder.cc
+ * \brief Implementation for building a source info during rewriting expressions.
+ */
+#include <tvm/ir/si_builder.h>
+#include <tvm/ir/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <vector>
+
+namespace tvm {
+
+using RelayExprSet = std::unordered_set<relay::Expr, ObjectPtrHash, ObjectPtrEqual>;
+using PrimExprSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+using StmtSet = std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>;
+
+class RelayCollectSpans : public relay::ExprVisitor {
+ public:
+ explicit RelayCollectSpans(const RelayExprSet& inputs = {}) : inputs_(inputs) {}
+
+ // From entry to inputs, recursively collect spans. The spans of inputs are included.
+ Span CollectSpans(const relay::Expr& entry);
+
+ void VisitExpr(const relay::Expr& expr) final;
+
+ private:
+ Array<Span> spans_;
+ const RelayExprSet& inputs_;
+};
+
+void RelayCollectSpans::VisitExpr(const relay::Expr& expr) {
+ if (visit_counter_.count(expr.get())) {
+ return;
+ }
+ if (expr->span.defined()) {
+ spans_.push_back(expr->span);
+ }
+ if (inputs_.find(expr) != inputs_.end()) {
+ // becuase it returns directly, it should be recorded as visted manually.
+ visit_counter_.insert({expr.get(), 1});
+ return;
+ }
+ relay::ExprVisitor::VisitExpr(expr);
+}
+
+Span RelayCollectSpans::CollectSpans(const relay::Expr& entry) {
+ VisitExpr(entry);
+ return SequentialSpan(spans_);
+}
+
+class RelayRecursivelyFill : public relay::ExprMutator {
+ public:
+ explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {})
+ : span_(span), inputs_(inputs) {}
+
+ // From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
+ void Fill(const relay::Expr& entry);
+
+ relay::Expr VisitExpr(const relay::Expr& expr) final;
+
+ private:
+ const Span& span_;
+ const RelayExprSet& inputs_;
+};
+
+relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) {
+ if (inputs_.find(expr) != inputs_.end()) {
+ return expr;
+ }
+ // Skip op node. Align with python frontend
+ if (!expr.as<OpNode>()) {
+ expr->span = span_;
+ }
+
+ return relay::ExprMutator::VisitExpr(expr);
+}
+
+void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); }
+
+class TirCollectSpans : public tir::StmtExprVisitor {
+ public:
+ explicit TirCollectSpans(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {})
+ : expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}
+
+ void VisitExpr(const PrimExpr& expr) final;
+ void VisitStmt(const tir::Stmt& stmt) final;
+
+ bool IsInput(const PrimExpr& expr);
+ bool IsInput(const tir::Stmt& stmt);
+
+ // From entry to inputs, recursively collect spans. The spans of inputs are included.
+ Span CollectSpans(const PrimExpr& expr);
+ // From entry to inputs, recursively collect spans. The spans of inputs are included.
+ Span CollectSpans(const tir::Stmt& stmt);
+
+ private:
+ Array<Span> spans_;
+ std::unordered_map<const Object*, size_t> visit_counter_;
+ const PrimExprSet& expr_inputs_;
+ const StmtSet& stmt_inputs_;
+};
+
+Span TirCollectSpans::CollectSpans(const PrimExpr& expr) {
+ operator()(expr);
+ return SequentialSpan(spans_);
+}
+
+Span TirCollectSpans::CollectSpans(const tir::Stmt& stmt) {
+ operator()(stmt);
+ return SequentialSpan(spans_);
+}
+
+bool TirCollectSpans::IsInput(const PrimExpr& expr) {
+ return expr_inputs_.find(expr) != expr_inputs_.end();
+}
+
+bool TirCollectSpans::IsInput(const tir::Stmt& stmt) {
+ return stmt_inputs_.find(stmt) != stmt_inputs_.end();
+}
+
+void TirCollectSpans::VisitExpr(const PrimExpr& expr) {
+ if (visit_counter_.count(expr.get())) {
+ return;
+ }
+ if (expr->span.defined()) {
+ spans_.push_back(expr->span);
+ }
+ if (IsInput(expr)) {
+ // becuase it returns directly, it should be recorded as visted manually.
+ visit_counter_.insert({expr.get(), 1});
+ return;
+ }
+ StmtExprVisitor::VisitExpr(expr);
+}
+
+void TirCollectSpans::VisitStmt(const tir::Stmt& stmt) {
+ if (visit_counter_.count(stmt.get())) {
+ return;
+ }
+ if (stmt->span.defined()) {
+ spans_.push_back(stmt->span);
+ }
+ if (IsInput(stmt)) {
+ // becuase it returns directly, it should be recorded as visted manually.
+ visit_counter_.insert({stmt.get(), 1});
+ return;
+ }
+ StmtExprVisitor::VisitStmt(stmt);
+}
+
+class TirRecursivelyFill : public tir::StmtExprMutator {
+ public:
+ TirRecursivelyFill(const Span& span, const PrimExprSet& expr_inputs = {},
+ const StmtSet& stmt_inputs = {})
+ : span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}
+
+ // From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
+ tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); }
+ // From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
+ PrimExpr Fill(const PrimExpr& e) { return operator()(e); }
+
+ bool IsInput(const PrimExpr& expr);
+ bool IsInput(const tir::Stmt& stmt);
+
+ PrimExpr VisitExpr(const PrimExpr& expr) final;
+ tir::Stmt VisitStmt(const tir::Stmt& stmt) final;
+
+ private:
+ const Span& span_;
+ const PrimExprSet& expr_inputs_;
+ const StmtSet& stmt_inputs_;
+};
+
+tir::Stmt TirRecursivelyFill::VisitStmt(const tir::Stmt& stmt) {
+ if (IsInput(stmt)) {
+ return stmt;
+ }
+ stmt->span = span_;
+ return StmtExprMutator::VisitStmt(stmt);
+}
+
+bool TirRecursivelyFill::IsInput(const PrimExpr& expr) {
+ return expr_inputs_.find(expr) != expr_inputs_.end();
+}
+
+bool TirRecursivelyFill::IsInput(const tir::Stmt& stmt) {
+ return stmt_inputs_.find(stmt) != stmt_inputs_.end();
+}
+
+PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) {
+ if (IsInput(expr)) {
+ return expr;
+ }
+ expr->span = span_;
+ return StmtExprMutator::VisitExpr(expr);
+}
+
+struct SIBuilder::Impl {
+ virtual Span Build() const { return Span(); }
+ virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {}
+ virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {}
+ virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {}
+ virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {}
+ virtual void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) {}
+ virtual void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) {}
+ virtual void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) {}
+ virtual void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) {}
+};
+
+SIBuilder::~SIBuilder() = default;
+
+Span SIBuilder::Build() const { return impl_->Build(); }
+
+template <>
+void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {
+ impl_->RecursivelyFillSpan(entry, inputs);
+}
+
+template <>
+void SIBuilder::RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {
+ impl_->RecursivelyFillSpan(entry, inputs);
+}
+
+void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {
+ impl_->RecursivelyFillSpan(entry, inputs);
+}
+
+void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {
+ impl_->RecursivelyFillSpan(entry, inputs);
+}
+
+std::unique_ptr<SIBuilder::Impl> SIBuilder::CreateImpl(const Span& span) {
+ struct Impl : public SIBuilder::Impl {
+ explicit Impl(const Span& span) : span_(span) {}
+ Span Build() const final { return span_; }
+ void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final {
+ RelayRecursivelyFill(Build(), inputs).Fill(entry);
+ }
+ void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final {
+ TirRecursivelyFill(Build(), inputs).Fill(entry);
+ }
+ void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final {
+ TirRecursivelyFill(Build(), inputs).Fill(entry);
+ }
+ void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final {
+ TirRecursivelyFill(Build(), {}, inputs).Fill(entry);
+ }
+ void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) final {
+ span_ = RelayCollectSpans(inputs).CollectSpans(entry);
+ }
+ void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) final {
+ span_ = TirCollectSpans(inputs).CollectSpans(entry);
+ }
+ void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final {
+ span_ = TirCollectSpans(inputs).CollectSpans(entry);
+ }
+ void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) final {
+ span_ = TirCollectSpans({}, inputs).CollectSpans(entry);
+ }
+
+ private:
+ Span span_;
+ };
+
+ const bool enable_si_builder = transform::PassContext::Current()
+ ->GetConfig<Bool>("ir.enable_si_builder", Bool(false))
+ .value();
+
+ if (enable_si_builder) {
+ return std::make_unique<Impl>(span);
+ }
+
+ return std::make_unique<SIBuilder::Impl>();
+}
+
+SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {}
+SIBuilder::SIBuilder(const Array<Span>& spans) : impl_(CreateImpl(SequentialSpan(spans))) {}
+SIBuilder::SIBuilder(const std::initializer_list<Span>& init)
+ : impl_(CreateImpl(SequentialSpan(Array<Span>(init)))) {}
+
+template <>
+SIBuilder::SIBuilder(const relay::Expr& expr, const Array<relay::Expr>& inputs)
+ : impl_(CreateImpl(Span())) {
+ impl_->CollectSpansSpan(expr, RelayExprSet(inputs.begin(), inputs.end()));
+}
+
+template <>
+SIBuilder::SIBuilder(const PrimExpr& expr, const Array<PrimExpr>& inputs)
+ : impl_(CreateImpl(Span())) {
+ impl_->CollectSpansSpan(expr, PrimExprSet(inputs.begin(), inputs.end()));
+}
+
+SIBuilder::SIBuilder(const tir::Stmt& s, const Array<PrimExpr>& inputs)
+ : impl_(CreateImpl(Span())) {
+ impl_->CollectSpansSpan(s, PrimExprSet(inputs.begin(), inputs.end()));
+}
+
+SIBuilder::SIBuilder(const tir::Stmt& s, const Array<tir::Stmt>& inputs)
+ : impl_(CreateImpl(Span())) {
+ impl_->CollectSpansSpan(s, StmtSet(inputs.begin(), inputs.end()));
+}
+
+// Register build pipeline related options
+TVM_REGISTER_PASS_CONFIG_OPTION("ir.enable_si_builder", Bool);
+
+} // namespace tvm
diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc
index 8b91390..721a30a 100644
--- a/src/ir/source_map.cc
+++ b/src/ir/source_map.cc
@@ -88,11 +88,58 @@
TVM_REGISTER_NODE_TYPE(SpanNode);
+SequentialSpan::SequentialSpan(tvm::Array<Span> spans) {
+ auto n = make_object<SequentialSpanNode>();
+ tvm::Array<Span> tmp_spans;
+ for (const Span& s : spans) {
+ if (const SequentialSpanNode* seq_s = s.as<SequentialSpanNode>()) {
+ tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end());
+ } else {
+ tmp_spans.push_back(s);
+ }
+ }
+ n->spans = std::move(tmp_spans);
+
+ n->line = 0;
+ n->end_line = 0;
+ n->column = 0;
+ n->end_column = 0;
+
+ data_ = std::move(n);
+}
+
+SequentialSpan::SequentialSpan(std::initializer_list<Span> init) {
+ auto n = make_object<SequentialSpanNode>();
+ tvm::Array<Span> spans = tvm::Array<Span>(init);
+ tvm::Array<Span> tmp_spans;
+ for (const Span& s : spans) {
+ if (const SequentialSpanNode* seq_s = s.as<SequentialSpanNode>()) {
+ tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end());
+ } else {
+ tmp_spans.push_back(s);
+ }
+ }
+ n->spans = std::move(tmp_spans);
+
+ n->line = 0;
+ n->end_line = 0;
+ n->column = 0;
+ n->end_column = 0;
+
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(SequentialSpanNode);
+
TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line,
int column, int end_column) {
return Span(source_name, line, end_line, column, end_column);
});
+TVM_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array<Span> spans) {
+ return SequentialSpan(spans);
+});
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
@@ -100,6 +147,19 @@
<< ", " << node->column << ", " << node->end_column << ")";
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<SequentialSpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const SequentialSpanNode*>(ref.get());
+
+ p->stream << "SequentailSpan([ ";
+ int index = 0;
+ const int last = node->spans.size() - 1;
+ while (index < last) {
+ p->stream << node->spans[index++] << ", ";
+ }
+ p->stream << node->spans[last] << " ])";
+ });
+
TVM_REGISTER_NODE_TYPE(SourceNode);
/*! \brief Construct a source from a string. */
diff --git a/tests/cpp/si_builder_test.cc b/tests/cpp/si_builder_test.cc
new file mode 100644
index 0000000..f65deba
--- /dev/null
+++ b/tests/cpp/si_builder_test.cc
@@ -0,0 +1,399 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/ir/si_builder.h>
+#include <tvm/ir/source_map.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/te/operation.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt.h>
+
+tvm::Span _CreateSpan(std::string text) {
+ return tvm::Span(tvm::SourceName::Get(text), 0, 0, 0, 0);
+}
+
+class RelayCheckSpan : public tvm::relay::ExprVisitor {
+ public:
+ std::vector<tvm::Span> tmp_result_;
+ std::vector<tvm::Span> lhs_spans_;
+ std::vector<tvm::Span> rhs_spans_;
+
+ std::vector<tvm::Span> CollectSpan(tvm::relay::Expr expr) {
+ tmp_result_.clear();
+ VisitExpr(expr);
+ return tmp_result_;
+ }
+
+ void Check(tvm::relay::Expr lhs, tvm::relay::Expr rhs) {
+ tvm::relay::Function lhs_f =
+ tvm::relay::Function(tvm::relay::FreeVars(lhs), lhs, tvm::relay::Type(), {});
+ tvm::relay::Function rhs_f =
+ tvm::relay::Function(tvm::relay::FreeVars(rhs), rhs, tvm::relay::Type(), {});
+ EXPECT_TRUE(tvm::StructuralEqual()(lhs_f, rhs_f));
+ lhs_spans_ = CollectSpan(lhs);
+ rhs_spans_ = CollectSpan(rhs);
+
+ EXPECT_EQ(lhs_spans_.size(), rhs_spans_.size());
+ for (std::size_t i = 0; i != lhs_spans_.size(); i++) {
+ EXPECT_TRUE(tvm::StructuralEqual()(lhs_spans_[i], rhs_spans_[i]));
+ }
+ }
+
+ void VisitExpr(const tvm::relay::Expr& expr) {
+ if (expr->span.defined()) {
+ tmp_result_.push_back(expr->span);
+ }
+ using TParent = ExprFunctor<void(const tvm::relay::Expr&)>;
+ TParent::VisitExpr(expr);
+ visit_counter_.emplace(expr.get(), 1);
+ }
+};
+
+TEST(SIBuilder, SequentialSpan) {
+ using namespace tvm;
+ Array<Span> ingredients = {_CreateSpan("first"), _CreateSpan("second"), _CreateSpan("third")};
+
+ SequentialSpan seq_span_1{ingredients[0], ingredients[1]};
+ EXPECT_EQ(seq_span_1->spans.size(), 2);
+ for (std::size_t i = 0; i != seq_span_1->spans.size(); i++) {
+ EXPECT_EQ(seq_span_1->spans[i], ingredients[i]);
+ }
+
+ // nested SequentialSpan test
+ SequentialSpan seq_span_2{seq_span_1, ingredients[2]};
+ EXPECT_EQ(seq_span_2->spans.size(), 3);
+ for (std::size_t i = 0; i != seq_span_2->spans.size(); i++) {
+ EXPECT_EQ(seq_span_2->spans[i], ingredients[i]);
+ }
+
+ // Array constructor test
+ Array<Span> tvm_array(ingredients);
+ SequentialSpan seq_span_3(tvm_array);
+ EXPECT_EQ(seq_span_3->spans.size(), 3);
+ for (std::size_t i = 0; i != seq_span_3->spans.size(); i++) {
+ EXPECT_EQ(seq_span_3->spans[i], ingredients[i]);
+ }
+}
+
+TEST(SIBuilder, CreateSapn) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span span_1 = _CreateSpan("first");
+ {
+ SIBuilder si_builder(span_1);
+ EXPECT_EQ(span_1, si_builder.Build());
+ }
+
+ Span span_2 = _CreateSpan("second");
+ Array<Span> ingredients = {span_1, span_2};
+ SequentialSpan seq_span_1{ingredients[0], ingredients[1]};
+ {
+ SIBuilder si_builder_1(seq_span_1);
+ SIBuilder si_builder_2({span_1, span_2});
+ SIBuilder si_builder_3{span_1, span_2};
+
+ Span created_span_1 = si_builder_1.Build();
+ Span created_span_2 = si_builder_2.Build();
+ Span created_span_3 = si_builder_3.Build();
+
+ auto created_seq_span_1 = created_span_1.as<SequentialSpanNode>();
+ auto created_seq_span_2 = created_span_2.as<SequentialSpanNode>();
+ auto created_seq_span_3 = created_span_3.as<SequentialSpanNode>();
+ EXPECT_EQ(created_seq_span_1->spans.size(), 2);
+ EXPECT_EQ(created_seq_span_2->spans.size(), 2);
+ EXPECT_EQ(created_seq_span_3->spans.size(), 2);
+ for (std::size_t i = 0; i != 2; i++) {
+ EXPECT_EQ(created_seq_span_1->spans[i], ingredients[i]);
+ EXPECT_EQ(created_seq_span_2->spans[i], ingredients[i]);
+ EXPECT_EQ(created_seq_span_3->spans[i], ingredients[i]);
+ }
+ }
+}
+
+TEST(SIBuilder, DisableSIBuilder) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(false));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span span_1 = _CreateSpan("first");
+ {
+ SIBuilder si_builder(span_1);
+ EXPECT_NE(span_1, si_builder.Build());
+ }
+}
+
+TEST(SIBuilder, RelayRecursivelyFill) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span test_span = _CreateSpan("test_span");
+ Span a_node_span = _CreateSpan("a_node");
+
+ auto tensor_type = relay::TensorType({2, 3}, tvm::DataType::Float(32));
+ relay::Expr add_op = relay::Op::Get("add");
+ relay::Expr relu_op = relay::Op::Get("nn.relu");
+ relay::Expr leaky_relu_op = relay::Op::Get("nn.leaky_relu");
+ // Reset span of OpNode. Because a relay Op Node is a static reference, any change on it will
+ // be assigned the original object.
+ add_op->span = Span();
+ relu_op->span = Span();
+ leaky_relu_op->span = Span();
+
+ relay::Expr a = relay::Var("a", tensor_type, a_node_span);
+ relay::Expr x = relay::Call(relu_op, {a}, tvm::Attrs(), {});
+ relay::Expr y = relay::Call(leaky_relu_op, {x}, tvm::Attrs(), {});
+ relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {});
+
+ relay::Expr expected_a = relay::Var("a", tensor_type, a_node_span);
+ relay::Expr expected_x = relay::Call(relu_op, {expected_a}, tvm::Attrs(), {}, test_span);
+ relay::Expr expected_y = relay::Call(leaky_relu_op, {expected_x}, tvm::Attrs(), {}, test_span);
+ relay::Expr expected_z =
+ relay::Call(add_op, {expected_y, expected_x}, tvm::Attrs(), {}, test_span);
+
+ SIBuilder si_builder(test_span);
+ si_builder.RecursivelyFillSpan(z, {a});
+ RelayCheckSpan checker;
+ checker.Check(z, expected_z);
+}
+
+TEST(SIBuilder, RelayCollectSpans) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span a_node_span = _CreateSpan("a_node");
+ Span x_node_span = _CreateSpan("x_node");
+ Span y_node_span = _CreateSpan("y_node");
+ Span z_node_span = _CreateSpan("z_node");
+ std::vector<Span> target = {z_node_span, y_node_span, x_node_span, a_node_span};
+
+ auto tensor_type = relay::TensorType({2, 3}, tvm::DataType::Float(32));
+ relay::Expr add_op = relay::Op::Get("add");
+ relay::Expr relu_op = relay::Op::Get("nn.relu");
+ relay::Expr leaky_relu_op = relay::Op::Get("nn.leaky_relu");
+ // Reset span of OpNode. Because a relay Op Node is a static reference, any change on it will
+ // be assigned the original object.
+ add_op->span = Span();
+ relu_op->span = Span();
+ leaky_relu_op->span = Span();
+
+ relay::Expr a = relay::Var("a", tensor_type, a_node_span);
+ relay::Expr x = relay::Call(relu_op, {a}, tvm::Attrs(), {}, x_node_span);
+ relay::Expr y = relay::Call(leaky_relu_op, {x}, tvm::Attrs(), {}, y_node_span);
+ relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}, z_node_span);
+
+ SIBuilder si_builder(z, {a});
+ Span created_span = si_builder.Build();
+ auto created_seq_span = created_span.as<SequentialSpanNode>();
+ EXPECT_EQ(created_seq_span->spans.size(), 4);
+ for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) {
+ EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i]));
+ }
+}
+
+TEST(SIBuilder, TirCollectSpansPrimExpr) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span a_node_span = _CreateSpan("a_node");
+ Span b_node_span = _CreateSpan("b_node");
+ Span x_node_span = _CreateSpan("x_node");
+ Span add_1_node_span = _CreateSpan("add_1_node");
+ Span add_2_node_span = _CreateSpan("add_2_node");
+ Span z_node_span = _CreateSpan("z_node");
+ std::vector<Span> target = {z_node_span, add_2_node_span, add_1_node_span, x_node_span,
+ a_node_span};
+ tir::Var a("a");
+ tir::Var b("b");
+ auto x = a + b;
+ auto add_1 = x + 1;
+ auto add_2 = add_1 + 2;
+ auto z = max(add_2, 100);
+ x->span = x_node_span;
+ a->span = a_node_span;
+ b->span = b_node_span;
+ add_1->span = add_1_node_span;
+ add_2->span = add_2_node_span;
+ z->span = z_node_span;
+
+ SIBuilder si_builder(z, {x});
+ Span created_span = si_builder.Build();
+ auto created_seq_span = created_span.as<SequentialSpanNode>();
+
+ EXPECT_EQ(created_seq_span->spans.size(), 4);
+ for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) {
+ EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i]));
+ }
+}
+
+TEST(SIBuilder, TirCollectSpansStmtWithPrimInput) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span a_node_span = _CreateSpan("a_node");
+ Span b_node_span = _CreateSpan("b_node");
+ Span x_node_span = _CreateSpan("x_node");
+ Span z_node_span = _CreateSpan("z_plus_1");
+ Span stmt_node_span = _CreateSpan("stmt_node");
+ std::vector<Span> target = {stmt_node_span, z_node_span, x_node_span};
+ tir::Var a("a");
+ tir::Var b("b");
+ auto x = a + b;
+ x->span = x_node_span;
+ auto fmaketest = [&]() {
+ auto z = x + 1;
+ z->span = z_node_span;
+ tir::Stmt ret = te::Evaluate(z);
+ return ret;
+ };
+ auto stmt = fmaketest();
+ stmt->span = stmt_node_span;
+ SIBuilder si_builder(stmt, {x});
+ Span created_span = si_builder.Build();
+ auto created_seq_span = created_span.as<SequentialSpanNode>();
+
+ EXPECT_EQ(created_seq_span->spans.size(), 3);
+ for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) {
+ EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i]));
+ }
+}
+
+TEST(SIBuilder, TirCollectSpansStmtWithStmtInput) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span zero_node_span = _CreateSpan("zero_node");
+ Span body_node_span = _CreateSpan("body_node");
+ Span init_node_span = _CreateSpan("init_node");
+ Span block_node_span = _CreateSpan("block_node");
+ std::vector<Span> target = {block_node_span, init_node_span, body_node_span};
+
+ tir::Stmt zero = tir::Evaluate(Integer(0), zero_node_span);
+ tir::Stmt body = tir::Evaluate(Integer(1), body_node_span);
+ tir::Stmt init = tir::IfThenElse(tir::const_true(), zero, zero, init_node_span);
+ tir::Block block({}, {}, {}, "block", body, init, Array<tir::Buffer>(),
+ Array<tir::MatchBufferRegion>(), Map<String, ObjectRef>(), block_node_span);
+ SIBuilder si_builder(block, {init});
+ Span created_span = si_builder.Build();
+ auto created_seq_span = created_span.as<SequentialSpanNode>();
+
+ EXPECT_EQ(created_seq_span->spans.size(), 3);
+ for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) {
+ EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i]));
+ }
+}
+
+TEST(SIBuilder, TirRecursivelyFillPrimExpr) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span test_span = _CreateSpan("test_span");
+ tir::Var a("a");
+ tir::Var b("b");
+ auto x = a + b;
+ auto add_1 = x + 1;
+ auto add_2 = add_1 + 2;
+ auto z = max(add_2, 100);
+
+ SIBuilder si_builder(test_span);
+ si_builder.RecursivelyFillSpan(z, {a, b});
+ EXPECT_TRUE(!a->span.defined());
+ EXPECT_TRUE(!b->span.defined());
+ EXPECT_TRUE(StructuralEqual()(x->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(add_1->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(add_2->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(z->span, test_span));
+
+ ObjectRef tmp = z;
+ PrimExpr zz = Downcast<PrimExpr>(tmp);
+ std::ostringstream os;
+ os << z;
+ EXPECT_TRUE(zz.same_as(z));
+ EXPECT_EQ(os.str(), "T.max(a + b + 1 + 2, 100)");
+}
+
+TEST(SIBuilder, TirRecursivelyFillStmtWithPrimInput) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ Span test_span = _CreateSpan("test_span");
+ tir::Var a("a");
+ tir::Var b("b");
+ auto x = a + b;
+ auto z = x + 1;
+ tir::Stmt stmt = te::Evaluate(z);
+ SIBuilder si_builder(test_span);
+ const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual> inputs = {a, b};
+ si_builder.RecursivelyFillSpan(stmt, inputs);
+
+ EXPECT_TRUE(!a->span.defined());
+ EXPECT_TRUE(!b->span.defined());
+ EXPECT_TRUE(StructuralEqual()(x->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(z->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(stmt->span, test_span));
+
+ ObjectRef tmp = z;
+ PrimExpr zz = Downcast<PrimExpr>(tmp);
+ std::ostringstream os;
+ os << z;
+ EXPECT_TRUE(zz.same_as(z));
+ EXPECT_EQ(os.str(), "a + b + 1");
+}
+
+TEST(SIBuilder, TirRecursivelyFillStmtWithStmtInput) {
+ using namespace tvm;
+ auto pass_ctx = transform::PassContext::Create();
+ pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
+ tvm::With<transform::PassContext> ctx_scope(pass_ctx);
+ tir::Stmt zero = tir::Evaluate(Integer(0));
+ tir::Stmt init = tir::IfThenElse(tir::const_true(), zero, zero);
+ tir::Stmt body = tir::Evaluate(Integer(1));
+ tir::Block block(/*iter_vars=*/{}, /*reads=*/{},
+ /*writes=*/{}, /*name_hint=*/"block", /*body=*/body,
+ /*init=*/init);
+
+ Span test_span = _CreateSpan("test_span");
+ const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual> inputs = {init};
+ SIBuilder si_builder(test_span);
+ si_builder.RecursivelyFillSpan(block, {init});
+ EXPECT_TRUE(!zero->span.defined());
+ EXPECT_TRUE(!init->span.defined());
+ EXPECT_TRUE(StructuralEqual()(body->span, test_span));
+ EXPECT_TRUE(StructuralEqual()(block->span, test_span));
+
+ tir::Stmt expected_zero = tir::Evaluate(Integer(0));
+ tir::Stmt expected_init = tir::IfThenElse(tir::const_true(), zero, zero);
+ tir::Stmt expected_body = tir::Evaluate(Integer(1));
+ tir::Block expected_block(/*iter_vars=*/{}, /*reads=*/{},
+ /*writes=*/{}, /*name_hint=*/"block", /*body=*/expected_body,
+ /*init=*/expected_init);
+ EXPECT_TRUE(tvm::StructuralEqual()(block, expected_block));
+}