[SYSTEMDS-2784] Enable lineage-based reuse in federated workers

This patch builds the initial infrastructure for lineage based
reuse in federated workers. Changes include:
 - Lineage tracing InitFEDInstruction
 - Lineage trace READ and PUT requests. For PUT, lineageitem hash
   is sent with the request, which will be replaced by Adler32
   in future commits.
 - Disable compiler assisted optimizations for lineage-based reuse
   (e.g. mark for caching) for the workers.
 - Testing infrastructure.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 6c9be16..33dad44 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -23,8 +23,10 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.stream.Collectors;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.utils.Statistics;
 
 public class FederatedRequest implements Serializable {
@@ -45,6 +47,7 @@
 	private long _tid;
 	private List<Object> _data;
 	private boolean _checkPrivacy;
+	private List<Integer> _lineageHash;
 	
 	
 	public FederatedRequest(RequestType method) {
@@ -117,6 +120,16 @@
 		return _checkPrivacy;
 	}
 	
+	public void setLineageHash(LineageItem[] liItems) {
+		// copy the hash of the corresponding lineage DAG
+		// TODO: copy both Adler32 checksum (on data) and hash (on lineage DAG)
+		_lineageHash = Arrays.stream(liItems).map(li -> li.hashCode()).collect(Collectors.toList());
+	}
+	
+	public int getLineageHash(int i) {
+		return _lineageHash.get(i);
+	}
+	
 	@Override
 	public String toString() {
 		StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e3ec403..5c0a0bc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,6 +48,8 @@
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -232,6 +234,10 @@
 		cd.enableCleanup(false); // guard against deletion
 		_ecm.get(tid).setVariable(String.valueOf(id), cd);
 
+		if (DMLScript.LINEAGE)
+			// create a literal type lineage item with the file name
+			_ecm.get(tid).getLineage().set(String.valueOf(id), new LineageItem(filename));
+
 		if(dataType == Types.DataType.FRAME) {
 			FrameObject frameObject = (FrameObject) cd;
 			frameObject.acquireRead();
@@ -264,6 +270,10 @@
 
 		// set variable and construct empty response
 		ec.setVariable(varname, data);
+		if (DMLScript.LINEAGE)
+			// TODO: Identify MO uniquely. Use Adler32 checksum.
+			ec.getLineage().set(varname, new LineageItem(String.valueOf(request.getLineageHash(0))));
+
 		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
 	}
 
@@ -299,6 +309,14 @@
 		pb.getInstructions().clear();
 		Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String) request.getParam(0));
 		pb.getInstructions().add(receivedInstruction);
