blob: bb27b3f3883768c1fead6e29c6d4c4141cb090f1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.iterative;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.IOException;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/** Implementation of PageRank accounting for "sink" vertices with 0 out-degree. */
@RunWith(Parameterized.class)
@SuppressWarnings({"serial", "unchecked"})
public class DanglingPageRankITCase extends MultipleProgramsTestBase {
private static final String AGGREGATOR_NAME = "pagerank.aggregator";
public DanglingPageRankITCase(TestExecutionMode mode) {
super(mode);
}
@Test
public void testDanglingPageRank() {
try {
final int numIterations = 25;
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple2<Long, Boolean>> vertices =
env.fromElements(
new Tuple2<>(1L, false),
new Tuple2<>(2L, false),
new Tuple2<>(5L, false),
new Tuple2<>(3L, true),
new Tuple2<>(4L, false));
DataSet<PageWithLinks> edges =
env.fromElements(
new PageWithLinks(2L, new long[] {1}),
new PageWithLinks(5L, new long[] {2, 4}),
new PageWithLinks(4L, new long[] {3, 2}),
new PageWithLinks(1L, new long[] {4, 2, 3}));
final long numVertices = vertices.count();
final long numDanglingVertices =
vertices.filter(
new FilterFunction<Tuple2<Long, Boolean>>() {
@Override
public boolean filter(Tuple2<Long, Boolean> value) {
return value.f1;
}
})
.count();
DataSet<PageWithRankAndDangling> verticesWithInitialRank =
vertices.map(
new MapFunction<Tuple2<Long, Boolean>, PageWithRankAndDangling>() {
@Override
public PageWithRankAndDangling map(Tuple2<Long, Boolean> value) {
return new PageWithRankAndDangling(
value.f0, 1.0 / numVertices, value.f1);
}
});
IterativeDataSet<PageWithRankAndDangling> iteration =
verticesWithInitialRank.iterate(numIterations);
iteration
.getAggregators()
.registerAggregationConvergenceCriterion(
AGGREGATOR_NAME,
new PageRankStatsAggregator(),
new DiffL1NormConvergenceCriterion());
DataSet<PageWithRank> partialRanks =
iteration
.join(edges)
.where("pageId")
.equalTo("pageId")
.with(
new FlatJoinFunction<
PageWithRankAndDangling,
PageWithLinks,
PageWithRank>() {
@Override
public void join(
PageWithRankAndDangling page,
PageWithLinks links,
Collector<PageWithRank> out) {
double rankToDistribute =
page.rank / (double) links.targets.length;
PageWithRank output =
new PageWithRank(0L, rankToDistribute);
for (long target : links.targets) {
output.pageId = target;
out.collect(output);
}
}
});
DataSet<PageWithRankAndDangling> newRanks =
iteration
.coGroup(partialRanks)
.where("pageId")
.equalTo("pageId")
.with(
new RichCoGroupFunction<
PageWithRankAndDangling,
PageWithRank,
PageWithRankAndDangling>() {
private static final double BETA = 0.85;
private final double randomJump =
(1.0 - BETA) / numVertices;
private PageRankStatsAggregator aggregator;
private double danglingRankFactor;
@Override
public void open(Configuration parameters)
throws Exception {
int currentIteration =
getIterationRuntimeContext()
.getSuperstepNumber();
aggregator =
getIterationRuntimeContext()
.getIterationAggregator(
AGGREGATOR_NAME);
if (currentIteration == 1) {
danglingRankFactor =
BETA
* (double) numDanglingVertices
/ ((double) numVertices
* (double) numVertices);
} else {
PageRankStats previousAggregate =
getIterationRuntimeContext()
.getPreviousIterationAggregate(
AGGREGATOR_NAME);
danglingRankFactor =
BETA
* previousAggregate.danglingRank()
/ (double) numVertices;
}
}
@Override
public void coGroup(
Iterable<PageWithRankAndDangling> currentPages,
Iterable<PageWithRank> partialRanks,
Collector<PageWithRankAndDangling> out) {
// compute the next rank
long edges = 0;
double summedRank = 0;
for (PageWithRank partial : partialRanks) {
summedRank += partial.rank;
edges++;
}
double rank =
BETA * summedRank
+ randomJump
+ danglingRankFactor;
// current rank, for stats and convergence
PageWithRankAndDangling currentPage =
currentPages.iterator().next();
double currentRank = currentPage.rank;
boolean isDangling = currentPage.dangling;
// maintain statistics to compensate for probability
// loss on dangling nodes
double danglingRankToAggregate = isDangling ? rank : 0;
long danglingVerticesToAggregate = isDangling ? 1 : 0;
double diff = Math.abs(currentRank - rank);
aggregator.aggregate(
diff,
rank,
danglingRankToAggregate,
danglingVerticesToAggregate,
1,
edges);
currentPage.rank = rank;
out.collect(currentPage);
}
});
List<PageWithRankAndDangling> result = iteration.closeWith(newRanks).collect();
double totalRank = 0.0;
for (PageWithRankAndDangling r : result) {
totalRank += r.rank;
assertTrue(r.pageId >= 1 && r.pageId <= 5);
assertTrue(r.pageId != 3 || r.dangling);
}
assertEquals(1.0, totalRank, 0.001);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
// ------------------------------------------------------------------------
// custom types
// ------------------------------------------------------------------------
/** POJO for page ID and rank value. */
public static class PageWithRank {
public long pageId;
public double rank;
public PageWithRank() {}
public PageWithRank(long pageId, double rank) {
this.pageId = pageId;
this.rank = rank;
}
}
/** POJO for page ID, rank value, and whether a "dangling" vertex with 0 out-degree. */
public static class PageWithRankAndDangling {
public long pageId;
public double rank;
public boolean dangling;
public PageWithRankAndDangling() {}
public PageWithRankAndDangling(long pageId, double rank, boolean dangling) {
this.pageId = pageId;
this.rank = rank;
this.dangling = dangling;
}
@Override
public String toString() {
return "PageWithRankAndDangling{"
+ "pageId="
+ pageId
+ ", rank="
+ rank
+ ", dangling="
+ dangling
+ '}';
}
}
/** POJO for page ID and list of target IDs. */
public static class PageWithLinks {
public long pageId;
public long[] targets;
public PageWithLinks() {}
public PageWithLinks(long pageId, long[] targets) {
this.pageId = pageId;
this.targets = targets;
}
}
// ------------------------------------------------------------------------
// statistics
// ------------------------------------------------------------------------
/** PageRank statistics. */
public static class PageRankStats implements Value {
private double diff;
private double rank;
private double danglingRank;
private long numDanglingVertices;
private long numVertices;
private long edges;
public PageRankStats() {}
public PageRankStats(
double diff,
double rank,
double danglingRank,
long numDanglingVertices,
long numVertices,
long edges) {
this.diff = diff;
this.rank = rank;
this.danglingRank = danglingRank;
this.numDanglingVertices = numDanglingVertices;
this.numVertices = numVertices;
this.edges = edges;
}
public double diff() {
return diff;
}
public double rank() {
return rank;
}
public double danglingRank() {
return danglingRank;
}
public long numDanglingVertices() {
return numDanglingVertices;
}
public long numVertices() {
return numVertices;
}
public long edges() {
return edges;
}
@Override
public void write(DataOutputView out) throws IOException {
out.writeDouble(diff);
out.writeDouble(rank);
out.writeDouble(danglingRank);
out.writeLong(numDanglingVertices);
out.writeLong(numVertices);
out.writeLong(edges);
}
@Override
public void read(DataInputView in) throws IOException {
diff = in.readDouble();
rank = in.readDouble();
danglingRank = in.readDouble();
numDanglingVertices = in.readLong();
numVertices = in.readLong();
edges = in.readLong();
}
@Override
public String toString() {
return "PageRankStats: diff ["
+ diff
+ "], rank ["
+ rank
+ "], danglingRank ["
+ danglingRank
+ "], numDanglingVertices ["
+ numDanglingVertices
+ "], numVertices ["
+ numVertices
+ "], edges ["
+ edges
+ "]";
}
}
private static class PageRankStatsAggregator implements Aggregator<PageRankStats> {
private double diff;
private double rank;
private double danglingRank;
private long numDanglingVertices;
private long numVertices;
private long edges;
@Override
public PageRankStats getAggregate() {
return new PageRankStats(
diff, rank, danglingRank, numDanglingVertices, numVertices, edges);
}
public void aggregate(
double diffDelta,
double rankDelta,
double danglingRankDelta,
long danglingVerticesDelta,
long verticesDelta,
long edgesDelta) {
diff += diffDelta;
rank += rankDelta;
danglingRank += danglingRankDelta;
numDanglingVertices += danglingVerticesDelta;
numVertices += verticesDelta;
edges += edgesDelta;
}
@Override
public void aggregate(PageRankStats pageRankStats) {
diff += pageRankStats.diff();
rank += pageRankStats.rank();
danglingRank += pageRankStats.danglingRank();
numDanglingVertices += pageRankStats.numDanglingVertices();
numVertices += pageRankStats.numVertices();
edges += pageRankStats.edges();
}
@Override
public void reset() {
diff = 0;
rank = 0;
danglingRank = 0;
numDanglingVertices = 0;
numVertices = 0;
edges = 0;
}
}
private static class DiffL1NormConvergenceCriterion
implements ConvergenceCriterion<PageRankStats> {
private static final double EPSILON = 0.00005;
@Override
public boolean isConverged(int iteration, PageRankStats pageRankStats) {
return pageRankStats.diff() < EPSILON;
}
}
}