[SYSTEMDS-2624] Cleanup federated workers for repeated execution

This patch fixes the cleanup of federated workers to perform a full
cleanup of variables and execution context before and after every
execution. This change now enables keeping the federated workers as
standing executors and launch repeated coordinator jobs without any
conflicts of existing variables or unnecessary memory pressure and
evictions. Furthermore this also adds related Kmeans tests with multiple
runs (and reset ID sequences to provoke conflicts) as well as PCA tests
with more than two workers.
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index d4743ed..635aef7 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -59,6 +59,7 @@
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
@@ -505,6 +506,9 @@
 		sb.append(DMLScript.getUUID());
 		String dirSuffix = sb.toString();
 		
+		//0) cleanup federated workers if necessary
+		FederatedData.clearFederatedWorkers();
+		
 		//1) cleanup scratch space (everything for current uuid) 
 		//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
 		HDFSTool.deleteFileIfExistOnHDFS( config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix );
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index c6d7e6e..d5ef760 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -422,6 +422,10 @@
 		LOG.trace("PARFOR: ParForProgramBlock created with mode = "+_execMode+", optmode = "+_optMode+", numThreads = "+_numThreads);
 	}
 	
+	public static void resetWorkerIDs() {
+		_pwIDSeq.reset();
+	}
+	
 	public long getID() {
 		return _ID;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
index 1d06f46..352728d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.federated;
 
+import java.util.ArrayList;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -44,6 +45,18 @@
 			k -> deriveExecutionContext(_main));
 	}
 	
+	public void clear() {
+		//handle main symbol table (w/ tmp list for concurrent modification)
+		for( String varName : new ArrayList<>(_main.getVariables().keySet()) )
+			_main.cleanupDataObject(_main.removeVariable(varName));
+		
+		//handle parfor execution contexts
+		for( ExecutionContext ec : _parEc.values() )
+			for( String varName : ec.getVariables().keySet() )
+				_main.cleanupDataObject(ec.removeVariable(varName));
+		_parEc.clear();
+	}
+	
 	private static ExecutionContext createExecutionContext() {
 		ExecutionContext ec = ExecutionContextFactory.createContext();
 		ec.setAutoCreateVars(true); //w/o createvar inst
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 296e6f2..d161522 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -32,30 +32,37 @@
 import io.netty.handler.codec.serialization.ObjectDecoder;
 import io.netty.handler.codec.serialization.ObjectEncoder;
 import io.netty.util.concurrent.Promise;
+
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
 import java.util.concurrent.Future;
 
 
 public class FederatedData {
-	private Types.DataType _dataType;
-	private InetSocketAddress _address;
-	private String _filepath;
+	private static Set<InetSocketAddress> _allFedSites = new HashSet<>();
+	
+	private final Types.DataType _dataType;
+	private final InetSocketAddress _address;
+	private final String _filepath;
 	/**
 	 * The ID of default matrix/tensor on which operations get executed if no other ID is given.
 	 */
 	private long _varID = -1; // -1 is never valid since varIDs start at 0
-	private int _nrThreads = DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS;
-
 
 	public FederatedData(Types.DataType dataType, InetSocketAddress address, String filepath) {
 		_dataType = dataType;
 		_address = address;
 		_filepath = filepath;
+		if( _address != null )
+			_allFedSites.add(_address);
 	}
 	
 	/**
@@ -105,16 +112,20 @@
 		return executeFederatedOperation(request);
 	}
 	
+	public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) {
+		return executeFederatedOperation(_address, request);
+	}
+	
 	/**
 	 * Executes an federated operation on a federated worker.
 	 *
 	 * @param request the requested operation
 	 * @return the response
 	 */
-	public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) {
+	public static Future<FederatedResponse> executeFederatedOperation(InetSocketAddress address, FederatedRequest... request) {
 		// Careful with the number of threads. Each thread opens connections to multiple files making resulting in 
 		// java.io.IOException: Too many open files
-		EventLoopGroup workerGroup = new NioEventLoopGroup(_nrThreads);
+		EventLoopGroup workerGroup = new NioEventLoopGroup(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
 		try {
 			Bootstrap b = new Bootstrap();
 			final DataRequestHandler handler = new DataRequestHandler(workerGroup);
@@ -128,7 +139,7 @@
 				}
 			});
 			
-			ChannelFuture f = b.connect(_address).sync();
+			ChannelFuture f = b.connect(address).sync();
 			Promise<FederatedResponse> promise = f.channel().eventLoop().newPromise();
 			handler.setPromise(promise);
 			f.channel().writeAndFlush(request);
@@ -142,6 +153,21 @@
 		}
 	}
 	
