[SYSTEMDS-2747] Federated weighted cross entropy operations (WCEMM)

Quaternary operations, part 1
Closes #1133.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index e9c41b2..881991a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -119,7 +119,7 @@
 			long size = 0;
 			for(int i=0; i<ffr.length; i++) {
 				Object input = ffr[i].get().getData()[0];
-				MatrixBlock tmp = (input instanceof ScalarObject) ? 
+				MatrixBlock tmp = (input instanceof ScalarObject) ?
 					new MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input;
 				size += ranges[i].getSize(0);
 				sop1 = sop1.setConstant(ranges[i].getSize(0));
@@ -317,6 +317,10 @@
 		}
 	}
 
+	public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
+		return aggScalar(aop, ffr, null);
+	}
+
 	public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
 		if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin &&
 			(((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
@@ -366,7 +370,7 @@
 			throw new DMLRuntimeException("Unsupported aggregation operator: "
 				+ aop.aggOp.increOp.fn.getClass().getSimpleName());
 	}
-	
+
 	public static FederationMap federateLocalData(CacheableData<?> data) {
 		long id = FederationUtils.getNextFedDataID();
 		FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 00f6b72..0cf1fac 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -25,7 +25,7 @@
 import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
 
 public abstract class FEDInstruction extends Instruction {
-	
+
 	public enum FEDType {
 		AggregateBinary,
 		AggregateUnary,
@@ -40,41 +40,42 @@
 		Reorg,
 		Reshape,
 		MatrixIndexing,
+		Quaternary,
 		QSort,
 		QPick
 	}
-	
+
 	protected final FEDType _fedType;
 	protected long _tid = -1; //main
-	
+
 	protected FEDInstruction(FEDType type, String opcode, String istr) {
 		this(type, null, opcode, istr);
 	}
-	
+
 	protected FEDInstruction(FEDType type, Operator op, String opcode, String istr) {
 		super(op);
 		_fedType = type;
 		instString = istr;
 		instOpcode = opcode;
 	}
-	
+
 	@Override
 	public IType getType() {
 		return IType.FEDERATED;
 	}
-	
+
 	public FEDType getFEDInstructionType() {
 		return _fedType;
 	}
-	
+
 	public long getTID() {
 		return _tid;
 	}
-	
+
 	public void setTID(long tid) {
 		_tid = tid;
 	}
-	
+
 	@Override
 	public Instruction preprocessInstruction(ExecutionContext ec) {
 		Instruction tmp = super.preprocessInstruction(ec);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index e6a64cb..34f40bb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -36,6 +36,7 @@
 import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -48,6 +49,7 @@
 import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
 
@@ -101,7 +103,7 @@
 			}
 			else if(instruction.input1 != null && instruction.input1.isMatrix()
 				&& ec.containsVariable(instruction.input1)) {
-				
+
 				MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
 				if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated()) {
 					fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -160,18 +162,18 @@
 		}
 		else if(inst instanceof VariableCPInstruction ){
 			VariableCPInstruction ins = (VariableCPInstruction) inst;
-			if(ins.getVariableOpcode() == VariableOperationCode.Write 
+			if(ins.getVariableOpcode() == VariableOperationCode.Write
 				&& ins.getInput1().isMatrix()
 				&& ins.getInput3().getName().contains("federated")){
 				fedinst = VariableFEDInstruction.parseInstruction(ins);
 			}
-			else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable 
-				&& ins.getInput1().isMatrix() 
+			else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable
+				&& ins.getInput1().isMatrix()
 				&& ec.getCacheableData(ins.getInput1()).isFederated()){
 				fedinst = VariableFEDInstruction.parseInstruction(ins);
 			}
-			else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable 
-				&& ins.getInput1().isFrame() 
+			else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable
+				&& ins.getInput1().isFrame()
 				&& ec.getCacheableData(ins.getInput1()).isFederated()){
 				fedinst = VariableFEDInstruction.parseInstruction(ins);
 			}
@@ -183,16 +185,22 @@
 				fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins);
 			}
 		}
+		else if(inst instanceof QuaternaryCPInstruction) {
+			QuaternaryCPInstruction instruction = (QuaternaryCPInstruction) inst;
+			Data data = ec.getVariable(instruction.input1);
+			if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+				fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+		}
 
 		//set thread id for federated context management
 		if( fedinst != null ) {
 			fedinst.setTID(ec.getTID());
 			return fedinst;
 		}
-		
+
 		return inst;
 	}
