| |
| # pylint: disable=invalid-name, unused-import |
| """A parser for Relay's text format.""" |
| from __future__ import absolute_import |
| |
| import sys |
| |
| from collections import deque |
| from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict |
| |
| import tvm |
| |
| from . import module |
| from .base import Span, SourceName |
| from . import expr |
| from . import ty |
| from . import op |
| |
| |
| class ParseError(Exception): |
| """Exception type for parse errors.""" |
| |
| def __init__(self, message): |
| # type: (str) -> None |
| super(ParseError, self).__init__() |
| self.message = message |
| |
| PYTHON_VERSION = sys.version_info.major |
| try: |
| if PYTHON_VERSION == 2: |
| from .grammar.py2.RelayVisitor import RelayVisitor |
| from .grammar.py2.RelayParser import RelayParser |
| from .grammar.py2.RelayLexer import RelayLexer |
| else: |
| from .grammar.py3.RelayVisitor import RelayVisitor |
| from .grammar.py3.RelayParser import RelayParser |
| from .grammar.py3.RelayLexer import RelayLexer |
| except ImportError: |
| raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") |
| |
| try: |
| from antlr4 import ParserRuleContext, InputStream, CommonTokenStream |
| from antlr4.tree.Tree import TerminalNode |
| except ImportError: |
| raise ParseError("Couldn't find ANTLR runtime." + |
| "Try running `pip{} install antlr4-python{}-runtime`." |
| .format(PYTHON_VERSION, PYTHON_VERSION)) |
| |
| BINARY_OPS = { |
| RelayParser.MUL: op.multiply, |
| RelayParser.DIV: op.divide, |
| RelayParser.ADD: op.add, |
| RelayParser.SUB: op.subtract, |
| RelayParser.LT: op.less, |
| RelayParser.GT: op.greater, |
| RelayParser.LE: op.less_equal, |
| RelayParser.GE: op.greater_equal, |
| RelayParser.EQ: op.equal, |
| RelayParser.NE: op.not_equal, |
| } |
| |
| TYPE_PREFIXES = [ |
| "int", |
| "uint", |
| "float", |
| "bool", |
| ] |
| |
| T = TypeVar("T") |
| Scope = Deque[Tuple[str, T]] |
| Scopes = Deque[Scope[T]] |
| |
| def lookup(scopes, name): |
| # type: (Scopes[T], str) -> Optional[T] |
| """Look up `name` in `scopes`.""" |
| |
| for scope in scopes: |
| for key, val in scope: |
| if key == name: |
| return val |
| return None |
| |
| def spanify(f): |
| """A decorator which attaches span information |
| to the value returned by calling `f`. |
| |
| Intended for use with the below AST visiting |
| methods. The idea is that after we do the work |
| of constructing the AST we attach Span information. |
| """ |
| |
| def _wrapper(*args, **kwargs): |
| # Assumes 0th arg is self and gets source_name from object. |
| sn = args[0].source_name |
| # Assumes 1st arg is an ANTLR parser context. |
| ctx = args[1] |
| ast = f(*args, **kwargs) |
| line, col = ctx.getSourceInterval() |
| sp = Span(sn, line, col) |
| ast.set_span(sp) |
| return ast |
| return _wrapper |
| |
| # TODO(@jmp): Use https://stackoverflow.com/q/13889941 |
| # to figure out how to get ANTLR4 to be more unhappy about syntax errors |
| class ParseTreeToRelayIR(RelayVisitor): |
| """Parse Relay text format into Relay IR.""" |
| |
| def __init__(self, source_name): |
| # type: (str) -> None |
| self.source_name = source_name |
| self.module = module.Module({}) # type: module.Module |
| |
| # Adding an empty scope allows naked lets without pain. |
| self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] |
| self.global_var_scope = deque() # type: Scope[expr.GlobalVar] |
| self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] |
| self.graph_expr = [] # type: List[expr.Expr] |
| |
| super(ParseTreeToRelayIR, self).__init__() |
| |
| |
| def enter_var_scope(self): |
| # type: () -> None |
| """Enter a new Var scope so it can be popped off later.""" |
| |
| self.var_scopes.appendleft(deque()) |
| |
| def exit_var_scope(self): |
| # type: () -> Scope[expr.Var] |
| """Pop off the current Var scope and return it.""" |
| |
| return self.var_scopes.popleft() |
| |
| def mk_var(self, name, type_): |
| # type: (str, ty.Type) -> expr.Var |
| """Create a new Var and add it to the Var scope.""" |
| |
| var = expr.Var(name, type_) |
| self.var_scopes[0].appendleft((name, var)) |
| return var |
| |
| def mk_global_var(self, name): |
| # type: (str) -> expr.GlobalVar |
| """Create a new GlobalVar and add it to the GlobalVar scope.""" |
| |
| var = expr.GlobalVar(name) |
| self.global_var_scope.append((name, var)) |
| return var |
| |
| def enter_type_param_scope(self): |
| # type: () -> None |
| """Enter a new TypeVar scope so it can be popped off later.""" |
| |
| self.type_param_scopes.appendleft(deque()) |
| |
| def exit_type_param_scope(self): |
| # type: () -> Scope[ty.TypeVar] |
| """Pop off the current TypeVar scope and return it.""" |
| |
| return self.type_param_scopes.popleft() |
| |
| def mk_typ(self, name, kind): |
| # (str, ty.Kind) -> ty.TypeVar |
| """Create a new TypeVar and add it to the TypeVar scope.""" |
| |
| typ = ty.TypeVar(name, kind) |
| self.type_param_scopes[0].appendleft((name, typ)) |
| return typ |
| |
| def visitTerminal(self, node): |
| # type: (TerminalNode) -> Union[expr.Expr, int, float] |
| """Visit lexer tokens that aren't ignored or visited by other functions.""" |
| |
| node_type = node.getSymbol().type |
| node_text = node.getText() |
| name = node_text[1:] |
| |
| # variables |
| if node_type == RelayLexer.GLOBAL_VAR: |
| return lookup(deque([self.global_var_scope]), node_text[1:]) |
| elif node_type == RelayLexer.LOCAL_VAR: |
| # Remove the leading '%' and lookup the name. |
| var = lookup(self.var_scopes, name) |
| if var is None: |
| raise ParseError("Couldn't resolve `{}`.".format(name)) |
| return var |
| elif node_type == RelayLexer.GRAPH_VAR: |
| try: |
| return self.graph_expr[int(name)] |
| except IndexError: |
| raise ParseError("Couldn't resolve `{}`".format(name)) |
| |
| # data types |
| elif node_type == RelayLexer.NAT: |
| return int(node_text) |
| elif node_type == RelayLexer.FLOAT: |
| return float(node_text) |
| elif node_type == RelayLexer.BOOL_LIT: |
| if node_text == "True": |
| return True |
| elif node_text == "False": |
| return False |
| else: |
| raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) |
| |
| else: |
| raise ParseError("todo: {}".format(node_text)) |
| |
| def visit_list(self, ctx_list): |
| # type: (List[ParserRuleContext]) -> List[Any] |
| """"Visit a list of contexts.""" |
| |
| return [self.visit(ctx) for ctx in ctx_list] |
| |
| def getType_(self, ctx): |
| # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] |
| """Return a (possibly None) Relay type.""" |
| |
| if ctx is None: |
| return None |
| |
| return self.visit(ctx) |
| |
| def visitProg(self, ctx): |
| # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] |
| if ctx.defn(): |
| self.visit_list(ctx.defn()) |
| return self.module |
| |
| return self.visit(ctx.expr()) |
| |
| # Exprs |
| |
| def visitOpIdent(self, ctx): |
| # type: (RelayParser.OpIdentContext) -> op.Op |
| return op.get(ctx.CNAME().getText()) |
| |
| # pass through |
| def visitParens(self, ctx): |
| # type: (RelayParser.ParensContext) -> expr.Expr |
| return self.visit(ctx.expr()) |
| |
| # pass through |
| def visitBody(self, ctx): |
| # type: (RelayParser.BodyContext) -> expr.Expr |
| return self.visit(ctx.expr()) |
| |
| def visitScalarFloat(self, ctx): |
| # type: (RelayParser.ScalarFloatContext) -> expr.Constant |
| return expr.const(self.visit(ctx.FLOAT())) |
| |
| def visitScalarInt(self, ctx): |
| # type: (RelayParser.ScalarIntContext) -> expr.Constant |
| return expr.const(self.visit(ctx.NAT())) |
| |
| def visitScalarBool(self, ctx): |
| # type: (RelayParser.ScalarBoolContext) -> expr.Constant |
| return expr.const(self.visit(ctx.BOOL_LIT())) |
| |
| def visitNeg(self, ctx): |
| # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] |
| val = self.visit(ctx.expr()) |
| if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: |
| # fold Neg in for scalars |
| return expr.const(-val.data.asnumpy().item()) |
| |
| return op.negative(val) |
| |
| def visitTuple(self, ctx): |
| # type: (RelayParser.TupleContext) -> expr.Tuple |
| tup = self.visit_list(ctx.expr()) |
| return expr.Tuple(tup) |
| |
| # Currently doesn't support mutable sequencing. |
| def visitLet(self, ctx): |
| # type: (RelayParser.SeqContext) -> expr.Let |
| """Desugar various sequence constructs to Relay Let nodes.""" |
| if ctx.MUT() is not None: |
| raise ParseError("Mutation is currently unsupported.") |
| |
| if ctx.var() is None or ctx.var().ident() is None: |
| # anonymous identity |
| ident = "_" |
| type_ = None |
| else: |
| local_var = ctx.var().ident().LOCAL_VAR() |
| if local_var is None: |
| raise ParseError("Only local ids may be used in `let`s.") |
| ident = local_var.getText()[1:] |
| type_ = self.getType_(ctx.var().type_()) |
| |
| var = self.mk_var(ident, type_) |
| |
| self.enter_var_scope() |
| value = self.visit(ctx.expr(0)) |
| self.exit_var_scope() |
| |
| body = self.visit(ctx.expr(1)) |
| |
| return expr.Let(var, value, body) |
| |
| def visitBinOp(self, ctx): |
| # type: (RelayParser.BinOpContext) -> expr.Call |
| """Desugar binary operators.""" |
| arg0, arg1 = self.visit_list(ctx.expr()) |
| relay_op = BINARY_OPS.get(ctx.op.type) |
| |
| if relay_op is None: |
| raise ParseError("Unimplemented binary op.") |
| |
| return relay_op(arg0, arg1) |
| |
| @spanify |
| def visitVar(self, ctx): |
| # type: (RelayParser.VarContext) -> expr.Var |
| """Visit a single variable.""" |
| ident = ctx.ident().LOCAL_VAR() |
| |
| if ident is None: |
| raise ParseError("Only local ids may be used in vars.") |
| |
| type_ = self.getType_(ctx.type_()) |
| |
| return self.mk_var(ident.getText()[1:], type_) |
| |
| def visitVarList(self, ctx): |
| # type: (RelayParser.VarListContext) -> List[expr.Var] |
| return self.visit_list(ctx.var()) |
| |
| # TODO: support a larger class of values than just Relay exprs |
| def visitAttr(self, ctx): |
| # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] |
| return (ctx.CNAME().getText(), self.visit(ctx.expr())) |
| |
| def visitAttrList(self, ctx): |
| # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] |
| return dict(self.visit_list(ctx.attr())) |
| |
| def visitArgList(self, |
| ctx # type: RelayParser.ArgListContext |
| ): |
| # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] |
| var_list = self.visit(ctx.varList()) if ctx.varList() else None |
| attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None |
| |
| return (var_list, attr_list) |
| |
| def mk_func(self, ctx): |
| # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function |
| """Construct a function from either a Func or Defn.""" |
| |
| # Enter var scope early to put params in scope. |
| self.enter_var_scope() |
| # Capture type params in params. |
| self.enter_type_param_scope() |
| var_list, attr_list = self.visit(ctx.argList()) |
| ret_type = self.getType_(ctx.type_()) |
| |
| type_params = list(self.exit_type_param_scope()) |
| if type_params: |
| _, type_params = zip(*type_params) |
| |
| body = self.visit(ctx.body()) |
| self.exit_var_scope() |
| |
| attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None |
| |
| return expr.Function(var_list, body, ret_type, type_params, attrs) |
| |
| @spanify |
| def visitFunc(self, ctx): |
| # type: (RelayParser.FuncContext) -> expr.Function |
| return self.mk_func(ctx) |
| |
| # TODO: how to set spans for definitions? |
| # @spanify |
| def visitDefn(self, ctx): |
| # type: (RelayParser.DefnContext) -> None |
| ident = ctx.ident().GLOBAL_VAR() |
| if ident is None: |
| raise ParseError("Only global ids may be used in `def`s.") |
| ident_name = ident.getText()[1:] |
| ident = self.mk_global_var(ident_name) |
| |
| self.module[ident] = self.mk_func(ctx) |
| |
| @spanify |
| def visitCall(self, ctx): |
| # type: (RelayParser.CallContext) -> expr.Call |
| visited_exprs = self.visit_list(ctx.expr()) |
| |
| func = visited_exprs[0] |
| args = visited_exprs[1:] |
| |
| return expr.Call(func, args, None, None) |
| |
| @spanify |
| def visitIfElse(self, ctx): |
| # type: (RelayParser.IfElseContext) -> expr.If |
| """Construct a Relay If node. Creates a new scope for each branch.""" |
| cond = self.visit(ctx.expr()) |
| |
| self.enter_var_scope() |
| true_branch = self.visit(ctx.body(0)) |
| self.exit_var_scope() |
| |
| self.enter_var_scope() |
| false_branch = self.visit(ctx.body(1)) |
| self.exit_var_scope() |
| |
| return expr.If(cond, true_branch, false_branch) |
| |
| @spanify |
| def visitGraph(self, ctx): |
| # type: (RelayParser.GraphContext) -> expr.Expr |
| """Visit a graph variable assignment.""" |
| if ctx.ident().GRAPH_VAR() is None: |
| raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) |
| graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) |
| |
| self.enter_var_scope() |
| value = self.visit(ctx.expr(0)) |
| self.exit_var_scope() |
| |
| if graph_nid != len(self.graph_expr): |
| raise ParseError( |
| "Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ |
| "but got `%{}`".format(graph_nid)) |
| self.graph_expr.append(value) |
| |
| kont = self.visit(ctx.expr(1)) |
| return kont |
| |
| # Types |
| |
| # pylint: disable=unused-argument |
| def visitIncompleteType(self, ctx): |
| # type (RelayParser.IncompleteTypeContext) -> None: |
| return None |
| |
| def visitIdentType(self, ctx): |
| # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] |
| ident_type = ctx.CNAME().getText() |
| |
| # look through all type prefixes for a match |
| for type_prefix in TYPE_PREFIXES: |
| if ident_type.startswith(type_prefix): |
| return ty.scalar_type(ident_type) |
| |
| raise ParseError("Unknown builtin type: {}".format(ident_type)) |
| |
| # def visitCallType(self, ctx): |
| # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] |
| # ident_type = ctx.identType().CNAME().getText() |
| |
| # args = self.visit_list(ctx.type_()) |
| |
| # if not args: |
| # raise ParseError("Type-level functions must have arguments!") |
| |
| # func_type = TYPE_FUNCS.get(ident_type)(args) |
| |
| # if func_type is None: |
| # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) |
| # else: |
| # return func_type |
| |
| def visitParensShape(self, ctx): |
| # type: (RelayParser.ParensShapeContext) -> int |
| return self.visit(ctx.shape()) |
| |
| def visitShapeSeq(self, ctx): |
| # type: (RelayParser.ShapeSeqContext) -> List[int] |
| return self.visit_list(ctx.shape()) |
| |
| def visitTensorType(self, ctx): |
| # type: (RelayParser.TensorTypeContext) -> ty.TensorType |
| """Create a simple tensor type. No generics.""" |
| |
| shape = self.visit(ctx.shapeSeq()) |
| dtype = self.visit(ctx.type_()) |
| |
| if not isinstance(dtype, ty.TensorType): |
| raise ParseError("Expected dtype to be a Relay base type.") |
| |
| dtype = dtype.dtype |
| |
| return ty.TensorType(shape, dtype) |
| |
| def visitTupleType(self, ctx): |
| # type: (RelayParser.TupleTypeContext) -> ty.TupleType |
| return ty.TupleType(self.visit_list(ctx.type_())) |
| |
| def visitFuncType(self, ctx): |
| # type: (RelayParser.FuncTypeContext) -> ty.FuncType |
| types = self.visit_list(ctx.type_()) |
| |
| arg_types = types[:-1] |
| ret_type = types[-1] |
| |
| return ty.FuncType(arg_types, ret_type, [], None) |
| |
| def make_parser(data): |
| # type: (str) -> RelayParser |
| """Construct a RelayParser a given data stream.""" |
| input_stream = InputStream(data) |
| lexer = RelayLexer(input_stream) |
| token_stream = CommonTokenStream(lexer) |
| return RelayParser(token_stream) |
| |
| __source_name_counter__ = 0 |
| |
| def fromtext(data, source_name=None): |
| # type: (str, str) -> Union[expr.Expr, module.Module] |
| """Parse a Relay program.""" |
| global __source_name_counter__ |
| |
| if source_name is None: |
| source_name = "source_file{0}".format(__source_name_counter__) |
| |
| if isinstance(source_name, str): |
| source_name = SourceName(source_name) |
| |
| tree = make_parser(data).prog() |
| return ParseTreeToRelayIR(source_name).visit(tree) |