blob: 48d91dfa804471e224055bae66051da685f54dc7 [file] [log] [blame]
"""Statement AST Node in TVM.
User do not need to deal with AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each statement node have subfields that can be visited from python side.
.. code-block:: python
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a)
"""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from . import make as _make
class Stmt(NodeBase):
pass
@register_node
class LetStmt(Stmt):
"""LetStmt node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Stmt
The body statement.
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_make.LetStmt, var, value, body)
@register_node
class AssertStmt(Stmt):
"""AssertStmt node.
Parameters
----------
condition : Expr
The assert condition.
message : Expr
The error message.
body : Stmt
The body statement.
"""
def __init__(self, condition, message, body):
self.__init_handle_by_constructor__(
_make.AssertStmt, condition, message, body)
@register_node
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_make.ProducerConsumer, func, is_producer, body)
@register_node
class For(Stmt):
"""For node.
Parameters
----------
loop_var : Var
The loop variable.
min_val : Expr
The begining value.
extent : Expr
The length of the loop.
for_type : int
The for type.
device_api : int
The device api type.
body : Stmt
The body statement.
"""
Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
def __init__(self,
loop_var,
min_val,
extent,
for_type,
device_api,
body):
self.__init_handle_by_constructor__(
_make.For, loop_var, min_val, extent,
for_type, device_api, body)
@register_node
class Store(Stmt):
"""Store node.
Parameters
----------
buffer_var : Var
The buffer Variable.
value : Expr
The value we want to store.
index : Expr
The index in the store expression.
predicate : Expr
The store predicate.
"""
def __init__(self, buffer_var, value, index, predicate):
self.__init_handle_by_constructor__(
_make.Store, buffer_var, value, index, predicate)
@register_node
class Provide(Stmt):
"""Provide node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
value : Expr
The value to be stored.
args : list of Expr
The index arguments of the Provide.
"""
def __init__(self, func, value_index, value, args):
self.__init_handle_by_constructor__(
_make.Provide, func, value_index, value, args)
@register_node
class Allocate(Stmt):
"""Allocate node.
Parameters
----------
buffer_var : Var
The buffer variable.
dtype : str
The data type of the buffer.
extents : list of Expr
The extents of the allocate
condition : Expr
The condition.
body : Stmt
The body statement.
"""
def __init__(self,
buffer_var,
dtype,
extents,
condition,
body):
self.__init_handle_by_constructor__(
_make.Allocate, buffer_var, dtype,
extents, condition, body)
@register_node
class AttrStmt(Stmt):
"""AttrStmt node.
Parameters
----------
node : Node
The node to annotate the attribute
attr_key : str
Attribute type key.
value : Expr
The value of the attribute
body : Stmt
The body statement.
"""
def __init__(self, node, attr_key, value, body):
self.__init_handle_by_constructor__(
_make.AttrStmt, node, attr_key, value, body)
@register_node
class Free(Stmt):
"""Free node.
Parameters
----------
buffer_var : Var
The buffer variable.
"""
def __init__(self, buffer_var):
self.__init_handle_by_constructor__(
_make.Free, buffer_var)
@register_node
class Realize(Stmt):
"""Realize node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type of the operation.
bounds : list of range
The bound of realize
condition : Expr
The realize condition.
body : Stmt
The realize body
"""
def __init__(self,
func,
value_index,
dtype,
bounds,
condition,
body):
self.__init_handle_by_constructor__(
_make.Realize, func, value_index, dtype,
bounds, condition, body)
@register_node
class Block(Stmt):
"""Block node.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
"""
def __init__(self, first, rest):
self.__init_handle_by_constructor__(
_make.Block, first, rest)
@register_node
class IfThenElse(Stmt):
"""IfThenElse node.
Parameters
----------
condition : Expr
The expression
then_case : Stmt
The statement to execute if condition is true.
else_case : Stmt
The statement to execute if condition is false.
"""
def __init__(self, condition, then_case, else_case):
self.__init_handle_by_constructor__(
_make.IfThenElse, condition, then_case, else_case)
@register_node
class Evaluate(Stmt):
"""Evaluate node.
Parameters
----------
value : Expr
The expression to be evalued.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.Evaluate, value)
@register_node
class Prefetch(Stmt):
"""Prefetch node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
def __init__(self, func, value_index, dtype, bounds):
self.__init_handle_by_constructor__(
_make.Prefetch, func, value_index, dtype, bounds)
def stmt_seq(*args):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret = None
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
def stmt_list(stmt):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq