blob: 3c5cbb1051ccffb7cbf7853bea2701d65e90a2c4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.reef.examples.group.bgd;
import org.apache.reef.examples.group.bgd.operatornames.*;
import org.apache.reef.examples.group.bgd.parameters.*;
import org.apache.reef.examples.group.bgd.utils.StepSizes;
import org.apache.reef.examples.group.utils.math.DenseVector;
import org.apache.reef.examples.group.utils.math.Vector;
import org.apache.reef.examples.group.utils.timer.Timer;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.Tuple;
import org.apache.reef.io.network.group.api.operators.Broadcast;
import org.apache.reef.io.network.group.api.operators.Reduce;
import org.apache.reef.io.network.group.api.GroupChanges;
import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
import org.apache.reef.io.network.group.api.task.GroupCommClient;
import org.apache.reef.io.network.util.Pair;
import org.apache.reef.io.serialization.Codec;
import org.apache.reef.io.serialization.SerializableCodec;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.Task;
import javax.inject.Inject;
import java.util.ArrayList;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Master task for BGD example.
*/
public class MasterTask implements Task {
public static final String TASK_ID = "MasterTask";
private static final Logger LOG = Logger.getLogger(MasterTask.class.getName());
private final CommunicationGroupClient communicationGroupClient;
private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster;
private final Broadcast.Sender<Vector> modelBroadcaster;
private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
private final Broadcast.Sender<Vector> descentDriectionBroadcaster;
private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
private final Broadcast.Sender<Double> minEtaBroadcaster;
private final boolean ignoreAndContinue;
private final StepSizes ts;
private final double lambda;
private final int maxIters;
private final ArrayList<Double> losses = new ArrayList<>();
private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>();
private final Vector model;
private boolean sendModel = true;
private double minEta = 0;
@Inject
public MasterTask(
final GroupCommClient groupCommClient,
@Parameter(ModelDimensions.class) final int dimensions,
@Parameter(Lambda.class) final double lambda,
@Parameter(Iterations.class) final int maxIters,
@Parameter(EnableRampup.class) final boolean rampup,
final StepSizes ts) {
this.lambda = lambda;
this.maxIters = maxIters;
this.ts = ts;
this.ignoreAndContinue = rampup;
this.model = new DenseVector(dimensions);
this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class);
this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class);
this.lossAndGradientReducer = communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class);
this.modelAndDescentDirectionBroadcaster =
communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class);
this.descentDriectionBroadcaster = communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class);
this.lineSearchEvaluationsReducer = communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class);
this.minEtaBroadcaster = communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class);
}
@Override
public byte[] call(final byte[] memento) throws Exception {
double gradientNorm = Double.MAX_VALUE;
for (int iteration = 1; !converged(iteration, gradientNorm); ++iteration) {
try (Timer t = new Timer("Current Iteration(" + iteration + ")")) {
final Pair<Double, Vector> lossAndGradient = computeLossAndGradient();
losses.add(lossAndGradient.getFirst());
final Vector descentDirection = getDescentDirection(lossAndGradient.getSecond());
updateModel(descentDirection);
gradientNorm = descentDirection.norm2();
}
}
LOG.log(Level.INFO, "OUT: Stop");
controlMessageBroadcaster.send(ControlMessages.Stop);
for (final Double loss : losses) {
LOG.log(Level.INFO, "OUT: LOSS = {0}", loss);
}
return lossCodec.encode(losses);
}
private void updateModel(final Vector descentDirection) throws NetworkException, InterruptedException {
try (Timer t = new Timer("GetDescentDirection + FindMinEta + UpdateModel")) {
final Vector lineSearchEvals = lineSearch(descentDirection);
minEta = findMinEta(model, descentDirection, lineSearchEvals);
model.multAdd(minEta, descentDirection);
}
LOG.log(Level.INFO, "OUT: New Model = {0}", model);
}
private Vector lineSearch(final Vector descentDirection) throws NetworkException, InterruptedException {
Vector lineSearchResults = null;
boolean allDead = false;
do {
try (Timer t = new Timer("LineSearch - Broadcast("
+ (sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)")) {
if (sendModel) {
LOG.log(Level.INFO, "OUT: DoLineSearchWithModel");
controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel);
modelAndDescentDirectionBroadcaster.send(new Pair<>(model, descentDirection));
} else {
LOG.log(Level.INFO, "OUT: DoLineSearch");
controlMessageBroadcaster.send(ControlMessages.DoLineSearch);
descentDriectionBroadcaster.send(descentDirection);
}
final Pair<Vector, Integer> lineSearchEvals = lineSearchEvaluationsReducer.reduce();
if (lineSearchEvals != null) {
final int numExamples = lineSearchEvals.getSecond();
lineSearchResults = lineSearchEvals.getFirst();
lineSearchResults.scale(1.0 / numExamples);
LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", lineSearchResults);
allDead = false;
} else {
allDead = true;
}
}
sendModel = chkAndUpdate();
} while (allDead || !ignoreAndContinue && sendModel);
return lineSearchResults;
}
private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException {
Pair<Double, Vector> returnValue = null;
boolean allDead = false;
do {
try (Timer t = new Timer("Broadcast(" + (sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)")) {
if (sendModel) {
LOG.log(Level.INFO, "OUT: ComputeGradientWithModel");
controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel);
modelBroadcaster.send(model);
} else {
LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta");
controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta);
minEtaBroadcaster.send(minEta);
}
final Pair<Pair<Double, Integer>, Vector> lossAndGradient = lossAndGradientReducer.reduce();
if (lossAndGradient != null) {
final int numExamples = lossAndGradient.getFirst().getSecond();
LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
final double lossPerExample = lossAndGradient.getFirst().getFirst() / numExamples;
LOG.log(Level.INFO, "OUT: Loss: {0}", lossPerExample);
final double objFunc = (lambda / 2) * model.norm2Sqr() + lossPerExample;
LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", objFunc);
final Vector gradient = lossAndGradient.getSecond();
gradient.scale(1.0 / numExamples);
LOG.log(Level.INFO, "OUT: Gradient: {0}", gradient);
returnValue = new Pair<>(objFunc, gradient);
allDead = false;
} else {
allDead = true;
}
}
sendModel = chkAndUpdate();
} while (allDead || !ignoreAndContinue && sendModel);
return returnValue;
}
private boolean chkAndUpdate() {
long t1 = System.currentTimeMillis();
final GroupChanges changes = communicationGroupClient.getTopologyChanges();
long t2 = System.currentTimeMillis();
LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
if (changes.exist()) {
LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology");
t1 = System.currentTimeMillis();
communicationGroupClient.updateTopology();
t2 = System.currentTimeMillis();
LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
return true;
} else {
LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology");
return false;
}
}
private boolean converged(final int iters, final double gradNorm) {
return iters >= maxIters || Math.abs(gradNorm) <= 1e-3;
}
private double findMinEta(final Vector theModel, final Vector descentDir, final Vector lineSearchEvals) {
final double wNormSqr = theModel.norm2Sqr();
final double dNormSqr = descentDir.norm2Sqr();
final double wDotd = theModel.dot(descentDir);
final double[] t = ts.getT();
int i = 0;
for (final double eta : t) {
final double modelNormSqr = wNormSqr + (eta * eta) * dNormSqr + 2 * eta * wDotd;
final double loss = lineSearchEvals.get(i) + ((lambda / 2) * modelNormSqr);
lineSearchEvals.set(i, loss);
++i;
}
LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", lineSearchEvals);
final Tuple<Integer, Double> minTup = lineSearchEvals.min();
LOG.log(Level.INFO, "OUT: MinTup: {0}", minTup);
final double minT = t[minTup.getKey()];
LOG.log(Level.INFO, "OUT: MinT: {0}", minT);
return minT;
}
private Vector getDescentDirection(final Vector gradient) {
gradient.multAdd(lambda, model);
gradient.scale(-1);
LOG.log(Level.INFO, "OUT: DescentDirection: {0}", gradient);
return gradient;
}
}