blob: e9b4286edad804ecb221217abf49091a1eaf42e4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser For TIR
We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
import types
import json
import operator
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from synr import ast, Transformer, to_ast
import tvm
from tvm import IRModule
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
from tvm.tir import buffer
from tvm.tir.function import PrimFunc
from . import _ffi_api
from . import tir
from .context_maintainer import ContextMaintainer
from .meta_unparser import MetaUnparser
from .registry import Registry
from .diagnostics import TVMDiagnosticCtx
from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
from .tir.intrin import Intrin
from .tir.node import Slice, BufferSlice
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
from .tir.special_stmt import SpecialStmt
from .tir import ty
class CallArgumentReader(object):
"""Helper class to read required arguments from passed arguments.
When parsing a function call, we need to match the arguments provided in
the AST to the required arguments of the function. This class makes sure
all the positional arguments are filled and also fill keyword arguments
with thier default value if a different value was not provided.
"""
def __init__(self, func_name, args, kwargs, parser, node):
self.func_name = func_name
self.args = args
self.kwargs = kwargs
self.parser = parser
self.node = node
def get_pos_only_arg(self, pos, name):
"""Get corresponding position only function argument from argument list"""
if len(self.args) >= pos:
arg = self.args[pos - 1]
elif name not in self.kwargs:
# If no positional argument was found in the AST, we see if it was
# defined by name instead.
# TODO(tkonolige): this error message is not quite correct. The
# number of required arguments is >= pos
self.parser.report_error(
f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.",
self.node.span,
)
else:
arg = self.kwargs[name]
return arg
def get_kwarg(self, pos, name, default):
"""Get corresponding keyword function argument from argument list.
If the user hasn't provided the argument, set it to the default value.
"""
if len(self.args) >= pos:
arg = self.args[pos - 1]
elif name in self.kwargs:
arg = self.kwargs[name]
else:
return default
return arg
def get_varargs(self, pos):
"""Get corresponding variable argument from argument list"""
if len(self.args) >= pos and len(self.kwargs) == 0:
return self.args[pos - 1 :]
return []
class TVMScriptParser(Transformer):
"""Synr AST visitor pass which finally lowers to TIR.
Notes for Extension
-------------------
1. To support a new type of AST node, add a function transform_xxx().
2. To support new functions, add the function to the appropriate registry:
We divide allowed function calls in TVM script into 3 categories,
intrin, scope_handler and special_stmt.
1. intrin functions are low level functions like mod, load, and
constants. They correspond to a tir `IRNode`. They must have a
return value. The user can register intrin functions for the parser to
use.
2. scope_handler functions have no return value. They take two
arguments: the parser and the AST node. scope_handler functions are
used in with and for statements.
3. special_stmt functions handle cases that do not have a corresponding
tir `IRNode`. These functions take the parser and the AST node as
arguments and may return a value.
When visiting a Call node, we check the special_stmt registry first. If
no registered function is found, we then check the intrin registry.
When visiting With node, we check the with_scope registry.
When visiting For node, we check the for_scope registry.
"""
_binop_maker = {
ast.BuiltinOp.Add: tvm.tir.Add,
ast.BuiltinOp.Sub: tvm.tir.Sub,
ast.BuiltinOp.Mul: tvm.tir.Mul,
ast.BuiltinOp.Div: tvm.tir.Div,
ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
ast.BuiltinOp.Mod: tvm.tir.FloorMod,
ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
ast.BuiltinOp.GT: tvm.tir.GT,
ast.BuiltinOp.GE: tvm.tir.GE,
ast.BuiltinOp.LT: tvm.tir.LT,
ast.BuiltinOp.LE: tvm.tir.LE,
ast.BuiltinOp.Eq: tvm.tir.EQ,
ast.BuiltinOp.NotEq: tvm.tir.NE,
ast.BuiltinOp.And: tvm.tir.And,
ast.BuiltinOp.Or: tvm.tir.Or,
}
_unaryop_maker = {
ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
ast.BuiltinOp.Not: tvm.tir.Not,
}
# pylint gets confused here with synr.Transformer which doesn't have a
# custom init, so just disable it
def __init__(
self, base_lineno, tir_namespace, closure_vars
): # pylint: disable=super-init-not-called
self.context = None
self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
self.closure_vars = closure_vars
self.meta = None
self._inside_buffer_sugar = False
def init_function_parsing_env(self):
"""Initialize function parsing environment"""
self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter
def init_meta(self, meta_dict):
if meta_dict is not None:
self.meta = tvm.ir.load_json(json.dumps(meta_dict))
def transform(self, node):
"""Generic transformation for visiting the AST. Dispatches to
`transform_ClassName` for the appropriate ClassName."""
old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
if hasattr(node, "lineno"):
self.current_lineno = self.base_lineno + node.lineno - 1
if hasattr(node, "col_offset"):
self.current_col_offset = node.col_offset
method = "transform_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
transform_res = visitor(node)
self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
return transform_res
def match_tir_namespace(self, identifier: str) -> bool:
"""Check if the namespace is equal to tvm.script.tir"""
return identifier in self.tir_namespace
def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
"""Report an error occuring at a location.
This just dispatches to synr's DiagnosticContext.
Parameters
----------
message : str
Error message
span : Union[synr.ast.Span, tvm.ir.Span]
Location of the error
"""
if isinstance(span, tvm.ir.Span):
span = synr_span_from_tvm(span)
self.error(message, span)
def parse_body(self, parent):
"""Parse remaining statements in this scope.
Parameters
----------
parent : synr.ast.Node
Parent node of this scope. Errors will be reported here.
"""
body = []
spans = []
stmt = parent
while len(self.context.node_stack[-1]) > 0:
stmt = self.context.node_stack[-1].pop()
spans.append(stmt.span)
res = self.transform(stmt)
if res is not None:
body.append(res)
if len(body) == 0:
self.report_error(
"Expected another statement at the end of this block. Perhaps you "
"used a concise statement and forgot to include a body afterwards.",
stmt.span,
)
else:
return (
tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans)))
if len(body) > 1
else body[0]
)
def parse_arg_list(self, func, node_call):
"""Match the arguments of a function call in the AST to the required
arguments of the function. This handles positional arguments,
positional arguments specified by name, keyword arguments, and varargs.
Parameters
----------
func : Function
The function that provides the signature
node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
The AST call node that calls into the function.
Returns
-------
arg_list : list
The parsed positional argument.
"""
assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
# collect arguments
args = [self.transform(arg) for arg in node_call.params]
if isinstance(node_call, ast.TypeApply):
kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
else:
kw_args = {
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
}
# get the name and parameter list of func
if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
func_name, param_list = func.signature()
else:
self.report_error(
"Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
f"but it is {type(func).__name__}",
node_call.span,
)
# check arguments and parameter list and get a list of arguments
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
pos_only, kwargs, varargs = param_list
internal_args = list()
for i, arg_name in enumerate(pos_only):
internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
for i, arg_info in enumerate(kwargs):
arg_name, default = arg_info
internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default))
if varargs is not None:
internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
self.report_error(
"Arguments mismatched. "
+ f"Expected {len(pos_only) + len(kwargs)} args but got "
+ f"{len(args) + len(kw_args)}",
node_call.span,
)
return internal_args
def parse_type(self, type_node, parent):
"""Parse a type annotation.
We require the parent object to the type so that we have a place to
report the error message if the type does not exist.
"""
if type_node is None:
self.report_error("A type annotation is required", parent.span)
res_type = self.transform(type_node)
return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
def generic_visit(self, node):
"""Fallback visitor if node type is not handled. Reports an error."""
self.report_error(type(node).__name__ + " AST node is not supported", node.span)
def transform_Module(self, node):
"""Module visitor
Right now, we only support two formats for TVM Script.
Example
-------
1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
.. code-block:: python
import tvm
@tvm.script
def A(...):
...
# returns a PrimFunc
func = A
2. Generate an IRModule
.. code-block:: python
import tvm
@tvm.script.ir_module
class MyMod():
@T.prim_func
def A(...):
...
@T.prim_func
def B(...):
...
__tvm_meta__ = ...
# returns an IRModule
mod = MyMod
"""
if len(node.funcs) == 1:
return self.transform(next(iter(node.funcs.values())))
elif len(node.funcs) == 0:
self.report_error(
"You must supply at least one class or function definition", node.span
)
else:
self.report_error(
"Only one-function, one-class or function-with-meta source code is allowed",
ast.Span.union([x.span for x in list(node.funcs.values())[1:]]),
)
def transform_Class(self, node):
"""Class definition visitor.
A class can have multiple function definitions and a single
:code:`__tvm_meta__` statement. Each class corresponds to a single
:code:`IRModule`.
Example
-------
.. code-block:: python
@tvm.script.ir_module
class MyClass:
__tvm_meta__ = {}
def A():
T.evaluate(0)
"""
if len(node.assignments) == 1:
if not (
len(node.assignments[0].lhs) == 1
and isinstance(node.assignments[0].lhs[0], ast.Var)
and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
):
self.report_error(
"The only top level assignments allowed are `__tvm_meta__ = ...`",
node.assignments[0].span,
)
self.init_meta(
MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
)
elif len(node.assignments) > 1:
self.report_error(
"Only a single top level `__tvm_meta__` is allowed",
ast.Span.union([x.span for x in node.assignments[1:]]),
)
return IRModule(
{GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
)
def transform_Function(self, node):
"""Function definition visitor.
Each function definition is translated to a single :code:`PrimFunc`.
There are a couple restrictions on TVM Script functions:
1. Function arguments must have their types specified.
2. The body of the function can contain :code:`func_attr` to specify
attributes of the function (like it's name).
3. The body of the function can also contain multiple :code:`buffer_bind`s,
which give shape and dtype information to arguments.
4. Return statements are implicit.
Example
-------
.. code-block:: python
@T.prim_func
def my_function(x: T.handle): # 1. Argument types
T.func_attr({"global_symbol": "mmult"}) # 2. Function attributes
X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding
T.evaluate(0) # 4. This function returns 0
"""
def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
if isinstance(decorator, ast.Call):
if len(decorator.params) != 1:
return False
func_name = decorator.func_name
else:
func_name = decorator
if isinstance(func_name, ast.Var):
return func_name.id.name == "as_torch"
def check_decorator(decorators: List[ast.Expr]) -> bool:
"""Check the decorator is `T.prim_func"""
if len(decorators) > 2 or len(decorators) == 0:
return False
if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
return False
d: ast.Expr = decorators[-1]
return (
isinstance(d, ast.Attr)
and isinstance(d.object, ast.Var)
and self.match_tir_namespace(d.object.id.name)
and d.field.name == "prim_func"
)
self.init_function_parsing_env()
self.context.enter_scope(nodes=node.body.stmts)
# add parameters of function
for arg in node.params:
# Note that this case is for T.match_buffer syntax sugar
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
self.transform(arg.ty.func_name), ty.GenericBufferType
):
result = self.handle_match_buffer_type(arg.ty, arg.name)
if not isinstance(result, buffer.Buffer):
self.report_error(
"The result type of evaluating TypeCall and TypeApply stmt"
f" is wrong: {type(result)}. It should be a Buffer",
node.span,
)
arg_name_with_handle = arg.name + "_handle"
arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
self.context.func_buffer_map[arg_var] = result
self.context.update_symbol(arg.name, result, node)
else:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
if not check_decorator(node.decorators):
self.report_error(
"All functions should be decorated by `T.prim_func`",
node.span,
)
# fetch the body of root block
body = self.parse_body(node.body)
# return a tir.PrimFunc
dict_attr = self.context.func_dict_attr
ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None
func = tvm.tir.PrimFunc(
self.context.func_params,
body,
ret_type,
buffer_map=self.context.func_buffer_map,
preflattened_buffer_map=self.context.func_preflattened_buffer_map,
attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
span=tvm_span_from_synr(node.span),
)
# New Scope : Implicit root block
# Each function contains an implicit root block in TensorIR,
# so here we need a block scope for it.
# If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
# the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
# Fix the PrimFunc
# 1. generate root block if necessary
# 2. generate surrounding loops for blocks if necessary
func = call_with_error_reporting(
self.report_error,
node.span,
_ffi_api.Complete,
func,
self.context.root_alloc_buffers,
)
self.context.exit_scope()
return func
def transform_Lambda(self, node):
"""Lambda visitor
Return an array of input parameters and the transformed lambda body.
"""
self.context.enter_scope(nodes=[node.body])
# add parameters of the lambda
arg_vars = []
for arg in node.params:
# Use "void" for dtype here. The actual type is not yet known and will be
# determined later. Using void type will allow IRSubstitute to do the
# replacement without flagging a type-mismatch error.
arg_var = tvm.te.var(arg.name, dtype="")
arg_vars.append(arg_var)
self.context.update_symbol(arg.name, arg_var, node)
# the body of a lambda must be an expr
if not isinstance(node.body, ast.Expr):
self.report_error("The body of a lambda must be an expression", node.span)
# transform the body of the lambda
body = self.transform(node.body)
self.context.exit_scope()
return arg_vars, body
def transform_Assign(self, node):
"""Assign visitor
AST abstract grammar:
Assign(expr* targets, expr value, string? type_comment)
By now 5 patterns of Assign is supported:
1. special stmts with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
1.3 Var = T.env_thread()
2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = T.allocate()
5. A call to a pure python function, consuming and producing TVMScript values.
The outputs are inlined into the following body (no variable is created).
x, y = f(...)
"""
if isinstance(node.rhs, ast.Call):
# Pattern 1 & Pattern 4
if isinstance(node.rhs.func_name, ast.Op):
func = None
else:
func = self.transform(node.rhs.func_name)
if isinstance(func, WithScopeHandler):
if not func.concise_scope or not func.def_symbol:
self.report_error(
"with scope handler " + func.signature()[0] + " is not suitable here",
node.rhs.span,
)
# Pattern 4
arg_list = self.parse_arg_list(func, node.rhs)
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
elif isinstance(func, SpecialStmt):
# Pattern 1
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
elif isinstance(func, types.FunctionType):
# Pattern 5
args = [self.transform(arg) for arg in node.rhs.params]
try:
out = func(*args)
except Exception as e:
self.report_error(
"Error occurred when invoking the function "
+ func.__name__
+ ": \n"
+ str(e),
node.rhs.span,
)
if len(node.lhs) == 1 and not isinstance(out, list):
out = [out]
assert len(out) == len(node.lhs)
for var, value in zip(node.lhs, out):
self.context.update_symbol(var.id.name, value, node)
body = self.parse_body(node)
for var, value in zip(node.lhs, out):
self.context.remove_symbol(var.id.name)
return body
if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
# This is a little confusing because it only is true when
# we have taken this branch. We might need to clarify what
# exectly is allowed in Assignments in tvmscript.
self.report_error(
"Left hand side of assignment must be an unqualified variable",
node.span,
)
ast_var = node.lhs[0]
if node.ty is None and hasattr(value, "dtype"):
var_ty = value.dtype
else:
var_ty = self.parse_type(node.ty, ast_var)
var = tvm.te.var(
ast_var.id.name,
var_ty,
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
self.context.remove_symbol(var.name)
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
self.report_error(
"""Assignments should be one of:
1. A "special statement" with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
1.3 Var = T.env_thread()
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
3. A store into a variable: Var[PrimExpr] = PrimExpr
4. A with scope handler with concise scoping and var def
4.1 var = T.allocate()
5. The right-hand side being a call to a pure python function, consuming and
producing TVMScript values.
x, y = f(...)""",
node.span,
)
def transform_SubscriptAssign(self, node):
"""Visitor for statements of the form :code:`x[1] = 2`."""
symbol = self.transform(node.params[0])
indexes = self.transform(node.params[1])
rhs = self.transform(node.params[2])
rhs_span = tvm_span_from_synr(node.params[2].span)
if isinstance(symbol, tvm.tir.Buffer):
if len(indexes) != len(symbol.shape):
self.report_error(
f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
f"cannot be indexed by {len(indexes)}-dimensional indices.",
node.params[1].span,
)
def __convert_index(x):
if isinstance(x, Slice):
return x.as_index_expr(self.report_error)
return x
# BufferStore
indexes = [__convert_index(x) for x in indexes]
return tvm.tir.BufferStore(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
indexes,
span=tvm_span_from_synr(node.span),
)
else:
if symbol.dtype == "handle" and len(indexes) != 1:
self.report_error(
"Handles only support one-dimensional indexing. Use `T.match_buffer` to "
"construct a multidimensional buffer from a handle.",
node.params[0].span,
)
if len(indexes) != 1:
self.report_error(
f"Store is only allowed with one index, but {len(indexes)} were provided.",
node.params[1].span,
)
self.report_error(
"Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
)
def transform_AttrAssign(self, node):
"""Visitor for statements of the form :code:`x.y = 2`."""
obj = self.transform(node.params[0])
field = node.params[1]
value = self.transform(node.params[2])
if not hasattr(obj, field.name):
self.error(f"Field {field.name} does not exist", field.span)
var = getattr(obj, field.name)
if not isinstance(var, tvm.tir.Var):
self.error(
f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
)
body = self.parse_body(node)
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
def transform_Assert(self, node):
"""Assert visitor
Pattern corresponds to concise mode of :code:`with T.Assert()`.
"""
condition = self.transform(node.condition)
if node.msg is None:
self.report_error("Assert statements must have an error message.", node.span)
message = self.transform(node.msg)
body = self.parse_body(node)
return tvm.tir.AssertStmt(
condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span)
)
def transform_For(self, node):
"""For visitor
AST abstract grammar:
For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
By now 1 pattern of For is supported:
1. for scope handler
for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
T.grid()/T.thread_binding()
"""
if not isinstance(node.rhs, ast.Call):
self.report_error("The loop iterator should be a function call.", node.rhs.span)
func = self.transform(node.rhs.func_name)
if not isinstance(func, ForScopeHandler):
self.report_error(
"Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
)
# prepare for new for scope
old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
self.current_lineno = node.span.start_line
self.current_col_offset = node.span.start_column
self.context.enter_scope(nodes=node.body.stmts)
# for scope handler process the scope
arg_list = [
tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
for arg in self.parse_arg_list(func, node.rhs)
]
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
# exit the scope
self.context.exit_scope()
self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
return res
def transform_While(self, node):
"""While visitor
AST abstract grammar:
While(expr condition, stmt* body)
"""
condition = self.transform(node.condition)
# body
self.context.enter_scope(nodes=node.body.stmts)
body = self.parse_body(node)
self.context.exit_scope()
return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))
def transform_With(self, node):
"""With visitor
AST abstract grammar:
With(withitem* items, stmt* body, string? type_comment)
withitem = (expr context_expr, expr? optional_vars)
By now 2 patterns of With is supported:
1. with scope handler with symbol def
with T.allocate() as targets:
2. with scope handler without symbol def
with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize()
"""
if not isinstance(node.rhs, ast.Call):
self.report_error(
"The context expression of a `with` statement should be a function call.",
node.rhs.span,
)
func = self.transform(node.rhs.func_name)
if not isinstance(func, WithScopeHandler):
self.report_error(
f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
)
# prepare for new block scope
old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
self.current_lineno = node.body.span.start_line
self.current_col_offset = node.body.span.start_column
self.context.enter_block_scope(nodes=node.body.stmts)
# with scope handler process the scope
arg_list = self.parse_arg_list(func, node.rhs)
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
# exit the scope
self.context.exit_block_scope()
self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
return res
def transform_If(self, node):
"""If visitor
AST abstract grammar:
If(expr test, stmt* body, stmt* orelse)
"""
condition = self.transform(node.condition)
# then body
self.context.enter_scope(nodes=node.true.stmts)
then_body = self.parse_body(node)
self.context.exit_scope()
# else body
if len(node.false.stmts) > 0:
self.context.enter_scope(nodes=node.false.stmts)
else_body = self.parse_body(node)
self.context.exit_scope()
else:
else_body = None
return tvm.tir.IfThenElse(
condition, then_body, else_body, span=tvm_span_from_synr(node.span)
)
def transform_Call(self, node):
"""Call visitor
3 different Call patterns are allowed:
1. Intrin representing a PrimExpr/IterVar
1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
1.2 tir.range/reduce_axis/scan_axis/opaque_axis
2. tir.Op(dtype, ...)
3. other callable functions
"""
if isinstance(node.func_name, ast.Op):
if node.func_name.name == ast.BuiltinOp.Subscript:
return self.transform_Subscript(node)
if node.func_name.name in self._binop_maker:
lhs = self.transform(node.params[0])
# There is no supertype for everything that can appear in
# an expression, so we manually add what we might get here.
if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
# We would really like to report a more specific
# error here, but this parser contains no distinction
# between parsing statements and parsing expressions. All
# rules just call `transform`.
self.report_error(
f"Left hand side of binary op must be a PrimExpr, "
"but it is a {type(lhs).__name__}",
node.params[0].span,
)
rhs = self.transform(node.params[1])
if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
self.report_error(
f"Right hand side of binary op must be a PrimExpr, "
"but it is a {type(rhs).__name__}",
node.params[1].span,
)
return call_with_error_reporting(
self.report_error,
node.span,
lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
lhs, rhs, span=span
),
node,
lhs,
rhs,
tvm_span_from_synr(node.span),
)
if node.func_name.name in self._unaryop_maker:
rhs = self.transform(node.params[0])
return self._unaryop_maker[node.func_name.name](
rhs, span=tvm_span_from_synr(node.span)
)
self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
else:
func = self.transform(node.func_name)
if isinstance(func, Intrin) and not func.stmt:
# pattern 1
arg_list = self.parse_arg_list(func, node)
return call_with_error_reporting(
self.report_error,
node.func_name.span,
func.handle,
arg_list,
node.func_name.span,
)
else:
args = [self.transform(arg) for arg in node.params]
kw_args = {
self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
}
if isinstance(func, tvm.tir.op.Op):
if not "dtype" in kw_args.keys():
self.report_error(f"{func} requires a dtype keyword argument.", node.span)
# pattern 2
return tvm.tir.Call(
kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
)
elif callable(func):
# pattern 3
return func(*args, **kw_args)
else:
self.report_error(
f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
node.func_name.span,
)
def transform_UnassignedCall(self, node):
"""Visitor for statements that are function calls.
This handles function calls that appear on thier own line like `tir.realize`.
Examples
--------
.. code-block:: python
@T.prim_func
def f():
A = T.buffer_decl([10, 10])
T.realize(A[1:2, 1:2], "") # This is an UnassignedCall
A[1, 1] = 2 # This is also an UnassignedCall
"""
# Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
if isinstance(node.call.func_name, ast.Op):
if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
return self.transform_SubscriptAssign(node.call)
if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
return self.transform_AttrAssign(node.call)
self.report_error(
"Binary and unary operators are not allowed as a statement", node.span
)
# handle a regular function call
func = self.transform(node.call.func_name)
arg_list = self.parse_arg_list(func, node.call)
if isinstance(func, tir.scope_handler.AssertHandler):
self.report_error(
"A standalone `T.Assert` is not allowed. Use `assert condition, message` "
"instead.",
node.call.func_name.span,
)
if isinstance(func, Intrin):
if func.stmt:
return call_with_error_reporting(
self.report_error,
node.call.func_name.span,
func.handle,
arg_list,
node.call.func_name.span,
)
else:
self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
func.body = self.parse_body(node)
return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
elif isinstance(func, SpecialStmt) and not func.def_symbol:
func.handle(node, self.context, arg_list, node.call.func_name.span)
return
self.report_error(
"Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
f"special statement, but got {type(func).__name__}.",
node.call.func_name.span,
)
def transform_Slice(self, node):
"""Index slice visitor."""
start = self.transform(node.start)
end = self.transform(node.end)
if not (
isinstance(node.step, ast.Constant)
and isinstance(node.step.value, int)
and node.step.value > 0
):
self.report_error(
"Only positive integer step size is supported for slices.", node.step.span
)
return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
def transform_Subscript(self, node):
"""Array access visitor.
By now only 3 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
Var[index] Buffer element access()
2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
3. Array[index], Buffer element access
"""
symbol = self.transform(node.params[0])
if symbol is None:
self.report_error(
f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
)
indexes = [self.transform(x) for x in node.params[1].values]
if isinstance(symbol, tvm.tir.expr.Var):
if symbol.dtype == "handle":
self.report_error(
"Cannot read directly from a handle, use `T.match_buffer` "
"to create a buffer to read from.",
node.params[0].span,
)
if len(indexes) > 1:
self.report_error(
"Only a single index can be provided when indexing into a `var`.",
node.params[1].span,
)
index = indexes[0]
if not isinstance(index, (tvm.tir.PrimExpr, int)):
self.report_error(
"Var load index should be an int or PrimExpr, but it is a" + type(index),
node.span,
)
self.report_error(
"Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
)
elif isinstance(symbol, tvm.tir.Buffer):
return BufferSlice(
symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
)
elif isinstance(symbol, tvm.container.Array):
if len(indexes) > 1:
self.report_error(
"Array access should be one-dimension access, but the indices are "
+ str(indexes),
node.span,
)
index = indexes[0]
if not isinstance(index, (int, tvm.tir.expr.IntImm)):
self.report_error(
"Array access index expected int or IntImm, but got " + type(index),
node.span,
)
if int(index) >= len(symbol):
self.report_error(
f"Array access out of bound, size: {len(symbol)}, got index {index}.",
node.span,
)
return symbol[int(index)]
else:
self.report_error(
f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
"buffers are supported.",
node.params[0].span,
)
def transform_Attr(self, node):
"""Visitor for field access of the form `x.y`.
This visitor is used to lookup function and symbol names. We have two
cases to handle here:
1. If we have a statement of the form `tir.something`, then we lookup
`tir.something` in the `Registry`. If the function is not in the
registry, then we try to find a `tvm.ir.op.Op` with the same name.
2. All other names `tvm.something` are lookup up in this current python
namespace.
"""
def get_full_attr_name(node: ast.Attr) -> str:
reverse_field_names = [node.field.name]
while isinstance(node.object, ast.Attr):
node = node.object
reverse_field_names.append(node.field.name)
if isinstance(node.object, ast.Var):
reverse_field_names.append(node.object.id.name)
return ".".join(reversed(reverse_field_names))
if isinstance(node.object, (ast.Var, ast.Attr)):
full_attr_name = get_full_attr_name(node)
attr_object, fields = full_attr_name.split(".", maxsplit=1)
if self.match_tir_namespace(attr_object):
func_name = "tir." + fields
res = Registry.lookup(func_name)
if res is not None:
return res
try:
return tvm.ir.op.Op.get(func_name)
except TVMError as e:
# Check if we got an attribute error
if e.args[0].find("AttributeError"):
self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
else:
raise e
symbol = self.transform(node.object)
if symbol is None:
self.report_error("Unsupported Attribute expression.", node.object.span)
if not hasattr(symbol, node.field.name):
self.report_error(
f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
)
res = getattr(symbol, node.field.name)
return res
def transform_TypeAttr(self, node):
"""Visitor for field access of the form `x.y` for types.
We have two cases here:
1. If the type is of the form `T.something`, we look up the type in
the `tir` namespace in this module.
2. If the type is of the form `tvm.x.something` then we look up
`tvm.x.something` in this modules namespace.
"""
if isinstance(node.object, ast.TypeVar):
if self.match_tir_namespace(node.object.id.name):
if not hasattr(tir, node.field.name):
self.report_error(
f"Invalid type annotation `tir.{node.field.name}`.", node.span
)
return getattr(tir, node.field.name)
symbol = self.transform(node.object)
if symbol is None:
self.report_error("Unsupported Attribute expression", node.object.span)
if not hasattr(symbol, node.field):
self.report_error(
f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
)
res = getattr(symbol, node.field)
return res
def transform_DictLiteral(self, node):
"""Dictionary literal visitor.
Handles dictionary literals of the form `{x:y, z:2}`.
"""
keys = [self.transform(key) for key in node.keys]
values = [self.transform(value) for value in node.values]
return dict(zip(keys, values))
def transform_Tuple(self, node):
"""Tuple visitor.
Handles tuples of the form `(x, y, 2)`.
"""
return tuple(self.transform(element) for element in node.values)
def transform_ArrayLiteral(self, node):
"""List literal visitor.
Handles lists of the form `[x, 2, 3]`.
"""
return [self.transform(element) for element in node.values]
def transform_Var(self, node):
"""Variable visitor
Handles variables like `x` in `x = 2`.
"""
name = node.id.name
if name == "meta":
return self.meta
symbol = Registry.lookup(name)
if symbol is not None:
return symbol
symbol = self.context.lookup_symbol(name)
if symbol is not None:
return symbol
self.report_error(f"Unknown identifier {name}.", node.span)
def transform_TypeVar(self, node):
"""Type variable visitor.
Equivalent to `transform_Var` but for types.
"""
name = node.id.name
symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
if symbol is not None:
return symbol
self.report_error(f"Unknown identifier {name}.", node.span)
def transform_Constant(self, node):
"""Constant value visitor.
Constant values include `None`, `"strings"`, `2` (integers), `4.2`
(floats), and `true` (booleans).
"""
return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span))
def transform_TypeConstant(self, node):
"""Constant value visitor for types.
See `transform_Constant`.
"""
if self._inside_buffer_sugar:
return self.transform_Constant(node)
return node.value
def transform_TypeTuple(self, node):
"""Tuple value visitor for types.
Mostly used in `transform_TypeCall` and `transform_TypeApply`.
"""
return [self.transform(value) for value in node.values]
def transform_TypeCall(self, node):
"""TypeCall visitor
This occurs when an expression is used inside a T.Buffer
parameter annotation.
"""
# ast.Call has the BuiltinOp as node.func_name.name, where
# ast.TypeCall has the BuiltinOp as node.func_name. So we can
# delegate to self.transform_Call, but the error messages for
# unsupported operations will highlight the entire expression
# and not just the function itself.
op = ast.Op(node.span, node.func_name)
call = ast.Call(node.span, op, node.params, node.keyword_params)
return self.transform_Call(call)
def transform_TypeApply(self, node):
"""Visitor for Type[Type] expressions.
Mostly used for ``T.Ptr`` expressions.
"""
func = self.transform(node.func_name)
if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
self.report_error(
f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
f"but found {type(func).__name__} instead.",
node.span,
)
param_types = []
for idx, param in enumerate(node.params):
param_type = self.transform(param)
if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
self.report_error(
f"Expected a type but found {type(param).__name__} "
f"at {idx}th type argument",
param.span,
)
param_types.append(param_type)
if len(param_types) == 1:
return func[param_types[0]]
else:
return func[param_types]
def handle_match_buffer_type(self, node, buffer_name):
"""special function to handle syntax sugar for match buffer.
This method is for buffer declarations in the function parameters.
"""
func = self.transform(node.func_name)
assert isinstance(func, SpecialStmt)
# parse args and kwargs for TypeCall and TypeApply
self._inside_buffer_sugar = True
try:
arg_list = self.parse_arg_list(func, node)
finally:
self._inside_buffer_sugar = False
# Note that the third element in arg_list would always be the 'name'
# TODO: This index is hardcoded as a workaround. Better to make it programmatic
if arg_list[2] is None:
arg_list[2] = buffer_name
buf = func.handle(node, self.context, arg_list, node.func_name.span)
return buf
def transform_Return(self, node):
self.report_error(
"TVM script does not support return statements. Instead the last statement in any "
"block is implicitly returned.",
node.span,
)
def get_tir_namespace(script: Union[Callable, type]) -> List[str]:
assert inspect.isfunction(script) or inspect.isclass(script)
env: Dict[str, Any] = script.__globals__
return [key for key in env.keys() if env[key] == tir]
def from_source(
input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
) -> Union[PrimFunc, IRModule]:
"""Parse function or string into PrimFunc or IRModule.
If possible, pass the TVM script in as a function so that line numbers and
filename will be accurate.
Parameters
----------
input_module : Union[str, Callable]
The python function to be parsed.
tir_prefix : Optional[List[str]]
The tir prefix list. Only works for str input, default by "tir" and "T".
Returns
-------
output : Union[Function, Module]
The Function or Module in IR.
"""
if isinstance(input_func, str):
tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
elif inspect.isfunction(input_func):
_, start_line = inspect.getsourcelines(input_func)
env: Dict[str, Any] = input_func.__globals__
namespace = [key for key in env.keys() if env[key] is tir]
_closure_vars = inspect.getclosurevars(input_func)
closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
parser = TVMScriptParser(start_line, namespace, closure_vars)
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
return result
else:
raise TypeError("Only function definitions are supported.")
def ir_module(input_module: type) -> IRModule:
"""Decorate a python class as tvm IRModule.
Parameters
----------
input_module : type
The python class to be parsed.
Returns
-------
output : IRModule
The result IRModule.
"""
if inspect.isclass(input_module):
func_dict = {
name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
}
return IRModule(func_dict)
raise TypeError("Only class definitions are supported.")