blob: f95262424d31bd19f3ac276506780fbb443c9043 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file device_api.cc
* \brief Device specific implementations
*/
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/ffi/string.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <cstdlib>
#include <mutex>
#include <sstream>
#include <string>
#include <tuple>
#include <variant>
namespace tvm {
namespace runtime {
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = TVMDeviceExtType_End;
// Get API
static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); }
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager* inst = new DeviceAPIManager();
return inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
const auto f = tvm::ffi::Function::GetGlobal(factory);
if (!f.has_value()) {
TVM_FFI_ICHECK(allow_missing) << "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)().cast<void*>();
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint);
}
static size_t GetDataAlignment(const DLDataType dtype) {
size_t align = (dtype.bits / 8) * dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
}
size_t DeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional<ffi::String> mem_scope) {
if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") {
size_t size = 1;
for (int i = 0; i < arr.ndim; ++i) {
size *= static_cast<size_t>(arr.shape[i]);
}
return ffi::GetDataSize(size, arr.dtype);
}
TVM_FFI_THROW(InternalError) << "Device does not support physical mem computation with "
<< "specified memory scope: " << mem_scope.value();
return 0;
}
void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
ffi::Optional<ffi::String> mem_scope) {
if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") {
// by default, we can always redirect to the flat memory allocations
DLTensor temp;
temp.data = nullptr;
temp.device = dev;
temp.ndim = ndim;
temp.dtype = dtype;
temp.shape = const_cast<int64_t*>(shape);
temp.strides = nullptr;
temp.byte_offset = 0;
size_t size = GetDataSize(temp);
size_t alignment = GetDataAlignment(temp.dtype);
return AllocDataSpace(dev, size, alignment, dtype);
}
TVM_FFI_THROW(InternalError) << "Device does not support allocate data space with "
<< "specified memory scope: " << mem_scope.value();
return nullptr;
}
void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
// by default, we can always redirect to the flat memory copy operation.
size_t nbytes = GetDataSize(*from);
TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to));
TVM_FFI_ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to))
<< "CopyDataFromTo only support contiguous array for now";
CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device,
to->device, from->dtype, stream);
}
void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
size_t num_bytes, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
TVM_FFI_THROW(InternalError) << "Device does not support CopyDataFromTo.";
}
void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); }
TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr));
}
TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {
return TVMFFIEnvGetStream(dev.device_type, dev.device_id);
}
void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("runtime.Device_StreamCreate",
[](DLDevice dev) {
return reinterpret_cast<int64_t>(DeviceAPIManager::Get(dev)->CreateStream(dev));
})
.def("runtime.Device_StreamFree",
[](DLDevice dev, int64_t stream) {
DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast<TVMStreamHandle>(stream));
})
.def("runtime.Device_SetStream",
[](DLDevice dev, int64_t stream) {
DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast<TVMStreamHandle>(stream));
})
.def("runtime.Device_StreamSync",
[](DLDevice dev, int64_t stream) {
DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast<TVMStreamHandle>(stream));
})
.def("runtime.Device_StreamSyncFromTo", [](DLDevice dev, int64_t src, int64_t dst) {
DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast<TVMStreamHandle>(src),
reinterpret_cast<TVMStreamHandle>(dst));
});
}
// set device api
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed(tvm::runtime::symbol::tvm_set_device,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) {
DLDevice dev;
dev.device_type = static_cast<DLDeviceType>(args[0].cast<int>());
dev.device_id = args[1].cast<int>();
DeviceAPIManager::Get(dev)->SetDevice(dev);
})
.def_packed("runtime.GetDeviceAttr",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) {
DLDevice dev;
dev.device_type = static_cast<DLDeviceType>(args[0].cast<int>());
dev.device_id = args[1].cast<int>();
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].cast<int>());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true);
if (api != nullptr) {
api->GetAttr(dev, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret);
}
})
.def("runtime.TVMSetStream", [](int device_type, int device_id, void* stream) {
Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
DeviceAPIManager::Get(dev)->SetStream(dev, stream);
});
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) {
return TVMFFIEnvModLookupFromImports(mod_node, func_name, func);
}
void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
int dtype_bits_hint) {
DLDevice dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
DLDataType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast<size_t>(size), type_hint);
}
int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
DLDevice dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr);
return 0;
}
int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}