blob: 480c2c3a9f8ddca4aaa67aad3bfd33c87a37b3aa [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISABLE_WARNINGS
#include "singa/core/memory.h"
#include <iostream>
#include "singa/proto/core.pb.h"
#include "singa/utils/logging.h"
#ifdef USE_CUDA
namespace singa {
std::pair<size_t, size_t> CnMemPool::GetMemUsage() {
size_t free, total;
auto status = cnmemMemGetInfo(&free, &total, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< cnmemGetErrorString(status);
return std::make_pair(free, total);
}
std::pair<size_t, size_t> CnMemPool::GetMemUsage(int id) {
CHECK_EQ(cudaSetDevice(id), cudaError_t::cudaSuccess);
size_t free, total;
auto status = cnmemMemGetInfo(&free, &total, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< cnmemGetErrorString(status);
return std::make_pair(free, total);
}
CnMemPool::CnMemPool(int numDevices, size_t init_size, size_t max_size) {
for (int i = 0; i < numDevices; i++) conf_.add_device(i);
conf_.set_init_size(init_size);
conf_.set_max_size(max_size);
}
CnMemPool::CnMemPool(const MemPoolConf &conf) { conf_ = conf; }
void CnMemPool::Init() {
mtx_.lock();
if (!initialized_) {
const size_t kNBytesPerMB = (1u << 20);
CHECK_GE(conf_.device_size(), 1);
cnmemDevice_t *settingPtr = new cnmemDevice_t[conf_.device_size()];
CHECK_GT(conf_.init_size(), 0u);
int i = 0;
for (auto device : conf_.device()) {
settingPtr[i].device = device;
settingPtr[i].size = conf_.init_size() * kNBytesPerMB;
settingPtr[i].numStreams = 0;
settingPtr[i].streams = NULL;
settingPtr[i].streamSizes = 0;
i++;
}
auto status = cnmemInit(conf_.device_size(), settingPtr, conf_.flag());
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< " " << cnmemGetErrorString(status);
delete[] settingPtr;
initialized_ = true;
}
mtx_.unlock();
}
CnMemPool::~CnMemPool() {
mtx_.lock();
if (initialized_) {
cnmemStatus_t status = cnmemFinalize();
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< " " << cnmemGetErrorString(status);
initialized_ = false;
}
mtx_.unlock();
}
void CnMemPool::Malloc(void **ptr, const size_t size) {
if (!initialized_) Init();
cnmemStatus_t status = cnmemMalloc(ptr, size, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< " " << cnmemGetErrorString(status);
}
void CnMemPool::Free(void *ptr) {
CHECK(initialized_)
<< "Cannot free the memory as the pool is not initialzied";
cnmemStatus_t status = cnmemFree(ptr, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< " " << cnmemGetErrorString(status);
}
// ===========================================================================
void CudaMemPool::Malloc(void **ptr, const size_t size) {
cudaError_t status = cudaMalloc(ptr, size);
CHECK_EQ(status, cudaError_t::cudaSuccess);
}
void CudaMemPool::Free(void *ptr) {
cudaError_t status = cudaFree(ptr);
CHECK_EQ(status, cudaError_t::cudaSuccess);
}
} // namespace singa
#endif
#endif