[BEAM-10633] UdfImpl should be able to return java.util.List.
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java
index 3ef4d9f..b4a9d7e 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java
@@ -124,7 +124,7 @@
@Override
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
- return CalciteUtils.sqlTypeWithAutoCast(typeFactory, method.getReturnType());
+ return CalciteUtils.sqlTypeWithAutoCast(typeFactory, method.getGenericReturnType());
}
@Override
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index acb4ee1..d7016cb 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.utils;
+import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Date;
import java.util.Map;
@@ -283,18 +284,26 @@
/**
* SQL-Java type mapping, with specified Beam rules: <br>
- * 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can recognize it.
+ * 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can recognize it. <br>
+ * 2. For a list, the component type is needed to create a Sql array type.
*
- * @param rawType
- * @return
+ * @param type
+ * @return Calcite RelDataType
*/
- public static RelDataType sqlTypeWithAutoCast(RelDataTypeFactory typeFactory, Type rawType) {
+ public static RelDataType sqlTypeWithAutoCast(RelDataTypeFactory typeFactory, Type type) {
// For Joda time types, return SQL type for java.util.Date.
- if (rawType instanceof Class && AbstractInstant.class.isAssignableFrom((Class<?>) rawType)) {
+ if (type instanceof Class && AbstractInstant.class.isAssignableFrom((Class<?>) type)) {
return typeFactory.createJavaType(Date.class);
- } else if (rawType instanceof Class && ByteString.class.isAssignableFrom((Class<?>) rawType)) {
+ } else if (type instanceof Class && ByteString.class.isAssignableFrom((Class<?>) type)) {
return typeFactory.createJavaType(byte[].class);
+ } else if (type instanceof ParameterizedType
+ && java.util.List.class.isAssignableFrom(
+ (Class<?>) ((ParameterizedType) type).getRawType())) {
+ ParameterizedType parameterizedType = (ParameterizedType) type;
+ Class<?> genericType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
+ RelDataType collectionElementType = typeFactory.createJavaType(genericType);
+ return typeFactory.createArrayType(collectionElementType, UNLIMITED_ARRAY_SIZE);
}
- return typeFactory.createJavaType((Class) rawType);
+ return typeFactory.createJavaType((Class) type);
}
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
index 75e8a08..c2afc5d 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
@@ -23,6 +23,7 @@
import com.google.auto.service.AutoService;
import java.sql.Timestamp;
+import java.util.Arrays;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable;
@@ -30,6 +31,7 @@
import org.apache.beam.sdk.extensions.sql.meta.provider.UdfUdafProvider;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -101,6 +103,29 @@
pipeline.run().waitUntilFinish();
}
+ @Test
+ public void testListUdf() throws Exception {
+ Schema resultType1 = Schema.builder().addArrayField("array_field", FieldType.INT64).build();
+ Row row1 = Row.withSchema(resultType1).addValue(Arrays.asList(1L)).build();
+ String sql1 = "SELECT test_array(1)";
+ PCollection<Row> result1 =
+ boundedInput1.apply(
+ "testArrayUdf",
+ SqlTransform.query(sql1).registerUdf("test_array", TestReturnTypeList.class));
+ PAssert.that(result1).containsInAnyOrder(row1);
+
+ Schema resultType2 = Schema.builder().addInt32Field("int_field").build();
+ Row row2 = Row.withSchema(resultType2).addValue(3).build();
+ String sql2 = "select array_length(ARRAY[1, 2, 3])";
+ PCollection<Row> result2 =
+ boundedInput1.apply(
+ "testArrayUdf2",
+ SqlTransform.query(sql2).registerUdf("array_length", TestListLength.class));
+ PAssert.that(result2).containsInAnyOrder(row2);
+
+ pipeline.run().waitUntilFinish();
+ }
+
/** Test that an indirect subclass of a {@link CombineFn} works as a UDAF. BEAM-3777 */
@Test
public void testUdafMultiLevelDescendent() {
@@ -347,6 +372,20 @@
}
}
+ /** A UDF to test support of array as return type. */
+ public static final class TestReturnTypeList implements BeamSqlUdf {
+ public static java.util.List<Long> eval(Long i) {
+ return Arrays.asList(i);
+ }
+ }
+
+ /** A UDF to test support of array as argument type. */
+ public static final class TestListLength implements BeamSqlUdf {
+ public static Integer eval(java.util.List<Long> i) {
+ return i.size();
+ }
+ }
+
/**
* UDF to test support for {@link
* org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.TableMacro}.