| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file parser.cc |
| * \brief A parser for TVM IR. |
| */ |
| #include <tvm/ir/module.h> |
| #include <tvm/node/reflection.h> |
| #include <tvm/parser/parser.h> |
| #include <tvm/relay/adt.h> |
| #include <tvm/relay/expr.h> |
| #include <tvm/relay/function.h> |
| #include <tvm/relay/transform.h> |
| #include <tvm/runtime/logging.h> |
| #include <tvm/runtime/object.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/target/virtual_device.h> |
| |
| #include <fstream> |
| |
| #include "../support/scalars.h" |
| #include "./meta_ref.h" |
| #include "./op_table.h" |
| #include "./span_check.h" |
| #include "./tokenizer.h" |
| #include "tvm/runtime/builtin_fp16.h" |
| |
| namespace tvm { |
| namespace parser { |
| |
| using namespace relay; |
| using Expr = relay::Expr; |
| |
| /*! \brief The meta table maps from type key to a sequence of objects. */ |
| using MetaTable = Map<String, Array<ObjectRef>>; |
| |
| using tvm::transform::CreateModulePass; |
| using tvm::transform::PassContext; |
| |
| /*! \brief A helper for passing around spans with data structures with |
| * no span field. |
| */ |
| template <typename T> |
| struct Spanned { |
| T data; |
| Span span; |
| |
| Spanned() = default; |
| Spanned(const Spanned<T>& other) = default; |
| Spanned(T data, Span span) : data(data), span(span) {} |
| }; |
| |
| /*! \brief A wrapper structure for capturing the result of parsing |
| * a global definition *before* we add it to the IRModule. |
| * |
| * This enables the parser to parse everything in one pass before |
| * constructing the IRModule. |
| */ |
| struct GlobalFunc { |
| GlobalVar global; |
| Function function; |
| GlobalFunc() : global(), function() {} |
| GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {} |
| GlobalFunc(const GlobalFunc& gfunc) { |
| this->global = gfunc.global; |
| this->function = gfunc.function; |
| } |
| }; |
| |
| /*! \brief A wrapper structure for capturing all top-level definitions |
| * when parsing a module. |
| */ |
| struct Definitions { |
| /*! \brief The set of global functions. */ |
| std::vector<GlobalFunc> funcs; |
| /*! \brief The set of type definitions. */ |
| std::vector<TypeData> types; |
| // TODO(@jroesch): contain meta-table below |
| }; |
| |
| /*! \brief A structure representing the semantic versioning information |
| * for a Relay program. |
| */ |
| class SemVer { |
| public: |
| int major_version; |
| int minor_version; |
| int patch_version; |
| |
| SemVer() : major_version(0), minor_version(0), patch_version(0) {} |
| SemVer(int major_version, int minor_version, int patch_version) |
| : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {} |
| SemVer(const SemVer& other) |
| : major_version(other.major_version), |
| minor_version(other.minor_version), |
| patch_version(other.patch_version) {} |
| }; |
| |
| /*! \brief A simple wrapper around a mapping from raw string names |
| * to a TVM variable, type variable or other binder type. |
| */ |
| template <typename T> |
| struct Scope { |
| /*! \brief The internal map. */ |
| std::unordered_map<std::string, T> name_map; |
| }; |
| |
| /*! \brief A stack of scopes. |
| * |
| * In order to properly handle scoping we must maintain a stack of scopes. |
| * |
| * A stack allows users to write programs which contain repeated variable |
| * names and to properly handle both nested scopes and removal of variables |
| * when they go out of scope. |
| * |
| * This is the classic approach to lexical scoping. |
| */ |
| template <typename T> |
| class ScopeStack { |
| private: |
| std::vector<Scope<T>> scope_stack; |
| std::unordered_map<std::string, T> free_vars; |
| |
| public: |
| /*! \brief Adds a variable binding to the current scope. */ |
| void Add(const std::string& name, const T& value) { |
| if (!this->scope_stack.size()) { |
| LOG(FATAL) << "internal issue"; |
| } |
| this->scope_stack.back().name_map.insert({name, value}); |
| } |
| |
| void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } |
| |
| /*! \brief Looks up a variable name in the scope stack returning the matching variable |
| * in most recent scope. */ |
| T Lookup(const std::string& name) { |
| for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) { |
| auto it = scope->name_map.find(name); |
| if (it != scope->name_map.end()) { |
| return it->second; |
| } |
| } |
| |
| // Check if we bound a free variable declaration. |
| auto it = free_vars.find(name); |
| if (it != free_vars.end()) { |
| return it->second; |
| } |
| |
| return T(); |
| } |
| |
| /*! \brief Adds a fresh scope. */ |
| void PushStack() { this->scope_stack.push_back(Scope<T>()); } |
| |
| /*! \brief Removes the most recent scope. */ |
| void PopStack() { this->scope_stack.pop_back(); } |
| }; |
| |
| struct DuplicateKeyError : public Error { |
| explicit DuplicateKeyError(const std::string& msg) : Error(msg) {} |
| }; |
| |
| /*! \brief A table of interning strings as global function and type names. */ |
| template <typename T> |
| struct InternTable { |
| /*! \brief The internal table mapping strings to a unique allocation. */ |
| std::unordered_map<std::string, T> table; |
| DiagnosticContext* ctx; |
| |
| /*! \brief Add the unique allocation. */ |
| void Add(const std::string& name, const T& t) { |
| auto it = table.find(name); |
| if (it != table.end()) { |
| throw DuplicateKeyError("duplicate key name in intern table"); |
| } else { |
| table.insert({name, t}); |
| } |
| } |
| |
| /*! \brief Return the unique allocation. */ |
| Optional<T> Get(const std::string& name) const { |
| auto it = table.find(name); |
| if (it != table.end()) { |
| return Optional<T>(it->second); |
| } else { |
| return Optional<T>(); |
| } |
| } |
| }; |
| |
| GlobalVar AddOrGet(InternTable<GlobalVar>* table, const std::string& name) { |
| auto var = table->Get(name); |
| if (var) { |
| return var.value(); |
| } else { |
| auto gvar = GlobalVar(name); |
| table->Add(name, gvar); |
| return gvar; |
| } |
| } |
| |
| GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& name, |
| TypeKind kind = TypeKind::kType) { |
| auto var = table->Get(name); |
| if (var) { |
| auto tvar = var.value(); |
| TypeKind& tvar_kind = const_cast<TypeKind&>(tvar->kind); |
| tvar_kind = kind; |
| return tvar; |
| } else { |
| auto gvar = GlobalTypeVar(name, kind); |
| table->Add(name, gvar); |
| return gvar; |
| } |
| } |
| |
| /*! \brief The parser class is the main interface to the parser. |
| * the parser is not currently exposed beyond this .cc file. |
| * |
| * The parser is initialized with a diagnostic context, an |
| * operator table, and a token stream. |
| * |
| * The rest of the internal state is used to map the human readable |
| * form to in-memory IR representation. |
| * |
| * The main entry point to the parser are a set of parsing methods |
| * such as `ParseModule` and `ParseExpr`. |
| * |
| * As with traditional recursive descent parsers the parsing methods |
| * are factored recursively just as one would do with a formal language |
| * grammar. |
| * |
| * You can view a recursive descent parser as a human friendly way to specify |
| * a state machine, and thus this factoring is necessary as the 'state' of this |
| * machine is the combination of the current parsing method and the next token. |
| * |
| * Parsing proceeds by matching a token and then dispatching to the appropriate |
| * method to parse the next tokens in the stream. |
| * |
| * For example if we are parsing a type and encounter a "Tensor" token we switch |
| * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`. |
| * |
| * Certain matches like this are unambiguous and proceed in a straight line fashion |
| * once the initial token is found. Other parsing is more complex and requires some |
| * tricks to correctly parse. |
| * |
| * For example when we find a '(' in an expression context, it may be part of |
| * a tuple, the arguments to a call, or a parenthesized expression. The below code |
| * disambiguate these cases by factoring expression parsing into a series of methods |
| * which encode the parsing context and thus how to interpret the parenthesis. |
| * |
| * For more information one should be able to read the code in order starting with |
| * `ParseModule` or `ParseExpr`. |
| */ |
| class Parser { |
| public: |
| /*! \brief The version that the parser is parsing. */ |
| SemVer version; |
| |
| /*! \brief The IRModule we are building. */ |
| IRModule module; |
| |
| /*! \brief The diagnostic context used for error reporting. */ |
| DiagnosticContext diag_ctx; |
| |
| const Source& source; |
| |
| /*! \brief The current position in the token stream. */ |
| int pos; |
| |
| /*! \brief The token stream for the parser. */ |
| std::vector<Token> tokens; |
| |
| /*! \brief The configured operator table. */ |
| OperatorTable op_table; |
| |
| /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */ |
| bool ignore_whitespace; |
| |
| /*! \brief A global mapping for GlobalVar. */ |
| InternTable<GlobalVar> global_names; |
| |
| /*! \brief A global mapping for type definitions. */ |
| InternTable<GlobalTypeVar> type_names; |
| |
| /*! \brief A global mapping for constructor names. */ |
| InternTable<Constructor> ctors; |
| |
| /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */ |
| std::unordered_map<int, Expr> graph_ctx; |
| |
| /*! \brief The set of type scopes used for generics. */ |
| ScopeStack<TypeVar> type_scopes; |
| |
| /*! \brief The set of expression scopes used for lexical scope. */ |
| ScopeStack<Var> expr_scopes; |
| |
| /*! \brief The metadata section. */ |
| MetaTable meta_table; |
| |
| Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector<Token> tokens, |
| OperatorTable op_table, MetaTable table) |
| : module(module), |
| diag_ctx(ctx), |
| source(source), |
| pos(0), |
| tokens(tokens), |
| op_table(op_table), |
| ignore_whitespace(true), |
| meta_table(table) { |
| InitializeGlobals(); |
| InitializeTypeDefs(); |
| } |
| |
| /*! If we are parsing into a module with previously loaded data types we need to |
| * map constructor names and variable names in the global tables. |
| */ |
| void InitializeTypeDefs() { |
| for (auto pair : this->module->type_definitions) { |
| type_names.Add(pair.first->name_hint, pair.first); |
| for (auto ctor : pair.second->constructors) { |
| ctors.Add(ctor->name_hint, ctor); |
| } |
| } |
| } |
| |
| void InitializeGlobals() { |
| for (auto pair : this->module->functions) { |
| global_names.Add(pair.first->name_hint, pair.first); |
| } |
| } |
| |
| /*! \brief Examine the next token in the stream, the current parser is configured to be |
| * whitespace insensitive so we will skip all whitespace or comment tokens. */ |
| Token Peek() { |
| // For now we ignore all whitespace tokens and comments. |
| // We can tweak this behavior later to enable white space sensitivity in the parser. |
| while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace && |
| (tokens.at(pos)->token_type == TokenType::kWhitespace || |
| tokens.at(pos)->token_type == TokenType::kNewline || |
| tokens.at(pos)->token_type == TokenType::kLineComment || |
| tokens.at(pos)->token_type == TokenType::kComment)) { |
| pos++; |
| } |
| |
| if (pos < static_cast<int64_t>(tokens.size())) { |
| return Token(this->tokens.at(pos)); |
| } else { |
| return Token::Null(); |
| } |
| } |
| |
| /*! \brief Lookahead by N tokens. |
| * \param n The number of tokens to lookahead. |
| * \return The Nth token. |
| */ |
| Token Lookahead(int n) { |
| ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1"; |
| |
| // We intend to skip n - 1 tokens, then return the nth. |
| auto old_pos = pos; |
| for (int i = 0; i < n - 1; i++) { |
| Peek(); |
| pos++; |
| } |
| |
| auto tok = Peek(); |
| pos = old_pos; |
| return tok; |
| } |
| |
| /*! \brief Consume a token, this method is the lowest level way to consume a token |
| * and will not ignore white space or look ahead in anyway. |
| * |
| * /param token_type The token type to match. |
| */ |
| void Consume(const TokenType& token_type) { |
| if (tokens[pos]->token_type != token_type) { |
| this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span) |
| << "expected a " << Pretty(token_type) << " found " |
| << Pretty(Peek()->token_type)); |
| } |
| pos++; |
| } |
| |
| /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such |
| * as whitespace or comments returning the first meaningful token. |
| * |
| * We then try and consume the requested token, this will trigger an error if the |
| * current token does not match the token_type. |
| */ |
| Token Match(const TokenType& token_type) { |
| auto tok = Peek(); |
| Consume(token_type); |
| return tok; |
| } |
| |
| /*! Conditionally consume a token when it matches, this will never trigger an error |
| * as we guard against consuming the token before we do. |
| * |
| * Useful for matching optional tokens, effectively looksahead by one. |
| */ |
| bool WhenMatch(const TokenType& token_type) { |
| VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek(); |
| if (Peek()->token_type == token_type) { |
| Consume(token_type); |
| return true; |
| } else { |
| return false; |
| } |
| } |
| |
| /* \brief Add a graph binding to the parsing context |
| * |
| * For example if we parse %0 = add(...), map 0 -> add(...), etc. |
| */ |
| void AddGraphBinding(const Token& token, const Expr& expr) { |
| auto graph_no = token.ToNumber(); |
| this->graph_ctx.insert({graph_no, expr}); |
| } |
| |
| /* \brief Lookup a previously bound graph variable. |
| * |
| * Note: we take tokens in all lookup methods so that we |
| * that we can do error reporting based on token location. |
| */ |
| Expr LookupGraphBinding(const Token& token) { |
| auto graph_no = token.ToNumber(); |
| auto it = this->graph_ctx.find(graph_no); |
| if (it != this->graph_ctx.end()) { |
| return it->second; |
| } else { |
| LOG(FATAL) << "Local variable %" << graph_no << " has not yet been defined"; |
| throw; |
| } |
| } |
| |
| /*! \brief Bind a local variable in the expression scope. |
| * |
| * "x" -> Var("x"), these are needed to map from the raw string names |
| * to unique variable nodes. |
| * If a virtual device is specified, sets the virtual device of the variable. |
| */ |
| Var BindVar(const std::string& name, const relay::Type& type_annotation, |
| Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) { |
| auto var = Var(name, type_annotation); |
| var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained()); |
| VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var); |
| this->expr_scopes.Add(name, var); |
| return var; |
| } |
| |
| /*! \brief Bind a local variable in the expression scope. |
| * |
| * "x" -> Var("x"), these are needed to map from the raw string names |
| * to unique variable nodes. |
| */ |
| Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { |
| auto var = Var(name, type_annotation); |
| this->expr_scopes.AddFreeVar(name, var); |
| return var; |
| } |
| |
| /*! \brief Bind a type variable in the type scope. |
| * |
| * "A" -> TypeVar("A", ...), these are needed to map from raw string names |
| * to unique type variable nodes. |
| */ |
| TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) { |
| auto type_var = TypeVar(name, type_kind); |
| this->type_scopes.Add(name, type_var); |
| return type_var; |
| } |
| |
| /*! \brief Lookup a variable in the expression scope. |
| * |
| * Note: all lookup methods take tokens intentionally for error reporting information. |
| */ |
| Var LookupLocal(const Token& local) { |
| auto var = this->expr_scopes.Lookup(local.ToString()); |
| if (!var.defined()) { |
| diag_ctx.Emit(Diagnostic::Error(local->span) |
| << "this local variable has not been previously declared"); |
| } |
| return var; |
| } |
| |
| /*! \brief Lookup a variable in the type scope. |
| * |
| * Note: all lookup methods take tokens intentionally for error reporting information. |
| */ |
| TypeVar LookupTypeVar(const Token& ident) { |
| auto var = this->type_scopes.Lookup(ident.ToString()); |
| return var; |
| } |
| |
| /*! \brief Add an expression scope to the scope stack. */ |
| void PushScope() { this->expr_scopes.PushStack(); } |
| |
| /*! \brief Remove N expression scopes from the scope stack. */ |
| void PopScopes(int n) { |
| for (int i = 0; i < n; i++) { |
| this->expr_scopes.PopStack(); |
| } |
| } |
| |
| /*! \brief Add an type scope to the scope stack. */ |
| void PushTypeScope() { this->type_scopes.PushStack(); } |
| |
| /*! \brief Remove N type scopes from the scope stack. */ |
| void PopTypeScopes(int n) { |
| for (int i = 0; i < n; i++) { |
| this->type_scopes.PopStack(); |
| } |
| } |
| |
| /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ |
| NDArray NumberToNDArray(const Token& token) { |
| if (token->token_type == TokenType::kInteger) { |
| return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data)); |
| } else if (token->token_type == TokenType::kFloat) { |
| return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data)); |
| } else { |
| LOG(FATAL) << "internal error: should only call this function on numeric tokens"; |
| return {}; |
| } |
| } |
| |
| [[noreturn]] void ParseError(const Token& token, const std::string& msg) { |
| throw std::runtime_error(msg); |
| } |
| |
| /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */ |
| template <typename R> |
| R Bracket(TokenType open, TokenType close, std::function<R()> parser) { |
| Match(open); |
| R result = parser(); |
| Match(close); |
| return result; |
| } |
| |
| /*! \brief Parse `(` parser() `)`. */ |
| template <typename R> |
| R Parens(std::function<R()> parser) { |
| return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); |
| } |
| |
| /*! \brief Parse `{` parser() `}`. */ |
| template <typename R> |
| R Block(std::function<R()> parser) { |
| return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); |
| } |
| |
| template <typename R> |
| R WithSpan(std::function<R()> parser) { |
| auto start_span = Peek()->span; |
| VLOG(9) << "WithSpan: start_span = " << start_span; |
| R ast = parser(); |
| if (ast.defined()) { |
| // The token at the head of the stream is now 1 past where we parsed. So we find its start |
| // position as its start and end, so that when we merge we only grow the spanned region |
| // to the start of the current stream. |
| auto span_pos = pos - 1; |
| while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace || |
| tokens.at(span_pos)->token_type == TokenType::kNewline || |
| tokens.at(span_pos)->token_type == TokenType::kLineComment || |
| tokens.at(span_pos)->token_type == TokenType::kComment)) { |
| span_pos--; |
| } |
| auto end_token = tokens.at(span_pos); |
| VLOG(9) << "WithSpan: end_span = " << end_token->span; |
| ast->span = start_span.Merge(end_token->span); |
| } |
| return ast; |
| } |
| |
| struct MetaRef { |
| std::string type_key; |
| uint64_t node_index; |
| Span span; |
| MetaRef(std::string type_key, uint64_t node_index, Span span) |
| : type_key(type_key), node_index(node_index), span(span) {} |
| }; |
| |
| MetaRef MetaRefFromToken(const Token& tok) { |
| Call ref = Downcast<Call>(tok->data); |
| auto attrs = ref->attrs.as<MetaRefAttrs>(); |
| auto type_key = attrs->node_type_key; |
| auto index = attrs->node_index; |
| return MetaRef(type_key, index, ref->span); |
| } |
| |
| /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. |
| * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` |
| * the second, and so on. |
| */ |
| ObjectRef ParseMetaRef() { |
| auto meta_ref_tok = Match(TokenType::kMetaReference); |
| auto meta_ref = MetaRefFromToken(meta_ref_tok); |
| auto it = this->meta_table.find(meta_ref.type_key); |
| if (it != this->meta_table.end()) { |
| auto nodes = (*it).second; |
| if (meta_ref.node_index < nodes.size()) { |
| return nodes[meta_ref.node_index]; |
| } else { |
| this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
| << "the node index `" << meta_ref.node_index |
| << "` is out of bounds for `" << meta_ref.type_key << "`"); |
| return ObjectRef(); |
| } |
| } else { |
| this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
| << "no entry in the meta table for `" << meta_ref.type_key << "`"); |
| return ObjectRef(); |
| } |
| } |
| /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and |
| * ending with a stop token. |
| * |
| * The simple form being <start> (<parse()> <seperator>)* <stop>. |
| * |
| * This also provides a fourth argument which is allowed to run when the sequence which matches |
| * the inner sequence can not proceed. |
| * |
| * This is useful for parsing things like attributes which don't match the standard expression |
| * parsers but are contained within the stop token. |
| */ |
| template <typename T> |
| Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse, |
| std::function<bool()> before_stop = nullptr) { |
| VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) |
| << " stop=" << ToString(stop); |
| Match(start); |
| |
| // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream |
| // we must parse leftovers, then match a stop token. |
| if (before_stop) { |
| auto did_parse = before_stop(); |
| if (did_parse) { |
| Match(stop); |
| return {}; |
| } |
| } |
| |
| // This is the case in which we find an empty arguments lists and no leftovers. |
| if (WhenMatch(stop)) { |
| return Array<T>(); |
| } else { |
| VLOG(9) << "Parser::ParseSequence: parse first"; |
| auto data = parse(); |
| Array<T> elements = {data}; |
| |
| if (WhenMatch(stop)) { |
| return elements; |
| // parse '( expr ',' * ')' |
| } else if (WhenMatch(sep)) { |
| while (true) { |
| VLOG(9) << "Parser::ParseSequence: parse element"; |
| if (WhenMatch(stop)) { |
| break; |
| } else { |
| // If before stop is |
| if (before_stop) { |
| auto did_parse = before_stop(); |
| if (did_parse) { |
| Match(stop); |
| return elements; |
| } |
| } |
| auto data = parse(); |
| WhenMatch(sep); |
| elements.push_back(data); |
| } |
| } |
| return elements; |
| } else { |
| auto next = Peek(); |
| this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
| << "expected a " << Pretty(stop) << " found " |
| << Pretty(next->token_type)); |
| return Array<T>(nullptr); |
| } |
| } |
| } |
| |
| /*! \brief Parse a full IRModule. */ |
| IRModule ParseModule() { |
| // Parse the semver header at the top of the module. |
| this->version = ParseSemVer(); |
| // Parse the definitions. |
| auto defs = ParseDefinitions(); |
| // Parse the metadata section at the end. |
| auto metadata = ParseMetadata(); |
| |
| Match(TokenType::kEndOfFile); |
| |
| for (auto type_def : defs.types) { |
| module->AddTypeDef(type_def->header, type_def); |
| } |
| |
| for (auto func : defs.funcs) { |
| module->Add(func.global, func.function, true); |
| } |
| |
| return module; |
| } |
| |
| /*! \brief Parse the semantic versioning header. */ |
| SemVer ParseSemVer(bool required = true) { |
| if (Peek()->token_type == TokenType::kVersion) { |
| auto version = Match(TokenType::kVersion); |
| // TODO(@jroesch): we currently only support 0.0.5. |
| if (version.ToString() != "\"0.0.5\"") { |
| this->diag_ctx.Emit(Diagnostic::Error(version->span) |
| << "invalid semantic version `" << version.ToString() << "`"); |
| } |
| } else if (required) { |
| this->diag_ctx.Emit(Diagnostic::Error(Peek()->span) |
| << "expected text format semantic version, found a " |
| << PrettyPrint(Peek())); |
| |
| this->diag_ctx.Emit(Diagnostic::Help(Peek()->span) |
| << "you can annotate it as #[version = \"0.0.5\"]"); |
| } |
| return SemVer(0, 0, 5); |
| } |
| |
| /*! \brief Parse zero or more Relay definitions. */ |
| Definitions ParseDefinitions() { |
| Definitions defs; |
| |
| while (true) { |
| auto next = Peek(); |
| switch (next->token_type) { |
| case TokenType::kDefn: { |
| Consume(TokenType::kDefn); |
| auto global_tok = Match(TokenType::kGlobal); |
| auto global_name = global_tok.ToString(); |
| auto global = AddOrGet(&global_names, global_name); |
| auto func = WithSpan<relay::Function>([&]() { return ParseFunctionDef(); }); |
| ICHECK(func->span.defined()) << "spans must be set in parser"; |
| defs.funcs.push_back(GlobalFunc(global, func)); |
| continue; |
| } |
| case TokenType::kTypeDef: { |
| defs.types.push_back(ParseTypeDef()); |
| continue; |
| } |
| case TokenType::kExtern: { |
| Consume(TokenType::kExtern); |
| auto type_def = ParseTypeDef(); |
| if (type_def->constructors.size()) { |
| diag_ctx.Emit(Diagnostic::Error(next->span) |
| << "an external type may not have any constructors"); |
| } |
| defs.types.push_back(type_def); |
| } |
| default: |
| return defs; |
| } |
| } |
| } |
| |
| /*! \brief Parse zero or more Relay type definitions. */ |
| TypeData ParseTypeDef() { |
| // Match the `type` keyword. |
| Match(TokenType::kTypeDef); |
| // Parse the type's identifier. |
| auto type_tok = Match(TokenType::kIdentifier); |
| auto type_id = type_tok.ToString(); |
| auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle); |
| |
| Array<TypeVar> generics; |
| |
| bool should_pop = false; |
| if (Peek()->token_type == TokenType::kLSquare) { |
| // If we have generics we need to add a type scope. |
| PushTypeScope(); |
| should_pop = true; |
| generics = ParseSequence<TypeVar>( |
| TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
| auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
| return BindTypeVar(type_var_name, TypeKind::kType); |
| }); |
| } |
| |
| Array<tvm::Constructor> ctors; |
| if (Peek()->token_type == TokenType::kLCurly) { |
| // Parse the list of constructors. |
| ctors = ParseSequence<tvm::Constructor>( |
| TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { |
| // First match the name of the constructor. |
| auto ctor_tok = Match(TokenType::kIdentifier); |
| auto ctor_name = ctor_tok.ToString(); |
| |
| Constructor ctor; |
| // Match the optional field list. |
| if (Peek()->token_type != TokenType::kOpenParen) { |
| ctor = tvm::Constructor(ctor_name, {}, type_global); |
| } else { |
| auto arg_types = |
| ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
| TokenType::kCloseParen, [&]() { return ParseType(); }); |
| ctor = tvm::Constructor(ctor_name, arg_types, type_global); |
| } |
| |
| ICHECK(ctor.defined()); |
| |
| try { |
| this->ctors.Add(ctor_name, ctor); |
| } catch (const DuplicateKeyError& e) { |
| this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span) |
| << "a constructor with the name " |
| << "`" << ctor_name << "` " |
| << "was previously defined"); |
| } |
| |
| return ctor; |
| }); |
| } |
| |
| // Now pop the type scope. |
| if (should_pop) { |
| PopTypeScopes(1); |
| } |
| |
| return TypeData(type_global, generics, ctors); |
| } |
| |
| std::string HackTokensAsString(int n) { |
| std::stringstream key; |
| n = std::min(static_cast<int>(tokens.size() - pos), n); |
| for (int i = 0; i < n; i++) { |
| key << ToString(tokens.at(pos + i)->token_type); |
| } |
| return key.str(); |
| } |
| |
| std::vector<Rule> ParseOp() { |
| std::vector<Rule> matched; |
| Peek(); |
| for (int i = 4; i > 0; i--) { |
| auto key = HackTokensAsString(i); |
| auto it = this->op_table.this_is_a_hack.find(key); |
| if (it != this->op_table.this_is_a_hack.end()) { |
| pos = pos + i; |
| matched.push_back(it->second); |
| } |
| } |
| |
| return matched; |
| } |
| |
| /*! \brief Parse a single Relay expression. */ |
| Expr ParseExpr() { |
| VLOG(9) << "Parser::ParseExpr"; |
| return WithSpan<Expr>([this] { |
| std::vector<Expr> exprs; |
| |
| while (true) { |
| VLOG(9) << "Parser::ParseExpr: parsing a single expression"; |
| auto next = Peek(); |
| switch (next->token_type) { |
| // For graph or let, match first rhs, then invoke ParseBindingExpr |
| // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue |
| case TokenType::kLCurly: { |
| // NB: Might need to optimize to remove deep recursion. |
| // Stack should only grow proportionally to the number of |
| // nested scopes. |
| // Parses `{` expression `}`. |
| auto block = WithSpan<Expr>([&]() { |
| return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() { |
| PushScope(); |
| auto expr = ParseExpr(); |
| PopScopes(1); |
| return expr; |
| }); |
| }); |
| exprs.push_back(block); |
| break; |
| } |
| case TokenType::kFreeVar: { |
| Consume(TokenType::kFreeVar); |
| auto var_token = Match(TokenType::kLocal); |
| |
| Type type; |
| if (WhenMatch(TokenType::kColon)) { |
| type = ParseType(); |
| } else { |
| type = IncompleteType(); |
| } |
| |
| BindFreeVar(var_token.ToString(), type); |
| break; |
| } |
| // Parses `let ...`; |
| case TokenType::kLet: |
| exprs.push_back(ParseBindingExpr()); |
| break; |
| case TokenType::kMatch: |
| case TokenType::kPartialMatch: { |
| bool is_total = next->token_type == TokenType::kMatch; |
| Consume(next->token_type); |
| exprs.push_back(ParseMatch(is_total)); |
| break; |
| } |
| |
| // %x ... |
| case TokenType::kGraph: |
| if (Lookahead(2)->token_type == TokenType::kEqual) { |
| exprs.push_back(ParseBindingExpr()); |
| break; |
| } |
| // intentional fall through here. |
| default: { |
| exprs.push_back(ParseExprBinOp()); |
| break; |
| } |
| } |
| |
| if (!WhenMatch(TokenType::kSemicolon)) { |
| break; |
| } |
| } |
| |
| ICHECK_GE(exprs.size(), 1); |
| |
| if (exprs.size() == 1) { |
| // ICHECK(exprs[0].defined() && exprs[0]->span.defined()) |
| // << "parser must set expression spans.\n" |
| // << exprs[0]; |
| return exprs[0]; |
| } else { |
| auto body = exprs.back(); |
| exprs.pop_back(); |
| while (exprs.size()) { |
| auto value = exprs.back(); |
| ICHECK(value->span.defined()) << "parser must set expression spans."; |
| exprs.pop_back(); |
| body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span)); |
| } |
| ICHECK(body->span.defined()) << "parser must set expression spans."; |
| return body; |
| } |
| }); |
| } |
| |
| /*! \brief Parse a "binding expression"; an expression where |
| * a graph or let variable is bound. |
| * |
| * In order to avoid stack overflow this is implemented in a special |
| * iterative way to keep stack depth constant in a long chain of bindings. |
| */ |
| Expr ParseBindingExpr() { |
| // We use a loop here so that the stack depth |
| // does not grow linearly with a sequence of |
| // graph or let bindings. |
| // |
| // Assuming we start at call depth k, we will |
| // enter k + c call frames to parse the RHS |
| // of the bindings where `c` is the depth |
| // of recursion needed by RHS. |
| // |
| // If RHS is a call expresssion the c=1. |
| // |
| // Once we have parsed the RHS we will be |
| // back at depth K, and will return to |
| // this loop header to parse another |
| // graph or let binding. |
| // |
| // This ensures for n sequential bindings |
| // the call depth will be the same before |
| // and after parsing the n bindings. |
| VLOG(9) << "Parser::ParseBindingExpr"; |
| std::vector<std::tuple<Var, Expr, Span>> bindings; |
| int scopes = 0; |
| |
| while (true) { |
| auto next = Peek(); |
| if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { |
| Match(TokenType::kGraph); |
| Match(TokenType::kEqual); |
| auto val = this->ParseExprBinOp(); |
| Match(TokenType::kSemicolon); |
| AddGraphBinding(next, val); |
| } else if (next->token_type == TokenType::kLet) { |
| auto span = next->span; |
| // Parse the 'let'. |
| Consume(TokenType::kLet); |
| |
| // Parse the local '%<id>'. |
| auto local_tok = Match(TokenType::kLocal); |
| auto string = local_tok.ToString(); |
| |
| // Parse the optional type annotation (':' <type>). |
| Type type; |
| if (WhenMatch(TokenType::kColon)) { |
| type = ParseType(); |
| } |
| |
| auto var = BindVar(string, type); |
| |
| // Parse the '='; |
| Match(TokenType::kEqual); |
| |
| // Parse the body, and the ';'. |
| auto val = this->ParseExprBinOp(); |
| Consume(TokenType::kSemicolon); |
| |
| // Add the bindings to the local data structure. |
| std::tuple<relay::Var, relay::Expr, Span> tuple(var, val, span); |
| bindings.push_back(tuple); |
| scopes++; |
| PushScope(); |
| } else { |
| // This is the only case we will increase the stack |
| // depth. |
| // |
| // If we parse a program which is a sequence of N bindings |
| // followed by a single body expression we will end up with |
| // a call depth of 3, the first call to ParseExpr, then |
| // ParseBindingExpr, then finally ParseExpr once more. |
| |
| auto body = this->ParseExpr(); |
| |
| // Remove the same number of scopes we added. |
| PopScopes(scopes); |
| |
| if (bindings.size() == 0) { |
| return body; |
| } else { |
| // We can now build the let binding up backwards. |
| for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) { |
| auto span = body->span.Merge(std::get<2>(*binding)); |
| body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span); |
| } |
| return body; |
| } |
| } |
| } |
| } |
| |
| /*! Parse a function definition without a leading keyword or identifier. |
| * |
| * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. |
| */ |
| Function ParseFunctionDef() { |
| VLOG(9) << "Parser::ParseFunctionDef"; |
| return WithSpan<Function>([&]() { |
| PushScope(); |
| PushTypeScope(); |
| |
| Array<TypeVar> generics; |
| if (Peek()->token_type == TokenType::kLSquare) { |
| generics = ParseSequence<TypeVar>( |
| TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
| auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
| return BindTypeVar(type_var_name, TypeKind::kType); |
| }); |
| } |
| |
| Map<String, ObjectRef> raw_attrs; |
| |
| auto params = ParseSequence<Var>( |
| TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
| [&]() { |
| auto token = Match(TokenType::kLocal); |
| auto string = token.ToString(); |
| |
| // The fake attributes where the virtual device is specified. |
| VirtualDevice virtual_device; |
| if (WhenMatch(TokenType::kLCurly)) { |
| Map<String, ObjectRef> fake_attrs = ParseAttrs(); |
| VLOG(9) << "Fake attributes for function parameter: " << fake_attrs; |
| Match(TokenType::kRCurly); |
| if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) { |
| ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>()) |
| << "Expected the " << kVirtualDevice |
| << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey(); |
| virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]); |
| } |
| } |
| |
| Type type; |
| if (WhenMatch(TokenType::kColon)) { |
| type = ParseType(); |
| } |
| return BindVar(string, type, virtual_device); |
| }, |
| [&] { |
| auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
| auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
| |
| if (is_ident && next_is_equal) { |
| raw_attrs = ParseAttrs(); |
| return true; |
| } |
| |
| return false; |
| }); |
| |
| Type ret_type; |
| if (WhenMatch(TokenType::kMinus)) { |
| Match(TokenType::kRAngle); |
| ret_type = ParseType(); |
| } |
| |
| auto body = Block<Expr>([&]() { return ParseExpr(); }); |
| |
| PopTypeScopes(1); |
| PopScopes(1); |
| |
| // TODO(@jroesch): attributes should never be null, they should always be empty. |
| if (raw_attrs.size()) { |
| // Promote kVirtualDevice to first-class |
| if (raw_attrs.count(kVirtualDevice)) { |
| ObjectRef vid = raw_attrs.at(kVirtualDevice); |
| ICHECK(vid.as<VirtualDeviceNode>()) |
| << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " |
| << vid->GetTypeKey(); |
| |
| DictAttrs attrs; |
| // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the |
| // attributes |
| if (raw_attrs.size() > 1) { |
| raw_attrs.erase(kVirtualDevice); |
| attrs = DictAttrs(raw_attrs); |
| } |
| Function func = relay::Function(params, body, ret_type, generics, attrs); |
| func->virtual_device_ = vid; |
| return func; |
| } else { |
| return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); |
| } |
| } else { |
| return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); |
| } |
| }); |
| } |
| |
| /*! \brief Parse an if-expression. */ |
| Expr ParseIf() { |
| return WithSpan<Expr>([&]() { |
| VLOG(9) << "Parser::ParseIf"; |
| Consume(TokenType::kIf); |
| |
| auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); }); |
| |
| auto true_branch = Block<Expr>([&] { |
| this->PushScope(); |
| auto expr = ParseExpr(); |
| this->PopScopes(1); |
| return expr; |
| }); |
| |
| Match(TokenType::kElse); |
| |
| auto false_branch = Block<Expr>([&] { |
| this->PushScope(); |
| auto expr = ParseExpr(); |
| this->PopScopes(1); |
| return expr; |
| }); |
| |
| return relay::If(guard, true_branch, false_branch); |
| }); |
| } |
| |
| /* This factors parsing a list of patterns for both tuples, and constructors. */ |
| Array<Pattern> ParsePatternList() { |
| return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
| [&] { return ParsePattern(); }); |
| } |
| |
| /*! \brief Parses a pattern for a match expression. |
| * |
| * A pattern is either a wildcard `_`, a local `%name`, |
| * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn). |
| * |
| * This function recursively parses a pattern. |
| */ |
| Pattern ParsePattern() { |
| VLOG(9) << "Parser::ParsePattern"; |
| auto next = Peek(); |
| switch (next->token_type) { |
| case TokenType::kUnderscore: { |
| Match(TokenType::kUnderscore); |
| return PatternWildcard(); |
| } |
| case TokenType::kLocal: { |
| auto id = Match(TokenType::kLocal); |
| Type type_annotation; |
| if (WhenMatch(TokenType::kColon)) { |
| type_annotation = ParseType(); |
| } |
| auto var = BindVar(id.ToString(), type_annotation); |
| return PatternVar(var); |
| } |
| case TokenType::kIdentifier: { |
| auto id = Match(TokenType::kIdentifier); |
| auto ctor = ctors.Get(id.ToString()); |
| if (!ctor) { |
| diag_ctx.EmitFatal( |
| // TODO(@jroesch): split into error and help |
| // deal with multiple rendering |
| Diagnostic::Error(id->span) |
| << "undefined constructor name `" << id.ToString() |
| << "`, perhaps you intended to write a" |
| << "pattern variable, considering changing this to `%" << id.ToString() << "`"); |
| } |
| if (Peek()->token_type == TokenType::kOpenParen) { |
| auto fields = ParsePatternList(); |
| return PatternConstructor(ctor.value(), fields); |
| } else { |
| return PatternConstructor(ctor.value(), {}); |
| } |
| } |
| default: |
| return PatternTuple(ParsePatternList()); |
| } |
| } |
| |
| Clause ParseMatchArm() { |
| PushScope(); |
| auto pattern = ParsePattern(); |
| Match(TokenType::kEqual); |
| Consume(TokenType::kRAngle); |
| auto expr = ParseExpr(); |
| PopScopes(1); |
| return Clause(pattern, expr); |
| } |
| |
| Expr ParseMatch(bool is_total) { |
| return WithSpan<Expr>([&]() { |
| Expr scrutinee = ParseAtomicExpr(); |
| |
| Array<Clause> clauses = |
| ParseSequence<Clause>(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, |
| [&] { return ParseMatchArm(); }); |
| |
| return relay::Match(scrutinee, clauses, is_total); |
| }); |
| } |
| |
| Expr ParseExprBinOp() { |
| VLOG(9) << "Parser::ParseExprBinOp"; |
| return WithSpan<Expr>([this] { |
| // We must parse at least one expression, the default |
| // case is that there is no operator and we will fall |
| // through. |
| std::vector<Expr> exprs; |
| Expr expr = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
| |
| exprs.push_back(expr); |
| |
| // Now we parse an optional op. |
| std::vector<Rule> ops; |
| |
| // We will now parse 0 or more operator occurrences. |
| while (true) { |
| auto opt_op = ParseOp(); |
| |
| // If we didn't parse one we done. |
| if (opt_op.size() == 0) { |
| break; |
| } |
| |
| // Read the operation we parsed; |
| auto op = opt_op[0]; |
| |
| Expr right = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
| ICHECK(right->span.defined()); |
| |
| // If the operator stack is empty |
| // we parse an operator and expression |
| // and push them to stacks, then |
| // continue. |
| if (ops.size() == 0) { |
| ops.push_back(op); |
| exprs.push_back(right); |
| continue; |
| } |
| |
| if (op.precedence > ops.back().precedence || |
| (op.precedence == ops.back().precedence && op.left_assoc == false)) { |
| ops.push_back(op); |
| exprs.push_back(right); |
| continue; |
| } |
| |
| while (ops.size() && (op.precedence < ops.back().precedence || |
| (op.precedence == ops.back().precedence && op.left_assoc == true))) { |
| Rule new_op = ops.back(); |
| ops.pop_back(); |
| Expr right = exprs.back(); |
| exprs.pop_back(); |
| Expr left = exprs.back(); |
| exprs.pop_back(); |
| ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
| exprs.push_back( |
| relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
| } |
| |
| exprs.push_back(right); |
| ops.push_back(op); |
| } |
| |
| while (ops.size()) { |
| Rule new_op = ops.back(); |
| ops.pop_back(); |
| Expr right = exprs.back(); |
| exprs.pop_back(); |
| Expr left = exprs.back(); |
| exprs.pop_back(); |
| ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
| exprs.push_back( |
| relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
| } |
| |
| ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack."; |
| |
| ICHECK_EQ(exprs.size(), 1) |
| << "Only a single expression should be left on the expression stack."; |
| |
| return exprs[0]; |
| }); |
| } |
| |
| ObjectRef ParseAttributeValue() { |
| VLOG(9) << "Parser::ParseAttributeValue"; |
| auto next = Peek(); |
| switch (next->token_type) { |
| case TokenType::kFloat: |
| case TokenType::kInteger: |
| case TokenType::kBoolean: |
| case TokenType::kStringLiteral: |
| return Match(next->token_type)->data; |
| case TokenType::kMetaReference: |
| return ParseMetaRef(); |
| case TokenType::kLSquare: { |
| return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
| [&]() { return ParseAttributeValue(); }); |
| } |
| case TokenType::kOpenParen: { |
| // TODO(@jroesch: need to figure out bracket vs. sequence) |
| // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma, |
| // TokenType::kCloseParen, |
| // [&]() { return ParseAttributeValue(); }); |
| return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen, |
| [&]() { return ParseAttributeValue(); }); |
| } |
| // TODO(@jroesch): not sure about this being the right way to handle nulls. |
| case TokenType::kIdentifier: { |
| if (auto text = next->data.as<tvm::StringObj>()) { |
| std::string id = GetRef<String>(text); |
| if (id == "nullptr") { |
| Match(TokenType::kIdentifier); |
| return ObjectRef(); |
| } |
| if (id == "None") { |
| Match(TokenType::kIdentifier); |
| return Optional<ObjectRef>(); |
| } |
| } |
| } |
| default: |
| return ParseAtomicExpr(); |
| } |
| } |
| |
| Map<String, ObjectRef> ParseAttrs() { |
| VLOG(9) << "Parser::ParseAttrs"; |
| Map<String, ObjectRef> kwargs; |
| while (Peek()->token_type == TokenType::kIdentifier) { |
| auto key = GetHierarchicalName(ParseHierarchicalName().data); |
| Match(TokenType::kEqual); |
| // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. |
| auto value = ParseAttributeValue(); |
| // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text |
| // format is bad. |
| kwargs.Set(key, value); |
| WhenMatch(TokenType::kComma); |
| } |
| VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs; |
| return kwargs; |
| } |
| |
| Expr ParseCallArgs(Expr op) { |
| ICHECK(op.defined()) << "the operator must be defined"; |
| |
| VLOG(9) << "Parser::ParseCallArgs"; |
| Attrs attrs; |
| std::string op_key; |
| bool is_op = false; |
| |
| if (auto op_node = op.as<OpNode>()) { |
| is_op = true; |
| op_key = op_node->attrs_type_key; |
| } |
| |
| if (Peek()->token_type == TokenType::kOpenParen) { |
| Array<Expr> args = ParseSequence<Expr>( |
| TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
| [&] { return ParseExpr(); }, |
| [&] { |
| auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
| auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
| auto is_pretty_attrs = is_ident && next_is_equal; |
| auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference; |
| // TODO(@jroesch): might not handle trailing comma |
| auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; |
| auto is_meta_attrs = is_meta_next && last_meta; |
| |
| if (is_pretty_attrs || is_meta_attrs) { |
| if (is_meta_attrs) { |
| auto meta_ref = ParseMetaRef(); |
| if (meta_ref.as<BaseAttrsNode>()) { |
| attrs = Downcast<Attrs>(meta_ref); |
| } else { |
| // Not awesome parsing code here. |
| this->pos--; |
| return false; |
| } |
| } else { |
| auto raw_attrs = ParseAttrs(); |
| if (is_op && op_key.size()) { |
| auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); |
| ICHECK(attr_obj.defined()); |
| attrs = Downcast<Attrs>(attr_obj); |
| } else if (raw_attrs.count("attrs_type_key")) { |
| String attr_key = Downcast<String>(raw_attrs["attrs_type_key"]); |
| if (attr_key.size()) { |
| raw_attrs.erase("attrs_type_key"); |
| auto attr_obj = |
| tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); |
| ICHECK(attr_obj.defined()); |
| attrs = Downcast<Attrs>(attr_obj); |
| } |
| } else { |
| this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) |
| << "unable to determine the 'attrs_type_key' with which " |
| "to represent the call attributes for this operator"); |
| } |
| } |
| return true; |
| } |
| return false; |
| }); |
| |
| if (!attrs.defined()) { |
| if (is_op && op_key.size()) { |
| auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {}); |
| ICHECK(attr_obj.defined()); |
| attrs = Downcast<Attrs>(attr_obj); |
| } |
| } |
| |
| // TODO(@jroesch): in a secondary pass adjust spans. |
| return Expr(Call(op, args, attrs, {})); |
| } else { |
| return Expr(); |
| } |
| |
| return Expr(); |
| } |
| |
| Expr ParseCallExpr() { |
| VLOG(9) << "Parser::ParseCallExpr"; |
| return WithSpan<Expr>([this] { |
| Expr expr = ParseAtomicExpr(); |
| // Parse as many call args as possible, building up expression |
| // |
| // NB(@jroesch): this seems like a hack but in order to parse curried functions |
| // and avoid complex grammar we will parse multiple call lists in a row. |
| while (Peek()->token_type == TokenType::kOpenParen) { |
| auto new_expr = ParseCallArgs(expr); |
| |
| if (new_expr.defined()) { |
| expr = new_expr; |
| } else { |
| break; |
| } |
| } |
| |
| // We need a zero-arity case for constructors. |
| if (auto ctor_node = expr.as<ConstructorNode>()) { |
| if (ctor_node->inputs.size() == 0) { |
| return Expr(Call(expr, {})); |
| } |
| } |
| |
| return expr; |
| }); |
| } |
| |
| Expr GetOp(const std::string& op_name, const Span& span) { |
| VLOG(9) << "op_name=" << op_name << " span=" << span; |
| try { |
| return Op::Get(op_name); |
| } catch (const Error& e) { |
| // we can relax this, but probably need to relax checks or return non-null here. |
| this->diag_ctx.EmitFatal(Diagnostic::Error(span) |
| << "operator `" << op_name |
| << "` not found, perhaps you forgot to register it?"); |
| return Expr(); |
| } |
| } |
| |
| Expr ParseAtomicExpr() { |
| VLOG(9) << "Parser::ParseAtomicExpr"; |
| Expr expr = WithSpan<Expr>([this] { |
| auto next = Peek(); |
| switch (next->token_type) { |
| case TokenType::kInteger: |
| case TokenType::kFloat: { |
| Consume(next->token_type); |
| auto number = NumberToNDArray(next); |
| Expr e = Constant(number, next->span); |
| ICHECK(e->span.defined()) << "constant spans must be defined"; |
| return e; |
| } |
| case TokenType::kBoolean: { |
| Consume(TokenType::kBoolean); |
| int64_t value = Downcast<tvm::Integer>(next->data).IntValue(); |
| Expr e = Constant(support::BoolToNDArray(value), next->span); |
| ICHECK(e->span.defined()) << "constant spans must be defined"; |
| return e; |
| } |
| // Parse a local of the form `%x`. |
| case TokenType::kLocal: { |
| Consume(TokenType::kLocal); |
| return Expr(LookupLocal(next)); |
| } |
| // Parse a local of the form `@x`. |
| case TokenType::kGlobal: { |
| auto global_name = next.ToString(); |
| Consume(TokenType::kGlobal); |
| auto global = AddOrGet(&global_names, global_name); |
| return Expr(global); |
| } |
| // Parse a local of the form `x`. |
| // Right now we fail to parse `x.y`. |
| case TokenType::kIdentifier: { |
| auto ctor = ctors.Get(next.ToString()); |
| if (ctor) { |
| Consume(TokenType::kIdentifier); |
| return Expr(ctor.value()); |
| } else { |
| auto spanned_idents = ParseHierarchicalName(); |
| auto idents = spanned_idents.data; |
| auto span = spanned_idents.span; |
| return GetOp(GetHierarchicalName(idents), span); |
| } |
| } |
| case TokenType::kGraph: { |
| Consume(TokenType::kGraph); |
| return LookupGraphBinding(next); |
| } |
| case TokenType::kMetaReference: { |
| return Downcast<Expr>(ParseMetaRef()); |
| } |
| case TokenType::kFn: { |
| Consume(TokenType::kFn); |
| Expr e = ParseFunctionDef(); |
| ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; |
| return e; |
| } |
| case TokenType::kIf: { |
| Expr e = ParseIf(); |
| return e; |
| } |
| case TokenType::kRef: { |
| Consume(TokenType::kRef); |
| Match(TokenType::kOpenParen); |
| auto ref_value = ParseExpr(); |
| Match(TokenType::kCloseParen); |
| return static_cast<Expr>(RefCreate(ref_value)); |
| } |
| case TokenType::kRefRead: { |
| return WithSpan<Expr>([&]() { |
| Consume(TokenType::kRefRead); |
| Match(TokenType::kOpenParen); |
| auto ref = ParseExpr(); |
| Match(TokenType::kCloseParen); |
| return static_cast<Expr>(RefRead(ref)); |
| }); |
| } |
| case TokenType::kRefWrite: { |
| return WithSpan<Expr>([&]() { |
| Consume(TokenType::kRefWrite); |
| Match(TokenType::kOpenParen); |
| auto ref = ParseExpr(); |
| Match(TokenType::kComma); |
| auto value = ParseExpr(); |
| Match(TokenType::kCloseParen); |
| return static_cast<Expr>(RefWrite(ref, value)); |
| }); |
| } |
| case TokenType::kOpenParen: { |
| Span sp = next->span; |
| Consume(TokenType::kOpenParen); |
| // parse '(' ')' |
| if (WhenMatch(TokenType::kCloseParen)) { |
| return Expr(Tuple(Array<Expr>())); |
| } else { |
| Expr subexpr = ParseExpr(); |
| // parse '(' expr ')' |
| if (WhenMatch(TokenType::kCloseParen)) { |
| return subexpr; |
| // parse '( expr ',' * ')' |
| } else if (WhenMatch(TokenType::kComma)) { |
| Array<Expr> exprs = {subexpr}; |
| while (true) { |
| if (WhenMatch(TokenType::kCloseParen)) { |
| break; |
| } else { |
| auto element = ParseExpr(); |
| auto comma = Peek(); |
| if (WhenMatch(TokenType::kComma)) { |
| sp = sp.Merge(element->span.Merge(comma->span)); |
| } else { |
| sp = sp.Merge(element->span); |
| } |
| exprs.push_back(element); |
| } |
| } |
| Expr tuple = Tuple(exprs, sp); |
| ICHECK(tuple->span.defined()) << "tuple span should be defined"; |
| return tuple; |
| } |
| } |
| } |
| default: { |
| this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
| << "expected an expression found " << Pretty(next->token_type)); |
| return Expr(); |
| } |
| } |
| }); |
| |
| if (WhenMatch(TokenType::kPeriod)) { |
| auto token = Match(TokenType::kInteger); |
| auto index = token.ToNumber(); |
| auto span = token->span.Merge(expr->span); |
| VLOG(9) << "Parser::ParseAtomicExpr: tuple get item"; |
| return relay::TupleGetItem(expr, index, span); |
| } else { |
| return expr; |
| } |
| } |
| |
| /*! \brief Parse a hierarchical name. |
| * |
| * The tokenizer produces a token stream of <id1> . <id2> |
| * and so on for names of the form `nn.conv2d`. |
| * Currently we only use string names everywhere instead |
| * of a notion of a hierarchical name. |
| * |
| * The below utility reassembles a token stream into a |
| * single stream inserting the required periods needed |
| * to look up registered names. |
| */ |
| Spanned<Array<String>> ParseHierarchicalName() { |
| Array<String> idents; |
| Span span; |
| while (Peek()->token_type == TokenType::kIdentifier) { |
| auto token = Peek(); |
| |
| if (span.defined()) { |
| span = span.Merge(token->span); |
| } else { |
| span = token->span; |
| } |
| |
| auto name = token.ToString(); |
| idents.push_back(name); |
| Consume(TokenType::kIdentifier); |
| |
| // Keep parsing while we see a trailing period. |
| if (Peek()->token_type == TokenType::kPeriod) { |
| Consume(TokenType::kPeriod); |
| continue; |
| } else { |
| // No more periods means we are done! |
| break; |
| } |
| } |
| |
| return Spanned<Array<String>>(idents, span); |
| } |
| |
| std::string GetHierarchicalName(Array<String> idents) { |
| ICHECK_NE(idents.size(), 0); |
| std::stringstream hierarchical_name; |
| int i = 0; |
| int periods = idents.size() - 1; |
| for (auto ident : idents) { |
| hierarchical_name << ident; |
| if (i < periods) { |
| hierarchical_name << "."; |
| i++; |
| } |
| } |
| return hierarchical_name.str(); |
| } |
| |
| /*! \brief Parse a shape. */ |
| Array<tvm::PrimExpr> ParseShape() { |
| auto dims = ParseSequence<tvm::PrimExpr>( |
| TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { |
| tvm::PrimExpr dim; |
| if (Peek()->token_type == TokenType::kMetaReference) { |
| dim = Downcast<tvm::PrimExpr>(ParseMetaRef()); |
| } else if (WhenMatch(TokenType::kQuestion)) { |
| dim = tvm::tir::Any(); |
| } else { |
| dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data); |
| } |
| |
| return dim; |
| }); |
| return dims; |
| } |
| |
| /*! \brief Parse a function type. */ |
| Type ParseFunctionType() { |
| auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
| TokenType::kCloseParen, [&]() { return ParseType(); }); |
| |
| Match(TokenType::kMinus); |
| Match(TokenType::kRAngle); |
| auto ret_type = ParseType(); |
| |
| return relay::FuncType(ty_params, ret_type, {}, {}); |
| } |
| |
| // Parses a user defined ADT or type variable. |
| Type ParseNonPrimitiveType(const Token& tok) { |
| return WithSpan<Type>([&]() { |
| auto name = tok.ToString(); |
| Type head_type = LookupTypeVar(tok); |
| |
| if (!head_type.defined()) { |
| // head_type = type_names.Get(name); |
| head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle); |
| } |
| |
| if (!head_type.defined()) { |
| diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
| << "the type constructor `" << name << "` is undefined"); |
| } |
| |
| Array<Type> arg_types; |
| if (Peek()->token_type == TokenType::kLSquare) { |
| arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
| [&]() { return ParseType(); }); |
| } |
| |
| if (arg_types.size()) { |
| return static_cast<Type>(TypeCall(head_type, arg_types)); |
| } else { |
| if (head_type.as<GlobalTypeVarNode>()) { |
| return static_cast<Type>(TypeCall(head_type, {})); |
| } else { |
| return static_cast<Type>(head_type); |
| } |
| } |
| }); |
| } |
| |
| /*! \brief Parses a TVM type. |
| * |
| * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type, |
| * a scalar type or an incomplete type `_`. |
| */ |
| Type ParseType() { |
| return WithSpan<Type>([&]() -> Type { |
| auto tok = Peek(); |
| |
| if (tok->token_type == TokenType::kOpenParen) { |
| auto tys = |
| ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma, |
| TokenType::kCloseParen, [&]() { return ParseType(); }); |
| return relay::TupleType(tys); |
| } else if (WhenMatch(TokenType::kFn)) { |
| return ParseFunctionType(); |
| } else if (WhenMatch(TokenType::kIdentifier)) { |
| auto id = tok.ToString(); |
| if (id == "Tensor") { |
| Match(TokenType::kLSquare); |
| auto shape = ParseShape(); |
| Match(TokenType::kComma); |
| auto dtype_tok = Match(TokenType::kIdentifier); |
| auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); |
| Match(TokenType::kRSquare); |
| return TensorType(shape, dtype); |
| } else { |
| auto ty = tok.ToString(); |
| if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 || |
| ty.find("bool", 0) == 0) { |
| // Need to do better error handling here. |
| auto dtype = DataType(String2DLDataType(tok.ToString())); |
| return TensorType({}, dtype); |
| } else { |
| return ParseNonPrimitiveType(tok); |
| } |
| } |
| } else if (WhenMatch(TokenType::kUnderscore)) { |
| return IncompleteType(); |
| } else { |
| this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
| << "failed to parse type found " << tok); |
| return Type(); |
| } |
| }); |
| } |
| |
| template <typename R> |
| R ConsumeWhitespace(std::function<R()> func) { |
| auto old = this->ignore_whitespace; |
| this->ignore_whitespace = true; |
| while (tokens[pos]->token_type == TokenType::kWhitespace) { |
| pos++; |
| } |
| auto res = func(); |
| this->ignore_whitespace = old; |
| return res; |
| } |
| |
| Map<String, Array<ObjectRef>> ParseMetadata() { |
| if (Peek()->token_type == TokenType::kMetadata) { |
| return Match(TokenType::kMetadata).ToMetadata(); |
| } else { |
| return Map<String, Array<ObjectRef>>(); |
| } |
| } |
| |
| /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ |
| void DisplayNextN(int n) { |
| std::cout << "remaining tokens: " << std::endl; |
| auto bound = std::min(pos + n, static_cast<int>(tokens.size())); |
| for (int i = 0; i < bound - pos; i++) { |
| std::cout << tokens[pos + i] << std::endl; |
| } |
| } |
| |
| // A function for debugging the operator parser. |
| void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) { |
| std::cout << "Expr Stack: "; |
| for (auto expr : exprs) { |
| std::cout << expr << ", "; |
| } |
| |
| std::cout << std::endl; |
| std::cout << "Op Stack: "; |
| for (auto rule : rules) { |
| std::cout << rule.op << ", "; |
| } |
| |
| std::cout << std::endl; |
| } |
| }; |
| |
| Parser InitParser(const std::string& file_name, const std::string& file_content, |
| const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
| VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); |
| SourceName src_name = SourceName::Get(file_name); |
| Source source(src_name, file_content); |
| |
| IRModule module; |
| if (!init_module) { |
| SourceMap source_map; |
| module = IRModule({}, {}, {}, source_map); |
| } else { |
| module = init_module.value(); |
| } |
| |
| module->source_map.Add(source); |
| |
| auto diag_ctx = DiagnosticContext::Default(module); |
| auto tokens_and_table = Tokenize(diag_ctx, source); |
| |
| auto tokens = tokens_and_table.first; |
| MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); |
| |
| // Merge any entries in init_meta_table into anything captured in the #[metadata] section |
| // of the file_content. Metadata references within file_content must use indexes which account |
| // for this ordering. |
| for (const auto& pair : init_meta_table) { |
| Array<ObjectRef> items; |
| if (meta_data_table.count(pair.first)) { |
| items = meta_data_table[pair.first]; |
| } |
| for (const auto& obj : pair.second) { |
| items.push_back(obj); |
| } |
| meta_data_table.Set(pair.first, items); |
| } |
| |
| return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); |
| } |
| |
| IRModule ParseModule(const std::string& file_name, const std::string& file_content, |
| const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
| VLOG_CONTEXT << "ParseModule"; |
| VLOG(9) << "parsing and type-checking " << file_name; |
| auto parser = InitParser(file_name, file_content, init_module, init_meta_table); |
| auto mod = parser.ParseModule(); |
| ICHECK(mod.defined()) << "The parser must return a non-null module."; |
| // NB(@jroesch): it is very important that we render any errors before we proceed |
| // if there were any errors which allow the parser to proceed we must render them |
| // here. |
| parser.diag_ctx.Render(); |
| auto infer_type = tvm::relay::transform::InferType(); |
| ICHECK(infer_type.defined()) << "The type inferencer must be non-null."; |
| return infer_type(mod); |
| } |
| |
| Expr ParseExpr(const std::string& file_name, const std::string& file_content) { |
| VLOG(9) << "ParseExpr"; |
| auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable()); |
| parser.ParseSemVer(false); |
| parser.PushScope(); |
| auto expr = parser.ParseExpr(); |
| parser.Match(TokenType::kEndOfFile); |
| // NB(@jroesch): it is very important that we render any errors before we proceed |
| // if there were any errors which allow the parser to proceed we must render them |
| // here. |
| parser.diag_ctx.Render(); |
| return expr; |
| } |
| |
| TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") |
| .set_body_typed([](const std::string& file_name, const std::string& file_content, |
| const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
| return ParseModule(file_name, file_content, init_module, init_meta_table); |
| }); |
| |
| TVM_REGISTER_GLOBAL("parser.ParseModule") |
| .set_body_typed([](const std::string& file_name, const std::string& file_content) { |
| return ParseModule(file_name, file_content); |
| }); |
| |
| TVM_REGISTER_GLOBAL("parser.ParseExpr") |
| .set_body_typed([](tvm::String file_name, tvm::String file_content) { |
| return ParseExpr(file_name, file_content); |
| }); |
| |
| /*! |
| * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources |
| * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for |
| * modules constructed programaticaly rather than textually. |
| */ |
| Pass AnnotateSpans() { |
| auto pass_func = [](const IRModule& mod, const PassContext& ctx) { |
| String text = AsText(mod, /*show_meta_data=*/true); |
| VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; |
| return ParseModule("GeneratedSource", text); |
| }; |
| return CreateModulePass(pass_func, 0, "AnnotateSpans", {}); |
| } |
| |
| TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans); |
| |
| } // namespace parser |
| } // namespace tvm |