[MINOR] Fix multi-threaded federated MV multiply, and test issues
So far, the federated matrix-vector multiplications were always executed
in a single-threaded manner, now we execute them according to the local
parallelism configuration at the federated worker.
Also, it seems I introduced a bug of privacy handling during the merge,
which this patch also fixes (e.g., on scalar casts of non-cacheable data
objects).
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index bba731c..6fe814a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -38,15 +38,13 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -187,8 +185,8 @@
matTo = PrivacyMonitor.handlePrivacy(matTo);
MatrixBlock matBlock1 = matTo.acquireReadAndRelease();
// TODO other datatypes
- AggregateBinaryOperator ab_op = new AggregateBinaryOperator(
- Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()));
+ AggregateBinaryOperator ab_op = InstructionUtils
+ .getMatMultOperator(OptimizerUtils.getConstrainedNumThreads(0));
MatrixBlock result = isMatVecMult ?
matBlock1.aggregateBinaryOperations(matBlock1, vector, new MatrixBlock(), ab_op) :
vector.aggregateBinaryOperations(vector, matBlock1, new MatrixBlock(), ab_op);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
index 1e3186d..0df8108 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
@@ -19,17 +19,12 @@
package org.apache.sysds.runtime.instructions.cp;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
@@ -39,10 +34,6 @@
}
public static AggregateBinaryCPInstruction parseInstruction( String str ) {
- CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
-
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -50,15 +41,13 @@
throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
}
- InstructionUtils.checkNumFields( parts, 4 );
- in1.split(parts[1]);
- in2.split(parts[2]);
- out.split(parts[3]);
+ InstructionUtils.checkNumFields(parts, 4);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
-
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, k);
- return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
+ return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index f40abb0..15e9ccb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -58,9 +58,6 @@
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
-import org.apache.sysds.runtime.privacy.DMLPrivacyException;
-import org.apache.sysds.runtime.privacy.PrivacyConstraint;
-import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
index eb6ce30..78b05a0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
@@ -30,7 +30,6 @@
import org.apache.sysds.runtime.matrix.data.LibMatrixCuMatMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.utils.GPUStatistics;
@@ -64,8 +63,7 @@
CPOperand out = new CPOperand(parts[3]);
boolean isLeftTransposed = Boolean.parseBoolean(parts[4]);
boolean isRightTransposed = Boolean.parseBoolean(parts[5]);
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, 1);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
return new AggregateBinaryGPUInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 0592a49..ab98af3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -27,8 +27,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -42,7 +40,6 @@
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -75,8 +72,7 @@
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, str);
@@ -195,8 +191,7 @@
throws Exception
{
if( _op == null ) { //lazy operator construction
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ _op = InstructionUtils.getMatMultOperator(1);
}
MatrixBlock blkIn1 = (MatrixBlock)arg0._2()._1().getValue();
@@ -224,8 +219,7 @@
public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) throws Exception {
//lazy operator construction
if( _op == null ) {
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ _op = InstructionUtils.getMatMultOperator(1);
_rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject());
}
//prepare inputs, including transpose of right-hand-side
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 0d8ca26..4dce9c1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -80,9 +80,8 @@
boolean outputEmpty = Boolean.parseBoolean(parts[5]);
SparkAggType aggtype = SparkAggType.valueOf(parts[6]);
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
- return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
+ return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
}
@Override
@@ -245,14 +244,10 @@
private final AggregateBinaryOperator _op;
private final PartitionedBroadcast<MatrixBlock> _pbc;
- public RDDMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput )
- {
+ public RDDMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) {
_type = type;
_pbc = binput;
-
- //created operator for reuse
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ _op = InstructionUtils.getMatMultOperator(1);
}
@Override
@@ -412,14 +407,10 @@
private final AggregateBinaryOperator _op;
private final PartitionedBroadcast<MatrixBlock> _pbc;
- public RDDFlatMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput )
- {
+ public RDDFlatMapMMFunction( CacheType type, PartitionedBroadcast<MatrixBlock> binput ) {
_type = type;
_pbc = binput;
-
- //created operator for reuse
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ _op = InstructionUtils.getMatMultOperator(1);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/PMapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/PMapmmSPInstruction.java
index e3fe68e..eed76b7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/PMapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/PMapmmSPInstruction.java
@@ -29,8 +29,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBlock;
@@ -40,7 +38,6 @@
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;
@@ -68,14 +65,12 @@
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
-
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
return new PMapmmSPInstruction(aggbin, in1, in2, out, opcode, str);
}
else {
throw new DMLRuntimeException("PMapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
- }
+ }
}
@Override
@@ -162,14 +157,10 @@
private Broadcast<PartitionedBlock<MatrixBlock>> _pbc = null;
private long _offset = -1;
- public PMapMMFunction( Broadcast<PartitionedBlock<MatrixBlock>> binput, long offset )
- {
+ public PMapMMFunction( Broadcast<PartitionedBlock<MatrixBlock>> binput, long offset ) {
_pbc = binput;
_offset = offset;
-
- //created operator for reuse
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ _op = InstructionUtils.getMatMultOperator(1);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
index b87c904..cbaf347 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
@@ -28,8 +28,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -37,7 +35,6 @@
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -66,8 +63,7 @@
CPOperand nrow = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
CacheType type = CacheType.valueOf(parts[5]);
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
return new PmmSPInstruction(aggbin, in1, in2, out, nrow, type, opcode, str);
}
else {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
index 7bf9e8b..e76c5a4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ZipmmSPInstruction.java
@@ -59,8 +59,7 @@
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
boolean tRewrite = Boolean.parseBoolean(parts[4]);
- AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
return new ZipmmSPInstruction(aggbin, in1, in2, out, tRewrite, opcode, str);
}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
index 118a153..ee88bf4 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
@@ -85,7 +85,7 @@
* @param input variable for which the privacy constraint is checked
*/
public static void handlePrivacyScalarOutput(CPOperand input, ExecutionContext ec) {
- Data data = ec.getCacheableData(input);
+ Data data = ec.getVariable(input);
if ( data != null && (data instanceof CacheableData<?>)){
PrivacyConstraint privacyConstraintIn = ((CacheableData<?>) data).getPrivacyConstraint();
if ( privacyConstraintIn != null && (privacyConstraintIn.getPrivacyLevel() == PrivacyLevel.Private) ){
diff --git a/src/test/java/org/apache/sysds/test/component/compress/ParCompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/ParCompressedMatrixTest.java
index 80b82bb..e86c269 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/ParCompressedMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/ParCompressedMatrixTest.java
@@ -24,11 +24,9 @@
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.TestUtils;
@@ -160,8 +158,7 @@
.convertToMatrixBlock(TestUtils.generateTestMatrix(cols, 1, 1, 1, 1.0, 3));
// matrix-vector uncompressed
- AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop, k);
+ AggregateBinaryOperator abop = InstructionUtils.getMatMultOperator(k);
MatrixBlock ret1 = mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop);
// matrix-vector compressed
@@ -188,8 +185,7 @@
.convertToMatrixBlock(TestUtils.generateTestMatrix(1, rows, 1, 1, 1.0, 3));
// Make Operator
- AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject());
- AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop, k);
+ AggregateBinaryOperator abop = InstructionUtils.getMatMultOperator(k);
// vector-matrix uncompressed
MatrixBlock ret1 = mb.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop);