| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include "./utils.h" |
| |
| namespace tvm { |
| namespace script { |
| namespace printer { |
| |
| IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // |
| const ffi::Optional<ExprDoc>& var, const ffi::Optional<ExprDoc>& ann) { |
| using relax::SeqExpr; |
| ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond")); |
| std::vector<ffi::Array<StmtDoc>> branches{ |
| PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false), |
| PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false), |
| }; |
| if (var.defined()) { |
| for (ffi::Array<StmtDoc>& stmts : branches) { |
| ExprDoc ret = Downcast<ExprStmtDoc>(stmts.back())->expr; |
| stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); |
| } |
| } |
| return IfDoc(cond, branches[0], branches[1]); |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
| .set_dispatch<relax::MatchCast>( |
| "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { |
| using relax::StructInfo; |
| using relax::MatchStructInfo; |
| ffi::Optional<ExprDoc> ann = std::nullopt; |
| if (d->cfg->show_all_struct_info) { |
| ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); |
| } |
| ExprDoc rhs = Relax(d, "match_cast") |
| ->Call({d->AsDoc<ExprDoc>(n->value, n_p->Attr("value")), |
| d->AsDoc<ExprDoc>(n->struct_info, n_p->Attr("struct_info_"))}); |
| ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); |
| return AssignDoc(lhs, rhs, ann); |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
| .set_dispatch<relax::VarBinding>( // |
| "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { |
| if (const auto if_ = n->value.as<relax::IfNode>()) { |
| ffi::Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); |
| ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); |
| return PrintIfExpr(ffi::GetRef<relax::If>(if_), n_p->Attr("value"), d, lhs, ann); |
| } else if (n->value->IsInstance<tvm::BaseFuncNode>() && |
| !n->value->IsInstance<relax::ExternFuncNode>()) { |
| IdDoc lhs = DefineVar(n->var, d->frames.back(), d); |
| d->cfg->binding_names.push_back(lhs->name); |
| Doc ret = d->AsDoc(n->value, n_p->Attr("value")); |
| d->cfg->binding_names.pop_back(); |
| return ret; |
| } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) && |
| relax::HasVoidStructInfo(n->var)) { |
| ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value")); |
| return ExprStmtDoc(rhs); |
| } else { |
| ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value")); |
| ffi::Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); |
| ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); |
| return AssignDoc(lhs, rhs, ann); |
| } |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
| .set_dispatch<relax::If>("", [](relax::If n, AccessPath n_p, IRDocsifier d) -> Doc { |
| return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); |
| }); |
| |
| TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax); |
| TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax); |
| TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax); |
| |
| } // namespace printer |
| } // namespace script |
| } // namespace tvm |