[SYSTEMDS-2549] Extended federated binary element-wise operations
This patch generalizes the existing federated binary element-wise
operations to avoid unsupported scenarios. Specifically, if the
right-hand-side matrix (instead of left-hand-side) matrix is federated
and the operation is commutative (e.g., mult/add) we canonicalize the
inputs accordingly.
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index bceb6ae..ea34df1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -25,6 +25,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
@@ -39,8 +40,16 @@
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
+ //canonicalization for federated lhs
+ if( !mo1.isFederated() && mo2.isFederated()
+ && mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
+ && ((BinaryOperator)_optr).isCommutative() ) {
+ mo1 = ec.getMatrixObject(input2);
+ mo2 = ec.getMatrixObject(input1);
+ }
+
+ //execute federated operation on mo1 or mo2
FederatedRequest fr2 = null;
-
if( mo2.isFederated() ) {
if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
@@ -48,12 +57,12 @@
mo1.getFedMapping().execute(getTID(), true, fr2);
}
else {
- throw new DMLRuntimeException("Matrix-matrix binary operations "
- + " with a federated right input are not supported yet.");
+ throw new DMLRuntimeException("Matrix-matrix binary operations with a "
+ + "federated right input are only supported for special cases yet.");
}
}
else {
- //matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output
+ //matrix-matrix binary operations -> lhs fed input -> fed output
if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index beca629..bc4cdd0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -56,6 +56,7 @@
private static final long serialVersionUID = -2547950181558989209L;
public final ValueFunction fn;
+ public final boolean commutative;
public BinaryOperator(ValueFunction p) {
//binaryop is sparse-safe iff (0 op 0) == 0
@@ -65,6 +66,8 @@
|| p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor
|| p instanceof BitwShiftL || p instanceof BitwShiftR);
fn = p;
+ commutative = p instanceof Plus || p instanceof Multiply
+ || p instanceof And || p instanceof Or || p instanceof Xor;
}
/**
@@ -111,6 +114,10 @@
return null;
}
+ public boolean isCommutative() {
+ return commutative;
+ }
+
@Override
public String toString() {
return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index d71ce9d..a28d98d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -188,9 +188,11 @@
dimOut.set(dim1.getRows(), dim2.getCols(), dim1.getBlocksize());
}
+ public abstract boolean equalDims(Object anObject);
+
@Override
public abstract boolean equals(Object anObject);
-
+
@Override
public abstract int hashCode();
}
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
index 0b29cce..bdc4b21 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
@@ -229,7 +229,17 @@
return !nnzKnown() || numRows==0 || numColumns==0
|| (nonZero < numRows*numColumns - singleBlk);
}
-
+
+ @Override
+ public boolean equalDims(Object anObject) {
+ if( !(anObject instanceof MatrixCharacteristics) )
+ return false;
+ MatrixCharacteristics mc = (MatrixCharacteristics) anObject;
+ return dimsKnown() && mc.dimsKnown()
+ && numRows == mc.numRows
+ && numColumns == mc.numColumns;
+ }
+
@Override
public boolean equals (Object anObject) {
if( !(anObject instanceof MatrixCharacteristics) )
diff --git a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
index 449cc2d..2b554a2 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
@@ -157,6 +157,15 @@
}
@Override
+ public boolean equalDims(Object anObject) {
+ if( !(anObject instanceof TensorCharacteristics) )
+ return false;
+ TensorCharacteristics tc = (TensorCharacteristics) anObject;
+ return dimsKnown() && tc.dimsKnown()
+ && Arrays.equals(_dims, tc._dims);
+ }
+
+ @Override
public boolean equals (Object anObject) {
if( !(anObject instanceof TensorCharacteristics) )
return false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 2b9d287..44de28f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -123,7 +123,7 @@
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
- Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ //Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
//check that federated input files are still existing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index eb70a4b..0dd339f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -128,8 +128,10 @@
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+ //Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uark+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
Assert.assertTrue(heavyHittersContainsString("fed_*"));
Assert.assertTrue(heavyHittersContainsString("fed_+"));
Assert.assertTrue(heavyHittersContainsString("fed_<="));