[FLINK-18683][table-common] Support @DataTypeHint for table/aggregate function output types
Allows to use @DataTypeHint(...) as a synonym @FunctionHint(output = @DataTypeHint(...)) for
table and imperative aggregate functions.
This closes #13149.
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/DataTypeHint.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/DataTypeHint.java
index 5f191c9..0a3f8e0 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/DataTypeHint.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/DataTypeHint.java
@@ -67,6 +67,9 @@
* class is annotated with {@code @DataTypeHint(defaultDecimalPrecision = 12, defaultDecimalScale = 2)}. Individual
* field annotations allow to deviate from those default values.
*
+ * <p>A data type hint on top of a table or aggregate function is similar to defining {@link FunctionHint#output()}
+ * for the output type of the function.
+ *
* @see FunctionHint
*/
@PublicEvolving
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
index 1c6875c..dd9dc6a 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
@@ -35,6 +35,7 @@
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
+import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -45,6 +46,8 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;
+import static org.apache.flink.table.types.extraction.ExtractionUtils.collectAnnotationsOfClass;
+import static org.apache.flink.table.types.extraction.ExtractionUtils.collectAnnotationsOfMethod;
import static org.apache.flink.table.types.extraction.ExtractionUtils.collectMethods;
import static org.apache.flink.table.types.extraction.ExtractionUtils.createMethodSignatureString;
import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError;
@@ -437,9 +440,30 @@
/**
* Extraction that uses a generic type variable for producing a {@link FunctionResultTemplate}.
+ *
+ * <p>If enabled, a {@link DataTypeHint} from method or class has higher priority.
*/
- static ResultExtraction createGenericResultExtraction(Class<? extends UserDefinedFunction> baseClass, int genericPos) {
+ static ResultExtraction createGenericResultExtraction(
+ Class<? extends UserDefinedFunction> baseClass,
+ int genericPos,
+ boolean allowDataTypeHint) {
return (extractor, method) -> {
+ if (allowDataTypeHint) {
+ final Set<DataTypeHint> dataTypeHints = new HashSet<>();
+ dataTypeHints.addAll(collectAnnotationsOfMethod(DataTypeHint.class, method));
+ dataTypeHints.addAll(collectAnnotationsOfClass(DataTypeHint.class, extractor.function));
+ if (dataTypeHints.size() > 1) {
+ throw extractionError(
+ "More than one data type hint found for output of function. " +
+ "Please use a function hint instead.");
+ }
+ if (dataTypeHints.size() == 1) {
+ return FunctionTemplate.createResultTemplate(
+ extractor.typeFactory,
+ dataTypeHints.iterator().next());
+ }
+ // otherwise continue with regular extraction
+ }
final DataType dataType = DataTypeExtractor.extractFromGeneric(
extractor.typeFactory,
baseClass,
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
index 01109f2..e313543 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
@@ -74,6 +74,28 @@
);
}
+ /**
+ * Creates an instance of {@link FunctionResultTemplate} from a {@link DataTypeHint}.
+ */
+ static @Nullable FunctionResultTemplate createResultTemplate(
+ DataTypeFactory typeFactory,
+ @Nullable DataTypeHint hint) {
+ if (hint == null) {
+ return null;
+ }
+ final DataTypeTemplate template;
+ try {
+ template = DataTypeTemplate.fromAnnotation(typeFactory, hint);
+ } catch (Throwable t) {
+ throw extractionError(t, "Error in data type hint annotation.");
+ }
+ if (template.dataType != null) {
+ return FunctionResultTemplate.of(template.dataType);
+ }
+ throw extractionError(
+ "Data type hint does not specify a data type for use as function result.");
+ }
+
@Nullable FunctionSignatureTemplate getSignatureTemplate() {
return signatureTemplate;
}
@@ -141,25 +163,6 @@
argumentNames);
}
- private static @Nullable FunctionResultTemplate createResultTemplate(
- DataTypeFactory typeFactory,
- @Nullable DataTypeHint hint) {
- if (hint == null) {
- return null;
- }
- final DataTypeTemplate template;
- try {
- template = DataTypeTemplate.fromAnnotation(typeFactory, hint);
- } catch (Throwable t) {
- throw extractionError(t, "Error in data type hint annotation.");
- }
- if (template.dataType != null) {
- return FunctionResultTemplate.of(template.dataType);
- }
- throw extractionError(
- "Data type hint does not specify a data type for use as function result.");
- }
-
private static FunctionArgumentTemplate createArgumentTemplate(
DataTypeFactory typeFactory,
DataTypeHint hint) {
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
index fd95927..e486a3b 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
@@ -96,8 +96,8 @@
function,
UserDefinedFunctionHelper.AGGREGATE_ACCUMULATE,
createParameterSignatureExtraction(1),
- createGenericResultExtraction(AggregateFunction.class, 1),
- createGenericResultExtraction(AggregateFunction.class, 0),
+ createGenericResultExtraction(AggregateFunction.class, 1, false),
+ createGenericResultExtraction(AggregateFunction.class, 0, true),
createParameterWithAccumulatorVerification());
return extractTypeInference(mappingExtractor);
}
@@ -114,7 +114,7 @@
UserDefinedFunctionHelper.TABLE_EVAL,
createParameterSignatureExtraction(0),
null,
- createGenericResultExtraction(TableFunction.class, 0),
+ createGenericResultExtraction(TableFunction.class, 0, true),
createParameterVerification());
return extractTypeInference(mappingExtractor);
}
@@ -130,8 +130,8 @@
function,
UserDefinedFunctionHelper.TABLE_AGGREGATE_ACCUMULATE,
createParameterSignatureExtraction(1),
- createGenericResultExtraction(TableAggregateFunction.class, 1),
- createGenericResultExtraction(TableAggregateFunction.class, 0),
+ createGenericResultExtraction(TableAggregateFunction.class, 1, false),
+ createGenericResultExtraction(TableAggregateFunction.class, 0, true),
createParameterWithAccumulatorVerification());
return extractTypeInference(mappingExtractor);
}
@@ -148,7 +148,7 @@
UserDefinedFunctionHelper.ASYNC_TABLE_EVAL,
createParameterSignatureExtraction(1),
null,
- createGenericResultExtraction(AsyncTableFunction.class, 0),
+ createGenericResultExtraction(AsyncTableFunction.class, 0, true),
createParameterWithArgumentVerification(CompletableFuture.class));
return extractTypeInference(mappingExtractor);
}
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
index aeb7918..ca7bd70 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
@@ -23,6 +23,7 @@
import org.apache.flink.table.annotation.InputGroup;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.TableAggregateFunction;
@@ -379,7 +380,52 @@
TypeStrategies.explicit(DataTypes.BIGINT()))
.expectOutputMapping(
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT())),
- TypeStrategies.explicit(DataTypes.INT()))
+ TypeStrategies.explicit(DataTypes.INT())),
+
+ TestSpec
+ .forTableFunction(
+ "A data type hint on the class is used instead of a function output hint",
+ DataTypeHintOnTableFunctionClass.class)
+ .expectNamedArguments()
+ .expectTypedArguments()
+ .expectOutputMapping(
+ InputTypeStrategies.sequence(
+ new String[]{},
+ new ArgumentTypeStrategy[]{}),
+ TypeStrategies.explicit(DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))),
+
+ TestSpec
+ .forTableFunction(
+ "A data type hint on the method is used instead of a function output hint",
+ DataTypeHintOnTableFunctionMethod.class)
+ .expectNamedArguments("i")
+ .expectTypedArguments(DataTypes.INT())
+ .expectOutputMapping(
+ InputTypeStrategies.sequence(
+ new String[]{"i"},
+ new ArgumentTypeStrategy[]{InputTypeStrategies.explicit(DataTypes.INT())}),
+ TypeStrategies.explicit(DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))),
+
+ TestSpec
+ .forTableFunction(
+ "Invalid data type hint on top of method and class",
+ InvalidDataTypeHintOnTableFunction.class)
+ .expectErrorMessage(
+ "More than one data type hint found for output of function. " +
+ "Please use a function hint instead."),
+
+ TestSpec
+ .forScalarFunction(
+ "A data type hint on the method is used for enriching (not a function output hint)",
+ DataTypeHintOnScalarFunction.class)
+ .expectNamedArguments()
+ .expectTypedArguments()
+ .expectOutputMapping(
+ InputTypeStrategies.sequence(
+ new String[]{},
+ new ArgumentTypeStrategy[]{}),
+ TypeStrategies.explicit(
+ DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())).bridgedTo(RowData.class)))
);
}
@@ -861,4 +907,32 @@
return n;
}
}
+
+ @DataTypeHint("ROW<i INT>")
+ private static class DataTypeHintOnTableFunctionClass extends TableFunction<Row> {
+ public void eval() {
+ // nothing to do
+ }
+ }
+
+ private static class DataTypeHintOnTableFunctionMethod extends TableFunction<Row> {
+ @DataTypeHint("ROW<i INT>")
+ public void eval(Integer i) {
+ // nothing to do
+ }
+ }
+
+ @DataTypeHint("ROW<i BOOLEAN>")
+ private static class InvalidDataTypeHintOnTableFunction extends TableFunction<Row> {
+ @DataTypeHint("ROW<i INT>")
+ public void eval(Integer i) {
+ // nothing to do
+ }
+ }
+
+ private static class DataTypeHintOnScalarFunction extends ScalarFunction {
+ public @DataTypeHint("ROW<i INT>") RowData eval() {
+ return null;
+ }
+ }
}