| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #ifndef MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_ |
| #define MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_ |
| |
| #if !defined(ANDROID) && !defined(__ANDROID__) |
| |
| #ifndef _WIN32 |
| #include <sys/mman.h> |
| #include <sys/fcntl.h> |
| #include <sys/stat.h> |
| #else |
| #include <Windows.h> |
| #include <process.h> |
| #endif // _WIN32 |
| |
| #include <string> |
| #include <limits> |
| #include "./storage_manager.h" |
| |
| namespace mxnet { |
| namespace storage { |
| /*! |
| * \brief Storage manager for cpu shared memory |
| */ |
| class CPUSharedStorageManager final : public StorageManager { |
| public: |
| /*! |
| * \brief Default constructor. |
| */ |
| CPUSharedStorageManager() : rand_gen_(std::random_device()()) {} |
| /*! |
| * \brief Default destructor. |
| */ |
| ~CPUSharedStorageManager() { |
| for (const auto& kv : pool_) { |
| FreeImpl(kv.second); |
| } |
| #ifdef _WIN32 |
| CheckAndRealFree(); |
| #endif |
| } |
| |
| void Alloc(Storage::Handle* handle, bool failsafe) override; |
| void Free(Storage::Handle handle) override { |
| std::lock_guard<std::recursive_mutex> lock(mutex_); |
| pool_.erase(handle.dptr); |
| FreeImpl(handle); |
| } |
| |
| void DirectFree(Storage::Handle handle) override { |
| Free(handle); |
| } |
| |
| void IncrementRefCount(const Storage::Handle& handle) { |
| std::atomic<int>* counter = |
| reinterpret_cast<std::atomic<int>*>(static_cast<char*>(handle.dptr) - alignment_); |
| ++(*counter); |
| } |
| |
| int DecrementRefCount(const Storage::Handle& handle) { |
| std::atomic<int>* counter = |
| reinterpret_cast<std::atomic<int>*>(static_cast<char*>(handle.dptr) - alignment_); |
| return --(*counter); |
| } |
| |
| private: |
| static constexpr size_t alignment_ = 16; |
| |
| std::recursive_mutex mutex_; |
| std::mt19937 rand_gen_; |
| std::unordered_map<void*, Storage::Handle> pool_; |
| #ifdef _WIN32 |
| std::unordered_map<void*, Storage::Handle> is_free_; |
| std::unordered_map<void*, HANDLE> map_handle_map_; |
| #endif |
| |
| void FreeImpl(const Storage::Handle& handle); |
| #ifdef _WIN32 |
| void CheckAndRealFree(); |
| #endif |
| |
| std::string SharedHandleToString(int shared_pid, int shared_id) { |
| std::stringstream name; |
| name << "/mx_" << std::hex << shared_pid << "_" << std::hex << shared_id; |
| return name.str(); |
| } |
| DISALLOW_COPY_AND_ASSIGN(CPUSharedStorageManager); |
| }; // class CPUSharedStorageManager |
| |
| void CPUSharedStorageManager::Alloc(Storage::Handle* handle, bool /* failsafe */) { |
| std::lock_guard<std::recursive_mutex> lock(mutex_); |
| std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max()); |
| int fid = -1; |
| std::string filename; |
| bool is_new = false; |
| size_t size = handle->size + alignment_; |
| void* ptr = nullptr; |
| #ifdef _WIN32 |
| CheckAndRealFree(); |
| HANDLE map_handle = nullptr; |
| uint32_t error = 0; |
| if (handle->shared_id == -1 && handle->shared_pid == -1) { |
| is_new = true; |
| handle->shared_pid = _getpid(); |
| for (int i = 0; i < 10; ++i) { |
| handle->shared_id = dis(rand_gen_); |
| filename = SharedHandleToString(handle->shared_pid, handle->shared_id); |
| map_handle = CreateFileMapping( |
| INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0, size, filename.c_str()); |
| if ((error = GetLastError()) == ERROR_SUCCESS) { |
| break; |
| } |
| } |
| } else { |
| filename = SharedHandleToString(handle->shared_pid, handle->shared_id); |
| map_handle = OpenFileMapping(FILE_MAP_READ | FILE_MAP_WRITE, FALSE, filename.c_str()); |
| error = GetLastError(); |
| } |
| |
| if (error != ERROR_SUCCESS && map_handle == nullptr) { |
| LOG(FATAL) << "Failed to open shared memory. CreateFileMapping failed with error " << error; |
| } |
| |
| ptr = MapViewOfFile(map_handle, FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); |
| CHECK_NE(ptr, (void*)0) << "Failed to map shared memory. MapViewOfFile failed with error " |
| << GetLastError(); |
| map_handle_map_[ptr] = map_handle; |
| #else |
| if (handle->shared_id == -1 && handle->shared_pid == -1) { |
| is_new = true; |
| handle->shared_pid = getpid(); |
| for (int i = 0; i < 10; ++i) { |
| handle->shared_id = dis(rand_gen_); |
| filename = SharedHandleToString(handle->shared_pid, handle->shared_id); |
| fid = shm_open(filename.c_str(), O_EXCL | O_CREAT | O_RDWR, 0666); |
| if (fid != -1) |
| break; |
| } |
| } else { |
| #ifdef __linux__ |
| fid = handle->shared_id; |
| #else |
| filename = SharedHandleToString(handle->shared_pid, handle->shared_id); |
| fid = shm_open(filename.c_str(), O_RDWR, 0666); |
| #endif // __linux__ |
| } |
| |
| if (fid == -1) { |
| if (is_new) { |
| LOG(FATAL) << "Failed to open shared memory. shm_open failed with error " << strerror(errno); |
| } else { |
| LOG(FATAL) << "Invalid file descriptor from shared array."; |
| } |
| } |
| |
| if (is_new) |
| CHECK_EQ(ftruncate(fid, size), 0); |
| |
| ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fid, 0); |
| CHECK_NE(ptr, MAP_FAILED) << "Failed to map shared memory. mmap failed with error " |
| << strerror(errno); |
| #ifdef __linux__ |
| handle->shared_id = fid; |
| if (is_new) { |
| CHECK_EQ(shm_unlink(filename.c_str()), 0) |
| << "Failed to unlink shared memory. shm_unlink failed with error " << strerror(errno); |
| } |
| #else |
| CHECK_EQ(close(fid), 0) << "Failed to close shared memory. close failed with error " |
| << strerror(errno); |
| #endif // __linux__ |
| #endif // _WIN32 |
| |
| if (is_new) { |
| new (ptr) std::atomic<int>(1); |
| } |
| handle->dptr = static_cast<char*>(ptr) + alignment_; |
| pool_[handle->dptr] = *handle; |
| } |
| |
| void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) { |
| int count = DecrementRefCount(handle); |
| CHECK_GE(count, 0); |
| #ifdef _WIN32 |
| is_free_[handle.dptr] = handle; |
| #else |
| CHECK_EQ(munmap(static_cast<char*>(handle.dptr) - alignment_, handle.size + alignment_), 0) |
| << "Failed to unmap shared memory. munmap failed with error " << strerror(errno); |
| |
| #ifdef __linux__ |
| if (handle.shared_id != -1) { |
| CHECK_EQ(close(handle.shared_id), 0) |
| << "Failed to close shared memory. close failed with error " << strerror(errno); |
| } |
| #else |
| if (count == 0) { |
| auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id); |
| CHECK_EQ(shm_unlink(filename.c_str()), 0) |
| << "Failed to unlink shared memory. shm_unlink failed with error " << strerror(errno); |
| } |
| #endif // __linux__ |
| #endif // _WIN32 |
| } |
| |
| #ifdef _WIN32 |
| inline void CPUSharedStorageManager::CheckAndRealFree() { |
| std::lock_guard<std::recursive_mutex> lock(mutex_); |
| for (auto it = std::begin(is_free_); it != std::end(is_free_);) { |
| void* ptr = static_cast<char*>(it->second.dptr) - alignment_; |
| std::atomic<int>* counter = |
| reinterpret_cast<std::atomic<int>*>(static_cast<char*>(it->second.dptr) - alignment_); |
| if ((*counter) == 0) { |
| CHECK_NE(UnmapViewOfFile(ptr), 0) << "Failed to UnmapViewOfFile shared memory "; |
| CHECK_NE(CloseHandle(map_handle_map_[ptr]), 0) << "Failed to CloseHandle shared memory "; |
| map_handle_map_.erase(ptr); |
| it = is_free_.erase(it); |
| } else { |
| ++it; |
| } |
| } |
| } |
| #endif // _WIN32 |
| } // namespace storage |
| } // namespace mxnet |
| |
| #endif // !defined(ANDROID) && !defined(__ANDROID__) |
| |
| #endif // MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_ |