-	
+
 	public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
 		FEDInstruction fedinst = null;
 		if (inst instanceof MapmmSPInstruction) {
@@ -256,12 +264,18 @@
 				return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
 			}
 		}
+		else if(inst instanceof QuaternarySPInstruction) {
+			QuaternarySPInstruction instruction = (QuaternarySPInstruction) inst;
+			Data data = ec.getVariable(instruction.input1);
+			if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+				fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+		}
 		//set thread id for federated context management
 		if( fedinst != null ) {
 			fedinst.setTID(ec.getTID());
 			return fedinst;
 		}
-		
+
 		return inst;
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
new file mode 100644
index 0000000..2b62ec5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *	 http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.	See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
+{
+	protected CPOperand _input4 = null;
+
+	protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, Operator operator,
+		CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String instruction_str)
+	{
+		super(type, operator, in1, in2, in3, out, opcode, instruction_str);
+		_input4 = in4;
+	}
+
+	public static QuaternaryFEDInstruction parseInstruction(String str)
+	{
+		if(str.startsWith(ExecType.SPARK.name())) {
+			// rewrite the spark instruction to a cp instruction
+			str = str.replace(ExecType.SPARK.name(), ExecType.CP.name());
+			str = str.replace("mapwcemm", "wcemm");
+			str += Lop.OPERAND_DELIMITOR + "1"; //num threads
+		}
+
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode = parts[0];
+
+		CPOperand in1 = new CPOperand(parts[1]);
+		CPOperand in2 = new CPOperand(parts[2]);
+		CPOperand in3 = new CPOperand(parts[3]);
+		CPOperand out = new CPOperand(parts[5]);
+
+		InstructionUtils.checkNumFields(parts, 7);
+
+		if(opcode.equals("wcemm")) {
+			CPOperand in4 = new CPOperand(parts[4]);
+			checkDataTypes(in1, in2, in3, in4);
+
+			WCeMMType wcemm_type = WCeMMType.valueOf(parts[6]);
+			QuaternaryOperator quaternary_operator = (wcemm_type.hasFourInputs() ?
+				new QuaternaryOperator(wcemm_type, Double.parseDouble(in4.getName())) :
+				new QuaternaryOperator(wcemm_type));
+			return new QuaternaryWCeMMFEDInstruction(quaternary_operator, in1, in2, in3, in4, out, opcode, str);
+		}
+
+		throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for QuaternaryFEDInstruction.");
+	}
+
+	protected static void checkDataTypes(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4) {
+		if(in1.getDataType() != DataType.MATRIX || in2.getDataType() != DataType.MATRIX 
+			|| in3.getDataType() != DataType.MATRIX 
+			|| !(in4.getDataType() == DataType.SCALAR || in4.getDataType() == DataType.MATRIX)) {
+			throw new DMLRuntimeException("Federated quaternary operations "
+				+ "only supported with matrix inputs and scalar epsilon.");
+		}
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
new file mode 100644
index 0000000..8566b39
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *	 http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.	See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+import java.util.concurrent.Future;
+
+public class QuaternaryWCeMMFEDInstruction extends QuaternaryFEDInstruction
+{
+	// input1 ... federated X
+	// input2 ... U
+	// input3 ... V
+	// _input4 ... W (=epsilon)
+	protected QuaternaryWCeMMFEDInstruction(Operator operator,
+		CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
+		CPOperand out, String opcode, String instruction_str)
+	{
+		super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, opcode, instruction_str);
+	}
+
+	@Override
+	public void processInstruction(ExecutionContext ec)
+	{
+		QuaternaryOperator qop = (QuaternaryOperator) _optr;
+		MatrixObject X = ec.getMatrixObject(input1);
+		MatrixObject U = ec.getMatrixObject(input2);
+		MatrixObject V = ec.getMatrixObject(input3);
+		ScalarObject eps = null;
+		
+		if(qop.hasFourInputs()) {
+			eps = (_input4.getDataType() == DataType.SCALAR) ?
+				ec.getScalarInput(_input4) :
+				new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+		}
+
+		if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
+			throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+				+X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")");
+		
+		FederationMap fedMap = X.getFedMapping();
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
+		FederatedRequest fr2 = fedMap.broadcast(V);
+		FederatedRequest fr3 = null;
+		FederatedRequest frComp = null;
+
+		// broadcast the scalar epsilon if there are four inputs
+		if(eps != null) {
+			fr3 = fedMap.broadcast(eps);
+			// change the is_literal flag from true to false because when broadcasted it is no literal anymore
+			instString = instString.replace("true", "false");
+			frComp = FederationUtils.callInstruction(instString, output,
+				new CPOperand[]{input1, input2, input3, _input4},
+				new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()});
+		}
+		else {
+			frComp = FederationUtils.callInstruction(instString, output,
+			new CPOperand[]{input1, input2, input3},
+			new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()});
+		}
+		
+		FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+		FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID());
+		FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID());
+		FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID());
+
+		Future<FederatedResponse>[] response;
+		if(fr3 != null) {
+			FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID());
+			// execute federated instructions
+			response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
+				frComp, frGet, frClean1, frClean2, frClean3, frClean4);
+		}
+		else {
+			// execute federated instructions
+			response = fedMap.execute(getTID(), true, fr1, fr2,
+				frComp, frGet, frClean1, frClean2, frClean3);
+		}
+		
+		//aggregate partial results from federated responses
+		AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+		ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index b2f3a53..a033769 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -19,9 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
-import java.util.AbstractMap;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import org.apache.sysds.common.Types;
@@ -36,16 +34,13 @@
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
-import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-import org.apache.sysds.runtime.util.IndexRange;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
 	
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4d3e9d9..98c6b79 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -837,6 +837,10 @@
 		return TestUtils.readDMLMatrixFromHDFS(baseDirectory + OUTPUT_DIR + fileName);
 	}
 
