[SYSTEMDS-2550] Federated parameter server scaling and weight handling

Closes #1141.
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 5171f21..05bfc48 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -289,8 +289,8 @@
 		Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS,
 			Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
 			Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
-			Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
-			Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+			Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
+			Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED);
 		checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
 		// check existence and correctness of parameters
@@ -308,9 +308,11 @@
 		checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional);
 		checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional);
 		checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
-		checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional);
+		checkStringParam(true, fname, Statement.PS_FED_RUNTIME_BALANCING, conditional);
+		checkStringParam(true, fname, Statement.PS_FED_WEIGHING, conditional);
 		checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
 		checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional);
+		checkDataValueType(true, fname, Statement.PS_SEED, DataType.SCALAR, ValueType.INT64, conditional);
 
 		// set output characteristics
 		output.setDataType(DataType.LIST);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index 6767d85..9104246 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -70,6 +70,7 @@
 	public static final String PS_AGGREGATION_FUN = "agg";
 	public static final String PS_MODE = "mode";
 	public static final String PS_GRADIENTS = "gradients";
+	public static final String PS_SEED = "seed";
 	public enum PSModeType {
 		FEDERATED, LOCAL, REMOTE_SPARK
 	}
@@ -87,9 +88,10 @@
 	public enum PSFrequency {
 		BATCH, EPOCH
 	}
-	public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+	public static final String PS_FED_WEIGHING = "weighing";
+	public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing";
 	public enum PSRuntimeBalancing {
-		NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH
+		NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
 	}
 	public static final String PS_EPOCHS = "epochs";
 	public static final String PS_BATCH_SIZE = "batchsize";
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 54a45d0..f09b5c2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -228,7 +228,7 @@
 		return val;
 	}
 
-	protected final double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
+	protected static double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
 		throw new NotImplementedException("This Method was implemented incorrectly");
 		// final int numCols = getNumCols();
 		// final int valOff = valIx * numCols;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 393b131..48249db 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -24,6 +24,8 @@
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.Statement.PSFrequency;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
 import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
@@ -37,13 +39,17 @@
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.util.ProgramConverter;
 
 import java.util.ArrayList;
@@ -58,21 +64,29 @@
 public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
 	private static final long serialVersionUID = 6846648059569648791L;
 	protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
-	
-	Statement.PSRuntimeBalancing _runtimeBalancing;
-	FederatedData _featuresData;
-	FederatedData _labelsData;
-	final long _localStartBatchNumVarID;
-	final long _modelVarID;
-	int _numBatchesPerGlobalEpoch;
-	int _possibleBatchesPerLocalEpoch;
-	boolean _cycleStartAt0 = false;
 
-	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {
+	private FederatedData _featuresData;
+	private FederatedData _labelsData;
+	private final long _localStartBatchNumVarID;
+	private final long _modelVarID;
+
+	// runtime balancing
+	private PSRuntimeBalancing _runtimeBalancing;
+	private int _numBatchesPerEpoch;
+	private int _possibleBatchesPerLocalEpoch;
+	private boolean _weighing;
+	private double _weighingFactor = 1;
+	private boolean _cycleStartAt0 = false;
+
+	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
+		PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize,
+		int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps)
+	{
 		super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
 
-		_numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+		_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
 		_runtimeBalancing = runtimeBalancing;
+		_weighing = weighing;
 		// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
 		_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
 		_modelVarID = FederationUtils.getNextFedDataID();
@@ -80,65 +94,72 @@
 
 	/**
 	 * Sets up the federated worker and control thread
+	 *
+	 * @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled
 	 */
-	public void setup() {
+	public void setup(double weighingFactor) {
 		// prepare features and labels
 		_featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
 		_labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];
 
-		// calculate number of batches and get data size
+		// weighing factor is always set, but only used when weighing is specified
+		_weighingFactor = weighingFactor;
+
+		// different runtime balancing calculations
 		long dataSize = _features.getNumRows();
-		_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
-		if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) {
-			_numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch;
+
+		// calculate scaled batch size if balancing via batch size.
+		// In some cases there will be some cycling
+		if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+			_batchSize = (int) Math.ceil((double) dataSize / _numBatchesPerEpoch);
 		}
 
-		if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
-			throw new NotImplementedException();
+		// Calculate possible batches with batch size
+		_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
+
+		// If no runtime balancing is specified, just run possible number of batches
+		// WARNING: Will get stuck on miss match
+		if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+			_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
 		}
 
