blob: 8c849957ff9001d52886e4622a401fd3d9e5347e [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hama.graph;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.MapWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import com.google.common.base.Preconditions;
/**
* Runner class to do the tasks that need to be done if aggregation was
* configured.
*
*/
@SuppressWarnings("rawtypes")
public final class AggregationRunner<V extends WritableComparable, E extends Writable, M extends Writable> {
// multiple aggregator arrays
private Aggregator<M>[] aggregators;
private Writable[] globalAggregatorResult;
private IntWritable[] globalAggregatorIncrement;
private boolean[] isAbstractAggregator;
private String[] aggregatorClassNames;
private Text[] aggregatorValueFlag;
private Text[] aggregatorIncrementFlag;
// aggregator on the master side
private Aggregator<M>[] masterAggregator;
private boolean enabled = false;
private Configuration conf;
@SuppressWarnings("unchecked")
public void setupAggregators(
BSPPeer<Writable, Writable, Writable, Writable, GraphJobMessage> peer) {
this.conf = peer.getConfiguration();
String aggregatorClasses = peer.getConfiguration().get(
GraphJob.AGGREGATOR_CLASS_ATTR);
if (aggregatorClasses != null) {
enabled = true;
aggregatorClassNames = aggregatorClasses.split(";");
// init to the split size
aggregators = new Aggregator[aggregatorClassNames.length];
globalAggregatorResult = new Writable[aggregatorClassNames.length];
globalAggregatorIncrement = new IntWritable[aggregatorClassNames.length];
isAbstractAggregator = new boolean[aggregatorClassNames.length];
aggregatorValueFlag = new Text[aggregatorClassNames.length];
aggregatorIncrementFlag = new Text[aggregatorClassNames.length];
if (GraphJobRunner.isMasterTask(peer)) {
masterAggregator = new Aggregator[aggregatorClassNames.length];
}
for (int i = 0; i < aggregatorClassNames.length; i++) {
aggregators[i] = getNewAggregator(aggregatorClassNames[i]);
aggregatorValueFlag[i] = new Text(
GraphJobRunner.S_FLAG_AGGREGATOR_VALUE + ";" + i);
aggregatorIncrementFlag[i] = new Text(
GraphJobRunner.S_FLAG_AGGREGATOR_INCREMENT + ";" + i);
if (aggregators[i] instanceof AbstractAggregator) {
isAbstractAggregator[i] = true;
}
if (GraphJobRunner.isMasterTask(peer)) {
masterAggregator[i] = getNewAggregator(aggregatorClassNames[i]);
}
}
}
}
/**
* Runs the aggregators by sending their values to the master task.
*
* @param changedVertexCnt
*/
public void sendAggregatorValues(
BSPPeer<Writable, Writable, Writable, Writable, GraphJobMessage> peer,
int activeVertices, int changedVertexCnt) throws IOException {
// send msgCounts to the master task
MapWritable updatedCnt = new MapWritable();
updatedCnt.put(GraphJobRunner.FLAG_MESSAGE_COUNTS, new IntWritable(
activeVertices));
// send total number of vertices changes
updatedCnt.put(GraphJobRunner.FLAG_VERTEX_ALTER_COUNTER, new LongWritable(
changedVertexCnt));
// also send aggregated values to the master
if (aggregators != null) {
for (int i = 0; i < this.aggregators.length; i++) {
updatedCnt.put(aggregatorValueFlag[i], aggregators[i].getValue());
if (isAbstractAggregator[i]) {
updatedCnt.put(aggregatorIncrementFlag[i],
((AbstractAggregator<M>) aggregators[i]).getTimesAggregated());
}
}
for (int i = 0; i < aggregators.length; i++) {
// now create new aggregators for the next iteration
aggregators[i] = getNewAggregator(aggregatorClassNames[i]);
if (GraphJobRunner.isMasterTask(peer)) {
masterAggregator[i] = getNewAggregator(aggregatorClassNames[i]);
}
}
}
peer.send(GraphJobRunner.getMasterTask(peer), new GraphJobMessage(
updatedCnt));
}
/**
* Aggregates the last value before computation and the value after the
* computation.
*
* @param lastValue the value before compute().
* @param value the vertex.
*/
public void aggregateVertex(int index, M lastValue, M value) {
if (isEnabled()) {
Aggregator<M> aggregator = this.aggregators[index];
aggregator.aggregate(value);
if (isAbstractAggregator[index]) {
AbstractAggregator<M> intern = (AbstractAggregator<M>) aggregator;
intern.aggregate(lastValue, value);
intern.aggregateInternal();
}
}
}
/**
* The method the master task does, it globally aggregates the values of each
* peer and updates the given map accordingly.
*/
public void doMasterAggregation(MapWritable updatedCnt) {
if (isEnabled()) {
// work through the master aggregators
for (int i = 0; i < masterAggregator.length; i++) {
Writable lastAggregatedValue = masterAggregator[i].getValue();
if (isAbstractAggregator[i]) {
final AbstractAggregator<M> intern = ((AbstractAggregator<M>) masterAggregator[i]);
final Writable finalizeAggregation = intern.finalizeAggregation();
if (intern.finalizeAggregation() != null) {
lastAggregatedValue = finalizeAggregation;
}
// this count is usually the times of active
// vertices in the graph
updatedCnt.put(aggregatorIncrementFlag[i],
intern.getTimesAggregated());
}
updatedCnt.put(aggregatorValueFlag[i], lastAggregatedValue);
}
}
}
/**
* Receives aggregated values from a master task.
*
* @return always true if no aggregators are defined, false if aggregators say
* we haven't seen any messages anymore.
*/
public boolean receiveAggregatedValues(MapWritable updatedValues,
long iteration) throws IOException, SyncException, InterruptedException {
// map is the first value that is in the queue
for (int i = 0; i < aggregators.length; i++) {
globalAggregatorResult[i] = updatedValues.get(aggregatorValueFlag[i]);
globalAggregatorIncrement[i] = (IntWritable) updatedValues
.get(aggregatorIncrementFlag[i]);
}
IntWritable count = (IntWritable) updatedValues
.get(GraphJobRunner.FLAG_MESSAGE_COUNTS);
if (count != null && count.get() == Integer.MIN_VALUE) {
return false;
}
return true;
}
/**
* @return true if aggregators were defined. Normally used by the internal
* stateful methods, outside shouldn't use it too extensively.
*/
public boolean isEnabled() {
return enabled;
}
/**
* Method to let the master read messages from peers and aggregate a value.
*/
public void masterReadAggregatedValue(Text textIndex, M value) {
int index = Integer.parseInt(textIndex.toString().split(";")[1]);
masterAggregator[index].aggregate(value);
}
/**
* Method to let the master read messages from peers and aggregate the
* incremental value.
*/
public void masterReadAggregatedIncrementalValue(Text textIndex, M value) {
int index = Integer.parseInt(textIndex.toString().split(";")[1]);
if (isAbstractAggregator[index]) {
((AbstractAggregator<M>) masterAggregator[index])
.addTimesAggregated(((IntWritable) value).get());
}
}
@SuppressWarnings("unchecked")
private Aggregator<M> getNewAggregator(String clsName) {
try {
return (Aggregator<M>) ReflectionUtils.newInstance(
conf.getClassByName(clsName), conf);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
throw new IllegalArgumentException("Aggregator class " + clsName
+ " could not be found or instantiated!");
}
public final Writable getLastAggregatedValue(int index) {
return globalAggregatorResult[Preconditions.checkPositionIndex(index,
globalAggregatorResult.length)];
}
public final IntWritable getNumLastAggregatedVertices(int index) {
return globalAggregatorIncrement[Preconditions.checkPositionIndex(index,
globalAggregatorIncrement.length)];
}
}