[SYSTEMDS-2550] Federated parameter server scaling and weight handling
Closes #1141.
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 5171f21..05bfc48 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -289,8 +289,8 @@
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS,
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
- Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
- Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+ Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
+ Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
@@ -308,9 +308,11 @@
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional);
checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
- checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional);
+ checkStringParam(true, fname, Statement.PS_FED_RUNTIME_BALANCING, conditional);
+ checkStringParam(true, fname, Statement.PS_FED_WEIGHING, conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional);
+ checkDataValueType(true, fname, Statement.PS_SEED, DataType.SCALAR, ValueType.INT64, conditional);
// set output characteristics
output.setDataType(DataType.LIST);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index 6767d85..9104246 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -70,6 +70,7 @@
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
+ public static final String PS_SEED = "seed";
public enum PSModeType {
FEDERATED, LOCAL, REMOTE_SPARK
}
@@ -87,9 +88,10 @@
public enum PSFrequency {
BATCH, EPOCH
}
- public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+ public static final String PS_FED_WEIGHING = "weighing";
+ public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing";
public enum PSRuntimeBalancing {
- NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH
+ NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
}
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 54a45d0..f09b5c2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -228,7 +228,7 @@
return val;
}
- protected final double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
+ protected static double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
throw new NotImplementedException("This Method was implemented incorrectly");
// final int numCols = getNumCols();
// final int valOff = valIx * numCols;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 393b131..48249db 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -24,6 +24,8 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.Statement.PSFrequency;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
@@ -37,13 +39,17 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.util.ProgramConverter;
import java.util.ArrayList;
@@ -58,21 +64,29 @@
public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
private static final long serialVersionUID = 6846648059569648791L;
protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
-
- Statement.PSRuntimeBalancing _runtimeBalancing;
- FederatedData _featuresData;
- FederatedData _labelsData;
- final long _localStartBatchNumVarID;
- final long _modelVarID;
- int _numBatchesPerGlobalEpoch;
- int _possibleBatchesPerLocalEpoch;
- boolean _cycleStartAt0 = false;
- public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {
+ private FederatedData _featuresData;
+ private FederatedData _labelsData;
+ private final long _localStartBatchNumVarID;
+ private final long _modelVarID;
+
+ // runtime balancing
+ private PSRuntimeBalancing _runtimeBalancing;
+ private int _numBatchesPerEpoch;
+ private int _possibleBatchesPerLocalEpoch;
+ private boolean _weighing;
+ private double _weighingFactor = 1;
+ private boolean _cycleStartAt0 = false;
+
+ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
+ PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize,
+ int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps)
+ {
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
- _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+ _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
+ _weighing = weighing;
// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
_modelVarID = FederationUtils.getNextFedDataID();
@@ -80,65 +94,72 @@
/**
* Sets up the federated worker and control thread
+ *
+ * @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled
*/
- public void setup() {
+ public void setup(double weighingFactor) {
// prepare features and labels
_featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];
- // calculate number of batches and get data size
+ // weighing factor is always set, but only used when weighing is specified
+ _weighingFactor = weighingFactor;
+
+ // different runtime balancing calculations
long dataSize = _features.getNumRows();
- _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
- if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN
- || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG
- || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) {
- _numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch;
+
+ // calculate scaled batch size if balancing via batch size.
+ // In some cases there will be some cycling
+ if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+ _batchSize = (int) Math.ceil((double) dataSize / _numBatchesPerEpoch);
}
- if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH
- || _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
- throw new NotImplementedException();
+ // Calculate possible batches with batch size
+ _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
+
+ // If no runtime balancing is specified, just run possible number of batches
+ // WARNING: Will get stuck on miss match
+ if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+ _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
}
+ LOG.info("Setup config for worker " + this.getWorkerName());
+ LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch
+ + " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+
// serialize program
// create program blocks for the instruction filtering
String programSerialized;
- ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+ ArrayList<ProgramBlock> pbs = new ArrayList<>();
BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
- programBlocks.add(gradientProgramBlock);
+ pbs.add(gradientProgramBlock);
- if(_freq == Statement.PSFrequency.EPOCH) {
+ if(_freq == PSFrequency.EPOCH) {
BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
- programBlocks.add(aggProgramBlock);
+ pbs.add(aggProgramBlock);
}
- StringBuilder sb = new StringBuilder();
- sb.append(PROG_BEGIN);
- sb.append( NEWLINE );
- sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
- programBlocks,
- new HashMap<>(),
- false
- ));
- sb.append(PROG_END);
- programSerialized = sb.toString();
+ programSerialized = InstructionUtils.concatStrings(
+ PROG_BEGIN, NEWLINE,
+ ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>(), false),
+ PROG_END);
// write program and meta data to worker
Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
new SetupFederatedWorker(_batchSize,
- dataSize,
- _possibleBatchesPerLocalEpoch,
- programSerialized,
- _inst.getNamespace(),
- _inst.getFunctionName(),
- _ps.getAggInst().getFunctionName(),
- _ec.getListObject("hyperparams"),
- _localStartBatchNumVarID,
- _modelVarID
+ dataSize,
+ _possibleBatchesPerLocalEpoch,
+ programSerialized,
+ _inst.getNamespace(),
+ _inst.getFunctionName(),
+ _ps.getAggInst().getFunctionName(),
+ _ec.getListObject("hyperparams"),
+ _localStartBatchNumVarID,
+ _modelVarID
)
));
@@ -286,12 +307,23 @@
return _ps.pull(_workerID);
}
- protected void pushGradients(ListObject gradients) {
+ protected void scaleAndPushGradients(ListObject gradients) {
+ // scale gradients - must only include MatrixObjects
+ if(_weighing && _weighingFactor != 1) {
+ gradients.getData().parallelStream().forEach((matrix) -> {
+ MatrixObject matrixObject = (MatrixObject) matrix;
+ MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations(
+ new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock());
+ matrixObject.acquireModify(input);
+ matrixObject.release();
+ });
+ }
+
// Push the gradients to ps
_ps.push(_workerID, gradients);
}
- static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
+ protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
}
@@ -300,18 +332,18 @@
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
- int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+ int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
- for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) {
+ for (int batchCounter = 0; batchCounter < _numBatchesPerEpoch; batchCounter++) {
int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum);
- pushGradients(gradients);
+ scaleAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
+ LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum);
}
- if( LOG.isInfoEnabled() )
- LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+ LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
}
}
@@ -327,15 +359,14 @@
*/
protected void computeWithEpochUpdates() {
for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
- int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+ int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
// Pull the global parameters from ps
ListObject model = pullModel();
- ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, localStartBatchNum, true);
- pushGradients(gradients);
-
- if( LOG.isInfoEnabled() )
- LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+ ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
+ scaleAndPushGradients(gradients);
+
+ LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
@@ -424,12 +455,12 @@
ArrayList<DataIdentifier> inputs = func.getInputParams();
ArrayList<DataIdentifier> outputs = func.getOutputParams();
CPOperand[] boundInputs = inputs.stream()
- .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
+ .collect(Collectors.toCollection(ArrayList::new));
Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
- func.getInputParamNames(), outputNames, "gradient function");
+ func.getInputParamNames(), outputNames, "gradient function");
DataIdentifier gradientsOutput = outputs.get(0);
// recreate aggregation instruction and output if needed
@@ -440,12 +471,12 @@
inputs = func.getInputParams();
outputs = func.getOutputParams();
boundInputs = inputs.stream()
- .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
outputNames = outputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
+ .collect(Collectors.toCollection(ArrayList::new));
aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
- func.getInputParamNames(), outputNames, "aggregation function");
+ func.getInputParamNames(), outputNames, "aggregation function");
aggregationOutput = outputs.get(0);
}
@@ -492,8 +523,6 @@
ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
- if( LOG.isInfoEnabled() )
- LOG.info("[+]" + " completed batch " + localBatchNum);
}
// model clean up
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 460faba..34e94f0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
@@ -35,13 +34,25 @@
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Balance to Avg Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global average number of examples is taken
+ * and the worker subsamples or replicates data to match that number of examples. See the other federated schemes.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ BalanceMetrics balanceMetricsBefore = getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetricsBefore);
- int average_num_rows = (int) Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+ int average_num_rows = (int) balanceMetricsBefore._avgRows;
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
@@ -49,7 +60,7 @@
FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, average_num_rows)));
+ featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, average_num_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -66,7 +77,7 @@
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -74,10 +85,12 @@
*/
private static class balanceDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID = 6631958250346625546L;
+ private final int _seed;
private final int _average_num_rows;
-
- protected balanceDataOnFederatedWorker(long[] inIDs, int average_num_rows) {
+
+ protected balanceDataOnFederatedWorker(long[] inIDs, int seed, int average_num_rows) {
super(inIDs);
+ _seed = seed;
_average_num_rows = average_num_rows;
}
@@ -88,14 +101,14 @@
if(features.getNumRows() > _average_num_rows) {
// generate subsampling matrix
- MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), _seed);
subsampleTo(features, subsampleMatrixBlock);
subsampleTo(labels, subsampleMatrixBlock);
}
else if(features.getNumRows() < _average_num_rows) {
int num_rows_needed = _average_num_rows - Math.toIntExact(features.getNumRows());
// generate replication matrix
- MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed);
replicateTo(features, replicateMatrixBlock);
replicateTo(labels, replicateMatrixBlock);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index f5c9638..e00923e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,16 +45,31 @@
public final List<MatrixObject> _pLabels;
public final int _workerNum;
public final BalanceMetrics _balanceMetrics;
+ public final List<Double> _weighingFactors;
- public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum, BalanceMetrics balanceMetrics) {
- this._pFeatures = pFeatures;
- this._pLabels = pLabels;
- this._workerNum = workerNum;
- this._balanceMetrics = balanceMetrics;
+
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double> weighingFactors) {
+ _pFeatures = pFeatures;
+ _pLabels = pLabels;
+ _workerNum = workerNum;
+ _balanceMetrics = balanceMetrics;
+ _weighingFactors = weighingFactors;
}
}
- public abstract Result doPartitioning(MatrixObject features, MatrixObject labels);
+ public static final class BalanceMetrics {
+ public final long _minRows;
+ public final long _avgRows;
+ public final long _maxRows;
+
+ public BalanceMetrics(long minRows, long avgRows, long maxRows) {
+ _minRows = minRows;
+ _avgRows = avgRows;
+ _maxRows = maxRows;
+ }
+ }
+
+ public abstract Result partition(MatrixObject features, MatrixObject labels, int seed);
/**
* Takes a row federated Matrix and slices it into a matrix for each worker
@@ -110,16 +125,12 @@
return new BalanceMetrics(minRows, sum / slices.size(), maxRows);
}
- public static final class BalanceMetrics {
- public final long _minRows;
- public final long _avgRows;
- public final long _maxRows;
-
- public BalanceMetrics(long minRows, long avgRows, long maxRows) {
- this._minRows = minRows;
- this._avgRows = avgRows;
- this._maxRows = maxRows;
- }
+ static List<Double> getWeighingFactors(List<MatrixObject> pFeatures, BalanceMetrics balanceMetrics) {
+ List<Double> weighingFactors = new ArrayList<>();
+ pFeatures.forEach((feature) -> {
+ weighingFactors.add((double) feature.getNumRows() / balanceMetrics._avgRows);
+ });
+ return weighingFactors;
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index d1ebb6c..ce2f954 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -24,10 +24,11 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
public class FederatedDataPartitioner {
-
private final DataPartitionFederatedScheme _scheme;
+ private final int _seed;
- public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme, int seed) {
+ _seed = seed;
switch (scheme) {
case KEEP_DATA_ON_WORKER:
_scheme = new KeepDataOnWorkerFederatedScheme();
@@ -50,6 +51,6 @@
}
public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject features, MatrixObject labels) {
- return _scheme.doPartitioning(features, labels);
+ return _scheme.partition(features, labels, _seed);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index e306f25..afbaf4d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -22,11 +22,20 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import java.util.List;
+/**
+ * Keep Data on Worker Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * All entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+ BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetrics);
+ return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index 068cfa9..a1b8f6c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -34,11 +34,23 @@
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Replicate to Max Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global maximum number of examples is taken
+ * and the worker replicates data to match that number of examples. The generation is done by multiplying with a
+ * Permutation Matrix with a global seed. These selected examples are appended to the original data.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures));
int max_rows = 0;
for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@
FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, max_rows)));
+ featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, max_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -76,10 +88,12 @@
*/
private static class replicateDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID = -6930898456315100587L;
+ private final int _seed;
private final int _max_rows;
-
- protected replicateDataOnFederatedWorker(long[] inIDs, int max_rows) {
+
+ protected replicateDataOnFederatedWorker(long[] inIDs, int seed, int max_rows) {
super(inIDs);
+ _seed = seed;
_max_rows = max_rows;
}
@@ -92,7 +106,7 @@
if(features.getNumRows() < _max_rows) {
int num_rows_needed = _max_rows - Math.toIntExact(features.getNumRows());
// generate replication matrix
- MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed);
replicateTo(features, replicateMatrixBlock);
replicateTo(labels, replicateMatrixBlock);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 65ef69d..1920593 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -33,11 +33,23 @@
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Shuffle Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case it is shuffled by generating a permutation
+ * matrix with a global seed and doing a mat mult.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures, balanceMetrics);
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
@@ -45,7 +57,7 @@
FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()})));
+ featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed)));
try {
FederatedResponse response = udfResponse.get();
@@ -57,7 +69,7 @@
}
}
- return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors);
}
/**
@@ -65,9 +77,11 @@
*/
private static class shuffleDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID = 3228664618781333325L;
+ private final int _seed;
- protected shuffleDataOnFederatedWorker(long[] inIDs) {
+ protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) {
super(inIDs);
+ _seed = seed;
}
@Override
@@ -76,7 +90,7 @@
MatrixObject labels = (MatrixObject) data[1];
// generate permutation matrix
- MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), _seed);
shuffle(features, permutationMatrixBlock);
shuffle(labels, permutationMatrixBlock);
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 9b62cc8..937c37e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -34,11 +34,23 @@
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Subsample to Min Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the global minimum number of examples is taken
+ * and the worker subsamples data to match that number of examples. The subsampling is done by multiplying with a
+ * Permutation Matrix with a global seed.
+ *
+ * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures));
int min_rows = Integer.MAX_VALUE;
for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@
FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, min_rows)));
+ featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, min_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -76,10 +88,12 @@
*/
private static class subsampleDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID = 2213790859544004286L;
+ private final int _seed;
private final int _min_rows;
-
- protected subsampleDataOnFederatedWorker(long[] inIDs, int min_rows) {
+
+ protected subsampleDataOnFederatedWorker(long[] inIDs, int seed, int min_rows) {
super(inIDs);
+ _seed = seed;
_min_rows = min_rows;
}
@@ -91,7 +105,7 @@
// subsample down to minimum
if(features.getNumRows() > _min_rows) {
// generate subsampling matrix
- MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), _seed);
subsampleTo(features, subsampleMatrixBlock);
subsampleTo(labels, subsampleMatrixBlock);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a2b8d9f..a66e039 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,6 +19,17 @@
package org.apache.sysds.runtime.instructions.cp;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
import static org.apache.sysds.parser.Statement.PS_EPOCHS;
@@ -32,18 +43,9 @@
import static org.apache.sysds.parser.Statement.PS_SCHEME;
import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
+import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_SEED;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
@@ -121,37 +123,36 @@
}
private void runFederated(ExecutionContext ec) {
- System.out.println("PARAMETER SERVER");
- System.out.println("[+] Running in federated mode");
+ LOG.info("PARAMETER SERVER");
+ LOG.info("[+] Running in federated mode");
// get inputs
- PSFrequency freq = getFrequency();
- PSUpdateType updateType = getUpdateType();
- PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
- FederatedPSScheme federatedPSScheme = getFederatedScheme();
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
+ PSUpdateType updateType = getUpdateType();
+ PSFrequency freq = getFrequency();
+ FederatedPSScheme federatedPSScheme = getFederatedScheme();
+ PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+ boolean weighing = getWeighing();
+ int seed = getSeed();
- // partition federated data
- DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme)
- .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
- List<MatrixObject> pFeatures = result._pFeatures;
- List<MatrixObject> pLabels = result._pLabels;
- int workerNum = result._workerNum;
-
- // calculate runtime balancing
- int numBatchesPerEpoch = 0;
- if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
- numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
- } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
- numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
- } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
- numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
+ if( LOG.isInfoEnabled() ) {
+ LOG.info("[+] Update Type: " + updateType);
+ LOG.info("[+] Frequency: " + freq);
+ LOG.info("[+] Data Partitioning: " + federatedPSScheme);
+ LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
+ LOG.info("[+] Weighing: " + weighing);
+ LOG.info("[+] Seed: " + seed);
}
+
+ // partition federated data
+ DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed)
+ .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
+ int workerNum = result._workerNum;
// setup threading
BasicThreadFactory factory = new BasicThreadFactory.Builder()
- .namingPattern("workers-pool-thread-%d").build();
+ .namingPattern("workers-pool-thread-%d").build();
ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
// Get the compiled execution context
@@ -166,10 +167,11 @@
ListObject model = ec.getListObject(getParam(PS_MODEL));
ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC);
// Create the local workers
- int finalNumBatchesPerEpoch = numBatchesPerEpoch;
+ int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
- .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
- .collect(Collectors.toList());
+ .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighing,
+ getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+ .collect(Collectors.toList());
if(workerNum != threads.size()) {
throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
@@ -177,9 +179,9 @@
// Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
for (int i = 0; i < threads.size(); i++) {
- threads.get(i).setFeatures(pFeatures.get(i));
- threads.get(i).setLabels(pLabels.get(i));
- threads.get(i).setup();
+ threads.get(i).setFeatures(result._pFeatures.get(i));
+ threads.get(i).setLabels(result._pLabels.get(i));
+ threads.get(i).setup(result._weighingFactors.get(i));
}
try {
@@ -395,14 +397,14 @@
}
private PSRuntimeBalancing getRuntimeBalancing() {
- if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+ if (!getParameterMap().containsKey(PS_FED_RUNTIME_BALANCING)) {
return DEFAULT_RUNTIME_BALANCING;
}
try {
- return PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+ return PSRuntimeBalancing.valueOf(getParam(PS_FED_RUNTIME_BALANCING));
} catch (IllegalArgumentException e) {
throw new DMLRuntimeException(String.format("Paramserv function: "
- + "not support '%s' runtime balancing.", getParam(PS_RUNTIME_BALANCING)));
+ + "not support '%s' runtime balancing.", getParam(PS_FED_RUNTIME_BALANCING)));
}
}
@@ -507,4 +509,32 @@
}
return federated_scheme;
}
+
+ /**
+ * Calculates the number of batches per epoch depending on the balance metrics and the runtime balancing
+ *
+ * @param runtimeBalancing the runtime balancing
+ * @param balanceMetrics the balance metrics calculated during data partitioning
+ * @return numBatchesPerEpoch
+ */
+ private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
+ int numBatchesPerEpoch = 0;
+ if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+ numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
+ || runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+ numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+ numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+ }
+ return numBatchesPerEpoch;
+ }
+
+ private boolean getWeighing() {
+ return getParameterMap().containsKey(PS_FED_WEIGHING) && Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+ }
+
+ private int getSeed() {
+ return (getParameterMap().containsKey(PS_SEED)) ? Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 6a52fc4..a00e8dc 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,44 +54,45 @@
private final String _freq;
private final String _scheme;
private final String _runtime_balancing;
+ private final String _weighing;
private final String _data_distribution;
+ private final int _seed;
// parameters
@Parameterized.Parameters
public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[][] {
// Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency
-
// basic functionality
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
- {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "IMBALANCED"},
- {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX", "IMBALANCED"},
- {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "true", "IMBALANCED", 200},
+ {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "true", "IMBALANCED", 200},
+ {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED", 200},
- /*
- // runtime balancing
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
+ /* // runtime balancing
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
- // data partitioning
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"},
+ // data partitioning
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "true", "IMBALANCED", 200},
- // balanced tests
- {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}
- */
+ // balanced tests
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED", 200} */
+
});
}
public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size,
- int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String data_distribution) {
+ int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighing, String data_distribution, int seed) {
+
_networkType = networkType;
_numFederatedWorkers = numFederatedWorkers;
_dataSetSize = dataSetSize;
@@ -102,7 +103,9 @@
_freq = freq;
_scheme = scheme;
_runtime_balancing = runtime_balancing;
+ _weighing = weighing;
_data_distribution = data_distribution;
+ _seed = seed;
}
@Override
@@ -185,11 +188,12 @@
"freq=" + _freq,
"scheme=" + _scheme,
"runtime_balancing=" + _runtime_balancing,
+ "weighing=" + _weighing,
"network_type=" + _networkType,
"channels=" + C,
"hin=" + Hin,
"win=" + Win,
- "seed=" + 25));
+ "seed=" + _seed));
programArgs = programArgsList.toArray(new String[0]);
LOG.debug(runTest(null));
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 69c7e76..0f9ae63 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -163,7 +163,7 @@
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing,
+ int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing,
double eta, int C, int Hin, int Win,
int seed = -1)
return (list[unknown] model) {
@@ -211,7 +211,7 @@
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed)
}
/*
diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 10d2cc7..5176cca 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,10 +26,12 @@
features = read($features)
labels = read($labels)
+print($weighing)
+
if($network_type == "TwoNN") {
- model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed)
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
}
else {
- model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed)
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
}
print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 9bd49d8..a6dc6f2 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing,
+ int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing,
double eta, int seed = -1)
return (list[unknown] model) {
@@ -155,7 +155,7 @@
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed)
}
/*