| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/runtime/logging.h> |
| #include <tvm/script/printer/doc.h> |
| |
| #include <algorithm> |
| #include <cmath> |
| #include <string> |
| |
| #include "../../../support/str_escape.h" |
| #include "../../../support/utils.h" |
| #include "./base_doc_printer.h" |
| |
| namespace tvm { |
| namespace script { |
| namespace printer { |
| |
| /*! |
| * \brief Operator precedence |
| * |
| * This is based on |
| * https://docs.python.org/3/reference/expressions.html#operator-precedence |
| */ |
| enum class ExprPrecedence : int32_t { |
| /*! \brief Unknown precedence */ |
| kUnkown = 0, |
| /*! \brief Lambda Expression */ |
| kLambda = 1, |
| /*! \brief Conditional Expression */ |
| kIfThenElse = 2, |
| /*! \brief Boolean OR */ |
| kBooleanOr = 3, |
| /*! \brief Boolean AND */ |
| kBooleanAnd = 4, |
| /*! \brief Boolean NOT */ |
| kBooleanNot = 5, |
| /*! \brief Comparisons */ |
| kComparison = 6, |
| /*! \brief Bitwise OR */ |
| kBitwiseOr = 7, |
| /*! \brief Bitwise XOR */ |
| kBitwiseXor = 8, |
| /*! \brief Bitwise AND */ |
| kBitwiseAnd = 9, |
| /*! \brief Shift Operators */ |
| kShift = 10, |
| /*! \brief Addition and subtraction */ |
| kAdd = 11, |
| /*! \brief Multiplication, division, floor division, remainder */ |
| kMult = 12, |
| /*! \brief Positive negative and bitwise NOT */ |
| kUnary = 13, |
| /*! \brief Exponentiation */ |
| kExp = 14, |
| /*! \brief Index access, attribute access, call and atom expression */ |
| kIdentity = 15, |
| }; |
| |
| ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { |
| // Key is the value of OperationDocNode::Kind |
| static const std::vector<ExprPrecedence> op_kind_precedence = []() { |
| using OpKind = OperationDocNode::Kind; |
| std::map<OpKind, ExprPrecedence> raw_table = { |
| {OpKind::kUSub, ExprPrecedence::kUnary}, |
| {OpKind::kInvert, ExprPrecedence::kUnary}, |
| {OpKind::kNot, ExprPrecedence::kBooleanNot}, |
| {OpKind::kAdd, ExprPrecedence::kAdd}, |
| {OpKind::kSub, ExprPrecedence::kAdd}, |
| {OpKind::kMult, ExprPrecedence::kMult}, |
| {OpKind::kDiv, ExprPrecedence::kMult}, |
| {OpKind::kFloorDiv, ExprPrecedence::kMult}, |
| {OpKind::kMod, ExprPrecedence::kMult}, |
| {OpKind::kPow, ExprPrecedence::kExp}, |
| {OpKind::kLShift, ExprPrecedence::kShift}, |
| {OpKind::kRShift, ExprPrecedence::kShift}, |
| {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, |
| {OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, |
| {OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, |
| {OpKind::kLt, ExprPrecedence::kComparison}, |
| {OpKind::kLtE, ExprPrecedence::kComparison}, |
| {OpKind::kEq, ExprPrecedence::kComparison}, |
| {OpKind::kNotEq, ExprPrecedence::kComparison}, |
| {OpKind::kGt, ExprPrecedence::kComparison}, |
| {OpKind::kGtE, ExprPrecedence::kComparison}, |
| {OpKind::kAnd, ExprPrecedence::kBooleanAnd}, |
| {OpKind::kOr, ExprPrecedence::kBooleanOr}, |
| {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, |
| }; |
| int n = static_cast<int>(OpKind::kSpecialEnd); |
| std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown); |
| for (const auto& kv : raw_table) { |
| table[static_cast<int>(kv.first)] = kv.second; |
| } |
| return table; |
| }(); |
| |
| // Key is the type index of Doc |
| static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = { |
| {LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda}, |
| {TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| {DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
| }; |
| |
| if (const auto* op_doc = doc.as<OperationDocNode>()) { |
| size_t kind = static_cast<int>(op_doc->kind); |
| ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid operation: " << kind; |
| ExprPrecedence precedence = op_kind_precedence[kind]; |
| ICHECK(precedence != ExprPrecedence::kUnkown) |
| << "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown"; |
| return precedence; |
| } |
| auto it = doc_type_precedence.find(doc->type_index()); |
| if (it != doc_type_precedence.end()) { |
| return it->second; |
| } |
| ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown"; |
| throw; |
| } |
| |
| class PythonDocPrinter : public DocPrinter { |
| public: |
| explicit PythonDocPrinter(const PrinterConfig& options) : DocPrinter(options) {} |
| |
| protected: |
| using DocPrinter::PrintDoc; |
| |
| void PrintTypedDoc(const LiteralDoc& doc) final; |
| void PrintTypedDoc(const IdDoc& doc) final; |
| void PrintTypedDoc(const AttrAccessDoc& doc) final; |
| void PrintTypedDoc(const IndexDoc& doc) final; |
| void PrintTypedDoc(const OperationDoc& doc) final; |
| void PrintTypedDoc(const CallDoc& doc) final; |
| void PrintTypedDoc(const LambdaDoc& doc) final; |
| void PrintTypedDoc(const ListDoc& doc) final; |
| void PrintTypedDoc(const DictDoc& doc) final; |
| void PrintTypedDoc(const TupleDoc& doc) final; |
| void PrintTypedDoc(const SliceDoc& doc) final; |
| void PrintTypedDoc(const StmtBlockDoc& doc) final; |
| void PrintTypedDoc(const AssignDoc& doc) final; |
| void PrintTypedDoc(const IfDoc& doc) final; |
| void PrintTypedDoc(const WhileDoc& doc) final; |
| void PrintTypedDoc(const ForDoc& doc) final; |
| void PrintTypedDoc(const ExprStmtDoc& doc) final; |
| void PrintTypedDoc(const AssertDoc& doc) final; |
| void PrintTypedDoc(const ReturnDoc& doc) final; |
| void PrintTypedDoc(const ScopeDoc& doc) final; |
| void PrintTypedDoc(const FunctionDoc& doc) final; |
| void PrintTypedDoc(const ClassDoc& doc) final; |
| void PrintTypedDoc(const CommentDoc& doc) final; |
| void PrintTypedDoc(const DocStringDoc& doc) final; |
| |
| private: |
| void NewLineWithoutIndent() { |
| size_t start_pos = output_.tellp(); |
| output_ << "\n"; |
| size_t end_pos = output_.tellp(); |
| underlines_exempted_.push_back({start_pos, end_pos}); |
| } |
| |
| template <typename DocType> |
| void PrintJoinedDocs(const ffi::Array<DocType>& docs, const std::string& separator) { |
| bool is_first = true; |
| for (auto& doc : docs) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| output_ << separator; |
| } |
| PrintDoc(doc); |
| } |
| } |
| |
| void PrintIndentedBlock(const ffi::Array<StmtDoc>& docs) { |
| IncreaseIndent(); |
| for (const StmtDoc& d : docs) { |
| NewLine(); |
| PrintDoc(d); |
| } |
| if (docs.empty()) { |
| NewLine(); |
| output_ << "pass"; |
| } |
| DecreaseIndent(); |
| } |
| |
| void PrintDecorators(const ffi::Array<ExprDoc>& decorators) { |
| for (const ExprDoc& decorator : decorators) { |
| output_ << "@"; |
| PrintDoc(decorator); |
| NewLine(); |
| } |
| } |
| |
| /*! |
| * \brief Print expression and add parenthesis if needed. |
| */ |
| void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence, |
| bool parenthesis_for_same_precedence = false) { |
| ExprPrecedence doc_precedence = GetExprPrecedence(doc); |
| if (doc_precedence < parent_precedence || |
| (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) { |
| output_ << "("; |
| PrintDoc(doc); |
| output_ << ")"; |
| } else { |
| PrintDoc(doc); |
| } |
| } |
| |
| /*! |
| * \brief Print expression and add parenthesis if doc has lower precedence than parent. |
| */ |
| void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent, |
| bool parenthesis_for_same_precedence = false) { |
| ExprPrecedence parent_precedence = GetExprPrecedence(parent); |
| return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence); |
| } |
| |
| /*! |
| * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent. |
| * |
| * This function should be used to print an child expression that needs to be wrapped |
| * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b` |
| * and the `b` and `c` in `a if b else c`. |
| */ |
| void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) { |
| PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true); |
| } |
| |
| void MaybePrintCommentInline(const StmtDoc& stmt) { |
| if (stmt->comment.has_value()) { |
| const std::string& comment = stmt->comment.value(); |
| bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); |
| CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() |
| << " cannot have newline."; |
| size_t start_pos = output_.tellp(); |
| output_ << " # " << comment; |
| size_t end_pos = output_.tellp(); |
| underlines_exempted_.push_back({start_pos, end_pos}); |
| } |
| } |
| |
| void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) { |
| if (stmt->comment.has_value()) { |
| std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n'); |
| bool first_line = true; |
| size_t start_pos = output_.tellp(); |
| for (const std::string& line : comment_lines) { |
| if (first_line) { |
| output_ << "# " << line; |
| first_line = false; |
| } else { |
| NewLine() << "# " << line; |
| } |
| } |
| size_t end_pos = output_.tellp(); |
| underlines_exempted_.push_back({start_pos, end_pos}); |
| if (new_line) { |
| NewLine(); |
| } |
| } |
| } |
| |
| void PrintDocString(const ffi::String& comment) { |
| size_t start_pos = output_.tellp(); |
| output_ << "\"\"\""; |
| |
| std::vector<std::string> comment_lines = support::Split(comment, '\n'); |
| for (const std::string& line : comment_lines) { |
| if (line.empty()) { |
| // No indentation on empty line |
| output_ << "\n"; |
| } else { |
| NewLine() << line; |
| } |
| } |
| |
| NewLine() << "\"\"\""; |
| size_t end_pos = output_.tellp(); |
| underlines_exempted_.push_back({start_pos, end_pos}); |
| } |
| |
| void PrintBlockComment(const ffi::String& comment) { |
| IncreaseIndent(); |
| NewLine(); |
| PrintDocString(comment); |
| DecreaseIndent(); |
| } |
| }; |
| |
| void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { |
| const ffi::Any& value = doc->value; |
| if (value == nullptr) { |
| output_ << "None"; |
| } else if (const auto* int_imm = value.as<IntImmNode>()) { |
| if (int_imm->dtype.is_bool()) { |
| output_ << (int_imm->value ? "True" : "False"); |
| } else { |
| output_ << int_imm->value; |
| } |
| } else if (const auto* float_imm = value.as<FloatImmNode>()) { |
| // TODO(yelite): Make float number printing roundtrippable |
| if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { |
| output_ << '"' << float_imm->value << '"'; |
| } else if (std::nearbyint(float_imm->value) == float_imm->value) { |
| // Special case for floating-point values which would be |
| // formatted using %g, are not displayed in scientific |
| // notation, and whose fractional part is zero. |
| // |
| // By default, using `operator<<(std::ostream&, double)` |
| // delegates to the %g printf formatter. This strips off any |
| // trailing zeros, and also strips the decimal point if no |
| // trailing zeros are found. When parsed in python, due to the |
| // missing decimal point, this would incorrectly convert a float |
| // to an integer. Providing the `std::showpoint` modifier |
| // instead delegates to the %#g printf formatter. On its own, |
| // this resolves the round-trip errors, but also prevents the |
| // trailing zeros from being stripped off. |
| std::showpoint(output_); |
| std::fixed(output_); |
| output_.precision(1); |
| output_ << float_imm->value; |
| } else { |
| std::defaultfloat(output_); |
| std::noshowpoint(output_); |
| output_.precision(17); |
| output_ << float_imm->value; |
| } |
| |
| } else if (const auto opt_str = value.as<ffi::String>()) { |
| output_ << "\"" << support::StrEscape((*opt_str).data(), (*opt_str).size()) << "\""; |
| } else { |
| LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } |
| |
| void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { |
| PrintChildExpr(doc->value, doc); |
| output_ << "." << doc->name; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { |
| PrintChildExpr(doc->value, doc); |
| if (doc->indices.size() == 0) { |
| output_ << "[()]"; |
| } else { |
| output_ << "["; |
| PrintJoinedDocs(doc->indices, ", "); |
| output_ << "]"; |
| } |
| } |
| |
| const std::string OperatorToString(OperationDocNode::Kind operation_kind) { |
| static const std::vector<std::string> op_kind2str = []() { |
| using OpKind = OperationDocNode::Kind; |
| std::map<OpKind, std::string> raw_table = { |
| {OpKind::kUSub, "-"}, // |
| {OpKind::kInvert, "~"}, // |
| {OpKind::kNot, "not "}, // |
| {OpKind::kAdd, "+"}, // |
| {OpKind::kSub, "-"}, // |
| {OpKind::kMult, "*"}, // |
| {OpKind::kDiv, "/"}, // |
| {OpKind::kFloorDiv, "//"}, // |
| {OpKind::kMod, "%"}, // |
| {OpKind::kPow, "**"}, // |
| {OpKind::kLShift, "<<"}, // |
| {OpKind::kRShift, ">>"}, // |
| {OpKind::kBitAnd, "&"}, // |
| {OpKind::kBitOr, "|"}, // |
| {OpKind::kBitXor, "^"}, // |
| {OpKind::kLt, "<"}, // |
| {OpKind::kLtE, "<="}, // |
| {OpKind::kEq, "=="}, // |
| {OpKind::kNotEq, "!="}, // |
| {OpKind::kGt, ">"}, // |
| {OpKind::kGtE, ">="}, // |
| {OpKind::kAnd, "and"}, // |
| {OpKind::kOr, "or"}, // |
| }; |
| |
| std::vector<std::string> table; |
| table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1); |
| |
| for (const auto& kv : raw_table) { |
| table[static_cast<int>(kv.first)] = kv.second; |
| } |
| |
| return table; |
| }(); |
| |
| auto op_index = static_cast<int>(operation_kind); |
| ICHECK_LT(op_index, op_kind2str.size()); |
| const std::string str = op_kind2str[op_index]; |
| ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast<int>(operation_kind) |
| << " cannot be converted to operator token in Python directly."; |
| return str; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { |
| using OpKind = OperationDocNode::Kind; |
| if (doc->kind < OpKind::kUnaryEnd) { |
| // Unary Operators |
| ICHECK_EQ(doc->operands.size(), 1); |
| output_ << OperatorToString(doc->kind); |
| PrintChildExpr(doc->operands[0], doc); |
| } else if (doc->kind == OpKind::kPow) { |
| // Power operator is different than other binary operators |
| // It's right-associative and binds less tightly than unary operator on its right. |
| // https://docs.python.org/3/reference/expressions.html#the-power-operator |
| // https://docs.python.org/3/reference/expressions.html#operator-precedence |
| ICHECK_EQ(doc->operands.size(), 2); |
| PrintChildExprConservatively(doc->operands[0], doc); |
| output_ << " ** "; |
| PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary); |
| } else if (doc->kind < OpKind::kBinaryEnd) { |
| // Binary Operator |
| ICHECK_EQ(doc->operands.size(), 2); |
| PrintChildExpr(doc->operands[0], doc); |
| output_ << " " << OperatorToString(doc->kind) << " "; |
| PrintChildExprConservatively(doc->operands[1], doc); |
| } else if (doc->kind == OpKind::kIfThenElse) { |
| ICHECK_EQ(doc->operands.size(), 3) |
| << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); |
| PrintChildExpr(doc->operands[1], doc); |
| output_ << " if "; |
| PrintChildExprConservatively(doc->operands[0], doc); |
| output_ << " else "; |
| PrintChildExprConservatively(doc->operands[2], doc); |
| } else { |
| LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind); |
| throw; |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { |
| PrintChildExpr(doc->callee, doc); |
| |
| output_ << "("; |
| |
| // Print positional args |
| bool is_first = true; |
| for (const ExprDoc& arg : doc->args) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| output_ << ", "; |
| } |
| PrintDoc(arg); |
| } |
| |
| // Print keyword args |
| ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) |
| << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; |
| for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| output_ << ", "; |
| } |
| const ffi::String& keyword = doc->kwargs_keys[i]; |
| output_ << keyword; |
| output_ << "="; |
| PrintDoc(doc->kwargs_values[i]); |
| } |
| |
| output_ << ")"; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { |
| output_ << "lambda "; |
| PrintJoinedDocs(doc->args, ", "); |
| output_ << ": "; |
| PrintChildExpr(doc->body, doc); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { |
| output_ << "["; |
| PrintJoinedDocs(doc->elements, ", "); |
| output_ << "]"; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { |
| output_ << "("; |
| if (doc->elements.size() == 1) { |
| PrintDoc(doc->elements[0]); |
| output_ << ","; |
| } else { |
| PrintJoinedDocs(doc->elements, ", "); |
| } |
| output_ << ")"; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { |
| ICHECK_EQ(doc->keys.size(), doc->values.size()) |
| << "DictDoc should have equal number of elements in keys and values."; |
| output_ << "{"; |
| size_t idx = 0; |
| for (const ExprDoc& key : doc->keys) { |
| if (idx > 0) { |
| output_ << ", "; |
| } |
| PrintDoc(key); |
| output_ << ": "; |
| PrintDoc(doc->values[idx]); |
| idx++; |
| } |
| output_ << "}"; |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { |
| if (doc->start != nullptr) { |
| PrintDoc(doc->start.value()); |
| } |
| output_ << ":"; |
| if (doc->stop != nullptr) { |
| PrintDoc(doc->stop.value()); |
| } |
| if (doc->step != nullptr) { |
| output_ << ":"; |
| PrintDoc(doc->step.value()); |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) { |
| for (const StmtDoc& stmt : doc->stmts) { |
| PrintDoc(stmt); |
| if (stmt != doc->stmts.back()) { |
| NewLine(); |
| } |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { |
| if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) { |
| PrintJoinedDocs(tuple_doc->elements, ", "); |
| } else { |
| PrintDoc(doc->lhs); |
| } |
| |
| if (doc->annotation) { |
| output_ << ": "; |
| PrintDoc(doc->annotation.value()); |
| } |
| if (doc->rhs) { |
| output_ << " = "; |
| if (const auto* tuple_doc = doc->rhs.as<TupleDocNode>()) { |
| if (tuple_doc->elements.size() > 1) { |
| PrintJoinedDocs(tuple_doc->elements, ", "); |
| } else { |
| PrintDoc(doc->rhs.value()); |
| } |
| } else { |
| PrintDoc(doc->rhs.value()); |
| } |
| } |
| MaybePrintCommentInline(doc); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) { |
| MaybePrintCommenMultiLines(doc, true); |
| output_ << "if "; |
| PrintDoc(doc->predicate); |
| output_ << ":"; |
| |
| PrintIndentedBlock(doc->then_branch); |
| |
| if (!doc->else_branch.empty()) { |
| NewLine(); |
| output_ << "else:"; |
| PrintIndentedBlock(doc->else_branch); |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) { |
| MaybePrintCommenMultiLines(doc, true); |
| output_ << "while "; |
| PrintDoc(doc->predicate); |
| output_ << ":"; |
| |
| PrintIndentedBlock(doc->body); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) { |
| MaybePrintCommenMultiLines(doc, true); |
| output_ << "for "; |
| if (const auto* tuple = doc->lhs.as<TupleDocNode>()) { |
| if (tuple->elements.size() == 1) { |
| PrintDoc(tuple->elements[0]); |
| output_ << ","; |
| } else { |
| PrintJoinedDocs(tuple->elements, ", "); |
| } |
| } else { |
| PrintDoc(doc->lhs); |
| } |
| output_ << " in "; |
| PrintDoc(doc->rhs); |
| output_ << ":"; |
| |
| PrintIndentedBlock(doc->body); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) { |
| MaybePrintCommenMultiLines(doc, true); |
| output_ << "with "; |
| PrintDoc(doc->rhs); |
| if (doc->lhs != nullptr) { |
| output_ << " as "; |
| PrintDoc(doc->lhs.value()); |
| } |
| output_ << ":"; |
| |
| PrintIndentedBlock(doc->body); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) { |
| PrintDoc(doc->expr); |
| MaybePrintCommentInline(doc); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) { |
| output_ << "assert "; |
| PrintDoc(doc->test); |
| if (doc->msg.defined()) { |
| output_ << ", "; |
| PrintDoc(doc->msg.value()); |
| } |
| MaybePrintCommentInline(doc); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) { |
| output_ << "return "; |
| PrintDoc(doc->value); |
| MaybePrintCommentInline(doc); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { |
| for (const AssignDoc& arg_doc : doc->args) { |
| ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment attached to them."; |
| } |
| |
| PrintDecorators(doc->decorators); |
| |
| output_ << "def "; |
| PrintDoc(doc->name); |
| |
| output_ << "("; |
| PrintJoinedDocs(doc->args, ", "); |
| output_ << ")"; |
| |
| if (doc->return_type.defined()) { |
| output_ << " -> "; |
| PrintDoc(doc->return_type.value()); |
| } |
| |
| output_ << ":"; |
| |
| if (doc->comment.has_value()) { |
| PrintBlockComment(doc->comment.value()); |
| } |
| PrintIndentedBlock(doc->body); |
| NewLineWithoutIndent(); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { |
| PrintDecorators(doc->decorators); |
| |
| output_ << "class "; |
| PrintDoc(doc->name); |
| output_ << ":"; |
| |
| if (doc->comment.has_value()) { |
| PrintBlockComment(doc->comment.value()); |
| } |
| PrintIndentedBlock(doc->body); |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { |
| if (doc->comment.has_value()) { |
| MaybePrintCommenMultiLines(doc, false); |
| } |
| } |
| |
| void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { |
| if (doc->comment.has_value() && !doc->comment.value().empty()) { |
| PrintDocString(doc->comment.value()); |
| } |
| } |
| |
| ffi::String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { |
| if (cfg->num_context_lines < 0) { |
| cfg->num_context_lines = std::numeric_limits<int32_t>::max(); |
| } |
| PythonDocPrinter printer(cfg); |
| printer.Append(doc, cfg); |
| std::string result = printer.GetString(); |
| int last_space = result.size(); |
| while (last_space > 0 && std::isspace(result[last_space - 1])) { |
| last_space--; |
| } |
| return result.substr(0, last_space); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("script.printer.DocToPythonScript", DocToPythonScript); |
| } |
| |
| } // namespace printer |
| } // namespace script |
| } // namespace tvm |