[SYSTEMDS-2739] Adjust computeTime for CostNSize with ref counts
This patch improves the CostNsize lineage cache eviction policy
by adjusting the compute time of an cache entry with reference
counts (#hits, #misses). This patch also introduces a non-recursive
equal method for LineageItem.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 71d55e3..71f01bd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -225,7 +225,7 @@
public static boolean probe(LineageItem key) {
//TODO problematic as after probe the matrix might be kicked out of cache
boolean p = _cache.containsKey(key); // in cache or in disk
- if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.contains(key))
+ if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.containsKey(key))
// The sought entry was in cache but removed later
LineageCacheStatistics.incrementDelHits();
return p;
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 20c9b67..3b8ee07 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -40,7 +40,7 @@
"uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", "<=",
"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", "replace",
"^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", "qsort",
- "qpick", "transformapply"
+ "qpick", "transformapply", "uarmax", "n+"
//TODO: Reuse everything.
};
private static String[] REUSE_OPCODES = new String[] {};
@@ -286,6 +286,10 @@
// Check the LRU component of weights array.
return (WEIGHTS[1] > 0);
}
+
+ public static boolean isCostNsize() {
+ return (WEIGHTS[0] > 0);
+ }
public static boolean isDagHeightBased() {
// Check the DAGHEIGHT component of weights array.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 00d385e..a82c8a5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.lineage;
+import java.util.Map;
+
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
@@ -138,6 +140,30 @@
recomputeScore();
}
+ protected synchronized void computeScore(Map<LineageItem, Integer> removeList) {
+ setTimestamp();
+ if (removeList.containsKey(_key)) {
+ //FIXME: increase computetime instead of score (that now leads to overflow).
+ // updating computingtime seamlessly takes care of spilling
+ //_computeTime = _computeTime * (1 + removeList.get(_key));
+ score = score * (1 + removeList.get(_key));
+ }
+ if (_computeTime < 0)
+ System.out.println("after recache: "+_computeTime+" miss count: "+removeList.get(_key));
+ }
+
+ protected synchronized void updateComputeTime() {
+ if ((Long.MAX_VALUE - _computeTime) < _computeTime) {
+ System.out.println("Overflow for: "+_key.getOpcode());
+ }
+ //FIXME: increase computetime instead of score (that now leads to overflow).
+ // updating computingtime seamlessly takes care of spilling
+ //_computeTime = _computeTime * (1 + removeList.get(_key));
+ //_computeTime += _computeTime;
+ //recomputeScore();
+ score *= 2;
+ }
+
protected synchronized long getTimestamp() {
return _timestamp;
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
index 7025818..6fc3d38 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
@@ -20,9 +20,8 @@
package org.apache.sysds.runtime.lineage;
import java.io.IOException;
-import java.util.HashSet;
+import java.util.HashMap;
import java.util.Map;
-import java.util.Set;
import java.util.TreeSet;
import org.apache.sysds.api.DMLScript;
@@ -37,7 +36,7 @@
private static long _cachesize = 0;
private static long CACHE_LIMIT; //limit in bytes
private static long _startTimestamp = 0;
- protected static final Set<LineageItem> _removelist = new HashSet<>();
+ protected static final Map<LineageItem, Integer> _removelist = new HashMap<>();
private static String _outdir = null;
private static TreeSet<LineageCacheEntry> weightedQueue = new TreeSet<>(LineageCacheConfig.LineageCacheComparator);
@@ -71,19 +70,27 @@
// Don't add the memory pinned entries in weighted queue.
// The eviction queue should contain only entries that can
// be removed or spilled to disk.
- entry.setTimestamp();
+ //entry.setTimestamp();
+ entry.computeScore(_removelist);
+ // Adjust score according to cache miss counts.
weightedQueue.add(entry);
}
}
protected static void getEntry(LineageCacheEntry entry) {
// Reset the timestamp to maintain the LRU component of the scoring function
- if (!LineageCacheConfig.isTimeBased())
- return;
-
- if (weightedQueue.remove(entry)) {
- entry.setTimestamp();
- weightedQueue.add(entry);
+ if (LineageCacheConfig.isTimeBased()) {
+ if (weightedQueue.remove(entry)) {
+ entry.setTimestamp();
+ weightedQueue.add(entry);
+ }
+ }
+ // Increase computation time of the sought entry.
+ if (LineageCacheConfig.isCostNsize()) {
+ if (weightedQueue.remove(entry)) {
+ entry.updateComputeTime();
+ weightedQueue.add(entry);
+ }
}
}
@@ -91,8 +98,13 @@
if (cache.remove(e._key) != null)
_cachesize -= e.getSize();
+ // Increase priority if same entry is removed multiple times
+ if (_removelist.containsKey(e._key))
+ _removelist.replace(e._key, _removelist.get(e._key)+1);
+ else
+ _removelist.put(e._key, 1);
+
if (DMLScript.STATISTICS) {
- _removelist.add(e._key);
LineageCacheStatistics.incrementMemDeletes();
}
// NOTE: The caller of this method maintains the eviction queue.
@@ -207,10 +219,11 @@
removeOrSpillEntry(cache, e, false);
continue;
}
-
+
// Estimate time to write to FS + read from FS.
double spilltime = getDiskSpillEstimate(e) * 1000; // in milliseconds
double exectime = ((double) e._computeTime) / 1000000; // in milliseconds
+ //FIXME: this comuteTime is not adjusted according to hit/miss counts
if (LineageCache.DEBUG) {
System.out.print("LI = " + e._key.getOpcode());
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index e34979d..f0fcad4 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -178,7 +178,7 @@
return false;
resetVisitStatusNR();
- boolean ret = equalsLI((LineageItem) o);
+ boolean ret = equalsLINR((LineageItem) o);
resetVisitStatusNR();
return ret;
}
@@ -198,6 +198,33 @@
return ret;
}
+ private boolean equalsLINR(LineageItem that) {
+ Stack<LineageItem> s1 = new Stack<>();
+ Stack<LineageItem> s2 = new Stack<>();
+ s1.push(this);
+ s2.push(that);
+ boolean ret = false;
+ while (!s1.empty() && !s2.empty()) {
+ LineageItem li1 = s1.pop();
+ LineageItem li2 = s2.pop();
+ if (li1.isVisited() || li1 == li2)
+ return true;
+
+ ret = li1._opcode.equals(li2._opcode);
+ ret &= li1._data.equals(li2._data);
+ ret &= (li1.hashCode() == li2.hashCode());
+ if (!ret) break;
+ if (ret && li1._inputs != null && li1._inputs.length == li2._inputs.length)
+ for (int i=0; i<li1._inputs.length; i++) {
+ s1.push(li1.getInputs()[i]);
+ s2.push(li2.getInputs()[i]);
+ }
+ li1.setVisited();
+ }
+
+ return ret;
+ }
+
@Override
public int hashCode() {
if (_hash == 0) {