+		LOG.info("Setup config for worker " + this.getWorkerName());
+		LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch
+				+ " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+
 		// serialize program
 		// create program blocks for the instruction filtering
 		String programSerialized;
-		ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+		ArrayList<ProgramBlock> pbs = new ArrayList<>();
 
 		BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
 		gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
-		programBlocks.add(gradientProgramBlock);
+		pbs.add(gradientProgramBlock);
 
-		if(_freq == Statement.PSFrequency.EPOCH) {
+		if(_freq == PSFrequency.EPOCH) {
 			BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
 			aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
-			programBlocks.add(aggProgramBlock);
+			pbs.add(aggProgramBlock);
 		}
 
-		StringBuilder sb = new StringBuilder();
-		sb.append(PROG_BEGIN);
-		sb.append( NEWLINE );
-		sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
-				programBlocks,
-				new HashMap<>(),
-				false
-		));
-		sb.append(PROG_END);
-		programSerialized = sb.toString();
+		programSerialized = InstructionUtils.concatStrings(
+			PROG_BEGIN, NEWLINE,
+			ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>(), false),
+			PROG_END);
 
 		// write program and meta data to worker
 		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
 			new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
 				new SetupFederatedWorker(_batchSize,
-						dataSize,
-						_possibleBatchesPerLocalEpoch,
-						programSerialized,
-						_inst.getNamespace(),
-						_inst.getFunctionName(),
-						_ps.getAggInst().getFunctionName(),
-						_ec.getListObject("hyperparams"),
-						_localStartBatchNumVarID,
-						_modelVarID
+					dataSize,
+					_possibleBatchesPerLocalEpoch,
+					programSerialized,
+					_inst.getNamespace(),
+					_inst.getFunctionName(),
+					_ps.getAggInst().getFunctionName(),
+					_ec.getListObject("hyperparams"),
+					_localStartBatchNumVarID,
+					_modelVarID
 				)
 		));
 
@@ -286,12 +307,23 @@
 		return _ps.pull(_workerID);
 	}
 
-	protected void pushGradients(ListObject gradients) {
+	protected void scaleAndPushGradients(ListObject gradients) {
+		// scale gradients - must only include MatrixObjects
+		if(_weighing && _weighingFactor != 1) {
+			gradients.getData().parallelStream().forEach((matrix) -> {
+				MatrixObject matrixObject = (MatrixObject) matrix;
+				MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations(
+					new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock());
+				matrixObject.acquireModify(input);
+				matrixObject.release();
+			});
+		}
+
 		// Push the gradients to ps
 		_ps.push(_workerID, gradients);
 	}
 
-	static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
+	protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
 		return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
 	}
 
@@ -300,18 +332,18 @@
 	 */
 	protected void computeWithBatchUpdates() {
 		for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
-			int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+			int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
-			for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) {
+			for (int batchCounter = 0; batchCounter < _numBatchesPerEpoch; batchCounter++) {
 				int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
 				ListObject model = pullModel();
 				ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum);
-				pushGradients(gradients);
+				scaleAndPushGradients(gradients);
 				ParamservUtils.cleanupListObject(model);
 				ParamservUtils.cleanupListObject(gradients);
+				LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum);
 			}
-			if( LOG.isInfoEnabled() )
-				LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+			LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
 		}
 	}
 
@@ -327,15 +359,14 @@
 	 */
 	protected void computeWithEpochUpdates() {
 		for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
-			int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+			int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
 			// Pull the global parameters from ps
 			ListObject model = pullModel();
-			ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, localStartBatchNum, true);
-			pushGradients(gradients);
-			
-			if( LOG.isInfoEnabled() )
-				LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+			ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
+			scaleAndPushGradients(gradients);
+
+			LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
 			ParamservUtils.cleanupListObject(model);
 			ParamservUtils.cleanupListObject(gradients);
 		}
@@ -424,12 +455,12 @@
 			ArrayList<DataIdentifier> inputs = func.getInputParams();
 			ArrayList<DataIdentifier> outputs = func.getOutputParams();
 			CPOperand[] boundInputs = inputs.stream()
