| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.runtime.operators; |
| |
| import org.apache.flink.api.common.ExecutionConfig; |
| import org.apache.flink.api.common.functions.GroupCombineFunction; |
| import org.apache.flink.api.common.functions.GroupReduceFunction; |
| import org.apache.flink.api.common.functions.RichGroupReduceFunction; |
| import org.apache.flink.api.common.typeutils.TypeComparator; |
| import org.apache.flink.api.common.typeutils.TypeSerializer; |
| import org.apache.flink.api.common.typeutils.base.IntComparator; |
| import org.apache.flink.api.common.typeutils.base.IntSerializer; |
| import org.apache.flink.api.java.tuple.Tuple2; |
| import org.apache.flink.api.java.typeutils.runtime.TupleComparator; |
| import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; |
| import org.apache.flink.runtime.operators.testutils.DelayingIterator; |
| import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector; |
| import org.apache.flink.runtime.operators.testutils.ExpectedTestException; |
| import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator; |
| import org.apache.flink.runtime.operators.testutils.UnaryOperatorTestBase; |
| import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator; |
| import org.apache.flink.util.Collector; |
| import org.apache.flink.util.MutableObjectIterator; |
| |
| import org.junit.Test; |
| |
| import java.util.ArrayList; |
| |
| import static org.junit.Assert.*; |
| |
| public class CombineTaskTest |
| extends UnaryOperatorTestBase< |
| RichGroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>, |
| Tuple2<Integer, Integer>, |
| Tuple2<Integer, Integer>> { |
| |
| private static final long COMBINE_MEM = 3 * 1024 * 1024; |
| |
| private final double combine_frac; |
| |
| private final ArrayList<Tuple2<Integer, Integer>> outList = new ArrayList<>(); |
| |
| @SuppressWarnings("unchecked") |
| private final TypeSerializer<Tuple2<Integer, Integer>> serializer = |
| new TupleSerializer<>( |
| (Class<Tuple2<Integer, Integer>>) (Class<?>) Tuple2.class, |
| new TypeSerializer<?>[] {IntSerializer.INSTANCE, IntSerializer.INSTANCE}); |
| |
| private final TypeComparator<Tuple2<Integer, Integer>> comparator = |
| new TupleComparator<>( |
| new int[] {0}, |
| new TypeComparator<?>[] {new IntComparator(true)}, |
| new TypeSerializer<?>[] {IntSerializer.INSTANCE}); |
| |
| public CombineTaskTest(ExecutionConfig config) { |
| super(config, COMBINE_MEM, 0); |
| |
| combine_frac = (double) COMBINE_MEM / this.getMemoryManager().getMemorySize(); |
| } |
| |
| @Test |
| public void testCombineTask() { |
| try { |
| int keyCnt = 100; |
| int valCnt = 20; |
| |
| setInput(new UniformIntTupleGenerator(keyCnt, valCnt, false), serializer); |
| addDriverComparator(this.comparator); |
| addDriverComparator(this.comparator); |
| setOutput(this.outList, serializer); |
| |
| getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE); |
| getTaskConfig().setRelativeMemoryDriver(combine_frac); |
| getTaskConfig().setFilehandlesDriver(2); |
| |
| final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> |
| testTask = new GroupReduceCombineDriver<>(); |
| |
| testDriver(testTask, MockCombiningReduceStub.class); |
| |
| int expSum = 0; |
| for (int i = 1; i < valCnt; i++) { |
| expSum += i; |
| } |
| |
| assertTrue(this.outList.size() == keyCnt); |
| |
| for (Tuple2<Integer, Integer> record : this.outList) { |
| assertTrue(record.f1 == expSum); |
| } |
| |
| this.outList.clear(); |
| } catch (Exception e) { |
| e.printStackTrace(); |
| fail(e.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testFailingCombineTask() { |
| try { |
| int keyCnt = 100; |
| int valCnt = 20; |
| |
| setInput(new UniformIntTupleGenerator(keyCnt, valCnt, false), serializer); |
| addDriverComparator(this.comparator); |
| addDriverComparator(this.comparator); |
| setOutput(new DiscardingOutputCollector<Tuple2<Integer, Integer>>()); |
| |
| getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE); |
| getTaskConfig().setRelativeMemoryDriver(combine_frac); |
| getTaskConfig().setFilehandlesDriver(2); |
| |
| final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> |
| testTask = new GroupReduceCombineDriver<>(); |
| |
| try { |
| testDriver(testTask, MockFailingCombiningReduceStub.class); |
| fail("Exception not forwarded."); |
| } catch (ExpectedTestException etex) { |
| // good! |
| } |
| } catch (Exception e) { |
| e.printStackTrace(); |
| fail(e.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testCancelCombineTaskSorting() { |
| try { |
| MutableObjectIterator<Tuple2<Integer, Integer>> slowInfiniteInput = |
| new DelayingIterator<>(new InfiniteIntTupleIterator(), 1); |
| |
| setInput(slowInfiniteInput, serializer); |
| addDriverComparator(this.comparator); |
| addDriverComparator(this.comparator); |
| setOutput(new DiscardingOutputCollector<Tuple2<Integer, Integer>>()); |
| |
| getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE); |
| getTaskConfig().setRelativeMemoryDriver(combine_frac); |
| getTaskConfig().setFilehandlesDriver(2); |
| |
| final GroupReduceCombineDriver<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> |
| testTask = new GroupReduceCombineDriver<>(); |
| |
| Thread taskRunner = |
| new Thread() { |
| @Override |
| public void run() { |
| try { |
| testDriver(testTask, MockFailingCombiningReduceStub.class); |
| } catch (Exception e) { |
| // exceptions may happen during canceling |
| } |
| } |
| }; |
| taskRunner.start(); |
| |
| // give the task some time |
| Thread.sleep(500); |
| |
| // cancel |
| testTask.cancel(); |
| |
| // make sure it reacts to the canceling in some time |
| long deadline = System.currentTimeMillis() + 10000; |
| do { |
| taskRunner.interrupt(); |
| taskRunner.join(5000); |
| } while (taskRunner.isAlive() && System.currentTimeMillis() < deadline); |
| |
| assertFalse("Task did not cancel properly within in 10 seconds.", taskRunner.isAlive()); |
| } catch (Exception e) { |
| e.printStackTrace(); |
| fail(e.getMessage()); |
| } |
| } |
| |
| // ------------------------------------------------------------------------ |
| // Test Combiners |
| // ------------------------------------------------------------------------ |
| |
| public static class MockCombiningReduceStub |
| implements GroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>, |
| GroupCombineFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void reduce( |
| Iterable<Tuple2<Integer, Integer>> records, |
| Collector<Tuple2<Integer, Integer>> out) { |
| int key = 0; |
| int sum = 0; |
| |
| for (Tuple2<Integer, Integer> next : records) { |
| key = next.f0; |
| sum += next.f1; |
| } |
| |
| out.collect(new Tuple2<>(key, sum)); |
| } |
| |
| @Override |
| public void combine( |
| Iterable<Tuple2<Integer, Integer>> records, |
| Collector<Tuple2<Integer, Integer>> out) { |
| reduce(records, out); |
| } |
| } |
| |
| public static final class MockFailingCombiningReduceStub |
| implements GroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>, |
| GroupCombineFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { |
| private static final long serialVersionUID = 1L; |
| |
| private int cnt; |
| |
| @Override |
| public void reduce( |
| Iterable<Tuple2<Integer, Integer>> records, |
| Collector<Tuple2<Integer, Integer>> out) { |
| int key = 0; |
| int sum = 0; |
| |
| for (Tuple2<Integer, Integer> next : records) { |
| key = next.f0; |
| sum += next.f1; |
| } |
| |
| int resultValue = sum - key; |
| out.collect(new Tuple2<>(key, resultValue)); |
| } |
| |
| @Override |
| public void combine( |
| Iterable<Tuple2<Integer, Integer>> records, |
| Collector<Tuple2<Integer, Integer>> out) { |
| int key = 0; |
| int sum = 0; |
| |
| for (Tuple2<Integer, Integer> next : records) { |
| key = next.f0; |
| sum += next.f1; |
| } |
| |
| if (++this.cnt >= 10) { |
| throw new ExpectedTestException(); |
| } |
| |
| int resultValue = sum - key; |
| out.collect(new Tuple2<>(key, resultValue)); |
| } |
| } |
| } |