blob: c165f6b05cf39dd7ba31a1cdf4b8f752c72c53dd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/json/json_node.h
* \brief The graph nodes used by JSON runtime.
*/
#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_
#define TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_
#include <dlpack/dlpack.h>
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/json.h>
#include <tvm/ffi/string.h>
#include <tvm/runtime/data_type.h>
#include <cstdint>
#include <cstdio>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace tvm {
namespace runtime {
namespace json {
using JSONGraphAttrs = ffi::Map<ffi::String, ffi::Any>;
/*!
* \brief The node entry in the serialized json graph.
*/
class JSONGraphNodeEntry {
public:
// Constructors.
JSONGraphNodeEntry() = default;
JSONGraphNodeEntry(int id, int index, int version = 0)
: id_(id), index_(index), version_(version) {}
/*!
* \brief Serialize a node entry.
*/
ffi::json::Value SaveToJSON() const {
ffi::json::Array arr;
arr.push_back(static_cast<int64_t>(id_));
arr.push_back(static_cast<int64_t>(index_));
arr.push_back(static_cast<int64_t>(version_));
return arr;
}
/*!
* \brief Deserialize the json array into a node entry.
*/
void Load(ffi::json::Array arr) {
TVM_FFI_ICHECK_GE(arr.size(), 2) << "invalid json format";
id_ = static_cast<uint32_t>(arr[0].cast<int64_t>());
index_ = static_cast<uint32_t>(arr[1].cast<int64_t>());
if (arr.size() > 2) {
version_ = static_cast<uint32_t>(arr[2].cast<int64_t>());
} else {
version_ = 0;
}
}
/*! \brief The json graph node ID. */
uint32_t id_;
/*! \brief The entry index. */
uint32_t index_;
uint32_t version_;
};
/*!
* \brief The node of the serialized json graph. It includes an array of
* entries.
*/
class JSONGraphNode {
public:
// Constructors.
JSONGraphNode() = default;
JSONGraphNode(const std::string& name, const std::string& op_type,
const std::vector<JSONGraphNodeEntry>& inputs = {}, size_t num_outputs = 1) {
name_ = name;
op_type_ = op_type;
num_inputs_ = inputs.size();
inputs_ = inputs;
num_outputs_ = num_outputs;
}
/*!
* \brief Serialize a node so that it can be saved to disk.
*/
ffi::json::Value SaveToJSON() {
if (!inputs_.empty()) {
SetAttr("num_inputs", static_cast<int64_t>(inputs_.size()));
SetAttr("num_outputs", static_cast<int64_t>(num_outputs_));
}
ffi::json::Object obj;
obj.Set(ffi::String("op"), ffi::String(op_type_));
obj.Set(ffi::String("name"), ffi::String(name_));
if (!inputs_.empty()) {
ffi::json::Array inputs_arr;
for (const auto& e : inputs_) {
inputs_arr.push_back(e.SaveToJSON());
}
obj.Set(ffi::String("inputs"), std::move(inputs_arr));
}
if (attrs_.size() > 0) {
ffi::json::Object attrs_obj;
for (const auto& kv : attrs_) {
std::string key = kv.first;
const ffi::Any& v = kv.second;
attrs_obj.Set(ffi::String(key), v);
}
obj.Set(ffi::String("attrs"), std::move(attrs_obj));
}
return obj;
}
/*!
* \brief Load the attribute of a node from a json object.
*/
void LoadAttrs(ffi::json::Object obj) {
for (const auto& kv : obj) {
std::string key = std::string(kv.first.cast<ffi::String>());
if (key == "num_inputs") {
num_inputs_ = static_cast<uint32_t>(kv.second.cast<int64_t>());
} else if (key == "num_outputs") {
num_outputs_ = static_cast<uint32_t>(kv.second.cast<int64_t>());
} else {
attrs_.Set(key, kv.second);
}
}
// Populate cached shape/dtype from attrs for fast access
if (HasAttr("shape")) {
shape_ = GetAttr<ffi::Array<ffi::Array<int64_t>>>("shape");
}
if (HasAttr("dtype")) {
dtype_ = GetAttr<ffi::Array<DLDataType>>("dtype");
}
if (shape_.defined() && dtype_.defined()) {
TVM_FFI_ICHECK_EQ(shape_.size(), dtype_.size());
}
}
/*!
* \brief Load a node from a json object.
*/
void Load(ffi::json::Object obj) {
for (const auto& kv : obj) {
std::string key = std::string(kv.first.cast<ffi::String>());
if (key == "op") {
op_type_ = std::string(kv.second.cast<ffi::String>());
} else if (key == "name") {
name_ = std::string(kv.second.cast<ffi::String>());
} else if (key == "inputs") {
auto arr = kv.second.cast<ffi::json::Array>();
inputs_.clear();
for (const auto& e : arr) {
JSONGraphNodeEntry entry;
entry.Load(e.cast<ffi::json::Array>());
inputs_.push_back(entry);
}
} else if (key == "attr" || key == "attrs") {
this->LoadAttrs(kv.second.cast<ffi::json::Object>());
} else {
TVM_FFI_THROW(InternalError) << "Unknown key: " << key;
}
}
}
/*!
* \brief Check if a node is a leaf node, i.e. input to the graph.
*
* \return True if the node has no input, otherwise, false.
*/
bool IsLeaf() const { return inputs_.empty(); }
/*!
* \brief Return the number of outputs of the node.
*
* \return The number of the output.
*/
uint32_t GetNumOutput() const { return num_outputs_; }
/*!
* \brief Return the input entries.
*
* \return The input entries.
*/
std::vector<JSONGraphNodeEntry> GetInputs() const { return inputs_; }
/*!
* \brief Return the op type.
*
* \return The op type.
*/
std::string GetOpType() const { return op_type_; }
/*!
* \brief Return the op name.
*
* \return The op name.
*/
std::string GetOpName() const { return name_; }
/*!
* \brief Return the op output shapes.
*
* \return The shapes.
*/
ffi::Array<ffi::Array<int64_t>> GetOpShape() const { return shape_; }
/*!
* \brief Return the op types.
*
* \return The types.
*/
ffi::Array<DLDataType> GetOpDataType() const { return dtype_; }
/*!
* \brief Set the number of outputs of the node.
*
* \param num_outputs The number of output.
*/
void SetNumOutput(uint32_t num_outputs) { num_outputs_ = num_outputs; }
/*!
* \brief Get the value of an attribute in the node.
*
* \tparam T The return type.
* \param key The key for lookup.
*
* \return The value.
*/
template <typename T>
T GetAttr(const std::string& key) const {
TVM_FFI_ICHECK(attrs_.count(key) > 0) << "Key: " << key << " is not found";
return attrs_[key].cast<T>();
}
/*!
* \brief Set an attribute for the node.
*
* \param key The key of the attribute.
* \param value The attribute value (native type: int64_t, double, ffi::String, ffi::Array, etc.)
*/
void SetAttr(const std::string& key, ffi::Any value) { attrs_.Set(key, std::move(value)); }
/*!
* \brief Set shape for the node.
*/
void SetShape(const std::vector<std::vector<int64_t>>& shape) {
ffi::Array<ffi::Array<int64_t>> arr;
for (const auto& s : shape) {
ffi::Array<int64_t> row;
for (auto x : s) row.push_back(x);
arr.push_back(std::move(row));
}
shape_ = arr;
SetAttr("shape", std::move(arr));
}
/*!
* \brief Set dtype for the node.
*/
void SetDType(const std::vector<DLDataType>& dtype) {
ffi::Array<ffi::String> str_arr;
ffi::Array<DLDataType> dt_arr;
for (const auto& d : dtype) {
str_arr.push_back(ffi::String(ffi::DLDataTypeToString(d)));
dt_arr.push_back(d);
}
dtype_ = std::move(dt_arr);
SetAttr("dtype", std::move(str_arr));
}
/*!
* \brief Check if node has attribute.
*
* \param key The key of the attribute.
*
* \return True if attribute exists, false otherwise.
*/
bool HasAttr(const std::string& key) const { return attrs_.count(key) > 0; }
void CaptureAttrs(const JSONGraphNode& that) {
for (const auto& kv : that.attrs_) {
attrs_.Set(kv.first, kv.second);
}
}
virtual ~JSONGraphNode() {}
private:
/*! \brief The number of input. */
uint32_t num_inputs_{0};
/*! \brief The number of output. */
uint32_t num_outputs_{1};
/*! \brief The name of the op. It is the symbol that used for runtime lookup. */
std::string name_;
/*! \brief The operator type, i.e. input is "null". */
std::string op_type_;
/*! \brief The inputs of the node. */
std::vector<JSONGraphNodeEntry> inputs_;
/*! \brief Attribute of the node, including shape and dtype. */
JSONGraphAttrs attrs_;
/*! \brief Cached shape for fast access (also stored in attrs_). */
ffi::Array<ffi::Array<int64_t>> shape_;
/*! \brief Cached dtype for fast access (also stored in attrs_ as strings). */
ffi::Array<DLDataType> dtype_;
friend class JSONRuntimeBase;
};
} // namespace json
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_JSON_JSON_NODE_H_