-					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-					.toArray(CPOperand[]::new);
+				.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+				.toArray(CPOperand[]::new);
 			ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
-					.collect(Collectors.toCollection(ArrayList::new));
+				.collect(Collectors.toCollection(ArrayList::new));
 			Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
-					func.getInputParamNames(), outputNames, "gradient function");
+				func.getInputParamNames(), outputNames, "gradient function");
 			DataIdentifier gradientsOutput = outputs.get(0);
 
 			// recreate aggregation instruction and output if needed
@@ -440,12 +471,12 @@
 				inputs = func.getInputParams();
 				outputs = func.getOutputParams();
 				boundInputs = inputs.stream()
-						.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-						.toArray(CPOperand[]::new);
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
 				outputNames = outputs.stream().map(DataIdentifier::getName)
-						.collect(Collectors.toCollection(ArrayList::new));
+					.collect(Collectors.toCollection(ArrayList::new));
 				aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
-						func.getInputParamNames(), outputNames, "aggregation function");
+					func.getInputParamNames(), outputNames, "aggregation function");
 				aggregationOutput = outputs.get(0);
 			}
 
@@ -492,8 +523,6 @@
 				ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
 				ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
 				ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
-				if( LOG.isInfoEnabled() )
-					LOG.info("[+]" + " completed batch " + localBatchNum);
 			}
 
 			// model clean up
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 460faba..34e94f0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
@@ -35,13 +34,25 @@
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Balance to Avg Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global average number of examples is taken
+ * and the worker subsamples or replicates data to match that number of examples. See the other federated schemes.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
 	@Override
-	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+	public Result partition(MatrixObject features, MatrixObject labels, int seed) {
 		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
 		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		BalanceMetrics balanceMetricsBefore = getBalanceMetrics(pFeatures);
+		List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetricsBefore);
 
-		int average_num_rows = (int) Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+		int average_num_rows = (int) balanceMetricsBefore._avgRows;
 
 		for(int i = 0; i < pFeatures.size(); i++) {
 			// Works, because the map contains a single entry
@@ -49,7 +60,7 @@
 			FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
 			Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-					featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, average_num_rows)));
+					featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, average_num_rows)));
 
 			try {
 				FederatedResponse response = udfResponse.get();
@@ -66,7 +77,7 @@
 			pLabels.get(i).updateDataCharacteristics(update);
 		}
 
-		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
 	}
 
 	/**
@@ -74,10 +85,12 @@
 	 */
 	private static class balanceDataOnFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = 6631958250346625546L;
+		private final int _seed;
 		private final int _average_num_rows;
-		
-		protected balanceDataOnFederatedWorker(long[] inIDs, int average_num_rows) {
+
+		protected balanceDataOnFederatedWorker(long[] inIDs, int seed, int average_num_rows) {
 			super(inIDs);
+			_seed = seed;
 			_average_num_rows = average_num_rows;
 		}
 
@@ -88,14 +101,14 @@
 
 			if(features.getNumRows() > _average_num_rows) {
 				// generate subsampling matrix
-				MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+				MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), _seed);
 				subsampleTo(features, subsampleMatrixBlock);
 				subsampleTo(labels, subsampleMatrixBlock);
 			}
 			else if(features.getNumRows() < _average_num_rows) {
 				int num_rows_needed = _average_num_rows - Math.toIntExact(features.getNumRows());
 				// generate replication matrix
-				MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+				MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed);
 				replicateTo(features, replicateMatrixBlock);
 				replicateTo(labels, replicateMatrixBlock);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index f5c9638..e00923e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,16 +45,31 @@
 		public final List<MatrixObject> _pLabels;
 		public final int _workerNum;
 		public final BalanceMetrics _balanceMetrics;
+		public final List<Double> _weighingFactors;
 
-		public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum, BalanceMetrics balanceMetrics) {
-			this._pFeatures = pFeatures;
-			this._pLabels = pLabels;
-			this._workerNum = workerNum;
-			this._balanceMetrics = balanceMetrics;
+
+		public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double> weighingFactors) {
+			_pFeatures = pFeatures;
+			_pLabels = pLabels;
+			_workerNum = workerNum;
+			_balanceMetrics = balanceMetrics;
+			_weighingFactors = weighingFactors;
 		}
 	}
 
