blob: f3e0edab6e076bd16074f38f7fd2bc453509156a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/node/functor.h>
#include <tvm/node/script_printer.h>
#include <iostream>
#include <string>
namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief The node to be printed. */
TVM_DLL void Print(const ffi::Any& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::ObjectRef& node);
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::Object* node);
} // namespace tvm
namespace tvm {
namespace ffi {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
ReprPrinter(os).Print(n);
return os;
}
// default print function for any
inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*)
ReprPrinter(os).Print(n);
return os;
}
template <typename... V>
inline std::ostream& operator<<(std::ostream& os, const ffi::Variant<V...>& n) { // NOLINT(*)
ReprPrinter(os).Print(Any(n));
return os;
}
namespace reflection {
inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) {
namespace refl = ffi::reflection;
switch (step->kind) {
case refl::AccessKind::kAttr: {
os << '.' << step->key.cast<ffi::String>();
return os;
}
case refl::AccessKind::kArrayItem: {
os << "[" << step->key.cast<int64_t>() << "]";
return os;
}
case refl::AccessKind::kMapItem: {
os << "[" << step->key << "]";
return os;
}
case refl::AccessKind::kAttrMissing: {
os << ".<missing attr " << step->key.cast<ffi::String>() << "`>";
return os;
}
case refl::AccessKind::kArrayItemMissing: {
os << "[<missing item at " << step->key.cast<int64_t>() << ">]";
return os;
}
case refl::AccessKind::kMapItemMissing: {
os << "[<missing item at " << step->key << ">]";
return os;
}
default: {
LOG(FATAL) << "Unknown access step kind: " << static_cast<int>(step->kind);
}
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) {
ffi::Array<AccessStep> steps = path->ToSteps();
os << "<root>";
for (const auto& step : steps) {
os << step;
}
return os;
}
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_NODE_REPR_PRINTER_H_