Merge pull request #15477 from echauchot/BEAM-5172-ignore-es-flaky-tests

[BEAM-5172] Temporary ignore testSplit and testSizes tests waiting for a fix because they are flaky.
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 57bae57..4cdf4ae 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -444,20 +444,20 @@
     def checkerframework_version = "3.10.0"
     def classgraph_version = "4.8.104"
     def errorprone_version = "2.3.4"
-    def google_clients_version = "1.31.0"
+    def google_clients_version = "1.32.1"
     def google_cloud_bigdataoss_version = "2.2.2"
     def google_cloud_pubsublite_version = "0.13.2"
     def google_code_gson_version = "2.8.6"
     def google_oauth_clients_version = "1.31.0"
     // Try to keep grpc_version consistent with gRPC version in google_cloud_platform_libraries_bom
-    def grpc_version = "1.37.0"
-    def guava_version = "30.1-jre"
+    def grpc_version = "1.40.1"
+    def guava_version = "30.1.1-jre"
     def hadoop_version = "2.10.1"
     def hamcrest_version = "2.1"
     def influxdb_version = "2.19"
-    def httpclient_version = "4.5.10"
-    def httpcore_version = "4.4.12"
-    def jackson_version = "2.12.1"
+    def httpclient_version = "4.5.13"
+    def httpcore_version = "4.4.14"
+    def jackson_version = "2.12.4"
     def jaxb_api_version = "2.3.3"
     def jsr305_version = "3.0.2"
     def kafka_version = "2.4.1"
@@ -466,7 +466,7 @@
     def postgres_version = "42.2.16"
     def powermock_version = "2.0.9"
     // Try to keep protobuf_version consistent with the protobuf version in google_cloud_platform_libraries_bom
-    def protobuf_version = "3.15.8"
+    def protobuf_version = "3.17.3"
     def quickcheck_version = "0.8"
     def slf4j_version = "1.7.30"
     def spark_version = "2.4.8"
@@ -519,7 +519,7 @@
         cassandra_driver_core                       : "com.datastax.cassandra:cassandra-driver-core:$cassandra_driver_version",
         cassandra_driver_mapping                    : "com.datastax.cassandra:cassandra-driver-mapping:$cassandra_driver_version",
         classgraph                                  : "io.github.classgraph:classgraph:$classgraph_version",
-        commons_codec                               : "commons-codec:commons-codec:1.14",
+        commons_codec                               : "commons-codec:commons-codec:1.15",
         commons_compress                            : "org.apache.commons:commons-compress:1.21",
         commons_csv                                 : "org.apache.commons:commons-csv:1.8",
         commons_io                                  : "commons-io:commons-io:2.6",
@@ -530,23 +530,23 @@
         gax                                         : "com.google.api:gax", // google_cloud_platform_libraries_bom sets version
         gax_grpc                                    : "com.google.api:gax-grpc", // google_cloud_platform_libraries_bom sets version
         gax_httpjson                                : "com.google.api:gax-httpjson", // google_cloud_platform_libraries_bom sets version
-        google_api_client                           : "com.google.api-client:google-api-client:1.31.1", // 1.31.1 is required to run 1.31.0 of google_clients_version below.
+        google_api_client                           : "com.google.api-client:google-api-client:$google_clients_version", // for the libraries using $google_clients_version below.
         google_api_client_jackson2                  : "com.google.api-client:google-api-client-jackson2:$google_clients_version",
         google_api_client_java6                     : "com.google.api-client:google-api-client-java6:$google_clients_version",
         google_api_common                           : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version
-        google_api_services_bigquery                : "com.google.apis:google-api-services-bigquery:v2-rev20210410-$google_clients_version",
-        google_api_services_clouddebugger           : "com.google.apis:google-api-services-clouddebugger:v2-rev20210326-$google_clients_version",
-        google_api_services_cloudresourcemanager    : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20210331-$google_clients_version",
-        google_api_services_dataflow                : "com.google.apis:google-api-services-dataflow:v1b3-rev20210408-$google_clients_version",
-        google_api_services_healthcare              : "com.google.apis:google-api-services-healthcare:v1-rev20210603-$google_clients_version",
-        google_api_services_pubsub                  : "com.google.apis:google-api-services-pubsub:v1-rev20210322-$google_clients_version",
+        google_api_services_bigquery                : "com.google.apis:google-api-services-bigquery:v2-rev20210813-$google_clients_version",
+        google_api_services_clouddebugger           : "com.google.apis:google-api-services-clouddebugger:v2-rev20210813-$google_clients_version",
+        google_api_services_cloudresourcemanager    : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20210815-$google_clients_version",
+        google_api_services_dataflow                : "com.google.apis:google-api-services-dataflow:v1b3-rev20210818-$google_clients_version",
+        google_api_services_healthcare              : "com.google.apis:google-api-services-healthcare:v1-rev20210806-$google_clients_version",
+        google_api_services_pubsub                  : "com.google.apis:google-api-services-pubsub:v1-rev20210809-$google_clients_version",
         google_api_services_storage                 : "com.google.apis:google-api-services-storage:v1-rev20210127-$google_clients_version",
         google_auth_library_credentials             : "com.google.auth:google-auth-library-credentials", // google_cloud_platform_libraries_bom sets version
         google_auth_library_oauth2_http             : "com.google.auth:google-auth-library-oauth2-http", // google_cloud_platform_libraries_bom sets version
         google_cloud_bigquery                       : "com.google.cloud:google-cloud-bigquery", // google_cloud_platform_libraries_bom sets version
-        google_cloud_bigquery_storage               : "com.google.cloud:google-cloud-bigquerystorage:1.21.1",
-        google_cloud_bigtable_client_core           : "com.google.cloud.bigtable:bigtable-client-core:1.19.1",
-        google_cloud_bigtable_emulator              : "com.google.cloud:google-cloud-bigtable-emulator:0.125.2",
+        google_cloud_bigquery_storage               : "com.google.cloud:google-cloud-bigquerystorage", // google_cloud_platform_libraries_bom sets version
+        google_cloud_bigtable_client_core           : "com.google.cloud.bigtable:bigtable-client-core:1.23.1",
+        google_cloud_bigtable_emulator              : "com.google.cloud:google-cloud-bigtable-emulator:0.137.1",
         google_cloud_core                           : "com.google.cloud:google-cloud-core", // google_cloud_platform_libraries_bom sets version
         google_cloud_core_grpc                      : "com.google.cloud:google-cloud-core-grpc", // google_cloud_platform_libraries_bom sets version
         google_cloud_datacatalog_v1beta1            : "com.google.cloud:google-cloud-datacatalog", // google_cloud_platform_libraries_bom sets version
@@ -558,7 +558,7 @@
         // The GCP Libraries BOM dashboard shows the versions set by the BOM:
         // https://storage.googleapis.com/cloud-opensource-java-dashboard/com.google.cloud/libraries-bom/20.0.0/artifact_details.html
         // Update libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml
-        google_cloud_platform_libraries_bom         : "com.google.cloud:libraries-bom:20.0.0",
+        google_cloud_platform_libraries_bom         : "com.google.cloud:libraries-bom:22.0.0",
         google_cloud_spanner                        : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version
         google_code_gson                            : "com.google.code.gson:gson:$google_code_gson_version",
         // google-http-client's version is explicitly declared for sdks/java/maven-archetypes/examples
diff --git a/model/pipeline/src/main/proto/metrics.proto b/model/pipeline/src/main/proto/metrics.proto
index 10166e1..8f819b6 100644
--- a/model/pipeline/src/main/proto/metrics.proto
+++ b/model/pipeline/src/main/proto/metrics.proto
@@ -417,6 +417,9 @@
     GCS_PROJECT_ID = 17 [(label_props) = { name: "GCS_PROJECT_ID"}];
     DATASTORE_PROJECT = 18 [(label_props) = { name: "DATASTORE_PROJECT" }];
     DATASTORE_NAMESPACE = 19 [(label_props) = { name: "DATASTORE_NAMESPACE" }];
+    BIGTABLE_PROJECT_ID = 20 [(label_props) = { name: "BIGTABLE_PROJECT_ID"}];
+    INSTANCE_ID = 21 [(label_props) = { name: "INSTANCE_ID"}];
+    TABLE_ID = 22 [(label_props) = { name: "TABLE_ID"}];
   }
 
   // A set of key and value labels which define the scope of the metric. For
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java
new file mode 100644
index 0000000..c61152e
--- /dev/null
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
+ *
+ * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
+ * expected to crash!
+ *
+ * <p>This is copied from the dataflow runner code.
+ *
+ * @param <T> the type of elements to concatenate.
+ */
+public class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
+  @Override
+  public List<T> createAccumulator() {
+    return new ArrayList<>();
+  }
+
+  @Override
+  public List<T> addInput(List<T> accumulator, T input) {
+    accumulator.add(input);
+    return accumulator;
+  }
+
+  @Override
+  public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
+    List<T> result = createAccumulator();
+    for (List<T> accumulator : accumulators) {
+      result.addAll(accumulator);
+    }
+    return result;
+  }
+
+  @Override
+  public List<T> extractOutput(List<T> accumulator) {
+    return accumulator;
+  }
+
+  @Override
+  public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
+    return ListCoder.of(inputCoder);
+  }
+
+  @Override
+  public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
+    return ListCoder.of(inputCoder);
+  }
+}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
index 800413c..3133cc1 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
@@ -33,6 +33,16 @@
         projectId, datasetId, tableId);
   }
 
+  public static String bigtableTableID(String project, String instance, String table) {
+    return String.format("projects/%s/instances/%s/tables/%s", project, instance, table);
+  }
+
+  public static String bigtableResource(String projectId, String instanceId, String tableId) {
+    return String.format(
+        "//bigtable.googleapis.com/projects/%s/instances/%s/tables/%s",
+        projectId, instanceId, tableId);
+  }
+
   public static String datastoreResource(String projectId, String namespace) {
     return String.format(
         "//bigtable.googleapis.com/projects/%s/namespaces/%s", projectId, namespace);
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
index 88e55db..57ee58a 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
@@ -80,6 +80,9 @@
     public static final String BIGQUERY_QUERY_NAME = "BIGQUERY_QUERY_NAME";
     public static final String DATASTORE_PROJECT = "DATASTORE_PROJECT";
     public static final String DATASTORE_NAMESPACE = "DATASTORE_NAMESPACE";
+    public static final String BIGTABLE_PROJECT_ID = "BIGTABLE_PROJECT_ID";
+    public static final String INSTANCE_ID = "INSTANCE_ID";
+    public static final String TABLE_ID = "TABLE_ID";
 
     static {
       // Note: One benefit of defining these strings above, instead of pulling them in from
@@ -109,6 +112,10 @@
       checkArgument(DATASTORE_PROJECT.equals(extractLabel(MonitoringInfoLabels.DATASTORE_PROJECT)));
       checkArgument(
           DATASTORE_NAMESPACE.equals(extractLabel(MonitoringInfoLabels.DATASTORE_NAMESPACE)));
+      checkArgument(
+          BIGTABLE_PROJECT_ID.equals(extractLabel(MonitoringInfoLabels.BIGTABLE_PROJECT_ID)));
+      checkArgument(INSTANCE_ID.equals(extractLabel(MonitoringInfoLabels.INSTANCE_ID)));
+      checkArgument(TABLE_ID.equals(extractLabel(MonitoringInfoLabels.TABLE_ID)));
     }
   }
 
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
index 8f34413..7920fa4 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
@@ -18,14 +18,11 @@
 package org.apache.beam.runners.flink;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.core.Concatenate;
 import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.Combine;
@@ -59,51 +56,6 @@
   }
 
   /**
-   * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
-   *
-   * <p>For internal use by {@link CreateStreamingFlinkView}. This combiner requires that the input
-   * {@link PCollection} fits in memory. For a large {@link PCollection} this is expected to crash!
-   *
-   * @param <T> the type of elements to concatenate.
-   */
-  private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
-    @Override
-    public List<T> createAccumulator() {
-      return new ArrayList<>();
-    }
-
-    @Override
-    public List<T> addInput(List<T> accumulator, T input) {
-      accumulator.add(input);
-      return accumulator;
-    }
-
-    @Override
-    public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
-      List<T> result = createAccumulator();
-      for (List<T> accumulator : accumulators) {
-        result.addAll(accumulator);
-      }
-      return result;
-    }
-
-    @Override
-    public List<T> extractOutput(List<T> accumulator) {
-      return accumulator;
-    }
-
-    @Override
-    public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-
-    @Override
-    public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-  }
-
-  /**
    * Creates a primitive {@link PCollectionView}.
    *
    * <p>For internal use only by runner implementors.
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index d798846..d3d69d0 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -26,7 +26,6 @@
 
 import com.google.auto.service.AutoService;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -39,6 +38,7 @@
 import java.util.stream.Collectors;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
+import org.apache.beam.runners.core.Concatenate;
 import org.apache.beam.runners.core.construction.NativeTransforms;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.RehydratedComponents;
@@ -62,10 +62,7 @@
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.join.RawUnionValue;
 import org.apache.beam.sdk.transforms.join.UnionCoder;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -73,7 +70,6 @@
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
@@ -559,53 +555,6 @@
         Iterables.getOnlyElement(transform.getTransform().getOutputsMap().values()), dataSource);
   }
 
-  /**
-   * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
-   *
-   * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
-   * expected to crash!
-   *
-   * <p>This is copied from the dataflow runner code.
-   *
-   * @param <T> the type of elements to concatenate.
-   */
-  private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
-    @Override
-    public List<T> createAccumulator() {
-      return new ArrayList<>();
-    }
-
-    @Override
-    public List<T> addInput(List<T> accumulator, T input) {
-      accumulator.add(input);
-      return accumulator;
-    }
-
-    @Override
-    public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
-      List<T> result = createAccumulator();
-      for (List<T> accumulator : accumulators) {
-        result.addAll(accumulator);
-      }
-      return result;
-    }
-
-    @Override
-    public List<T> extractOutput(List<T> accumulator) {
-      return accumulator;
-    }
-
-    @Override
-    public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-
-    @Override
-    public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-  }
-
   private static void urnNotFound(
       PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
     throw new IllegalArgumentException(
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
index c2826e0..af5259e 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
@@ -21,12 +21,12 @@
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import org.apache.beam.runners.core.Concatenate;
 import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
@@ -48,10 +48,8 @@
 import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.runners.AppliedPTransform;
@@ -59,7 +57,6 @@
 import org.apache.beam.sdk.transforms.CombineFnBase;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.join.RawUnionValue;
@@ -425,53 +422,6 @@
     }
   }
 