-	public abstract Result doPartitioning(MatrixObject features, MatrixObject labels);
+	public static final class BalanceMetrics {
+		public final long _minRows;
+		public final long _avgRows;
+		public final long _maxRows;
+
+		public BalanceMetrics(long minRows, long avgRows, long maxRows) {
+			_minRows = minRows;
+			_avgRows = avgRows;
+			_maxRows = maxRows;
+		}
+	}
+
+	public abstract Result partition(MatrixObject features, MatrixObject labels, int seed);
 
 	/**
 	 * Takes a row federated Matrix and slices it into a matrix for each worker
@@ -110,16 +125,12 @@
 		return new BalanceMetrics(minRows, sum / slices.size(), maxRows);
 	}
 
-	public static final class BalanceMetrics {
-		public final long _minRows;
-		public final long _avgRows;
-		public final long _maxRows;
-
-		public BalanceMetrics(long minRows, long avgRows, long maxRows) {
-			this._minRows = minRows;
-			this._avgRows = avgRows;
-			this._maxRows = maxRows;
-		}
+	static List<Double> getWeighingFactors(List<MatrixObject> pFeatures, BalanceMetrics balanceMetrics) {
+		List<Double> weighingFactors = new ArrayList<>();
+		pFeatures.forEach((feature) -> {
+			weighingFactors.add((double) feature.getNumRows() / balanceMetrics._avgRows);
+		});
+		return weighingFactors;
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index d1ebb6c..ce2f954 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -24,10 +24,11 @@
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 
 public class FederatedDataPartitioner {
-
 	private final DataPartitionFederatedScheme _scheme;
+	private final int _seed;
 
-	public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+	public FederatedDataPartitioner(Statement.FederatedPSScheme scheme, int seed) {
+		_seed = seed;
 		switch (scheme) {
 			case KEEP_DATA_ON_WORKER:
 				_scheme = new KeepDataOnWorkerFederatedScheme();
@@ -50,6 +51,6 @@
 	}
 
 	public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject features, MatrixObject labels) {
-		return _scheme.doPartitioning(features, labels);
+		return _scheme.partition(features, labels, _seed);
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index e306f25..afbaf4d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -22,11 +22,20 @@
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import java.util.List;
 
+/**
+ * Keep Data on Worker Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * All entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme {
 	@Override
-	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+	public Result partition(MatrixObject features, MatrixObject labels, int seed) {
 		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
 		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+		BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+		List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetrics);
+		return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors);
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index 068cfa9..a1b8f6c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -34,11 +34,23 @@
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Replicate to Max Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global maximum number of examples is taken
+ * and the worker replicates data to match that number of examples. The generation is done by multiplying with a
+ * Permutation Matrix with a global seed. These selected examples are appended to the original data.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme {
 	@Override
-	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+	public Result partition(MatrixObject features, MatrixObject labels, int seed) {
 		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
 		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		List<Double> weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures));
 
 		int max_rows = 0;
 		for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@
 			FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
 			Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-					featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, max_rows)));
+					featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, max_rows)));
 
 			try {
 				FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@
 			pLabels.get(i).updateDataCharacteristics(update);
 		}
 
-		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
 	}
 
 	/**
@@ -76,10 +88,12 @@
 	 */
 	private static class replicateDataOnFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = -6930898456315100587L;
+		private final int _seed;
 		private final int _max_rows;
-		
-		protected replicateDataOnFederatedWorker(long[] inIDs, int max_rows) {
+
+		protected replicateDataOnFederatedWorker(long[] inIDs, int seed, int max_rows) {
 			super(inIDs);
+			_seed = seed;
 			_max_rows = max_rows;
 		}
 
@@ -92,7 +106,7 @@
 			if(features.getNumRows() < _max_rows) {
 				int num_rows_needed = _max_rows - Math.toIntExact(features.getNumRows());
 				// generate replication matrix
-				MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+				MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed);
 				replicateTo(features, replicateMatrixBlock);
 				replicateTo(labels, replicateMatrixBlock);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 65ef69d..1920593 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -33,11 +33,23 @@
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Shuffle Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case it is shuffled by generating a permutation
+ * matrix with a global seed and doing a mat mult.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
 	@Override
-	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+	public Result partition(MatrixObject features, MatrixObject labels, int seed) {
 		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
 		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+		List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetrics);
 
 		for(int i = 0; i < pFeatures.size(); i++) {
 			// Works, because the map contains a single entry
@@ -45,7 +57,7 @@
 			FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
 			Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-					featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()})));