+	public static void clearFederatedWorkers() {
+		if( _allFedSites.isEmpty() )
+			return;
+		
+		//create and execute clear request on all workers
+		FederatedRequest fr = new FederatedRequest(RequestType.CLEAR);
+		List<Future<FederatedResponse>> ret = new ArrayList<>();
+		for( InetSocketAddress address : _allFedSites )
+			ret.add(executeFederatedOperation(address, fr));
+		
+		//wait for successful completion
+		FederationUtils.waitFor(ret);
+		_allFedSites.clear();
+	}
+	
 	private static class DataRequestHandler extends ChannelInboundHandlerAdapter {
 		private Promise<FederatedResponse> _prom;
 		private EventLoopGroup _workerGroup;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index d62e6f6..6c9be16 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -37,6 +37,7 @@
 		GET_VAR,   // return local variable to main
 		EXEC_INST, // execute arbitrary instruction over
 		EXEC_UDF,  // execute arbitrary user-defined function
+		CLEAR,     // clear all variables and execution contexts (i.e., rmvar ALL)
 	}
 	
 	private RequestType _method;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 35b844d..6dd0abc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -119,6 +119,8 @@
 					return execInstruction(request);
 				case EXEC_UDF:
 					return execUDF(request);
+				case CLEAR:
+					return execClear();
 				default:
 					String message = String.format("Method %s is not supported.", method);
 					return new FederatedResponse(ResponseType.ERROR,
@@ -278,6 +280,11 @@
 		}
 	}
 
+	private FederatedResponse execClear() {
+		_ecm.clear();
+		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
+	}
+	
 	private static void checkNumParams(int actual, int... expected) {
 		if (Arrays.stream(expected).anyMatch(x -> x == actual))
 			return;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b25f8b9..b272bf9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -154,7 +154,7 @@
 		// prepare results (future federated responses), with optional wait to ensure the 
 		// order of requests without data dependencies (e.g., cleanup RPCs)
 		if( wait )
-			waitFor(ret);
+			FederationUtils.waitFor(ret);
 		return ret.toArray(new Future[0]);
 	}
 	
@@ -178,17 +178,7 @@
 		for(FederatedData fd : _fedMap.values())
 			tmp.add(fd.executeFederatedOperation(request));
 		//wait to avoid interference w/ following requests
-		waitFor(tmp);
-	}
-	
-	private static void waitFor(List<Future<FederatedResponse>> responses) {
-		try {
-			for(Future<FederatedResponse> fr : responses)
-				fr.get();
-		}
-		catch(Exception ex) {
-			throw new DMLRuntimeException(ex);
-		}
+		FederationUtils.waitFor(tmp);
 	}
 	
 	private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 429834b..9a14aba 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -21,6 +21,7 @@
 
 
 import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.Future;
 
 import org.apache.sysds.common.Types.ExecType;
@@ -44,6 +45,10 @@
 public class FederationUtils {
 	private static final IDSequence _idSeq = new IDSequence();
 	
+	public static void resetFedDataID() {
+		_idSeq.reset();
+	}
+	
 	public static long getNextFedDataID() {
 		return _idSeq.getNextID();
 	}
@@ -159,4 +164,14 @@
 			throw new DMLRuntimeException("Unsupported aggregation operator: "
 				+ aop.aggOp.increOp.fn.getClass().getSimpleName());
 	}
