[FLINK-33563] Implement type inference for Agg functions
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index fa1441f..6f34233 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -672,14 +672,16 @@
BuiltInFunctionDefinition.newBuilder()
.name("collect")
.kind(AGGREGATE)
- .outputTypeStrategy(TypeStrategies.MISSING)
+ .inputTypeStrategy(sequence(ANY))
+ .outputTypeStrategy(SpecificTypeStrategies.COLLECT)
.build();
public static final BuiltInFunctionDefinition DISTINCT =
BuiltInFunctionDefinition.newBuilder()
.name("distinct")
.kind(AGGREGATE)
- .outputTypeStrategy(TypeStrategies.MISSING)
+ .inputTypeStrategy(sequence(ANY))
+ .outputTypeStrategy(argument(0))
.build();
// --------------------------------------------------------------------------------------------
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategy.java
new file mode 100644
index 0000000..df03ba7
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategy.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.TypeStrategy;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Type strategy that returns a {@link DataTypes#MULTISET(DataType)} with element type equal to the
+ * type of the first argument.
+ */
+@Internal
+public class CollectTypeStrategy implements TypeStrategy {
+
+ @Override
+ public Optional<DataType> inferType(CallContext callContext) {
+ List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
+ if (argumentDataTypes.size() != 1) {
+ return Optional.empty();
+ }
+
+ return Optional.of(DataTypes.MULTISET(argumentDataTypes.get(0)).notNull());
+ }
+}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
index 4fb0f64..bf9c41d 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
@@ -49,6 +49,9 @@
/** See {@link MapTypeStrategy}. */
public static final TypeStrategy MAP = new MapTypeStrategy();
+ /** See {@link CollectTypeStrategy}. */
+ public static final TypeStrategy COLLECT = new CollectTypeStrategy();
+
/** See {@link IfNullTypeStrategy}. */
public static final TypeStrategy IF_NULL = new IfNullTypeStrategy();
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategyTest.java
new file mode 100644
index 0000000..5590375
--- /dev/null
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CollectTypeStrategyTest.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.types.inference.TypeStrategiesTestBase;
+
+import java.util.stream.Stream;
+
+/** Tests for {@link CollectTypeStrategy}. */
+class CollectTypeStrategyTest extends TypeStrategiesTestBase {
+
+ @Override
+ protected Stream<TestSpec> testData() {
+ return Stream.of(
+ TestSpec.forStrategy("Infer a collect type", SpecificTypeStrategies.COLLECT)
+ .inputTypes(DataTypes.BIGINT())
+ .expectDataType(DataTypes.MULTISET(DataTypes.BIGINT()).notNull()));
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
index d2c7087..9272f29 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
@@ -117,14 +117,6 @@
case fd: FunctionDefinition =>
fd match {
- case DISTINCT =>
- assert(args.size == 1)
- DistinctAgg(args.head)
-
- case COLLECT =>
- assert(args.size == 1)
- Collect(args.head)
-
case ORDER_ASC =>
assert(args.size == 1)
Asc(args.head)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
index f13d389..d3c19b8 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
@@ -48,39 +48,6 @@
.fromDataTypeToLegacyInfo(resolvedCall.getOutputDataType)
}
-case class DistinctAgg(child: PlannerExpression) extends Aggregation {
-
- def distinct: PlannerExpression = DistinctAgg(child)
-
- override private[flink] def resultType: TypeInformation[_] = child.resultType
-
- override private[flink] def validateInput(): ValidationResult = {
- super.validateInput()
- child match {
- case agg: Aggregation =>
- child.validateInput()
- case _ =>
- ValidationFailure(
- s"Distinct modifier cannot be applied to $child! " +
- s"It can only be applied to an aggregation expression, for example, " +
- s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
- }
- }
-
- override private[flink] def children = Seq(child)
-}
-
-/** Returns a multiset aggregates. */
-case class Collect(child: PlannerExpression) extends Aggregation {
-
- override private[flink] def children: Seq[PlannerExpression] = Seq(child)
-
- override private[flink] def resultType: TypeInformation[_] =
- MultisetTypeInfo.getInfoFor(child.resultType)
-
- override def toString: String = s"collect($child)"
-}
-
/** Expression for calling a user-defined (table)aggregate function. */
case class AggFunctionCall(
aggregateFunction: ImperativeAggregateFunction[_, _],