+					featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed)));
 
 			try {
 				FederatedResponse response = udfResponse.get();
@@ -57,7 +69,7 @@
 			}
 		}
 
-		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+		return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors);
 	}
 
 	/**
@@ -65,9 +77,11 @@
 	 */
 	private static class shuffleDataOnFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = 3228664618781333325L;
+		private final int _seed;
 
-		protected shuffleDataOnFederatedWorker(long[] inIDs) {
+		protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) {
 			super(inIDs);
+			_seed = seed;
 		}
 
 		@Override
@@ -76,7 +90,7 @@
 			MatrixObject labels = (MatrixObject) data[1];
 
 			// generate permutation matrix
-			MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+			MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), _seed);
 			shuffle(features, permutationMatrixBlock);
 			shuffle(labels, permutationMatrixBlock);
 			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 9b62cc8..937c37e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -34,11 +34,23 @@
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Subsample to Min Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global minimum number of examples is taken
+ * and the worker subsamples data to match that number of examples. The subsampling is done by multiplying with a
+ * Permutation Matrix with a global seed.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme {
 	@Override
-	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+	public Result partition(MatrixObject features, MatrixObject labels, int seed) {
 		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
 		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		List<Double> weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures));
 
 		int min_rows = Integer.MAX_VALUE;
 		for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@
 			FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
 			Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-					featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, min_rows)));
+					featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, min_rows)));
 
 			try {
 				FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@
 			pLabels.get(i).updateDataCharacteristics(update);
 		}
 
-		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+		return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
 	}
 
 	/**
@@ -76,10 +88,12 @@
 	 */
 	private static class subsampleDataOnFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = 2213790859544004286L;
+		private final int _seed;
 		private final int _min_rows;
-		
-		protected subsampleDataOnFederatedWorker(long[] inIDs, int min_rows) {
+
+		protected subsampleDataOnFederatedWorker(long[] inIDs, int seed, int min_rows) {
 			super(inIDs);
+			_seed = seed;
 			_min_rows = min_rows;
 		}
 
@@ -91,7 +105,7 @@
 			// subsample down to minimum
 			if(features.getNumRows() > _min_rows) {
 				// generate subsampling matrix
-				MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+				MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), _seed);
 				subsampleTo(features, subsampleMatrixBlock);
 				subsampleTo(labels, subsampleMatrixBlock);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a2b8d9f..a66e039 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,6 +19,17 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
 import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
 import static org.apache.sysds.parser.Statement.PS_EPOCHS;
@@ -32,18 +43,9 @@
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
+import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_SEED;
 
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
@@ -121,37 +123,36 @@
 	}
 
 	private void runFederated(ExecutionContext ec) {
-		System.out.println("PARAMETER SERVER");
-		System.out.println("[+] Running in federated mode");
+		LOG.info("PARAMETER SERVER");
+		LOG.info("[+] Running in federated mode");
 
 		// get inputs
-		PSFrequency freq = getFrequency();
-		PSUpdateType updateType = getUpdateType();
-		PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
-		FederatedPSScheme federatedPSScheme = getFederatedScheme();
 		String updFunc = getParam(PS_UPDATE_FUN);
 		String aggFunc = getParam(PS_AGGREGATION_FUN);
+		PSUpdateType updateType = getUpdateType();
+		PSFrequency freq = getFrequency();
+		FederatedPSScheme federatedPSScheme = getFederatedScheme();
+		PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+		boolean weighing = getWeighing();
+		int seed = getSeed();
 
-		// partition federated data
-		DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme)
-				.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
-		List<MatrixObject> pFeatures = result._pFeatures;
-		List<MatrixObject> pLabels = result._pLabels;
-		int workerNum = result._workerNum;
-
-		// calculate runtime balancing
-		int numBatchesPerEpoch = 0;
-		if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
-			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
-		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
-			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
- 		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
-			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
+		if( LOG.isInfoEnabled() ) {
+			LOG.info("[+] Update Type: " + updateType);
+			LOG.info("[+] Frequency: " + freq);
+			LOG.info("[+] Data Partitioning: " + federatedPSScheme);
+			LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
+			LOG.info("[+] Weighing: " + weighing);
+			LOG.info("[+] Seed: " + seed);
 		}