-  /**
-   * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
-   *
-   * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
-   * expected to crash!
-   *
-   * <p>This is copied from the dataflow runner code.
-   *
-   * @param <T> the type of elements to concatenate.
-   */
-  private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
-    @Override
-    public List<T> createAccumulator() {
-      return new ArrayList<>();
-    }
-
-    @Override
-    public List<T> addInput(List<T> accumulator, T input) {
-      accumulator.add(input);
-      return accumulator;
-    }
-
-    @Override
-    public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
-      List<T> result = createAccumulator();
-      for (List<T> accumulator : accumulators) {
-        result.addAll(accumulator);
-      }
-      return result;
-    }
-
-    @Override
-    public List<T> extractOutput(List<T> accumulator) {
-      return accumulator;
-    }
-
-    @Override
-    public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-
-    @Override
-    public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-  }
-
   private static class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
       implements FlinkBatchPipelineTranslator.BatchTransformTranslator<
           PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
index 8644e73..1f8fbcc 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
@@ -17,12 +17,8 @@
  */
 package org.apache.beam.runners.samza.translation;
 
-import java.util.ArrayList;
-import java.util.List;
+import org.apache.beam.runners.core.Concatenate;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -67,41 +63,4 @@
       return input;
     }
   }
-
-  private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
-    @Override
-    public List<T> createAccumulator() {
-      return new ArrayList<>();
-    }
-
-    @Override
-    public List<T> addInput(List<T> accumulator, T input) {
-      accumulator.add(input);
-      return accumulator;
-    }
-
-    @Override
-    public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
-      List<T> result = createAccumulator();
-      for (List<T> accumulator : accumulators) {
-        result.addAll(accumulator);
-      }
-      return result;
-    }
-
-    @Override
-    public List<T> extractOutput(List<T> accumulator) {
-      return accumulator;
-    }
-
-    @Override
-    public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-
-    @Override
-    public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
-      return ListCoder.of(inputCoder);
-    }
-  }
 }
diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle
index 5b49e14..5a2e0f9 100644
--- a/runners/spark/spark_runner.gradle
+++ b/runners/spark/spark_runner.gradle
@@ -321,7 +321,7 @@
   forkEvery 1
   maxParallelForks 4
   // Increase memory heap in order to avoid OOM errors
-  jvmArgs '-Xmx4g'
+  jvmArgs '-Xmx7g'
   useJUnit {
     includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner'
     // Unbounded
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
index 6391ba4..4fe26d7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -17,26 +17,27 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import java.io.Serializable;
-import org.apache.beam.runners.core.InMemoryStateInternals;
-import org.apache.beam.runners.core.StateInternals;
-import org.apache.beam.runners.core.StateInternalsFactory;
-import org.apache.beam.runners.core.SystemReduceFn;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.runners.core.Concatenate;
 import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.KeyValueGroupedDataset;
+import scala.Tuple2;
 
 class GroupByKeyTranslatorBatch<K, V>
     implements TransformTranslator<
@@ -48,43 +49,62 @@
       AbstractTranslationContext context) {
 
     @SuppressWarnings("unchecked")
-    final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput();
-    Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection);
-    WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy();
-    KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
-    Coder<V> valueCoder = kvCoder.getValueCoder();
+    final PCollection<KV<K, V>> input = (PCollection<KV<K, V>>) context.getInput();
+    @SuppressWarnings("unchecked")
+    final PCollection<KV<K, List<V>>> output = (PCollection<KV<K, List<V>>>) context.getOutput();
+    final Combine.CombineFn<V, List<V>, List<V>> combineFn = new Concatenate<>();
 
-    // group by key only
-    Coder<K> keyCoder = kvCoder.getKeyCoder();
-    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
-        input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+    WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
 
-    // group also by windows
-    WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
-            windowingStrategy.getWindowFn().windowCoder());
-    Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
-        groupByKeyOnly.flatMapGroups(
-            new GroupAlsoByWindowViaOutputBufferFn<>(
-                windowingStrategy,
-                new InMemoryStateInternalsFactory<>(),
-                SystemReduceFn.buffering(valueCoder),
-                context.getSerializableOptions()),
-            EncoderHelpers.fromBeamCoder(outputCoder));
+    Dataset<WindowedValue<KV<K, V>>> inputDataset = context.getDataset(input);
 
-    context.putDataset(context.getOutput(), output);
-  }
+    KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+    Coder<K> keyCoder = inputCoder.getKeyCoder();
+    KvCoder<K, List<V>> outputKVCoder = (KvCoder<K, List<V>>) output.getCoder();
+    Coder<List<V>> outputCoder = outputKVCoder.getValueCoder();
 
-  /**
-   * In-memory state internals factory.
-   *
-   * @param <K> State key type.
-   */
-  static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
-    @Override
-    public StateInternals stateInternalsForKey(K key) {
-      return InMemoryStateInternals.forKey(key);
+    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupedDataset =
+      inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+
+    Coder<List<V>> accumulatorCoder = null;
+    try {
+      accumulatorCoder =
+        combineFn.getAccumulatorCoder(
+          input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
+    } catch (CannotProvideCoderException e) {
+      throw new RuntimeException(e);
     }
+
+    Dataset<Tuple2<K, Iterable<WindowedValue<List<V>>>>> combinedDataset =
+      groupedDataset.agg(
+        new AggregatorCombiner<K, V, List<V>, List<V>, BoundedWindow>(
+          combineFn, windowingStrategy, accumulatorCoder, outputCoder)
+          .toColumn());
+
+    // expand the list into separate elements and put the key back into the elements
+    WindowedValue.WindowedValueCoder<KV<K, List<V>>> wvCoder =
+      WindowedValue.FullWindowedValueCoder.of(
+        outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
+    Dataset<WindowedValue<KV<K, List<V>>>> outputDataset =
+      combinedDataset.flatMap(
+        (FlatMapFunction<
+          Tuple2<K, Iterable<WindowedValue<List<V>>>>, WindowedValue<KV<K, List<V>>>>)
+          tuple2 -> {
+            K key = tuple2._1();
+            Iterable<WindowedValue<List<V>>> windowedValues = tuple2._2();
+            List<WindowedValue<KV<K, List<V>>>> result = new ArrayList<>();
+            for (WindowedValue<List<V>> windowedValue : windowedValues) {
+              KV<K, List<V>> kv = KV.of(key, windowedValue.getValue());
+              result.add(
+                WindowedValue.of(
+                  kv,
+                  windowedValue.getTimestamp(),
+                  windowedValue.getWindows(),
+                  windowedValue.getPane()));
+            }
+            return result.iterator();
+          },
+        EncoderHelpers.fromBeamCoder(wvCoder));
+    context.putDataset(output, outputDataset);
   }
 }
diff --git a/sdks/go/examples/xlang/transforms.go b/sdks/go/examples/xlang/transforms.go
index acad214..3d410e8 100644
--- a/sdks/go/examples/xlang/transforms.go
+++ b/sdks/go/examples/xlang/transforms.go
@@ -20,7 +20,6 @@
 	"reflect"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam"
-	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
 )
@@ -51,7 +50,7 @@
 	pl := beam.CrossLanguagePayload(prefixPayload{Data: prefix})
 	outT := beam.UnnamedOutput(typex.New(reflectx.String))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:prefix", pl, addr, beam.UnnamedInput(col), outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func CoGroupByKey(s beam.Scope, addr string, col1, col2 beam.PCollection) beam.PCollection {
@@ -59,21 +58,21 @@
 	namedInputs := map[string]beam.PCollection{"col1": col1, "col2": col2}
 	outT := beam.UnnamedOutput(typex.NewCoGBK(typex.New(reflectx.Int64), typex.New(reflectx.String)))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:cgbk", nil, addr, namedInputs, outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func CombinePerKey(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
 	s = s.Scope("XLangTest.CombinePerKey")
 	outT := beam.UnnamedOutput(typex.NewKV(typex.New(reflectx.String), typex.New(reflectx.Int64)))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:compk", nil, addr, beam.UnnamedInput(col), outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func CombineGlobally(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
 	s = s.Scope("XLangTest.CombineGlobally")
 	outT := beam.UnnamedOutput(typex.New(reflectx.Int64))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:comgl", nil, addr, beam.UnnamedInput(col), outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func Flatten(s beam.Scope, addr string, col1, col2 beam.PCollection) beam.PCollection {
@@ -81,14 +80,14 @@
 	namedInputs := map[string]beam.PCollection{"col1": col1, "col2": col2}
 	outT := beam.UnnamedOutput(typex.New(reflectx.Int64))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:flatten", nil, addr, namedInputs, outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func GroupByKey(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
 	s = s.Scope("XLangTest.GroupByKey")
 	outT := beam.UnnamedOutput(typex.NewCoGBK(typex.New(reflectx.String), typex.New(reflectx.Int64)))
 	outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:gbk", nil, addr, beam.UnnamedInput(col), outT)
-	return outs[graph.UnnamedOutputTag]
+	return outs[beam.UnnamedOutputTag()]
 }
 
 func Multi(s beam.Scope, addr string, main1, main2, side beam.PCollection) (mainOut, sideOut beam.PCollection) {
@@ -112,5 +111,5 @@
 	s = s.Scope("XLang.Count")
 	outT := beam.UnnamedOutput(typex.NewKV(typex.New(reflectx.String), typex.New(reflectx.Int64)))
 	c := beam.CrossLanguage(s, "beam:transforms:xlang:count", nil, addr, beam.UnnamedInput(col), outT)
-	return c[graph.UnnamedOutputTag]
+	return c[beam.UnnamedOutputTag()]
 }
diff --git a/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go b/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
index f2317f0..8781cff 100644
--- a/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
+++ b/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
@@ -48,7 +48,6 @@
 	"reflect"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam"
-	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
 )
@@ -131,7 +130,7 @@
 	pl := beam.CrossLanguagePayload(rpl)
 	outT := beam.UnnamedOutput(typex.NewKV(typex.New(rcfg.key), typex.New(rcfg.val)))
 	out := beam.CrossLanguage(s, readURN, pl, addr, nil, outT)
-	return out[graph.UnnamedOutputTag]
+	return out[beam.UnnamedOutputTag()]
 }
 
 type readOption func(*readConfig)
diff --git a/sdks/go/pkg/beam/xlang.go b/sdks/go/pkg/beam/xlang.go
index 4ec0162..b9c6ec7 100644
--- a/sdks/go/pkg/beam/xlang.go
+++ b/sdks/go/pkg/beam/xlang.go
@@ -22,32 +22,44 @@
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
 )
 
-// xlang exposes an API to execute cross-language transforms within the Go SDK.
-// It is experimental and likely to change. It exposes convenient wrappers
-// around the core functions to pass in any combination of named/unnamed
-// inputs/outputs.
-
 // UnnamedInput is a helper function for passing single unnamed inputs to
-// `beam.CrossLanguage`.
+// beam.CrossLanguage.
 //
 // Example:
-//    beam.CrossLanguage(s, urn, payload, addr, UnnamedInput(input), outputs);
+//    beam.CrossLanguage(s, urn, payload, addr, UnnamedInput(input), outputs)
 func UnnamedInput(col PCollection) map[string]PCollection {
 	return map[string]PCollection{graph.UnnamedInputTag: col}
 }
 
 // UnnamedOutput is a helper function for passing single unnamed output types to
-// `beam.CrossLanguage`.
+// beam.CrossLanguage. The associated output can be accessed with beam.UnnamedOutputTag.
 //
 // Example:
