[SYSTEMDS-2709] Fix missing federated unary aggregate for scalar mean

With the fixed missing size propagation for federated init statements,
now rewrites trigger, which expose operations we don't support yet. This
patch adds, besides the existing row means and columns means, also
support for full mean aggregates.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index c8da781..37cb7d5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -98,7 +98,9 @@
 			MatrixBlock ret = null;
 			long size = 0;
 			for(int i=0; i<ffr.length; i++) {
-				MatrixBlock tmp = (MatrixBlock)ffr[i].get().getData()[0];
+				Object input = ffr[i].get().getData()[0];
+				MatrixBlock tmp = (input instanceof ScalarObject) ? 
+					new MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input;
 				size += ranges[i].getSize(0);
 				sop1 = sop1.setConstant(ranges[i].getSize(0));
 				tmp = tmp.scalarOperations(sop1, new MatrixBlock());
@@ -167,10 +169,11 @@
 		}
 	}
 
-	public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
+	public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
 		if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin &&
-			(((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN ||
-				((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) {
+			(((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
+			|| ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)
+			|| aop.aggOp.increOp.fn instanceof Mean ))) {
 			throw new DMLRuntimeException("Unsupported aggregation operator: "
 				+ aop.aggOp.increOp.getClass().getSimpleName());
 		}
@@ -181,7 +184,10 @@
 				boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
 				return new DoubleObject(aggMinMax(ffr, isMin, true,  Optional.empty()).getValue(0,0));
 			}
-			else {
+			else if( aop.aggOp.increOp.fn instanceof Mean ) {
+				return new DoubleObject(aggMean(ffr, map).getValue(0,0));
+			}
+			else { //if (aop.aggOp.increOp.fn instanceof KahanFunction)
 				double sum = 0; //uak+
 				for( Future<FederatedResponse> fr : ffr )
 					sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue();
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 60fe40b..d06dfaa 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -66,7 +66,7 @@
 		//execute federated commands and cleanups
 		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
 		if( output.isScalar() )
-			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
+			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, map));
 		else
 			ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
 	}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
index af55b95..a970479 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
@@ -37,26 +37,26 @@
 import org.junit.Assert;
 
 public class FederatedTestObjectConstructor {
-    public static MatrixObject constructFederatedInput(int rows, int cols, int blocksize, String host, long[][] begin,
-        long[][] end, int[] ports, String[] inputs, String file) {
-        MatrixObject fed = new MatrixObject(ValueType.FP64, file);
-        try {
-            fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, blocksize, rows * cols)));
-            List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>();
-            for(int i = 0; i < ports.length; i++) {
-                FederatedRange X1r = new FederatedRange(begin[i], end[i]);
-                FederatedData X1d = new FederatedData(Types.DataType.MATRIX,
-                    new InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]);
-                d.add(new ImmutablePair<>(X1r, X1d));
-            }
+	public static MatrixObject constructFederatedInput(int rows, int cols, int blocksize, String host, long[][] begin,
+		long[][] end, int[] ports, String[] inputs, String file) {
+		MatrixObject fed = new MatrixObject(ValueType.FP64, file);
+		try {
+			fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, blocksize, rows * cols)));
+			List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>();
+			for(int i = 0; i < ports.length; i++) {
+				FederatedRange X1r = new FederatedRange(begin[i], end[i]);
+				FederatedData X1d = new FederatedData(Types.DataType.MATRIX,
+					new InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]);
+				d.add(new ImmutablePair<>(X1r, X1d));
+			}
 
-            InitFEDInstruction.federateMatrix(fed, d);
-        }
-        catch(Exception e) {
-            e.printStackTrace();
-            Assert.assertTrue(false);
-        }
-        return fed;
+			InitFEDInstruction.federateMatrix(fed, d);
+		}
+		catch(Exception e) {
+			e.printStackTrace();
+			Assert.assertTrue(false);
+		}
+		return fed;
 
-    }
+	}
 }