| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.extensions.sql; |
| |
| import static org.hamcrest.Matchers.containsString; |
| import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage; |
| |
| import org.apache.beam.sdk.extensions.sql.impl.ParseException; |
| import org.apache.beam.sdk.schemas.Schema; |
| import org.apache.beam.sdk.testing.PAssert; |
| import org.apache.beam.sdk.transforms.Combine.CombineFn; |
| import org.apache.beam.sdk.transforms.SerializableFunction; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionTuple; |
| import org.apache.beam.sdk.values.Row; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.calcite.linq4j.function.Parameter; |
| import org.junit.Test; |
| |
| /** Tests for UDF/UDAF. */ |
| public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase { |
| |
| /** GROUP-BY with UDAF. */ |
| @Test |
| public void testUdaf() throws Exception { |
| Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build(); |
| |
| Row row = Row.withSchema(resultType).addValues(0, 30).build(); |
| |
| String sql1 = |
| "SELECT f_int2, squaresum1(f_int) AS `squaresum`" + " FROM PCOLLECTION GROUP BY f_int2"; |
| PCollection<Row> result1 = |
| boundedInput1.apply( |
| "testUdaf1", SqlTransform.query(sql1).registerUdaf("squaresum1", new SquareSum())); |
| PAssert.that(result1).containsInAnyOrder(row); |
| |
| String sql2 = |
| "SELECT f_int2, squaresum2(f_int) AS `squaresum`" + " FROM PCOLLECTION GROUP BY f_int2"; |
| PCollection<Row> result2 = |
| PCollectionTuple.of(new TupleTag<>("PCOLLECTION"), boundedInput1) |
| .apply( |
| "testUdaf2", SqlTransform.query(sql2).registerUdaf("squaresum2", new SquareSum())); |
| PAssert.that(result2).containsInAnyOrder(row); |
| |
| pipeline.run().waitUntilFinish(); |
| } |
| |
| /** Test that an indirect subclass of a {@link CombineFn} works as a UDAF. BEAM-3777 */ |
| @Test |
| public void testUdafMultiLevelDescendent() { |
| Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build(); |
| |
| Row row = Row.withSchema(resultType).addValues(0, 354).build(); |
| |
| String sql1 = |
| "SELECT f_int2, double_square_sum(f_int) AS `squaresum`" |
| + " FROM PCOLLECTION GROUP BY f_int2"; |
| PCollection<Row> result1 = |
| boundedInput1.apply( |
| "testUdaf", |
| SqlTransform.query(sql1).registerUdaf("double_square_sum", new SquareSquareSum())); |
| PAssert.that(result1).containsInAnyOrder(row); |
| |
| pipeline.run().waitUntilFinish(); |
| } |
| |
| /** |
| * Test that correct exception is thrown when subclass of {@link CombineFn} is not parameterized. |
| * BEAM-3777 |
| */ |
| @Test |
| public void testRawCombineFnSubclass() { |
| exceptions.expect(ParseException.class); |
| exceptions.expectCause(hasMessage(containsString("CombineFn must be parameterized"))); |
| pipeline.enableAbandonedNodeEnforcement(false); |
| |
| Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build(); |
| |
| Row row = Row.withSchema(resultType).addValues(0, 354).build(); |
| |
| String sql1 = |
| "SELECT f_int2, squaresum(f_int) AS `squaresum`" + " FROM PCOLLECTION GROUP BY f_int2"; |
| PCollection<Row> result1 = |
| boundedInput1.apply( |
| "testUdaf", SqlTransform.query(sql1).registerUdaf("squaresum", new RawCombineFn())); |
| } |
| |
| /** test UDF. */ |
| @Test |
| public void testUdf() throws Exception { |
| Schema resultType = Schema.builder().addInt32Field("f_int").addInt32Field("cubicvalue").build(); |
| Row row = Row.withSchema(resultType).addValues(2, 8).build(); |
| |
| String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; |
| PCollection<Row> result1 = |
| boundedInput1.apply( |
| "testUdf1", SqlTransform.query(sql1).registerUdf("cubic1", CubicInteger.class)); |
| PAssert.that(result1).containsInAnyOrder(row); |
| |
| String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; |
| PCollection<Row> result2 = |
| PCollectionTuple.of(new TupleTag<>("PCOLLECTION"), boundedInput1) |
| .apply( |
| "testUdf2", SqlTransform.query(sql2).registerUdf("cubic2", new CubicIntegerFn())); |
| PAssert.that(result2).containsInAnyOrder(row); |
| |
| String sql3 = "SELECT f_int, substr(f_string) as sub_string FROM PCOLLECTION WHERE f_int = 2"; |
| PCollection<Row> result3 = |
| PCollectionTuple.of(new TupleTag<>("PCOLLECTION"), boundedInput1) |
| .apply( |
| "testUdf3", SqlTransform.query(sql3).registerUdf("substr", UdfFnWithDefault.class)); |
| |
| Schema subStrSchema = |
| Schema.builder().addInt32Field("f_int").addStringField("sub_string").build(); |
| Row subStrRow = Row.withSchema(subStrSchema).addValues(2, "s").build(); |
| PAssert.that(result3).containsInAnyOrder(subStrRow); |
| |
| pipeline.run().waitUntilFinish(); |
| } |
| |
| /** UDAF(CombineFn) for test, which returns the sum of square. */ |
| public static class SquareSum extends CombineFn<Integer, Integer, Integer> { |
| @Override |
| public Integer createAccumulator() { |
| return 0; |
| } |
| |
| @Override |
| public Integer addInput(Integer accumulator, Integer input) { |
| return accumulator + input * input; |
| } |
| |
| @Override |
| public Integer mergeAccumulators(Iterable<Integer> accumulators) { |
| int v = 0; |
| for (Integer accumulator : accumulators) { |
| v += accumulator; |
| } |
| return v; |
| } |
| |
| @Override |
| public Integer extractOutput(Integer accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** |
| * Non-parameterized CombineFn. Intended to test that non-parameterized CombineFns are correctly |
| * rejected. The methods just return null, as they should never be called. |
| */ |
| public static class RawCombineFn extends CombineFn { |
| |
| @Override |
| public Object createAccumulator() { |
| return null; |
| } |
| |
| @Override |
| public Object addInput(Object accumulator, Object input) { |
| return null; |
| } |
| |
| @Override |
| public Object mergeAccumulators(Iterable accumulators) { |
| return null; |
| } |
| |
| @Override |
| public Object extractOutput(Object accumulator) { |
| return null; |
| } |
| } |
| |
| /** An example UDAF with two levels of descendancy from CombineFn. */ |
| public static class SquareSquareSum extends SquareSum { |
| @Override |
| public Integer addInput(Integer accumulator, Integer input) { |
| return super.addInput(accumulator, input * input); |
| } |
| } |
| |
| /** An example UDF for test. */ |
| public static class CubicInteger implements BeamSqlUdf { |
| public static Integer eval(Integer input) { |
| return input * input * input; |
| } |
| } |
| |
| /** An example UDF with {@link SerializableFunction}. */ |
| public static class CubicIntegerFn implements SerializableFunction<Integer, Integer> { |
| @Override |
| public Integer apply(Integer input) { |
| return input * input * input; |
| } |
| } |
| |
| /** A UDF with default parameters. */ |
| public static final class UdfFnWithDefault implements BeamSqlUdf { |
| public static String eval( |
| @Parameter(name = "s") String s, @Parameter(name = "n", optional = true) Integer n) { |
| return s.substring(0, n == null ? 1 : n); |
| } |
| } |
| } |