[SYSTEMDS-2739] Adjust computeTime for CostNSize with ref counts

This patch improves the CostNsize lineage cache eviction policy
by adjusting the compute time of an cache entry with reference
counts (#hits, #misses). This patch also introduces a non-recursive
equal method for LineageItem.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 71d55e3..71f01bd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -225,7 +225,7 @@
 	public static boolean probe(LineageItem key) {
 		//TODO problematic as after probe the matrix might be kicked out of cache
 		boolean p = _cache.containsKey(key);  // in cache or in disk
-		if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.contains(key))
+		if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.containsKey(key))
 			// The sought entry was in cache but removed later 
 			LineageCacheStatistics.incrementDelHits();
 		return p;
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 20c9b67..3b8ee07 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -40,7 +40,7 @@
 		"uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", "<=",
 		"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", "replace",
 		"^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", "qsort", 
-		"qpick", "transformapply"
+		"qpick", "transformapply", "uarmax", "n+"
 		//TODO: Reuse everything. 
 	};
 	private static String[] REUSE_OPCODES  = new String[] {};
@@ -286,6 +286,10 @@
 		// Check the LRU component of weights array.
 		return (WEIGHTS[1] > 0);
 	}
+	
+	public static boolean isCostNsize() {
+		return (WEIGHTS[0] > 0);
+	}
 
 	public static boolean isDagHeightBased() {
 		// Check the DAGHEIGHT component of weights array.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 00d385e..a82c8a5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.runtime.lineage;
 
+import java.util.Map;
+
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
@@ -138,6 +140,30 @@
 		recomputeScore();
 	}
 	
+	protected synchronized void computeScore(Map<LineageItem, Integer> removeList) {
+		setTimestamp();
+		if (removeList.containsKey(_key)) {
+			//FIXME: increase computetime instead of score (that now leads to overflow).
+			// updating computingtime seamlessly takes care of spilling 
+			//_computeTime = _computeTime * (1 + removeList.get(_key));
+			score = score * (1 + removeList.get(_key));
+		}
+		if (_computeTime < 0)
+			System.out.println("after recache: "+_computeTime+" miss count: "+removeList.get(_key));
+	}
+	
+	protected synchronized void updateComputeTime() {
+		if ((Long.MAX_VALUE - _computeTime) < _computeTime) {
+			System.out.println("Overflow for: "+_key.getOpcode());
+		}
+		//FIXME: increase computetime instead of score (that now leads to overflow).
+		// updating computingtime seamlessly takes care of spilling 
+		//_computeTime = _computeTime * (1 + removeList.get(_key));
+		//_computeTime += _computeTime;
+		//recomputeScore();
+		score *= 2;
+	}
+	
 	protected synchronized long getTimestamp() {
 		return _timestamp;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
index 7025818..6fc3d38 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
@@ -20,9 +20,8 @@
 package org.apache.sysds.runtime.lineage;
 
 import java.io.IOException;
-import java.util.HashSet;
+import java.util.HashMap;
 import java.util.Map;
-import java.util.Set;
 import java.util.TreeSet;
 
 import org.apache.sysds.api.DMLScript;
@@ -37,7 +36,7 @@
 	private static long _cachesize = 0;
 	private static long CACHE_LIMIT; //limit in bytes
 	private static long _startTimestamp = 0;
-	protected static final Set<LineageItem> _removelist = new HashSet<>();
+	protected static final Map<LineageItem, Integer> _removelist = new HashMap<>();
 	private static String _outdir = null;
 	private static TreeSet<LineageCacheEntry> weightedQueue = new TreeSet<>(LineageCacheConfig.LineageCacheComparator);
 	
@@ -71,19 +70,27 @@
 			// Don't add the memory pinned entries in weighted queue. 
 			// The eviction queue should contain only entries that can
 			// be removed or spilled to disk.
-			entry.setTimestamp();
+			//entry.setTimestamp();
+			entry.computeScore(_removelist); 
+			// Adjust score according to cache miss counts.
 			weightedQueue.add(entry);
 		}
 	}
 	
 	protected static void getEntry(LineageCacheEntry entry) {
 		// Reset the timestamp to maintain the LRU component of the scoring function
-		if (!LineageCacheConfig.isTimeBased()) 
-			return;
-		
-		if (weightedQueue.remove(entry)) {
-			entry.setTimestamp();
-			weightedQueue.add(entry);
+		if (LineageCacheConfig.isTimeBased()) { 
+			if (weightedQueue.remove(entry)) {
+				entry.setTimestamp();
+				weightedQueue.add(entry);
+			}
+		}
+		// Increase computation time of the sought entry.
+		if (LineageCacheConfig.isCostNsize()) {
+			if (weightedQueue.remove(entry)) {
+				entry.updateComputeTime();
+				weightedQueue.add(entry);
+			}
 		}
 	}
 
@@ -91,8 +98,13 @@
 		if (cache.remove(e._key) != null)
 			_cachesize -= e.getSize();
 
+		// Increase priority if same entry is removed multiple times
+		if (_removelist.containsKey(e._key))
+			_removelist.replace(e._key, _removelist.get(e._key)+1);
+		else
+			_removelist.put(e._key, 1);
+
 		if (DMLScript.STATISTICS) {
-			_removelist.add(e._key);
 			LineageCacheStatistics.incrementMemDeletes();
 		}
 		// NOTE: The caller of this method maintains the eviction queue.
@@ -207,10 +219,11 @@
 				removeOrSpillEntry(cache, e, false);
 				continue;
 			}
-
+			
 			// Estimate time to write to FS + read from FS.
 			double spilltime = getDiskSpillEstimate(e) * 1000; // in milliseconds
 			double exectime = ((double) e._computeTime) / 1000000; // in milliseconds
+			//FIXME: this comuteTime is not adjusted according to hit/miss counts
 
 			if (LineageCache.DEBUG) {
 				System.out.print("LI = " + e._key.getOpcode());
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index e34979d..f0fcad4 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -178,7 +178,7 @@
 			return false;
 		
 		resetVisitStatusNR();
-		boolean ret = equalsLI((LineageItem) o);
+		boolean ret = equalsLINR((LineageItem) o);
 		resetVisitStatusNR();
 		return ret;
 	}
@@ -198,6 +198,33 @@
 		return ret;
 	}
 	
+	private boolean equalsLINR(LineageItem that) {
+		Stack<LineageItem> s1 = new Stack<>();
+		Stack<LineageItem> s2 = new Stack<>();
+		s1.push(this);
+		s2.push(that);
+		boolean ret = false;
+		while (!s1.empty() && !s2.empty()) {
+			LineageItem li1 = s1.pop();
+			LineageItem li2 = s2.pop();
+			if (li1.isVisited() || li1 == li2)
+				return true;
+
+			ret = li1._opcode.equals(li2._opcode);
+			ret &= li1._data.equals(li2._data);
+			ret &= (li1.hashCode() == li2.hashCode());
+			if (!ret) break;
+			if (ret && li1._inputs != null && li1._inputs.length == li2._inputs.length)
+				for (int i=0; i<li1._inputs.length; i++) {
+					s1.push(li1.getInputs()[i]);
+					s2.push(li2.getInputs()[i]);
+				}
+			li1.setVisited();
+		}
+		
+		return ret;
+	}
+	
 	@Override
 	public int hashCode() {
 		if (_hash == 0) {