| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #ifndef TVM_RUNTIME_DISCO_PROTOCOL_H_ |
| #define TVM_RUNTIME_DISCO_PROTOCOL_H_ |
| |
| #include <tvm/ffi/function.h> |
| #include <tvm/runtime/base.h> |
| #include <tvm/runtime/disco/session.h> |
| #include <tvm/support/io.h> |
| #include <tvm/support/serializer.h> |
| |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "../../support/arena.h" |
| #include "../../support/base64.h" |
| #include "../../support/bytes_io.h" |
| #include "../minrpc/rpc_reference.h" |
| |
| namespace tvm { |
| namespace runtime { |
| |
| /*! |
| * \brief The communication protocol used by Disco message channel. |
| * \tparam SubClassType The subclass type that inherits this protocol. |
| */ |
| template <class SubClassType> |
| struct DiscoProtocol { |
| protected: |
| /*! \brief Virtual destructor */ |
| virtual ~DiscoProtocol() = default; |
| |
| /*! \brief Recycle all the memory used in the arena */ |
| inline void RecycleAll() { |
| this->any_arena_.clear(); |
| this->arena_.RecycleAll(); |
| } |
| |
| /*! \brief Get the length of the object being serialized. Used by RPCReference. */ |
| inline uint64_t GetFFIAnyProtocolBytes(const TVMFFIAny* obj); |
| |
| /*! \brief Write the object to stream. Used by RPCReference. */ |
| inline void WriteFFIAny(const TVMFFIAny* obj); |
| |
| /*! \brief Read the object from stream. Used by RPCReference. */ |
| inline void ReadFFIAny(TVMFFIAny* out); |
| |
| /*! \brief Callback method used when starting a new message. Used by RPCReference. */ |
| void MessageStart(uint64_t packet_nbytes) {} |
| |
| /*! \brief Callback method used when a new message is complete. Used by RPCReference. */ |
| void MessageDone() {} |
| |
| /*! \brief Callback method when an error occurs in (de)-serialization. Used by RPCReference. */ |
| void ThrowError(RPCServerStatus status) { |
| TVM_FFI_THROW(InternalError) << "Unexpected error in RPC: " << RPCServerStatusToString(status); |
| } |
| |
| /*!\ brief Arena used by RPCReference to allocate POD memory */ |
| template <typename T> |
| T* ArenaAlloc(int count) { |
| static_assert(std::is_pod<T>::value, "need to be trival"); |
| return arena_.template allocate_<T>(count); |
| } |
| |
| support::Arena arena_; |
| std::vector<Any> any_arena_; |
| friend struct RPCReference; |
| }; |
| |
| /*! |
| * \brief The debug extension of the communication protocol that allows serialization and |
| * deserialization of Tensors and reflection-capable TVM objects. |
| */ |
| struct DiscoDebugObject : public Object { |
| public: |
| /*! \brief The data to be serialized */ |
| ffi::Any data; |
| |
| /*! \brief Wrap an Tensor or reflection-capable TVM object into the debug extension. */ |
| static ObjectRef Wrap(const ffi::Any& data) { |
| ObjectPtr<DiscoDebugObject> n = ffi::make_object<DiscoDebugObject>(); |
| n->data = data; |
| return ObjectRef(n); |
| } |
| |
| /*! \brief Wrap an Tensor or reflection-capable TVM object into the debug extension. */ |
| static ObjectRef Wrap(const ffi::AnyView& data) { |
| ffi::Any rv; |
| rv = data; |
| return Wrap(std::move(rv)); |
| } |
| |
| /*! \brief Serialize the debug object to string */ |
| inline std::string SaveToStr() const; |
| /*! \brief Deserialize the debug object from string */ |
| static inline ObjectPtr<DiscoDebugObject> LoadFromStr(std::string json_str); |
| /*! \brief Get the size of the debug object in bytes */ |
| inline uint64_t GetFFIAnyProtocolBytes() const { |
| return sizeof(uint64_t) + this->SaveToStr().size(); |
| } |
| TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.DiscoDebugObject", DiscoDebugObject, SessionObj); |
| }; |
| |
| template <class SubClassType> |
| inline uint64_t DiscoProtocol<SubClassType>::GetFFIAnyProtocolBytes(const TVMFFIAny* value) { |
| const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(value); |
| if (any_view_ptr->as<DRefObj>()) { |
| return sizeof(uint32_t) + sizeof(int64_t); |
| } else if (const auto opt_str = any_view_ptr->as<ffi::String>()) { |
| uint64_t size = (*opt_str).size(); |
| return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); |
| } else if (const auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) { |
| uint64_t size = (*opt_bytes).size(); |
| return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); |
| } else if (const auto opt_shape = any_view_ptr->as<ffi::Shape>()) { |
| uint64_t ndim = (*opt_shape).size(); |
| return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ffi::ShapeObj::index_type); |
| } else if (const auto opt_debug_obj = any_view_ptr->as<DiscoDebugObject>()) { |
| return sizeof(uint32_t) + (*opt_debug_obj).GetFFIAnyProtocolBytes(); |
| } else { |
| TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " |
| << any_view_ptr->GetTypeKey() |
| << " (type_index = " << any_view_ptr->type_index() << ")"; |
| return 0; |
| } |
| } |
| template <class SubClassType> |
| inline void DiscoProtocol<SubClassType>::WriteFFIAny(const TVMFFIAny* value) { |
| SubClassType* self = static_cast<SubClassType*>(this); |
| const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(value); |
| if (const auto* ref = any_view_ptr->as<DRefObj>()) { |
| int64_t reg_id = ref->reg_id; |
| self->template Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef); |
| self->template Write<int64_t>(reg_id); |
| } else if (const auto opt_str = any_view_ptr->as<ffi::String>()) { |
| self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIStr); |
| self->template Write<uint64_t>((*opt_str).size()); |
| self->template WriteArray<char>((*opt_str).data(), (*opt_str).size()); |
| } else if (const auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) { |
| self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIBytes); |
| self->template Write<uint64_t>((*opt_bytes).size()); |
| self->template WriteArray<char>((*opt_bytes).data(), (*opt_bytes).size()); |
| } else if (const auto opt_shape = any_view_ptr->as<ffi::Shape>()) { |
| self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIShape); |
| self->template Write<uint64_t>((*opt_shape).size()); |
| self->template WriteArray<ffi::ShapeObj::index_type>((*opt_shape).data(), (*opt_shape).size()); |
| } else if (const auto opt_debug_obj = any_view_ptr->as<DiscoDebugObject>()) { |
| self->template Write<uint32_t>(0); |
| std::string str = (*opt_debug_obj).SaveToStr(); |
| self->template Write<uint64_t>(str.size()); |
| self->template WriteArray<char>(str.data(), str.size()); |
| } else { |
| TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " |
| << any_view_ptr->GetTypeKey() |
| << " (type_index = " << any_view_ptr->type_index() << ")"; |
| } |
| } |
| |
| template <class SubClassType> |
| inline void DiscoProtocol<SubClassType>::ReadFFIAny(TVMFFIAny* out) { |
| SubClassType* self = static_cast<SubClassType*>(this); |
| ffi::Any result{nullptr}; |
| uint32_t type_index; |
| self->template Read<uint32_t>(&type_index); |
| if (type_index == TypeIndex::kRuntimeDiscoDRef) { |
| ObjectPtr<DRefObj> dref = ffi::make_object<DRefObj>(); |
| self->template Read<int64_t>(&dref->reg_id); |
| dref->session = Session{nullptr}; |
| result = ObjectRef(std::move(dref)); |
| } else if (type_index == ffi::TypeIndex::kTVMFFIStr) { |
| uint64_t size = 0; |
| self->template Read<uint64_t>(&size); |
| std::string data(size, '\0'); |
| self->template ReadArray<char>(data.data(), size); |
| result = ffi::String(std::move(data)); |
| } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) { |
| uint64_t size = 0; |
| self->template Read<uint64_t>(&size); |
| std::string data(size, '\0'); |
| self->template ReadArray<char>(data.data(), size); |
| result = ffi::Bytes(std::move(data)); |
| } else if (type_index == ffi::TypeIndex::kTVMFFIShape) { |
| uint64_t ndim = 0; |
| self->template Read<uint64_t>(&ndim); |
| std::vector<ffi::ShapeObj::index_type> data(ndim); |
| self->template ReadArray<ffi::ShapeObj::index_type>(data.data(), ndim); |
| result = ffi::Shape(std::move(data)); |
| } else if (type_index == 0) { |
| uint64_t size = 0; |
| self->template Read<uint64_t>(&size); |
| std::string data(size, '\0'); |
| self->template ReadArray<char>(data.data(), size); |
| result = DiscoDebugObject::LoadFromStr(std::move(data))->data.cast<ObjectRef>(); |
| } else { |
| TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " |
| << Object::TypeIndex2Key(type_index) |
| << " (type_index = " << type_index << ")"; |
| } |
| *reinterpret_cast<ffi::AnyView*>(out) = result; |
| any_arena_.push_back(result); |
| } |
| |
| inline std::string DiscoDebugObject::SaveToStr() const { |
| if (auto opt_nd = this->data.as<Tensor>()) { |
| Tensor array = opt_nd.value(); |
| std::string result; |
| { |
| support::BytesOutStream mstrm(&result); |
| support::Base64OutStream b64strm(&mstrm); |
| runtime::SaveDLTensor(&b64strm, array.operator->()); |
| b64strm.Finish(); |
| } |
| result.push_back('1'); |
| return result; |
| } else if (auto opt_obj = this->data.as<ObjectRef>()) { |
| ObjectRef obj = opt_obj.value(); |
| const auto f = tvm::ffi::Function::GetGlobal("node.SaveJSON"); |
| TVM_FFI_CHECK(f.has_value(), ValueError) |
| << "Cannot serialize object in non-debugging mode: " << obj->GetTypeKey(); |
| std::string result = (*f)(obj).cast<std::string>(); |
| result.push_back('0'); |
| return result; |
| } |
| TVM_FFI_THROW(ValueError) << "Cannot serialize the following type code in non-debugging mode: " |
| << this->data.GetTypeKey(); |
| return ""; |
| } |
| |
| inline ObjectPtr<DiscoDebugObject> DiscoDebugObject::LoadFromStr(std::string json_str) { |
| TVM_FFI_ICHECK(!json_str.empty()); |
| char control_bit = json_str.back(); |
| json_str.pop_back(); |
| ObjectPtr<DiscoDebugObject> result = ffi::make_object<DiscoDebugObject>(); |
| if (control_bit == '0') { |
| const auto f = tvm::ffi::Function::GetGlobal("node.LoadJSON"); |
| TVM_FFI_CHECK(f.has_value(), ValueError) << "Cannot deserialize object in non-debugging mode"; |
| result->data = (*f)(json_str); |
| } else if (control_bit == '1') { |
| support::BytesInStream mstrm(json_str); |
| support::Base64InStream b64strm(&mstrm); |
| b64strm.InitPosition(); |
| runtime::Tensor array; |
| TVM_FFI_ICHECK(array.Load(&b64strm)); |
| result->data = std::move(array); |
| } else { |
| TVM_FFI_THROW(ValueError) << "Unsupported control bit: " << control_bit |
| << ". Full string: " << json_str; |
| } |
| return result; |
| } |
| |
| } // namespace runtime |
| } // namespace tvm |
| #endif // TVM_RUNTIME_DISCO_PROTOCOL_H_ |