| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.controlprogram.paramserv; |
| |
| import org.apache.commons.lang.NotImplementedException; |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysds.parser.DataIdentifier; |
| import org.apache.sysds.parser.Statement; |
| import org.apache.sysds.parser.Statement.PSFrequency; |
| import org.apache.sysds.parser.Statement.PSRuntimeBalancing; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.ProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedData; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; |
| import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; |
| import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; |
| import org.apache.sysds.runtime.functionobjects.Multiply; |
| import org.apache.sysds.runtime.instructions.Instruction; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.Data; |
| import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.IntObject; |
| import org.apache.sysds.runtime.instructions.cp.ListObject; |
| import org.apache.sysds.runtime.instructions.cp.StringObject; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; |
| import org.apache.sysds.runtime.util.ProgramConverter; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.concurrent.Callable; |
| import java.util.concurrent.Future; |
| import java.util.stream.Collectors; |
| |
| import static org.apache.sysds.runtime.util.ProgramConverter.*; |
| |
| public class FederatedPSControlThread extends PSWorker implements Callable<Void> { |
| private static final long serialVersionUID = 6846648059569648791L; |
| protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName()); |
| |
| private FederatedData _featuresData; |
| private FederatedData _labelsData; |
| private final long _localStartBatchNumVarID; |
| private final long _modelVarID; |
| |
| // runtime balancing |
| private PSRuntimeBalancing _runtimeBalancing; |
| private int _numBatchesPerEpoch; |
| private int _possibleBatchesPerLocalEpoch; |
| private boolean _weighing; |
| private double _weighingFactor = 1; |
| private boolean _cycleStartAt0 = false; |
| |
| public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, |
| PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize, |
| int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) |
| { |
| super(workerID, updFunc, freq, epochs, batchSize, ec, ps); |
| |
| _numBatchesPerEpoch = numBatchesPerGlobalEpoch; |
| _runtimeBalancing = runtimeBalancing; |
| _weighing = weighing; |
| // generate the IDs for model and batch counter. These get overwritten on the federated worker each time |
| _localStartBatchNumVarID = FederationUtils.getNextFedDataID(); |
| _modelVarID = FederationUtils.getNextFedDataID(); |
| } |
| |
| /** |
| * Sets up the federated worker and control thread |
| * |
| * @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled |
| */ |
| public void setup(double weighingFactor) { |
| // prepare features and labels |
| _featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0]; |
| _labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0]; |
| |
| // weighing factor is always set, but only used when weighing is specified |
| _weighingFactor = weighingFactor; |
| |
| // different runtime balancing calculations |
| long dataSize = _features.getNumRows(); |
| |
| // calculate scaled batch size if balancing via batch size. |
| // In some cases there will be some cycling |
| if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) { |
| _batchSize = (int) Math.ceil((double) dataSize / _numBatchesPerEpoch); |
| } |
| |
| // Calculate possible batches with batch size |
| _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize); |
| |
| // If no runtime balancing is specified, just run possible number of batches |
| // WARNING: Will get stuck on miss match |
| if(_runtimeBalancing == PSRuntimeBalancing.NONE) { |
| _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch; |
| } |
| |
| LOG.info("Setup config for worker " + this.getWorkerName()); |
| LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch |
| + " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor); |
| |
| // serialize program |
| // create program blocks for the instruction filtering |
| String programSerialized; |
| ArrayList<ProgramBlock> pbs = new ArrayList<>(); |
| |
| BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram()); |
| gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst))); |
| pbs.add(gradientProgramBlock); |
| |
| if(_freq == PSFrequency.EPOCH) { |
| BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram()); |
| aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst()))); |
| pbs.add(aggProgramBlock); |
| } |
| |
| programSerialized = InstructionUtils.concatStrings( |
| PROG_BEGIN, NEWLINE, |
| ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>(), false), |
| PROG_END); |
| |
| // write program and meta data to worker |
| Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( |
| new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), |
| new SetupFederatedWorker(_batchSize, |
| dataSize, |
| _possibleBatchesPerLocalEpoch, |
| programSerialized, |
| _inst.getNamespace(), |
| _inst.getFunctionName(), |
| _ps.getAggInst().getFunctionName(), |
| _ec.getListObject("hyperparams"), |
| _localStartBatchNumVarID, |
| _modelVarID |
| ) |
| )); |
| |
| try { |
| FederatedResponse response = udfResponse.get(); |
| if(!response.isSuccessful()) |
| throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed"); |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage()); |
| } |
| } |
| |
| /** |
| * Setup UDF executed on the federated worker |
| */ |
| private static class SetupFederatedWorker extends FederatedUDF { |
| private static final long serialVersionUID = -3148991224792675607L; |
| private final long _batchSize; |
| private final long _dataSize; |
| private final int _possibleBatchesPerLocalEpoch; |
| private final String _programString; |
| private final String _namespace; |
| private final String _gradientsFunctionName; |
| private final String _aggregationFunctionName; |
| private final ListObject _hyperParams; |
| private final long _batchCounterVarID; |
| private final long _modelVarID; |
| |
| protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, |
| String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, |
| ListObject hyperParams, long batchCounterVarID, long modelVarID) |
| { |
| super(new long[]{}); |
| _batchSize = batchSize; |
| _dataSize = dataSize; |
| _possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch; |
| _programString = programString; |
| _namespace = namespace; |
| _gradientsFunctionName = gradientsFunctionName; |
| _aggregationFunctionName = aggregationFunctionName; |
| _hyperParams = hyperParams; |
| _batchCounterVarID = batchCounterVarID; |
| _modelVarID = modelVarID; |
| } |
| |
| @Override |
| public FederatedResponse execute(ExecutionContext ec, Data... data) { |
| // parse and set program |
| ec.setProgram(ProgramConverter.parseProgram(_programString, 0, false)); |
| |
| // set variables to ec |
| ec.setVariable(Statement.PS_FED_BATCH_SIZE, new IntObject(_batchSize)); |
| ec.setVariable(Statement.PS_FED_DATA_SIZE, new IntObject(_dataSize)); |
| ec.setVariable(Statement.PS_FED_POSS_BATCHES_LOCAL, new IntObject(_possibleBatchesPerLocalEpoch)); |
| ec.setVariable(Statement.PS_FED_NAMESPACE, new StringObject(_namespace)); |
| ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName)); |
| ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName)); |
| ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams); |
| ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new IntObject(_batchCounterVarID)); |
| ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID)); |
| |
| return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); |
| } |
| } |
| |
| /** |
| * cleans up the execution context of the federated worker |
| */ |
| public void teardown() { |
| // write program and meta data to worker |
| Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( |
| new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), |
| new TeardownFederatedWorker() |
| )); |
| |
| try { |
| FederatedResponse response = udfResponse.get(); |
| if(!response.isSuccessful()) |
| throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed"); |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage()); |
| } |
| } |
| |
| /** |
| * Teardown UDF executed on the federated worker |
| */ |
| private static class TeardownFederatedWorker extends FederatedUDF { |
| private static final long serialVersionUID = -153650281873318969L; |
| |
| protected TeardownFederatedWorker() { |
| super(new long[]{}); |
| } |
| |
| @Override |
| public FederatedResponse execute(ExecutionContext ec, Data... data) { |
| // remove variables from ec |
| ec.removeVariable(Statement.PS_FED_BATCH_SIZE); |
| ec.removeVariable(Statement.PS_FED_DATA_SIZE); |
| ec.removeVariable(Statement.PS_FED_POSS_BATCHES_LOCAL); |
| ec.removeVariable(Statement.PS_FED_NAMESPACE); |
| ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME); |
| ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME); |
| ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID); |
| ec.removeVariable(Statement.PS_FED_MODEL_VARID); |
| ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS); |
| |
| return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); |
| } |
| } |
| |
| /** |
| * Entry point of the functionality |
| * |
| * @return void |
| * @throws Exception incase the execution fails |
| */ |
| @Override |
| public Void call() throws Exception { |
| try { |
| switch (_freq) { |
| case BATCH: |
| computeWithBatchUpdates(); |
| break; |
| /*case NBATCH: |
| computeWithNBatchUpdates(); |
| break; */ |
| case EPOCH: |
| computeWithEpochUpdates(); |
| break; |
| default: |
| throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq)); |
| } |
| } catch (Exception e) { |
| throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e); |
| } |
| teardown(); |
| return null; |
| } |
| |
| protected ListObject pullModel() { |
| // Pull the global parameters from ps |
| return _ps.pull(_workerID); |
| } |
| |
| protected void scaleAndPushGradients(ListObject gradients) { |
| // scale gradients - must only include MatrixObjects |
| if(_weighing && _weighingFactor != 1) { |
| gradients.getData().parallelStream().forEach((matrix) -> { |
| MatrixObject matrixObject = (MatrixObject) matrix; |
| MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations( |
| new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock()); |
| matrixObject.acquireModify(input); |
| matrixObject.release(); |
| }); |
| } |
| |
| // Push the gradients to ps |
| _ps.push(_workerID, gradients); |
| } |
| |
| protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) { |
| return currentLocalBatchNumber % possibleBatchesPerLocalEpoch; |
| } |
| |
| /** |
| * Computes all epochs and updates after each batch |
| */ |
| protected void computeWithBatchUpdates() { |
| for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { |
| int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch; |
| |
| for (int batchCounter = 0; batchCounter < _numBatchesPerEpoch; batchCounter++) { |
| int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); |
| ListObject model = pullModel(); |
| ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum); |
| scaleAndPushGradients(gradients); |
| ParamservUtils.cleanupListObject(model); |
| ParamservUtils.cleanupListObject(gradients); |
| LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum); |
| } |
| LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter); |
| } |
| } |
| |
| /** |
| * Computes all epochs and updates after N batches |
| */ |
| protected void computeWithNBatchUpdates() { |
| throw new NotImplementedException(); |
| } |
| |
| /** |
| * Computes all epochs and updates after each epoch |
| */ |
| protected void computeWithEpochUpdates() { |
| for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { |
| int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch; |
| |
| // Pull the global parameters from ps |
| ListObject model = pullModel(); |
| ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true); |
| scaleAndPushGradients(gradients); |
| |
| LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter); |
| ParamservUtils.cleanupListObject(model); |
| ParamservUtils.cleanupListObject(gradients); |
| } |
| } |
| |
| protected ListObject computeGradientsForNBatches(ListObject model, int numBatchesToCompute, int localStartBatchNum) { |
| return computeGradientsForNBatches(model, numBatchesToCompute, localStartBatchNum, false); |
| } |
| |
| /** |
| * Computes the gradients of n batches on the federated worker and is able to update the model local. |
| * Returns the gradients. |
| * |
| * @param model the current model from the parameter server |
| * @param localStartBatchNum the batch to start from |
| * @param localUpdate whether to update the model locally |
| * |
| * @return the gradient vector |
| */ |
| protected ListObject computeGradientsForNBatches(ListObject model, |
| int numBatchesToCompute, int localStartBatchNum, boolean localUpdate) |
| { |
| // put local start batch num on federated worker |
| Future<FederatedResponse> putBatchCounterResponse = _featuresData.executeFederatedOperation( |
| new FederatedRequest(RequestType.PUT_VAR, _localStartBatchNumVarID, new IntObject(localStartBatchNum))); |
| // put current model on federated worker |
| Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation( |
| new FederatedRequest(RequestType.PUT_VAR, _modelVarID, model)); |
| |
| try { |
| if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful()) |
| throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful"); |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage()); |
| } |
| |
| // create and execute the udf on the remote worker |
| Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( |
| new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), |
| new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), |
| _localStartBatchNumVarID, _modelVarID}, numBatchesToCompute,localUpdate) |
| )); |
| |
| try { |
| Object[] responseData = udfResponse.get().getData(); |
| return (ListObject) responseData[0]; |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage()); |
| } |
| } |
| |
| /** |
| * This is the code that will be executed on the federated Worker when computing one gradients for n batches |
| */ |
| private static class federatedComputeGradientsForNBatches extends FederatedUDF { |
| private static final long serialVersionUID = -3075901536748794832L; |
| int _numBatchesToCompute; |
| boolean _localUpdate; |
| |
| protected federatedComputeGradientsForNBatches(long[] inIDs, int numBatchesToCompute, boolean localUpdate) { |
| super(inIDs); |
| _numBatchesToCompute = numBatchesToCompute; |
| _localUpdate = localUpdate; |
| } |
| |
| @Override |
| public FederatedResponse execute(ExecutionContext ec, Data... data) { |
| // read in data by varid |
| MatrixObject features = (MatrixObject) data[0]; |
| MatrixObject labels = (MatrixObject) data[1]; |
| int localStartBatchNum = (int) ((IntObject) data[2]).getLongValue(); |
| ListObject model = (ListObject) data[3]; |
| |
| // get data from execution context |
| long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue(); |
| long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue(); |
| int possibleBatchesPerLocalEpoch = (int) ((IntObject) ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue(); |
| String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue(); |
| String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue(); |
| String aggregationFuctionName = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue(); |
| |
| // recreate gradient instruction and output |
| FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false); |
| ArrayList<DataIdentifier> inputs = func.getInputParams(); |
| ArrayList<DataIdentifier> outputs = func.getOutputParams(); |
| CPOperand[] boundInputs = inputs.stream() |
| .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) |
| .toArray(CPOperand[]::new); |
| ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) |
| .collect(Collectors.toCollection(ArrayList::new)); |
| Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs, |
| func.getInputParamNames(), outputNames, "gradient function"); |
| DataIdentifier gradientsOutput = outputs.get(0); |
| |
| // recreate aggregation instruction and output if needed |
| Instruction aggregationInstruction = null; |
| DataIdentifier aggregationOutput = null; |
| if(_localUpdate && _numBatchesToCompute > 1) { |
| func = ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, false); |
| inputs = func.getInputParams(); |
| outputs = func.getOutputParams(); |
| boundInputs = inputs.stream() |
| .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) |
| .toArray(CPOperand[]::new); |
| outputNames = outputs.stream().map(DataIdentifier::getName) |
| .collect(Collectors.toCollection(ArrayList::new)); |
| aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs, |
| func.getInputParamNames(), outputNames, "aggregation function"); |
| aggregationOutput = outputs.get(0); |
| } |
| |
| ListObject accGradients = null; |
| int currentLocalBatchNumber = localStartBatchNum; |
| // prepare execution context |
| ec.setVariable(Statement.PS_MODEL, model); |
| for (int batchCounter = 0; batchCounter < _numBatchesToCompute; batchCounter++) { |
| int localBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch); |
| |
| // slice batch from feature and label matrix |
| long begin = localBatchNum * batchSize + 1; |
| long end = Math.min((localBatchNum + 1) * batchSize, dataSize); |
| MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end); |
| MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end); |
| |
| // prepare execution context |
| ec.setVariable(Statement.PS_FEATURES, bFeatures); |
| ec.setVariable(Statement.PS_LABELS, bLabels); |
| |
| // calculate gradients for batch |
| gradientsInstruction.processInstruction(ec); |
| ListObject gradients = ec.getListObject(gradientsOutput.getName()); |
| |
| // accrue the computed gradients - In the single batch case this is just a list copy |
| // is this equivalent for momentum based and AMS prob? |
| accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false); |
| |
| // update the local model with gradients if needed |
| if(_localUpdate && batchCounter < _numBatchesToCompute - 1) { |
| // Invoke the aggregate function |
| assert aggregationInstruction != null; |
| aggregationInstruction.processInstruction(ec); |
| // Get the new model |
| model = ec.getListObject(aggregationOutput.getName()); |
| // Set new model in execution context |
| ec.setVariable(Statement.PS_MODEL, model); |
| // clean up gradients and result |
| ParamservUtils.cleanupListObject(ec, aggregationOutput.getName()); |
| } |
| |
| // clean up |
| ParamservUtils.cleanupListObject(ec, gradientsOutput.getName()); |
| ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); |
| ParamservUtils.cleanupData(ec, Statement.PS_LABELS); |
| ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); |
| } |
| |
| // model clean up |
| ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString()); |
| ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); |
| |
| return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients); |
| } |
| } |
| |
| // Statistics methods |
| @Override |
| public String getWorkerName() { |
| return String.format("Federated worker_%d", _workerID); |
| } |
| |
| @Override |
| protected void incWorkerNumber() { |
| |
| } |
| |
| @Override |
| protected void accLocalModelUpdateTime(Timing time) { |
| |
| } |
| |
| @Override |
| protected void accBatchIndexingTime(Timing time) { |
| |
| } |
| |
| @Override |
| protected void accGradientComputeTime(Timing time) { |
| |
| } |
| } |