+		
+		// partition federated data
+		DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed)
+			.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
+		int workerNum = result._workerNum;
 
 		// setup threading
 		BasicThreadFactory factory = new BasicThreadFactory.Builder()
-				.namingPattern("workers-pool-thread-%d").build();
+			.namingPattern("workers-pool-thread-%d").build();
 		ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
 
 		// Get the compiled execution context
@@ -166,10 +167,11 @@
 		ListObject model = ec.getListObject(getParam(PS_MODEL));
 		ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC);
 		// Create the local workers
-		int finalNumBatchesPerEpoch = numBatchesPerEpoch;
+		int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
 		List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
-				.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
-				.collect(Collectors.toList());
+			.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighing,
+				getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+			.collect(Collectors.toList());
 
 		if(workerNum != threads.size()) {
 			throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
@@ -177,9 +179,9 @@
 
 		// Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
 		for (int i = 0; i < threads.size(); i++) {
-			threads.get(i).setFeatures(pFeatures.get(i));
-			threads.get(i).setLabels(pLabels.get(i));
-			threads.get(i).setup();
+			threads.get(i).setFeatures(result._pFeatures.get(i));
+			threads.get(i).setLabels(result._pLabels.get(i));
+			threads.get(i).setup(result._weighingFactors.get(i));
 		}
 
 		try {
@@ -395,14 +397,14 @@
 	}
 
 	private PSRuntimeBalancing getRuntimeBalancing() {
-		if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+		if (!getParameterMap().containsKey(PS_FED_RUNTIME_BALANCING)) {
 			return DEFAULT_RUNTIME_BALANCING;
 		}
 		try {
-			return PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+			return PSRuntimeBalancing.valueOf(getParam(PS_FED_RUNTIME_BALANCING));
 		} catch (IllegalArgumentException e) {
 			throw new DMLRuntimeException(String.format("Paramserv function: "
-					+ "not support '%s' runtime balancing.", getParam(PS_RUNTIME_BALANCING)));
+				+ "not support '%s' runtime balancing.", getParam(PS_FED_RUNTIME_BALANCING)));
 		}
 	}
 
@@ -507,4 +509,32 @@
 		}
 		return federated_scheme;
 	}
