[SYSTEMDS-2747] Federated weighted cross entropy operations (WCEMM)
Quaternary operations, part 1
Closes #1133.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index e9c41b2..881991a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -119,7 +119,7 @@
long size = 0;
for(int i=0; i<ffr.length; i++) {
Object input = ffr[i].get().getData()[0];
- MatrixBlock tmp = (input instanceof ScalarObject) ?
+ MatrixBlock tmp = (input instanceof ScalarObject) ?
new MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input;
size += ranges[i].getSize(0);
sop1 = sop1.setConstant(ranges[i].getSize(0));
@@ -317,6 +317,10 @@
}
}
+ public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
+ return aggScalar(aop, ffr, null);
+ }
+
public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin &&
(((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
@@ -366,7 +370,7 @@
throw new DMLRuntimeException("Unsupported aggregation operator: "
+ aop.aggOp.increOp.fn.getClass().getSimpleName());
}
-
+
public static FederationMap federateLocalData(CacheableData<?> data) {
long id = FederationUtils.getNextFedDataID();
FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 00f6b72..0cf1fac 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -25,7 +25,7 @@
import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
public abstract class FEDInstruction extends Instruction {
-
+
public enum FEDType {
AggregateBinary,
AggregateUnary,
@@ -40,41 +40,42 @@
Reorg,
Reshape,
MatrixIndexing,
+ Quaternary,
QSort,
QPick
}
-
+
protected final FEDType _fedType;
protected long _tid = -1; //main
-
+
protected FEDInstruction(FEDType type, String opcode, String istr) {
this(type, null, opcode, istr);
}
-
+
protected FEDInstruction(FEDType type, Operator op, String opcode, String istr) {
super(op);
_fedType = type;
instString = istr;
instOpcode = opcode;
}
-
+
@Override
public IType getType() {
return IType.FEDERATED;
}
-
+
public FEDType getFEDInstructionType() {
return _fedType;
}
-
+
public long getTID() {
return _tid;
}
-
+
public void setTID(long tid) {
_tid = tid;
}
-
+
@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
Instruction tmp = super.preprocessInstruction(ec);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index e6a64cb..34f40bb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -36,6 +36,7 @@
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -48,6 +49,7 @@
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
@@ -101,7 +103,7 @@
}
else if(instruction.input1 != null && instruction.input1.isMatrix()
&& ec.containsVariable(instruction.input1)) {
-
+
MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated()) {
fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -160,18 +162,18 @@
}
else if(inst instanceof VariableCPInstruction ){
VariableCPInstruction ins = (VariableCPInstruction) inst;
- if(ins.getVariableOpcode() == VariableOperationCode.Write
+ if(ins.getVariableOpcode() == VariableOperationCode.Write
&& ins.getInput1().isMatrix()
&& ins.getInput3().getName().contains("federated")){
fedinst = VariableFEDInstruction.parseInstruction(ins);
}
- else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable
- && ins.getInput1().isMatrix()
+ else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable
+ && ins.getInput1().isMatrix()
&& ec.getCacheableData(ins.getInput1()).isFederated()){
fedinst = VariableFEDInstruction.parseInstruction(ins);
}
- else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable
- && ins.getInput1().isFrame()
+ else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable
+ && ins.getInput1().isFrame()
&& ec.getCacheableData(ins.getInput1()).isFederated()){
fedinst = VariableFEDInstruction.parseInstruction(ins);
}
@@ -183,16 +185,22 @@
fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins);
}
}
+ else if(inst instanceof QuaternaryCPInstruction) {
+ QuaternaryCPInstruction instruction = (QuaternaryCPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
return fedinst;
}
-
+
return inst;
}
-
+
public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
FEDInstruction fedinst = null;
if (inst instanceof MapmmSPInstruction) {
@@ -256,12 +264,18 @@
return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
}
}
+ else if(inst instanceof QuaternarySPInstruction) {
+ QuaternarySPInstruction instruction = (QuaternarySPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
return fedinst;
}
-
+
return inst;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
new file mode 100644
index 0000000..2b62ec5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
+{
+ protected CPOperand _input4 = null;
+
+ protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, Operator operator,
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String instruction_str)
+ {
+ super(type, operator, in1, in2, in3, out, opcode, instruction_str);
+ _input4 = in4;
+ }
+
+ public static QuaternaryFEDInstruction parseInstruction(String str)
+ {
+ if(str.startsWith(ExecType.SPARK.name())) {
+ // rewrite the spark instruction to a cp instruction
+ str = str.replace(ExecType.SPARK.name(), ExecType.CP.name());
+ str = str.replace("mapwcemm", "wcemm");
+ str += Lop.OPERAND_DELIMITOR + "1"; //num threads
+ }
+
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand out = new CPOperand(parts[5]);
+
+ InstructionUtils.checkNumFields(parts, 7);
+
+ if(opcode.equals("wcemm")) {
+ CPOperand in4 = new CPOperand(parts[4]);
+ checkDataTypes(in1, in2, in3, in4);
+
+ WCeMMType wcemm_type = WCeMMType.valueOf(parts[6]);
+ QuaternaryOperator quaternary_operator = (wcemm_type.hasFourInputs() ?
+ new QuaternaryOperator(wcemm_type, Double.parseDouble(in4.getName())) :
+ new QuaternaryOperator(wcemm_type));
+ return new QuaternaryWCeMMFEDInstruction(quaternary_operator, in1, in2, in3, in4, out, opcode, str);
+ }
+
+ throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for QuaternaryFEDInstruction.");
+ }
+
+ protected static void checkDataTypes(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4) {
+ if(in1.getDataType() != DataType.MATRIX || in2.getDataType() != DataType.MATRIX
+ || in3.getDataType() != DataType.MATRIX
+ || !(in4.getDataType() == DataType.SCALAR || in4.getDataType() == DataType.MATRIX)) {
+ throw new DMLRuntimeException("Federated quaternary operations "
+ + "only supported with matrix inputs and scalar epsilon.");
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
new file mode 100644
index 0000000..8566b39
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+import java.util.concurrent.Future;
+
+public class QuaternaryWCeMMFEDInstruction extends QuaternaryFEDInstruction
+{
+ // input1 ... federated X
+ // input2 ... U
+ // input3 ... V
+ // _input4 ... W (=epsilon)
+ protected QuaternaryWCeMMFEDInstruction(Operator operator,
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
+ CPOperand out, String opcode, String instruction_str)
+ {
+ super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, opcode, instruction_str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec)
+ {
+ QuaternaryOperator qop = (QuaternaryOperator) _optr;
+ MatrixObject X = ec.getMatrixObject(input1);
+ MatrixObject U = ec.getMatrixObject(input2);
+ MatrixObject V = ec.getMatrixObject(input3);
+ ScalarObject eps = null;
+
+ if(qop.hasFourInputs()) {
+ eps = (_input4.getDataType() == DataType.SCALAR) ?
+ ec.getScalarInput(_input4) :
+ new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+ }
+
+ if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
+ throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+ +X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")");
+
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
+ FederatedRequest fr2 = fedMap.broadcast(V);
+ FederatedRequest fr3 = null;
+ FederatedRequest frComp = null;
+
+ // broadcast the scalar epsilon if there are four inputs
+ if(eps != null) {
+ fr3 = fedMap.broadcast(eps);
+ // change the is_literal flag from true to false because when broadcasted it is no literal anymore
+ instString = instString.replace("true", "false");
+ frComp = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3, _input4},
+ new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()});
+ }
+ else {
+ frComp = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()});
+ }
+
+ FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+ FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID());
+ FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID());
+ FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID());
+
+ Future<FederatedResponse>[] response;
+ if(fr3 != null) {
+ FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID());
+ // execute federated instructions
+ response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
+ frComp, frGet, frClean1, frClean2, frClean3, frClean4);
+ }
+ else {
+ // execute federated instructions
+ response = fedMap.execute(getTID(), true, fr1, fr2,
+ frComp, frGet, frClean1, frClean2, frClean3);
+ }
+
+ //aggregate partial results from federated responses
+ AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index b2f3a53..a033769 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -19,9 +19,7 @@
package org.apache.sysds.runtime.instructions.fed;
-import java.util.AbstractMap;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import org.apache.sysds.common.Types;
@@ -36,16 +34,13 @@
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
-import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-import org.apache.sysds.runtime.util.IndexRange;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4d3e9d9..98c6b79 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -837,6 +837,10 @@
return TestUtils.readDMLMatrixFromHDFS(baseDirectory + OUTPUT_DIR + fileName);
}
+ protected static HashMap<CellIndex, Double> readDMLMatrixFromExpectedDir(String fileName) {
+ return TestUtils.readDMLMatrixFromHDFS(baseDirectory + EXPECTED_DIR + fileName);
+ }
+
public HashMap<CellIndex, Double> readRMatrixFromExpectedDir(String fileName) {
if(LOG.isInfoEnabled())
LOG.info("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
new file mode 100644
index 0000000..bf676a3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedWeightedCrossEntropyTest extends AutomatedTestBase
+{
+ private final static String STD_TEST_NAME = "FederatedWCeMMTest";
+ private final static String EPS_TEST_NAME = "FederatedWCeMMEpsTest";
+ private final static String TEST_DIR = "functions/federated/quaternary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWeightedCrossEntropyTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 1e-9;
+ private final static int blocksize = 1024;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int rank;
+ @Parameterized.Parameter(3)
+ public double epsilon;
+ @Parameterized.Parameter(4)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(STD_TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME, new String[]{OUTPUT_NAME}));
+ addTestConfiguration(EPS_TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, EPS_TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, rank, epsilon, sparsity}
+ {2000, 50, 10, 0.01, 0.01},
+ {2000, 50, 10, 0.01, 0.9},
+ {2000, 50, 10, 6.45, 0.01},
+ {2000, 50, 10, 6.45, 0.9}
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ public void federatedWeightedCrossEntropySingleNode() {
+ federatedWeightedCrossEntropy(STD_TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedCrossEntropySpark() {
+ federatedWeightedCrossEntropy(STD_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedCrossEntropySingleNodeEpsilon() {
+ federatedWeightedCrossEntropy(EPS_TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedCrossEntropySparkEpsilon() {
+ federatedWeightedCrossEntropy(EPS_TEST_NAME, ExecMode.SPARK);
+ }
+
+// -----------------------------------------------------------------------------
+
+ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode)
+ {
+ // store the previous platform config to restore it after the test
+ ExecMode platform_old = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(testname);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows / 2;
+ int fed_cols = cols;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, sparsity, 3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, sparsity, 7);
+
+ double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+ double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+ writeInputMatrixWithMTD("U", U, true);
+ writeInputMatrixWithMTD("V", V, true);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(testname);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testname + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+ "in_U=" + input("U"), "in_V=" + input("V"), "in_W=" + Double.toString(epsilon),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_U=" + input("U"),
+ "in_V=" + input("V"),
+ "in_W=" + Double.toString(epsilon),
+ "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ resetExecMode(platform_old);
+ }
+}
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
new file mode 100644
index 0000000..84c0b92
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols)))
+
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
new file mode 100644
index 0000000..c01f99a
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
new file mode 100644
index 0000000..75ae2ef
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols)))
+
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
new file mode 100644
index 0000000..499ed3d
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)