[BEAM-12312] Don't rely on remove() in LazyAggregateCombineFn#mergeAccumulators.
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
index 3b782d9..c489d12 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -73,10 +73,9 @@
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- Iterator<AccumT> it = accumulators.iterator();
- AccumT first = it.next();
- it.remove();
- return getAggregateFn().mergeAccumulators(first, accumulators);
+ AccumT first = accumulators.iterator().next();
+ Iterable<AccumT> rest = new SkipFirstElementIterable<>(accumulators);
+ return getAggregateFn().mergeAccumulators(first, rest);
}
@Override
@@ -99,4 +98,20 @@
public TypeVariable<?> getAccumTVariable() {
return AggregateFn.class.getTypeParameters()[1];
}
+
+ /** Wrapper {@link Iterable} which always skips its first element. */
+ private static class SkipFirstElementIterable<T> implements Iterable<T> {
+ private final Iterable<T> all;
+
+ SkipFirstElementIterable(Iterable<T> all) {
+ this.all = all;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ Iterator<T> it = all.iterator();
+ it.next();
+ return it;
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
index 21ab8d0..cf3f40d 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
@@ -19,12 +19,14 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -41,6 +43,13 @@
assertThat(coder, instanceOf(VarLongCoder.class));
}
+ @Test
+ public void mergeAccumulators() {
+ LazyAggregateCombineFn<Long, Long, Long> combiner = new LazyAggregateCombineFn<>(new Sum());
+ long merged = combiner.mergeAccumulators(ImmutableList.of(1L, 1L));
+ assertEquals(2L, merged);
+ }
+
public static class Sum implements AggregateFn<Long, Long, Long> {
@Override