[SYSTEMDS-2624] Rework acquireRead and cleanup of federated matrices

The integration of federated matrices and frames into our caching and
buffer pool framework was suboptimal because (1) repeated pinning of a
federated data into memory of the coordinator caused repeated
consolidation of the federated data, and (2) missing cleanup of
federated data intermediates. This rework fixes both issues and thus,
reduces unnecessary memory pressure at the workers and unnecessary data
transfer at the coordinator.

For the cleanup of federated data, we added special guards to ensure
only federated intermediates (and files thereof) are deleted but NOT the
original federated input files.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index a86acae..b9ef965 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -242,8 +242,9 @@
 			
 			// try to reuse instruction result from lineage cache
 			if( !LineageCache.reuse(tmp, ec) ) {
-				// process actual instruction
 				long et0 = !ReuseCacheType.isNone() ? System.nanoTime() : 0;
+				
+				// process actual instruction
 				tmp.processInstruction(ec);
 				
 				// cache result
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index c287aeb..c809a84 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -466,10 +466,16 @@
 		//(probe data for cache_nowrite / jvm_reuse)
 		if( _data==null && isEmpty(true) ) {
 			try {
-				if( DMLScript.STATISTICS )
-					CacheStatistics.incrementHDFSHits();
-				
-				if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
+				if( isFederated() ) {
+					_data = readBlobFromFederated( _fedMapping );
+					
+					//mark for initial local write despite read operation
+					_requiresLocalWrite = CACHING_WRITE_CACHE_ON_READ;
+				}
+				else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
+					if( DMLScript.STATISTICS )
+						CacheStatistics.incrementHDFSHits();
+					
 					//check filename
 					if( _hdfsFileName == null )
 						throw new DMLRuntimeException("Cannot read matrix for empty filename.");
@@ -661,6 +667,10 @@
 					gObj.clearData(null, DMLScript.EAGER_CUDA_FREE);
 		}
 		
+		//clear federated matrix
+		if( _fedMapping != null )
+			_fedMapping.cleanup(_fedMapping.getID());
+		
 		// change object state EMPTY
 		setDirty(false);
 		setEmpty();
@@ -923,19 +933,30 @@
 		return hashCode() + " " + debugNameEnding;
 	}
 
-	protected T readBlobFromHDFS(String fname) 
-		throws IOException 
-	{
+	//HDFS read
+	protected T readBlobFromHDFS(String fname) throws IOException {
 		MetaDataFormat iimd = (MetaDataFormat) _metaData;
 		DataCharacteristics dc = iimd.getDataCharacteristics();
 		return readBlobFromHDFS(fname, dc.getDims());
 	}
 
-	protected abstract T readBlobFromHDFS(String fname, long[] dims) throws IOException;
+	protected abstract T readBlobFromHDFS(String fname, long[] dims)
+		throws IOException;
 
+	//RDD read
 	protected abstract T readBlobFromRDD(RDDObject rdd, MutableBoolean status)
 		throws IOException;
 
+	// Federated read
+	protected T readBlobFromFederated(FederationMap fedMap) throws IOException {
+		MetaDataFormat iimd = (MetaDataFormat) _metaData;
+		DataCharacteristics dc = iimd.getDataCharacteristics();
+		return readBlobFromFederated(fedMap, dc.getDims());
+	}
+	
+	protected abstract T readBlobFromFederated(FederationMap fedMap, long[] dims)
+		throws IOException;
+
 	protected abstract void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
 		throws IOException;
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index ef6e790..0418ce7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -31,6 +31,7 @@
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.io.FileFormatProperties;
 import org.apache.sysds.runtime.io.FrameReaderFactory;
@@ -160,41 +161,6 @@
 	}
 	
 	@Override
-	public FrameBlock acquireRead() {
-		// forward call for non-federated objects
-		if( !isFederated() )
-			return super.acquireRead();
-		
-		FrameBlock result = new FrameBlock(_schema);
-		// provide long support?
-		result.ensureAllocatedColumns((int) _metaData.getDataCharacteristics().getRows());
-		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = _fedMapping.requestFederatedData();
-		try {
-			for(Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
-				FederatedRange range = readResponse.getLeft();
-				FederatedResponse response = readResponse.getRight().get();
-				// add result
-				FrameBlock multRes = (FrameBlock) response.getData()[0];
-				for (int r = 0; r < multRes.getNumRows(); r++) {
-					for (int c = 0; c < multRes.getNumColumns(); c++) {
-						int destRow = range.getBeginDimsInt()[0] + r;
-						int destCol = range.getBeginDimsInt()[1] + c;
-						result.set(destRow, destCol, multRes.get(r, c));
-					}
-				}
-			}
-		}
-		catch(Exception e) {
-			throw new DMLRuntimeException("Federated Frame read failed.", e);
-		}
-		
-		//keep returned object for future use 
-		acquireModify(result);
-		
-		return result;
-	}
-	
-	@Override
 	protected FrameBlock readBlobFromCache(String fname) throws IOException {
 		return (FrameBlock)LazyWriteBuffer.readBlock(fname, false);
 	}
@@ -269,6 +235,36 @@
 		
 		return fb;
 	}
