blob: 9369026862fb2e5d47a137c611e11035627f3d1b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.gpu.context;
import java.util.HashMap;
import java.util.HashSet;
import java.util.stream.Collectors;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.utils.GPUStatistics;
import jcuda.Pointer;
public class GPULazyCudaFreeMemoryManager {
protected static final Log LOG = LogFactory.getLog(GPULazyCudaFreeMemoryManager.class.getName());
GPUMemoryManager gpuManager;
public GPULazyCudaFreeMemoryManager(GPUMemoryManager gpuManager) {
this.gpuManager = gpuManager;
}
/**
* Map of free blocks allocate on GPU. maps size_of_block -> pointer on GPU
*/
private HashMap<Long, Set<Pointer>> rmvarGPUPointers = new HashMap<>();
/**
* Get any pointer of the given size from rmvar-ed pointers (applicable if eager cudaFree is set to false)
*
* @param opcode opcode
* @param size size in bytes
* @return pointer
*/
public Pointer getRmvarPointer(String opcode, long size) {
if (rmvarGPUPointers.containsKey(size)) {
if(LOG.isTraceEnabled())
LOG.trace("Getting rmvar-ed pointers for size:" + size);
Pointer A = remove(rmvarGPUPointers, size); // remove from rmvarGPUPointers as you are not calling cudaFree
if(DMLScript.STATISTICS)
GPUStatistics.cudaAllocReuseCount.increment();
return A;
}
else {
return null;
}
}
public Set<Pointer> getAllPointers() {
return rmvarGPUPointers.values().stream().flatMap(ptrs -> ptrs.stream()).collect(Collectors.toSet());
}
public void clearAll() {
Set<Pointer> toFree = new HashSet<>();
for(Set<Pointer> ptrs : rmvarGPUPointers.values()) {
toFree.addAll(ptrs);
}
rmvarGPUPointers.clear();
for(Pointer ptr : toFree) {
gpuManager.guardedCudaFree(ptr);
}
}
public Pointer getRmvarPointerMinSize(String opcode, long minSize) throws DMLRuntimeException {
Optional<Long> toClear = rmvarGPUPointers.entrySet().stream().filter(e -> e.getValue().size() > 0).map(e -> e.getKey())
.filter(size -> size >= minSize).min((s1, s2) -> s1 < s2 ? -1 : 1);
if(toClear.isPresent()) {
Pointer A = remove(rmvarGPUPointers, toClear.get()); // remove from rmvarGPUPointers as you are not calling cudaFree
if(DMLScript.STATISTICS)
GPUStatistics.cudaAllocReuseCount.increment();
return A;
}
return null;
}
/**
* Remove any pointer in the given hashmap
*
* @param hm hashmap of size, pointers
* @param size size in bytes
* @return the pointer that was removed
*/
private static Pointer remove(HashMap<Long, Set<Pointer>> hm, long size) {
Pointer A = hm.get(size).iterator().next();
remove(hm, size, A);
return A;
}
/**
* Remove a specific pointer in the given hashmap
*
* @param hm hashmap of size, pointers
* @param size size in bytes
* @param ptr pointer to be removed
*/
private static void remove(HashMap<Long, Set<Pointer>> hm, long size, Pointer ptr) {
hm.get(size).remove(ptr);
if (hm.get(size).isEmpty())
hm.remove(size);
}
/**
* Return the total memory in bytes used by this memory manager
* @return number of bytes
*/
public long getTotalMemoryAllocated() {
long rmvarMemoryAllocated = 0;
for(long numBytes : rmvarGPUPointers.keySet()) {
rmvarMemoryAllocated += numBytes;
}
return rmvarMemoryAllocated;
}
/**
* Get total number of rmvared pointers
*
* @return number of pointers
*/
public int getNumPointers() {
return rmvarGPUPointers.size();
}
/**
* Add a pointer to the rmvar-ed list
* @param size size of the pointer
* @param toFree pointer
*/
public void add(long size, Pointer toFree) {
Set<Pointer> freeList = rmvarGPUPointers.get(size);
if (freeList == null) {
freeList = new HashSet<>();
rmvarGPUPointers.put(size, freeList);
}
if (freeList.contains(toFree))
throw new RuntimeException("GPU : Internal state corrupted, double free");
freeList.add(toFree);
}
/**
* Remove a specific pointer if present in the internal hashmap
*
* @param size size in bytes
* @param ptr pointer to be removed
*/
public void removeIfPresent(long size, Pointer ptr) {
if(rmvarGPUPointers.containsKey(size) && rmvarGPUPointers.get(size).contains(ptr)) {
rmvarGPUPointers.get(size).remove(ptr);
if (rmvarGPUPointers.get(size).isEmpty())
rmvarGPUPointers.remove(size);
}
}
}