| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file device_api.cc |
| * \brief Device specific implementations |
| */ |
| #include <tvm/ffi/container/tensor.h> |
| #include <tvm/ffi/extra/c_env_api.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/optional.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ffi/rvalue_ref.h> |
| #include <tvm/ffi/string.h> |
| #include <tvm/runtime/base.h> |
| #include <tvm/runtime/c_backend_api.h> |
| #include <tvm/runtime/device_api.h> |
| #include <tvm/runtime/module.h> |
| |
| #include <algorithm> |
| #include <array> |
| #include <cctype> |
| #include <cstdlib> |
| #include <mutex> |
| #include <sstream> |
| #include <string> |
| #include <tuple> |
| #include <variant> |
| |
| namespace tvm { |
| namespace runtime { |
| |
| class DeviceAPIManager { |
| public: |
| static const int kMaxDeviceAPI = TVMDeviceExtType_End; |
| // Get API |
| static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } |
| static DeviceAPI* Get(int dev_type, bool allow_missing = false) { |
| return Global()->GetAPI(dev_type, allow_missing); |
| } |
| |
| private: |
| std::array<DeviceAPI*, kMaxDeviceAPI> api_; |
| DeviceAPI* rpc_api_{nullptr}; |
| std::mutex mutex_; |
| // constructor |
| DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } |
| // Global static variable. |
| static DeviceAPIManager* Global() { |
| static DeviceAPIManager* inst = new DeviceAPIManager(); |
| return inst; |
| } |
| // Get or initialize API. |
| DeviceAPI* GetAPI(int type, bool allow_missing) { |
| if (type < kRPCSessMask) { |
| if (api_[type] != nullptr) return api_[type]; |
| std::lock_guard<std::mutex> lock(mutex_); |
| if (api_[type] != nullptr) return api_[type]; |
| api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); |
| return api_[type]; |
| } else { |
| if (rpc_api_ != nullptr) return rpc_api_; |
| std::lock_guard<std::mutex> lock(mutex_); |
| if (rpc_api_ != nullptr) return rpc_api_; |
| rpc_api_ = GetAPI("rpc", allow_missing); |
| return rpc_api_; |
| } |
| } |
| DeviceAPI* GetAPI(const std::string name, bool allow_missing) { |
| std::string factory = "device_api." + name; |
| const auto f = tvm::ffi::Function::GetGlobal(factory); |
| if (!f.has_value()) { |
| TVM_FFI_ICHECK(allow_missing) << "Device API " << name << " is not enabled."; |
| return nullptr; |
| } |
| void* ptr = (*f)().cast<void*>(); |
| return static_cast<DeviceAPI*>(ptr); |
| } |
| }; |
| |
| DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { |
| return DeviceAPIManager::Get(static_cast<int>(dev.device_type), allow_missing); |
| } |
| |
| void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { |
| return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); |
| } |
| |
| static size_t GetDataAlignment(const DLDataType dtype) { |
| size_t align = (dtype.bits / 8) * dtype.lanes; |
| if (align < kAllocAlignment) return kAllocAlignment; |
| return align; |
| } |
| |
| size_t DeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional<ffi::String> mem_scope) { |
| if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { |
| size_t size = 1; |
| for (int i = 0; i < arr.ndim; ++i) { |
| size *= static_cast<size_t>(arr.shape[i]); |
| } |
| return ffi::GetDataSize(size, arr.dtype); |
| } |
| TVM_FFI_THROW(InternalError) << "Device does not support physical mem computation with " |
| << "specified memory scope: " << mem_scope.value(); |
| return 0; |
| } |
| |
| void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, |
| ffi::Optional<ffi::String> mem_scope) { |
| if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { |
| // by default, we can always redirect to the flat memory allocations |
| DLTensor temp; |
| temp.data = nullptr; |
| temp.device = dev; |
| temp.ndim = ndim; |
| temp.dtype = dtype; |
| temp.shape = const_cast<int64_t*>(shape); |
| temp.strides = nullptr; |
| temp.byte_offset = 0; |
| size_t size = GetDataSize(temp); |
| size_t alignment = GetDataAlignment(temp.dtype); |
| return AllocDataSpace(dev, size, alignment, dtype); |
| } |
| TVM_FFI_THROW(InternalError) << "Device does not support allocate data space with " |
| << "specified memory scope: " << mem_scope.value(); |
| return nullptr; |
| } |
| |
| void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { |
| // by default, we can always redirect to the flat memory copy operation. |
| size_t nbytes = GetDataSize(*from); |
| TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); |
| |
| TVM_FFI_ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) |
| << "CopyDataFromTo only support contiguous array for now"; |
| CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, |
| to->device, from->dtype, stream); |
| } |
| |
| void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, |
| size_t num_bytes, Device dev_from, Device dev_to, |
| DLDataType type_hint, TVMStreamHandle stream) { |
| TVM_FFI_THROW(InternalError) << "Device does not support CopyDataFromTo."; |
| } |
| |
| void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } |
| |
| TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } |
| |
| void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} |
| |
| void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { |
| TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr)); |
| } |
| |
| TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { |
| return TVMFFIEnvGetStream(dev.device_type, dev.device_id); |
| } |
| |
| void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("runtime.Device_StreamCreate", |
| [](DLDevice dev) { |
| return reinterpret_cast<int64_t>(DeviceAPIManager::Get(dev)->CreateStream(dev)); |
| }) |
| .def("runtime.Device_StreamFree", |
| [](DLDevice dev, int64_t stream) { |
| DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast<TVMStreamHandle>(stream)); |
| }) |
| .def("runtime.Device_SetStream", |
| [](DLDevice dev, int64_t stream) { |
| DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast<TVMStreamHandle>(stream)); |
| }) |
| .def("runtime.Device_StreamSync", |
| [](DLDevice dev, int64_t stream) { |
| DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast<TVMStreamHandle>(stream)); |
| }) |
| .def("runtime.Device_StreamSyncFromTo", [](DLDevice dev, int64_t src, int64_t dst) { |
| DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast<TVMStreamHandle>(src), |
| reinterpret_cast<TVMStreamHandle>(dst)); |
| }); |
| } |
| |
| // set device api |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def_packed(tvm::runtime::symbol::tvm_set_device, |
| [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { |
| DLDevice dev; |
| dev.device_type = static_cast<DLDeviceType>(args[0].cast<int>()); |
| dev.device_id = args[1].cast<int>(); |
| DeviceAPIManager::Get(dev)->SetDevice(dev); |
| }) |
| .def_packed("runtime.GetDeviceAttr", |
| [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { |
| DLDevice dev; |
| dev.device_type = static_cast<DLDeviceType>(args[0].cast<int>()); |
| dev.device_id = args[1].cast<int>(); |
| |
| DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].cast<int>()); |
| if (kind == kExist) { |
| DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); |
| if (api != nullptr) { |
| api->GetAttr(dev, kind, ret); |
| } else { |
| *ret = 0; |
| } |
| } else { |
| DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); |
| } |
| }) |
| .def("runtime.TVMSetStream", [](int device_type, int device_id, void* stream) { |
| Device dev; |
| dev.device_type = static_cast<DLDeviceType>(device_type); |
| dev.device_id = device_id; |
| DeviceAPIManager::Get(dev)->SetStream(dev, stream); |
| }); |
| } |
| } // namespace runtime |
| } // namespace tvm |
| |
| using namespace tvm::runtime; |
| |
| int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { |
| return TVMFFIEnvModLookupFromImports(mod_node, func_name, func); |
| } |
| |
| void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, |
| int dtype_bits_hint) { |
| DLDevice dev; |
| dev.device_type = static_cast<DLDeviceType>(device_type); |
| dev.device_id = device_id; |
| |
| DLDataType type_hint; |
| type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); |
| type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); |
| type_hint.lanes = 1; |
| |
| return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast<size_t>(size), type_hint); |
| } |
| |
| int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { |
| DLDevice dev; |
| dev.device_type = static_cast<DLDeviceType>(device_type); |
| dev.device_id = device_id; |
| DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); |
| return 0; |
| } |
| |
| int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { |
| if (*handle == nullptr) { |
| *handle = reinterpret_cast<void*>(1); |
| return (*f)(cdata); |
| } |
| return 0; |
| } |