+
+		if (DMLScript.LINEAGE)
+			// Compiler assisted optimizations are not applicable for Fed workers.
+			// e.g. isMarkedForCaching fails as output operands are saved in the 
+			// symbol table only after the instruction execution finishes. 
+			// NOTE: In shared JVM, this will disable compiler assistance even for the coordinator 
+			LineageCacheConfig.setCompAssRW(false);
+
 		try {
 			pb.execute(ec); // execute single instruction
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..4a8194b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -21,6 +21,7 @@
 
 import java.util.concurrent.Future;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,7 @@
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
@@ -78,6 +80,10 @@
 		else if(mo1.isFederated(FType.ROW)) { // MV + MM
 			//construct commands: broadcast rhs, fed mv, retrieve results
 			FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+			if (DMLScript.LINEAGE)
+				//also copy the hash of the lineage DAG
+				fr1.setLineageHash(LineageItemUtils.getLineage(ec, input1));
+				//TODO: calculate Adler32 checksum on data, and move this code inside FederationMap.
 			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
 			if( mo2.getNumColumns() == 1 ) { //MV
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 17e2855..bc16149 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -56,9 +56,11 @@
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
-public class InitFEDInstruction extends FEDInstruction {
+public class InitFEDInstruction extends FEDInstruction implements LineageTraceable {
 
 	private static final Log LOG = LogFactory.getLog(InitFEDInstruction.class.getName());
 
@@ -342,4 +344,34 @@
 			throw new DMLRuntimeException("Exception in frame response from federated worker.", e);
 		}
 	}
+
+	@Override
+	public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+		String type = ec.getScalarInput(_type).getStringValue();
+		ListObject addresses = ec.getListObject(_addresses.getName());
+		ListObject ranges = ec.getListObject(_ranges.getName());
+		LineageItem[] liInputs = new LineageItem[addresses.getLength()];
+
+		for(int i = 0; i < addresses.getLength(); i++) {
+			Data addressData = addresses.getData().get(i);
+			if(addressData instanceof StringObject) {
+				String address = ((StringObject)addressData).getStringValue();
+				// get beginning and end of data ranges
+				List<Data> rangesData = ranges.getData();
+				List<Data> beginDimsData = ((ListObject) rangesData.get(i*2)).getData();
+				List<Data> endDimsData = ((ListObject) rangesData.get(i*2+1)).getData();
+				String rl = ((ScalarObject)beginDimsData.get(0)).getStringValue();
+				String cl = ((ScalarObject)beginDimsData.get(1)).getStringValue();
+				String ru = ((ScalarObject)endDimsData.get(0)).getStringValue();
+				String cu = ((ScalarObject)endDimsData.get(1)).getStringValue();
+				// form a string with all the information and create a lineage item
+				String data = InstructionUtils.concatOperands(type, address, rl, cl, ru, cu);
+				liInputs[i] = new LineageItem(data);
+			}
+			else {
+				throw new DMLRuntimeException("federated instruction only takes strings as addresses");
+			}
+		}
+		return Pair.of(_output.getName(), new LineageItem(getOpcode(), liInputs));
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 0143fed..d51f05b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.stream.Collectors;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
@@ -1530,7 +1531,11 @@
 	 * @return the thread associated with the worker.
 	 */
 	protected Thread startLocalFedWorkerThread(int port) {
-		return startLocalFedWorkerThread(port, FED_WORKER_WAIT);
+		return startLocalFedWorkerThread(port, null, FED_WORKER_WAIT);
+	}
+
+	protected Thread startLocalFedWorkerThread(int port, String[] otherArgs) {
+		return startLocalFedWorkerThread(port, otherArgs, FED_WORKER_WAIT);
 	}
 
 	/**
@@ -1543,11 +1548,17 @@
 	 * @return the thread associated with the worker.
 	 */
 	protected Thread startLocalFedWorkerThread(int port, int sleep) {
+		return startLocalFedWorkerThread(port, null, sleep);
+	}
+	protected Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) {
 		Thread t = null;
 		String[] fedWorkArgs = {"-w", Integer.toString(port)};
 		ArrayList<String> args = new ArrayList<>();
 
 		addProgramIndependentArguments(args);
+		
+		if (otherArgs != null)
+			args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList()));
 
 		for(int i = 0; i < fedWorkArgs.length; i++)
 			args.add(fedWorkArgs[i]);
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
new file mode 100644
index 0000000..00c6d6f
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FedFullReuseTest extends AutomatedTestBase {
+
+	private final static String TEST_DIR = "functions/lineage/";
+	private final static String TEST_NAME = "FedFullReuse1";
+	private final static String TEST_CLASS_DIR = TEST_DIR + FedFullReuseTest.class.getSimpleName() + "/";
+
+	private final static int blocksize = 1024;
+	@Parameterized.Parameter()
+	public int rows;
+	@Parameterized.Parameter(1)
+	public int cols;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+	}
+
+	@Parameterized.Parameters
+	public static Collection<Object[]> data() {
+		// rows have to be even and > 1
+		return Arrays.asList(new Object[][] {
+			// {2, 1000}, {10, 100},
+			{100, 10}, 
+			//{1000, 1},
+			// {10, 2000}, {2000, 10}
+		});
+	}
+
+	@Test
+	public void federatedReuseMM() {    //reuse inside federated workers
+		federatedReuse();
+	}
+	
+	public void federatedReuse() {
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+
+		// write input matrices
+		int halfRows = rows / 2;
+		// Share two matrices between two federated worker
+		double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+		double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+		double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44);
+		double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21);
+
+		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+
+		int port1 = getRandomAvailablePort();
+		int port2 = getRandomAvailablePort();
+		String[] otherargs = new String[] {"-lineage", "reuse_full"};
+		Lineage.resetInternalState();
+		Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S);
+		Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+
+		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+		loadTestConfiguration(config);
+
+		// Run reference dml script with normal matrix. Reuse of ba+*.
+		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+		programArgs = new String[] {"-stats", "-lineage", "reuse_full",
+			"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+			"Y2=" + input("Y2"), "Z=" + expected("Z")};
+		runTest(true, false, null, -1);
+		long mmCount = Statistics.getCPHeavyHitterCount("ba+*");
+
+		// Run actual dml script with federated matrix
+		// The fed workers reuse ba+*
+		fullDMLScriptName = HOME + TEST_NAME + ".dml";
+		programArgs = new String[] {"-stats","-lineage", "reuse_full",
+			"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"X2=" + TestUtils.federatedAddress(port2, input("X2")),
+			"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+			"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+		runTest(true, false, null, -1);
+		long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*");
+
+		// compare results 
+		compareResults(1e-9);
+		// compare matrix multiplication count
+		// #federated execution of ba+* = #threads times #non-federated execution of ba+* (after reuse) 
+		Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == "+mmCount*2, 
+				mmCount_fed == mmCount * 2); // #threads = 2
+
+		TestUtils.shutdownThreads(t1, t2);
+	}
+
+}
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1.dml b/src/test/scripts/functions/lineage/FedFullReuse1.dml
new file mode 100644
index 0000000..4597332
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)));
+Y = federated(addresses=list($Y1, $Y2),
+    ranges=list(list(0, 0), list($c, $r / 2), list(0, $r / 2), list($c, $r)));
+
+for(i in 1:10)
+  Z = X %*% Y;
+
+write(Z, $Z);
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
new file mode 100644
index 0000000..6049f5d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2));
+Y = cbind(read($Y1), read($Y2));
+
+for(i in 1:10)
+  Z = X %*% Y;
+
+write(Z, $Z);