[SYSTEMDS-2762] Federated reshape operations (aligned)
Closes #1129.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 76521db..9590510 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -242,10 +242,17 @@
}
private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
- FederatedRequest[] ret = new FederatedRequest[b.length + 1];
- ret[0] = a;
- System.arraycopy(b, 0, ret, 1, b.length);
- return ret;
+ // empty b array
+ if( b == null || b.length==0 ) {
+ return new FederatedRequest[] {a};
+ }
+ // concat with b array
+ else {
+ FederatedRequest[] ret = new FederatedRequest[b.length + 1];
+ ret[0] = a;
+ System.arraycopy(b, 0, ret, 1, b.length);
+ return ret;
+ }
}
public FederationMap identCopy(long tid, long id) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 093ff30..e9c41b2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -78,6 +78,25 @@
return new FederatedRequest(RequestType.EXEC_INST, id, linst);
}
+ public static FederatedRequest[] callInstruction(String[] inst, CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn) {
+ long id = getNextFedDataID();
+ String[] linst = inst;
+ FederatedRequest[] fr = new FederatedRequest[inst.length];
+ for(int j=0; j<inst.length; j++) {
+ for(int i = 0; i < varOldIn.length; i++) {
+ linst[j] = linst[j].replace(ExecType.SPARK.name(), ExecType.CP.name());
+ linst[j] = linst[j].replace(Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX, Lop.OPERAND_DELIMITOR + String.valueOf(id) + Lop.DATATYPE_PREFIX);
+
+ if(varOldIn[i] != null) {
+ linst[j] = linst[j].replace(Lop.OPERAND_DELIMITOR + varOldIn[i].getName() + Lop.DATATYPE_PREFIX, Lop.OPERAND_DELIMITOR + String.valueOf(varNewIn[i]) + Lop.DATATYPE_PREFIX);
+ linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i])); //parameterized
+ }
+ }
+ fr[j] = new FederatedRequest(RequestType.EXEC_INST, id, (Object) linst[j]);
+ }
+ return fr;
+ }
+
public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) {
try {
SimpleOperator op = new SimpleOperator(Plus.getPlusFnObject());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 10b6147..00f6b72 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -38,6 +38,7 @@
Tsmm,
MMChain,
Reorg,
+ Reshape,
MatrixIndexing,
QSort,
QPick
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 3076b9b..af17637 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -107,6 +107,8 @@
} else if(inst.getOpcode().equalsIgnoreCase("qsort") && mo1.isFederated()) {
if(mo1.getFedMapping().getFederatedRanges().length == 1)
fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ } else if(inst.getOpcode().equalsIgnoreCase("rshape") && mo1.isFederated()) {
+ fedinst = ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
} else if(inst instanceof AggregateUnaryCPInstruction && mo1.isFederated() &&
((AggregateUnaryCPInstruction) instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT) {
fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
new file mode 100644
index 0000000..b66f72a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class ReshapeFEDInstruction extends UnaryFEDInstruction {
+ private final CPOperand _opRows;
+ private final CPOperand _opCols;
+ private final CPOperand _opDims;
+ private final CPOperand _opByRow;
+
+ private ReshapeFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
+ CPOperand in5, CPOperand out, String opcode, String istr) {
+ super(FEDInstruction.FEDType.Reshape, op, in1, out, opcode, istr);
+ _opRows = in2;
+ _opCols = in3;
+ _opDims = in4;
+ _opByRow = in5;
+ }
+
+ public static ReshapeFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 6);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand in5 = new CPOperand(parts[5]);
+ CPOperand out = new CPOperand(parts[6]);
+ if(!opcode.equalsIgnoreCase("rshape"))
+ throw new DMLRuntimeException("Unknown opcode while parsing an ReshapeInstruction: " + str);
+ else
+ return new ReshapeFEDInstruction(new Operator(true), in1, in2, in3, in4, in5, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ if(output.getDataType() == Types.DataType.MATRIX) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ BooleanObject byRow = (BooleanObject) ec
+ .getScalarInput(_opByRow.getName(), Types.ValueType.BOOLEAN, _opByRow.isLiteral());
+ int rows = (int) ec.getScalarInput(_opRows).getLongValue();
+ int cols = (int) ec.getScalarInput(_opCols).getLongValue();
+
+ if(!mo1.isFederated())
+ throw new DMLRuntimeException("Federated Rshape: "
+ + "Federated input expected, but invoked w/ " + mo1.isFederated());
+ if(mo1.getNumColumns() * mo1.getNumRows() != rows * cols)
+ throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells ("
+ + mo1.getNumRows() + ":" + mo1.getNumColumns() + ", " + rows + ":" + cols + ").");
+
+ boolean isNotAligned = Arrays.stream(mo1.getFedMapping().getFederatedRanges())
+ .map(e -> e.getSize() % (byRow.getBooleanValue() ? cols : rows) == 0).collect(Collectors.toList())
+ .contains(false);
+
+ if(isNotAligned)
+ throw new DMLRuntimeException(
+ "Reshape matrix requires consistent numbers of input/output cells for each worker.");
+
+ String[] newInstString = getNewInstString(mo1, instString, rows, cols, byRow.getBooleanValue());
+
+ //execute at federated site
+ FederatedRequest[] fr1 = FederationUtils.callInstruction(newInstString,
+ output, new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()});
+ mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]);
+
+ // set new fed map
+ FederationMap reshapedFedMap = mo1.getFedMapping();
+ for(int i = 0; i < reshapedFedMap.getFederatedRanges().length; i++) {
+ long cells = reshapedFedMap.getFederatedRanges()[i].getSize();
+ long row = byRow.getBooleanValue() ? cells / cols : rows;
+ long col = byRow.getBooleanValue() ? cols : cells / rows;
+
+ reshapedFedMap.getFederatedRanges()[i].setBeginDim(0,
+ (reshapedFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 || i == 0) ? 0 :
+ reshapedFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+ reshapedFedMap.getFederatedRanges()[i]
+ .setEndDim(0, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[0] + row);
+ reshapedFedMap.getFederatedRanges()[i].setBeginDim(1,
+ (reshapedFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 || i == 0) ? 0 :
+ reshapedFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
+ reshapedFedMap.getFederatedRanges()[i]
+ .setEndDim(1, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[1] + col);
+ }
+
+ //derive output federated mapping
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(rows, cols, (int) mo1.getBlocksize(), mo1.getNnz());
+ out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
+ }
+ else {
+ // TODO support tensor out, frame and list
+ throw new DMLRuntimeException("Federated Reshape Instruction only supports matrix as output.");
+ }
+ }
+
+ // replace old reshape values for each worker
+ private static String[] getNewInstString(MatrixObject mo1, String instString, int rows, int cols, boolean byRow) {
+ String[] instStrings = new String[mo1.getFedMapping().getSize()];
+
+ int sameFedSize = Arrays.stream(mo1.getFedMapping().getFederatedRanges()).map(FederatedRange::getSize)
+ .collect(Collectors.toSet()).size();
+ sameFedSize = sameFedSize == 1 ? 1 : mo1.getFedMapping().getSize();
+
+ for(int i = 0; i < sameFedSize; i++) {
+ String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
+ long size = mo1.getFedMapping().getFederatedRanges()[i].getSize();
+ String oldInstStringPart = byRow ? instParts[3] : instParts[4];
+ String newInstStringPart = byRow ?
+ oldInstStringPart.replace(String.valueOf(rows), String.valueOf(size/cols)) :
+ oldInstStringPart.replace(String.valueOf(cols), String.valueOf(size/rows));
+ instStrings[i] = instString.replace(oldInstStringPart, newInstStringPart);
+ }
+
+ if(sameFedSize == 1)
+ Arrays.fill(instStrings, instStrings[0]);
+
+ return instStrings;
+ }
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return Pair.of(output.getName(),
+ new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, input1, _opRows, _opCols, _opDims, _opByRow)));
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java
new file mode 100644
index 0000000..c32a4a9
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedReshapeTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedReshapeTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReshapeTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public int rRows;
+
+ @Parameterized.Parameter(3)
+ public int rCols;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {12, 12, 144, 1},
+ {12, 12, 24, 6},
+ {12, 12, 48, 3}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ }
+
+ @Test
+ public void federatedReshapeCP() {
+ federatedReshape(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedReshapeSP() {
+ federatedReshape(Types.ExecMode.SPARK);
+ }
+
+ public void federatedReshape(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] X1 = getRandomMatrix(2, cols, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(2, cols, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(6, cols, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(2, cols, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc1 = new MatrixCharacteristics(6, cols, blocksize, 6*cols);
+ MatrixCharacteristics mc2 = new MatrixCharacteristics(2, cols, blocksize, 2*cols);
+ writeInputMatrixWithMTD("X1", X1, false, mc2);
+ writeInputMatrixWithMTD("X2", X2, false, mc2);
+ writeInputMatrixWithMTD("X3", X3, false, mc1);
+ writeInputMatrixWithMTD("X4", X4, false, mc2);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ // Run reference dml script with normal matrix for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"), expected("S"), String.valueOf(rRows), String.valueOf(rCols)};
+ runTest(null);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+ "rows=" + rows,
+ "cols=" + cols,
+ "r_rows=" + rRows,
+ "r_cols=" + rCols,
+ "out_S=" + output("S")};
+ runTest(null);
+
+ // compare all sums via files
+ compareResults(0.01);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_rshape"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTest.dml b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
new file mode 100644
index 0000000..6aa8a16
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+/*
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+*/
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list(2, 12), list(2, 0), list(4, $cols),
+ list(4, 0), list(10, $cols), list(10, 0), list(12, $cols)));
+
+s = matrix(A, rows=$r_rows, cols=$r_cols);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTestReference.dml b/src/test/scripts/functions/federated/FederatedReshapeTestReference.dml
new file mode 100644
index 0000000..2b364dd
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedReshapeTestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = rbind(read($1), read($2), read($3), read($4));
+s = matrix(A, rows=$6, cols=$7);
+write(s, $5);