blob: 35fbf8abbb6ffec5f573c34430cb660e3fbd33f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#define PICOJSON_USE_INT64
#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS
#endif
#include <picojson.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/disco/builtin.h>
#include <tvm/runtime/vm/tensor_cache_support.h>
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "../file_utils.h"
#include "./utils.h"
namespace tvm {
namespace runtime {
using vm::TensorCacheMetadata;
using FileRecord = TensorCacheMetadata::FileRecord;
using ParamRecord = TensorCacheMetadata::FileRecord::ParamRecord;
struct ShardInfo {
struct TensorInfo {
ffi::Shape shape;
DataType dtype;
};
struct ShardFunc {
std::string name;
TensorInfo output_info;
std::vector<int64_t> params;
};
std::vector<ShardFunc> funcs;
};
template <typename ExpectedType>
inline ExpectedType AsType(const picojson::value& json) {
ICHECK(json.is<ExpectedType>());
return json.get<ExpectedType>();
}
template <typename ValueType>
inline ValueType GetValue(const picojson::object& json, const std::string& key) {
return AsType<ValueType>(json.at(key));
}
std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str);
ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_info) {
CHECK_EQ(json_tensor_info.size(), 2) << "ValueError: Invalid tensor info JSON";
picojson::array shape_json = AsType<picojson::array>(json_tensor_info[0]);
int ndim = shape_json.size();
std::vector<int64_t> shape;
shape.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
shape.push_back(AsType<int64_t>(shape_json[i]));
}
std::string dtype = AsType<std::string>(json_tensor_info[1]);
return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)),
DataType(ffi::StringToDLDataType(dtype))};
}
ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) {
int n = json_shard_func.size();
ShardInfo::ShardFunc shard_info;
shard_info.name = AsType<std::string>(json_shard_func[0]);
shard_info.output_info = LoadTensorInfoFromJSON(AsType<picojson::array>(json_shard_func[1]));
shard_info.params.reserve(n - 2);
for (int i = 2; i < n; ++i) {
shard_info.params.push_back(AsType<int64_t>(json_shard_func[i]));
}
return shard_info;
}
std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str) {
picojson::value json_info;
picojson::parse(json_info, json_str);
picojson::object json_obj = AsType<picojson::object>(json_info);
std::unordered_map<std::string, ShardInfo> result;
for (auto kv : json_obj) {
std::string name = kv.first;
picojson::array json_shard_funcs = AsType<picojson::array>(kv.second);
ShardInfo info;
std::vector<ShardInfo::ShardFunc>& shard_funcs = info.funcs;
shard_funcs.reserve(json_shard_funcs.size());
for (const picojson::value& json_shard_func : json_shard_funcs) {
shard_funcs.push_back(LoadShardFuncFromJSON(AsType<picojson::array>(json_shard_func)));
}
result[name] = info;
}
return result;
}
/*! \brief An object that helps to load parameters in shards. */
class ShardLoaderObj : public Object {
public:
/*! \brief Create a shard loader. */
static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata,
std::string shard_info, ffi::Optional<ffi::Module> mod);
/*! \brief Load the i-th parameter */
Tensor Load(int weight_index) const;
Tensor LoadParamOnWorker0(int weight_index) const;
/*! \brief Load all the parameters */
ffi::Array<Tensor> LoadAll() const;
Tensor ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const Tensor& param) const;
/*! \brief Load all the pre-sharded parameters */
ffi::Array<Tensor> LoadAllPresharded() const;
/*! \brief Load the i-th parameter from presharded binaries */
Tensor LoadPresharded(int weight_index) const;
/*! \brief Slice the given tensor at a specific dimension */
Tensor Shard(Tensor source, int dim, int num_slices) const;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.ShardLoader", ShardLoaderObj, Object);
public:
/*! \brief Information of how each weight is stored and sharded */
struct ParamInfo {
const FileRecord* file;
const ParamRecord* param;
ShardInfo shard_info;
};
/*! \brief The ffi::Functions being used during sharding */
std::unordered_map<std::string, ffi::Function> shard_funcs_;
/*! \brief The metadata loaded from `tensor-cache.json` */
TensorCacheMetadata metadata_;
/*! \brief Sharding information for each weight */
std::vector<ParamInfo> param_info_;
/*! \brief Maps the name of a shard to its index */
std::unordered_map<std::string, int> param_name_to_index_;
/*! \brief The current file opened to load weights in it */
mutable const FileRecord* current_file_;
/*! \brief The context of the current file to be loaded from */
mutable std::string current_file_stream_;
private:
/*! \brief Load the i-th parameter without post-processing
*
* This function should not be called externally, as it does not
* check for post-processing that may be required. Instead, the
* public function `Load` or `LoadPresharded` should be called.
*
* \param weight_index The index of Tensor tensor to load
*
* \returns The full tensor at the specified index
*/
Tensor LoadDirect(int weight_index) const;
};
ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata,
std::string shard_info, ffi::Optional<ffi::Module> mod) {
if (shard_info.empty() && mod.has_value()) {
if (auto get_shard_info = (*mod)->GetFunction("get_shard_info")) {
shard_info = (*get_shard_info)().cast<ffi::String>();
}
}
ObjectPtr<ShardLoaderObj> n = ffi::make_object<ShardLoaderObj>();
n->metadata_ = TensorCacheMetadata::LoadFromStr(metadata, path_to_metadata);
n->current_file_ = nullptr;
n->param_info_.clear();
std::unordered_map<std::string, ShardInfo> shards = LoadShardInfoFromStr(shard_info);
for (const FileRecord& file_record : n->metadata_.records) {
for (const ParamRecord& param_record : file_record.records) {
const std::string& name = param_record.name;
int index = n->param_info_.size();
n->param_name_to_index_[name] = index;
ShardInfo& shard_info = shards[name];
for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) {
const std::string& name = shard_func.name;
if (ffi::Optional<ffi::Function> f =
mod.has_value() ? (*mod)->GetFunction(name, true) : std::nullopt) {
n->shard_funcs_[name] = *f;
} else if (const auto f = tvm::ffi::Function::GetGlobal(name)) {
n->shard_funcs_[name] = *f;
} else {
LOG(FATAL) << "ValueError: Undefined function: " << name;
}
}
n->param_info_.emplace_back(ParamInfo{&file_record, &param_record, shard_info});
}
}
return ObjectRef(std::move(n));
}
Tensor ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func,
const Tensor& param) const {
Device device = param->device;
Tensor o = Tensor::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device);
ffi::Function f = this->shard_funcs_.at(shard_func.name);
int n = static_cast<int>(shard_func.params.size());
std::vector<ffi::AnyView> packed_args(n + 2);
const DLTensor* w_in = param.operator->();
const DLTensor* w_out = o.operator->();
packed_args[0] = const_cast<DLTensor*>(w_in);
for (int i = 0; i < n; ++i) {
packed_args[i + 1] = shard_func.params[i];
}
packed_args[n + 1] = const_cast<DLTensor*>(w_out);
ffi::Any rv;
f.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv);
return o;
}
std::string GetSiblingPath(const std::string& path, const std::string& filename) {
size_t found = path.find_last_of("/\\");
if (found != std::string::npos) {
return path.substr(0, found + 1) + filename;
}
LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path;
}
Tensor ShardLoaderObj::LoadParamOnWorker0(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int worker_id = worker->worker_id;
Device device = worker->default_device;
int param_index = param_name_to_index_.at("param_" + std::to_string(weight_index));
const ParamInfo& param_info = param_info_.at(param_index);
const ParamRecord* param = param_info.param;
const FileRecord* file = param_info.file;
auto load = [this, param, device, file]() {
if (file != current_file_) {
current_file_ = file;
std::string file_name = GetSiblingPath(this->metadata_.path, file->data_path);
LoadBinaryFromFile(file_name, &this->current_file_stream_);
}
return param->Load(device, &this->current_file_stream_);
};
if (worker_id == 0) {
Tensor w = load();
return w;
} else {
Tensor w = Tensor::Empty(param->shape, param->dtype, device);
return w;
}
}
std::tuple<int, int> ParseParamShardingInfo(const ParamRecord* param) {
// Given a name "param_shard-X-of-Y", return the integer values
// rank=(X-1) and world_size=Y.
std::string name = param->name;
size_t pos1 = name.rfind("-of-");
CHECK(pos1 != std::string::npos)
<< "Attempt to read num_shards from unexpected param name: " << name;
size_t pos2 = name.rfind("_shard-", pos1 - 1);
CHECK(pos2 != std::string::npos)
<< "Attempt to read sharded worker_id from unexpected param name: " << name;
int num_shards = std::stoi(name.substr(pos1 + 4));
int worker_id = std::stoi(name.substr(pos2 + 7, pos1 - pos2 - 7)) - 1;
CHECK_GT(num_shards, 1);
CHECK_GE(worker_id, 0);
CHECK_LT(worker_id, num_shards);
return {num_shards, worker_id};
}
Tensor ShardLoaderObj::LoadDirect(int weight_index) const {
const ParamInfo& param_info = param_info_.at(weight_index);
const ParamRecord* param = param_info.param;
const FileRecord* file = param_info.file;
DiscoWorker* worker = DiscoWorker::ThreadLocal();
Device device = worker->default_device;
if (file != current_file_) {
current_file_ = file;
std::string file_name = GetSiblingPath(this->metadata_.path, file->data_path);
LoadBinaryFromFile(file_name, &this->current_file_stream_);
}
return param->Load(device, &this->current_file_stream_);
}
Tensor ShardLoaderObj::Load(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int worker_id = worker->worker_id;
int num_shards = worker->num_workers;
Device device = worker->default_device;
const ParamInfo& param_info = param_info_.at(weight_index);
const ParamRecord* param = param_info.param;
bool needs_sharding = !param_info.shard_info.funcs.empty();
if (needs_sharding) {
ffi::Shape shape = param_info.shard_info.funcs.back().output_info.shape;
DataType dtype = param_info.shard_info.funcs.back().output_info.dtype;
ICHECK(shape.size() >= 1 && shape[0] == num_shards)
<< "ValueError: The first dimension of the "
<< "output shape must be equal to the "
<< "number of shards, but got: " << shape << " and num_shards = " << num_shards;
Tensor recv = Tensor::Empty(ffi::Shape(shape.begin() + 1, shape.end()), dtype, device);
if (worker_id == 0) {
Tensor w = LoadDirect(weight_index);
for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) {
w = this->ApplyShardFunc(shard_func, w);
}
ScatterFromWorker0(w, /*in_group=*/false, recv);
} else {
ScatterFromWorker0(std::nullopt, /*in_group=*/false, recv);
}
return recv;
} else {
if (worker_id == 0) {
Tensor w = LoadDirect(weight_index);
BroadcastFromWorker0(w, /*in_group=*/false, w);
return w;
} else {
Tensor w = Tensor::Empty(param->shape, param->dtype, device);
BroadcastFromWorker0(w, /*in_group=*/false, w);
return w;
}
}
}
ffi::Array<Tensor> ShardLoaderObj::LoadAll() const {
int n = static_cast<int>(param_info_.size());
ffi::Array<Tensor> shards;
shards.reserve(n);
for (int i = 0; i < n; ++i) {
std::string param_name = "param_" + std::to_string(i);
ICHECK(this->param_name_to_index_.count(param_name));
int shard_id = this->param_name_to_index_.at(param_name);
shards.push_back(this->Load(shard_id));
}
return shards;
}
Tensor ShardLoaderObj::LoadPresharded(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int worker_id = worker->worker_id;
int num_shards = worker->num_workers;
size_t num_weights = param_info_.size() / num_shards;
size_t index = worker_id * num_weights + weight_index;
CHECK(index < param_info_.size())
<< "Loading param " << weight_index << " for shard " << worker_id << " at position " << index
<< " is out of bounds for the provided ndarray chace.";
const auto& shard_info = param_info_[index];
const ParamRecord* param = shard_info.param;
const FileRecord* file = shard_info.file;
auto [p_num_shards, p_worker_id] = ParseParamShardingInfo(param);
CHECK_EQ(num_shards, p_num_shards)
<< "Runtime number of shards (" << num_shards
<< ") does not match number of compiled shards (" << p_num_shards << "): " << param->name
<< " loaded from " << file->data_path;
CHECK_EQ(worker_id, p_worker_id)
<< "Runtime worker_id (" << worker_id << ") does not match worker_id of compiled shard ("
<< p_worker_id << "): " << param->name << " loaded from " << file->data_path;
return LoadDirect(index);
}
ffi::Array<Tensor> ShardLoaderObj::LoadAllPresharded() const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
size_t worker_id = static_cast<size_t>(worker->worker_id);
size_t num_workers = static_cast<size_t>(worker->num_workers);
size_t num_params = param_info_.size() / num_workers;
ffi::Array<Tensor> params;
params.reserve(num_params);
for (size_t i_param = 0; i_param < num_params; ++i_param) {
std::string param_name = static_cast<const std::stringstream&>(
std::stringstream() << "param_" << i_param << "_shard-"
<< (worker_id + 1) << "-of-" << num_workers)
.str();
auto it = param_name_to_index_.find(param_name);
CHECK(it != param_name_to_index_.end())
<< "Parameter " << param_name << " was not found in the parameter set";
int param_id = this->param_name_to_index_.at(param_name);
params.push_back(this->LoadDirect(param_id));
}
return params;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("runtime.disco.ShardLoader", ShardLoaderObj::Create)
.def("runtime.disco.ShardLoaderLoad",
[](ObjectRef loader_obj, ffi::Shape weight_index) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
CHECK(loader != nullptr)
<< "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey();
return loader->Load(IntegerFromShape(weight_index));
})
.def("runtime.disco.ShardLoaderLoadPresharded",
[](ObjectRef loader_obj, ffi::Shape weight_index) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
CHECK(loader != nullptr)
<< "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey();
return loader->LoadPresharded(IntegerFromShape(weight_index));
})
.def("runtime.disco.ShardLoaderLoadAll",
[](ObjectRef loader_obj) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
CHECK(loader != nullptr)
<< "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey();
return loader->LoadAll();
})
.def("runtime.disco.ShardLoaderLoadAllPresharded",
[](ObjectRef loader_obj) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
CHECK(loader != nullptr)
<< "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey();
return loader->LoadAllPresharded();
})
.def("runtime.disco.ShardLoaderLoadParamOnWorker0",
[](ObjectRef loader_obj, int param_index) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
CHECK(loader != nullptr)
<< "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey();
return loader->LoadParamOnWorker0(param_index);
});
}
} // namespace runtime
} // namespace tvm