| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.controlprogram.paramserv; |
| |
| import static org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX; |
| |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.stream.Collectors; |
| |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.parser.DataIdentifier; |
| import org.apache.sysds.parser.Statement; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; |
| |
| // TODO use the validate features and labels to calculate the model precision when training |
| public abstract class PSWorker implements Serializable |
| { |
| private static final long serialVersionUID = -3510485051178200118L; |
| |
| protected int _workerID; |
| protected int _epochs; |
| protected long _batchSize; |
| protected ExecutionContext _ec; |
| protected ParamServer _ps; |
| protected DataIdentifier _output; |
| protected FunctionCallCPInstruction _inst; |
| protected MatrixObject _features; |
| protected MatrixObject _labels; |
| protected String _updFunc; |
| protected Statement.PSFrequency _freq; |
| |
| protected PSWorker() {} |
| |
| protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { |
| _workerID = workerID; |
| _updFunc = updFunc; |
| _freq = freq; |
| _epochs = epochs; |
| _batchSize = batchSize; |
| _ec = ec; |
| _ps = ps; |
| setupUpdateFunction(updFunc, ec); |
| } |
| |
| protected void setupUpdateFunction(String updFunc, ExecutionContext ec) { |
| // Get the update function |
| String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX); |
| String ns = cfn[0]; |
| String fname = cfn[1]; |
| FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname); |
| ArrayList<DataIdentifier> inputs = func.getInputParams(); |
| ArrayList<DataIdentifier> outputs = func.getOutputParams(); |
| CPOperand[] boundInputs = inputs.stream() |
| .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) |
| .toArray(CPOperand[]::new); |
| ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) |
| .collect(Collectors.toCollection(ArrayList::new)); |
| _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, |
| func.getInputParamNames(), outputNames, "update function"); |
| |
| // Check the inputs of the update function |
| checkInput(false, inputs, DataType.MATRIX, Statement.PS_FEATURES); |
| checkInput(false, inputs, DataType.MATRIX, Statement.PS_LABELS); |
| checkInput(false, inputs, DataType.LIST, Statement.PS_MODEL); |
| checkInput(true, inputs, DataType.LIST, Statement.PS_HYPER_PARAMS); |
| |
| // Check the output of the update function |
| if (outputs.size() != 1) { |
| throw new DMLRuntimeException(String.format("The output of the '%s' function " |
| + "should provide one list containing the gradients.", updFunc)); |
| } |
| if (outputs.get(0).getDataType() != DataType.LIST) { |
| throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", updFunc)); |
| } |
| _output = outputs.get(0); |
| } |
| |
| private void checkInput(boolean optional, ArrayList<DataIdentifier> inputs, DataType dt, String pname) { |
| if (optional && inputs.stream().noneMatch(input -> pname.equals(input.getName()))) { |
| // We do not need to check if the input is optional and is not provided |
| return; |
| } |
| if (inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) { |
| throw new DMLRuntimeException(String.format("The '%s' function should provide " |
| + "an input of '%s' type named '%s'.", _updFunc, dt, pname)); |
| } |
| } |
| |
| public void setFeatures(MatrixObject features) { |
| _features = features; |
| } |
| |
| public void setLabels(MatrixObject labels) { |
| _labels = labels; |
| } |
| |
| public MatrixObject getFeatures() { |
| return _features; |
| } |
| |
| public MatrixObject getLabels() { |
| return _labels; |
| } |
| |
| public abstract String getWorkerName(); |
| |
| /** |
| * ----- The following methods are dedicated to statistics ------------- |
| */ |
| protected abstract void incWorkerNumber(); |
| |
| protected abstract void accLocalModelUpdateTime(Timing time); |
| |
| protected abstract void accBatchIndexingTime(Timing time); |
| |
| protected abstract void accGradientComputeTime(Timing time); |
| |
| protected void accNumEpochs(int n) { |
| //do nothing |
| } |
| |
| protected void accNumBatches(int n) { |
| //do nothing |
| } |
| } |