+	protected static HashMap<CellIndex, Double> readDMLMatrixFromExpectedDir(String fileName) {
+		return TestUtils.readDMLMatrixFromHDFS(baseDirectory + EXPECTED_DIR + fileName);
+	}
+	
 	public HashMap<CellIndex, Double> readRMatrixFromExpectedDir(String fileName) {
 		if(LOG.isInfoEnabled())
 			LOG.info("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
new file mode 100644
index 0000000..bf676a3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.	See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.	The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.	You may obtain a copy of the License at
+ *
+ *	 http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.	See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedWeightedCrossEntropyTest extends AutomatedTestBase
+{
+	private final static String STD_TEST_NAME = "FederatedWCeMMTest";
+	private final static String EPS_TEST_NAME = "FederatedWCeMMEpsTest";
+	private final static String TEST_DIR = "functions/federated/quaternary/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWeightedCrossEntropyTest.class.getSimpleName() + "/";
+
+	private final static String OUTPUT_NAME = "Z";
+	private final static double TOLERANCE = 1e-9;
+	private final static int blocksize = 1024;
+
+	@Parameterized.Parameter()
+	public int rows;
+	@Parameterized.Parameter(1)
+	public int cols;
+	@Parameterized.Parameter(2)
+	public int rank;
+	@Parameterized.Parameter(3)
+	public double epsilon;
+	@Parameterized.Parameter(4)
+	public double sparsity;
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(STD_TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME, new String[]{OUTPUT_NAME}));
+		addTestConfiguration(EPS_TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, EPS_TEST_NAME, new String[]{OUTPUT_NAME}));
+	}
+
+	@Parameterized.Parameters
+	public static Collection<Object[]> data() {
+		// rows must be even
+		return Arrays.asList(new Object[][] {
+			// {rows, cols, rank, epsilon, sparsity}
+			{2000, 50, 10, 0.01, 0.01},
+			{2000, 50, 10, 0.01, 0.9},
+			{2000, 50, 10, 6.45, 0.01},
+			{2000, 50, 10, 6.45, 0.9}
+		});
+	}
+
+	@BeforeClass
+	public static void init() {
+		TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+	}
+
+	@Test
+	public void federatedWeightedCrossEntropySingleNode() {
+		federatedWeightedCrossEntropy(STD_TEST_NAME, ExecMode.SINGLE_NODE);
+	}
+
+	@Test
+	public void federatedWeightedCrossEntropySpark() {
+		federatedWeightedCrossEntropy(STD_TEST_NAME, ExecMode.SPARK);
+	}
+
+	@Test
+	public void federatedWeightedCrossEntropySingleNodeEpsilon() {
+		federatedWeightedCrossEntropy(EPS_TEST_NAME, ExecMode.SINGLE_NODE);
+	}
+
+	@Test
+	public void federatedWeightedCrossEntropySparkEpsilon() {
+		federatedWeightedCrossEntropy(EPS_TEST_NAME, ExecMode.SPARK);
+	}
+
+// -----------------------------------------------------------------------------
+
+	public void federatedWeightedCrossEntropy(String testname, ExecMode execMode)
+	{
+		// store the previous platform config to restore it after the test
+		ExecMode platform_old = setExecMode(execMode);
+
+		getAndLoadTestConfiguration(testname);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+
+		int fed_rows = rows / 2;
+		int fed_cols = cols;
+
+		// generate dataset
+		// matrix handled by two federated workers
+		double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, sparsity, 3);
+		double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, sparsity, 7);
+
+		double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+		double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+		writeInputMatrixWithMTD("U", U, true);
+		writeInputMatrixWithMTD("V", V, true);
+
+		// empty script name because we don't execute any script, just start the worker
+		fullDMLScriptName = "";
+		int port1 = getRandomAvailablePort();
+		int port2 = getRandomAvailablePort();
+		Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+		Thread thread2 = startLocalFedWorkerThread(port2);
+
+		getAndLoadTestConfiguration(testname);
+
+		// Run reference dml script with normal matrix
+		fullDMLScriptName = HOME + testname + "Reference.dml";
+		programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+			"in_U=" + input("U"), "in_V=" + input("V"), "in_W=" + Double.toString(epsilon),
+			"out_Z=" + expected(OUTPUT_NAME)};
+		runTest(true, false, null, -1);
+
+		// Run actual dml script with federated matrix
+		fullDMLScriptName = HOME + testname + ".dml";
+		programArgs = new String[] {"-stats", "-nvargs",
+			"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+			"in_U=" + input("U"),
+			"in_V=" + input("V"),
+			"in_W=" + Double.toString(epsilon),
+			"rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + output(OUTPUT_NAME)};
+		runTest(true, false, null, -1);
+
+		// compare the results via files
+		HashMap<CellIndex, Double> refResults  = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+		HashMap<CellIndex, Double> fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
+		TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
+
+		TestUtils.shutdownThreads(thread1, thread2);
+
+		// check for federated operations
+		Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
+
+		// check that federated input files are still existing
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+		resetExecMode(platform_old);
+	}
+}
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
new file mode 100644
index 0000000..84c0b92
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols)))
+
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
new file mode 100644
index 0000000..c01f99a
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
new file mode 100644
index 0000000..75ae2ef
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols)))
+
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)
diff --git a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
new file mode 100644
index 0000000..499ed3d
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)