Merge pull request #15490 from KevinGG/beam_sql_out_cache

[BEAM-10708] Introspect beam_sql output
diff --git a/CHANGES.md b/CHANGES.md
index 5547c75..e6181da 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -46,10 +46,15 @@
 
 * X behavior is deprecated and will be removed in X versions ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 
-## Known Issues
+## Bugfixes
 
 * Fixed X (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
+
+## Known Issues
+
+* ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 -->
+
 # [2.34.0] - Unreleased
 * Add an [example](https://github.com/cometta/python-apache-beam-spark) of deploying Python Apache Beam job with Spark Cluster
 ## Highlights
@@ -81,10 +86,14 @@
 
 * X behavior is deprecated and will be removed in X versions ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 
-## Known Issues
+## Bugfixes
 
 * Fixed X (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 
+## Known Issues
+
+* ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
+
 # [2.33.0] - Unreleased
 
 ## Highlights
@@ -172,7 +181,7 @@
 * Python GBK will stop supporting unbounded PCollections that have global windowing and a default trigger in Beam 2.33. This can be overriden with `--allow_unsafe_triggers`. ([BEAM-9487](https://issues.apache.org/jira/browse/BEAM-9487)).
 * Python GBK will start requiring safe triggers or the `--allow_unsafe_triggers` flag starting with Beam 2.33. ([BEAM-9487](https://issues.apache.org/jira/browse/BEAM-9487)).
 
-## Known Issues
+## Bugfixes
 
 * Fixed race condition in RabbitMqIO causing duplicate acks (Java) ([BEAM-6516](https://issues.apache.org/jira/browse/BEAM-6516)))
 
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 67fe8b8..5c00e27 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -510,6 +510,7 @@
         aws_java_sdk2_sdk_core                      : "software.amazon.awssdk:sdk-core:$aws_java_sdk2_version",
         aws_java_sdk2_sns                           : "software.amazon.awssdk:sns:$aws_java_sdk2_version",
         aws_java_sdk2_sqs                           : "software.amazon.awssdk:sqs:$aws_java_sdk2_version",
+        aws_java_sdk2_sts                           : "software.amazon.awssdk:sts:$aws_java_sdk2_version",
         aws_java_sdk2_s3                            : "software.amazon.awssdk:s3:$aws_java_sdk2_version",
         aws_java_sdk2_http_client_spi               : "software.amazon.awssdk:http-client-spi:$aws_java_sdk2_version",
         aws_java_sdk2_regions                       : "software.amazon.awssdk:regions:$aws_java_sdk2_version",
diff --git a/runners/samza/job-server/build.gradle b/runners/samza/job-server/build.gradle
index 266ec0e..4660bfc 100644
--- a/runners/samza/job-server/build.gradle
+++ b/runners/samza/job-server/build.gradle
@@ -69,6 +69,7 @@
         pipelineOpts: pipelineOptions,
         environment: BeamModulePlugin.PortableValidatesRunnerConfiguration.Environment.EMBEDDED,
         testCategories: {
+            includeCategories 'org.apache.beam.sdk.testing.NeedsRunner'
             includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner'
             // TODO: BEAM-12350
             excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics'
@@ -112,6 +113,62 @@
             excludeTestsMatching 'org.apache.beam.sdk.testing.TestStreamTest.testFirstElementLate'
             // TODO(BEAM-12036)
             excludeTestsMatching 'org.apache.beam.sdk.testing.TestStreamTest.testLateDataAccumulating'
+            // TODO(BEAM-12886)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupByKeyTest$WindowTests.testWindowFnPostMerging'
+            // TODO(BEAM-12887)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testParDoShiftTimestampInvalid'
+            // TODO(BEAM-12888)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testParDoShiftTimestampInvalidZeroAllowed'
+            // TODO(BEAM-12889)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.DeduplicateTest.testEventTime'
+            // TODO(BEAM-12890)
+            excludeTestsMatching 'org.apache.beam.sdk.io.TFRecordIOTest.testReadInvalidRecord'
+            // TODO(BEAM-12891)
+            excludeTestsMatching 'org.apache.beam.sdk.io.TFRecordIOTest.testReadInvalidDataMask'
+            // TODO(BEAM-12892)
+            excludeTestsMatching 'org.apache.beam.sdk.io.TFRecordIOTest.testReadInvalidLengthMask'
+            // TODO(BEAM-12893)
+            excludeTestsMatching 'org.apache.beam.sdk.io.TextIOReadTest$CompressedReadTest.testCompressedReadWithoutExtension'
+            // TODO(BEAM-12894)
+            excludeTestsMatching 'org.apache.beam.sdk.io.WriteFilesTest.testWithRunnerDeterminedShardingUnbounded'
+            // TODO(BEAM-128945)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$MultipleInputsAndOutputTests.testParDoWritingToUndeclaredTag'
+            // TODO(BEAM-12896)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$MultipleInputsAndOutputTests.testParDoReadingFromUnknownSideInput'
+            // TODO(BEAM-12897)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ViewTest.testMapSideInputWithNullValuesCatchesDuplicates'
+
+            // TODO(BEAM-12743)
+            excludeTestsMatching 'org.apache.beam.sdk.coders.PCollectionCustomCoderTest.testEncodingNPException'
+            excludeTestsMatching 'org.apache.beam.sdk.coders.PCollectionCustomCoderTest.testEncodingIOException'
+            excludeTestsMatching 'org.apache.beam.sdk.coders.PCollectionCustomCoderTest.testDecodingNPException'
+            excludeTestsMatching 'org.apache.beam.sdk.coders.PCollectionCustomCoderTest.testDecodingIOException'
+            // TODO(BEAM-12744)
+            excludeTestsMatching 'org.apache.beam.sdk.PipelineTest.testEmptyPipeline'
+            // TODO(BEAM-12745)
+            excludeTestsMatching 'org.apache.beam.sdk.io.AvroIOTest*'
+            // TODO(BEAM-12746)
+            excludeTestsMatching 'org.apache.beam.sdk.io.FileIOTest*'
+            // TODO(BEAM-12747)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.WithTimestampsTest.withTimestampsBackwardsInTimeShouldThrow'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.WithTimestampsTest.withTimestampsWithNullTimestampShouldThrow'
+            // TODO(BEAM-12748)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ViewTest.testEmptySingletonSideInput'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ViewTest.testNonSingletonSideInput'
+            // TODO(BEAM-12749)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.MapElementsTest.testMapSimpleFunction'
+            // TODO(BEAM-12750)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testInGlobalWindowBatchSizeByteSizeFn'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testInStreamingMode'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithUnevenBatches'
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testInGlobalWindowBatchSizeByteSize'
+            // TODO(BEAM-10025)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testOutputTimestampDefaultUnbounded'
+            // TODO(BEAM-11479)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testOutputTimestamp'
+            // TODO(BEAM-11479)
+            excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testRelativeTimerWithOutputTimestamp'
         }
 )
 
diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle
index c32f7b0..f1bb021 100644
--- a/runners/spark/job-server/spark_job_server.gradle
+++ b/runners/spark/job-server/spark_job_server.gradle
@@ -59,7 +59,8 @@
   validatesPortableRunner project(path: ":runners:core-java", configuration: "testRuntime")
   validatesPortableRunner project(path: ":runners:portability:java", configuration: "testRuntime")
   runtime project(":sdks:java:extensions:google-cloud-platform-core")
-//  TODO: Enable AWS and HDFS file system.
+  runtime project(":sdks:java:io:amazon-web-services2")
+//  TODO: Enable HDFS file system.
 }
 
 // NOTE: runShadow must be used in order to run the job server. The standard run
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
index 4fe26d7..6391ba4 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -17,27 +17,26 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.runners.core.Concatenate;
+import java.io.Serializable;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.KeyValueGroupedDataset;
-import scala.Tuple2;
 
 class GroupByKeyTranslatorBatch<K, V>
     implements TransformTranslator<
@@ -49,62 +48,43 @@
       AbstractTranslationContext context) {
 
     @SuppressWarnings("unchecked")
-    final PCollection<KV<K, V>> input = (PCollection<KV<K, V>>) context.getInput();
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, List<V>>> output = (PCollection<KV<K, List<V>>>) context.getOutput();
-    final Combine.CombineFn<V, List<V>, List<V>> combineFn = new Concatenate<>();
+    final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput();
+    Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection);
+    WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy();
+    KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
+    Coder<V> valueCoder = kvCoder.getValueCoder();
 
-    WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+    // group by key only
+    Coder<K> keyCoder = kvCoder.getKeyCoder();
+    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
+        input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
 
-    Dataset<WindowedValue<KV<K, V>>> inputDataset = context.getDataset(input);
+    // group also by windows
+    WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
+        WindowedValue.FullWindowedValueCoder.of(
+            KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
+            windowingStrategy.getWindowFn().windowCoder());
+    Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
+        groupByKeyOnly.flatMapGroups(
+            new GroupAlsoByWindowViaOutputBufferFn<>(
+                windowingStrategy,
+                new InMemoryStateInternalsFactory<>(),
+                SystemReduceFn.buffering(valueCoder),
+                context.getSerializableOptions()),
+            EncoderHelpers.fromBeamCoder(outputCoder));
 
-    KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
-    Coder<K> keyCoder = inputCoder.getKeyCoder();
-    KvCoder<K, List<V>> outputKVCoder = (KvCoder<K, List<V>>) output.getCoder();
-    Coder<List<V>> outputCoder = outputKVCoder.getValueCoder();
+    context.putDataset(context.getOutput(), output);
+  }
 
-    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupedDataset =
-      inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
-
-    Coder<List<V>> accumulatorCoder = null;
-    try {
-      accumulatorCoder =
-        combineFn.getAccumulatorCoder(
-          input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
-    } catch (CannotProvideCoderException e) {
-      throw new RuntimeException(e);
+  /**
+   * In-memory state internals factory.
+   *
+   * @param <K> State key type.
+   */
+  static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
+    @Override
+    public StateInternals stateInternalsForKey(K key) {
+      return InMemoryStateInternals.forKey(key);
     }
-
-    Dataset<Tuple2<K, Iterable<WindowedValue<List<V>>>>> combinedDataset =
-      groupedDataset.agg(
-        new AggregatorCombiner<K, V, List<V>, List<V>, BoundedWindow>(
-          combineFn, windowingStrategy, accumulatorCoder, outputCoder)
-          .toColumn());
-
-    // expand the list into separate elements and put the key back into the elements
-    WindowedValue.WindowedValueCoder<KV<K, List<V>>> wvCoder =
-      WindowedValue.FullWindowedValueCoder.of(
-        outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
-    Dataset<WindowedValue<KV<K, List<V>>>> outputDataset =
-      combinedDataset.flatMap(
-        (FlatMapFunction<
-          Tuple2<K, Iterable<WindowedValue<List<V>>>>, WindowedValue<KV<K, List<V>>>>)
-          tuple2 -> {
-            K key = tuple2._1();
-            Iterable<WindowedValue<List<V>>> windowedValues = tuple2._2();
-            List<WindowedValue<KV<K, List<V>>>> result = new ArrayList<>();
-            for (WindowedValue<List<V>> windowedValue : windowedValues) {
-              KV<K, List<V>> kv = KV.of(key, windowedValue.getValue());
-              result.add(
-                WindowedValue.of(
-                  kv,
-                  windowedValue.getTimestamp(),
-                  windowedValue.getWindows(),
-                  windowedValue.getPane()));
-            }
-            return result.iterator();
-          },
-        EncoderHelpers.fromBeamCoder(wvCoder));
-    context.putDataset(output, outputDataset);
   }
 }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
index 8c11464..0d6ab0c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
@@ -38,6 +38,7 @@
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
@@ -365,12 +366,62 @@
   }
 
   /**
+   * a {@link PTransform} that does a combine using an aggregation built up by calls to
+   * aggregateField and aggregateFields. The output of this transform will have a schema that is
+   * determined by the output types of all the composed combiners.
+   *
+   * @param <InputT>
+   */
+  public abstract static class AggregateCombiner<InputT>
+      extends PTransform<PCollection<InputT>, PCollection<Row>> {
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output schema.
+     */
+    public abstract <CombineInputT, AccumT, CombineOutputT>
+        AggregateCombiner<InputT> aggregateField(
+            int inputFieldId,
+            CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+            Field outputField);
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output schema.
+     */
+    public abstract <CombineInputT, AccumT, CombineOutputT>
+        AggregateCombiner<InputT> aggregateField(
+            String inputFieldName,
+            CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+            Field outputField);
+
+    /**
+     * Build up an aggregation function over the input elements by field id.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public abstract <CombineInputT, AccumT, CombineOutputT>
+        AggregateCombiner<InputT> aggregateFieldsById(
+            List<Integer> inputFieldIds,
+            CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+            Field outputField);
+  }
+
+  /**
    * a {@link PTransform} that does a global combine using an aggregation built up by calls to
    * aggregateField and aggregateFields. The output of this transform will have a schema that is
    * determined by the output types of all the composed combiners.
    */
-  public static class CombineFieldsGlobally<InputT>
-      extends PTransform<PCollection<InputT>, PCollection<Row>> {
+  public static class CombineFieldsGlobally<InputT> extends AggregateCombiner<InputT> {
     private final SchemaAggregateFn.Inner schemaAggregateFn;
 
     CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn) {
@@ -378,6 +429,15 @@
     }
 
     /**
+     * Returns a transform that does a global combine using an aggregation built up by calls to
+     * aggregateField and aggregateFields. This transform will have an unknown schema that will be
+     * determined by the output types of all the composed combiners.
+     */
+    public static CombineFieldsGlobally create() {
+      return new CombineFieldsGlobally<>(SchemaAggregateFn.create());
+    }
+
+    /**
      * Build up an aggregation function over the input elements.
      *
      * <p>This method specifies an aggregation over single field of the input. The union of all
@@ -431,6 +491,7 @@
      * <p>This method specifies an aggregation over single field of the input. The union of all
      * calls to aggregateField and aggregateFields will determine the output schema.
      */
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
         String inputFieldName,
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
@@ -450,6 +511,7 @@
               FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
         int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
       return new CombineFieldsGlobally<>(
@@ -526,6 +588,7 @@
           FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT>
         CombineFieldsGlobally<InputT> aggregateFieldsById(
             List<Integer> inputFieldIds,
@@ -551,9 +614,13 @@
     @Override
     public PCollection<Row> expand(PCollection<InputT> input) {
       SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema());
+      Combine.Globally<Row, Row> combineFn = Combine.globally(fn);
+      if (!(input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows)) {
+        combineFn = combineFn.withoutDefaults();
+      }
       return input
           .apply("toRows", Convert.toRows())
-          .apply("Global Combine", Combine.globally(fn))
+          .apply("Global Combine", combineFn)
           .setRowSchema(fn.getOutputSchema());
     }
   }
@@ -566,8 +633,7 @@
    * specified extracted fields.
    */
   @AutoValue
-  public abstract static class ByFields<InputT>
-      extends PTransform<PCollection<InputT>, PCollection<Row>> {
+  public abstract static class ByFields<InputT> extends AggregateCombiner<InputT> {
     abstract FieldAccessDescriptor getFieldAccessDescriptor();
 
     abstract String getKeyField();
@@ -698,6 +764,7 @@
      * <p>This method specifies an aggregation over single field of the input. The union of all
      * calls to aggregateField and aggregateFields will determine the output schema.
      */
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
         String inputFieldName,
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
@@ -725,6 +792,7 @@
           getValueField());
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
         int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
       return CombineFieldsByFields.of(
@@ -812,6 +880,7 @@
           FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT>
         CombineFieldsByFields<InputT> aggregateFieldsById(
             List<Integer> inputFieldIds,
@@ -875,8 +944,7 @@
    * determined by the output types of all the composed combiners.
    */
   @AutoValue
-  public abstract static class CombineFieldsByFields<InputT>
-      extends PTransform<PCollection<InputT>, PCollection<Row>> {
+  public abstract static class CombineFieldsByFields<InputT> extends AggregateCombiner<InputT> {
     abstract ByFields<InputT> getByFields();
 
     abstract SchemaAggregateFn.Inner getSchemaAggregateFn();
@@ -995,6 +1063,7 @@
      * <p>This method specifies an aggregation over single field of the input. The union of all
      * calls to aggregateField and aggregateFields will determine the output schema.
      */
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
         String inputFieldName,
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
@@ -1020,6 +1089,7 @@
           .build();
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsByFields<InputT> aggregateField(
         int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
       return toBuilder()
@@ -1095,6 +1165,7 @@
           FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, outputField);
     }
 
+    @Override
     public <CombineInputT, AccumT, CombineOutputT>
         CombineFieldsByFields<InputT> aggregateFieldsById(
             List<Integer> inputFieldIds,
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
index 9eee454..d93fe1e 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
@@ -48,6 +48,7 @@
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.UsesMapState;
 import org.apache.beam.sdk.testing.UsesSchema;
+import org.apache.beam.sdk.testing.UsesSetState;
 import org.apache.beam.sdk.testing.UsesStatefulParDo;
 import org.apache.beam.sdk.testing.ValidatesRunner;
 import org.apache.beam.sdk.values.KV;
@@ -768,7 +769,7 @@
   }
 
   @Test
-  @Category({NeedsRunner.class, UsesStatefulParDo.class})
+  @Category({NeedsRunner.class, UsesStatefulParDo.class, UsesSetState.class})
   public void testSetStateSchemaInference() throws NoSuchSchemaException {
     final String stateId = "foo";
 
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
index f3e14da..3173db9 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
@@ -31,6 +31,7 @@
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.Field;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.transforms.Group;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -216,6 +217,8 @@
     private WindowFn<Row, IntervalWindow> windowFn;
     private int windowFieldIndex;
     private List<FieldAggregation> fieldAggregations;
+    private final int groupSetCount;
+    private boolean ignoreValues;
 
     private Transform(
         WindowFn<Row, IntervalWindow> windowFn,
@@ -227,6 +230,8 @@
       this.windowFieldIndex = windowFieldIndex;
       this.fieldAggregations = fieldAggregations;
       this.outputSchema = outputSchema;
+      this.groupSetCount = groupSet.asList().size();
+      this.ignoreValues = false;
       this.keyFieldsIds =
           groupSet.asList().stream().filter(i -> i != windowFieldIndex).collect(toList());
     }
@@ -243,39 +248,60 @@
       if (windowFn != null) {
         windowedStream = assignTimestampsAndWindow(upstream);
       }
-
       validateWindowIsSupported(windowedStream);
+      // Check if have fields to be grouped
+      if (groupSetCount > 0) {
+        org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> byFields =
+            org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
+        PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(byFields);
+        boolean verifyRowValues =
+            pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class).getVerifyRowValues();
+        return windowedStream
+            .apply(combiner)
+            .apply(
+                "mergeRecord",
+                ParDo.of(
+                    mergeRecord(outputSchema, windowFieldIndex, ignoreValues, verifyRowValues)))
+            .setRowSchema(outputSchema);
+      }
+      org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> globally =
+          org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsGlobally.create();
+      PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(globally);
+      return windowedStream.apply(combiner).setRowSchema(outputSchema);
+    }
 
-      org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> byFields =
-          org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
-      org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsByFields<Row> combined = null;
+    private PTransform<PCollection<Row>, PCollection<Row>> createCombiner(
+        org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> initialCombiner) {
+
+      org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner combined = null;
       for (FieldAggregation fieldAggregation : fieldAggregations) {
         List<Integer> inputs = fieldAggregation.inputs;
         CombineFn combineFn = fieldAggregation.combineFn;
-        if (inputs.size() > 1 || inputs.isEmpty()) {
-          // In this path we extract a Row (an empty row if inputs.isEmpty).
-          combined =
-              (combined == null)
-                  ? byFields.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField)
-                  : combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
-        } else {
+        if (inputs.size() == 1) {
           // Combining over a single field, so extract just that field.
           combined =
               (combined == null)
-                  ? byFields.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField)
+                  ? initialCombiner.aggregateField(
+                      inputs.get(0), combineFn, fieldAggregation.outputField)
                   : combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField);
+        } else {
+          // In this path we extract a Row (an empty row if inputs.isEmpty).
+          combined =
+              (combined == null)
+                  ? initialCombiner.aggregateFieldsById(
+                      inputs, combineFn, fieldAggregation.outputField)
+                  : combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
         }
       }
 
       PTransform<PCollection<Row>, PCollection<Row>> combiner = combined;
-      boolean ignoreValues = false;
       if (combiner == null) {
         // If no field aggregations were specified, we run a constant combiner that always returns
         // a single empty row for each key. This is used by the SELECT DISTINCT query plan - in this
         // case a group by is generated to determine unique keys, and a constant null combiner is
         // used.
         combiner =
-            byFields.aggregateField(
+            initialCombiner.aggregateField(
                 "*",
                 AggregationCombineFnAdapter.createConstantCombineFn(),
                 Field.of(
@@ -283,15 +309,7 @@
                     FieldType.row(AggregationCombineFnAdapter.EMPTY_SCHEMA).withNullable(true)));
         ignoreValues = true;
       }
-
-      boolean verifyRowValues =
-          pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class).getVerifyRowValues();
-      return windowedStream
-          .apply(combiner)
-          .apply(
-              "mergeRecord",
-              ParDo.of(mergeRecord(outputSchema, windowFieldIndex, ignoreValues, verifyRowValues)))
-          .setRowSchema(outputSchema);
+      return combiner;
     }
 
     /** Extract timestamps from the windowFieldIndex, then window into windowFns. */
@@ -349,7 +367,6 @@
           if (!ignoreValues) {
             fieldValues.addAll(kvRow.getRow(1).getValues());
           }
-
           if (windowStartFieldIndex != -1) {
             fieldValues.add(windowStartFieldIndex, ((IntervalWindow) window).start());
           }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index 6dd76a3..dcc8ec9 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.extensions.sql.impl.transform;
 
+import java.io.Serializable;
 import java.math.BigDecimal;
 import java.math.MathContext;
 import java.math.RoundingMode;
@@ -24,6 +25,7 @@
 import java.util.function.Function;
 import org.apache.beam.sdk.coders.BigDecimalCoder;
 import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -56,16 +58,24 @@
       BUILTIN_AGGREGATOR_FACTORIES =
           ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
               .put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
-              .put("COUNT", typeName -> Count.combineFn())
-              .put("MAX", BeamBuiltinAggregations::createMax)
-              .put("MIN", BeamBuiltinAggregations::createMin)
-              .put("SUM", BeamBuiltinAggregations::createSum)
-              .put("$SUM0", BeamBuiltinAggregations::createSum)
-              .put("AVG", BeamBuiltinAggregations::createAvg)
-              .put("BIT_OR", BeamBuiltinAggregations::createBitOr)
-              .put("BIT_XOR", BeamBuiltinAggregations::createBitXOr)
+              // Drop null elements for these aggregations.
+              .put("COUNT", typeName -> new DropNullFn(Count.combineFn()))
+              .put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName)))
+              .put("MIN", typeName -> new DropNullFn(BeamBuiltinAggregations.createMin(typeName)))
+              .put("SUM", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum(typeName)))
+              .put(
+                  "$SUM0", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum0(typeName)))
+              .put("AVG", typeName -> new DropNullFn(BeamBuiltinAggregations.createAvg(typeName)))
+              .put(
+                  "BIT_OR",
+                  typeName -> new DropNullFn(BeamBuiltinAggregations.createBitOr(typeName)))
+              .put(
+                  "BIT_XOR",
+                  typeName -> new DropNullFn(BeamBuiltinAggregations.createBitXOr(typeName)))
               // JIRA link:https://issues.apache.org/jira/browse/BEAM-10379
-              .put("BIT_AND", BeamBuiltinAggregations::createBitAnd)
+              .put(
+                  "BIT_AND",
+                  typeName -> new DropNullFn(BeamBuiltinAggregations.createBitAnd(typeName)))
               .put("VAR_POP", t -> VarianceFn.newPopulation(t.getTypeName()))
               .put("VAR_SAMP", t -> VarianceFn.newSample(t.getTypeName()))
               .put("COVAR_POP", t -> CovarianceFn.newPopulation(t.getTypeName()))
@@ -150,7 +160,7 @@
       case BYTE:
         return new ByteSum();
       case INT64:
-        return Sum.ofLongs();
+        return new LongSum();
       case FLOAT:
         return new FloatSum();
       case DOUBLE:
@@ -163,6 +173,32 @@
     }
   }
 
+  /**
+   * {@link CombineFn} for Sum0 where sum of null returns 0 based on {@link Sum} and {@link
+   * Combine.BinaryCombineFn}.
+   */
+  static CombineFn createSum0(Schema.FieldType fieldType) {
+    switch (fieldType.getTypeName()) {
+      case INT32:
+        return new IntegerSum0();
+      case INT16:
+        return new ShortSum0();
+      case BYTE:
+        return new ByteSum0();
+      case INT64:
+        return new LongSum0();
+      case FLOAT:
+        return new FloatSum0();
+      case DOUBLE:
+        return new DoubleSum0();
+      case DECIMAL:
+        return new BigDecimalSum0();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not supported in SUM0", fieldType));
+    }
+  }
+
   /** {@link CombineFn} for AVG. */
   static CombineFn createAvg(Schema.FieldType fieldType) {
     switch (fieldType.getTypeName()) {
@@ -224,6 +260,13 @@
     }
   }
 
+  static class IntegerSum extends Combine.BinaryCombineFn<Integer> {
+    @Override
+    public Integer apply(Integer left, Integer right) {
+      return (int) (left + right);
+    }
+  }
+
   static class ShortSum extends Combine.BinaryCombineFn<Short> {
     @Override
     public Short apply(Short left, Short right) {
@@ -245,6 +288,20 @@
     }
   }
 
+  static class DoubleSum extends Combine.BinaryCombineFn<Double> {
+    @Override
+    public Double apply(Double left, Double right) {
+      return (double) left + right;
+    }
+  }
+
+  static class LongSum extends Combine.BinaryCombineFn<Long> {
+    @Override
+    public Long apply(Long left, Long right) {
+      return Math.addExact(left, right);
+    }
+  }
+
   static class BigDecimalSum extends Combine.BinaryCombineFn<BigDecimal> {
     @Override
     public BigDecimal apply(BigDecimal left, BigDecimal right) {
@@ -252,6 +309,90 @@
     }
   }
 
+  static class IntegerSum0 extends IntegerSum {
+    @Override
+    public @Nullable Integer identity() {
+      return 0;
+    }
+  }
+
+  static class ShortSum0 extends ShortSum {
+    @Override
+    public @Nullable Short identity() {
+      return 0;
+    }
+  }
+
+  static class ByteSum0 extends ByteSum {
+    @Override
+    public @Nullable Byte identity() {
+      return 0;
+    }
+  }
+
+  static class FloatSum0 extends FloatSum {
+    @Override
+    public @Nullable Float identity() {
+      return 0F;
+    }
+  }
+
+  static class DoubleSum0 extends DoubleSum {
+    @Override
+    public @Nullable Double identity() {
+      return 0D;
+    }
+  }
+
+  static class LongSum0 extends LongSum {
+    @Override
+    public @Nullable Long identity() {
+      return 0L;
+    }
+  }
+
+  static class BigDecimalSum0 extends BigDecimalSum {
+    @Override
+    public @Nullable BigDecimal identity() {
+      return BigDecimal.ZERO;
+    }
+  }
+
+  private static class DropNullFn<InputT, AccumT, OutputT>
+      extends CombineFn<InputT, AccumT, OutputT> {
+    private final CombineFn<InputT, AccumT, OutputT> combineFn;
+
+    DropNullFn(CombineFn<InputT, AccumT, OutputT> combineFn) {
+      this.combineFn = combineFn;
+    }
+
+    @Override
+    public AccumT createAccumulator() {
+      return combineFn.createAccumulator();
+    }
+
+    @Override
+    public AccumT addInput(AccumT accumulator, InputT input) {
+      return (input == null) ? accumulator : combineFn.addInput(accumulator, input);
+    }
+
+    @Override
+    public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+      return combineFn.mergeAccumulators(accumulators);
+    }
+
+    @Override
+    public OutputT extractOutput(AccumT accumulator) {
+      return combineFn.extractOutput(accumulator);
+    }
+
+    @Override
+    public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
+        throws CannotProvideCoderException {
+      return combineFn.getAccumulatorCoder(registry, inputCoder);
+    }
+  }
+
   /** {@link CombineFn} for <em>AVG</em> on {@link Number} types. */
   abstract static class Avg<T extends Number> extends CombineFn<T, KV<Integer, BigDecimal>, T> {
     @Override
@@ -376,29 +517,45 @@
     }
   }
 
