| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.test.manual; |
| |
| import org.apache.flink.api.common.functions.CrossFunction; |
| import org.apache.flink.api.common.functions.FilterFunction; |
| import org.apache.flink.api.common.functions.JoinFunction; |
| import org.apache.flink.api.common.functions.ReduceFunction; |
| import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint; |
| import org.apache.flink.api.java.DataSet; |
| import org.apache.flink.api.java.ExecutionEnvironment; |
| import org.apache.flink.api.java.tuple.Tuple2; |
| import org.apache.flink.api.java.typeutils.TupleTypeInfo; |
| import org.apache.flink.types.IntValue; |
| |
| import org.junit.Assert; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.io.Serializable; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Random; |
| |
| import static org.hamcrest.Matchers.is; |
| |
| /** |
| * These programs demonstrate the effects of user defined functions which modify input objects or |
| * return locally created objects that are retained and reused on future calls. The programs do not |
| * retain and later modify input objects. |
| */ |
| public class OverwriteObjects { |
| |
| public static final Logger LOG = LoggerFactory.getLogger(OverwriteObjects.class); |
| |
| // DataSets are created with this number of elements |
| private static final int NUMBER_OF_ELEMENTS = 3_000_000; |
| |
| // DataSet values are randomly generated over this range |
| private static final int KEY_RANGE = 1_000_000; |
| |
| private static final int MAX_PARALLELISM = 4; |
| |
| private static final long RANDOM_SEED = new Random().nextLong(); |
| |
| private static final Tuple2Comparator<IntValue, IntValue> comparator = new Tuple2Comparator<>(); |
| |
| public static void main(String[] args) throws Exception { |
| new OverwriteObjects().run(); |
| } |
| |
| public void run() throws Exception { |
| LOG.info("Random seed = {}", RANDOM_SEED); |
| |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| for (int parallelism = MAX_PARALLELISM; parallelism > 0; parallelism--) { |
| LOG.info("Parallelism = {}", parallelism); |
| |
| env.setParallelism(parallelism); |
| |
| testReduce(env); |
| testGroupedReduce(env); |
| testJoin(env); |
| testCross(env); |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| public void testReduce(ExecutionEnvironment env) throws Exception { |
| /* |
| * Test ChainedAllReduceDriver |
| */ |
| |
| LOG.info("Testing reduce"); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| Tuple2<IntValue, IntValue> enabledResult = |
| getDataSet(env).reduce(new OverwriteObjectsReduce(false)).collect().get(0); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| Tuple2<IntValue, IntValue> disabledResult = |
| getDataSet(env).reduce(new OverwriteObjectsReduce(false)).collect().get(0); |
| |
| Assert.assertEquals(NUMBER_OF_ELEMENTS, enabledResult.f1.getValue()); |
| Assert.assertEquals(NUMBER_OF_ELEMENTS, disabledResult.f1.getValue()); |
| |
| Assert.assertEquals(disabledResult, enabledResult); |
| } |
| |
| public void testGroupedReduce(ExecutionEnvironment env) throws Exception { |
| /* |
| * Test ReduceCombineDriver and ReduceDriver |
| */ |
| |
| LOG.info("Testing grouped reduce"); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| List<Tuple2<IntValue, IntValue>> enabledResult = |
| getDataSet(env).groupBy(0).reduce(new OverwriteObjectsReduce(true)).collect(); |
| |
| Collections.sort(enabledResult, comparator); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| List<Tuple2<IntValue, IntValue>> disabledResult = |
| getDataSet(env).groupBy(0).reduce(new OverwriteObjectsReduce(true)).collect(); |
| |
| Collections.sort(disabledResult, comparator); |
| |
| Assert.assertThat(disabledResult, is(enabledResult)); |
| } |
| |
| private class OverwriteObjectsReduce implements ReduceFunction<Tuple2<IntValue, IntValue>> { |
| private Scrambler scrambler; |
| |
| public OverwriteObjectsReduce(boolean keyed) { |
| scrambler = new Scrambler(keyed); |
| } |
| |
| @Override |
| public Tuple2<IntValue, IntValue> reduce( |
| Tuple2<IntValue, IntValue> a, Tuple2<IntValue, IntValue> b) throws Exception { |
| return scrambler.scramble(a, b); |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| public void testJoin(ExecutionEnvironment env) throws Exception { |
| /* |
| * Test JoinDriver, LeftOuterJoinDriver, RightOuterJoinDriver, and FullOuterJoinDriver |
| */ |
| |
| for (JoinHint joinHint : JoinHint.values()) { |
| if (joinHint == JoinHint.OPTIMIZER_CHOOSES) { |
| continue; |
| } |
| |
| List<Tuple2<IntValue, IntValue>> enabledResult; |
| |
| List<Tuple2<IntValue, IntValue>> disabledResult; |
| |
| // Inner join |
| |
| LOG.info("Testing inner join with JoinHint = {}", joinHint); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| enabledResult = |
| getDataSet(env) |
| .join(getDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(enabledResult, comparator); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| disabledResult = |
| getDataSet(env) |
| .join(getDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(disabledResult, comparator); |
| |
| Assert.assertEquals("JoinHint=" + joinHint, disabledResult, enabledResult); |
| |
| // Left outer join |
| |
| if (joinHint != JoinHint.BROADCAST_HASH_FIRST) { |
| LOG.info("Testing left outer join with JoinHint = {}", joinHint); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| enabledResult = |
| getDataSet(env) |
| .leftOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(enabledResult, comparator); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| disabledResult = |
| getDataSet(env) |
| .leftOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(disabledResult, comparator); |
| |
| Assert.assertThat("JoinHint=" + joinHint, disabledResult, is(enabledResult)); |
| } |
| |
| // Right outer join |
| |
| if (joinHint != JoinHint.BROADCAST_HASH_SECOND) { |
| LOG.info("Testing right outer join with JoinHint = {}", joinHint); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| enabledResult = |
| getDataSet(env) |
| .rightOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(enabledResult, comparator); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| disabledResult = |
| getDataSet(env) |
| .rightOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(disabledResult, comparator); |
| |
| Assert.assertThat("JoinHint=" + joinHint, disabledResult, is(enabledResult)); |
| } |
| |
| // Full outer join |
| |
| if (joinHint != JoinHint.BROADCAST_HASH_FIRST |
| && joinHint != JoinHint.BROADCAST_HASH_SECOND) { |
| LOG.info("Testing full outer join with JoinHint = {}", joinHint); |
| |
| env.getConfig().enableObjectReuse(); |
| |
| enabledResult = |
| getDataSet(env) |
| .fullOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(enabledResult, comparator); |
| |
| env.getConfig().disableObjectReuse(); |
| |
| disabledResult = |
| getDataSet(env) |
| .fullOuterJoin(getFilteredDataSet(env), joinHint) |
| .where(0) |
| .equalTo(0) |
| .with(new OverwriteObjectsJoin()) |
| .collect(); |
| |
| Collections.sort(disabledResult, comparator); |
| |
| Assert.assertThat("JoinHint=" + joinHint, disabledResult, is(enabledResult)); |
| } |
| } |
| } |
| |
| private class OverwriteObjectsJoin |
| implements JoinFunction< |
| Tuple2<IntValue, IntValue>, |
| Tuple2<IntValue, IntValue>, |
| Tuple2<IntValue, IntValue>> { |
| private Scrambler scrambler = new Scrambler(true); |
| |
| @Override |
| public Tuple2<IntValue, IntValue> join( |
| Tuple2<IntValue, IntValue> a, Tuple2<IntValue, IntValue> b) throws Exception { |
| return scrambler.scramble(a, b); |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| public void testCross(ExecutionEnvironment env) throws Exception { |
| /* |
| * Test CrossDriver |
| */ |
| |
| LOG.info("Testing cross"); |
| |
| DataSet<Tuple2<IntValue, IntValue>> small = getDataSet(env, 100, 20); |
| DataSet<Tuple2<IntValue, IntValue>> large = getDataSet(env, 10000, 2000); |
| |
| // test NESTEDLOOP_BLOCKED_OUTER_FIRST and NESTEDLOOP_BLOCKED_OUTER_SECOND with object reuse |
| // enabled |
| |
| env.getConfig().enableObjectReuse(); |
| |
| List<Tuple2<IntValue, IntValue>> enabledResultWithHuge = |
| small.crossWithHuge(large).with(new OverwriteObjectsCross()).collect(); |
| |
| List<Tuple2<IntValue, IntValue>> enabledResultWithTiny = |
| small.crossWithTiny(large).with(new OverwriteObjectsCross()).collect(); |
| |
| Assert.assertThat(enabledResultWithHuge, is(enabledResultWithTiny)); |
| |
| // test NESTEDLOOP_BLOCKED_OUTER_FIRST and NESTEDLOOP_BLOCKED_OUTER_SECOND with object reuse |
| // disabled |
| |
| env.getConfig().disableObjectReuse(); |
| |
| List<Tuple2<IntValue, IntValue>> disabledResultWithHuge = |
| small.crossWithHuge(large).with(new OverwriteObjectsCross()).collect(); |
| |
| List<Tuple2<IntValue, IntValue>> disabledResultWithTiny = |
| small.crossWithTiny(large).with(new OverwriteObjectsCross()).collect(); |
| |
| Assert.assertThat(disabledResultWithHuge, is(disabledResultWithTiny)); |
| |
| // verify match between object reuse enabled and disabled |
| Assert.assertThat(disabledResultWithHuge, is(enabledResultWithHuge)); |
| Assert.assertThat(disabledResultWithTiny, is(enabledResultWithTiny)); |
| } |
| |
| private class OverwriteObjectsCross |
| implements CrossFunction< |
| Tuple2<IntValue, IntValue>, |
| Tuple2<IntValue, IntValue>, |
| Tuple2<IntValue, IntValue>> { |
| private Scrambler scrambler = new Scrambler(true); |
| |
| @Override |
| public Tuple2<IntValue, IntValue> cross( |
| Tuple2<IntValue, IntValue> a, Tuple2<IntValue, IntValue> b) throws Exception { |
| return scrambler.scramble(a, b); |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| private DataSet<Tuple2<IntValue, IntValue>> getDataSet( |
| ExecutionEnvironment env, int numberOfElements, int keyRange) { |
| return env.fromCollection( |
| new TupleIntValueIntValueIterator(numberOfElements, keyRange), |
| TupleTypeInfo.<Tuple2<IntValue, IntValue>>getBasicAndBasicValueTupleTypeInfo( |
| IntValue.class, IntValue.class)); |
| } |
| |
| private DataSet<Tuple2<IntValue, IntValue>> getDataSet(ExecutionEnvironment env) { |
| return getDataSet(env, NUMBER_OF_ELEMENTS, KEY_RANGE); |
| } |
| |
| private DataSet<Tuple2<IntValue, IntValue>> getFilteredDataSet(ExecutionEnvironment env) { |
| return getDataSet(env) |
| .filter( |
| new FilterFunction<Tuple2<IntValue, IntValue>>() { |
| @Override |
| public boolean filter(Tuple2<IntValue, IntValue> value) |
| throws Exception { |
| return (value.f0.getValue() % 2) == 0; |
| } |
| }); |
| } |
| |
| private static class TupleIntValueIntValueIterator |
| implements Iterator<Tuple2<IntValue, IntValue>>, Serializable { |
| private int numElements; |
| private final int keyRange; |
| private Tuple2<IntValue, IntValue> ret = new Tuple2<>(new IntValue(), new IntValue()); |
| |
| public TupleIntValueIntValueIterator(int numElements, int keyRange) { |
| this.numElements = numElements; |
| this.keyRange = keyRange; |
| } |
| |
| private final Random rnd = new Random(123); |
| |
| @Override |
| public boolean hasNext() { |
| return numElements > 0; |
| } |
| |
| @Override |
| public Tuple2<IntValue, IntValue> next() { |
| numElements--; |
| ret.f0.setValue(rnd.nextInt(keyRange)); |
| ret.f1.setValue(1); |
| return ret; |
| } |
| |
| @Override |
| public void remove() { |
| throw new UnsupportedOperationException(); |
| } |
| } |
| |
| private static class Tuple2Comparator<T0 extends Comparable<T0>, T1 extends Comparable<T1>> |
| implements Comparator<Tuple2<T0, T1>> { |
| @Override |
| public int compare(Tuple2<T0, T1> o1, Tuple2<T0, T1> o2) { |
| int cmp = o1.f0.compareTo(o2.f0); |
| |
| if (cmp != 0) { |
| return cmp; |
| } |
| |
| return o1.f1.compareTo(o2.f1); |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| private static class Scrambler implements Serializable { |
| private Tuple2<IntValue, IntValue> d = new Tuple2<>(new IntValue(), new IntValue()); |
| |
| private final boolean keyed; |
| |
| public Scrambler(boolean keyed) { |
| this.keyed = keyed; |
| } |
| |
| public Tuple2<IntValue, IntValue> scramble( |
| Tuple2<IntValue, IntValue> a, Tuple2<IntValue, IntValue> b) { |
| /* |
| * Scramble all fields except returned object's key |
| * |
| * Randomly select among four return values: |
| * |
| * 0) return first object (a) |
| * 1) return second object (b) |
| * 2) return new object |
| * 3) return reused local object (d) |
| */ |
| |
| Random random = new Random(RANDOM_SEED); |
| |
| if (a != null && b != null) { |
| random.setSeed((((long) a.f0.getValue()) << 32) + b.f0.getValue()); |
| } else if (a != null) { |
| random.setSeed(a.f0.getValue()); |
| } else if (b != null) { |
| random.setSeed(b.f0.getValue()); |
| } else { |
| throw new RuntimeException("One of a or b should be not null"); |
| } |
| |
| Tuple2<IntValue, IntValue> result; |
| |
| switch (random.nextInt(4)) { |
| case 0: |
| result = a; |
| break; |
| case 1: |
| result = b; |
| break; |
| case 2: |
| result = d; |
| break; |
| case 3: |
| result = new Tuple2<>(new IntValue(), new IntValue()); |
| break; |
| default: |
| throw new RuntimeException("Unexpected value in switch statement"); |
| } |
| |
| if (a == null || b == null) { |
| // null values are seen when processing outer joins |
| if (result == null) { |
| result = d; |
| } |
| |
| if (a == null) { |
| b.f0.copyTo(result.f0); |
| b.f1.copyTo(result.f1); |
| } else { |
| a.f0.copyTo(result.f0); |
| a.f1.copyTo(result.f1); |
| } |
| } else { |
| if (keyed) { |
| result.f0.setValue(a.f0.getValue()); |
| } else { |
| result.f0.setValue(a.f0.getValue() + b.f0.getValue()); |
| } |
| |
| result.f1.setValue(a.f1.getValue() + b.f1.getValue()); |
| } |
| |
| scrambleIfNot(a, result); |
| scrambleIfNot(b, result); |
| scrambleIfNot(d, result); |
| |
| return result; |
| } |
| |
| private Random random = new Random(~RANDOM_SEED); |
| |
| private void scrambleIfNot(Tuple2<IntValue, IntValue> t, Object o) { |
| // verify that the tuple is not null and the same as the |
| // comparison object, then scramble the fields |
| if (t != null && t != o) { |
| t.f0.setValue(random.nextInt()); |
| t.f1.setValue(random.nextInt()); |
| } |
| } |
| } |
| } |