| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.test.iterative; |
| |
| import org.apache.flink.api.common.functions.RichFlatMapFunction; |
| import org.apache.flink.api.common.functions.RichGroupReduceFunction; |
| import org.apache.flink.api.common.functions.RichJoinFunction; |
| import org.apache.flink.api.java.DataSet; |
| import org.apache.flink.api.java.ExecutionEnvironment; |
| import org.apache.flink.api.java.functions.KeySelector; |
| import org.apache.flink.api.java.operators.DeltaIteration; |
| import org.apache.flink.api.java.tuple.Tuple2; |
| import org.apache.flink.test.util.JavaProgramTestBase; |
| import org.apache.flink.util.Collector; |
| |
| import java.util.ArrayList; |
| import java.util.List; |
| |
| /** |
| * Iterative Connected Components test case which recomputes only the elements of the solution set |
| * whose at least one dependency (in-neighbor) has changed since the last iteration. Requires two |
| * joins with the solution set. |
| */ |
| @SuppressWarnings("serial") |
| public class DependencyConnectedComponentsITCase extends JavaProgramTestBase { |
| |
| private static final int MAX_ITERATIONS = 20; |
| private static final int parallelism = 1; |
| |
| protected static List<Tuple2<Long, Long>> verticesInput = new ArrayList<Tuple2<Long, Long>>(); |
| protected static List<Tuple2<Long, Long>> edgesInput = new ArrayList<Tuple2<Long, Long>>(); |
| private String resultPath; |
| private String expectedResult; |
| |
| @Override |
| protected void preSubmit() throws Exception { |
| verticesInput.clear(); |
| edgesInput.clear(); |
| |
| // vertices input |
| verticesInput.add(new Tuple2<>(1L, 1L)); |
| verticesInput.add(new Tuple2<>(2L, 2L)); |
| verticesInput.add(new Tuple2<>(3L, 3L)); |
| verticesInput.add(new Tuple2<>(4L, 4L)); |
| verticesInput.add(new Tuple2<>(5L, 5L)); |
| verticesInput.add(new Tuple2<>(6L, 6L)); |
| verticesInput.add(new Tuple2<>(7L, 7L)); |
| verticesInput.add(new Tuple2<>(8L, 8L)); |
| verticesInput.add(new Tuple2<>(9L, 9L)); |
| |
| // vertices input |
| edgesInput.add(new Tuple2<>(1L, 2L)); |
| edgesInput.add(new Tuple2<>(1L, 3L)); |
| edgesInput.add(new Tuple2<>(2L, 3L)); |
| edgesInput.add(new Tuple2<>(2L, 4L)); |
| edgesInput.add(new Tuple2<>(2L, 1L)); |
| edgesInput.add(new Tuple2<>(3L, 1L)); |
| edgesInput.add(new Tuple2<>(3L, 2L)); |
| edgesInput.add(new Tuple2<>(4L, 2L)); |
| edgesInput.add(new Tuple2<>(4L, 6L)); |
| edgesInput.add(new Tuple2<>(5L, 6L)); |
| edgesInput.add(new Tuple2<>(6L, 4L)); |
| edgesInput.add(new Tuple2<>(6L, 5L)); |
| edgesInput.add(new Tuple2<>(7L, 8L)); |
| edgesInput.add(new Tuple2<>(7L, 9L)); |
| edgesInput.add(new Tuple2<>(8L, 7L)); |
| edgesInput.add(new Tuple2<>(8L, 9L)); |
| edgesInput.add(new Tuple2<>(9L, 7L)); |
| edgesInput.add(new Tuple2<>(9L, 8L)); |
| |
| resultPath = getTempDirPath("result"); |
| |
| expectedResult = |
| "(1,1)\n" + "(2,1)\n" + "(3,1)\n" + "(4,1)\n" + "(5,1)\n" + "(6,1)\n" + "(7,7)\n" |
| + "(8,7)\n" + "(9,7)\n"; |
| } |
| |
| @Override |
| protected void testProgram() throws Exception { |
| DependencyConnectedComponentsProgram.runProgram(resultPath); |
| } |
| |
| @Override |
| protected void postSubmit() throws Exception { |
| compareResultsByLinesInMemory(expectedResult, resultPath); |
| } |
| |
| private static class DependencyConnectedComponentsProgram { |
| |
| public static String runProgram(String resultPath) throws Exception { |
| |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| env.setParallelism(parallelism); |
| |
| DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput); |
| DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput); |
| int keyPosition = 0; |
| |
| DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration = |
| initialSolutionSet.iterateDelta( |
| initialSolutionSet, MAX_ITERATIONS, keyPosition); |
| |
| DataSet<Long> candidates = |
| iteration |
| .getWorkset() |
| .join(edges) |
| .where(0) |
| .equalTo(0) |
| .with(new FindCandidatesJoin()) |
| .groupBy( |
| new KeySelector<Long, Long>() { |
| public Long getKey(Long id) { |
| return id; |
| } |
| }) |
| .reduceGroup(new RemoveDuplicatesReduce()); |
| |
| DataSet<Tuple2<Long, Long>> candidatesDependencies = |
| candidates |
| .join(edges) |
| .where( |
| new KeySelector<Long, Long>() { |
| public Long getKey(Long id) { |
| return id; |
| } |
| }) |
| .equalTo( |
| new KeySelector<Tuple2<Long, Long>, Long>() { |
| public Long getKey(Tuple2<Long, Long> vertexWithId) { |
| return vertexWithId.f1; |
| } |
| }) |
| .with(new FindCandidatesDependenciesJoin()); |
| |
| DataSet<Tuple2<Long, Long>> verticesWithNewComponents = |
| candidatesDependencies |
| .join(iteration.getSolutionSet()) |
| .where(0) |
| .equalTo(0) |
| .with(new NeighborWithComponentIDJoin()) |
| .groupBy(0) |
| .reduceGroup(new MinimumReduce()); |
| |
| DataSet<Tuple2<Long, Long>> updatedComponentId = |
| verticesWithNewComponents |
| .join(iteration.getSolutionSet()) |
| .where(0) |
| .equalTo(0) |
| .flatMap(new MinimumIdFilter()); |
| |
| iteration.closeWith(updatedComponentId, updatedComponentId).writeAsText(resultPath); |
| |
| env.execute(); |
| |
| return resultPath; |
| } |
| } |
| |
| private static final class FindCandidatesJoin |
| extends RichJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Long> { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public Long join(Tuple2<Long, Long> vertexWithCompId, Tuple2<Long, Long> edge) |
| throws Exception { |
| // emit target vertex |
| return edge.f1; |
| } |
| } |
| |
| private static final class RemoveDuplicatesReduce extends RichGroupReduceFunction<Long, Long> { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void reduce(Iterable<Long> values, Collector<Long> out) { |
| out.collect(values.iterator().next()); |
| } |
| } |
| |
| private static final class FindCandidatesDependenciesJoin |
| extends RichJoinFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>> { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public Tuple2<Long, Long> join(Long candidateId, Tuple2<Long, Long> edge) throws Exception { |
| return edge; |
| } |
| } |
| |
| private static final class NeighborWithComponentIDJoin |
| extends RichJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public Tuple2<Long, Long> join(Tuple2<Long, Long> edge, Tuple2<Long, Long> vertexWithCompId) |
| throws Exception { |
| |
| vertexWithCompId.setField(edge.f1, 0); |
| return vertexWithCompId; |
| } |
| } |
| |
| private static final class MinimumReduce |
| extends RichGroupReduceFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> { |
| |
| private static final long serialVersionUID = 1L; |
| final Tuple2<Long, Long> resultVertex = new Tuple2<Long, Long>(); |
| |
| @Override |
| public void reduce(Iterable<Tuple2<Long, Long>> values, Collector<Tuple2<Long, Long>> out) { |
| Long vertexId = 0L; |
| Long minimumCompId = Long.MAX_VALUE; |
| |
| for (Tuple2<Long, Long> value : values) { |
| vertexId = value.f0; |
| Long candidateCompId = value.f1; |
| if (candidateCompId < minimumCompId) { |
| minimumCompId = candidateCompId; |
| } |
| } |
| resultVertex.f0 = vertexId; |
| resultVertex.f1 = minimumCompId; |
| |
| out.collect(resultVertex); |
| } |
| } |
| |
| private static final class MinimumIdFilter |
| extends RichFlatMapFunction< |
| Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void flatMap( |
| Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId, |
| Collector<Tuple2<Long, Long>> out) { |
| if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { |
| out.collect(vertexWithNewAndOldId.f0); |
| } |
| } |
| } |
| } |