-//    beam.CrossLanguage(s, urn, payload, addr, inputs, UnnamedOutput(output));
+//    resultMap := beam.CrossLanguage(s, urn, payload, addr, inputs, UnnamedOutput(output));
+//    result := resultMap[beam.UnnamedOutputTag()]
 func UnnamedOutput(t FullType) map[string]FullType {
 	return map[string]FullType{graph.UnnamedOutputTag: t}
 }
 
-// CrossLanguagePayload encodes a native Go struct into a payload for
-// cross-language transforms. To find the expected structure of a payload,
-// consult the documentation in the SDK performing the expansion.
+// UnnamedOutputTag provides the output tag used for an output passed to beam.UnnamedOutput.
+// Needed to retrieve the unnamed output PCollection from the result of beam.CrossLanguage.
+func UnnamedOutputTag() string {
+	return graph.UnnamedOutputTag
+}
+
+// CrossLanguagePayload encodes a native Go struct into a payload for cross-language transforms.
+// payloads are []byte encoded ExternalConfigurationPayload protobufs. In order to fill the
+// contents of the protobuf, the provided struct will be used to converted to a row encoded
+// representation with an accompanying schema, so the input struct must be compatible with schemas.
+//
+// See https://beam.apache.org/documentation/programming-guide/#schemas for basic information on
+// schemas, and pkg/beam/core/runtime/graphx/schema for details on schemas in the Go SDK.
+//
+// Example:
+//    type stringPayload struct {
+//        Data string
+//    }
+//    encodedPl := beam.CrossLanguagePayload(stringPayload{Data: "foo"})
 func CrossLanguagePayload(pl interface{}) []byte {
 	bytes, err := xlangx.EncodeStructPayload(pl)
 	if err != nil {
@@ -56,8 +68,73 @@
 	return bytes
 }
 
-// CrossLanguage executes a cross-language transform that uses named inputs and
-// returns named outputs.
+// CrossLanguage is a low-level transform for executing cross-language transforms written in other
+// SDKs. Because this is low-level, it is recommended to use one of the higher-level IO-specific
+// wrappers where available. These can be found in the pkg/beam/io/xlang subdirectory.
+// CrossLanguage is useful for executing cross-language transforms which do not have any existing
+// IO wrappers.
+//
+// Usage requires an address for an expansion service accessible during pipeline construction, a
+// URN identifying the desired transform, an optional payload with configuration information, and
+// input and output names. It outputs a map of named output PCollections.
+//
+// For more information on expansion services and other aspects of cross-language transforms in
+// general, refer to the Beam programming guide: https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines
+//
+// Payload
+//
+// Payloads are configuration data that some cross-language transforms require for expansion.
+// Consult the documentation of the transform in the source SDK to find out what payload data it
+// requires. If no payload is required, pass in nil.
+//
+// CrossLanguage accepts payloads as a []byte containing an encoded ExternalConfigurationPayload
+// protobuf. The helper function beam.CrossLanguagePayload is the recommended way to easily encode
+// a standard Go struct for use as a payload.
+//
+// Inputs and Outputs
+//
+// Like most transforms, any input PCollections must be provided. Unlike most transforms, output
+// types must be provided because Go cannot infer output types from external transforms.
+//
+// Inputs and outputs to a cross-language transform may be either named or unnamed. Named
+// inputs/outputs are used when there are more than one input/output, and are provided as maps with
+// names as keys. Unnamed inputs/outputs are used when there is only one, and a map can be quickly
+// constructed with the UnnamedInput and UnnamedOutput methods.
+//
+// An example of defining named inputs and outputs:
+//
+//    namedInputs := map[string]beam.PCollection{"pcol1": pcol1, "pcol2": pcol2}
+//    namedOutputTypes := map[string]typex.FullType{
+//        "main": typex.New(reflectx.String),
+//        "side": typex.New(reflectx.Int64),
+//    }
+//
+// CrossLanguage outputs a map of PCollections with associated names. These names will match those
+// from provided named outputs. If the beam.UnnamedOutput method was used, the PCollection can be
+// retrieved with beam.UnnamedOutputTag().
+//
+// An example of retrieving named outputs from a call to CrossLanguage:
+//
+//    outputs := beam.CrossLanguage(...)
+//    mainPcol := outputs["main"]
+//    sidePcol := outputs["side"]
+//
+// Example
+//
+// This example shows using CrossLanguage to execute the Prefix cross-language transform using an
+// expansion service running on localhost:8099. Prefix requires a payload containing a prefix to
+// prepend to every input string.
+//
+//    type prefixPayload struct {
+//        Data string
+//    }
+//    encodedPl := beam.CrossLanguagePayload(prefixPayload{Data: "foo"})
+//    urn := "beam:transforms:xlang:test:prefix"
+//    expansionAddr := "localhost:8099"
+//    outputType := beam.UnnamedOutput(typex.New(reflectx.String))
+//    input := beam.UnnamedInput(inputPcol)
+//    outs := beam.CrossLanguage(s, urn, encodedPl, expansionAddr, input, outputType)
+//    outPcol := outputs[beam.UnnamedOutputTag()]
 func CrossLanguage(
 	s Scope,
 	urn string,
@@ -86,7 +163,9 @@
 	return mapNodeToPCollection(namedOutputs)
 }
 
-// TryCrossLanguage coordinates the core functions required to execute the cross-language transform
+// TryCrossLanguage coordinates the core functions required to execute the cross-language transform.
+// This is mainly intended for internal use. For the general-use entry point, see
+// beam.CrossLanguage.
 func TryCrossLanguage(s Scope, ext *graph.ExternalTransform, ins []*graph.Inbound, outs []*graph.Outbound) (map[string]*graph.Node, error) {
 	// Adding an edge in the graph corresponding to the ExternalTransform
 	edge, isBoundedUpdater := graph.NewCrossLanguage(s.real, s.scope, ext, ins, outs)
diff --git a/sdks/java/container/license_scripts/dep_urls_java.yaml b/sdks/java/container/license_scripts/dep_urls_java.yaml
index 36439e8..14dbb3e 100644
--- a/sdks/java/container/license_scripts/dep_urls_java.yaml
+++ b/sdks/java/container/license_scripts/dep_urls_java.yaml
@@ -41,7 +41,7 @@
   '1.1.6':
     type: "3-Clause BSD"
 libraries-bom:
-  '20.0.0':
+  '22.0.0':
     license: "https://raw.githubusercontent.com/GoogleCloudPlatform/cloud-opensource-java/master/LICENSE"
     type: "Apache License 2.0"
 paranamer:
@@ -53,6 +53,6 @@
   '1.5':
     license: "https://git.tukaani.org/?p=xz-java.git;a=blob;f=COPYING;h=c1d404dc7a6f06a0437bf1055fedaa4a4c89d728;hb=9f1f97a26f090ffec6568c004a38c6534aa82b94"
 jackson-bom:
-  '2.12.1':
+  '2.12.4':
     license: "https://raw.githubusercontent.com/FasterXML/jackson-bom/master/LICENSE"
     type: "Apache License 2.0"
diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
similarity index 97%
rename from sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
rename to sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
index 5244108..e4e6de7 100644
--- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
+++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
@@ -68,9 +68,9 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-/** Tests for {@link JavaCLassLookupTransformProvider}. */
+/** Tests for {@link JavaClassLookupTransformProvider}. */
 @RunWith(JUnit4.class)
-public class JavaCLassLookupTransformProviderTest {
+public class JavaClassLookupTransformProviderTest {
 
   private static final String TEST_URN = "test:beam:transforms:count";
 
@@ -466,7 +466,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -485,7 +485,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethod");
 
     payloadBuilder.setConstructorMethod("from");
     Row constructorRow =
@@ -505,7 +505,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -547,7 +547,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor");
 
     Row constructorRow =
         Row.withSchema(
@@ -571,7 +571,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -606,7 +606,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithWrapperTypes");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -636,7 +636,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypes");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -681,7 +681,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithArray");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -718,7 +718,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithList");
 
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -758,7 +758,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeArray");
 
     Schema complexTypeSchema =
         Schema.builder()
@@ -824,7 +824,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeList");
 
     Schema complexTypeSchema =
         Schema.builder()
@@ -889,7 +889,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
     payloadBuilder.setConstructorMethod("from");
 
     Row constructorRow =
@@ -934,7 +934,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
     payloadBuilder.setConstructorMethod("from");
 
     Row constructorRow =
@@ -977,7 +977,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations");
     payloadBuilder.setConstructorMethod("create_transform");
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -1019,7 +1019,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$UnavailableClass");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$UnavailableClass");
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
             .withFieldValue("strField1", "test_str_1")
@@ -1042,7 +1042,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor");
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("incorrectField", FieldType.STRING)))
             .withFieldValue("incorrectField", "test_str_1")
@@ -1065,7 +1065,7 @@
     ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
         ExternalTransforms.JavaClassLookupPayload.newBuilder();
     payloadBuilder.setClassName(
-        "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+        "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
     Row constructorRow =
         Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
             .withFieldValue("strField1", "test_str_1")
diff --git a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
index ad11523..dd76f47 100644
--- a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
+++ b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
@@ -19,19 +19,19 @@
 
 version: v1
 allowedClasses:
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethod
   allowedConstructorMethods:
   - from
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods
   allowedBuilderMethods:
   - withStrField2
   - withIntField1
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod
   allowedBuilderMethods:
   - withFields
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods
   allowedConstructorMethods:
   - from
   allowedBuilderMethods:
@@ -39,28 +39,28 @@
   - withIntField1
   - strField2
   - intField1
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations
   allowedConstructorMethods:
   - create_transform
   allowedBuilderMethods:
   - abc
   - xyz
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithWrapperTypes
   allowedBuilderMethods:
   - withDoubleWrapperField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypes
   allowedBuilderMethods:
   - withComplexTypeField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithArray
   allowedBuilderMethods:
   - withStrArrayField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithList
   allowedBuilderMethods:
   - withStrListField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeArray
   allowedBuilderMethods:
   - withComplexTypeArrayField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeList
   allowedBuilderMethods:
   - withComplexTypeListField
 
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
index e94dc2c..bd919dd 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
@@ -21,13 +21,9 @@
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import com.datastax.driver.core.Cluster;
-import com.datastax.driver.core.ColumnMetadata;
 import com.datastax.driver.core.ConsistencyLevel;
 import com.datastax.driver.core.PlainTextAuthProvider;
 import com.datastax.driver.core.QueryOptions;
-import com.datastax.driver.core.ResultSet;
-import com.datastax.driver.core.ResultSetFuture;
-import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
 import com.datastax.driver.core.SocketOptions;
 import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
@@ -36,32 +32,32 @@
 import java.math.BigInteger;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.Iterator;
 import java.util.List;
-import java.util.NoSuchElementException;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.function.BiFunction;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -89,6 +85,11 @@
  *
  * }</pre>
  *
+ * <p>Alternatively, one may use {@code CassandraIO.<Person>readAll()
+ * .withCoder(SerializableCoder.of(Person.class))} to query a subset of the Cassandra database by
+ * creating a PCollection of {@code CassandraIO.Read<Person>} each with their own query or
+ * RingRange.
+ *
  * <h3>Writing to Apache Cassandra</h3>
  *
  * <p>{@code CassandraIO} provides a sink to write a collection of entities to Apache Cassandra.
@@ -137,6 +138,11 @@
     return new AutoValue_CassandraIO_Read.Builder<T>().build();
   }
 
+  /** Provide a {@link ReadAll} {@link PTransform} to read data from a Cassandra database. */
+  public static <T> ReadAll<T> readAll() {
+    return new AutoValue_CassandraIO_ReadAll.Builder<T>().build();
+  }
+
   /** Provide a {@link Write} {@link PTransform} to write data to a Cassandra database. */
   public static <T> Write<T> write() {
     return Write.<T>builder(MutationType.WRITE).build();
@@ -186,6 +192,9 @@
 
     abstract @Nullable SerializableFunction<Session, Mapper> mapperFactoryFn();
 
+    @Nullable
+    abstract ValueProvider<Set<RingRange>> ringRanges();
+
     abstract Builder<T> builder();
 
     /** Specify the hosts of the Apache Cassandra instances. */
@@ -371,6 +380,14 @@
       return builder().setMapperFactoryFn(mapperFactory).build();
     }
 
