blob: 0f5f1ace6f73c3e23a141b58cc85b2838e515627 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.giraph.block_app.library;
import org.apache.giraph.block_app.framework.api.BlockWorkerReceiveApi;
import org.apache.giraph.block_app.framework.api.BlockWorkerSendApi;
import org.apache.giraph.block_app.framework.api.local.LocalBlockRunner;
import org.apache.giraph.block_app.framework.block.Block;
import org.apache.giraph.block_app.framework.block.SequenceBlock;
import org.apache.giraph.block_app.framework.piece.Piece;
import org.apache.giraph.block_app.framework.piece.interfaces.VertexReceiver;
import org.apache.giraph.block_app.framework.piece.interfaces.VertexSender;
import org.apache.giraph.block_app.test_setup.NumericTestGraph;
import org.apache.giraph.combiner.MaxMessageCombiner;
import org.apache.giraph.combiner.SumMessageCombiner;
import org.apache.giraph.conf.GiraphConfiguration;
import org.apache.giraph.conf.GiraphConstants;
import org.apache.giraph.edge.Edge;
import org.apache.giraph.function.primitive.PrimitiveRefs.LongRef;
import org.apache.giraph.reducers.impl.SumReduce;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.junit.Assert;
import org.junit.Test;
import com.google.common.collect.Iterators;
/**
* Tests and examples of using SendMessageChain
*/
public class TestMessageChain {
private static GiraphConfiguration createConf() {
GiraphConfiguration conf = new GiraphConfiguration();
GiraphConstants.VERTEX_ID_CLASS.set(conf, LongWritable.class);
GiraphConstants.VERTEX_VALUE_CLASS.set(conf, LongWritable.class);
GiraphConstants.EDGE_VALUE_CLASS.set(conf, NullWritable.class);
return conf;
}
private static NumericTestGraph<LongWritable, LongWritable, NullWritable> createTestGraph() {
NumericTestGraph<LongWritable, LongWritable, NullWritable> graph =
new NumericTestGraph<LongWritable, LongWritable, NullWritable>(createConf());
graph.addVertex(1);
graph.addVertex(2);
graph.addVertex(3);
graph.addVertex(4);
graph.addSymmetricEdge(1, 2);
graph.addSymmetricEdge(2, 3);
return graph;
}
@Test
public void testReply() {
NumericTestGraph<LongWritable, LongWritable, NullWritable> graph = createTestGraph();
// calculates max ID of FOFs
Block reply = SendMessageChain.<LongWritable, LongWritable, NullWritable, LongWritable>
startSendToNeighbors(
"SendMyIdToAllNeighbors",
LongWritable.class,
VertexSuppliers.vertexIdSupplier()
).thenSendToNeighbors(
"SendMaxIReceivedToAllNeighbors",
LongWritable.class,
(vertex, messages) -> new LongWritable(max(messages))
).endConsume(
(vertex, messages) -> vertex.getValue().set(max(messages))
);
LocalBlockRunner.runBlock(graph.getTestGraph(), reply, new Object());
Assert.assertEquals(3, graph.getVertex(1).getValue().get());
Assert.assertEquals(2, graph.getVertex(2).getValue().get());
Assert.assertEquals(3, graph.getVertex(3).getValue().get());
Assert.assertEquals(0, graph.getVertex(4).getValue().get());
}
@Test
public void testReplyCombiner() {
NumericTestGraph<LongWritable, LongWritable, NullWritable> graph = createTestGraph();
// calculates max ID of FOFs
Block reply = SendMessageChain.<LongWritable, LongWritable, NullWritable, LongWritable>
startSendToNeighbors(
"SendMyIdToAllNeighbors",
MaxMessageCombiner.LONG,
VertexSuppliers.vertexIdSupplier()
).thenSendToNeighbors(
"SendMaxIReceivedToAllNeighbors",
MaxMessageCombiner.LONG,
(vertex, message) -> message
).endConsume(
(vertex, message) -> vertex.getValue().set(message != null ? message.get() : 0)
);
LocalBlockRunner.runBlock(graph.getTestGraph(), reply, new Object());
Assert.assertEquals(3, graph.getVertex(1).getValue().get());
Assert.assertEquals(2, graph.getVertex(2).getValue().get());
Assert.assertEquals(3, graph.getVertex(3).getValue().get());
Assert.assertEquals(0, graph.getVertex(4).getValue().get());
}
@Test
public void testReplyCombinerEndReduce() {
NumericTestGraph<LongWritable, LongWritable, NullWritable> graph = createTestGraph();
LongRef sumOfAll = new LongRef(0);
// calculates max ID of FOFs
Block reply = SendMessageChain.<LongWritable, LongWritable, NullWritable, LongWritable>
startSendToNeighbors(
"SendMyIdToAllNeighbors",
MaxMessageCombiner.LONG,
VertexSuppliers.vertexIdSupplier()
).thenSendToNeighbors(
"SendMaxIReceivedToAllNeighbors",
MaxMessageCombiner.LONG,
(vertex, message) -> message
).endReduce(
"SumAllReceivedValues",
SumReduce.LONG,
(vertex, message) -> message != null ? message : new LongWritable(0),
(value) -> sumOfAll.value = value.get()
);
LocalBlockRunner.runBlock(
graph.getTestGraph(),
new SequenceBlock(
reply,
Pieces.forAllVertices(
"SetAllValuesToReduced",
(vertex) -> ((LongWritable) vertex.getValue()).set(sumOfAll.value))),
new Object());
Assert.assertEquals(8, graph.getVertex(1).getValue().get());
Assert.assertEquals(8, graph.getVertex(2).getValue().get());
Assert.assertEquals(8, graph.getVertex(3).getValue().get());
Assert.assertEquals(8, graph.getVertex(4).getValue().get());
// Block execution is happening in the separate environment if SERIALIZE_MASTER is used,
// so our instance of sumOfAll will be unchanged
Assert.assertEquals(LocalBlockRunner.SERIALIZE_MASTER.getDefaultValue() ? 0 : 8, sumOfAll.value);
}
@Test
public void testStartCustom() {
NumericTestGraph<LongWritable, LongWritable, NullWritable> graph = createTestGraph();
Block reply = SendMessageChain.<LongWritable, LongWritable, NullWritable, LongWritable>
startCustom(
// Sends ID to it's first neighbor, passing max of received messages to next part of the chain
(consumer) -> new Piece<LongWritable, LongWritable, NullWritable, LongWritable, Object>() {
@Override
public VertexSender<LongWritable, LongWritable, NullWritable> getVertexSender(
BlockWorkerSendApi<LongWritable, LongWritable, NullWritable, LongWritable> workerApi,
Object executionStage) {
return (vertex) -> {
Edge<LongWritable, NullWritable> edge =
Iterators.getNext(vertex.getEdges().iterator(), null);
if (edge != null) {
workerApi.sendMessage(edge.getTargetVertexId(), vertex.getId());
}
};
}
@Override
public VertexReceiver<LongWritable, LongWritable, NullWritable, LongWritable>
getVertexReceiver(BlockWorkerReceiveApi<LongWritable> workerApi, Object executionStage) {
return (vertex, messages) -> {
consumer.apply(vertex, new LongWritable(max(messages)));
};
}
@Override
protected Class<LongWritable> getMessageClass() {
return LongWritable.class;
}
}
).thenSendToNeighbors(
"SendMaxIReceivedToAllNeighbors",
SumMessageCombiner.LONG,
(vertex, message) -> message
).endConsume(
(vertex, message) -> vertex.getValue().set(message != null ? message.get() : 0)
);
LocalBlockRunner.runBlock(graph.getTestGraph(), reply, new Object());
Assert.assertEquals(3, graph.getVertex(1).getValue().get());
Assert.assertEquals(2, graph.getVertex(2).getValue().get());
Assert.assertEquals(3, graph.getVertex(3).getValue().get());
Assert.assertEquals(0, graph.getVertex(4).getValue().get());
}
private static long max(Iterable<LongWritable> messages) {
long result = 0;
for (LongWritable message : messages) {
result = Math.max(result, message.get());
}
return result;
}
}