| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.test.operators; |
| |
| import org.apache.flink.api.common.functions.CoGroupFunction; |
| import org.apache.flink.api.common.functions.FlatJoinFunction; |
| import org.apache.flink.api.common.functions.FlatMapFunction; |
| import org.apache.flink.api.common.functions.GroupCombineFunction; |
| import org.apache.flink.api.common.functions.GroupReduceFunction; |
| import org.apache.flink.api.common.functions.JoinFunction; |
| import org.apache.flink.api.common.functions.MapFunction; |
| import org.apache.flink.api.common.operators.Order; |
| import org.apache.flink.api.common.typeinfo.BasicTypeInfo; |
| import org.apache.flink.api.common.typeinfo.TypeHint; |
| import org.apache.flink.api.java.DataSet; |
| import org.apache.flink.api.java.ExecutionEnvironment; |
| import org.apache.flink.api.java.tuple.Tuple3; |
| import org.apache.flink.api.java.typeutils.TupleTypeInfo; |
| import org.apache.flink.test.operators.util.CollectionDataSets; |
| import org.apache.flink.test.util.AbstractTestBase; |
| import org.apache.flink.util.Collector; |
| |
| import org.junit.Test; |
| |
| import java.util.List; |
| |
| /** Integration tests for {@link org.apache.flink.api.common.typeinfo.TypeHint}. */ |
| public class TypeHintITCase extends AbstractTestBase { |
| |
| @Test |
| public void testIdentityMapWithMissingTypesAndStringTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Tuple3<Integer, Long, String>> identityMapDs = |
| ds.map(new Mapper<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>>()) |
| .returns(new TypeHint<Tuple3<Integer, Long, String>>() {}); |
| List<Tuple3<Integer, Long, String>> result = identityMapDs.collect(); |
| |
| String expectedResult = "(2,2,Hello)\n" + "(3,2,Hello world)\n" + "(1,1,Hi)\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testIdentityMapWithMissingTypesAndTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Tuple3<Integer, Long, String>> identityMapDs = |
| ds |
| // all following generics get erased during compilation |
| .map( |
| new Mapper< |
| Tuple3<Integer, Long, String>, |
| Tuple3<Integer, Long, String>>()) |
| .returns( |
| new TupleTypeInfo<Tuple3<Integer, Long, String>>( |
| BasicTypeInfo.INT_TYPE_INFO, |
| BasicTypeInfo.LONG_TYPE_INFO, |
| BasicTypeInfo.STRING_TYPE_INFO)); |
| List<Tuple3<Integer, Long, String>> result = identityMapDs.collect(); |
| |
| String expectedResult = "(2,2,Hello)\n" + "(3,2,Hello world)\n" + "(1,1,Hi)\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testFlatMapWithClassTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> identityMapDs = |
| ds.flatMap(new FlatMapper<Tuple3<Integer, Long, String>, Integer>()) |
| .returns(Integer.class); |
| List<Integer> result = identityMapDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testJoinWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds1.join(ds2) |
| .where(0) |
| .equalTo(0) |
| .with( |
| new Joiner< |
| Tuple3<Integer, Long, String>, |
| Tuple3<Integer, Long, String>, |
| Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testFlatJoinWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds1.join(ds2) |
| .where(0) |
| .equalTo(0) |
| .with( |
| new FlatJoiner< |
| Tuple3<Integer, Long, String>, |
| Tuple3<Integer, Long, String>, |
| Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testUnsortedGroupReduceWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds.groupBy(0) |
| .reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testSortedGroupReduceWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds.groupBy(0) |
| .sortGroup(0, Order.ASCENDING) |
| .reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testCombineGroupWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds.groupBy(0) |
| .combineGroup(new GroupCombiner<Tuple3<Integer, Long, String>, Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| @Test |
| public void testCoGroupWithTypeInformationTypeHint() throws Exception { |
| final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); |
| |
| DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); |
| DataSet<Integer> resultDs = |
| ds1.coGroup(ds2) |
| .where(0) |
| .equalTo(0) |
| .with( |
| new CoGrouper< |
| Tuple3<Integer, Long, String>, |
| Tuple3<Integer, Long, String>, |
| Integer>()) |
| .returns(BasicTypeInfo.INT_TYPE_INFO); |
| List<Integer> result = resultDs.collect(); |
| |
| String expectedResult = "2\n" + "3\n" + "1\n"; |
| |
| compareResultAsText(result, expectedResult); |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| private static class Mapper<T, V> implements MapFunction<T, V> { |
| private static final long serialVersionUID = 1L; |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public V map(T value) throws Exception { |
| return (V) value; |
| } |
| } |
| |
| private static class FlatMapper<T, V> implements FlatMapFunction<T, V> { |
| private static final long serialVersionUID = 1L; |
| |
| @SuppressWarnings({"unchecked", "rawtypes"}) |
| @Override |
| public void flatMap(T value, Collector<V> out) throws Exception { |
| out.collect((V) ((Tuple3) value).f0); |
| } |
| } |
| |
| private static class Joiner<IN1, IN2, OUT> implements JoinFunction<IN1, IN2, OUT> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public OUT join(IN1 first, IN2 second) throws Exception { |
| return (OUT) ((Tuple3) first).f0; |
| } |
| } |
| |
| private static class FlatJoiner<IN1, IN2, OUT> implements FlatJoinFunction<IN1, IN2, OUT> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void join(IN1 first, IN2 second, Collector<OUT> out) throws Exception { |
| out.collect((OUT) ((Tuple3) first).f0); |
| } |
| } |
| |
| private static class GroupReducer<IN, OUT> implements GroupReduceFunction<IN, OUT> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void reduce(Iterable<IN> values, Collector<OUT> out) throws Exception { |
| out.collect((OUT) ((Tuple3) values.iterator().next()).f0); |
| } |
| } |
| |
| private static class GroupCombiner<IN, OUT> implements GroupCombineFunction<IN, OUT> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void combine(Iterable<IN> values, Collector<OUT> out) throws Exception { |
| out.collect((OUT) ((Tuple3) values.iterator().next()).f0); |
| } |
| } |
| |
| private static class CoGrouper<IN1, IN2, OUT> implements CoGroupFunction<IN1, IN2, OUT> { |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| public void coGroup(Iterable<IN1> first, Iterable<IN2> second, Collector<OUT> out) |
| throws Exception { |
| out.collect((OUT) ((Tuple3) first.iterator().next()).f0); |
| } |
| } |
| } |