blob: cea739877c9fca202af818c5ba3e0d7b3f25652c [file] [log] [blame] [view]
- Feature Name: TUNIP: TVMScript Unified Printer
- Start Date: 05/25/2022
- RFC PR: [apache/tvm-rfcs#74](https://github.com/apache/tvm-rfcs/pull/74)
- GitHub Issue: [apache/tvm#11912](https://github.com/apache/tvm/issues/11912)
- Co-Authors: Lite Ye ([**@yelite**](https://github.com/yelite)), Greg Bonik
([**@gbonik**](https://github.com/gbonik)) Yong Wu
([**@yongwww**](https://github.com/yongwww)), Yuchen Jin
([**@YuchenJin**](https://github.com/YuchenJin))
# Summary
[summary]: #summary
This RFC proposes to modularize and infrastructuralize the existing TVMScript
printer, to develop unified printing mechanism across TVM stack, where TIR,
Relax and any future vendor-specific IR are all treated equally as dialects and
could be printed together without potential conflict in engineering.
# Motivation
[motivation]: #motivation
TVMScript, as a roundtrippable python-based text format, is
the central piece of TVM performance productivity. As the frontend of TVM, it
enables end users to directly construct the TVM IR, either TIR or Relax, in a
pragmatic approach. From Relax to MetaSchedule and TIR, TVMScript enables
inspectability and reproducibility at any level of compilation and
optimization. Furthermore, based on TVMScript, developers are empowered to
intercept, manipulate and customize the compiler behavior in a principled way.
While TVMScript is gaining traction and buy-in from the open source community,
the TVMScript printer suffers from multiple profound design issues:
- Not supporting IR fragment printing requires users to jump in-between
TVMScript syntax and TIRText syntax
- The lack of modularity leads to practical inability to scale up to and
maintain multiple IRs without engineering conflicts
- Enhancing co-existence of multi-level IRs often leads to re-engineering of
existing features.
**Goal.** This RFC introduces Tvmscript UNIfied Printer (TUNIP), a systematic
redesign to
address those engineering, usability and scalability issues above. The goal of
this re-design includes:
**Goal 1 [Unified Representation].** Become the unified roundtrippable
representation of TIR and Relax, allowing systematic mixing of IRs or IR
fragments (Relax + TIR) in the same IRModule in the target language (for
example, python, C++).
Currently TVMScript priner is designed specifically for TIR, and printing
multiple dialects together was not a design goal at that time. Therefore,
supporting Relax requires ad-hoc hack around the system (for
instance, [relax#149](https://github.com/tlc-pack/relax/pull/149) added support
of printing `T.cast` and `T.max` in an ad-hoc way, without reusing the printing
code for TIR). The unified printer in this RFC addresses this issue by having a
unified approach for printing IR tree to TVMScript. Engineers will be able to
implement a fully-fledged printer for Relax, TIR and any potential IR in the
future with minimal effort.
The folder structure that we want to pursue is:
```bash
include/tvm/script/printer/
└── ... # Public headers for the core infra
src/script/printer/
├── core # Core infra, which is IR-agnostic
│ ├── ir_docsifier.cc
│ └── ...
├── tir # TIR dialect
│   ├── expr.cc
│   ├── stmt.cc
│   └── ...
└── relax # Hypothetical Relax dialect (not part of our RFC)
   └── ...
```
**Goal 2 [Third-Party IRs in Multi-Stage Compilation].** Modularize and
infrastructuralize the printer to support more future IRs or third-party IRs at
any level with maintainability, for example, IRs at lower-level than TIR, or
Relax VM executable.
The current TVMScript printer is tightly coupled with TIR by being a subclass
of TIR-specific functors
([link](https://github.com/apache/tvm/blob/main/src/printer/tvmscript_printer.cc#L129)).
This design isnt scalable when we want to support more IRs. More importantly,
its impossible for the current approach to support third-party IR bteing
registered in a dynamic library.
**Goal 3 [Reproducibility and Error Reporting].** Expand reproducibility and
flexible rendering of diagnostic messages during any level of IR
transformation.
For example, the following snippet runs and produces an error.
```py
import tvm
@T.prim_func
def func_a(A: T.Buffer[(1,), "int32"]):
A[0] = 0
@T.prim_func
def func_b(A: T.Buffer[(8,), "int32"]):
A[0] = 0
tvm.ir.assert_structural_equal(func_a, func_b)
```
The current error message indicates what the difference was, but not where it
occurred. This can sometimes be inferred from a stack trace, but becomes
increasingly difficult with larger IR graphs.
```
ValueError: StructuralEqual check failed, caused by lhs:
1
and rhs:
8
```
TUNIP should enable individual utilities and IR passes to have error messages
directing the user to exact locations in the IR representation.
```
ValueError: StructuralEqual check failed, first delta highlighted below
@T.prim_func
def func_a(A: T.Buffer[(1,), "int32"]) -> None:
^^^^
A[0] = 0
@T.prim_func
def func_b(A: T.Buffer[(8,), "int32"]) -> None:
^^^^
A[0] = 0
```
# Guide-level explanation
[guide-level-explanation]: #guide-level-explanation
This section introduces the design philosophy of the printer, and demonstrates
the proposed user-facing APIs where users means IR developers.
## Two-Stage Translation
Traditionally in TVM stack, printing is a single-stage process. The printer
assumes certain syntax of the target language, and therefore, so far there are
3 different printers all for TIR: ReprPrinter, TIRTextPrinter,
TVMScriptPrinter.
We extend the idea of the existing Doc class at
[src/printer/doc.h#L67](https://github.com/apache/tvm/blob/main/src/printer/doc.h#L67)
to allow better consistency and scalability. An IR, which could
be TIR, Relax or any other ones developed by third-party vendors, is first
translated to an intermediate Doc node tree, and then the Doc tree is mapped to a target
language, for example, Python, C++ IRBuilder API, or Rust.
**Stage 1 [TVM IR => Doc]**. On the first stage, the printer needs to take care
of translating a TVM IR to Doc tree. As an example, `tir.For` is translated to
`ForDoc` without having to worry about the underlying language. Note that some
complicated nodes in TVM IR, for example, `PrimFunc`, could be translated to
multiple IR elements, including `FunctionDoc` and a few `StmtDoc`.
During the translation from IR to Doc tree, it is possible that some statement
influences the syntax of its children or vice verse, especially for syntactic
sugars and declaring undefined variables in IR fragment printing. Therefore, a
generic data structure `Frame` is introduced to allow retrieval and
manipulation the relevant context information.
**Stage 2. [Doc => target language]**. On the second stage, Doc tree is then
honestly translated to the target language in text format. For example, when
the target language is python, `ForDoc` is translated to pythons for loop
syntax:
```python
for ... in ...:
...
```
When the target language becomes python IRBuilder, `ForDoc` is translated to:
```cpp
with T.For(...):
...
```
For generality, the Doc tree is designed to select minimal elements that exist
in languages used in developing TVM. A full spec of the Doc could be found in
the next section.
## Distributed Registration
As a major engineering challenge for TVMScript to scale to multiple IRs, the
existing printing logic has to be engineered, maintained and re-engineered in a
single file, which has brought significant confusion for developing multi-level
IRs for TVM Unity.
Inspired by the pass infrastructure, as well as the ReprPrinter in TVM, we
propose to develop the infrastructure to enable distributed registration, and
further allows printer for different levels of IR to be registered in separate
translation units, and in the meantime keeps the capability to be mixed
together at various level, for example, Relax uses TIR expression in its
function bodies, and TIR calls back to Relax function.
## Diagnostics and Reproducibility
Existing error reporting mechanisms have not taken IR structure and
reproducibility into consideration. Usually it reports a single line error
message without providing necessary context of how the IR looks like during
compilation. For example, when comparing whether two TIRs are structurally
equivalent, the system may report:
```cpp
ValueError: StructuralEqual check failed, caused by lhs:
{slow_memory_3_var: buffer(slow_memory_3_buffer_var, 0x501bf80), fast_memory_2_var: buffer(fast_memory_2_buffer_var, 0x501bd80), placeholder_3: buffer(placeholder_5, 0x50138a0), placeholder_2: buffer(placeholder_4, 0x5012b60), T_subtract: buffer(T_subtract_1, 0x5014390)}
and rhs:
{}
```
which lacks necessary information for users to understand where the mismatch
is.
As a recent effort, structural error reporting in TIR scheduling provides
relevant and reproducible context, as demonstrated below:
```cpp
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(a: tir.handle, b: tir.handle) -> None:
A = tir.match_buffer(a, [128, 128, 128, 128], dtype="float32")
B = tir.match_buffer(b, [128, 128, 128, 128], dtype="float32")
# body
# with tir.block("root")
for i, j, k, l in tir.grid(128, 128, 128, 8):
tir.Block#0
with tir.block("B"):
^^^^^^^^^^^^^^^^^^^^
vi, vj, vk = tir.axis.remap("SSS", [i, j, k])
vl = tir.axis.spatial(128, l * 16)
tir.reads([A[vi, vj, vk, vl]])
tir.writes([B[vi, vj, vk, vl]])
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * tir.float32(2)
Error: ...
```
However, the underlying mechanism supports only S-TIR and error reporting on
`tir.ForNode` and `tir.BlockNode`, and is less extensible for generic cases.
To generalize this UX across the TVM stack, during the first stage in
translation, the following steps is additionally executed:
- Each Doc node is optionally attached to a node in TVM IR
- After the 1st stage is finished, collect all IR nodes that gets attached to
Doc into a map, whose key is IR node and value is a list of Doc nodes.
- For each IR node that has diagnostic message, trace back through its parent
until it reaches to an IR node in the map collected in previous step. Then it
can produce a map from Doc node to diagnostic message.
- In the 2nd stage, diagnostic message will be printed as doc is being printed
into target language
# Reference-level explanation
## Doc Spec
The design of the Doc is to have a unified representation of TVMScript in
different languages. The overall structure is simplied from Python ast, and
their meaning is straightforward.
```py
Doc(Optional<ObjectRef> source) # Base class for doc
# Expression
ExprDoc() # Base class for expression
LiteralDoc(Union[IntImm, FloatImm, String, nullptr_t] value)
IdDoc(String name)
AttrAccessDoc(ExprDoc value, String attr)
IndexDoc(ExprDoc value, Array<Union<ExprDoc, SliceDoc>> indices)
CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys, Array<ExprDoc> kwargs_values)
OperationDoc(OperationKind kind, Array<ExprDoc> operands)
LambdaDoc(Array<IdDoc> args, ExprDoc body)
TupleDoc(Array<ExprDoc> elements)
ListDoc(Array<ExprDoc> elements)
DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values)
# Statements
StmtDoc(Array<String> comments) # Base class
AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation)
IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch)
WhileDoc(ExprDoc predicate, Array<StmtDoc> body)
ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body)
ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body)
ExprStmtDoc(ExprDoc expr)
# Special Docs
SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop)
FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators, ExprDoc return_type, Array<StmtDoc> body))
ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<AssignDoc> aliases, Array<FunctionDoc> functions)
```
## IRDocsifier Spec
IRDocsifier is responsible for transforming IR node tree into Doc tree. Its API
looks like
```cpp
class IRDocsifierNode : public Object {
public:
// ir_prefix maintains a map from dispatch_token to ir prefix
// so that the print function can construct an expression with
// the current ir prefix, like `T.xxx` in TIR and `R.xxx` in Relax
Map<String, String> ir_prefix;
// TranslationTable maintains a map from IR node to Doc
// It will be updated when new variable gets into the scope,
// like when print PrimFunc or BlockRealize
// It will be looked up when printing variable nodes like tir::Var and tir::Buffer
TranslationTable translation_table;
Array<Frame> frames;
Array<String> dispatch_tokens;
/*!
* \brief Transform the input object into TDoc
*/
template <class TDoc>
TDoc AsDoc(const ObjectRef& obj);
/*!
* \brief Push a new dispatch token into the stack
* \details The top dispatch token decides which dispatch table to use
* when printing Object. This method returns a RAII guard which
* pops the token when going out of the scope.
*/
WithCtx WithDispatchToken(const String& token);
/*!
* \brief Push a new frame the stack
* \details Frame contains the contextual information that's needed during printing,
* for example, variables in the scope. This method returns a RAII guard which
* pops the frame and call the cleanup method of frame when going out of the scope.
*/
WithCtx WithFrame(const Frame& frame);
/*!
* \brief Get the top frame with type FrameType
*/
template <typename FrameType>
Optional<FrameType> GetFrame() const;
}
```
To register print function to the `IRDocsifier`, one should use the
`TVM_STATIC_IR_FUNCTOR` macro and the `set_dispatch` method of the
`ObjectFunctor`
- Registration of printing methods for IR nodes
```cpp
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PrimType>("tir", [](PrimType ty, IRDocsifier p) -> Doc {
using runtime::DLDataType2String;
return TIR(p)->Attr(DLDataType2String(ty->dtype));
});
// Explanation:
// 1. Here we register the print function of the PrimType node in TIR
// 2. The first arg to the `set_dispatch` function is the dispatch token
// It's optional and represents the name of IR
// 3. The first argument to the print function is the node to be printed
// 4. The second argument is instance of `IRDocsifier`, which can be used
// to recursively translate the child nodes.
// 5. The print method returns a subclass of Doc
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<Range>([](Range e, IRDocsifier p) {
return SliceDoc(p->AsExprDoc(e->min), p->AsExprDoc(e->min + e->extent));
});
// The first arg to the `set_dispatch` can be omitted, and
// the print function will be registered the default layer.
// It will be called by default and can be overriden by registering
// another print function under an IR name.
// This function will be called instead of the previous one,
// if Printer is printing relax.
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<Range>("relax", [](Range e, IRDocsifier p) {
...
});
```
- Dispatch
```cpp
auto tir_dispatch_ctx = ir_docsifier->WithDispatchToken("tir");
Doc doc = ir_docsifier->AsDoc<Doc>(node);
// Here we setup the ir_docsifier to call print functions under
// the 'tir' dispatch token, and then call the AsDoc method to
// translate `node`, as an ObjectRef, into `Doc`, by using the
// print functions registered in the dispatch table.
template <class TDoc>
TDoc AsDoc(const ObjectRef& obj) const {
return Downcast<TDoc>(AsDocImpl(obj));
}
```
## Frame Spec
Frame provides the contextual information during printing. Most commonly, frame
contains variable defined in the current scope (like tir function, tir block,
tir loop). A subclass of Frame can be created to store more specific
information. For instance, `tir::ForLoopFrame` should contain the information about
the TIR for loop in order to print iter var remapping when printing
BlockRealize.
```cpp
class FrameNode : public Object {
public:
Array<ObjectRef> objs;
TranslationTableNode* translation_table;
/*!
* \brief Set the name of a variable IR node
*/
virtual IdDoc DefByName(const ObjectRef& obj, const String& name);
/*!
* \brief Set the doc of a variable IR node
* \details This is useful when the variable is implicitly defined in the TVMScript.
* For example, when defining a `tir::Buffer buf`, buf->data is also a tir::Var,
* which should be printed as `buf.data`, rather than an identifier
* in the TVMScript.
*/
virtual ExprDoc DefByDoc(const ObjectRef& obj, const ExprDoc& doc);
}
```
## Upgrade Plan
`IRModule.script()` is the current way to print TIR into TVMScript. It calls
the `script.AsTVMScript` function registered at
`scr/printer/tvmscript_printer.cc`. We plan to split the whole upgrading process
into 5 steps.
1. Without breaking change to existing functionality, upstream system
components piece by piece with small PRs under a tracking issue.
This new system mainly locates in `src/script`, which does not affect
the functionality of the existing TVMScript printer.
2. Expose the unified printer as a global TVM function `script.printer.Script`, which is parallel
to the existing printer.
3. Add a boolean flag `use_legacy_printer` to the Python `IRModule.script`,
which defaults to True. `IRModule.script` calls `script.printer.Print` if `use_legacy_printer`
is explicitly turned off.
4. After stabilizing the new infra, change the default value `use_legacy_printer` to `True`.
5. Finally, deprecate the `use_legacy_printer` flag and clean up legacy code.
# Drawbacks
[drawbacks]: #drawbacks
N/A
# Rationale and alternatives
[rationale-and-alternatives]: #rationale-and-alternatives
Compared to the existing way of printing TVMScript in single stage, introducing
two-stage printing will certainly increase the amount of code that needs to be
written. However, we believe two-stage printing is the right choice because it
reduces the complexity in the printing logic of each IR dialect by removing
unneccessary details about the target language syntax and string operations.
Therefore, it's more scalable if we want to support printing multiple kinds of
IR (TIR, Relax, and any potential third-party IRs in the future).
For example, printing buffer region (like `A[1:10, 2]`) in the current printer looks like
```cpp
Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
if (op->region.size() == 0) {
doc << Print(op->buffer) << "[()]";
} else {
doc << Print(op->buffer) << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
if (i != 0) doc << ", ";
const auto& range = op->region[i];
if (!is_one(range->extent)) {
doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent));
} else {
doc << Print(range->min);
}
}
doc << "]";
}
return doc;
}
```
while in the unified printer with two-stage printing
```cpp
ExprDoc PrintBufferRegion(tir::BufferRegion buffer_region, IRDocsifier p) {
Array<Doc> indices;
for (const Range& range : buffer_region->region) {
if (tir::is_one(range->extent)) {
indices.push_back(p->AsExprDoc(range->min));
} else {
indices.push_back(p->AsExprDoc(range));
}
}
return p->AsExprDoc(buffer_region->buffer)->Index(indices);
}
```
The latter one is much simpler because it's free from the noisy code on how to
print the script in valid index syntax in Python.
Assume the printer needs to support `k` IRs, and it takes `m` time to develop
the logic around IR semantics and `n` time to develop the logic around target
language syntax. It will take `k*(m+n)` time if we use single-stage printing
and `km + n` time if we adopt two-stage printing. We believe the cost of
extending the Doc class will be paid off as soon as `k` is larger than one,
based on our PoC on using two-stage printing for TIR.
Additionally, with two-stage printing we can change the output language from
Python to other languages easily. Although we will still focus on TVMScript in
Python in the foreseeable future, having such flexibilty is a nice additional
benefit.
# Prior art
[prior-art]: #prior-art
RFC for TVMScript: https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516
# Unresolved questions
[unresolved-questions]: #unresolved-questions
N/A
# Future possibilities
[future-possibilities]: #future-possibilities
With the unified TVMScript printer, we have one of the building blocks towards
a more open architecture, where the community can author their own IR and plug
into the TVM stack, interacting with other components and layers.
As a mirror of this RFC, we will send out another RFC on the unified TVMScript parser,
to support parsing TVMScript into different kinds of IR.