-  static class BitOr<T extends Number> extends CombineFn<T, Long, Long> {
-    @Override
-    public Long createAccumulator() {
-      return 0L;
+  static class BitOr<T extends Number> extends CombineFn<T, BitOr.Accum, Long> {
+    static class Accum implements Serializable {
+      /** True if no inputs have been seen yet. */
+      boolean isEmpty = true;
+      /** The bitwise-or of the inputs seen so far. */
+      long bitOr = 0L;
     }
 
     @Override
-    public Long addInput(Long accum, T input) {
-      return accum | input.longValue();
+    public Accum createAccumulator() {
+      return new Accum();
     }
 
     @Override
-    public Long mergeAccumulators(Iterable<Long> accums) {
-      Long merged = createAccumulator();
-      for (Long accum : accums) {
-        merged = merged | accum;
+    public Accum addInput(Accum accum, T input) {
+      accum.isEmpty = false;
+      accum.bitOr |= input.longValue();
+      return accum;
+    }
+
+    @Override
+    public Accum mergeAccumulators(Iterable<Accum> accums) {
+      Accum merged = createAccumulator();
+      for (Accum accum : accums) {
+        if (accum.isEmpty) {
+          continue;
+        }
+        merged.isEmpty = false;
+        merged.bitOr |= accum.bitOr;
       }
       return merged;
     }
 
     @Override
-    public Long extractOutput(Long accum) {
-      return accum;
+    public Long extractOutput(Accum accum) {
+      if (accum.isEmpty) {
+        return null;
+      }
+      return accum.bitOr;
     }
   }
 
@@ -409,14 +566,9 @@
    * (https://issues.apache.org/jira/browse/BEAM-10379)
    */
   static class BitAnd<T extends Number> extends CombineFn<T, BitAnd.Accum, Long> {
-    static class Accum {
+    static class Accum implements Serializable {
       /** True if no inputs have been seen yet. */
       boolean isEmpty = true;
-      /**
-       * True if any null inputs have been seen. If we see a single null value, the end result is
-       * null, so if isNull is true, isEmpty and bitAnd are ignored.
-       */
-      boolean isNull = false;
       /** The bitwise-and of the inputs seen so far. */
       long bitAnd = -1L;
     }
@@ -428,13 +580,6 @@
 
     @Override
     public Accum addInput(Accum accum, T input) {
-      if (accum.isNull) {
-        return accum;
-      }
-      if (input == null) {
-        accum.isNull = true;
-        return accum;
-      }
       accum.isEmpty = false;
       accum.bitAnd &= input.longValue();
       return accum;
@@ -444,9 +589,6 @@
     public Accum mergeAccumulators(Iterable<Accum> accums) {
       Accum merged = createAccumulator();
       for (Accum accum : accums) {
-        if (accum.isNull) {
-          return accum;
-        }
         if (accum.isEmpty) {
           continue;
         }
@@ -458,41 +600,55 @@
 
     @Override
     public Long extractOutput(Accum accum) {
-      if (accum.isEmpty || accum.isNull) {
+      if (accum.isEmpty) {
         return null;
       }
       return accum.bitAnd;
     }
   }
 
-  public static class BitXOr<T extends Number> extends CombineFn<T, Long, Long> {
+  public static class BitXOr<T extends Number> extends CombineFn<T, BitXOr.Accum, Long> {
 
-    @Override
-    public Long createAccumulator() {
-      return 0L;
+    static class Accum implements Serializable {
+      /** True if no inputs have been seen yet. */
+      boolean isEmpty = true;
+      /** The bitwise-and of the inputs seen so far. */
+      long bitXOr = 0L;
     }
 
     @Override
-    public Long addInput(Long mutableAccumulator, T input) {
+    public Accum createAccumulator() {
+      return new Accum();
+    }
+
+    @Override
+    public Accum addInput(Accum mutableAccumulator, T input) {
       if (input != null) {
-        return mutableAccumulator ^ input.longValue();
-      } else {
-        return 0L;
+        mutableAccumulator.isEmpty = false;
+        mutableAccumulator.bitXOr ^= input.longValue();
       }
+      return mutableAccumulator;
     }
 
     @Override
-    public Long mergeAccumulators(Iterable<Long> accumulators) {
-      Long merged = createAccumulator();
-      for (Long accum : accumulators) {
-        merged = merged ^ accum;
+    public Accum mergeAccumulators(Iterable<Accum> accumulators) {
+      Accum merged = createAccumulator();
+      for (Accum accum : accumulators) {
+        if (accum.isEmpty) {
+          continue;
+        }
+        merged.isEmpty = false;
+        merged.bitXOr ^= accum.bitXOr;
       }
       return merged;
     }
 
     @Override
-    public Long extractOutput(Long accumulator) {
-      return accumulator;
+    public Long extractOutput(Accum accumulator) {
+      if (accumulator.isEmpty) {
+        return null;
+      }
+      return accumulator.bitXOr;
     }
   }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
index 2178a2d..efef9d5 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
@@ -51,10 +51,7 @@
 
     @Override
     public Object addInput(Object accumulator, T input) {
-      T processedInput = getInput(input);
-      return (processedInput == null)
-          ? accumulator
-          : combineFn.addInput(accumulator, getInput(input));
+      return combineFn.addInput(accumulator, getInput(input));
     }
 
     @Override
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java
index 92b05b2..51977ed 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.extensions.sql.impl.transform.agg;
 
+import java.io.Serializable;
 import org.apache.beam.sdk.transforms.Combine;
 
 /**
@@ -32,7 +33,7 @@
 
   public static class CountIfFn extends Combine.CombineFn<Boolean, CountIfFn.Accum, Long> {
 
-    public static class Accum {
+    public static class Accum implements Serializable {
       boolean isExpressionFalse = true;
       long countIfResult = 0L;
     }
@@ -56,6 +57,7 @@
       CountIfFn.Accum merged = createAccumulator();
       for (CountIfFn.Accum accum : accums) {
         if (!accum.isExpressionFalse) {
+          merged.isExpressionFalse = false;
           merged.countIfResult += accum.countIfResult;
         }
       }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java
index f8112e2..14ceb3c 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java
@@ -124,6 +124,11 @@
     BigDecimal adjustedCount =
         this.isSample ? covariance.count().subtract(BigDecimal.ONE) : covariance.count();
 
+    // Avoid ArithmeticException: Division is undefined when adjustedCount is 0
+    if (adjustedCount.equals(BigDecimal.ZERO)) {
+      return BigDecimal.ZERO;
+    }
+
     return covariance.covariance().divide(adjustedCount, MATH_CTX);
   }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
index ae1d9c8..032c9b0 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
@@ -141,6 +141,11 @@
     BigDecimal adjustedCount =
         this.isSample ? variance.count().subtract(BigDecimal.ONE) : variance.count();
 
+    // Avoid ArithmeticException: Division is undefined when adjustedCount is 0
+    if (adjustedCount.equals(BigDecimal.ZERO)) {
+      return BigDecimal.ZERO;
+    }
+
     return variance.variance().divide(adjustedCount, MATH_CTX);
   }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/ArrayAgg.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/ArrayAgg.java
index 721b4fa..e5ec5a8 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/ArrayAgg.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/ArrayAgg.java
@@ -20,26 +20,27 @@
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.sdk.transforms.Combine;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 public class ArrayAgg {
 
-  public static class ArrayAggArray extends Combine.CombineFn<Object, List<Object>, List<Object>> {
+  public static class ArrayAggArray<T> extends Combine.CombineFn<T, List<T>, @Nullable List<T>> {
     @Override
-    public List<Object> createAccumulator() {
+    public List<T> createAccumulator() {
       return new ArrayList<>();
     }
 
     @Override
-    public List<Object> addInput(List<Object> accum, Object input) {
+    public List<T> addInput(List<T> accum, T input) {
       accum.add(input);
       return accum;
     }
 
     @Override
-    public List<Object> mergeAccumulators(Iterable<List<Object>> accums) {
-      List<Object> merged = new ArrayList<>();
-      for (List<Object> accum : accums) {
-        for (Object o : accum) {
+    public List<T> mergeAccumulators(Iterable<List<T>> accums) {
+      List<T> merged = new ArrayList<>();
+      for (List<T> accum : accums) {
+        for (T o : accum) {
           merged.add(o);
         }
       }
@@ -47,7 +48,10 @@
     }
 
     @Override
-    public List<Object> extractOutput(List<Object> accumulator) {
+    public @Nullable List<T> extractOutput(List<T> accumulator) {
+      if (accumulator.isEmpty()) {
+        return null;
+      }
       return accumulator;
     }
   }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
index 206dde0..5f2b574 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
@@ -157,6 +157,113 @@
     input.apply(SqlTransform.query(sql));
   }
 
+  /** Multiple aggregation functions with unbounded PCollection. */
+  @Test
+  public void testAggregationNonGroupedFunctionsWithUnbounded() throws Exception {
+    runAggregationFunctionsWithoutGroup(unboundedInput1);
+  }
+
+  /** Multiple aggregation functions with bounded PCollection. */
+  @Test
+  public void testAggregationNonGroupedFunctionsWithBounded() throws Exception {
+    runAggregationFunctionsWithoutGroup(boundedInput1);
+  }
+
+  private void runAggregationFunctionsWithoutGroup(PCollection<Row> input) throws Exception {
+    String sql =
+        "select count(*) as getFieldCount, "
+            + "sum(f_long) as sum1, avg(f_long) as avg1, "
+            + "max(f_long) as max1, min(f_long) as min1, "
+            + "sum(f_short) as sum2, avg(f_short) as avg2, "
+            + "max(f_short) as max2, min(f_short) as min2, "
+            + "sum(f_byte) as sum3, avg(f_byte) as avg3, "
+            + "max(f_byte) as max3, min(f_byte) as min3, "
+            + "sum(f_float) as sum4, avg(f_float) as avg4, "
+            + "max(f_float) as max4, min(f_float) as min4, "
+            + "sum(f_double) as sum5, avg(f_double) as avg5, "
+            + "max(f_double) as max5, min(f_double) as min5, "
+            + "max(f_timestamp) as max6, min(f_timestamp) as min6, "
+            + "max(f_string) as max7, min(f_string) as min7, "
+            + "var_pop(f_double) as varpop1, var_samp(f_double) as varsamp1, "
+            + "var_pop(f_int) as varpop2, var_samp(f_int) as varsamp2 "
+            + "FROM TABLE_A";
+
+    PCollection<Row> result =
+        PCollectionTuple.of(new TupleTag<>("TABLE_A"), input)
+            .apply("testAggregationFunctions", SqlTransform.query(sql));
+
+    Schema resultType =
+        Schema.builder()
+            .addInt64Field("size")
+            .addInt64Field("sum1")
+            .addInt64Field("avg1")
+            .addInt64Field("max1")
+            .addInt64Field("min1")
+            .addInt16Field("sum2")
+            .addInt16Field("avg2")
+            .addInt16Field("max2")
+            .addInt16Field("min2")
+            .addByteField("sum3")
+            .addByteField("avg3")
+            .addByteField("max3")
+            .addByteField("min3")
+            .addFloatField("sum4")
+            .addFloatField("avg4")
+            .addFloatField("max4")
+            .addFloatField("min4")
+            .addDoubleField("sum5")
+            .addDoubleField("avg5")
+            .addDoubleField("max5")
+            .addDoubleField("min5")
+            .addDateTimeField("max6")
+            .addDateTimeField("min6")
+            .addStringField("max7")
+            .addStringField("min7")
+            .addDoubleField("varpop1")
+            .addDoubleField("varsamp1")
+            .addInt32Field("varpop2")
+            .addInt32Field("varsamp2")
+            .build();
+
+    Row row =
+        Row.withSchema(resultType)
+            .addValues(
+                4L,
+                10000L,
+                2500L,
+                4000L,
+                1000L,
+                (short) 10,
+                (short) 2,
+                (short) 4,
+                (short) 1,
+                (byte) 10,
+                (byte) 2,
+                (byte) 4,
+                (byte) 1,
+                10.0F,
+                2.5F,
+                4.0F,
+                1.0F,
+                10.0,
+                2.5,
+                4.0,
+                1.0,
+                parseTimestampWithoutTimeZone("2017-01-01 02:04:03"),
+                parseTimestampWithoutTimeZone("2017-01-01 01:01:03"),
+                "第四行",
+                "string_row1",
+                1.25,
+                1.666666667,
+                1,
+                1)
+            .build();
+
+    PAssert.that(result).containsInAnyOrder(row);
+
+    pipeline.run().waitUntilFinish();
+  }
+
   /** GROUP-BY with multiple aggregation functions with bounded PCollection. */
   @Test
   public void testAggregationFunctionsWithBounded() throws Exception {
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java
index 6d9d114..65dee35 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java
@@ -412,11 +412,9 @@
           FunctionSignatureId.FN_SUM_DOUBLE, // sum
           FunctionSignatureId.FN_SUM_NUMERIC, // sum
           // FunctionSignatureId.FN_SUM_BIGNUMERIC, // sum
-          // JIRA link: https://issues.apache.org/jira/browse/BEAM-10379
-          // FunctionSignatureId.FN_BIT_AND_INT64, // bit_and
+          FunctionSignatureId.FN_BIT_AND_INT64, // bit_and
           FunctionSignatureId.FN_BIT_OR_INT64, // bit_or
-          // TODO(BEAM-10379) Re-enable when nulls are handled properly.
-          // FunctionSignatureId.FN_BIT_XOR_INT64, // bit_xor
+          FunctionSignatureId.FN_BIT_XOR_INT64, // bit_xor
           // FunctionSignatureId.FN_LOGICAL_AND, // logical_and
           // FunctionSignatureId.FN_LOGICAL_OR, // logical_or
           // Approximate aggregate functions.
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
index bd73241..ef00982 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
@@ -41,7 +41,6 @@
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.AggregateCall;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.logical.LogicalProject;
-import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.type.RelDataType;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexNode;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.SqlAggFunction;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.type.SqlReturnTypeInference;
@@ -92,12 +91,9 @@
       aggregateCalls = new ArrayList<>();
       // For aggregate calls, their input ref follow after GROUP BY input ref.
       int columnRefoff = groupFieldsListSize;
-      boolean nullable = false;
-      if (input.getProjects().size() > columnRefoff) {
-        nullable = input.getProjects().get(columnRefoff).getType().isNullable();
-      }
       for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
-        AggregateCall aggCall = convertAggCall(computedColumn, columnRefoff, nullable);
+        AggregateCall aggCall =
+            convertAggCall(computedColumn, columnRefoff, groupSet.size(), input);
         aggregateCalls.add(aggCall);
         if (!aggCall.getArgList().isEmpty()) {
           // Only increment column reference offset when aggregates use them (BEAM-8042).
@@ -177,7 +173,7 @@
   }
 
   private AggregateCall convertAggCall(
-      ResolvedComputedColumn computedColumn, int columnRefOff, boolean nullable) {
+      ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
     ResolvedAggregateFunctionCall aggregateFunctionCall =
         (ResolvedAggregateFunctionCall) computedColumn.getExpr();
 
@@ -259,12 +255,19 @@
       }
     }
 
-    RelDataType returnType =
-        ZetaSqlCalciteTranslationUtils.toCalciteType(
-            computedColumn.getColumn().getType(), nullable, getCluster().getRexBuilder());
-
     String aggName = getTrait().resolveAlias(computedColumn.getColumn());
     return AggregateCall.create(
-        sqlAggFunction, false, false, false, argList, -1, RelCollations.EMPTY, returnType, aggName);
+        sqlAggFunction,
+        false,
+        false,
+        false,
+        argList,
+        -1,
+        RelCollations.EMPTY,
+        groupCount,
+        input,
+        // When we pass null as the return type, Calcite infers it for us.
+        null,
+        aggName);
   }
 }
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
index c59df96..01b23eb 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
@@ -72,8 +72,7 @@
           .put("sum", SqlStdOperatorTable.SUM)
           .put("any_value", SqlStdOperatorTable.ANY_VALUE)
           .put("count", SqlStdOperatorTable.COUNT)
-          // .put("bit_and", SqlStdOperatorTable.BIT_AND) //JIRA link:
-          // https://issues.apache.org/jira/browse/BEAM-10379
+          .put("bit_and", SqlStdOperatorTable.BIT_AND)
           .put("string_agg", SqlOperators.STRING_AGG_STRING_FN) // NULL values not supported
           .put("array_agg", SqlOperators.ARRAY_AGG_FN)
           .put("bit_or", SqlStdOperatorTable.BIT_OR)
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
index b44f63e..29dfac4 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
@@ -48,6 +48,7 @@
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.SqlOperator;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.SqlSyntax;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.type.ArraySqlType;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.type.FamilyOperandTypeChecker;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.type.InferTypes;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql.type.OperandTypes;
@@ -89,8 +90,8 @@
   public static final SqlOperator ARRAY_AGG_FN =
       createUdafOperator(
           "array_agg",
-          x -> createTypeFactory().createArrayType(x.getOperandType(0), -1),
-          new UdafImpl<>(new ArrayAgg.ArrayAggArray()));
+          x -> new ArraySqlType(x.getOperandType(0), true),
+          new UdafImpl<>(new ArrayAgg.ArrayAggArray<>()));
 
   public static final SqlOperator START_WITHS =
       createUdfOperator(
@@ -161,7 +162,7 @@
   public static final SqlOperator BIT_XOR =
       createUdafOperator(
           "BIT_XOR",
-          x -> createTypeFactory().createSqlType(SqlTypeName.BIGINT),
+          x -> NULLABLE_BIGINT,
           new UdafImpl<>(new BeamBuiltinAggregations.BitXOr<Number>()));
 
   public static final SqlOperator COUNTIF =
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
index a802b30..35f8e22 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
@@ -3840,7 +3840,20 @@
   }
 
   @Test
-  @Ignore("NULL values don't work correctly. (https://issues.apache.org/jira/browse/BEAM-10379)")
+  public void testZetaSQLBitOrNull() {
+    String sql =
+        "SELECT bit_or(CAST(x as int64)) FROM "
+            + "(SELECT NULL x UNION ALL SELECT 5 UNION ALL SELECT 6);";
+
+    PCollection<Row> stream = execute(sql);
+
+    final Schema schema = Schema.builder().addInt64Field("field1").build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue(7L).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
   public void testZetaSQLBitAnd() {
     String sql = "SELECT BIT_AND(row_id) FROM table_all_types GROUP BY bool_col";
 
@@ -3851,6 +3864,124 @@
         .containsInAnyOrder(
             Row.withSchema(schema).addValue(1L).build(),
             Row.withSchema(schema).addValue(0L).build());
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testZetaSQLBitAndInt64() {
+    String sql = "SELECT bit_and(CAST(x as int64)) FROM (SELECT 1 x FROM (SELECT 1) WHERE false)";
+
+    PCollection<Row> stream = execute(sql);
+
+    final Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testZetaSQLBitAndNulls() {
+    String sql =
+        "SELECT bit_and(CAST(x as int64)) FROM "
+            + "(SELECT NULL x UNION ALL SELECT 5 UNION ALL SELECT 6)";
+
+    PCollection<Row> stream = execute(sql);
+
+    final Schema schema = Schema.builder().addInt64Field("field1").build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue(4L).build());
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testCountEmpty() {
+    String sql = "SELECT COUNT(x) FROM UNNEST([]) AS x";
+
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addInt64Field("field1").build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue(0L).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testBitwiseOrEmpty() {
+    String sql = "SELECT BIT_OR(x) FROM UNNEST([]) AS x";
+
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testArrayAggNulls() {
+    String sql = "SELECT ARRAY_AGG(x) FROM UNNEST([1, NULL, 3]) AS x";
+
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema =
+        Schema.builder()
+            .addField(
+                Field.of(
+                    "field1",
+                    FieldType.array(FieldType.of(Schema.TypeName.INT64).withNullable(true))))
+            .build();
+    PAssert.that(stream)
+        .containsInAnyOrder(Row.withSchema(schema).addArray(1L, (Long) null, 3L).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testArrayAggEmpty() {
+    String sql = "SELECT ARRAY_AGG(x) FROM UNNEST([]) AS x";
+
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testInt64SumOverflow() {
+    String sql =
+        "SELECT SUM(col1)\n"
+            + "FROM (SELECT CAST(9223372036854775807 as int64) as col1 UNION ALL\n"
+            + "      SELECT CAST(1 as int64))\n";
+
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+    BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+    thrown.expect(RuntimeException.class);
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testInt64SumUnderflow() {
+    String sql =
+        "SELECT SUM(col1)\n"
+            + "FROM (SELECT CAST(-9223372036854775808 as int64) as col1 UNION ALL\n"
+            + "      SELECT CAST(-1 as int64))\n";
+
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+    BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+    thrown.expect(RuntimeException.class);
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testZetaSQLSumNulls() {
+    String sql = "SELECT SUM(x) AS sum FROM UNNEST([null, null, null]) AS x";
+
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
 
     pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
   }
@@ -3870,8 +4001,6 @@
   }
 
   @Test
-  @Ignore(
-      "Null values are not handled properly, so BIT_XOR is temporarily removed from SupportedZetaSqlBuiltinFunctions. https://issues.apache.org/jira/browse/BEAM-10379")
   public void testZetaSQLBitXor() {
     String sql = "SELECT BIT_XOR(x) AS bit_xor FROM UNNEST([5678, 1234]) AS x";
     PCollection<Row> stream = execute(sql);
@@ -3883,6 +4012,28 @@
   }
 
   @Test
+  public void testZetaSQLBitXorEmpty() {
+    String sql = "SELECT bit_xor(CAST(x as int64)) FROM (SELECT 1 x FROM (SELECT 1) WHERE false);";
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testZetaSQLBitXorNull() {
+    String sql = "SELECT bit_xor(x) FROM (SELECT CAST(NULL AS int64) x);";
+    PCollection<Row> stream = execute(sql);
+
+    Schema schema = Schema.builder().addNullableField("field1", FieldType.INT64).build();
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addValue((Long) null).build());
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
   public void testCountIfZetaSQLDialect() {
     String sql =
         "WITH is_positive AS ( SELECT x > 0 flag FROM UNNEST([5, -2, 3, 6, -10, -7, 4, 0]) AS x) "
diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle
index 116a12c..937feb0 100644
--- a/sdks/java/io/amazon-web-services2/build.gradle
+++ b/sdks/java/io/amazon-web-services2/build.gradle
@@ -53,6 +53,7 @@
   compile library.java.commons_lang3
   compile library.java.http_core
   compile library.java.commons_codec
+  runtime library.java.aws_java_sdk2_sts
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   testCompile project(path: ":sdks:java:io:common", configuration: "testRuntime")
   testCompile project(path: ":sdks:java:io:kinesis", configuration: "testRuntime")
diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py
index bfcb47f..8b47b0b 100644
--- a/sdks/python/apache_beam/runners/direct/executor.py
+++ b/sdks/python/apache_beam/runners/direct/executor.py
@@ -23,7 +23,6 @@
 import itertools
 import logging
 import queue
-import sys
 import threading
 import traceback
 from typing import TYPE_CHECKING
@@ -478,8 +477,7 @@
     update = self.visible_updates.take()
     try:
       if update.exception:
-        t, v, tb = update.exc_info
-        raise t(v).with_traceback(tb)
+        raise update.exception
     finally:
       self.executor_service.shutdown()
       self.executor_service.await_completion()
@@ -576,10 +574,6 @@
       self.committed_bundle = committed_bundle
       self.unprocessed_bundle = unprocessed_bundle
       self.exception = exception
-      self.exc_info = sys.exc_info()
-      if self.exc_info[1] is not exception:
-        # Not the right exception.
-        self.exc_info = (exception, None, None)
 
   class _VisibleExecutorUpdate(object):
     """An update of interest to the user.
@@ -587,10 +581,9 @@
     Used for awaiting the completion to decide whether to return normally or
     raise an exception.
     """
-    def __init__(self, exc_info=(None, None, None)):
-      self.finished = exc_info[0] is not None
-      self.exception = exc_info[1] or exc_info[0]
-      self.exc_info = exc_info
+    def __init__(self, exception=None):
+      self.finished = exception is not None
+      self.exception = exception
 
   class _MonitorTask(_ExecutorService.CallableTask):
     """MonitorTask continuously runs to ensure that pipeline makes progress."""
@@ -618,7 +611,7 @@
                 'A task failed with exception: %s', update.exception)
             self._executor.visible_updates.offer(
                 _ExecutorServiceParallelExecutor._VisibleExecutorUpdate(
-                    update.exc_info))
+                    update.exception))
           update = self._executor.all_updates.poll()
         self._executor.evaluation_context.schedule_pending_unblocked_tasks(
             self._executor.executor_service)
@@ -626,8 +619,7 @@
       except Exception as e:  # pylint: disable=broad-except
         _LOGGER.error('Monitor task died due to exception.\n %s', e)
         self._executor.visible_updates.offer(
-            _ExecutorServiceParallelExecutor._VisibleExecutorUpdate(
-                sys.exc_info()))
+            _ExecutorServiceParallelExecutor._VisibleExecutorUpdate(e))
       finally:
         if not self._should_shutdown():
           self._executor.executor_service.submit(self)
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
index f809703..ec7c78e 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py
@@ -424,6 +424,21 @@
         {CacheKey.from_pcoll('squares', squares).to_str()})
 
   def test_clear(self):
+    p1 = beam.Pipeline(InteractiveRunner())
+    elems_1 = p1 | 'elems 1' >> beam.Create([0, 1, 2])
+
+    ib.watch(locals())
+    ie.current_env().track_user_pipelines()
+
+    recording_manager = RecordingManager(p1)
+    recording = recording_manager.record([elems_1], max_n=3, max_duration=500)
+    recording.wait_until_finish()
+    record_describe = recording_manager.describe()
+    self.assertGreater(record_describe['size'], 0)
+    recording_manager.clear()
+    self.assertEqual(recording_manager.describe()['size'], 0)
+
+  def test_clear_specific_pipeline(self):
     """Tests that clear can empty the cache for a specific pipeline."""
 
     # Create two pipelines so we can check that clearing the cache won't clear
@@ -448,16 +463,18 @@
     rm_2 = RecordingManager(p2)
     recording = rm_2.record([elems_2], max_n=3, max_duration=500)
     recording.wait_until_finish()
-
     # Assert that clearing only one recording clears that recording.
-    self.assertGreater(rm_1.describe()['size'], 0)
-    self.assertGreater(rm_2.describe()['size'], 0)
-    rm_1.clear()
-    self.assertEqual(rm_1.describe()['size'], 0)
-    self.assertGreater(rm_2.describe()['size'], 0)
+    if rm_1.describe()['state'] == PipelineState.STOPPED \
+            and rm_2.describe()['state'] == PipelineState.STOPPED:
 
-    rm_2.clear()
-    self.assertEqual(rm_2.describe()['size'], 0)
+      self.assertGreater(rm_1.describe()['size'], 0)
+      self.assertGreater(rm_2.describe()['size'], 0)
+      rm_1.clear()
+      self.assertEqual(rm_1.describe()['size'], 0)
+      self.assertGreater(rm_2.describe()['size'], 0)
+
+      rm_2.clear()
+      self.assertEqual(rm_2.describe()['size'], 0)
 
   def test_record_pipeline(self):
     # Add the TestStream so that it can be cached.
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
index e89a669..d395cee 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -24,10 +24,8 @@
 import collections
 import logging
 import queue
-import sys
 import threading
 import time
-from types import TracebackType
 from typing import TYPE_CHECKING
 from typing import Any
 from typing import Callable
@@ -41,7 +39,6 @@
 from typing import Optional
 from typing import Set
 from typing import Tuple
-from typing import Type
 from typing import Union
 
 import grpc
@@ -60,9 +57,6 @@
 else:
   OutputStream = type(coder_impl.create_OutputStream())
 
-ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
-OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
-
 # This module is experimental. No backwards-compatibility guarantees.
 
 _LOGGER = logging.getLogger(__name__)
@@ -429,7 +423,7 @@
     self._receive_lock = threading.Lock()
     self._reads_finished = threading.Event()
     self._closed = False
-    self._exc_info = (None, None, None)  # type: OptExcInfo
+    self._exception = None  # type: Optional[Exception]
 
   def close(self):
     # type: () -> None
@@ -497,9 +491,8 @@
             raise RuntimeError('Channel closed prematurely.')
           if abort_callback():
             return
-          t, v, tb = self._exc_info
-          if t:
-            raise t(v).with_traceback(tb)
+          if self._exception:
+            raise self._exception from None
         else:
           if isinstance(element, beam_fn_api_pb2.Elements.Timers):
             if element.is_last:
@@ -644,10 +637,10 @@
           _put_queue(timer.instruction_id, timer)
         for data in elements.data:
           _put_queue(data.instruction_id, data)
-    except:  # pylint: disable=bare-except
+    except Exception as e:
       if not self._closed:
         _LOGGER.exception('Failed to read inputs in the data plane.')
-        self._exc_info = sys.exc_info()
+        self._exception = e
         raise
     finally:
       self._closed = True
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 4ad0727..6fdca4d 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -31,7 +31,6 @@
 import time
 import traceback
 from concurrent import futures
-from types import TracebackType
 from typing import TYPE_CHECKING
 from typing import Any
 from typing import Callable
@@ -45,7 +44,6 @@
 from typing import MutableMapping
 from typing import Optional
 from typing import Tuple
-from typing import Type
 from typing import TypeVar
 from typing import Union
 
@@ -73,9 +71,6 @@
   from apache_beam.portability.api import endpoints_pb2
   from apache_beam.utils.profiler import Profile
 
-ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
-OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
-
 T = TypeVar('T')
 _KT = TypeVar('_KT')
 _VT = TypeVar('_VT')
@@ -1002,7 +997,7 @@
     )  # type: queue.Queue[Union[beam_fn_api_pb2.StateRequest, Sentinel]]
     self._responses_by_id = {}  # type: Dict[str, _Future]
     self._last_id = 0
-    self._exc_info = None  # type: Optional[OptExcInfo]
+    self._exception = None  # type: Optional[Exception]
     self._context = threading.local()
     self.start()
 
@@ -1041,8 +1036,8 @@
           future.set(response)
           if self._done:
             break
-      except:  # pylint: disable=bare-except
-        self._exc_info = sys.exc_info()
+      except Exception as e:
+        self._exception = e
         raise
 
     reader = threading.Thread(target=pull_responses, name='read_state')
@@ -1099,10 +1094,8 @@
     # type: (beam_fn_api_pb2.StateRequest) -> beam_fn_api_pb2.StateResponse
     req_future = self._request(request)
     while not req_future.wait(timeout=1):
-      if self._exc_info:
-        t, v, tb = self._exc_info
-        if t and v and tb:
-          raise t(v).with_traceback(tb)
+      if self._exception:
+        raise self._exception
       elif self._done:
         raise RuntimeError()
     response = req_future.get()
diff --git a/website/www/site/content/en/blog/beam-2.32.0.md b/website/www/site/content/en/blog/beam-2.32.0.md
index d0d3b12..39c6f89 100644
--- a/website/www/site/content/en/blog/beam-2.32.0.md
+++ b/website/www/site/content/en/blog/beam-2.32.0.md
@@ -88,8 +88,10 @@
 
 ## Deprecations
 
+* Python GBK will stop supporting unbounded PCollections that have global windowing and a default trigger in Beam 2.33. This can be overriden with `--allow_unsafe_triggers`. ([BEAM-9487](https://issues.apache.org/jira/browse/BEAM-9487)).
+* Python GBK will start requiring safe triggers or the `--allow_unsafe_triggers` flag starting with Beam 2.33. ([BEAM-9487](https://issues.apache.org/jira/browse/BEAM-9487)).
 
-## Known Issues
+## Bugfixes
 
 * Fixed race condition in RabbitMqIO causing duplicate acks (Java) ([BEAM-6516](https://issues.apache.org/jira/browse/BEAM-6516)))
 
diff --git a/website/www/site/content/en/documentation/dsls/dataframes/differences-from-pandas.md b/website/www/site/content/en/documentation/dsls/dataframes/differences-from-pandas.md
index da0054e..fb432e9 100644
--- a/website/www/site/content/en/documentation/dsls/dataframes/differences-from-pandas.md
+++ b/website/www/site/content/en/documentation/dsls/dataframes/differences-from-pandas.md
@@ -85,7 +85,7 @@
 
 ## Using Interactive Beam to access the full pandas API
 
-Interactive Beam is a module designed for use in interactive notebooks. The module, which by convention is imported as `ib`, provides an `ib.collect` function that brings a `PCollection` or deferred DataFrrame into local memory as a pandas DataFrame. After using `ib.collect` to materialize a deferred DataFrame you will be able to perform any operation in the pandas API, not just those that are supported in Beam.
+Interactive Beam is a module designed for use in interactive notebooks. The module, which by convention is imported as `ib`, provides an `ib.collect` function that brings a `PCollection` or deferred DataFrame into local memory as a pandas DataFrame. After using `ib.collect` to materialize a deferred DataFrame you will be able to perform any operation in the pandas API, not just those that are supported in Beam.
 
 <!-- TODO: Add code sample: https://issues.apache.org/jira/browse/BEAM-12535 -->
 
diff --git a/website/www/site/content/en/documentation/glossary.md b/website/www/site/content/en/documentation/glossary.md
index 4bc7395..acca8a3 100644
--- a/website/www/site/content/en/documentation/glossary.md
+++ b/website/www/site/content/en/documentation/glossary.md
@@ -308,6 +308,14 @@
 * [Overview](/documentation/programming-guide/#overview)
 * [Transforms](/documentation/programming-guide/#transforms)
 
+## Resource hints
+
+A Beam feature that lets you provide information to a runner about the compute resource requirements of your pipeline. You can use resource hints to define requirements for specific transforms or for an entire pipeline. For example, you could use a resource hint to specify the minimum amount of memory to allocate to workers. The runner is responsible for interpreting resource hints, and runners can ignore unsupported hints.
+
+To learn more, see:
+
+* [Resource hints](/documentation/runtime/resource-hints)
+
 ## Runner
 
 A runner runs a pipeline on a specific platform. Most runners are translators or adapters to massively parallel big data processing systems. Other runners exist for local testing and debugging. Among the supported runners are Google Cloud Dataflow, Apache Spark, Apache Samza, Apache Flink, the Interactive Runner, and the Direct Runner.
diff --git a/website/www/site/content/en/documentation/runtime/resource-hints.md b/website/www/site/content/en/documentation/runtime/resource-hints.md
new file mode 100644
index 0000000..38357de
--- /dev/null
+++ b/website/www/site/content/en/documentation/runtime/resource-hints.md
@@ -0,0 +1,81 @@
+---
+title: "Resource hints"
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Resource hints
+
+Resource hints let pipeline authors provide information to a runner about compute resource requirements. You can use resource hints to define requirements for specific transforms or for an entire pipeline. The runner is responsible for interpreting resource hints, and runners can ignore unsupported hints.
+
+Resource hints can be nested. For example, resource hints can be specified on subtransforms of a composite transform, and that composite transform can also have resource hints applied. By default, the innermost hint takes precedence. However, hints can define custom reconciliation behavior. For example,  `min_ram` takes the maximum value for all `min_ram` values set on a given step in the pipeline.
+
+{{< language-switcher java py >}}
+
+## Available hints
+
+Currently, Beam supports the following resource hints:
+
+* `min_ram="numberXB"`: The minimum amount of RAM to allocate to workers. Beam can parse various byte units, including MB, GB, MiB, and GiB (for example, `min_ram="4GB"`). This hint is intended to provide advisory minimal memory requirements for processing a transform.
+* `accelerator="hint"`: This hint is intended to describe a hardware accelerator to use for processing a transform. For example, the following is valid accelerator syntax for the Dataflow runner: `accelerator="type:<type>;count:<n>;<options>"`
+
+The interpretaton and actuation of resource hints can vary between runners. For an example implementation, see the [Dataflow resource hints](https://cloud.google.com/dataflow/docs/guides/right-fitting#available_resource_hints).
+
+## Specifying resource hints for a pipeline
+
+To specify resource hints for an entire pipeline, you can use pipeline options. The following command shows the basic syntax.
+
+{{< highlight java >}}
+mvn compile exec:java -Dexec.mainClass=com.example.MyPipeline \
+    -Dexec.args="... \
+                 --resourceHints=min_ram=<N>GB \
+                 --resourceHints=accelerator='hint'" \
+    -Pdirect-runner
+{{< /highlight >}}
+{{< highlight py >}}
+python my_pipeline.py \
+    ... \
+    --resource_hints min_ram=<N>GB \
+    --resource_hints accelerator="hint"
+{{< /highlight >}}
+
+## Specifying resource hints for a transform
+
+{{< paragraph class="language-java" >}}
+You can set resource hints programmatically on pipeline transforms using [setResourceHints](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/transforms/PTransform.html#setResourceHints-org.apache.beam.sdk.transforms.resourcehints.ResourceHints-).
+{{< /paragraph >}}
+
+{{< paragraph class="language-py" >}}
+You can set resource hints programmatically on pipeline transforms using [PTransforms.with_resource_hints](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.ptransform.html#apache_beam.transforms.ptransform.PTransform.with_resource_hints) (also see [ResourceHint](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/resources.py#L51)).
+{{< /paragraph >}}
+
+{{< highlight java >}}
+pcoll.apply(MyCompositeTransform.of(...)
+    .setResourceHints(
+        ResourceHints.create()
+            .withMinRam("15GB")
+            .withAccelerator("type:nvidia-tesla-k80;count:1;install-nvidia-driver")))
+
+pcoll.apply(ParDo.of(new BigMemFn())
+    .setResourceHints(
+        ResourceHints.create().withMinRam("30GB")))
+{{< /highlight >}}
+{{< highlight py >}}
+pcoll | MyPTransform().with_resource_hints(
+    min_ram="4GB",
+    accelerator="type:nvidia-tesla-k80;count:1;install-nvidia-driver")
+
+pcoll | beam.ParDo(BigMemFn()).with_resource_hints(
+    min_ram="30GB")
+{{< /highlight >}}
diff --git a/website/www/site/layouts/partials/section-menu/en/documentation.html b/website/www/site/layouts/partials/section-menu/en/documentation.html
index 4491ae1..6d1e664 100644
--- a/website/www/site/layouts/partials/section-menu/en/documentation.html
+++ b/website/www/site/layouts/partials/section-menu/en/documentation.html
@@ -213,6 +213,7 @@
 
   <ul class="section-nav-list">
     <li><a href="/documentation/runtime/environments/">Container environments</a></li>
+    <li><a href="/documentation/runtime/resource-hints/">Resource hints</a></li>
     <li><a href="/documentation/runtime/sdk-harness-config/">SDK Harness Configuration</a></li>
   </ul>
 </li>