blob: 2f8c287c006cfa122963338bf43d6dac2d8fbb12 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.horn.bsp;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import org.apache.hama.commons.io.MatrixWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DoubleMatrix;
/**
* NeuralNetworkMessage transmits the messages between peers during the training
* of neural networks.
*
*/
public class SmallLayeredNeuralNetworkMessage implements Writable {
protected double trainingError;
protected DoubleMatrix[] curMatrices;
protected DoubleMatrix[] prevMatrices;
protected boolean converge;
public SmallLayeredNeuralNetworkMessage() {
}
public SmallLayeredNeuralNetworkMessage(double trainingError,
boolean converge, DoubleMatrix[] weightMatrices,
DoubleMatrix[] prevMatrices) {
this.trainingError = trainingError;
this.converge = converge;
this.curMatrices = weightMatrices;
this.prevMatrices = prevMatrices;
}
@Override
public void readFields(DataInput input) throws IOException {
trainingError = input.readDouble();
converge = input.readBoolean();
int numMatrices = input.readInt();
boolean hasPrevMatrices = input.readBoolean();
curMatrices = new DenseDoubleMatrix[numMatrices];
// read matrice updates
for (int i = 0; i < curMatrices.length; ++i) {
curMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
}
if (hasPrevMatrices) {
prevMatrices = new DenseDoubleMatrix[numMatrices];
// read previous matrices updates
for (int i = 0; i < prevMatrices.length; ++i) {
prevMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
}
}
}
@Override
public void write(DataOutput output) throws IOException {
output.writeDouble(trainingError);
output.writeBoolean(converge);
output.writeInt(curMatrices.length);
if (prevMatrices == null) {
output.writeBoolean(false);
} else {
output.writeBoolean(true);
}
for (DoubleMatrix matrix : curMatrices) {
MatrixWritable.write(matrix, output);
}
if (prevMatrices != null) {
for (DoubleMatrix matrix : prevMatrices) {
MatrixWritable.write(matrix, output);
}
}
}
public double getTrainingError() {
return trainingError;
}
public void setTrainingError(double trainingError) {
this.trainingError = trainingError;
}
public boolean isConverge() {
return converge;
}
public void setConverge(boolean converge) {
this.converge = converge;
}
public DoubleMatrix[] getCurMatrices() {
return curMatrices;
}
public void setMatrices(DoubleMatrix[] curMatrices) {
this.curMatrices = curMatrices;
}
public DoubleMatrix[] getPrevMatrices() {
return prevMatrices;
}
public void setPrevMatrices(DoubleMatrix[] prevMatrices) {
this.prevMatrices = prevMatrices;
}
}