[MINOR] Federated tests now in same JVM
To make this work, a static flag inside statistics is set, to allow or
disallow the worker to clear and print the statistics. This is still
a temporary hack, since the wanted behavior would be isolated statistic
objects for each the worker and the controller.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 5dbccb4..0a0a05b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -113,7 +113,7 @@
response = tmp; //return last
}
- if (DMLScript.STATISTICS && request.getType() == RequestType.CLEAR){
+ if (DMLScript.STATISTICS && request.getType() == RequestType.CLEAR && Statistics.allowWorkerStatistics){
System.out.println("Federated Worker " + Statistics.display());
Statistics.reset();
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 8f22ab4..7642397 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -159,6 +159,8 @@
public static long recomputeNNZTime = 0;
public static long examSparsityTime = 0;
public static long allocateDoubleArrTime = 0;
+
+ public static boolean allowWorkerStatistics = true;
public static void incrementNativeFailuresCounter() {
numNativeFailures.increment();
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 9951982..28a82ff 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -101,7 +101,7 @@
public static final boolean TEST_GPU = false;
public static final double GPU_TOLERANCE = 1e-9;
- public static final int FED_WORKER_WAIT = 2000; // in ms
+ public static final int FED_WORKER_WAIT = 800; // in ms
// With OpenJDK 8u242 on Windows, the new changes in JDK are not allowing
// to set the native library paths internally thus breaking the code.
@@ -1311,6 +1311,47 @@
}
/**
+ * Start a thread for a worker. This will share the same JVM,
+ * so all static variables will be shared.!
+ *
+ * Also when using the local Fed Worker thread the statistics printing,
+ * and clearing from the worker is disabled.
+ *
+ * @param port Port to use
+ * @return the thread associated with the worker.
+ */
+ protected Thread startLocalFedWorkerThread(int port) {
+ Thread t = null;
+ String[] fedWorkArgs = {"-w", Integer.toString(port)};
+ ArrayList<String> args = new ArrayList<>();
+
+ addProgramIndependentArguments(args);
+
+ for(int i = 0; i < fedWorkArgs.length; i++)
+ args.add(fedWorkArgs[i]);
+
+ String[] finalArguments = args.toArray(new String[args.size()]);
+
+ Statistics.allowWorkerStatistics = false;
+
+ try {
+ t = new Thread(() -> {
+ try {
+ main(finalArguments);
+ }
+ catch(IOException e) {
+ }
+ });
+ t.start();
+ java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT);
+ }
+ catch(InterruptedException e) {
+ e.printStackTrace();
+ }
+ return t;
+ }
+
+ /**
* Start java worker in same JVM.
* @param args the command line arguments
* @return the thread associated with the process.s
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
index 96fef5d..3c9439d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
@@ -116,10 +116,10 @@
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
- Process t3 = startLocalFedWorker(port3);
- Process t4 = startLocalFedWorker(port4);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 636b279..8fbcfa7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -94,8 +94,8 @@
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index 933e971..c763541 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -101,8 +101,8 @@
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index 4caf52e..c102ef9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -76,7 +76,6 @@
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- Process t1, t2;
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -98,8 +97,8 @@
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1);
- t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index fef8889..9f4aaea 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -19,8 +19,6 @@
package org.apache.sysds.test.functions.federated.algorithms;
-import java.io.BufferedReader;
-import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Collection;
@@ -95,20 +93,8 @@
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
-
- BufferedReader output = new BufferedReader(new InputStreamReader(t1.getInputStream()));
- BufferedReader error = new BufferedReader(new InputStreamReader(t1.getInputStream()));
-
- Thread t = new Thread(() -> {
- output.lines().forEach(s -> System.out.println(s));
- });
- Thread te = new Thread(() -> {
- error.lines().forEach(s -> System.err.println(s));
- });
- t.start();
- te.start();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -131,7 +117,6 @@
compareResults(1e-9);
TestUtils.shutdownThreads(t1, t2);
- TestUtils.shutdownThreads(t, te);
// check for federated operations
Assert.assertTrue("contains federated matrix mult", heavyHittersContainsString("fed_ba+*"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 7a5a2fd..b86c4da 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -101,10 +101,10 @@
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
- Process t3 = startLocalFedWorker(port3);
- Process t4 = startLocalFedWorker(port4);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
index c042cbb..b1fb692 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
@@ -101,10 +101,10 @@
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
- Process t3 = startLocalFedWorker(port3);
- Process t4 = startLocalFedWorker(port4);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index ebdee88..24f04f0 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -78,7 +78,6 @@
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- Process t1, t2;
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -105,8 +104,8 @@
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1);
- t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
index 052234d..efd98c2 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
@@ -99,8 +99,8 @@
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index 187648b..c0a53d4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -100,8 +100,8 @@
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
index 74180be..9013f93 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
@@ -120,7 +120,7 @@
String HOME = SCRIPT_DIR + TEST_DIR;
int port = getRandomAvailablePort();
- Process t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorkerThread(port);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index c75b0b5..6039354 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -96,8 +96,8 @@
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index e59eea8..8b3b04f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -79,7 +79,7 @@
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols));
int port = getRandomAvailablePort();
- Process t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorkerThread(port);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
new file mode 100644
index 0000000..d412743
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+@Ignore
+public class FederatedStatisticsTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedLogRegTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedStatisticsTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+ }
+
+ @Test
+ public void federatedSinglenodeLogReg() {
+ federatedLogReg(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedHybridLogReg() {
+ federatedLogReg(Types.ExecMode.HYBRID);
+ }
+
+ public void federatedLogReg(Types.ExecMode execMode) {
+ ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
+ for(int i = 0; i < rows; i++)
+ Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows));
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Process t1 = startLocalFedWorker(port1);
+ Process t2 = startLocalFedWorker(port2);
+
+ BufferedReader output = new BufferedReader(new InputStreamReader(t1.getInputStream()));
+ BufferedReader error = new BufferedReader(new InputStreamReader(t1.getInputStream()));
+
+ Thread t = new Thread(() -> {
+ output.lines().forEach(s -> System.out.println(s));
+ });
+ Thread te = new Thread(() -> {
+ error.lines().forEach(s -> System.err.println(s));
+ });
+ t.start();
+ te.start();
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(false);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "30", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+ "in_Y=" + input("Y"), "out=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+
+ TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(t, te);
+
+ // check for federated operations
+ Assert.assertTrue("contains federated matrix mult", heavyHittersContainsString("fed_ba+*"));
+ Assert.assertTrue("contains federated row unary aggregate",
+ heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
+ Assert.assertTrue("contains federated matrix mult chain or transpose",
+ heavyHittersContainsString("fed_mmchain", "fed_r'"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platformOld);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index c6d4be6..9ce65eb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -78,7 +78,7 @@
double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 1, 1);
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
int port = getRandomAvailablePort();
- Process t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorkerThread(port);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;