blob: a28967cb419427b1ab25f1abbd00749d334989d3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
ffi::Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p,
const IRDocsifier& d, bool use_ret) {
With<RelaxFrame> f(d);
const ffi::Array<relax::BindingBlock>& blocks = n->blocks;
AccessPath blocks_p = n_p->Attr("blocks");
ffi::Array<StmtDoc>* stmts = &(*f)->stmts;
for (int i = 0, l = blocks.size(); i < l; ++i) {
Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i));
if (const auto* stmt_block = block.as<StmtBlockDocNode>()) {
stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end());
} else if (const auto* stmt = block.as<StmtDocNode>()) {
stmts->push_back(ffi::GetRef<StmtDoc>(stmt));
} else {
LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey();
}
}
ExprDoc ret = d->AsDoc<ExprDoc>(n->body, n_p->Attr("body"));
if (use_ret) {
stmts->push_back(ReturnDoc(ret));
} else {
stmts->push_back(ExprStmtDoc(ret));
}
return *stmts;
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::SeqExpr>("", [](relax::SeqExpr n, AccessPath n_p, IRDocsifier d) -> Doc {
return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false));
});
ffi::Array<StmtDoc> PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p,
const IRDocsifier& d,
ffi::Array<ExprDoc>* non_dataflow_vars) {
const ffi::Array<relax::Binding>& bindings = n->bindings;
AccessPath bindings_p = n_p->Attr("bindings");
ffi::Array<StmtDoc> stmts;
for (int i = 0, l = bindings.size(); i < l; ++i) {
const relax::Binding& binding = bindings[i];
AccessPath binding_p = bindings_p->ArrayItem(i);
ICHECK(binding->var.defined());
Doc binding_doc = d->AsDoc(binding, binding_p);
if (const auto* stmt = binding_doc.as<StmtDocNode>()) {
stmts.push_back(ffi::GetRef<StmtDoc>(stmt));
} else if (const auto* stmt_block = binding_doc.as<StmtBlockDocNode>()) {
stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end());
} else {
LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey();
}
if (non_dataflow_vars != nullptr && !binding->var->IsInstance<relax::DataflowVarNode>()) {
non_dataflow_vars->push_back(d->AsDoc<ExprDoc>(binding->var, binding_p->Attr("var")));
}
}
return stmts;
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::BindingBlock>( //
"", [](relax::BindingBlock n, AccessPath n_p, IRDocsifier d) -> Doc {
return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::DataflowBlock>( //
"", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc {
ffi::Array<ExprDoc> non_dataflow_vars;
ffi::Array<StmtDoc> stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars);
stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars)));
return ScopeDoc(std::nullopt, Relax(d, "dataflow")->Call({}), stmts);
});
TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax);
} // namespace printer
} // namespace script
} // namespace tvm