blob: ee7e74374b4905ffb47b08bdac4c914a5ec5fd6e [file] [log] [blame] [view]
- Feature Name: tvmscript-metaprogramming
- Start Date: 2022-06-16
- RFC PR: [apache/tvm-rfcs#79](https://github.com/apache/tvm-rfcs/pull/79)
- GitHub Issue: [apache/tvm#12442](https://github.com/apache/tvm/issues/12442)
- Co-Authors: Yaxing Cai ([**@cyx-6**](https://github.com/cyx-6), main implementation), Lite Ye
([**@yelite**](https://github.com/yelite)), Yong Wu
([**@yongwww**](https://github.com/yongwww)), Yuchen Jin
([**@YuchenJin**](https://github.com/YuchenJin)), Eric Lunderberg
([**@Lunderberg**](https://github.com/Lunderberg)), Masahiro Masuda
([**@masahi**](https://github.com/masahi)), Junru Shao
([**@junrushao1994**](https://github.com/junrushao1994), main designer)
# Summary
[summary]: #summary
This RFC proposes a new TVMScript parser infrastructure, supporting extensive
metaprogramming and syntactic sugars. The new infrastructure is IR-agnostic,
treating TIR just as one of dialects. Additionally, the new infrastructure will
provide better tooling around Python ecosystem (pylint, mypy, etc.).
# Motivation
[motivation]: #motivation
**What is TVMScript**.
Check [Blitz Course to TensorIR](https://tvm.apache.org/docs/tutorial/tensor_ir_blitz_course.html) and
[TVMScript Unified Printer RFC](https://github.com/apache/tvm-rfcs/pull/74/files#diff-6965a40ad8df7618ae68e11c88f924542a506c74a931cc3011ae9f99989b5f51R20-R26)
for an introduction into TVMScript.
**What is metaprogramming.** In the context of TVMScript, metaprogramming means
a programmable way to control IR generation. For example, in
https://github.com/apache/tvm/pull/11097, a metaprogramming feature was added
to the TVMScript parser, allows users to programmably control the shapes of the
input buffers of a `PrimFunc`.
### Limitation of current design
The current parser lacks capability on generic metaprogramming that allows user
to have more control on IR construction. This makes it challenging to support
operators like NMS (non-maximum suppression, which is crucial to object
detection model). There is an implementation of NMS at
[python/tvm/topi/cuda/nms.py#L367-L386](https://github.com/apache/tvm/blob/d0650bad66d0ff89a01347537021bc442a98c223/python/tvm/topi/cuda/nms.py#L367-L386).
The implementation of NMS-like operators requires rank-polymorphism and the
ability to interleave host program with TVMScript, which is difficult to be
implemented under the current design.
TVMScript also needs reasonable support on Python tooling. Currently it doesn’t
play nicely with pylint and mypy. For example,
[test_meta_schedule_postproc_rewrite_tensorize.py](https://github.com/apache/tvm/blob/d0650bad66d0ff89a01347537021bc442a98c223/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py)
has 100+ warnings from pylint within only 500 hundred lines of code. This
creates confusion to the user and leaves an impression that TVMScript isn’t a
mature product and not production-ready. Even though it’s something that can be
incrementally improved under the current design, we believe it’s easier to get
an ideal result if we have a design with the tooling support in mind.
The current design also lacks of unified approach for different IRs. At
[https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax](https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax),
a mature implementation of TVMScript parser is maintained for Relax. But it’s
hard to extend if we want to support more IRs for TVM unity.
To conclude, with this RFC, we want to:
1. Add more metaprogramming features to TVMScript, making it easier for TVM
developers to write complicated operators.
2. Improve tooling and documentation of TVMScript, reducing the friction for an
average machine learning practitioner to use TVMScript.
3. Modularize and infrastructuralize the TVMScript parser, lowering the cost to
implement parser for new IR.
# Guide-level explanation
[guide-level-explanation]: #guide-level-explanation
## Metaprogramming features to support
### (F1) Template Metaprogramming
Users should be able to use variables from outer scope in the TVMScript
function/class. The parsed result should be identical to function/class with
the variable replaced by its value. For instance,
```python
@T.prim_func
def matmul(
A: T.Buffer[(128, 128)],
) -> None:
...
def gen_matmul(n, m) -> None:
@T.prim_func
def f(A: T.Buffer[(n, m)]):
...
return f
f = gen_matmul(n=128, m=128) # `f` should be identical to `matmul`
```
This is already partially supported by https://github.com/apache/tvm/pull/11097
for using `PrimExpr` captured by outer function. With the new parser, we want
to support this feature in more places and with more variable types.
### (F2) Rank-polymorphism
Users should be able to write a single function to handle different ranks of
input buffers (different numbers of dimensions). For example, user should be
able to write a generic function to do broadcast add,
```python
def broadcast_add(a, b, c):
@T.prim_func
def f(
A: T.BufferFrom(a),
B: T.BufferFrom(b),
C: T.BufferFrom(c),
) -> None:
for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape):
with T.block():
C[*i] = A[*i_a] + B[*i_b]
broadcast_add(
a = Buffer((128, 1), "float32"),
b = Buffer((1, 128), "float32"),
c = Buffer((128, 128), "float32"),
)
```
### (F3) Sugar: TE Compute in TIR
Users should be able to replace boilerplate code with a function call, which’s
expanded to large chunk of code during parsing. For example, we may want to use
TE’s compute-like syntax to replace nested loop,
```python
@T.prim_func
def te_compute_sugar(
A: T.Buffer[(128, 128)],
B: T.Buffer[(128, 128)],
) -> None:
...
C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
...
## expands to ====>
@T.prim_func
def te_compute_expanded(
A: T.Buffer[(128, 128)],
B: T.Buffer[(128, 128)],
) -> None:
...
for i in range(128):
for j in range(128):
with T.block("..."):
C[i, j] = A[i, j] + B[i, j]
...
```
### (F4) Interleave host program and TVMScript program to customize metaprogramming
As an escape hatch from writing code to be parsed by the TVMScript
parser, users should be able to write imperative code to construct IR nodes
directly and embed it inside regular TVMScript. Those code will be evaluated
by the Python interpreter when parsing. This gives users the ultimate tool when
TVMScript isn’t expressible enough for their use cases. For example, at
[python/tvm/topi/vision/nms.py#L380-L431](https://github.com/apache/tvm/blob/3cb4597ed48360e3f3d80161d1c03f833072d28e/python/tvm/topi/vision/nms.py#L380-L431),
there are blocks of repetitive code on computing the coordinates of the four
corners of bounding box. This can be simplified as:
```python
# Before, without IRBuilder interleaving
@T.prim_func
def nms(...):
...
for i in range(batch_size):
...
a_l = min(
output[batch_idx, box_a_idx, box_start_idx],
output[batch_idx, box_a_idx, box_start_idx + 2],
)
a_t = min(
output[batch_idx, box_a_idx, box_start_idx + 1],
output[batch_idx, box_a_idx, box_start_idx + 3],
)
a_r = max(
output[batch_idx, box_a_idx, box_start_idx],
output[batch_idx, box_a_idx, box_start_idx + 2],
)
a_b = max(
output[batch_idx, box_a_idx, box_start_idx + 1],
output[batch_idx, box_a_idx, box_start_idx + 3],
)
...
for k in range(j):
check_iou = ...
...
if check_iou > 0:
# b_l: left, b_t: top, b_r: right, b_b: bottom
b_l = min(
output[batch_idx, box_b_idx, box_start_idx],
output[batch_idx, box_b_idx, box_start_idx + 2],
)
b_t = min(
output[batch_idx, box_b_idx, box_start_idx + 1],
output[batch_idx, box_b_idx, box_start_idx + 3],
)
b_r = max(
output[batch_idx, box_b_idx, box_start_idx],
output[batch_idx, box_b_idx, box_start_idx + 2],
)
b_b = max(
output[batch_idx, box_b_idx, box_start_idx + 1],
output[batch_idx, box_b_idx, box_start_idx + 3],
)
...
# With IRBuilder interleaving:
from tvm.script import tir as T
def get_box_coordinates(output, batch_idx, box_idx, box_start_idx):
"""a method executed by python interpreter"""
box_l = T.min(
output[batch_idx, box_idx, box_start_idx],
output[batch_idx, box_idx, box_start_idx + 2],
) # type(box_l) is PrimExpr
... # Repeat for other coordinates
return box_l, box_t, box_r, box_b
@T.prim_func(capture=[get_box_coordinates])
def nms(...):
...
for i in range(batch_size):
...
a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
...
for k in range(j):
check_iou = ...
...
if check_iou > 0:
b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx)
...
```
# Reference-level explanation
[reference-level-explanation]: #reference-level-explanation
## IRBuilder as Core
As the foundation of IR construction, we will provide a set of APIs called
IRBuilder to let user construct IR imperatively. IRBuilder will be used by the
parser, as well as by users directly as described in the feature F4. IRBuilder
allows user to write code in a style that’s similar to TVMScript, while it’s
being executed as host program. For example,
```python
from tvm.script.builder import Builder, def_, def_many
from tvm.script import tir as T
with Builder() as b:
with T.prim_func():
T.func_name("main")
buffer_a = T.Buffer((128, 128, 128), "float32")
buffer_b = T.Buffer((128, 128, 128), "float32")
arg_a = T.arg("A", buffer_a)
arg_b = T.arg("B", buffer_b)
with T.grid(128, 128, 128) as (i, j, k):
def_many(["i", "j", "k"], [i, j, k])
with T.block(name="block"):
vi = def_("vi", T.axis.spatial(128, i))
vj = def_("vj", T.axis.spatial(128, j))
vk = def_("vk", T.axis.reduce(128, k))
f = b.get() # f is a PrimFunc
```
produces equivalent result as
```python
@T.prim_func
def main(
A: T.Buffer(shape=(128, 128, 128), dtype="float32"),
B: T.Buffer(shape=(128, 128, 128), dtype="float32"),
) -> None:
for i, j, k in T.grid(128, 128, 128):
with T.block("block"):
vi = T.axis.S(128, i)
vj = T.axis.S(128, j)
vk = T.axis.R(128, k)
```
As shown in the example above, user doesn't need to pass the builder `b` to
subsequent calls to IRBuilder API. The current builder state is maintained in a
threadlocal store to improve the ergonomics of IRBuilder API by avoiding
passing the builder state explicitly.
The implementation of IRBuilder will be in C++ so that it can be used in an
environment without Python. Python binding will be created to expose IRBuilder
to TVMScript parser.
With the separation between IRBuilder and parser, most implementation and
documentation can be reused between TVMScript and IR definition. For example,
most of operators are simply imported into the IRBuilder package, like
```python
from tvm.tir import sin, cos
```
so the documentation and type signatures only need to be written once, and the
APIs are guaranteed to be consistent.
## Parse-time evaluation
TVMScript Parser can be considered as a thin layer built above the IRBuilder
API. The parser transforms input AST into a sequence of calls to IRBuilder API,
by evaluating small fragments of code as it visits AST. The IRBuilder is
responsible for building the actual IR graph. All metaprogramming features we
discussed above (F1 through F4) can be implemented through this parse-time
evaluation in a consistent manner. Using the same `gen_matmul` example from F1,
```python
def gen_matmul(n, m) -> None:
@T.prim_func
def f(A: T.Buffer[(n, m)]):
...
return f
```
What parser does here is to:
1. Collect the environment inside `gen_matmul`, getting a dictionary
1. All primitive types (`int`, `str`, `float` and `None`) will be captured
automatically, while advanced types (like function) needs to be
explicitly declared in the decorator to be captured (for example,
`@T.prim_func(capture=[get_box_coordinates])`)
2. Call the corresponding function from IRBuilder, as the parser visits the AST
of function `f`.
1. When visiting the function argument, call `eval` on its type annotation,
with the environment captured in the first step.
2. `T.Buffer[(n, m)]` gets evaluated to a value with type `tir.Buffer`.
3. Call the IRBuilder API `T.arg("A", buffer)` to add an arg to the function
that’s being constructed
Another example,
```python
for *i in T.grid(*A.shape):
...
```
The parser will:
1. Evaluate the expression `T.grid(*A.shape)` by the step described above.
`T.grid` returns a value that is nearly equivalent to `List[Var]`.
2. Call `exec` on a specially constructed statement `*i = __tvm_rhs_var__`,
with `locals` that maps `__tvm_rhs_var__` to the value evaluated in step 1.
3. Collect the value of `i` from the `locals` dictionary
4. Call the IRBuilder API `def_many(["i"], [i])`
As mentioned above, all metaprogramming features (F1 through F4) can be
implemented through this parse-time evaluation. It's straightforward to see how
F1 and F2 are implemented by parse-time evaluation, but it might be harder to
grasp the idea behind F3 and F4.
For F3 (TE Compute in TIR),
```python
C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
# build a similar graph as
for i in range(128):
for j in range(128):
with T.block("..."):
C[i, j] = A[i, j] + B[i, j]
```
`T.compute` is provided in IRBuilder API and will construct all IR nodes (For,
BlockRealize, Block and BufferStore). `T.compute` calls the lambda function to
get the rhs of BufferStore. Then `T.compute` returns a `Buffer` node that
represents `C`. The parser handles the assignment by assigning `"C"` to the
`name_hint` of the returned buffer (by calling IRBuilder API `def_("C", C)`),
and put it into the internal variable table (which is used to resolve variable
when evaluating subsequent statements and expressions).
For F4 (Interleave host program and TVMScript program),
```python
a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
```
The call to the `get_box_coordinates` function is evaluated when parser is visiting the
assign statement. The parser calls IRBuilder `def_many(["a_l", "a_t", "a_r", "a_b"], <returned_tuple>)`
and put them into the internal variable table.
Note that we will not place extra restriction on the signature of user-provided
function (`get_box_coordinates` in this example). More precisely, the function
argument types can be anything because parser is able to capture outter
variables thus bringing variables with arbitrary type into the scope. The restriction on
returned type depends on the language spec of the target IR. For example, in
TIR user can write `A[i] = ...` to represent `BufferStore`, where the rhs is a
`PrimExpr`, then the user can provide a custom function
`compute_magic_number(index: PrimExpr) -> PrimExpr` to use on the rhs as `A[i]
= compute_magic_number(i)`.
By running `eval` and `exec` provided by the Python interpreter, we can implement
language features which are difficult to implement manually, and also make sure
TVMScript has the same semantics on expression compared to regular Python code.
## Parser Registration
The logic of how to process a particular Python syntax node is registered
through decorator. For example,
```python
@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
...
```
handles the transformation of Python `For` node. The `token="tir"` in the
decorator means that the handler is for TIR. `self: Parser` has all the
infrastructural API and maintains a stack of `token` to determine which
function to dispatch to. This makes embedding different IR possible (for
example, embedding TIR in Relax). The folder structure will look like
```
python/tvm/script/
└── parser
   ├── core
   │   └── ... # Parser infrastructure
   ├── tir # TIR dialect
   │   └── ...
└── relax # Hypothetical Relax dialect (not part of our RFC)
└── ...
```
## Usage of `eval` and `exec`
The parser uses `eval` and `exec` in the following places:
- It calls `eval` to evaluate fragment of expressions
- It calls `exec` to evaluate different kinds of assignment statement, like `first, *rest, last = T.grid(...)`
The usage of `eval` and `exec` is necessary to our implementation. TVMScript
allows users to construct IR graph in TVM declaratively as if they were writing
Python code, lowering the barrier of using TVM to do low-level customization.
However it is still restrictive and does not allow the usage of many Python
features to build abstractions for user's code. All features proposed in this
RFC can be seen as an effort to narrow this gap, thus they are designed to
follow the Python semantics. On the implementation side, the most robust
approach is to leverage the Python interpreter itself to facilitate those
features, rather than write our own version of restricted Python interpreter.
And `eval` and `exec` are the most suitable choices to achieve this. Other
mechanism, like multiprocess + IPC and subinterpreter, either lacks an easy
path to exchange Python objects, or requires dependency on the C API of CPython.
In our use cases, `eval` and `exec` do not create additional
security risk for users. All inputs to `eval` come from user's code
directly, without modification. Our usage of `exec` only executes a
specific form of code,
```python
<lhs> = __tvm_rhs_var__
```
where `<lhs>` is the tokens from the left hand side of assign statement,
directly from user's code. The situation is very different from cases that make
`eval` and `exec` infamous, where they are used to evaluate untrusted
input from end users, typically in a web server. Furthermore, we require users
to explicitly capture variables. Therefore, The evaluation will only involve
objects from `tvm`, Python builtins and objects explicitly captured by users.
This rules out the possibility that an external function is called by parser
without user's acknowledgement. If the malicious actor wants to exploit the
system through the `eval` or `exec` in TVMScript parser, they must first get
another RCE (remote code execution) vulnerability in the Python runtime to
modify the code in runtime, which makes such exploit useless (one needs to
first have an RCE to exploit another RCE with the same exposure).
# Drawbacks
[drawbacks]: #drawbacks
N/A
# Rationale and alternatives
[rationale-and-alternatives]: #rationale-and-alternatives
N/A
# Prior art
[prior-art]: #prior-art
### Prior works in TVM
- Hybrid Script: [https://tvm.apache.org/docs/reference/langref/hybrid_script.html](https://tvm.apache.org/docs/reference/langref/hybrid_script.html)
- RFC for TVMScript: [https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516](https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516)
### Other libraries
#### Taichi
[https://www.taichi-lang.org](https://www.taichi-lang.org/)
Taichi has a very similar metaprogramming model ([link to doc](https://docs.taichi.graphics/docs/master/meta)) as we presented in this RFC.
The biggest difference is that Taichi requires `ti.static` to be wrapped around everything that needs to be evaluated in compile time. It also
has advanced features like loop unrolling, compile time branching and compile-time recursion.
In TVMScript parser, it does not need special marker to denote compile-time
evaluation. Expressions are consistently evaluated in compile time (evaluated
to IR node like PrimExpr, rather than the concrete value like a float matrix.),
thanks to the separation of IRBuilder and parser. Features like loop
unrolling can be implemented in the IRBuilder layer per target IR. This keeps
the core parser as minimal as possible.
#### Triton
[http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)
Triton uses the meta-parameters to generalize kernels (for example,
[tutorials/03-matrix-multiplication.html](https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html)).
Meta parameters are placed together with real parameters, but with type
annotation `tl.constexpr` to differentiate. This method slightly deviates from
regular Python semantics, as users will intuitively expect to pass them
together with real parameters when calling the kernel.
In TVMScript, one of the design principle is to narrow the gap between
TVMScript and regular Python code. TVMScript should not surprise users with
syntax or language features that deviate from Python. All features proposed in
this RFC are designed to strictly follow the semantics of Python and aimed to
be intuitive to Python users.
# Unresolved questions
[unresolved-questions]: #unresolved-questions
N/A
# Future possibilities
[future-possibilities]: #future-possibilities
N/A