blob: a8bb6c26083fd3671bce66239d6581334971723c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/contrib/json/json_runtime.h
* \brief Utilities for json runtime.
*/
#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
#include <tvm/ffi/extra/module.h>
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/tensor.h>
#include <cstddef>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include "json_node.h"
namespace tvm {
namespace runtime {
namespace json {
/*!
* \brief A json runtime that executes the serialized JSON format. This runtime
* can be extended by user defined runtime for execution.
*/
class JSONRuntimeBase : public ffi::ModuleObj {
public:
JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json,
const ffi::Array<ffi::String> const_names)
: symbol_name_(symbol_name), graph_json_(graph_json), const_names_(const_names) {
LoadGraph(graph_json_);
}
const char* kind() const override { return "json"; } // May be overridden
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const override {
return ffi::Module::kBinarySerializable | ffi::Module::kRunnable;
}
/*! \brief Initialize a specific json runtime. */
virtual void Init(const ffi::Array<Tensor>& consts) = 0;
/*! \brief Invoke the execution engine to inteprete a specific json runtime. */
virtual void Run() = 0;
/*! \brief Does the backend support debug & profiling */
virtual bool CanDebug() { return false; }
/*!
* \brief Invoke the profiler
* \param pointer to profiler
*/
virtual void RunProfile(profiling::Profiler* prof) {
LOG(FATAL) << "Not expected to be here : Profiling call w/o support ?";
}
/*!
* \brief Invoke the debugger
* \return External compiler specific debug blob
*/
virtual std::string DebugDump(void) {
LOG(FATAL) << "Not expected to be here : Debug dump w/o support ?";
}
/*!
* \brief Get a packed function.
* \param name The name/symbol of the function.
* \param sptr_to_self The pointer to the module node.
* \return The packed function.
*/
ffi::Optional<ffi::Function> GetFunction(const ffi::String& name) override {
ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
if (name == "get_symbol") {
return ffi::Function(
[sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; });
} else if (name == "get_const_vars") {
return ffi::Function(
[sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->const_names_; });
} else if (this->symbol_name_ == name) {
return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
ICHECK(this->initialized_) << "The module has not been initialized";
// Bind argument tensors to data entries.
this->SetInputOutputBuffers(args);
// Execute the subgraph.
this->Run();
});
} else if (this->symbol_name_ + "_debug" == name) {
// NOTE: the current debug convention is not very compatible with
// the FFI convention, consider clean up
if (!this->CanDebug()) {
return ffi::Function(nullptr);
}
return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
ICHECK(this->initialized_) << "The module has not been initialized";
// Bind argument tensors to data entries.
this->SetInputOutputBuffers(args);
if (auto opt_str = rv->try_cast<ffi::String>()) {
ffi::String purpose = std::move(opt_str.value());
if ("debug_dump" == purpose) {
*rv = this->DebugDump();
}
} else {
// Profile the subgraph.
profiling::Profiler* prof = static_cast<profiling::Profiler*>(rv->cast<void*>());
this->RunProfile(prof);
}
// ffi::String vendor_prof = this->RunProfile(prof);
});
} else if ("__init_" + this->symbol_name_ == name) {
// The function to initialize constant tensors.
return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) {
ICHECK_EQ(args.size(), 1U);
std::lock_guard<std::mutex> guard(this->initialize_mutex_);
if (!this->initialized_) {
this->Init(args[0].cast<ffi::Array<Tensor>>());
this->initialized_ = true;
}
*rv = 0;
});
} else {
return std::nullopt;
}
}
ffi::Bytes SaveToBytes() const override {
std::string buffer;
dmlc::MemoryStringStream ms(&buffer);
dmlc::Stream* stream = &ms;
// Save the symbol
stream->Write(symbol_name_);
// Save the graph
stream->Write(graph_json_);
// Save the required const names
std::vector<std::string> consts;
for (const auto& it : const_names_) {
consts.push_back(it);
}
stream->Write(consts);
return ffi::Bytes(buffer);
}
template <typename T,
typename = typename std::enable_if<std::is_base_of<JSONRuntimeBase, T>::value>::type>
static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) {
dmlc::MemoryFixedSizeStream ms(const_cast<char*>(bytes.data()), bytes.size());
dmlc::Stream* stream = &ms;
std::string symbol;
std::string graph_json;
std::vector<std::string> consts;
// Load the symbol
ICHECK(stream->Read(&symbol)) << "Loading symbol name failed";
ICHECK(stream->Read(&graph_json)) << "Loading graph json failed";
ICHECK(stream->Read(&consts)) << "Loading the const name list failed";
ffi::Array<ffi::String> const_names;
for (const auto& it : consts) {
const_names.push_back(it);
}
auto n = ffi::make_object<T>(symbol, graph_json, const_names);
return ffi::Module(n);
}
/*!
* \brief Get the JSON generated by codegen.
*
* \param format the format to return.
* \return A string of JSON.
*/
ffi::String InspectSource(const ffi::String& format) const override { return graph_json_; }
protected:
/*!
* \brief Set up the input and output buffers by binding their DLTensor pointers to the
* corresponding data entry.
*
* \param args The packed args.
*/
void SetInputOutputBuffers(const ffi::PackedArgs& args) {
ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
<< "Found mismatch in the number of provided data entryies and required.";
for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
: EntryID(outputs_[i - input_var_eid_.size()]);
const DLTensor* arg;
if (auto opt_nd = args[i].as<Tensor>()) {
Tensor arr = opt_nd.value();
arg = arr.operator->();
} else {
arg = args[i].cast<DLTensor*>();
}
// Assign input/output the Tensor pointers to data entry so that we can directly
// read/write host buffers.
data_entry_[eid] = arg;
}
}
/*!
* \brief Load the graph and record the entries for inputs and constants.
*
* \param graph_json The graph in the json format.
*/
void LoadGraph(const std::string& graph_json) {
std::istringstream is(graph_json);
dmlc::JSONReader reader(&is);
this->Load(&reader);
std::vector<std::string> consts;
for (size_t i = 0; i < input_nodes_.size(); i++) {
uint32_t nid = input_nodes_[i];
std::string name = nodes_[nid].name_;
if (nodes_[nid].op_type_ == "input") {
ICHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size());
for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
input_var_eid_.push_back(EntryID(nid, j));
}
nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size());
} else {
ICHECK_EQ(nodes_[nid].op_type_, "const");
auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
ICHECK(pos != std::end(const_names_)) << "Found non-existent constant: " << name;
const_idx_.push_back(nid);
consts.push_back(name);
}
}
ICHECK_EQ(consts.size(), const_names_.size())
<< "Found mismatch for the number of constants in the graph and required.";
for (size_t i = 0; i < consts.size(); i++) {
ICHECK_EQ(consts[i], const_names_[i])
<< "The position of constant in the graph must be the same as the required.";
}
// Reserve data entries.
data_entry_.resize(NumEntries());
}
/*!
* \brief Set up the constants/weights for inference by binding their DLTensor pointer to
* the corresponding data entry.
*
* \param consts A list of constant Tensor to be used.
*/
void SetupConstants(const ffi::Array<Tensor>& consts) {
for (size_t i = 0; i < consts.size(); ++i) {
data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->();
}
}
// Load the graph.
void Load(dmlc::JSONReader* reader) {
reader->BeginObject();
std::string key;
std::string symbol_;
while (reader->NextObjectItem(&key)) {
if (key == "nodes") {
reader->Read(&nodes_);
} else if (key == "arg_nodes") {
reader->Read(&input_nodes_);
} else if (key == "node_row_ptr") {
reader->Read(&node_row_ptr_);
} else if (key == "heads") {
reader->Read(&outputs_);
} else if (key == "symbol") {
reader->Read(&symbol_);
} else {
LOG(FATAL) << "Unknown key: " << key;
}
}
}
// Get the node entry index.
uint32_t EntryID(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
// Get the node entry index.
uint32_t EntryID(const JSONGraphNodeEntry& e) const { return EntryID(e.id_, e.index_); }
// Number of node entries.
uint32_t NumEntries() const { return node_row_ptr_.back(); }
protected:
/*! \brief The only subgraph name for this module. */
std::string symbol_name_;
/*! \brief The graph. */
std::string graph_json_;
/*! \brief The required constant names. */
ffi::Array<ffi::String> const_names_;
/*! \brief The json graph nodes. */
std::vector<JSONGraphNode> nodes_;
/*! \brief The input nodes, including variables and constants. */
std::vector<uint32_t> input_nodes_;
/*! \brief Used for quick entry indexing. */
std::vector<uint32_t> node_row_ptr_;
/*! \brief Output entries. */
std::vector<JSONGraphNodeEntry> outputs_;
/*! \brief Data of that entry. */
std::vector<const DLTensor*> data_entry_;
/*! \brief Map the input name to entry id. */
std::vector<uint32_t> input_var_eid_;
/*! \brief input const node index. */
std::vector<uint32_t> const_idx_;
/*! \brief Indicate if the engine has been initialized. */
bool initialized_{false};
/*! \brief Initializer mutex*/
std::mutex initialize_mutex_;
};
} // namespace json
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_