[SYSTEMDS-2784] Enable lineage-based reuse in federated workers
This patch builds the initial infrastructure for lineage based
reuse in federated workers. Changes include:
- Lineage tracing InitFEDInstruction
- Lineage trace READ and PUT requests. For PUT, lineageitem hash
is sent with the request, which will be replaced by Adler32
in future commits.
- Disable compiler assisted optimizations for lineage-based reuse
(e.g. mark for caching) for the workers.
- Testing infrastructure.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 6c9be16..33dad44 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -23,8 +23,10 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.utils.Statistics;
public class FederatedRequest implements Serializable {
@@ -45,6 +47,7 @@
private long _tid;
private List<Object> _data;
private boolean _checkPrivacy;
+ private List<Integer> _lineageHash;
public FederatedRequest(RequestType method) {
@@ -117,6 +120,16 @@
return _checkPrivacy;
}
+ public void setLineageHash(LineageItem[] liItems) {
+ // copy the hash of the corresponding lineage DAG
+ // TODO: copy both Adler32 checksum (on data) and hash (on lineage DAG)
+ _lineageHash = Arrays.stream(liItems).map(li -> li.hashCode()).collect(Collectors.toList());
+ }
+
+ public int getLineageHash(int i) {
+ return _lineageHash.get(i);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e3ec403..5c0a0bc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,6 +48,8 @@
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -232,6 +234,10 @@
cd.enableCleanup(false); // guard against deletion
_ecm.get(tid).setVariable(String.valueOf(id), cd);
+ if (DMLScript.LINEAGE)
+ // create a literal type lineage item with the file name
+ _ecm.get(tid).getLineage().set(String.valueOf(id), new LineageItem(filename));
+
if(dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
frameObject.acquireRead();
@@ -264,6 +270,10 @@
// set variable and construct empty response
ec.setVariable(varname, data);
+ if (DMLScript.LINEAGE)
+ // TODO: Identify MO uniquely. Use Adler32 checksum.
+ ec.getLineage().set(varname, new LineageItem(String.valueOf(request.getLineageHash(0))));
+
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
@@ -299,6 +309,14 @@
pb.getInstructions().clear();
Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String) request.getParam(0));
pb.getInstructions().add(receivedInstruction);
+
+ if (DMLScript.LINEAGE)
+ // Compiler assisted optimizations are not applicable for Fed workers.
+ // e.g. isMarkedForCaching fails as output operands are saved in the
+ // symbol table only after the instruction execution finishes.
+ // NOTE: In shared JVM, this will disable compiler assistance even for the coordinator
+ LineageCacheConfig.setCompAssRW(false);
+
try {
pb.execute(ec); // execute single instruction
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..4a8194b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -21,6 +21,7 @@
import java.util.concurrent.Future;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -78,6 +80,10 @@
else if(mo1.isFederated(FType.ROW)) { // MV + MM
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ if (DMLScript.LINEAGE)
+ //also copy the hash of the lineage DAG
+ fr1.setLineageHash(LineageItemUtils.getLineage(ec, input1));
+ //TODO: calculate Adler32 checksum on data, and move this code inside FederationMap.
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
if( mo2.getNumColumns() == 1 ) { //MV
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 17e2855..bc16149 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -56,9 +56,11 @@
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-public class InitFEDInstruction extends FEDInstruction {
+public class InitFEDInstruction extends FEDInstruction implements LineageTraceable {
private static final Log LOG = LogFactory.getLog(InitFEDInstruction.class.getName());
@@ -342,4 +344,34 @@
throw new DMLRuntimeException("Exception in frame response from federated worker.", e);
}
}
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ String type = ec.getScalarInput(_type).getStringValue();
+ ListObject addresses = ec.getListObject(_addresses.getName());
+ ListObject ranges = ec.getListObject(_ranges.getName());
+ LineageItem[] liInputs = new LineageItem[addresses.getLength()];
+
+ for(int i = 0; i < addresses.getLength(); i++) {
+ Data addressData = addresses.getData().get(i);
+ if(addressData instanceof StringObject) {
+ String address = ((StringObject)addressData).getStringValue();
+ // get beginning and end of data ranges
+ List<Data> rangesData = ranges.getData();
+ List<Data> beginDimsData = ((ListObject) rangesData.get(i*2)).getData();
+ List<Data> endDimsData = ((ListObject) rangesData.get(i*2+1)).getData();
+ String rl = ((ScalarObject)beginDimsData.get(0)).getStringValue();
+ String cl = ((ScalarObject)beginDimsData.get(1)).getStringValue();
+ String ru = ((ScalarObject)endDimsData.get(0)).getStringValue();
+ String cu = ((ScalarObject)endDimsData.get(1)).getStringValue();
+ // form a string with all the information and create a lineage item
+ String data = InstructionUtils.concatOperands(type, address, rl, cl, ru, cu);
+ liInputs[i] = new LineageItem(data);
+ }
+ else {
+ throw new DMLRuntimeException("federated instruction only takes strings as addresses");
+ }
+ }
+ return Pair.of(_output.getName(), new LineageItem(getOpcode(), liInputs));
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 0143fed..d51f05b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@
import java.util.List;
import java.util.Map;
import java.util.Properties;
+import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
@@ -1530,7 +1531,11 @@
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port) {
- return startLocalFedWorkerThread(port, FED_WORKER_WAIT);
+ return startLocalFedWorkerThread(port, null, FED_WORKER_WAIT);
+ }
+
+ protected Thread startLocalFedWorkerThread(int port, String[] otherArgs) {
+ return startLocalFedWorkerThread(port, otherArgs, FED_WORKER_WAIT);
}
/**
@@ -1543,11 +1548,17 @@
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port, int sleep) {
+ return startLocalFedWorkerThread(port, null, sleep);
+ }
+ protected Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) {
Thread t = null;
String[] fedWorkArgs = {"-w", Integer.toString(port)};
ArrayList<String> args = new ArrayList<>();
addProgramIndependentArguments(args);
+
+ if (otherArgs != null)
+ args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList()));
for(int i = 0; i < fedWorkArgs.length; i++)
args.add(fedWorkArgs[i]);
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
new file mode 100644
index 0000000..00c6d6f
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FedFullReuseTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/lineage/";
+ private final static String TEST_NAME = "FedFullReuse1";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FedFullReuseTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {
+ // {2, 1000}, {10, 100},
+ {100, 10},
+ //{1000, 1},
+ // {10, 2000}, {2000, 10}
+ });
+ }
+
+ @Test
+ public void federatedReuseMM() { //reuse inside federated workers
+ federatedReuse();
+ }
+
+ public void federatedReuse() {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // Share two matrices between two federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44);
+ double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ String[] otherargs = new String[] {"-lineage", "reuse_full"};
+ Lineage.resetInternalState();
+ Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix. Reuse of ba+*.
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-lineage", "reuse_full",
+ "-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+ long mmCount = Statistics.getCPHeavyHitterCount("ba+*");
+
+ // Run actual dml script with federated matrix
+ // The fed workers reuse ba+*
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats","-lineage", "reuse_full",
+ "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ runTest(true, false, null, -1);
+ long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*");
+
+ // compare results
+ compareResults(1e-9);
+ // compare matrix multiplication count
+ // #federated execution of ba+* = #threads times #non-federated execution of ba+* (after reuse)
+ Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == "+mmCount*2,
+ mmCount_fed == mmCount * 2); // #threads = 2
+
+ TestUtils.shutdownThreads(t1, t2);
+ }
+
+}
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1.dml b/src/test/scripts/functions/lineage/FedFullReuse1.dml
new file mode 100644
index 0000000..4597332
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)));
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($c, $r / 2), list(0, $r / 2), list($c, $r)));
+
+for(i in 1:10)
+ Z = X %*% Y;
+
+write(Z, $Z);
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
new file mode 100644
index 0000000..6049f5d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2));
+Y = cbind(read($Y1), read($Y2));
+
+for(i in 1:10)
+ Z = X %*% Y;
+
+write(Z, $Z);