blob: 744c61987cddbe223f822dec187f403f879017c4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/memory/pooled_allocator.h
*/
#ifndef TVM_RUNTIME_MEMORY_POOLED_ALLOCATOR_H_
#define TVM_RUNTIME_MEMORY_POOLED_ALLOCATOR_H_
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <atomic>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace runtime {
namespace memory {
class PooledAllocator : public Allocator {
public:
static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(size_t page_size = kDefaultPageSize)
: Allocator(kPooled), page_size_(page_size), used_memory_(0) {}
~PooledAllocator() { ReleaseAll(); }
Buffer Alloc(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) override {
std::lock_guard<std::recursive_mutex> lock(mu_);
size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
auto&& it = memory_pool_.find(size);
if (it != memory_pool_.end() && !it->second.empty()) {
auto&& pool = it->second;
auto ret = pool.back();
pool.pop_back();
return ret;
}
Buffer buf;
buf.device = dev;
buf.size = size;
buf.alloc_type = kPooled;
try {
buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint);
} catch (InternalError& err) {
LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.what();
LOG(WARNING) << "Trying to release all unused memory and reallocate...";
ReleaseAll();
buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint);
}
used_memory_.fetch_add(size, std::memory_order_relaxed);
VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B";
return buf;
}
Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint,
const std::string& mem_scope) override {
if (AllowMemoryScope(mem_scope)) {
return Allocator::Alloc(dev, shape, type_hint, mem_scope);
}
LOG(FATAL) << "This alloc should be implemented";
return {};
}
void Free(const Buffer& buffer) override {
std::lock_guard<std::recursive_mutex> lock(mu_);
if (memory_pool_.find(buffer.size) == memory_pool_.end()) {
memory_pool_.emplace(buffer.size, std::vector<Buffer>{});
}
memory_pool_.at(buffer.size).push_back(buffer);
VLOG(1) << "reclaim buffer " << buffer.size;
}
void Clear() override { ReleaseAll(); }
size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }
protected:
virtual void* DeviceAllocDataSpace(Device dev, size_t nbytes, size_t alignment,
DLDataType type_hint) {
return DeviceAPI::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint);
}
virtual void DeviceFreeDataSpace(Device dev, void* ptr) {
DeviceAPI::Get(dev)->FreeDataSpace(dev, ptr);
}
virtual void ReleaseAll() {
std::lock_guard<std::recursive_mutex> lock(mu_);
for (auto const& it : memory_pool_) {
auto const& pool = it.second;
for (auto const& buf : pool) {
DeviceFreeDataSpace(buf.device, buf.data);
}
}
memory_pool_.clear();
used_memory_ = 0;
VLOG(1) << "release all buffers";
}
protected:
size_t page_size_;
std::atomic<size_t> used_memory_;
std::unordered_map<size_t, std::vector<Buffer>> memory_pool_;
std::recursive_mutex mu_;
};
} // namespace memory
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MEMORY_POOLED_ALLOCATOR_H_