+	
+	@Override
+	protected FrameBlock readBlobFromFederated(FederationMap fedMap, long[] dims)
+		throws IOException
+	{
+		FrameBlock ret = new FrameBlock(_schema);
+		// provide long support?
+		ret.ensureAllocatedColumns((int) dims[0]);
+		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = fedMap.requestFederatedData();
+		try {
+			for(Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
+				FederatedRange range = readResponse.getLeft();
+				FederatedResponse response = readResponse.getRight().get();
+				// add result
+				FrameBlock multRes = (FrameBlock) response.getData()[0];
+				for (int r = 0; r < multRes.getNumRows(); r++) {
+					for (int c = 0; c < multRes.getNumColumns(); c++) {
+						int destRow = range.getBeginDimsInt()[0] + r;
+						int destCol = range.getBeginDimsInt()[1] + c;
+						ret.set(destRow, destCol, multRes.get(r, c));
+					}
+				}
+			}
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("Federated Frame read failed.", e);
+		}
+		
+		return ret;
+	}
 
 	@Override
 	protected void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
@@ -290,5 +286,4 @@
 		//lazy evaluation of pending transformations.
 		SparkExecutionContext.writeFrameRDDtoHDFS(rdd, fname, iimd.getFileFormat());
 	}
-
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 6216ba5..d96b700 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -39,6 +39,7 @@
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.io.FileFormatProperties;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -393,40 +394,7 @@
 
 		return sb.toString();
 	}
-	
-	@Override
-	public MatrixBlock acquireRead() {
-		// forward call for non-federated objects
-		if( !isFederated() )
-			return super.acquireRead();
-		
-		long[] dims = getDataCharacteristics().getDims();
-		// TODO sparse optimization
-		MatrixBlock result = new MatrixBlock((int) dims[0], (int) dims[1], false);
-		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = _fedMapping.requestFederatedData();
-		try {
-			for (Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
-				FederatedRange range = readResponse.getLeft();
-				FederatedResponse response = readResponse.getRight().get();
-				// add result
-				int[] beginDimsInt = range.getBeginDimsInt();
-				int[] endDimsInt = range.getEndDimsInt();
-				MatrixBlock multRes = (MatrixBlock) response.getData()[0];
-				result.copy(beginDimsInt[0], endDimsInt[0] - 1,
-					beginDimsInt[1], endDimsInt[1] - 1, multRes, false);
-				result.setNonZeros(result.getNonZeros() + multRes.getNonZeros());
-			}
-		}
-		catch (Exception e) {
-			throw new DMLRuntimeException("Federated matrix read failed.", e);
-		}
-		
-		//keep returned object for future use 
-		acquireModify(result);
-		
-		return result;
-	}
-	
+
 	// *********************************************
 	// ***                                       ***
 	// ***      LOW-LEVEL PROTECTED METHODS      ***
@@ -540,6 +508,33 @@
 		return mb;
 	}
 	
