blob: 7e6b404d9fbce2fd25ba98671654aa61c0a5954e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.transforms;
import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesSchema;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.hamcrest.Matcher;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Test for {@link Group}. */
@RunWith(JUnit4.class)
@Category(UsesSchema.class)
public class GroupTest implements Serializable {
@Rule public final transient TestPipeline pipeline = TestPipeline.create();
/** A simple POJO for testing. */
@DefaultSchema(JavaFieldSchema.class)
public static class POJO implements Serializable {
public String field1;
public long field2;
public String field3;
public POJO(String field1, long field2, String field3) {
this.field1 = field1;
this.field2 = field2;
this.field3 = field3;
}
public POJO() {}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
POJO pojo = (POJO) o;
return field2 == pojo.field2
&& Objects.equals(field1, pojo.field1)
&& Objects.equals(field3, pojo.field3);
}
@Override
public int hashCode() {
return Objects.hash(field1, field2, field3);
}
@Override
public String toString() {
return "POJO{"
+ "field1='"
+ field1
+ '\''
+ ", field2="
+ field2
+ ", field3='"
+ field3
+ '\''
+ '}';
}
}
private static final Schema POJO_SCHEMA =
Schema.builder()
.addStringField("field1")
.addInt64Field("field2")
.addStringField("field3")
.build();
@Test
@Category(NeedsRunner.class)
public void testGroupByOneField() {
PCollection<KV<Row, Iterable<POJO>>> grouped =
pipeline
.apply(
Create.of(
new POJO("key1", 1, "value1"),
new POJO("key1", 2, "value2"),
new POJO("key2", 3, "value3"),
new POJO("key2", 4, "value4")))
.apply(Group.byFieldNames("field1"));
Schema keySchema = Schema.builder().addStringField("field1").build();
List<KV<Row, Collection<POJO>>> expected =
ImmutableList.of(
KV.of(
Row.withSchema(keySchema).addValue("key1").build(),
ImmutableList.of(new POJO("key1", 1L, "value1"), new POJO("key1", 2L, "value2"))),
KV.of(
Row.withSchema(keySchema).addValue("key2").build(),
ImmutableList.of(new POJO("key2", 3L, "value3"), new POJO("key2", 4L, "value4"))));
PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new POJO[0]));
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testGroupByMultiple() {
PCollection<KV<Row, Iterable<POJO>>> grouped =
pipeline
.apply(
Create.of(
new POJO("key1", 1, "value1"),
new POJO("key1", 1, "value2"),
new POJO("key2", 2, "value3"),
new POJO("key2", 2, "value4")))
.apply(Group.byFieldNames("field1", "field2"));
Schema keySchema = Schema.builder().addStringField("field1").addInt64Field("field2").build();
List<KV<Row, Collection<POJO>>> expected =
ImmutableList.of(
KV.of(
Row.withSchema(keySchema).addValues("key1", 1L).build(),
ImmutableList.of(new POJO("key1", 1L, "value1"), new POJO("key1", 1L, "value2"))),
KV.of(
Row.withSchema(keySchema).addValues("key2", 2L).build(),
ImmutableList.of(new POJO("key2", 2L, "value3"), new POJO("key2", 2L, "value4"))));
PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new POJO[0]));
pipeline.run();
}
/** A class for testing nested key grouping. */
@DefaultSchema(JavaFieldSchema.class)
public static class OuterPOJO implements Serializable {
public POJO inner;
public OuterPOJO(POJO inner) {
this.inner = inner;
}
public OuterPOJO() {}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
OuterPOJO outerPOJO = (OuterPOJO) o;
return Objects.equals(inner, outerPOJO.inner);
}
@Override
public int hashCode() {
return Objects.hash(inner);
}
@Override
public String toString() {
return "OuterPOJO{" + "inner=" + inner + '}';
}
}
/** Test grouping by a set of fields that are nested. */
@Test
@Category(NeedsRunner.class)
public void testGroupByNestedKey() {
PCollection<KV<Row, Iterable<OuterPOJO>>> grouped =
pipeline
.apply(
Create.of(
new OuterPOJO(new POJO("key1", 1L, "value1")),
new OuterPOJO(new POJO("key1", 1L, "value2")),
new OuterPOJO(new POJO("key2", 2L, "value3")),
new OuterPOJO(new POJO("key2", 2L, "value4"))))
.apply(Group.byFieldNames("inner.field1", "inner.field2"));
Schema keySchema = Schema.builder().addStringField("field1").addInt64Field("field2").build();
List<KV<Row, Collection<OuterPOJO>>> expected =
ImmutableList.of(
KV.of(
Row.withSchema(keySchema).addValues("key1", 1L).build(),
ImmutableList.of(
new OuterPOJO(new POJO("key1", 1L, "value1")),
new OuterPOJO(new POJO("key1", 1L, "value2")))),
KV.of(
Row.withSchema(keySchema).addValues("key2", 2L).build(),
ImmutableList.of(
new OuterPOJO(new POJO("key2", 2L, "value3")),
new OuterPOJO(new POJO("key2", 2L, "value4")))));
PAssert.that(grouped)
.satisfies(actual -> containsKIterableVs(expected, actual, new OuterPOJO[0]));
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testGroupGlobally() {
Collection<POJO> elements =
ImmutableList.of(
new POJO("key1", 1, "value1"),
new POJO("key1", 1, "value2"),
new POJO("key2", 2, "value3"),
new POJO("key2", 2, "value4"));
PCollection<Iterable<POJO>> grouped =
pipeline.apply(Create.of(elements)).apply(Group.globally());
PAssert.that(grouped).satisfies(actual -> containsSingleIterable(elements, actual));
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testGlobalAggregation() {
Collection<POJO> elements =
ImmutableList.of(
new POJO("key1", 1, "value1"),
new POJO("key1", 1, "value2"),
new POJO("key2", 2, "value3"),
new POJO("key2", 2, "value4"));
PCollection<Long> count =
pipeline
.apply(Create.of(elements))
.apply(Group.<POJO>globally().aggregate(Count.combineFn()));
PAssert.that(count).containsInAnyOrder(4L);
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testPerKeyAggregation() {
Collection<POJO> elements =
ImmutableList.of(
new POJO("key1", 1, "value1"),
new POJO("key1", 1, "value2"),
new POJO("key2", 2, "value3"),
new POJO("key2", 2, "value4"),
new POJO("key2", 2, "value4"));
PCollection<KV<Row, Long>> count =
pipeline
.apply(Create.of(elements))
.apply(Group.<POJO>byFieldNames("field1").aggregate(Count.combineFn()));
Schema keySchema = Schema.builder().addStringField("field1").build();
Collection<KV<Row, Long>> expectedCounts =
ImmutableList.of(
KV.of(Row.withSchema(keySchema).addValue("key1").build(), 2L),
KV.of(Row.withSchema(keySchema).addValue("key2").build(), 3L));
PAssert.that(count).containsInAnyOrder(expectedCounts);
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testOutputCoders() {
Schema keySchema = Schema.builder().addStringField("field1").build();
PCollection<KV<Row, Iterable<POJO>>> grouped =
pipeline
.apply(Create.of(new POJO("key1", 1, "value1")))
.apply(Group.byFieldNames("field1"));
// Make sure that the key has the right schema.
PCollection<Row> keys = grouped.apply(Keys.create());
assertTrue(keys.getSchema().equivalent(keySchema));
// Make sure that the value has the right schema.
PCollection<POJO> values = grouped.apply(Values.create()).apply(Flatten.iterables());
assertTrue(values.getSchema().equivalent(POJO_SCHEMA));
pipeline.run();
}
/** A class for testing field aggregation. */
@DefaultSchema(JavaFieldSchema.class)
public static class AggregatePojos implements Serializable {
public long field1;
public long field2;
public int field3;
public AggregatePojos(long field1, long field2, int field3) {
this.field1 = field1;
this.field2 = field2;
this.field3 = field3;
}
public AggregatePojos() {}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AggregatePojos agg = (AggregatePojos) o;
return field1 == agg.field1 && field2 == agg.field2 && field3 == agg.field3;
}
@Override
public int hashCode() {
return Objects.hash(field1, field2, field3);
}
}
@Test
@Category(NeedsRunner.class)
public void testByKeyWithSchemaAggregateFn() {
Collection<AggregatePojos> elements =
ImmutableList.of(
new AggregatePojos(1, 1, 2),
new AggregatePojos(2, 1, 3),
new AggregatePojos(3, 2, 4),
new AggregatePojos(4, 2, 5));
PCollection<KV<Row, Row>> aggregations =
pipeline
.apply(Create.of(elements))
.apply(
Group.<AggregatePojos>byFieldNames("field2")
.aggregateField("field1", Sum.ofLongs(), "field1_sum")
.aggregateField("field3", Sum.ofIntegers(), "field3_sum")
.aggregateField("field1", Top.largestLongsFn(1), "field1_top"));
Schema keySchema = Schema.builder().addInt64Field("field2").build();
Schema valueSchema =
Schema.builder()
.addInt64Field("field1_sum")
.addInt32Field("field3_sum")
.addArrayField("field1_top", FieldType.INT64)
.build();
List<KV<Row, Row>> expected =
ImmutableList.of(
KV.of(
Row.withSchema(keySchema).addValue(1L).build(),
Row.withSchema(valueSchema).addValue(3L).addValue(5).addArray(2L).build()),
KV.of(
Row.withSchema(keySchema).addValue(2L).build(),
Row.withSchema(valueSchema).addValue(7L).addValue(9).addArray(4L).build()));
PAssert.that(aggregations).satisfies(actual -> containsKvs(expected, actual));
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testGloballyWithSchemaAggregateFn() {
Collection<AggregatePojos> elements =
ImmutableList.of(
new AggregatePojos(1, 1, 2),
new AggregatePojos(2, 1, 3),
new AggregatePojos(3, 2, 4),
new AggregatePojos(4, 2, 5));
PCollection<Row> aggregate =
pipeline
.apply(Create.of(elements))
.apply(
Group.<AggregatePojos>globally()
.aggregateField("field1", Sum.ofLongs(), "field1_sum")
.aggregateField("field3", Sum.ofIntegers(), "field3_sum")
.aggregateField("field1", Top.largestLongsFn(1), "field1_top"));
Schema aggregateSchema =
Schema.builder()
.addInt64Field("field1_sum")
.addInt32Field("field3_sum")
.addArrayField("field1_top", FieldType.INT64)
.build();
Row expectedRow = Row.withSchema(aggregateSchema).addValues(10L, 14).addArray(4L).build();
PAssert.that(aggregate).containsInAnyOrder(expectedRow);
pipeline.run();
}
/** A combine function that adds all long fields in the Row. */
public static class MultipleFieldCombineFn extends CombineFn<Row, long[], Long> {
@Override
public long[] createAccumulator() {
return new long[] {0};
}
@Override
public long[] addInput(long[] accumulator, Row input) {
for (Object o : input.getValues()) {
if (o instanceof Long) {
accumulator[0] += (Long) o;
}
}
return accumulator;
}
@Override
public long[] mergeAccumulators(Iterable<long[]> accumulators) {
Iterator<long[]> iter = accumulators.iterator();
if (!iter.hasNext()) {
return createAccumulator();
} else {
long[] running = iter.next();
while (iter.hasNext()) {
running[0] += iter.next()[0];
}
return running;
}
}
@Override
public Long extractOutput(long[] accumulator) {
return accumulator[0];
}
}
@Test
@Category(NeedsRunner.class)
public void testAggregateByMultipleFields() {
Collection<AggregatePojos> elements =
ImmutableList.of(
new AggregatePojos(1, 1, 2),
new AggregatePojos(2, 1, 3),
new AggregatePojos(3, 2, 4),
new AggregatePojos(4, 2, 5));
List<String> fieldNames = Lists.newArrayList("field1", "field2");
PCollection<Row> aggregate =
pipeline
.apply(Create.of(elements))
.apply(
Group.<AggregatePojos>globally()
.aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2"));
Schema outputSchema = Schema.builder().addInt64Field("field1+field2").build();
Row expectedRow = Row.withSchema(outputSchema).addValues(16L).build();
PAssert.that(aggregate).containsInAnyOrder(expectedRow);
pipeline.run();
}
/** A class for testing nested aggregation. */
@DefaultSchema(JavaFieldSchema.class)
public static class OuterAggregate implements Serializable {
public AggregatePojos inner;
public OuterAggregate(AggregatePojos inner) {
this.inner = inner;
}
public OuterAggregate() {}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
OuterAggregate that = (OuterAggregate) o;
return Objects.equals(inner, that.inner);
}
@Override
public int hashCode() {
return Objects.hash(inner);
}
}
@Test
@Category(NeedsRunner.class)
public void testByKeyWithSchemaAggregateFnNestedFields() {
Collection<OuterAggregate> elements =
ImmutableList.of(
new OuterAggregate(new AggregatePojos(1, 1, 2)),
new OuterAggregate(new AggregatePojos(2, 1, 3)),
new OuterAggregate(new AggregatePojos(3, 2, 4)),
new OuterAggregate(new AggregatePojos(4, 2, 5)));
PCollection<KV<Row, Row>> aggregations =
pipeline
.apply(Create.of(elements))
.apply(
Group.<OuterAggregate>byFieldNames("inner.field2")
.aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
.aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
.aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top"));
Schema keySchema = Schema.builder().addInt64Field("field2").build();
Schema valueSchema =
Schema.builder()
.addInt64Field("field1_sum")
.addInt32Field("field3_sum")
.addArrayField("field1_top", FieldType.INT64)
.build();
List<KV<Row, Row>> expected =
ImmutableList.of(
KV.of(
Row.withSchema(keySchema).addValue(1L).build(),
Row.withSchema(valueSchema).addValue(3L).addValue(5).addArray(2L).build()),
KV.of(
Row.withSchema(keySchema).addValue(2L).build(),
Row.withSchema(valueSchema).addValue(7L).addValue(9).addArray(4L).build()));
PAssert.that(aggregations).satisfies(actual -> containsKvs(expected, actual));
pipeline.run();
}
@Test
@Category(NeedsRunner.class)
public void testGloballyWithSchemaAggregateFnNestedFields() {
Collection<OuterAggregate> elements =
ImmutableList.of(
new OuterAggregate(new AggregatePojos(1, 1, 2)),
new OuterAggregate(new AggregatePojos(2, 1, 3)),
new OuterAggregate(new AggregatePojos(3, 2, 4)),
new OuterAggregate(new AggregatePojos(4, 2, 5)));
PCollection<Row> aggregate =
pipeline
.apply(Create.of(elements))
.apply(
Group.<OuterAggregate>globally()
.aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
.aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
.aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top"));
Schema aggregateSchema =
Schema.builder()
.addInt64Field("field1_sum")
.addInt32Field("field3_sum")
.addArrayField("field1_top", FieldType.INT64)
.build();
Row expectedRow = Row.withSchema(aggregateSchema).addValues(10L, 14).addArray(4L).build();
PAssert.that(aggregate).containsInAnyOrder(expectedRow);
pipeline.run();
}
private static <T> Void containsKIterableVs(
List<KV<Row, Collection<T>>> expectedKvs,
Iterable<KV<Row, Iterable<T>>> actualKvs,
T[] emptyArray) {
List<KV<Row, Iterable<T>>> list = Lists.newArrayList(actualKvs);
List<Matcher<? super KV<Row, Iterable<POJO>>>> matchers = new ArrayList<>();
for (KV<Row, Collection<T>> expected : expectedKvs) {
T[] values = expected.getValue().toArray(emptyArray);
matchers.add(isKv(equalTo(expected.getKey()), containsInAnyOrder(values)));
}
assertThat(actualKvs, containsInAnyOrder(matchers.toArray(new Matcher[0])));
return null;
}
private static <T> Void containsKvs(
List<KV<Row, Row>> expectedKvs, Iterable<KV<Row, Row>> actualKvs) {
List<Matcher<? super KV<Row, Iterable<POJO>>>> matchers = new ArrayList<>();
for (KV<Row, Row> expected : expectedKvs) {
matchers.add(isKv(equalTo(expected.getKey()), equalTo(expected.getValue())));
}
assertThat(actualKvs, containsInAnyOrder(matchers.toArray(new Matcher[0])));
return null;
}
private static Void containsSingleIterable(
Collection<POJO> expected, Iterable<Iterable<POJO>> actual) {
POJO[] values = expected.toArray(new POJO[0]);
assertThat(actual, containsInAnyOrder(containsInAnyOrder(values)));
return null;
}
}