| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * \file src/runtime/vm/tensor_cache_support.cc |
| * \brief Runtime to support tensor cache file loading. |
| * |
| * This file provides a minimum support for tensor cache file loading. |
| * |
| * The main focus of this implementation is to enable loading |
| * with minimum set of intermediate files while also being |
| * compatible to some of the multi-shard files that are more |
| * friendly in some of the environments. |
| * |
| * Tensor cache also provides a way to do system-wide |
| * parameter sharing across multiple VMs. |
| * |
| * There are likely other ways to load the parameters ndarray-ache. |
| * We will keep the impact minimum by puting it as a private |
| * runtime builtin provide as in this file. |
| */ |
| #include <tvm/ffi/extra/json.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/runtime/tensor.h> |
| #include <tvm/runtime/vm/tensor_cache_support.h> |
| |
| #include <string> |
| #include <vector> |
| |
| #include "../../support/utils.h" |
| #include "../file_utils.h" |
| |
| namespace tvm { |
| namespace runtime { |
| namespace vm { |
| |
| namespace json = tvm::ffi::json; |
| |
| TensorCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const json::Object& json) { |
| std::vector<ffi::Shape::index_type> shape; |
| { |
| json::Array shape_json = json["shape"].cast<json::Array>(); |
| shape.reserve(shape_json.size()); |
| for (const ffi::Any& d : shape_json) { |
| shape.push_back(d.cast<int64_t>()); |
| } |
| } |
| TensorCacheMetadata::FileRecord::ParamRecord result; |
| std::string dtype = json["dtype"].cast<ffi::String>(); |
| result.name = json["name"].cast<ffi::String>(); |
| result.dtype = DataType(ffi::StringToDLDataType(dtype)); |
| result.format = json["format"].cast<ffi::String>(); |
| result.nbytes = json["nbytes"].cast<int64_t>(); |
| result.byte_offset = json["byteOffset"].cast<int64_t>(); |
| result.shape = ffi::Shape(std::move(shape)); |
| return result; |
| } |
| |
| TensorCacheMetadata::FileRecord JSONAsFileRecord(const json::Object& json) { |
| json::Array records = json["records"].cast<json::Array>(); |
| TensorCacheMetadata::FileRecord result; |
| result.data_path = json["dataPath"].cast<ffi::String>(); |
| result.format = json["format"].cast<ffi::String>(); |
| result.nbytes = json["nbytes"].cast<int64_t>(); |
| result.records.reserve(records.size()); |
| for (const ffi::Any& item : records) { |
| result.records.push_back(JSONAsParamRecord(item.cast<json::Object>())); |
| } |
| return result; |
| } |
| |
| TensorCacheMetadata JSONAsTensorCacheMetadata(const json::Object& json) { |
| json::Array records = json["records"].cast<json::Array>(); |
| TensorCacheMetadata result; |
| result.records.reserve(records.size()); |
| for (const ffi::Any& item : records) { |
| result.records.push_back(JSONAsFileRecord(item.cast<json::Object>())); |
| } |
| return result; |
| } |
| |
| TensorCacheMetadata TensorCacheMetadata::LoadFromStr(const std::string& json_str, |
| const std::string& path) { |
| ffi::String err; |
| json::Value json_info = json::Parse(json_str, &err); |
| if (!err.empty()) { |
| TVM_FFI_THROW(InternalError) << "Failed to parse JSON: " << err |
| << ". The JSON string is:" << json_str; |
| } |
| TVM_FFI_CHECK(json_info.as<json::Object>(), ValueError) |
| << "The given string is not a JSON object: " << json_str; |
| TensorCacheMetadata result = JSONAsTensorCacheMetadata(json_info.cast<json::Object>()); |
| result.path = path; |
| return result; |
| } |
| |
| TVM_DLL TensorCacheMetadata TensorCacheMetadata::Load(const std::string& path) { |
| std::string json_str; |
| LoadBinaryFromFile(path + "/tensor-cache.json", &json_str); |
| ffi::String err; |
| json::Value json_info = json::Parse(json_str, &err); |
| if (!err.empty()) { |
| TVM_FFI_THROW(InternalError) << "Failed to parse JSON: " << err |
| << ". The JSON string is:" << json_str; |
| } |
| TVM_FFI_CHECK(json_info.as<json::Object>(), ValueError) |
| << "The given string is not a JSON object: " << json_str; |
| TensorCacheMetadata result = JSONAsTensorCacheMetadata(json_info.cast<json::Object>()); |
| result.path = path; |
| return result; |
| } |
| |
| void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, |
| ffi::Optional<Tensor>* staging_buffer) { |
| Device device = param->device; |
| if (device.device_type != kDLOpenCL || staging_buffer == nullptr) { |
| param.CopyFromBytes(data, nbytes); |
| return; |
| } |
| // Special handle for OpenCL runtime. |
| // It creates a host side memory mirror, for every cl_mem that tries to copy data from host |
| // which can cause memory issue. Her we use a large staging buffer to postpone deallocation |
| if (staging_buffer->defined()) { |
| size_t curr_size = runtime::GetDataSize(*(staging_buffer->value().operator->())); |
| if (curr_size < nbytes) { |
| *staging_buffer = std::nullopt; |
| } |
| } |
| if (!staging_buffer->defined()) { |
| *staging_buffer = Tensor::Empty(param.Shape(), param->dtype, param->device); |
| } |
| Tensor staging_view = staging_buffer->value().CreateView(param.Shape(), param->dtype); |
| staging_view.CopyFromBytes(data, nbytes); |
| param.CopyFrom(staging_view); |
| DeviceAPI::Get(device)->StreamSync(device, nullptr); |
| } |
| |
| Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load( |
| Device device, const std::string* raw_data, ffi::Optional<Tensor>* staging_buffer) const { |
| Tensor arr = Tensor::Empty(shape, dtype, device); |
| if (dtype == DataType::Float(32) && format == "f32-to-bf16") { |
| // decode bf16 to f32 |
| std::vector<uint16_t> buffer(nbytes / 2); |
| std::vector<uint32_t> decoded(nbytes / 2); |
| std::memcpy(buffer.data(), raw_data->data() + byte_offset, nbytes); |
| for (size_t i = 0; i < buffer.size(); ++i) { |
| decoded[i] = static_cast<uint32_t>(buffer[i]) << 16; |
| } |
| CopyTensorFromBytes(arr, decoded.data(), decoded.size() * sizeof(uint32_t), staging_buffer); |
| } else { |
| CopyTensorFromBytes(arr, raw_data->data() + byte_offset, nbytes, staging_buffer); |
| } |
| return arr; |
| } |
| |
| TVM_DLL ffi::Array<Tensor> TensorCacheMetadata::FileRecord::Load( |
| Device device, |
| const std::string& path_prefix, // |
| std::string* raw_data_buffer, // |
| ffi::Optional<Tensor>* staging_buffer) const { |
| LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer); |
| TVM_FFI_CHECK_EQ(this->format, "raw-shard", ValueError) << "Only `raw-shard` format is supported"; |
| TVM_FFI_CHECK_EQ(this->nbytes, raw_data_buffer->length(), ValueError) |
| << "Encountered an corrupted parameter shard. It means it is not downloaded " |
| "completely or downloading is interrupted. Please try to download again."; |
| ffi::Array<Tensor> result; |
| result.reserve(this->records.size()); |
| for (const ParamRecord& nd_rec : this->records) { |
| result.push_back(nd_rec.Load(device, raw_data_buffer, staging_buffer)); |
| } |
| return result; |
| } |
| |
| /*! |
| * A Tensor cache to store pre-loaded arrays in the system. |
| */ |
| class TensorCache { |
| public: |
| static TensorCache* Global() { |
| static TensorCache* inst = new TensorCache(); |
| return inst; |
| } |
| |
| static void Update(ffi::String name, Tensor arr, bool override) { |
| TensorCache* pool = Global(); |
| if (!override) { |
| TVM_FFI_ICHECK_EQ(pool->pool_.count(name), 0) |
| << "Name " << name << " already exists in the cache"; |
| } |
| pool->pool_.Set(name, arr); |
| } |
| |
| static ffi::Optional<Tensor> Get(ffi::String name) { |
| TensorCache* pool = Global(); |
| auto it = pool->pool_.find(name); |
| if (it != pool->pool_.end()) { |
| return (*it).second; |
| } else { |
| return std::nullopt; |
| } |
| } |
| |
| static void Remove(ffi::String name) { |
| TensorCache* pool = Global(); |
| pool->pool_.erase(name); |
| } |
| |
| static void Clear() { Global()->pool_.clear(); } |
| |
| /*! |
| * \brief Load parameters from path and append them. |
| * \param cache_path The cache to path. |
| * \param device_type The type of device to be loaded. |
| * \param device_id The device id. |
| */ |
| static void Load(const std::string& cache_path, int device_type, int device_id) { |
| DLDevice device{static_cast<DLDeviceType>(device_type), device_id}; |
| TensorCacheMetadata metadata = TensorCacheMetadata::Load(cache_path); |
| ffi::Optional<Tensor> staging_buffer; |
| std::string raw_data; |
| ffi::Array<Tensor> params; |
| for (const TensorCacheMetadata::FileRecord& shard_rec : metadata.records) { |
| try { |
| params = shard_rec.Load(device, cache_path, &raw_data, &staging_buffer); |
| } catch (const std::runtime_error& e) { |
| TVM_FFI_THROW(ValueError) << "Error when loading parameters from " << shard_rec.data_path |
| << ": " << e.what(); |
| } |
| int num_params = params.size(); |
| for (int i = 0; i < num_params; ++i) { |
| Update(shard_rec.records[i].name, params[i], true); |
| } |
| } |
| } |
| |
| private: |
| ffi::Map<ffi::String, Tensor> pool_; |
| }; |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("vm.builtin.tensor_cache.get", TensorCache::Get) |
| .def_packed("vm.builtin.tensor_cache.update", |
| [](ffi::PackedArgs args, ffi::Any* rv) { |
| TVM_FFI_ICHECK(args.size() == 2 || args.size() == 3); |
| ffi::String name = args[0].cast<ffi::String>(); |
| bool is_override = args.size() == 2 ? false : args[2].cast<bool>(); |
| |
| Tensor arr; |
| if (auto opt_nd = args[1].as<Tensor>()) { |
| arr = opt_nd.value(); |
| } else { |
| // We support converting DLTensors to Tensors as RPC references are always |
| // DLTensors |
| auto tensor = args[1].cast<DLTensor*>(); |
| std::vector<int64_t> shape; |
| for (int64_t i = 0; i < tensor->ndim; i++) { |
| shape.push_back(tensor->shape[i]); |
| } |
| arr = Tensor::Empty(shape, tensor->dtype, tensor->device); |
| arr.CopyFrom(tensor); |
| DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); |
| } |
| |
| TensorCache::Update(name, arr, is_override); |
| }) |
| .def("vm.builtin.tensor_cache.remove", TensorCache::Remove) |
| .def("vm.builtin.tensor_cache.clear", TensorCache::Clear) |
| .def("vm.builtin.tensor_cache.load", TensorCache::Load); |
| } |
| |
| // This param module node can be useful to get param dict in RPC mode |
| // when the remote already have loaded parameters from file. |
| class ParamModuleNode : public ffi::ModuleObj { |
| public: |
| const char* kind() const final { return "param_module"; } |
| |
| ffi::Optional<ffi::Function> GetFunction(const ffi::String& name) final { |
| if (name == "get_params") { |
| auto params = params_; |
| return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); |
| } else { |
| return ffi::Function(); |
| } |
| } |
| |
| static ffi::Array<Tensor> GetParams(const ffi::String& prefix, int num_params) { |
| ffi::Array<Tensor> params; |
| for (int i = 0; i < num_params || num_params == -1; ++i) { |
| std::string name = prefix + "_" + std::to_string(i); |
| auto opt = TensorCache::Get(name); |
| if (opt) { |
| params.push_back(opt.value()); |
| } else { |
| if (num_params == -1) return params; |
| TVM_FFI_THROW(InternalError) << "Cannot find " << name << " in cache"; |
| } |
| } |
| return params; |
| } |
| |
| static ffi::Array<Tensor> GetParamByName(const ffi::Array<ffi::String>& names) { |
| ffi::Array<Tensor> result; |
| result.reserve(names.size()); |
| for (const ffi::String& name : names) { |
| if (ffi::Optional<Tensor> opt = TensorCache::Get(name)) { |
| result.push_back(opt.value()); |
| } else { |
| TVM_FFI_THROW(ValueError) << "Cannot find parameter in cache: " << name; |
| } |
| } |
| return result; |
| } |
| |
| static ffi::Module Create(const std::string& prefix, int num_params) { |
| auto n = ffi::make_object<ParamModuleNode>(); |
| n->params_ = GetParams(prefix, num_params); |
| return ffi::Module(n); |
| } |
| |
| static ffi::Module CreateByName(const ffi::Array<ffi::String>& names) { |
| auto n = ffi::make_object<ParamModuleNode>(); |
| n->params_ = GetParamByName(names); |
| return ffi::Module(n); |
| } |
| |
| private: |
| ffi::Array<Tensor> params_; |
| }; |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("vm.builtin.param_module_from_cache", ParamModuleNode::Create) |
| .def("vm.builtin.param_module_from_cache_by_name", ParamModuleNode::CreateByName) |
| .def("vm.builtin.param_array_from_cache", ParamModuleNode::GetParams) |
| .def("vm.builtin.param_array_from_cache_by_name", ParamModuleNode::GetParamByName) |
| .def_packed("vm.builtin.param_array_from_cache_by_name_unpacked", |
| [](ffi::PackedArgs args, ffi::Any* rv) { |
| ffi::Array<ffi::String> names; |
| names.reserve(args.size()); |
| for (int i = 0; i < args.size(); ++i) { |
| if (!args[i].try_cast<ffi::String>()) { |
| TVM_FFI_THROW(ValueError) << "Expect string as input, but get " |
| << args[i].GetTypeKey() << " at " << i; |
| } |
| names.push_back(args[i].cast<ffi::String>()); |
| } |
| *rv = ParamModuleNode::GetParamByName(names); |
| }); |
| } |
| |
| } // namespace vm |
| } // namespace runtime |
| } // namespace tvm |