+	
+	public static void waitFor(List<Future<FederatedResponse>> responses) {
+		try {
+			for(Future<FederatedResponse> fr : responses)
+				fr.get();
+		}
+		catch(Exception ex) {
+			throw new DMLRuntimeException(ex);
+		}
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index 7da40ab..7b60476 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -26,6 +26,8 @@
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -50,6 +52,8 @@
 	public int cols;
 	@Parameterized.Parameter(2)
 	public int runs;
+	@Parameterized.Parameter(3)
+	public int rep;
 
 	@Override
 	public void setUp() {
@@ -61,8 +65,9 @@
 	public static Collection<Object[]> data() {
 		// rows have to be even and > 1
 		return Arrays.asList(new Object[][] {
-			{10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
-			{10000, 10, 2}, {2000, 50, 2}, {1000, 100, 2}, //concurrent requests
+			{10000, 10, 1, 1}, {2000, 50, 1, 1}, {1000, 100, 1, 1},
+			{10000, 10, 2, 1}, {2000, 50, 2, 1}, {1000, 100, 2, 1}, //concurrent requests
+			{10000, 10, 2, 2}, //repeated exec
 			//TODO more runs e.g., 16 -> but requires rework RPC framework first
 			//(e.g., see paramserv?)
 		});
@@ -115,26 +120,31 @@
 			"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
 			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
 			"runs=" + String.valueOf(runs), "out=" + output("Z")};
-		runTest(true, false, null, -1);
-
+		
+		for( int i=0; i<rep; i++ ) {
+			ParForProgramBlock.resetWorkerIDs();
+			FederationUtils.resetFedDataID();
+			runTest(true, false, null, -1);
+		
+			// check for federated operations
+			Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+			Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+			Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+			Assert.assertTrue(heavyHittersContainsString("fed_*"));
+			Assert.assertTrue(heavyHittersContainsString("fed_+"));
+			Assert.assertTrue(heavyHittersContainsString("fed_<="));
+			Assert.assertTrue(heavyHittersContainsString("fed_/"));
+			Assert.assertTrue(heavyHittersContainsString("fed_r'"));
+			
+			//check that federated input files are still existing
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+		}
+		
 		// compare via files
 		//compareResults(1e-9); --> randomized
 		TestUtils.shutdownThreads(t1, t2);
 		
-		// check for federated operations
-		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-		Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
-		Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
-		Assert.assertTrue(heavyHittersContainsString("fed_*"));
-		Assert.assertTrue(heavyHittersContainsString("fed_+"));
-		Assert.assertTrue(heavyHittersContainsString("fed_<="));
-		Assert.assertTrue(heavyHittersContainsString("fed_/"));
-		Assert.assertTrue(heavyHittersContainsString("fed_r'"));
-		
-		//check that federated input files are still existing
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-		
 		resetExecMode(platformOld);
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 906b124..4b4457a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -83,19 +83,28 @@
 		String HOME = SCRIPT_DIR + TEST_DIR;
 
 		// write input matrices
-		int halfRows = rows / 2;
+		int quarterRows = rows / 4;
 		// We have two matrices handled by a single federated worker
-		double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3);
-		double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7);
-		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		double[][] X1 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 3);
+		double[][] X2 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 7);
+		double[][] X3 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 8);
+		double[][] X4 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 9);
+		MatrixCharacteristics mc= new MatrixCharacteristics(quarterRows, cols, blocksize, quarterRows * cols);
+		writeInputMatrixWithMTD("X1", X1, false, mc);
+		writeInputMatrixWithMTD("X2", X2, false, mc);
+		writeInputMatrixWithMTD("X3", X3, false, mc);
+		writeInputMatrixWithMTD("X4", X4, false, mc);
 
 		// empty script name because we don't execute any script, just start the worker
 		fullDMLScriptName = "";
 		int port1 = getRandomAvailablePort();
 		int port2 = getRandomAvailablePort();
+		int port3 = getRandomAvailablePort();
+		int port4 = getRandomAvailablePort();
 		Thread t1 = startLocalFedWorker(port1);
 		Thread t2 = startLocalFedWorker(port2);
+		Thread t3 = startLocalFedWorker(port3);
+		Thread t4 = startLocalFedWorker(port4);
 
 		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
 		loadTestConfiguration(config);
@@ -103,21 +112,24 @@
 		
 		// Run reference dml script with normal matrix
 		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-		programArgs = new String[] {"-stats", "-args", input("X1"), input("X2"),
+		programArgs = new String[] {"-stats", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
 			String.valueOf(scaleAndShift).toUpperCase(), expected("Z")};
 		runTest(true, false, null, -1);
 
 		// Run actual dml script with federated matrix
 		fullDMLScriptName = HOME + TEST_NAME + ".dml";
-		programArgs = new String[] {"-stats",
-			"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
-			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+		programArgs = new String[] {"-stats", "-nvargs", 
+			"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+			"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+			"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+			"rows=" + rows, "cols=" + cols,
 			"scaleAndShift=" + String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
 		runTest(true, false, null, -1);
 
 		// compare via files
 		compareResults(1e-9);
-		TestUtils.shutdownThreads(t1, t2);
+		TestUtils.shutdownThreads(t1, t2, t3, t4);
 		
 		// check for federated operations
 		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
@@ -134,6 +146,8 @@
 		//check that federated input files are still existing
 		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
 		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
 		
 		resetExecMode(platformOld);
 	}
diff --git a/src/test/scripts/functions/federated/FederatedPCATest.dml b/src/test/scripts/functions/federated/FederatedPCATest.dml
index b235d44..049a789 100644
--- a/src/test/scripts/functions/federated/FederatedPCATest.dml
+++ b/src/test/scripts/functions/federated/FederatedPCATest.dml
@@ -19,7 +19,8 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+		list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)))
 [X2,M] = pca(X=X,  K=2, scale=$scaleAndShift, center=$scaleAndShift)
 write(X2, $out)
diff --git a/src/test/scripts/functions/federated/FederatedPCATestReference.dml b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
index 0b17fe0..3ab3b97 100644
--- a/src/test/scripts/functions/federated/FederatedPCATestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
@@ -19,6 +19,6 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
-[X2,M] = pca(X=X,  K=2, scale=$3, center=$3)
-write(X2, $4)
+X = rbind(read($1), read($2), read($3), read($4));
+[X2,M] = pca(X=X,  K=2, scale=$5, center=$5)
+write(X2, $6)