+    public Read<T> withRingRanges(Set<RingRange> ringRange) {
+      return withRingRanges(ValueProvider.StaticValueProvider.of(ringRange));
+    }
+
+    public Read<T> withRingRanges(ValueProvider<Set<RingRange>> ringRange) {
+      return builder().setRingRanges(ringRange).build();
+    }
+
     @Override
     public PCollection<T> expand(PBegin input) {
       checkArgument((hosts() != null && port() != null), "WithHosts() and withPort() are required");
@@ -379,7 +396,69 @@
       checkArgument(entity() != null, "withEntity() is required");
       checkArgument(coder() != null, "withCoder() is required");
 
-      return input.apply(org.apache.beam.sdk.io.Read.from(new CassandraSource<>(this, null)));
+      PCollection<Read<T>> splits =
+          input
+              .apply(Create.of(this))
+              .apply("Create Splits", ParDo.of(new SplitFn<T>()))
+              .setCoder(SerializableCoder.of(new TypeDescriptor<Read<T>>() {}));
+
+      return splits.apply("ReadAll", CassandraIO.<T>readAll().withCoder(coder()));
+    }
+
+    private static class SplitFn<T> extends DoFn<Read<T>, Read<T>> {
+      @ProcessElement
+      public void process(
+          @Element CassandraIO.Read<T> read, OutputReceiver<Read<T>> outputReceiver) {
+        Set<RingRange> ringRanges = getRingRanges(read);
+        for (RingRange rr : ringRanges) {
+          Set<RingRange> subset = ImmutableSet.<RingRange>of(rr);
+          outputReceiver.output(read.withRingRanges(ImmutableSet.of(rr)));
+        }
+      }
+
+      private static <T> Set<RingRange> getRingRanges(Read<T> read) {
+        try (Cluster cluster =
+            getCluster(
+                read.hosts(),
+                read.port(),
+                read.username(),
+                read.password(),
+                read.localDc(),
+                read.consistencyLevel(),
+                read.connectTimeout(),
+                read.readTimeout())) {
+          if (isMurmur3Partitioner(cluster)) {
+            LOG.info("Murmur3Partitioner detected, splitting");
+            Integer splitCount;
+            if (read.minNumberOfSplits() != null && read.minNumberOfSplits().get() != null) {
+              splitCount = read.minNumberOfSplits().get();
+            } else {
+              splitCount = cluster.getMetadata().getAllHosts().size();
+            }
+            List<BigInteger> tokens =
+                cluster.getMetadata().getTokenRanges().stream()
+                    .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
+                    .collect(Collectors.toList());
+            SplitGenerator splitGenerator =
+                new SplitGenerator(cluster.getMetadata().getPartitioner());
+
+            return splitGenerator.generateSplits(splitCount, tokens).stream()
+                .flatMap(List::stream)
+                .collect(Collectors.toSet());
+
+          } else {
+            LOG.warn(
+                "Only Murmur3Partitioner is supported for splitting, using an unique source for "
+                    + "the read");
+            String partitioner = cluster.getMetadata().getPartitioner();
+            RingRange totalRingRange =
+                RingRange.of(
+                    SplitGenerator.getRangeMin(partitioner),
+                    SplitGenerator.getRangeMax(partitioner));
+            return Collections.singleton(totalRingRange);
+          }
+        }
+      }
     }
 
     @AutoValue.Builder
@@ -418,6 +497,8 @@
 
       abstract Optional<SerializableFunction<Session, Mapper>> mapperFactoryFn();
 
+      abstract Builder<T> setRingRanges(ValueProvider<Set<RingRange>> ringRange);
+
       abstract Read<T> autoBuild();
 
       public Read<T> build() {
@@ -429,390 +510,6 @@
     }
   }
 
-  @VisibleForTesting
-  static class CassandraSource<T> extends BoundedSource<T> {
-    final Read<T> spec;
-    final List<String> splitQueries;
-    // split source ached size - can't be calculated when already split
-    Long estimatedSize;
-    private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
-
-    CassandraSource(Read<T> spec, List<String> splitQueries) {
-      this(spec, splitQueries, null);
-    }
-
-    private CassandraSource(Read<T> spec, List<String> splitQueries, Long estimatedSize) {
-      this.estimatedSize = estimatedSize;
-      this.spec = spec;
-      this.splitQueries = splitQueries;
-    }
-
-    @Override
-    public Coder<T> getOutputCoder() {
-      return spec.coder();
-    }
-
-    @Override
-    public BoundedReader<T> createReader(PipelineOptions pipelineOptions) {
-      return new CassandraReader(this);
-    }
-
-    @Override
-    public List<BoundedSource<T>> split(
-        long desiredBundleSizeBytes, PipelineOptions pipelineOptions) {
-      try (Cluster cluster =
-          getCluster(
-              spec.hosts(),
-              spec.port(),
-              spec.username(),
-              spec.password(),
-              spec.localDc(),
-              spec.consistencyLevel(),
-              spec.connectTimeout(),
-              spec.readTimeout())) {
-        if (isMurmur3Partitioner(cluster)) {
-          LOG.info("Murmur3Partitioner detected, splitting");
-          return splitWithTokenRanges(
-              spec, desiredBundleSizeBytes, getEstimatedSizeBytes(pipelineOptions), cluster);
-        } else {
-          LOG.warn(
-              "Only Murmur3Partitioner is supported for splitting, using a unique source for "
-                  + "the read");
-          return Collections.singletonList(
-              new CassandraIO.CassandraSource<>(spec, Collections.singletonList(buildQuery(spec))));
-        }
-      }
-    }
-
-    private static String buildQuery(Read spec) {
-      return (spec.query() == null)
-          ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
-          : spec.query().get().toString();
-    }
-
-    /**
-     * Compute the number of splits based on the estimated size and the desired bundle size, and
-     * create several sources.
-     */
-    private List<BoundedSource<T>> splitWithTokenRanges(
-        CassandraIO.Read<T> spec,
-        long desiredBundleSizeBytes,
-        long estimatedSizeBytes,
-        Cluster cluster) {
-      long numSplits =
-          getNumSplits(desiredBundleSizeBytes, estimatedSizeBytes, spec.minNumberOfSplits());
-      LOG.info("Number of desired splits is {}", numSplits);
-
-      SplitGenerator splitGenerator = new SplitGenerator(cluster.getMetadata().getPartitioner());
-      List<BigInteger> tokens =
-          cluster.getMetadata().getTokenRanges().stream()
-              .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
-              .collect(Collectors.toList());
-      List<List<RingRange>> splits = splitGenerator.generateSplits(numSplits, tokens);
-      LOG.info("{} splits were actually generated", splits.size());
-
-      final String partitionKey =
-          cluster.getMetadata().getKeyspace(spec.keyspace().get()).getTable(spec.table().get())
-              .getPartitionKey().stream()
-              .map(ColumnMetadata::getName)
-              .collect(Collectors.joining(","));
-
-      List<TokenRange> tokenRanges =
-          getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
-      final long estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges) / splits.size();
-
-      List<BoundedSource<T>> sources = new ArrayList<>();
-      for (List<RingRange> split : splits) {
-        List<String> queries = new ArrayList<>();
-        for (RingRange range : split) {
-          if (range.isWrapping()) {
-            // A wrapping range is one that overlaps from the end of the partitioner range and its
-            // start (ie : when the start token of the split is greater than the end token)
-            // We need to generate two queries here : one that goes from the start token to the end
-            // of
-            // the partitioner range, and the other from the start of the partitioner range to the
-            // end token of the split.
-            queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), null));
-            // Generation of the second query of the wrapping range
-            queries.add(generateRangeQuery(spec, partitionKey, null, range.getEnd()));
-          } else {
-            queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), range.getEnd()));
-          }
-        }
-        sources.add(new CassandraIO.CassandraSource<>(spec, queries, estimatedSize));
-      }
-      return sources;
-    }
-
-    private static String generateRangeQuery(
-        Read spec, String partitionKey, BigInteger rangeStart, BigInteger rangeEnd) {
-      final String rangeFilter =
-          Joiner.on(" AND ")
-              .skipNulls()
-              .join(
-                  rangeStart == null
-                      ? null
-                      : String.format("(token(%s) >= %d)", partitionKey, rangeStart),
-                  rangeEnd == null
-                      ? null
-                      : String.format("(token(%s) < %d)", partitionKey, rangeEnd));
-      final String query =
-          (spec.query() == null)
-              ? buildQuery(spec) + " WHERE " + rangeFilter
-              : buildQuery(spec) + " AND " + rangeFilter;
-      LOG.debug("CassandraIO generated query : {}", query);
-      return query;
-    }
-
-    private static long getNumSplits(
-        long desiredBundleSizeBytes,
-        long estimatedSizeBytes,
-        @Nullable ValueProvider<Integer> minNumberOfSplits) {
-      long numSplits =
-          desiredBundleSizeBytes > 0 ? (estimatedSizeBytes / desiredBundleSizeBytes) : 1;
-      if (numSplits <= 0) {
-        LOG.warn("Number of splits is less than 0 ({}), fallback to 1", numSplits);
-        numSplits = 1;
-      }
-      return minNumberOfSplits != null ? Math.max(numSplits, minNumberOfSplits.get()) : numSplits;
-    }
-
-    /**
-     * Returns cached estimate for split or if missing calculate size for whole table. Highly
-     * innacurate if query is specified.
-     *
-     * @param pipelineOptions
-     * @return
-     */
-    @Override
-    public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
-      if (estimatedSize != null) {
-        return estimatedSize;
-      } else {
-        try (Cluster cluster =
-            getCluster(
-                spec.hosts(),
-                spec.port(),
-                spec.username(),
-                spec.password(),
-                spec.localDc(),
-                spec.consistencyLevel(),
-                spec.connectTimeout(),
-                spec.readTimeout())) {
-          if (isMurmur3Partitioner(cluster)) {
-            try {
-              List<TokenRange> tokenRanges =
-                  getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
-              this.estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges);
-              return this.estimatedSize;
-            } catch (Exception e) {
-              LOG.warn("Can't estimate the size", e);
-              return 0L;
-            }
-          } else {
-            LOG.warn("Only Murmur3 partitioner is supported, can't estimate the size");
-            return 0L;
-          }
-        }
-      }
-    }
-
-    @VisibleForTesting
-    static long getEstimatedSizeBytesFromTokenRanges(List<TokenRange> tokenRanges) {
-      long size = 0L;
-      for (TokenRange tokenRange : tokenRanges) {
-        size = size + tokenRange.meanPartitionSize * tokenRange.partitionCount;
-      }
-      return Math.round(size / getRingFraction(tokenRanges));
-    }
-
-    @Override
-    public void populateDisplayData(DisplayData.Builder builder) {
-      super.populateDisplayData(builder);
-      if (spec.hosts() != null) {
-        builder.add(DisplayData.item("hosts", spec.hosts().toString()));
-      }
-      if (spec.port() != null) {
-        builder.add(DisplayData.item("port", spec.port()));
-      }
-      builder.addIfNotNull(DisplayData.item("keyspace", spec.keyspace()));
-      builder.addIfNotNull(DisplayData.item("table", spec.table()));
-      builder.addIfNotNull(DisplayData.item("username", spec.username()));
-      builder.addIfNotNull(DisplayData.item("localDc", spec.localDc()));
-      builder.addIfNotNull(DisplayData.item("consistencyLevel", spec.consistencyLevel()));
-    }
-    // ------------- CASSANDRA SOURCE UTIL METHODS ---------------//
-
-    /**
-     * Gets the list of token ranges that a table occupies on a give Cassandra node.
-     *
-     * <p>NB: This method is compatible with Cassandra 2.1.5 and greater.
-     */
-    private static List<TokenRange> getTokenRanges(Cluster cluster, String keyspace, String table) {
-      try (Session session = cluster.newSession()) {
-        ResultSet resultSet =
-            session.execute(
-                "SELECT range_start, range_end, partitions_count, mean_partition_size FROM "
-                    + "system.size_estimates WHERE keyspace_name = ? AND table_name = ?",
-                keyspace,
-                table);
-
-        ArrayList<TokenRange> tokenRanges = new ArrayList<>();
-        for (Row row : resultSet) {
-          TokenRange tokenRange =
-              new TokenRange(
-                  row.getLong("partitions_count"),
-                  row.getLong("mean_partition_size"),
-                  new BigInteger(row.getString("range_start")),
-                  new BigInteger(row.getString("range_end")));
-          tokenRanges.add(tokenRange);
-        }
-        // The table may not contain the estimates yet
-        // or have partitions_count and mean_partition_size fields = 0
-        // if the data was just inserted and the amount of data in the table was small.
-        // This is very common situation during tests,
-        // when we insert a few rows and immediately query them.
-        // However, for tiny data sets the lack of size estimates is not a problem at all,
-        // because we don't want to split tiny data anyways.
-        // Therefore, we're not issuing a warning if the result set was empty
-        // or mean_partition_size and partitions_count = 0.
-        return tokenRanges;
-      }
-    }
-
-    /** Compute the percentage of token addressed compared with the whole tokens in the cluster. */
-    @VisibleForTesting
-    static double getRingFraction(List<TokenRange> tokenRanges) {
-      double ringFraction = 0;
-      for (TokenRange tokenRange : tokenRanges) {
-        ringFraction =
-            ringFraction
-                + (distance(tokenRange.rangeStart, tokenRange.rangeEnd).doubleValue()
-                    / SplitGenerator.getRangeSize(MURMUR3PARTITIONER).doubleValue());
-      }
-      return ringFraction;
-    }
-
-    /**
-     * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
-     */
-    @VisibleForTesting
-    static boolean isMurmur3Partitioner(Cluster cluster) {
-      return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
-    }
-
-    /** Measure distance between two tokens. */
-    @VisibleForTesting
-    static BigInteger distance(BigInteger left, BigInteger right) {
-      return (right.compareTo(left) > 0)
-          ? right.subtract(left)
-          : right.subtract(left).add(SplitGenerator.getRangeSize(MURMUR3PARTITIONER));
-    }
-
-    /**
-     * Represent a token range in Cassandra instance, wrapping the partition count, size and token
-     * range.
-     */
-    @VisibleForTesting
-    static class TokenRange {
-      private final long partitionCount;
-      private final long meanPartitionSize;
-      private final BigInteger rangeStart;
-      private final BigInteger rangeEnd;
-
-      TokenRange(
-          long partitionCount, long meanPartitionSize, BigInteger rangeStart, BigInteger rangeEnd) {
-        this.partitionCount = partitionCount;
-        this.meanPartitionSize = meanPartitionSize;
-        this.rangeStart = rangeStart;
-        this.rangeEnd = rangeEnd;
-      }
-    }
-
-    private class CassandraReader extends BoundedSource.BoundedReader<T> {
-      private final CassandraIO.CassandraSource<T> source;
-      private Cluster cluster;
-      private Session session;
-      private Iterator<T> iterator;
-      private T current;
-
-      CassandraReader(CassandraSource<T> source) {
-        this.source = source;
-      }
-
-      @Override
-      public boolean start() {
-        LOG.debug("Starting Cassandra reader");
-        cluster =
-            getCluster(
-                source.spec.hosts(),
-                source.spec.port(),
-                source.spec.username(),
-                source.spec.password(),
-                source.spec.localDc(),
-                source.spec.consistencyLevel(),
-                source.spec.connectTimeout(),
-                source.spec.readTimeout());
-        session = cluster.connect(source.spec.keyspace().get());
-        LOG.debug("Queries: " + source.splitQueries);
-        List<ResultSetFuture> futures = new ArrayList<>();
-        for (String query : source.splitQueries) {
-          futures.add(session.executeAsync(query));
-        }
-
-        final Mapper<T> mapper = getMapper(session, source.spec.entity());
-
-        for (ResultSetFuture result : futures) {
-          if (iterator == null) {
-            iterator = mapper.map(result.getUninterruptibly());
-          } else {
-            iterator = Iterators.concat(iterator, mapper.map(result.getUninterruptibly()));
-          }
-        }
-
-        return advance();
-      }
-
-      @Override
-      public boolean advance() {
-        if (iterator.hasNext()) {
-          current = iterator.next();
-          return true;
-        }
-        current = null;
-        return false;
-      }
-
-      @Override
-      public void close() {
-        LOG.debug("Closing Cassandra reader");
-        if (session != null) {
-          session.close();
-        }
-        if (cluster != null) {
-          cluster.close();
-        }
-      }
-
-      @Override
-      public T getCurrent() throws NoSuchElementException {
-        if (current == null) {
-          throw new NoSuchElementException();
-        }
-        return current;
-      }
-
-      @Override
-      public CassandraIO.CassandraSource<T> getCurrentSource() {
-        return source;
-      }
-
-      private Mapper<T> getMapper(Session session, Class<T> enitity) {
-        return source.spec.mapperFactoryFn().apply(session);
-      }
-    }
-  }
-
   /** Specify the mutation type: either write or delete. */
   public enum MutationType {
     WRITE,
@@ -1179,7 +876,7 @@
   }
 
   /** Get a Cassandra cluster using hosts and port. */