+	@Override
+	protected MatrixBlock readBlobFromFederated(FederationMap fedMap, long[] dims)
+		throws IOException
+	{
+		// TODO sparse optimization
+		MatrixBlock ret = new MatrixBlock((int) dims[0], (int) dims[1], false);
+		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = fedMap.requestFederatedData();
+		try {
+			for (Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
+				FederatedRange range = readResponse.getLeft();
+				FederatedResponse response = readResponse.getRight().get();
+				// add result
+				int[] beginDimsInt = range.getBeginDimsInt();
+				int[] endDimsInt = range.getEndDimsInt();
+				MatrixBlock multRes = (MatrixBlock) response.getData()[0];
+				ret.copy(beginDimsInt[0], endDimsInt[0] - 1,
+					beginDimsInt[1], endDimsInt[1] - 1, multRes, false);
+				ret.setNonZeros(ret.getNonZeros() + multRes.getNonZeros());
+			}
+		}
+		catch (Exception e) {
+			throw new DMLRuntimeException("Federated matrix read failed.", e);
+		}
+		
+		return ret;
+	}
+	
 	/**
 	 * Writes in-memory matrix to HDFS in a specified format.
 	 */
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
index d222f43..32fb72b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
@@ -28,6 +28,7 @@
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.data.TensorIndexes;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
@@ -141,6 +142,13 @@
 		// TODO read from RDD
 		return SparkExecutionContext.toTensorBlock((JavaPairRDD<TensorIndexes, TensorBlock>) rdd.getRDD(), tc);
 	}
+	
+	@Override
+	protected TensorBlock readBlobFromFederated(FederationMap fedMap, long[] dims)
+		throws IOException
+	{
+		throw new DMLRuntimeException("Unsupported federated tensors");
+	}
 
 	@Override
 	protected void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index fcb5db3..31a467f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -59,7 +59,9 @@
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 public class ExecutionContext {
@@ -71,6 +73,7 @@
 	//symbol table
 	protected LocalVariableMap _variables;
 	protected boolean _autoCreateVars;
+	protected Set<String> _guardedFiles = new HashSet<>();
 	
 	//lineage map, cache, prepared dedup blocks
 	protected Lineage _lineage;
@@ -131,6 +134,10 @@
 	public void setAutoCreateVars(boolean flag) {
 		_autoCreateVars = flag;
 	}
+	
+	public void addGuardedFilename(String fname) {
+		_guardedFiles.add(fname);
+	}
 
 	/**
 	 * Get the i-th GPUContext
@@ -751,7 +758,7 @@
 			//compute ref count only if matrix cleanup actually necessary
 			if ( mo.isCleanupEnabled() && !getVariables().hasReferences(mo) )  {
 				mo.clearData(); //clean cached data
-				if( fileExists ) {
+				if( fileExists && !_guardedFiles.contains(mo.getFileName()) ) {
 					HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
 					HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
 				}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index f14bbb0..47ca43c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -181,6 +181,7 @@
 		
 		//TODO spawn async load of data, otherwise on first access
 		_ec.setVariable(String.valueOf(id), cd);
+		_ec.addGuardedFilename(filename);
 		
 		if (dataType == Types.DataType.FRAME) {
 			FrameObject frameObject = (FrameObject) cd;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 00f3b04..0a5a2a2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -54,9 +54,10 @@
 		}
 		else if (inst instanceof BinaryCPInstruction) {
 			BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
-			if( instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated()
-				|| instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated() ) {
-				return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+			if( (instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated())
+				|| (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated()) ) {
+				if(!instruction.getOpcode().equals("append")) //TODO support rbind/cbind
+					return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
 			}
 		}
 		else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
index 81c0b00..0f28a7b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
@@ -96,7 +96,8 @@
 		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
 		loadTestConfiguration(config);
 		fullDMLScriptName = HOME + TEST_NAME + ".dml";
-		programArgs = new String[] {"-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+		programArgs = new String[] {"-nvargs",
+			"in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
 			"cols=" + cols, "out_R=" + output("R"), "out_C=" + output("C")};
 
 		runTest(true, false, null, -1);