[SYSTEMDS-2549,2624] Fix federated binary matrix-vector, var cleanup
This patch fixes two correctness issues related to (1) cleanup of
federated matrices, and (2) federated binary matrix-row vector
operators. Furthermore, this also includes a new federated Kmeans test
and some minor fixes for row aggregates, and improvements of federated
matrix multiplications.
diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index f18466d..90a7222 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -160,7 +160,7 @@
C_old = C; C = C_new;
}
- if(is_verbose == TRUE)
+ if(is_verbose)
print ("Run " + run_index + ", Iteration " + iter_count + ": Terminated with code = "
+ term_code + ", Centroid WCSS = " + wcss);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 31a467f..fcb5db3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -59,9 +59,7 @@
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashSet;
import java.util.List;
-import java.util.Set;
import java.util.stream.Collectors;
public class ExecutionContext {
@@ -73,7 +71,6 @@
//symbol table
protected LocalVariableMap _variables;
protected boolean _autoCreateVars;
- protected Set<String> _guardedFiles = new HashSet<>();
//lineage map, cache, prepared dedup blocks
protected Lineage _lineage;
@@ -134,10 +131,6 @@
public void setAutoCreateVars(boolean flag) {
_autoCreateVars = flag;
}
-
- public void addGuardedFilename(String fname) {
- _guardedFiles.add(fname);
- }
/**
* Get the i-th GPUContext
@@ -758,7 +751,7 @@
//compute ref count only if matrix cleanup actually necessary
if ( mo.isCleanupEnabled() && !getVariables().hasReferences(mo) ) {
mo.clearData(); //clean cached data
- if( fileExists && !_guardedFiles.contains(mo.getFileName()) ) {
+ if( fileExists ) {
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 6571666..46ebce2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -41,8 +41,12 @@
* @param other the <code>FederatedRange</code> to copy
*/
public FederatedRange(FederatedRange other) {
- _beginDims = other._beginDims.clone();
- _endDims = other._endDims.clone();
+ this(other._beginDims.clone(), other._endDims.clone());
+ }
+
+ public FederatedRange(FederatedRange other, long clen) {
+ this(other._beginDims.clone(), other._endDims.clone());
+ _endDims[1] = clen;
}
public void setBeginDim(int dim, long value) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 47ca43c..1afbfb1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -181,7 +181,7 @@
//TODO spawn async load of data, otherwise on first access
_ec.setVariable(String.valueOf(id), cd);
- _ec.addGuardedFilename(filename);
+ cd.enableCleanup(false); //guard against deletion
if (dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 04532fd..d323bad 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -149,6 +149,14 @@
map.put(new FederatedRange(e.getKey()), new FederatedData(e.getValue(), id));
return new FederationMap(id, map);
}
+
+ public FederationMap copyWithNewID(long id, long clen) {
+ Map<FederatedRange, FederatedData> map = new TreeMap<>();
+ //TODO handling of file path, but no danger as never written
+ for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
+ map.put(new FederatedRange(e.getKey(), clen), new FederatedData(e.getValue(), id));
+ return new FederationMap(id, map);
+ }
public FederationMap rbind(long offset, FederationMap that) {
for( Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet() ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 3fe1004..14f81bf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -66,13 +66,22 @@
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
- FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2, fr3);
- MatrixBlock ret = FederationUtils.rbind(tmp);
- mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
- ec.setMatrixOutput(output.getName(), ret);
- //TODO should remain federated matrix (no need for agg)
+ if( mo2.getNumColumns() == 1 ) { //MV
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.rbind(tmp);
+ mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else { //MM
+ //execute federated operations and aggregate
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
+ }
}
//#2 vector - federated matrix multiplication
else if (mo2.isFederated()) {// VM + MM
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index d124c76..7813f6a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -45,13 +45,23 @@
}
//matrix-matrix binary operations -> lhs fed input -> fed output
- FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
- FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
-
- //execute federated instruction and cleanup intermediates
- mo1.getFedMapping().execute(fr1, fr2);
- mo1.getFedMapping().cleanup(fr1.getID());
+ FederatedRequest fr2 = null;
+ if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
+ FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+ fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+ //execute federated instruction and cleanup intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1[0].getID());
+ }
+ else { //MM or MV col vector
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ //execute federated instruction and cleanup intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+ }
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
diff --git a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index 58bdcd0..d71ce9d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -31,7 +31,7 @@
protected int _blocksize;
- public DataCharacteristics set(long nr, long nc, int len) {
+ public DataCharacteristics set(long nr, long nc, int blen) {
throw new DMLRuntimeException("DataCharacteristics.set(long, long, int): should never get called in the base class");
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
new file mode 100644
index 0000000..1ef2384
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedKmeansTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedKMeansTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedKmeansTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int runs;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {
+ {10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
+ //TODO support for multi-threaded federated interactions
+ //{10000, 10, 16}, {2000, 50, 16}, {1000, 100, 16}, //concurrent requests
+ });
+ }
+
+ @Test
+ public void federatedKmeansSinglenode() {
+ federatedKmeans(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedKmeansHybrid() {
+ federatedKmeans(Types.ExecMode.HYBRID);
+ }
+
+ public void federatedKmeans(Types.ExecMode execMode) {
+ ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7);
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorker(port1);
+ Thread t2 = startLocalFedWorker(port2);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(false);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"),
+ String.valueOf(runs), expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats",
+ "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+ "runs=" + String.valueOf(runs), "out=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ //compareResults(1e-9); --> randomized
+ TestUtils.shutdownThreads(t1, t2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+ Assert.assertTrue(heavyHittersContainsString("fed_*"));
+ Assert.assertTrue(heavyHittersContainsString("fed_+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_<="));
+ Assert.assertTrue(heavyHittersContainsString("fed_/"));
+
+ //check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platformOld);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index bf674a8..53eac1e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -27,6 +27,7 @@
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -129,6 +130,10 @@
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
}
+ //check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
resetExecMode(platformOld);
}
}
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
new file mode 100644
index 0000000..95f136c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+[C,Y] = kmeans(X=X, k=4, runs=$runs)
+write(C, $out)
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
new file mode 100644
index 0000000..da32c8b
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[C,Y] = kmeans(X=X, k=4, runs=$3)
+write(C, $4)