-  private static Cluster getCluster(
+  static Cluster getCluster(
       ValueProvider<List<String>> hosts,
       ValueProvider<Integer> port,
       ValueProvider<String> username,
@@ -1301,4 +998,53 @@
       }
     }
   }
+
+  /**
+   * A {@link PTransform} to read data from Apache Cassandra. See {@link CassandraIO} for more
+   * information on usage and configuration.
+   */
+  @AutoValue
+  public abstract static class ReadAll<T> extends PTransform<PCollection<Read<T>>, PCollection<T>> {
+    @AutoValue.Builder
+    abstract static class Builder<T> {
+
+      abstract Builder<T> setCoder(Coder<T> coder);
+
+      abstract ReadAll<T> autoBuild();
+
+      public ReadAll<T> build() {
+        return autoBuild();
+      }
+    }
+
+    @Nullable
+    abstract Coder<T> coder();
+
+    abstract Builder<T> builder();
+
+    /** Specify the {@link Coder} used to serialize the entity in the {@link PCollection}. */
+    public ReadAll<T> withCoder(Coder<T> coder) {
+      checkArgument(coder != null, "coder can not be null");
+      return builder().setCoder(coder).build();
+    }
+
+    @Override
+    public PCollection<T> expand(PCollection<Read<T>> input) {
+      checkArgument(coder() != null, "withCoder() is required");
+      return input
+          .apply("Reshuffle", Reshuffle.viaRandomKey())
+          .apply("Read", ParDo.of(new ReadFn<>()))
+          .setCoder(this.coder());
+    }
+  }
+
+  /**
+   * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
+   */
+  @VisibleForTesting
+  private static boolean isMurmur3Partitioner(Cluster cluster) {
+    return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
+  }
+
+  private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
 }
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
new file mode 100644
index 0000000..5091ac4
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Session;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.options.ValueProvider;
+
+@SuppressWarnings({
+  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class ConnectionManager {
+
+  private static final ConcurrentHashMap<String, Cluster> clusterMap =
+      new ConcurrentHashMap<String, Cluster>();
+  private static final ConcurrentHashMap<String, Session> sessionMap =
+      new ConcurrentHashMap<String, Session>();
+
+  static {
+    Runtime.getRuntime()
+        .addShutdownHook(
+            new Thread(
+                () -> {
+                  for (Session session : sessionMap.values()) {
+                    if (!session.isClosed()) {
+                      session.close();
+                    }
+                  }
+                }));
+  }
+
+  private static String readToClusterHash(Read<?> read) {
+    return Objects.requireNonNull(read.hosts()).get().stream().reduce(",", (a, b) -> a + b)
+        + Objects.requireNonNull(read.port()).get()
+        + safeVPGet(read.localDc())
+        + safeVPGet(read.consistencyLevel());
+  }
+
+  private static String readToSessionHash(Read<?> read) {
+    return readToClusterHash(read) + read.keyspace().get();
+  }
+
+  static Session getSession(Read<?> read) {
+    Cluster cluster =
+        clusterMap.computeIfAbsent(
+            readToClusterHash(read),
+            k ->
+                CassandraIO.getCluster(
+                    Objects.requireNonNull(read.hosts()),
+                    Objects.requireNonNull(read.port()),
+                    read.username(),
+                    read.password(),
+                    read.localDc(),
+                    read.consistencyLevel(),
+                    read.connectTimeout(),
+                    read.readTimeout()));
+    return sessionMap.computeIfAbsent(
+        readToSessionHash(read),
+        k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
+  }
+
+  private static String safeVPGet(ValueProvider<String> s) {
+    return s != null ? s.get() : "";
+  }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
index 8f6d578..92ec2c5 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
@@ -34,7 +34,7 @@
 })
 class DefaultObjectMapper<T> implements Mapper<T>, Serializable {
 
-  private transient com.datastax.driver.mapping.Mapper<T> mapper;
+  private final transient com.datastax.driver.mapping.Mapper<T> mapper;
 
   DefaultObjectMapper(com.datastax.driver.mapping.Mapper mapper) {
     this.mapper = mapper;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
index 7976665..ef75ff3 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
@@ -34,7 +34,7 @@
 class DefaultObjectMapperFactory<T> implements SerializableFunction<Session, Mapper> {
 
   private transient MappingManager mappingManager;
-  Class<T> entity;
+  final Class<T> entity;
 
   DefaultObjectMapperFactory(Class<T> entity) {
     this.entity = entity;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
new file mode 100644
index 0000000..193cdf0
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.ColumnMetadata;
+import com.datastax.driver.core.PreparedStatement;
+import com.datastax.driver.core.ResultSet;
+import com.datastax.driver.core.Session;
+import com.datastax.driver.core.Token;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@SuppressWarnings({
+  "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+class ReadFn<T> extends DoFn<Read<T>, T> {
+
+  private static final Logger LOG = LoggerFactory.getLogger(ReadFn.class);
+
+  @ProcessElement
+  public void processElement(@Element Read<T> read, OutputReceiver<T> receiver) {
+    try {
+      Session session = ConnectionManager.getSession(read);
+      Mapper<T> mapper = read.mapperFactoryFn().apply(session);
+      String partitionKey =
+          session.getCluster().getMetadata().getKeyspace(read.keyspace().get())
+              .getTable(read.table().get()).getPartitionKey().stream()
+              .map(ColumnMetadata::getName)
+              .collect(Collectors.joining(","));
+
+      String query = generateRangeQuery(read, partitionKey, read.ringRanges() != null);
+      PreparedStatement preparedStatement = session.prepare(query);
+      Set<RingRange> ringRanges =
+          read.ringRanges() == null ? Collections.emptySet() : read.ringRanges().get();
+
+      for (RingRange rr : ringRanges) {
+        Token startToken = session.getCluster().getMetadata().newToken(rr.getStart().toString());
+        Token endToken = session.getCluster().getMetadata().newToken(rr.getEnd().toString());
+        ResultSet rs =
+            session.execute(preparedStatement.bind().setToken(0, startToken).setToken(1, endToken));
+        Iterator<T> iter = mapper.map(rs);
+        while (iter.hasNext()) {
+          T n = iter.next();
+          receiver.output(n);
+        }
+      }
+
+      if (read.ringRanges() == null) {
+        ResultSet rs = session.execute(preparedStatement.bind());
+        Iterator<T> iter = mapper.map(rs);
+        while (iter.hasNext()) {
+          receiver.output(iter.next());
+        }
+      }
+    } catch (Exception ex) {
+      LOG.error("error", ex);
+    }
+  }
+
+  private Session getSession(Read<T> read) {
+    Cluster cluster =
+        CassandraIO.getCluster(
+            read.hosts(),
+            read.port(),
+            read.username(),
+            read.password(),
+            read.localDc(),
+            read.consistencyLevel(),
+            read.connectTimeout(),
+            read.readTimeout());
+
+    return cluster.connect(read.keyspace().get());
+  }
+
+  private static String generateRangeQuery(
+      Read<?> spec, String partitionKey, Boolean hasRingRange) {
+    final String rangeFilter =
+        (hasRingRange)
+            ? Joiner.on(" AND ")
+                .skipNulls()
+                .join(
+                    String.format("(token(%s) >= ?)", partitionKey),
+                    String.format("(token(%s) < ?)", partitionKey))
+            : "";
+    final String combinedQuery = buildInitialQuery(spec, hasRingRange) + rangeFilter;
+    LOG.debug("CassandraIO generated query : {}", combinedQuery);
+    return combinedQuery;
+  }
+
+  private static String buildInitialQuery(Read<?> spec, Boolean hasRingRange) {
+    return (spec.query() == null)
+        ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
+            + " WHERE "
+        : spec.query().get() + (hasRingRange ? " AND " : "");
+  }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
index b5f94d7..c83e47f 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
@@ -17,23 +17,31 @@
  */
 package org.apache.beam.sdk.io.cassandra;
 
+import java.io.Serializable;
 import java.math.BigInteger;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
 
 /** Models a Cassandra token range. */
-final class RingRange {
+@Experimental(Kind.SOURCE_SINK)
+@SuppressWarnings({
+  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public final class RingRange implements Serializable {
   private final BigInteger start;
   private final BigInteger end;
 
-  RingRange(BigInteger start, BigInteger end) {
+  private RingRange(BigInteger start, BigInteger end) {
     this.start = start;
     this.end = end;
   }
 
-  BigInteger getStart() {
+  public BigInteger getStart() {
     return start;
   }
 
-  BigInteger getEnd() {
+  public BigInteger getEnd() {
     return end;
   }
 
@@ -55,4 +63,34 @@
   public String toString() {
     return String.format("(%s,%s]", start.toString(), end.toString());
   }
+
+  public static RingRange of(BigInteger start, BigInteger end) {
+    return new RingRange(start, end);
+  }
+
+  @Override
+  public boolean equals(@Nullable Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+
+    RingRange ringRange = (RingRange) o;
+
+    if (getStart() != null
+        ? !getStart().equals(ringRange.getStart())
+        : ringRange.getStart() != null) {
+      return false;
+    }
+    return getEnd() != null ? getEnd().equals(ringRange.getEnd()) : ringRange.getEnd() == null;
+  }
+
+  @Override
+  public int hashCode() {
+    int result = getStart() != null ? getStart().hashCode() : 0;
+    result = 31 * result + (getEnd() != null ? getEnd().hashCode() : 0);
+    return result;
+  }
 }
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
index de49421..bc1205a 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
@@ -39,22 +39,22 @@
     this.partitioner = partitioner;
   }
 
-  private static BigInteger getRangeMin(String partitioner) {
+  static BigInteger getRangeMin(String partitioner) {
     if (partitioner.endsWith("RandomPartitioner")) {
       return BigInteger.ZERO;
     } else if (partitioner.endsWith("Murmur3Partitioner")) {
-      return new BigInteger("2").pow(63).negate();
+      return BigInteger.valueOf(2).pow(63).negate();
     } else {
       throw new UnsupportedOperationException(
           "Unsupported partitioner. " + "Only Random and Murmur3 are supported");
     }
   }
 
-  private static BigInteger getRangeMax(String partitioner) {
+  static BigInteger getRangeMax(String partitioner) {
     if (partitioner.endsWith("RandomPartitioner")) {
-      return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
+      return BigInteger.valueOf(2).pow(127).subtract(BigInteger.ONE);
     } else if (partitioner.endsWith("Murmur3Partitioner")) {
-      return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
+      return BigInteger.valueOf(2).pow(63).subtract(BigInteger.ONE);
     } else {
       throw new UnsupportedOperationException(
           "Unsupported partitioner. " + "Only Random and Murmur3 are supported");
@@ -84,7 +84,7 @@
       BigInteger start = ringTokens.get(i);
       BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);
 
-      if (!inRange(start) || !inRange(stop)) {
+      if (!isInRange(start) || !isInRange(stop)) {
         throw new RuntimeException(
             String.format("Tokens (%s,%s) not in range of %s", start, stop, partitioner));
       }
@@ -127,7 +127,7 @@
 
       // Append the splits between the endpoints
       for (int j = 0; j < splitCount; j++) {
-        splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 1)));
+        splits.add(RingRange.of(endpointTokens.get(j), endpointTokens.get(j + 1)));
         LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), endpointTokens.get(j + 1));
       }
     }
@@ -144,7 +144,7 @@
     return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
   }
 
-  private boolean inRange(BigInteger token) {
+  private boolean isInRange(BigInteger token) {
     return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
   }
 
diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
index a52808b..131ce83 100644
--- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
+++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
@@ -18,23 +18,18 @@
 package org.apache.beam.sdk.io.cassandra;
 
 import static junit.framework.TestCase.assertTrue;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.distance;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getEstimatedSizeBytesFromTokenRanges;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getRingFraction;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.isMurmur3Partitioner;
-import static org.apache.beam.sdk.testing.SourceTestUtils.readFromSource;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 
 import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Metadata;
+import com.datastax.driver.core.ProtocolVersion;
 import com.datastax.driver.core.ResultSet;
 import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
+import com.datastax.driver.core.TypeCodec;
 import com.datastax.driver.core.exceptions.NoHostAvailableException;
-import com.datastax.driver.core.querybuilder.QueryBuilder;
+import com.datastax.driver.mapping.annotations.ClusteringColumn;
 import com.datastax.driver.mapping.annotations.Column;
 import com.datastax.driver.mapping.annotations.Computed;
 import com.datastax.driver.mapping.annotations.PartitionKey;
@@ -44,10 +39,14 @@
 import java.io.IOException;
 import java.io.Serializable;
 import java.math.BigInteger;
+import java.nio.ByteBuffer;
 import java.nio.file.Files;
 import java.nio.file.Paths;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
@@ -61,13 +60,8 @@
 import javax.management.remote.JMXConnectorFactory;
 import javax.management.remote.JMXServiceURL;
 import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.TokenRange;
 import org.apache.beam.sdk.io.common.NetworkTestHelper;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.SourceTestUtils;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
@@ -82,7 +76,6 @@
 import org.apache.cassandra.service.StorageServiceMBean;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.junit.AfterClass;
-import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
 import org.junit.Rule;
@@ -99,7 +92,7 @@
   "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
 })
 public class CassandraIOTest implements Serializable {
-  private static final long NUM_ROWS = 20L;
+  private static final long NUM_ROWS = 22L;
   private static final String CASSANDRA_KEYSPACE = "beam_ks";
   private static final String CASSANDRA_HOST = "127.0.0.1";
   private static final String CASSANDRA_TABLE = "scientist";
@@ -190,39 +183,44 @@
     LOG.info("Create Cassandra tables");
     session.execute(
         String.format(
-            "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
-                + "(person_id));",
+            "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+                + "((person_department), person_id));",
             CASSANDRA_KEYSPACE, CASSANDRA_TABLE));
     session.execute(
         String.format(
-            "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
-                + "(person_id));",
+            "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+                + "((person_department), person_id));",
             CASSANDRA_KEYSPACE, CASSANDRA_TABLE_WRITE));
 
     LOG.info("Insert records");
-    String[] scientists = {
-      "Einstein",
-      "Darwin",
-      "Copernicus",
-      "Pasteur",
-      "Curie",
-      "Faraday",
-      "Newton",
-      "Bohr",
-      "Galilei",
-      "Maxwell"
+    String[][] scientists = {
+      new String[] {"phys", "Einstein"},
+      new String[] {"bio", "Darwin"},
+      new String[] {"phys", "Copernicus"},
+      new String[] {"bio", "Pasteur"},
+      new String[] {"bio", "Curie"},
+      new String[] {"phys", "Faraday"},
+      new String[] {"math", "Newton"},
+      new String[] {"phys", "Bohr"},
+      new String[] {"phys", "Galileo"},
+      new String[] {"math", "Maxwell"},
+      new String[] {"logic", "Russel"},
     };
     for (int i = 0; i < NUM_ROWS; i++) {
       int index = i % scientists.length;
-      session.execute(
+      String insertStr =
           String.format(
-              "INSERT INTO %s.%s(person_id, person_name) values("
+              "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+                  + "'"
+                  + scientists[index][0]
+                  + "', "
                   + i
                   + ", '"
-                  + scientists[index]
+                  + scientists[index][1]
                   + "');",
               CASSANDRA_KEYSPACE,
-              CASSANDRA_TABLE));
+              CASSANDRA_TABLE);
+      session.execute(insertStr);
     }
     flushMemTablesAndRefreshSizeEstimates();
   }
@@ -277,25 +275,6 @@
   }
 
   @Test
-  public void testEstimatedSizeBytes() throws Exception {
-    PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
-    CassandraIO.Read<Scientist> read =
-        CassandraIO.<Scientist>read()
-            .withHosts(Collections.singletonList(CASSANDRA_HOST))
-            .withPort(cassandraPort)
-            .withKeyspace(CASSANDRA_KEYSPACE)
-            .withTable(CASSANDRA_TABLE);
-    CassandraIO.CassandraSource<Scientist> source = new CassandraIO.CassandraSource<>(read, null);
-    long estimatedSizeBytes = source.getEstimatedSizeBytes(pipelineOptions);
-    // the size is non determanistic in Cassandra backend: checks that estimatedSizeBytes >= 12960L
-    // -20%  && estimatedSizeBytes <= 12960L +20%
-    assertThat(
-        "wrong estimated size in " + CASSANDRA_KEYSPACE + "/" + CASSANDRA_TABLE,
-        estimatedSizeBytes,
-        greaterThan(0L));
-  }
-
-  @Test
   public void testRead() throws Exception {
     PCollection<Scientist> output =
         pipeline.apply(
@@ -304,6 +283,7 @@
                 .withPort(cassandraPort)
                 .withKeyspace(CASSANDRA_KEYSPACE)
                 .withTable(CASSANDRA_TABLE)
+                .withMinNumberOfSplits(50)
                 .withCoder(SerializableCoder.of(Scientist.class))
                 .withEntity(Scientist.class));
 
@@ -321,9 +301,113 @@
     PAssert.that(mapped.apply("Count occurrences per scientist", Count.perKey()))
         .satisfies(
             input -> {
+              int count = 0;
               for (KV<String, Long> element : input) {
+                count++;
                 assertEquals(element.getKey(), NUM_ROWS / 10, element.getValue().longValue());
               }
+              assertEquals(11, count);
+              return null;
+            });
+
+    pipeline.run();
+  }
+
+  private CassandraIO.Read<Scientist> getReadWithRingRange(RingRange... rr) {
+    return CassandraIO.<Scientist>read()
+        .withHosts(Collections.singletonList(CASSANDRA_HOST))
+        .withPort(cassandraPort)
+        .withRingRanges(new HashSet<>(Arrays.asList(rr)))
+        .withKeyspace(CASSANDRA_KEYSPACE)
+        .withTable(CASSANDRA_TABLE)
+        .withCoder(SerializableCoder.of(Scientist.class))
+        .withEntity(Scientist.class);
+  }
+
+  private CassandraIO.Read<Scientist> getReadWithQuery(String query) {
+    return CassandraIO.<Scientist>read()
+        .withHosts(Collections.singletonList(CASSANDRA_HOST))
+        .withPort(cassandraPort)
+        .withQuery(query)
+        .withKeyspace(CASSANDRA_KEYSPACE)
+        .withTable(CASSANDRA_TABLE)
+        .withCoder(SerializableCoder.of(Scientist.class))
+        .withEntity(Scientist.class);
+  }
+
+  @Test
+  public void testReadAllQuery() {
+    String physQuery =
+        String.format(
+            "SELECT * From %s.%s WHERE person_department='phys' AND person_id=0;",
+            CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+    String mathQuery =
+        String.format(
+            "SELECT * From %s.%s WHERE person_department='math' AND person_id=6;",
+            CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+    PCollection<Scientist> output =
+        pipeline
+            .apply(Create.of(getReadWithQuery(physQuery), getReadWithQuery(mathQuery)))
+            .apply(
+                CassandraIO.<Scientist>readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+    PCollection<String> mapped =
+        output.apply(
+            MapElements.via(
+                new SimpleFunction<Scientist, String>() {
+                  @Override
+                  public String apply(Scientist scientist) {
+                    return scientist.name;
+                  }
+                }));
+    PAssert.that(mapped).containsInAnyOrder("Einstein", "Newton");
+    PAssert.thatSingleton(output.apply("count", Count.globally())).isEqualTo(2L);
+    pipeline.run();
+  }
+
+  @Test
+  public void testReadAllRingRange() {
+    RingRange physRR =
+        fromEncodedKey(
+            cluster.getMetadata(), TypeCodec.varchar().serialize("phys", ProtocolVersion.V3));
+
+    RingRange mathRR =
+        fromEncodedKey(
+            cluster.getMetadata(), TypeCodec.varchar().serialize("math", ProtocolVersion.V3));
+
+    RingRange logicRR =
+        fromEncodedKey(
+            cluster.getMetadata(), TypeCodec.varchar().serialize("logic", ProtocolVersion.V3));
+
+    PCollection<Scientist> output =
+        pipeline
+            .apply(Create.of(getReadWithRingRange(physRR), getReadWithRingRange(mathRR, logicRR)))
+            .apply(
+                CassandraIO.<Scientist>readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+    PCollection<KV<String, Integer>> mapped =
+        output.apply(
+            MapElements.via(
+                new SimpleFunction<Scientist, KV<String, Integer>>() {
+                  @Override
+                  public KV<String, Integer> apply(Scientist scientist) {
+                    return KV.of(scientist.department, scientist.id);
+                  }
+                }));
+
+    PAssert.that(mapped.apply("Count occurrences per department", Count.perKey()))
+        .satisfies(
+            input -> {
+              HashMap<String, Long> map = new HashMap<>();
+              for (KV<String, Long> element : input) {
+                map.put(element.getKey(), element.getValue());
+              }
+              assertEquals(3, map.size()); // do we have all three departments
+              assertEquals(map.get("phys"), 10L, 0L);
+              assertEquals(map.get("math"), 4L, 0L);
+              assertEquals(map.get("logic"), 2L, 0L);
               return null;
             });
 
@@ -339,8 +423,9 @@
                 .withPort(cassandraPort)
                 .withKeyspace(CASSANDRA_KEYSPACE)
                 .withTable(CASSANDRA_TABLE)
+                .withMinNumberOfSplits(20)
                 .withQuery(
-                    "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10")
+                    "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10 AND person_department='logic'")
                 .withCoder(SerializableCoder.of(Scientist.class))
                 .withEntity(Scientist.class));
 
@@ -365,6 +450,7 @@
       ScientistWrite scientist = new ScientistWrite();
       scientist.id = i;
       scientist.name = "Name " + i;
+      scientist.department = "bio";
       data.add(scientist);
     }
 
@@ -485,52 +571,6 @@
     assertEquals(1, counter.intValue());
   }
 
-  @Test
-  public void testSplit() throws Exception {
-    PipelineOptions options = PipelineOptionsFactory.create();
-    CassandraIO.Read<Scientist> read =
-        CassandraIO.<Scientist>read()
-            .withHosts(Collections.singletonList(CASSANDRA_HOST))
-            .withPort(cassandraPort)
-            .withKeyspace(CASSANDRA_KEYSPACE)
-            .withTable(CASSANDRA_TABLE)
-            .withEntity(Scientist.class)
-            .withCoder(SerializableCoder.of(Scientist.class));
-
-    // initialSource will be read without splitting (which does not happen in production)
-    // so we need to provide splitQueries to avoid NPE in source.reader.start()
-    String splitQuery = QueryBuilder.select().from(CASSANDRA_KEYSPACE, CASSANDRA_TABLE).toString();
-    CassandraIO.CassandraSource<Scientist> initialSource =
-        new CassandraIO.CassandraSource<>(read, Collections.singletonList(splitQuery));
-    int desiredBundleSizeBytes = 2048;
-    long estimatedSize = initialSource.getEstimatedSizeBytes(options);
-    List<BoundedSource<Scientist>> splits = initialSource.split(desiredBundleSizeBytes, options);
-    SourceTestUtils.assertSourcesEqualReferenceSource(initialSource, splits, options);
-    float expectedNumSplitsloat =
-        (float) initialSource.getEstimatedSizeBytes(options) / desiredBundleSizeBytes;
-    long sum = 0;
-
-    for (BoundedSource<Scientist> subSource : splits) {
-      sum += subSource.getEstimatedSizeBytes(options);
-    }
-
-    // due to division and cast estimateSize != sum but will be close. Exact equals checked below
-    assertEquals((long) (estimatedSize / splits.size()) * splits.size(), sum);
-
-    int expectedNumSplits = (int) Math.ceil(expectedNumSplitsloat);
-    assertEquals("Wrong number of splits", expectedNumSplits, splits.size());
-    int emptySplits = 0;
-    for (BoundedSource<Scientist> subSource : splits) {
-      if (readFromSource(subSource, options).isEmpty()) {
-        emptySplits += 1;
-      }
-    }
-    assertThat(
-        "There are too many empty splits, parallelism is sub-optimal",
-        emptySplits,
-        lessThan((int) (ACCEPTABLE_EMPTY_SPLITS_PERCENTAGE * splits.size())));
-  }
-
   private List<Row> getRows(String table) {
     ResultSet result =
         session.execute(
@@ -545,6 +585,7 @@
 
     Scientist einstein = new Scientist();
     einstein.id = 0;
+    einstein.department = "phys";
     einstein.name = "Einstein";
     pipeline
         .apply(Create.of(einstein))
@@ -561,7 +602,8 @@
     // re-insert suppressed doc to make the test autonomous
     session.execute(
         String.format(
-            "INSERT INTO %s.%s(person_id, person_name) values("
+            "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+                + "'phys', "
                 + einstein.id
                 + ", '"
                 + einstein.name
@@ -570,58 +612,6 @@
             CASSANDRA_TABLE));
   }
 
-  @Test
-  public void testValidPartitioner() {
-    Assert.assertTrue(isMurmur3Partitioner(cluster));
-  }
-
-  @Test
-  public void testDistance() {
-    BigInteger distance = distance(new BigInteger("10"), new BigInteger("100"));
-    assertEquals(BigInteger.valueOf(90), distance);
-
-    distance = distance(new BigInteger("100"), new BigInteger("10"));
-    assertEquals(new BigInteger("18446744073709551526"), distance);
-  }
-
-  @Test
-  public void testRingFraction() {
-    // simulate a first range taking "half" of the available tokens
-    List<TokenRange> tokenRanges = new ArrayList<>();
-    tokenRanges.add(new TokenRange(1, 1, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
-    assertEquals(0.5, getRingFraction(tokenRanges), 0);
-
-    // add a second range to cover all tokens available
-    tokenRanges.add(new TokenRange(1, 1, new BigInteger("0"), BigInteger.valueOf(Long.MAX_VALUE)));
-    assertEquals(1.0, getRingFraction(tokenRanges), 0);
-  }
-
-  @Test
-  public void testEstimatedSizeBytesFromTokenRanges() {
-    List<TokenRange> tokenRanges = new ArrayList<>();
-    // one partition containing all tokens, the size is actually the size of the partition
-    tokenRanges.add(
-        new TokenRange(
-            1, 1000, BigInteger.valueOf(Long.MIN_VALUE), BigInteger.valueOf(Long.MAX_VALUE)));
-    assertEquals(1000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
-    // one partition with half of the tokens, we estimate the size to the double of this partition
-    tokenRanges = new ArrayList<>();
-    tokenRanges.add(
-        new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
-    assertEquals(2000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
-    // we have three partitions covering all tokens, the size is the sum of partition size *
-    // partition count
-    tokenRanges = new ArrayList<>();
-    tokenRanges.add(
-        new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("-3")));
-    tokenRanges.add(new TokenRange(1, 1000, new BigInteger("-2"), new BigInteger("10000")));
-    tokenRanges.add(
-        new TokenRange(2, 3000, new BigInteger("10001"), BigInteger.valueOf(Long.MAX_VALUE)));
-    assertEquals(8000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-  }
-
   /** Simple Cassandra entity used in read tests. */
   @Table(name = CASSANDRA_TABLE, keyspace = CASSANDRA_KEYSPACE)
   static class Scientist implements Serializable {
@@ -632,10 +622,14 @@
     @Computed("writetime(person_name)")
     Long nameTs;
 
-    @PartitionKey()
+    @ClusteringColumn()
     @Column(name = "person_id")
     int id;
 
+    @PartitionKey
+    @Column(name = "person_department")
+    String department;
+
     @Override
     public String toString() {
       return id + ":" + name;
@@ -650,7 +644,9 @@
         return false;
       }
       Scientist scientist = (Scientist) o;
-      return id == scientist.id && Objects.equal(name, scientist.name);
+      return id == scientist.id
+          && Objects.equal(name, scientist.name)
+          && Objects.equal(department, scientist.department);
     }
 
     @Override
@@ -659,6 +655,11 @@
     }
   }
 
+  private static RingRange fromEncodedKey(Metadata metadata, ByteBuffer... bb) {
+    BigInteger bi = BigInteger.valueOf((long) metadata.newToken(bb).getValue());
+    return RingRange.of(bi, bi.add(BigInteger.valueOf(1L)));
+  }
+
   private static final String CASSANDRA_TABLE_WRITE = "scientist_write";
   /** Simple Cassandra entity used in write tests. */
   @Table(name = CASSANDRA_TABLE_WRITE, keyspace = CASSANDRA_KEYSPACE)
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 3ec4d39..215a66b 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -131,7 +131,7 @@
 
   compile "org.threeten:threetenbp:1.4.4"
 
-  testCompile "org.apache.arrow:arrow-memory-netty:4.0.0"
+  testCompile library.java.arrow_memory_netty
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   testCompile project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntime")
   testCompile project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntime")
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
index 484fe3b..dfe797e 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
@@ -86,7 +86,6 @@
   private final SerializableFunction<ElementT, TableRow> toTableRow;
   private final SerializableFunction<ElementT, TableRow> toFailsafeTableRow;
   private final Set<String> allowedMetricUrns;
-  private @Nullable DatasetService datasetService;
 
   /** Tracks bytes written, exposed as "ByteCount" Counter. */
   private Counter byteCounter = SinkMetrics.bytesWritten();
@@ -222,6 +221,15 @@
     /** The list of unique ids for each BigQuery table row. */
     private transient Map<String, List<String>> uniqueIdsForTableRows;
 
+    private transient @Nullable DatasetService datasetService;
+
+    private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+      if (datasetService == null) {
+        datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+      }
+      return datasetService;
+    }
+
     /** Prepares a target BigQuery table. */
     @StartBundle
     public void startBundle() {
@@ -257,10 +265,10 @@
           tableRows.entrySet()) {
         TableReference tableReference = BigQueryHelpers.parseTableSpec(entry.getKey());
         flushRows(
+            getDatasetService(options),
             tableReference,
             entry.getValue(),
             uniqueIdsForTableRows.get(entry.getKey()),
-            options,
             failedInserts,
             successfulInserts);
       }
@@ -272,6 +280,18 @@
       }
       reportStreamingApiLogging(options);
     }
+
+    @Teardown
+    public void onTeardown() {
+      try {
+        if (datasetService != null) {
+          datasetService.close();
+          datasetService = null;
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+    }
   }
 
   // The max duration input records are allowed to be buffered in the state, if using ViaStateful.
@@ -325,13 +345,22 @@
   // shuffling.
   private class InsertBatchedElements
       extends DoFn<KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>>, Void> {
+    private transient @Nullable DatasetService datasetService;
+
+    private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+      if (datasetService == null) {
+        datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+      }
+      return datasetService;
+    }
+
     @ProcessElement
     public void processElement(
         @Element KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>> input,
         BoundedWindow window,
         ProcessContext context,
         MultiOutputReceiver out)
-        throws InterruptedException {
+        throws InterruptedException, IOException {
       List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows = new ArrayList<>();
       List<String> uniqueIds = new ArrayList<>();
       for (TableRowInfo<ElementT> row : input.getValue()) {
@@ -347,7 +376,13 @@
       TableReference tableReference = BigQueryHelpers.parseTableSpec(input.getKey().getKey());
       List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
       List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
-      flushRows(tableReference, tableRows, uniqueIds, options, failedInserts, successfulInserts);
+      flushRows(
+          getDatasetService(options),
+          tableReference,
+          tableRows,
+          uniqueIds,
+          failedInserts,
+          successfulInserts);
 
       for (ValueInSingleWindow<ErrorT> row : failedInserts) {
         out.get(failedOutputTag).output(row.getValue());
@@ -357,44 +392,43 @@
       }
       reportStreamingApiLogging(options);
     }
-  }
 
-  @Teardown
-  public void onTeardown() {
-    try {
-      if (datasetService != null) {
-        datasetService.close();
-        datasetService = null;
+    @Teardown
+    public void onTeardown() {
+      try {
+        if (datasetService != null) {
+          datasetService.close();
+          datasetService = null;
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(e);
       }
-    } catch (Exception e) {
-      throw new RuntimeException(e);
     }
   }
 
   /** Writes the accumulated rows into BigQuery with streaming API. */
   private void flushRows(
+      DatasetService datasetService,
       TableReference tableReference,
       List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows,
       List<String> uniqueIds,
-      BigQueryOptions options,
       List<ValueInSingleWindow<ErrorT>> failedInserts,
       List<ValueInSingleWindow<TableRow>> successfulInserts)
       throws InterruptedException {
     if (!tableRows.isEmpty()) {
       try {
         long totalBytes =
-            getDatasetService(options)
-                .insertAll(
-                    tableReference,
-                    tableRows,
-                    uniqueIds,
-                    retryPolicy,
-                    failedInserts,
-                    errorContainer,
-                    skipInvalidRows,
-                    ignoreUnknownValues,
-                    ignoreInsertIds,
-                    successfulInserts);
+            datasetService.insertAll(
+                tableReference,
+                tableRows,
+                uniqueIds,
+                retryPolicy,
+                failedInserts,
+                errorContainer,
+                skipInvalidRows,
+                ignoreUnknownValues,
+                ignoreInsertIds,
+                successfulInserts);
         byteCounter.inc(totalBytes);
       } catch (IOException e) {
         throw new RuntimeException(e);
@@ -402,13 +436,6 @@
     }
   }
 
-  private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
-    if (datasetService == null) {
-      datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
-    }
-    return datasetService;
-  }
-
   private void reportStreamingApiLogging(BigQueryOptions options) {
     MetricsContainer processWideContainer = MetricsEnvironment.getProcessWideContainer();
     if (processWideContainer instanceof MetricsLogger) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 1d3d894..8b9b705 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -965,49 +965,53 @@
       // earlier stages of the pipeline or if a query depends on earlier stages of a pipeline.
       // For these cases the withoutValidation method can be used to disable the check.
       if (getValidate()) {
-        if (table != null) {
-          checkArgument(table.isAccessible(), "Cannot call validate if table is dynamically set.");
-        }
-        if (table != null && table.get().getProjectId() != null) {
-          // Check for source table presence for early failure notification.
-          DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
-          BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
-          BigQueryHelpers.verifyTablePresence(datasetService, table.get());
-        } else if (getQuery() != null) {
-          checkArgument(
-              getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
-          JobService jobService = getBigQueryServices().getJobService(bqOptions);
-          try {
-            jobService.dryRunQuery(
-                bqOptions.getBigQueryProject() == null
-                    ? bqOptions.getProject()
-                    : bqOptions.getBigQueryProject(),
-                new JobConfigurationQuery()
-                    .setQuery(getQuery().get())
-                    .setFlattenResults(getFlattenResults())
-                    .setUseLegacySql(getUseLegacySql()),
-                getQueryLocation());
-          } catch (Exception e) {
-            throw new IllegalArgumentException(
-                String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+        try (DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions)) {
+          if (table != null) {
+            checkArgument(
+                table.isAccessible(), "Cannot call validate if table is dynamically set.");
           }
+          if (table != null && table.get().getProjectId() != null) {
+            // Check for source table presence for early failure notification.
+            BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
+            BigQueryHelpers.verifyTablePresence(datasetService, table.get());
+          } else if (getQuery() != null) {
+            checkArgument(
+                getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
+            JobService jobService = getBigQueryServices().getJobService(bqOptions);
+            try {
+              jobService.dryRunQuery(
+                  bqOptions.getBigQueryProject() == null
+                      ? bqOptions.getProject()
+                      : bqOptions.getBigQueryProject(),
+                  new JobConfigurationQuery()
+                      .setQuery(getQuery().get())
+                      .setFlattenResults(getFlattenResults())
+                      .setUseLegacySql(getUseLegacySql()),
+                  getQueryLocation());
+            } catch (Exception e) {
+              throw new IllegalArgumentException(
+                  String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+            }
 
-          DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
-          // If the user provided a temp dataset, check if the dataset exists before launching the
-          // query
-          if (getQueryTempDataset() != null) {
-            // The temp table is only used for dataset and project id validation, not for table name
-            // validation
-            TableReference tempTable =
-                new TableReference()
-                    .setProjectId(
-                        bqOptions.getBigQueryProject() == null
-                            ? bqOptions.getProject()
-                            : bqOptions.getBigQueryProject())
-                    .setDatasetId(getQueryTempDataset())
-                    .setTableId("dummy table");
-            BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+            // If the user provided a temp dataset, check if the dataset exists before launching the
+            // query
+            if (getQueryTempDataset() != null) {
+              // The temp table is only used for dataset and project id validation, not for table
+              // name
+              // validation
+              TableReference tempTable =
+                  new TableReference()
+                      .setProjectId(
+                          bqOptions.getBigQueryProject() == null
+                              ? bqOptions.getProject()
+                              : bqOptions.getBigQueryProject())
+                      .setDatasetId(getQueryTempDataset())
+                      .setTableId("dummy table");
+              BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+            }
           }
+        } catch (Exception e) {
+          throw new RuntimeException(e);
         }
       }
     }
@@ -1401,15 +1405,17 @@
                           options.getJobName(), jobUuid, JobType.QUERY),
                       queryTempDataset);
 
-              DatasetService datasetService = getBigQueryServices().getDatasetService(options);
-              LOG.info("Deleting temporary table with query results {}", tempTable);
-              datasetService.deleteTable(tempTable);
-              // Delete dataset only if it was created by Beam
-              boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
-              if (datasetCreatedByBeam) {
-                LOG.info(
-                    "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
-                datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+              try (DatasetService datasetService =
+                  getBigQueryServices().getDatasetService(options)) {
+                LOG.info("Deleting temporary table with query results {}", tempTable);
+                datasetService.deleteTable(tempTable);
+                // Delete dataset only if it was created by Beam
+                boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
+                if (datasetCreatedByBeam) {
+                  LOG.info(
+                      "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
+                  datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+                }
               }
             }
           };
@@ -2484,17 +2490,20 @@
       // The user specified a table.
       if (getJsonTableRef() != null && getJsonTableRef().isAccessible() && getValidate()) {
         TableReference table = getTableWithDefaultProject(options).get();
-        DatasetService datasetService = getBigQueryServices().getDatasetService(options);
-        // Check for destination table presence and emptiness for early failure notification.
-        // Note that a presence check can fail when the table or dataset is created by an earlier
-        // stage of the pipeline. For these cases the #withoutValidation method can be used to
-        // disable the check.
-        BigQueryHelpers.verifyDatasetPresence(datasetService, table);
-        if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
-          BigQueryHelpers.verifyTablePresence(datasetService, table);
-        }
-        if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
-          BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+        try (DatasetService datasetService = getBigQueryServices().getDatasetService(options)) {
+          // Check for destination table presence and emptiness for early failure notification.
+          // Note that a presence check can fail when the table or dataset is created by an earlier
+          // stage of the pipeline. For these cases the #withoutValidation method can be used to
+          // disable the check.
+          BigQueryHelpers.verifyDatasetPresence(datasetService, table);
+          if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
+            BigQueryHelpers.verifyTablePresence(datasetService, table);
+          }
+          if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
+            BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+          }
+        } catch (Exception e) {
+          throw new RuntimeException(e);
         }
       }
     }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
index aeee3d2..eb9f8c4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
@@ -38,10 +38,14 @@
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.List;
 import java.util.NoSuchElementException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.ServiceCallMetric;
 import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource;
 import org.apache.beam.sdk.io.range.ByteKeyRange;
 import org.apache.beam.sdk.values.KV;
@@ -130,12 +134,40 @@
       String tableNameSr =
           session.getOptions().getInstanceName().toTableNameStr(source.getTableId().get());
 
+      HashMap<String, String> baseLabels = new HashMap<>();
+      baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+      baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+      baseLabels.put(MonitoringInfoConstants.Labels.METHOD, "google.bigtable.v2.ReadRows");
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.RESOURCE,
+          GcpResourceIdentifiers.bigtableResource(
+              session.getOptions().getProjectId(),
+              session.getOptions().getInstanceId(),
+              source.getTableId().get()));
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, session.getOptions().getProjectId());
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.INSTANCE_ID, session.getOptions().getInstanceId());
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.TABLE_ID,
+          GcpResourceIdentifiers.bigtableTableID(
+              session.getOptions().getProjectId(),
+              session.getOptions().getInstanceId(),
+              source.getTableId().get()));
+      ServiceCallMetric serviceCallMetric =
+          new ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
       ReadRowsRequest.Builder requestB =
           ReadRowsRequest.newBuilder().setRows(rowSet).setTableName(tableNameSr);
       if (source.getRowFilter() != null) {
         requestB.setFilter(source.getRowFilter());
       }
-      results = session.getDataClient().readRows(requestB.build());
+      try {
+        results = session.getDataClient().readRows(requestB.build());
+        serviceCallMetric.call("ok");
+      } catch (StatusRuntimeException e) {
+        serviceCallMetric.call(e.getStatus().getCode().value());
+        throw e;
+      }
       return advance();
     }
 
@@ -182,10 +214,12 @@
   static class BigtableWriterImpl implements Writer {
     private BigtableSession session;
     private BulkMutation bulkMutation;
+    private BigtableTableName tableName;
 
     BigtableWriterImpl(BigtableSession session, BigtableTableName tableName) {
       this.session = session;
       bulkMutation = session.createBulkMutation(tableName);
+      this.tableName = tableName;
     }
 
     @Override
@@ -231,6 +265,28 @@
               .addAllMutations(record.getValue())
               .build();
 
+      HashMap<String, String> baseLabels = new HashMap<>();
+      baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+      baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+      baseLabels.put(MonitoringInfoConstants.Labels.METHOD, "google.bigtable.v2.MutateRows");
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.RESOURCE,
+          GcpResourceIdentifiers.bigtableResource(
+              session.getOptions().getProjectId(),
+              session.getOptions().getInstanceId(),
+              tableName.getTableId()));
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, session.getOptions().getProjectId());
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.INSTANCE_ID, session.getOptions().getInstanceId());
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.TABLE_ID,
+          GcpResourceIdentifiers.bigtableTableID(
+              session.getOptions().getProjectId(),
+              session.getOptions().getInstanceId(),
+              tableName.getTableId()));
+      ServiceCallMetric serviceCallMetric =
+          new ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
       CompletableFuture<MutateRowResponse> result = new CompletableFuture<>();
       Futures.addCallback(
           new VendoredListenableFutureAdapter<>(bulkMutation.add(request)),
@@ -238,10 +294,17 @@
             @Override
             public void onSuccess(MutateRowResponse mutateRowResponse) {
               result.complete(mutateRowResponse);
+              serviceCallMetric.call("ok");
             }
 
             @Override
             public void onFailure(Throwable throwable) {
+              if (throwable instanceof StatusRuntimeException) {
+                serviceCallMetric.call(
+                    ((StatusRuntimeException) throwable).getStatus().getCode().value());
+              } else {
+                serviceCallMetric.call("unknown");
+              }
               result.completeExceptionally(throwable);
             }
           },
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
index 69be079..b983cc3 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable;
 
+import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.times;
@@ -42,9 +43,15 @@
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
+import java.util.HashMap;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource;
 import org.apache.beam.sdk.io.range.ByteKey;
 import org.apache.beam.sdk.io.range.ByteKeyRange;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -61,8 +68,12 @@
 @RunWith(JUnit4.class)
 public class BigtableServiceImplTest {
 
+  private static final String PROJECT_ID = "project";
+  private static final String INSTANCE_ID = "instance";
+  private static final String TABLE_ID = "table";
+
   private static final BigtableTableName TABLE_NAME =
-      new BigtableInstanceName("project", "instance").toTableName("table");
+      new BigtableInstanceName(PROJECT_ID, INSTANCE_ID).toTableName(TABLE_ID);
 
   @Mock private BigtableSession mockSession;
 
@@ -76,10 +87,13 @@
   public void setup() {
     MockitoAnnotations.initMocks(this);
     BigtableOptions options =
-        new BigtableOptions.Builder().setProjectId("project").setInstanceId("instance").build();
+        new BigtableOptions.Builder().setProjectId(PROJECT_ID).setInstanceId(INSTANCE_ID).build();
     when(mockSession.getOptions()).thenReturn(options);
     when(mockSession.createBulkMutation(eq(TABLE_NAME))).thenReturn(mockBulkMutation);
     when(mockSession.getDataClient()).thenReturn(mockBigtableDataClient);
+    // Setup the ProcessWideContainer for testing metrics are set.
+    MetricsContainerImpl container = new MetricsContainerImpl(null);
+    MetricsEnvironment.setProcessWideContainer(container);
   }
 
   /**
@@ -94,7 +108,7 @@
     ByteKey start = ByteKey.copyFrom("a".getBytes(StandardCharsets.UTF_8));
     ByteKey end = ByteKey.copyFrom("b".getBytes(StandardCharsets.UTF_8));
     when(mockBigtableSource.getRanges()).thenReturn(Arrays.asList(ByteKeyRange.of(start, end)));
-    when(mockBigtableSource.getTableId()).thenReturn(StaticValueProvider.of("table_name"));
+    when(mockBigtableSource.getTableId()).thenReturn(StaticValueProvider.of(TABLE_ID));
     @SuppressWarnings("unchecked")
     ResultScanner<Row> mockResultScanner = Mockito.mock(ResultScanner.class);
     Row expectedRow = Row.newBuilder().setKey(ByteString.copyFromUtf8("a")).build();
@@ -109,6 +123,7 @@
     underTest.close();
 
     verify(mockResultScanner, times(1)).close();
+    verifyMetricWasSet("google.bigtable.v2.ReadRows", "ok", 1);
   }
 
   /**
@@ -140,4 +155,27 @@
     underTest.close();
     verify(mockBulkMutation, times(1)).flush();
   }
+
+  private void verifyMetricWasSet(String method, String status, long count) {
+    // Verify the metric as reported.
+    HashMap<String, String> labels = new HashMap<>();
+    labels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+    labels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+    labels.put(MonitoringInfoConstants.Labels.METHOD, method);
+    labels.put(
+        MonitoringInfoConstants.Labels.RESOURCE,
+        GcpResourceIdentifiers.bigtableResource(PROJECT_ID, INSTANCE_ID, TABLE_ID));
+    labels.put(MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, PROJECT_ID);
+    labels.put(MonitoringInfoConstants.Labels.INSTANCE_ID, INSTANCE_ID);
+    labels.put(
+        MonitoringInfoConstants.Labels.TABLE_ID,
+        GcpResourceIdentifiers.bigtableTableID(PROJECT_ID, INSTANCE_ID, TABLE_ID));
+    labels.put(MonitoringInfoConstants.Labels.STATUS, status);
+
+    MonitoringInfoMetricName name =
+        MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, labels);
+    MetricsContainerImpl container =
+        (MetricsContainerImpl) MetricsEnvironment.getProcessWideContainer();
+    assertEquals(count, (long) container.getCounter(name).getCumulative());
+  }
 }
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 2cab8b4..c0e36f7 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -1900,6 +1900,17 @@
                 .withMaxRetries(retryConfiguration.getMaxAttempts());
       }
 
+      @Override
+      public void populateDisplayData(DisplayData.Builder builder) {
+        spec.populateDisplayData(builder);
+        builder.add(
+            DisplayData.item(
+                "query", preparedStatement == null ? "null" : preparedStatement.toString()));
+        builder.add(
+            DisplayData.item("dataSource", dataSource == null ? "null" : dataSource.toString()));
+        builder.add(DisplayData.item("spec", spec == null ? "null" : spec.toString()));
+      }
+
       @ProcessElement
       public void processElement(ProcessContext context) throws Exception {
         T record = context.element();
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 886e56e..9ed0b25 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -208,8 +208,6 @@
 
   def load_pcoder(self, *labels):
     saved_pcoder = self._saved_pcoders.get(self._path(*labels), None)
-    # TODO(BEAM-12506): Get rid of the SafeFastPrimitivesCoder for
-    # WindowedValueHolder.
     if saved_pcoder is None or isinstance(saved_pcoder,
                                           coders.FastPrimitivesCoder):
       return self._default_pcoder
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 054c9a6..fc8a8aa 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -390,8 +390,6 @@
   def load_pcoder(self, *labels):
     saved_pcoder = self._saved_pcoders.get(
         os.path.join(self._cache_dir, *labels), None)
-    # TODO(BEAM-12506): Get rid of the SafeFastPrimitivesCoder for
-    # WindowedValueHolder.
     if saved_pcoder is None or isinstance(saved_pcoder,
                                           coders.FastPrimitivesCoder):
       return self._default_pcoder
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
index cee3d34..1dc42e0 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
@@ -227,7 +227,11 @@
       endpoint=test_stream_service.endpoint)
   sql_source = {}
   for tag, output in output_pcolls.items():
-    sql_source[tag_to_name[tag]] = output
+    name = tag_to_name[tag]
+    # Must mark the element_type to avoid introducing pickled Python coder
+    # to the Java expansion service.
+    output.element_type = name_to_pcoll[name].element_type
+    sql_source[name] = output
   return sql_source