blob: 7bd5543f49bef8336e5636e386940bb84129a074 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.horn.core;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hama.commons.math.DoubleMatrix;
import com.google.common.base.Preconditions;
public class ParameterMergerServer implements ParameterMerger {
private static final Log LOG = LogFactory.getLog(ParameterMergerServer.class);
/* The parameter merge base. */
protected LayeredNeuralNetwork inMemoryModel;
/* To terminate or not to terminate. */
protected AtomicBoolean isConverge;
/* The number of slave works that request commits. */
protected int SlaveCount;
/* After mergeLimit, terminate whether the result is converging or not. */
protected int mergeLimit;
/*
* last n training errors. converging is decided based on the average value of
* these errors.
*/
protected double[] trainingErrors;
/*
* If the average of last n training errors is smaller than this value, it is
* converging.
*/
protected double prevAvgTrainingError = Double.MAX_VALUE;
/* current index for trainingErrors. */
protected int curTrainingError = 0;
/* how many merges have been conducted? */
protected int mergeCount = 0;
public ParameterMergerServer(LayeredNeuralNetwork inMemoryModel,
AtomicBoolean isConverge, int slaveCount, int mergeLimit,
int convergenceCheckInterval) {
this.inMemoryModel = inMemoryModel;
this.isConverge = isConverge;
this.SlaveCount = slaveCount;
this.mergeLimit = mergeLimit;
this.trainingErrors = new double[convergenceCheckInterval];
}
@Override
public long getProtocolVersion(String s, long l) throws IOException {
return ParameterMerger.versionID;
}
@Override
public ParameterMessage merge(
ParameterMessage msg) {
double trainingError = msg.getTrainingError();
DoubleMatrix[] weightUpdates = msg.getCurMatrices();
DoubleMatrix[] prevWeightUpdates = msg.getPrevMatrices();
Preconditions
.checkArgument(weightUpdates.length == prevWeightUpdates.length);
LOG.info("Start merging: " + this.mergeCount);
if (!this.isConverge.get()) {
synchronized (inMemoryModel) {
LOG.info(">>>> before: " + this.inMemoryModel.getWeightMatrices()[0].get(0, 0));
// this.inMemoryModel.addWeights(weightUpdates);
// this.inMemoryModel.addPrevWeights(prevWeightUpdates);
LOG.info(", after: " + this.inMemoryModel.getWeightMatrices()[0].get(0, 0));
// add trainingError to trainingErrors
this.trainingErrors[this.curTrainingError++] = trainingError;
// check convergence
if (this.trainingErrors.length == this.curTrainingError) {
double curAvgTrainingError = 0.0;
for (int i = 0; i < this.curTrainingError; ++i) {
curAvgTrainingError += this.trainingErrors[i];
}
curAvgTrainingError /= this.trainingErrors.length;
if (prevAvgTrainingError < curAvgTrainingError) {
this.isConverge.set(true);
} else {
// update
prevAvgTrainingError = curAvgTrainingError;
this.curTrainingError = 0;
}
}
if (++this.mergeCount == this.mergeLimit) {
this.isConverge.set(true);
}
}
}
return new ParameterMessage(0, this.isConverge.get(),
this.inMemoryModel.getWeightMatrices(),
this.inMemoryModel.getPrevMatricesUpdates());
}
}