[SYSTEMDS-2634] Reduced number of RPCs calls in federated backend

This patch improves the performance of the federated runtime backend by
merging the execution and cleanup RPC request batches into a single
batch of requests. Since every batch returns only a single response, we
now carefully select the right get_var, error, or other responses to
return. Overall, this reduced the number of RPC calls by almost 2x and
removed unnecessary synchronization barriers.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 720534a..4d0d5d9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -680,7 +680,7 @@
 		
 		//clear federated matrix
 		if( _fedMapping != null )
-			_fedMapping.cleanup(tid, _fedMapping.getID());
+			_fedMapping.execCleanup(tid, _fedMapping.getID());
 		
 		// change object state EMPTY
 		setDirty(false);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index f4af303..0dcb846 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -91,10 +91,24 @@
 			PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
 			PrivacyMonitor.clearCheckedConstraints();
 	
-			response = executeCommand(request);
-			conditionalAddCheckedConstraints(request, response);
-			if (!response.isSuccessful()){
-				log.error("Command " + request.getType() + " failed: " + response.getErrorMessage() + "full command: \n" + request.toString());
+			//execute command and handle privacy constraints
+			FederatedResponse tmp = executeCommand(request);
+			conditionalAddCheckedConstraints(request, tmp);
+			
+			//select the response for the entire batch of requests
+			if (!tmp.isSuccessful()) {
+				log.error("Command " + request.getType() + " failed: " 
+					+ tmp.getErrorMessage() + "full command: \n" + request.toString());
+				response = (response == null || response.isSuccessful()) 
+					? tmp : response; //return first error
+			}
+			else if( request.getType() == RequestType.GET_VAR ) {
+				if( response != null && response.isSuccessful() )
+					log.error("Multiple GET_VAR are not supported in single batch of requests.");
+				response = tmp; //return last get result
+			}
+			else if( response == null && i == requests.length-1 ) {
+				response = tmp; //return last
 			}
 		}
 		ctx.writeAndFlush(response).addListener(new CloseListener());
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b272bf9..72d1196 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -170,7 +170,14 @@
 		return readResponses;
 	}
 	
-	public void cleanup(long tid, long... id) {
+	public FederatedRequest cleanup(long tid, long... id) {
+		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
+			VariableCPInstruction.prepareRemoveInstruction(id).toString());
+		request.setTID(tid);
+		return request;
+	}
+	
+	public void execCleanup(long tid, long... id) {
 		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
 			VariableCPInstruction.prepareRemoveInstruction(id).toString());
 		request.setTID(tid);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index faae560..7df7c51 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -32,7 +32,6 @@
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.functionobjects.Mean;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 34caec2..c28a163 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -68,10 +68,10 @@
 				new CPOperand[]{input1, input2},
 				new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
 			FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+			FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else if(mo1.isFederated(FType.ROW)) { // MV + MM
@@ -81,16 +81,16 @@
 				new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
 			if( mo2.getNumColumns() == 1 ) { //MV
 				FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+				FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 				//execute federated operations and aggregate
-				Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+				Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
 				MatrixBlock ret = FederationUtils.rbind(tmp);
-				mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 				ec.setMatrixOutput(output.getName(), ret);
 			}
 			else { //MM
 				//execute federated operations and aggregate
-				mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
-				mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+				FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+				mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
 				MatrixObject out = ec.getMatrixObject(output);
 				out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
 				out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
@@ -104,10 +104,10 @@
 			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
 			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+			FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+			Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else { //other combinations
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e87bf57..60fe40b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -55,19 +55,19 @@
 	public void processInstruction(ExecutionContext ec) {
 		AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
 		MatrixObject in = ec.getMatrixObject(input1);
+		FederationMap map = in.getFedMapping();
 		
 		//create federated commands for aggregation
 		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
 		FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+		FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
 		
 		//execute federated commands and cleanups
-		FederationMap map = in.getFedMapping();
-		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2);
+		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
 		else
 			ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
-		map.cleanup(getTID(), fr1.getID());
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 63c2d71..bceb6ae 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,39 +42,36 @@
 		FederatedRequest fr2 = null;
 
 		if( mo2.isFederated() ) {
-			if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)){
+			if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
 				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
 					new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
 				mo1.getFedMapping().execute(getTID(), true, fr2);
-				
-			} else{
+			}
+			else {
 				throw new DMLRuntimeException("Matrix-matrix binary operations "
 					+ " with a federated right input are not supported yet.");
 			}
-
-		} 
+		}
 		else {
 			//matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output
-			
 			if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
 				FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
 				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
 					new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+				FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
 				//execute federated instruction and cleanup intermediates
-				mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
-				mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+				mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
 			}
 			else { //MM or MV col vector
 				FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
 				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
 					new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+				FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 				//execute federated instruction and cleanup intermediates
-				mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
-				mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+				mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
 			}
 		}
 		
-		
 		//derive new fed mapping for output
 		MatrixObject out = ec.getMatrixObject(output);
 		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 75bfe33..b6ea1fb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -39,17 +39,20 @@
 		CPOperand scalar = input2.isScalar() ? input2 : input1;
 		MatrixObject mo = ec.getMatrixObject(matrix);
 		
-		//execute federated matrix-scalar operation and cleanups
+		//prepare federated request matrix-scalar
 		FederatedRequest fr1 = !scalar.isLiteral() ?
 			mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
 		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
 			new CPOperand[]{matrix, (fr1 != null)?scalar:null},
 			new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1});
 		
-		mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
-			new FederatedRequest[]{fr1, fr2}: new FederatedRequest[]{fr2});
-		if( fr1 != null )
-			mo.getFedMapping().cleanup(getTID(), fr1.getID());
+		//execute federated matrix-scalar operation and cleanups
+		if( fr1 != null ) {
+			FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID());
+			mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+		}
+		else
+			mo.getFedMapping().execute(getTID(), true, fr2);
 		
 		//derive new fed mapping for output
 		MatrixObject out = ec.getMatrixObject(output);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 2dee64b..99a305b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -86,11 +86,12 @@
 			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
 			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+			FederatedRequest fr4 = mo1.getFedMapping()
+				.cleanup(getTID(), fr1.getID(), fr2.getID());
 			
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else { //XtwXv | XtXvy
@@ -101,11 +102,12 @@
 				new CPOperand[]{input1, input2, input3},
 				new long[]{mo1.getFedMapping().getID(), fr1.getID(), fr0[0].getID()});
 			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+			FederatedRequest fr4 = mo1.getFedMapping()
+				.cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
 			
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 292bced..fbe88d6 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -67,11 +67,11 @@
 			FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
 			FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+			FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 			
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else { //other combinations