+
+	/**
+	 * Calculates the number of batches per epoch depending on the balance metrics and the runtime balancing
+	 *
+	 * @param runtimeBalancing the runtime balancing
+	 * @param balanceMetrics the balance metrics calculated during data partitioning
+	 * @return numBatchesPerEpoch
+	 */
+	private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
+		int numBatchesPerEpoch = 0;
+		if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+			numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
+		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
+				|| runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+			numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
+		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+			numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+		}
+		return numBatchesPerEpoch;
+	}
+
+	private boolean getWeighing() {
+		return getParameterMap().containsKey(PS_FED_WEIGHING) && Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+	}
+
+	private int getSeed() {
+		return (getParameterMap().containsKey(PS_SEED)) ? Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 6a52fc4..a00e8dc 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,44 +54,45 @@
 	private final String _freq;
 	private final String _scheme;
 	private final String _runtime_balancing;
+	private final String _weighing;
 	private final String _data_distribution;
+	private final int _seed;
 
 	// parameters
 	@Parameterized.Parameters
 	public static Collection<Object[]> parameters() {
 		return Arrays.asList(new Object[][] {
 			// Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency
-
 			// basic functionality
-			{"TwoNN", 2, 4, 1, 4, 0.01,       "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
-			{"CNN",   2, 4, 1, 4, 0.01,       "BSP", "EPOCH", "SHUFFLE",             "NONE" ,     "IMBALANCED"},
-			{"CNN",   2, 4, 1, 4, 0.01,       "ASP", "BATCH", "REPLICATE_TO_MAX",    "RUN_MIN" ,  "IMBALANCED"},
-			{"TwoNN", 2, 4, 1, 4, 0.01,       "ASP", "EPOCH", "BALANCE_TO_AVG",      "CYCLE_MAX", "IMBALANCED"},
-			{"TwoNN", 5, 1000, 100, 2, 0.01,  "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" ,     "BALANCED"},
+			{"TwoNN",	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"RUN_MIN" ,	"true",	"IMBALANCED",	200},
+			{"CNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "SHUFFLE", 				"NONE" , 		"true",	"IMBALANCED", 	200},
+			{"CNN",		2, 4, 1, 4, 0.01, 		"ASP", "BATCH", "REPLICATE_TO_MAX", 	"RUN_MIN" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"ASP", "EPOCH", "BALANCE_TO_AVG", 		"CYCLE_MAX" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE" , 		"true",	"BALANCED",		200},
 
-			/*
-				// runtime balancing
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"RUN_MIN" , 	"IMBALANCED"},
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"RUN_MIN" , 	"IMBALANCED"},
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG" , 	"IMBALANCED"},
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG" , 	"IMBALANCED"},
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX" , 	"IMBALANCED"},
-				{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX" , 	"IMBALANCED"},
+			/* // runtime balancing
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"RUN_MIN" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"RUN_MIN" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX" ,	"true", "IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX" ,	"true", "IMBALANCED",	200},
 
-				// data partitioning
-				{"TwoNN", 2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SHUFFLE", 				"CYCLE_AVG" , 	"IMBALANCED"},
-				{"TwoNN", 2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "REPLICATE_TO_MAX",	 	"NONE" , 		"IMBALANCED"},
-				{"TwoNN", 2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SUBSAMPLE_TO_MIN",		"NONE" , 		"IMBALANCED"},
-				{"TwoNN", 2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "BALANCE_TO_AVG",		"NONE" , 		"IMBALANCED"},
+			// data partitioning
+			{"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SHUFFLE", 				"CYCLE_AVG" , 	"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "REPLICATE_TO_MAX",	 	"NONE" , 		"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SUBSAMPLE_TO_MIN",		"NONE" , 		"true",	"IMBALANCED",	200},
+			{"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "BALANCE_TO_AVG",		"NONE" , 		"true",	"IMBALANCED",	200},
 
-				// balanced tests
-				{"CNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"NONE" , 		"BALANCED"}
-			 */
+			// balanced tests
+			{"CNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"NONE" , 		"true",	"BALANCED",		200} */
+
 		});
 	}
 
 	public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size,
-		int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String data_distribution) {
+		int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighing, String data_distribution, int seed) {
+
 		_networkType = networkType;
 		_numFederatedWorkers = numFederatedWorkers;
 		_dataSetSize = dataSetSize;
@@ -102,7 +103,9 @@
 		_freq = freq;
 		_scheme = scheme;
 		_runtime_balancing = runtime_balancing;
+		_weighing = weighing;
 		_data_distribution = data_distribution;
+		_seed = seed;
 	}
 
 	@Override
@@ -185,11 +188,12 @@
 					"freq=" + _freq,
 					"scheme=" + _scheme,
 					"runtime_balancing=" + _runtime_balancing,
+					"weighing=" + _weighing,
 					"network_type=" + _networkType,
 					"channels=" + C,
 					"hin=" + Hin,
 					"win=" + Win,
-					"seed=" + 25));
+					"seed=" + _seed));
 
 			programArgs = programArgsList.toArray(new String[0]);
 			LOG.debug(runTest(null));
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 69c7e76..0f9ae63 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -163,7 +163,7 @@
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing,
+                 int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing,
                  double eta, int C, int Hin, int Win,
                  int seed = -1)
     return (list[unknown] model) {
@@ -211,7 +211,7 @@
     upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
     agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed)
 }
 
 /*
diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 10d2cc7..5176cca 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,10 +26,12 @@
 features = read($features)
 labels = read($labels)
 
+print($weighing)
+
 if($network_type == "TwoNN") {
-  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed)
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
 }
 else {
-  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed)
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
 }
 print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 9bd49d8..a6dc6f2 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing,
+                 int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing,
                  double eta, int seed = -1)
     return (list[unknown] model) {
 
@@ -155,7 +155,7 @@
     upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
     agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed)
 }
 
 /*