Cloud Bigtable stream changes and handle CloseStream responses (#25460)

* Handle ChangeStreamMutation

* Evaluate CloseStream split and merge messages from Change Stream API

* Fix rebase issues

---------

Co-authored-by: Pablo <pabloem@users.noreply.github.com>
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
index 8f307f5..34d3aff 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
@@ -18,11 +18,15 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams;
 
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.protobuf.ByteString;
 import com.google.protobuf.TextFormat;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 
 /** Helper functions to evaluate the completeness of collection of ByteStringRanges. */
 public class ByteStringRangeHelper {
-
   /**
    * Returns formatted string of a partition for debugging.
    *
@@ -36,4 +40,116 @@
         + TextFormat.escapeBytes(partition.getEnd())
         + "')";
   }
+
+  /**
+   * Convert partitions to a string for debugging.
+   *
+   * @param partitions to print
+   * @return string representation of partitions
+   */
+  public static String partitionsToString(List<ByteStringRange> partitions) {
+    return partitions.stream()
+        .map(ByteStringRangeHelper::formatByteStringRange)
+        .collect(Collectors.joining(", ", "{", "}"));
+  }
+
+  @VisibleForTesting
+  static class PartitionComparator implements Comparator<ByteStringRange> {
+    @Override
+    // if first > second, it returns positive number
+    // if first < second, it returns negative number
+    // if first == second, it returns 0
+    // First is greater than second if either of the following are true:
+    // - Its start key comes after second's start key
+    // - The start keys are equal and its end key comes after second's end key
+    // An end key of "" represents the final end key, so it needs to be handled as a special case
+    public int compare(ByteStringRange first, ByteStringRange second) {
+      int compareStart =
+          ByteString.unsignedLexicographicalComparator()
+              .compare(first.getStart(), second.getStart());
+      if (compareStart != 0) {
+        return compareStart;
+      }
+      if (first.getEnd().isEmpty() && !second.getEnd().isEmpty()) {
+        return 1;
+      }
+      if (second.getEnd().isEmpty() && !first.getEnd().isEmpty()) {
+        return -1;
+      }
+      return ByteString.unsignedLexicographicalComparator()
+          .compare(first.getEnd(), second.getEnd());
+    }
+  }
+
+  private static boolean childStartsBeforeParent(
+      ByteString parentStartKey, ByteString childStartKey) {
+    // Check if the start key of the child partition comes before the start key of the entire
+    // parentPartitions
+    return ByteString.unsignedLexicographicalComparator().compare(parentStartKey, childStartKey)
+        > 0;
+  }
+
+  private static boolean childEndsAfterParent(ByteString parentEndKey, ByteString childEndKey) {
+    // A final end key is represented by "" but this evaluates to < all characters, so we need to
+    // handle it as a special case.
+    if (childEndKey.isEmpty() && !parentEndKey.isEmpty()) {
+      return true;
+    }
+
+    // Check if the end key of the child partition comes after the end key of the entire
+    // parentPartitions. "" Represents the final end key so we need to handle that as a
+    // special case when it is the end key of the entire parentPartitions
+    return ByteString.unsignedLexicographicalComparator().compare(parentEndKey, childEndKey) < 0
+        && !parentEndKey.isEmpty();
+  }
+
+  // This assumes parentPartitions is sorted. If parentPartitions has not already been sorted
+  // it will be incorrect
+  private static boolean gapsInParentPartitions(List<ByteStringRange> sortedParentPartitions) {
+    for (int i = 1; i < sortedParentPartitions.size(); i++) {
+      // Iterating through a sorted list, the start key should be the same or before the end of the
+      // previous. Handle "" end key as a special case.
+      ByteString prevEndKey = sortedParentPartitions.get(i - 1).getEnd();
+      if (ByteString.unsignedLexicographicalComparator()
+                  .compare(sortedParentPartitions.get(i).getStart(), prevEndKey)
+              > 0
+          && !prevEndKey.isEmpty()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /**
+   * Returns true if parentPartitions is a superset of childPartition.
+   *
+   * <p>If ordered parentPartitions row ranges form a contiguous range, and start key is before or
+   * at childPartition's start key, and end key is at or after childPartition's end key, then
+   * parentPartitions is a superset of childPartition.
+   *
+   * <p>Overlaps from parents are valid because arbitrary partitions can merge and they may overlap.
+   * They will form a valid new partition. However, if there are any missing parent partitions, then
+   * merge cannot happen with missing row ranges.
+   *
+   * @param parentPartitions list of partitions to determine if it forms a large contiguous range
+   * @param childPartition the smaller partition
+   * @return true if parentPartitions is a superset of childPartition, otherwise false.
+   */
+  public static boolean isSuperset(
+      List<ByteStringRange> parentPartitions, ByteStringRange childPartition) {
+    // sort parentPartitions by starting key
+    // iterate through, check open end key and close start key of each iteration to ensure no gaps.
+    // first start key and last end key must be equal to or wider than child partition start and end
+    // key.
+    if (parentPartitions.isEmpty()) {
+      return false;
+    }
+    parentPartitions.sort(new PartitionComparator());
+    ByteString parentStartKey = parentPartitions.get(0).getStart();
+    ByteString parentEndKey = parentPartitions.get(parentPartitions.size() - 1).getEnd();
+
+    return !childStartsBeforeParent(parentStartKey, childPartition.getStart())
+        && !childEndsAfterParent(parentEndKey, childPartition.getEnd())
+        && !gapsInParentPartitions(parentPartitions);
+  }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
index ed14eb5..b7b24c4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
@@ -48,6 +48,14 @@
           "heartbeat_count");
 
   /**
+   * Counter for the total number of heartbeats identified during the execution of the Connector.
+   */
+  public static final Counter CLOSESTREAM_COUNT =
+      Metrics.counter(
+          org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
+          "closestream_count");
+
+  /**
    * Counter for the total number of ChangeStreamMutations that are initiated by users (not garbage
    * collection) identified during the execution of the Connector.
    */
@@ -71,6 +79,12 @@
           org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
           "processing_delay_from_commit_timestamp");
 
+  /** Counter for the total number of active partitions being streamed. */
+  public static final Counter PARTITION_STREAM_COUNT =
+      Metrics.counter(
+          org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
+          "partition_stream_count");
+
   /**
    * Increments the {@link
    * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#LIST_PARTITIONS_COUNT} by
@@ -91,6 +105,15 @@
 
   /**
    * Increments the {@link
+   * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#CLOSESTREAM_COUNT} by 1
+   * if the metric is enabled.
+   */
+  public void incClosestreamCount() {
+    inc(CLOSESTREAM_COUNT);
+  }
+
+  /**
+   * Increments the {@link
    * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#CHANGE_STREAM_MUTATION_USER_COUNT}
    * by 1 if the metric is enabled.
    */
@@ -108,6 +131,24 @@
   }
 
   /**
+   * Increments the {@link
+   * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PARTITION_STREAM_COUNT}
+   * by 1.
+   */
+  public void incPartitionStreamCount() {
+    inc(PARTITION_STREAM_COUNT);
+  }
+
+  /**
+   * Decrements the {@link
+   * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PARTITION_STREAM_COUNT}
+   * by 1.
+   */
+  public void decPartitionStreamCount() {
+    dec(PARTITION_STREAM_COUNT);
+  }
+
+  /**
    * Adds measurement of an instance for the {@link
    * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PROCESSING_DELAY_FROM_COMMIT_TIMESTAMP}.
    */
@@ -119,6 +160,10 @@
     counter.inc();
   }
 
+  private void dec(Counter counter) {
+    counter.dec();
+  }
+
   private void update(Distribution distribution, long value) {
     distribution.update(value);
   }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
index e64cd6f..5f37847 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
@@ -22,10 +22,12 @@
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.cloud.bigtable.data.v2.models.Heartbeat;
 import com.google.cloud.bigtable.data.v2.models.Range;
 import com.google.protobuf.ByteString;
 import java.util.Optional;
+import java.util.stream.Collectors;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
@@ -135,6 +137,38 @@
         return Optional.of(DoFn.ProcessContinuation.stop());
       }
       metrics.incHeartbeatCount();
+    } else if (record instanceof CloseStream) {
+      CloseStream closeStream = (CloseStream) record;
+      StreamProgress streamProgress = new StreamProgress(closeStream);
+
+      if (shouldDebug) {
+        LOG.info(
+            "RCSP {}: CloseStream: {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            closeStream.getChangeStreamContinuationTokens().stream()
+                .map(
+                    c ->
+                        "{partition: "
+                            + formatByteStringRange(c.getPartition())
+                            + " token: "
+                            + c.getToken()
+                            + "}")
+                .collect(Collectors.joining(", ", "[", "]")));
+      }
+      // If the tracker fail to claim the streamProgress, it most likely means the runner initiated
+      // a checkpoint. See {@link
+      // org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker}
+      // for more information regarding runner initiated checkpoints.
+      if (!tracker.tryClaim(streamProgress)) {
+        if (shouldDebug) {
+          LOG.info(
+              "RCSP {}: Failed to claim close stream tracker",
+              formatByteStringRange(partitionRecord.getPartition()));
+        }
+        return Optional.of(DoFn.ProcessContinuation.stop());
+      }
+      metrics.incClosestreamCount();
+      return Optional.of(DoFn.ProcessContinuation.resume());
     } else if (record instanceof ChangeStreamMutation) {
       ChangeStreamMutation changeStreamMutation = (ChangeStreamMutation) record;
       final Instant watermark =
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
index a36b57f..8a49788 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
@@ -17,11 +17,21 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.isSuperset;
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString;
+
 import com.google.api.gax.rpc.ServerStream;
+import com.google.cloud.bigtable.common.Status;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
+import com.google.cloud.bigtable.data.v2.models.Range;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.ChangeStreamDao;
@@ -126,6 +136,47 @@
               + tracker.currentRestriction());
     }
 
+    // Process CloseStream if it exists
+    CloseStream closeStream = tracker.currentRestriction().getCloseStream();
+    if (closeStream != null) {
+      if (closeStream.getStatus().getCode() != Status.Code.OUT_OF_RANGE) {
+        LOG.error(
+            "RCSP {}: Reached unexpected terminal state: {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            closeStream.getStatus().toString());
+        metrics.decPartitionStreamCount();
+        return ProcessContinuation.stop();
+      }
+      // The partitions in the continuation tokens should be a superset of this partition.
+      // If there's only 1 token, then the token's partition should be a superset of this partition.
+      // If there are more than 1 tokens, then the tokens should form a continuous row range that is
+      // a superset of this partition.
+      List<Range.ByteStringRange> partitions = new ArrayList<>();
+      for (ChangeStreamContinuationToken changeStreamContinuationToken :
+          closeStream.getChangeStreamContinuationTokens()) {
+        partitions.add(changeStreamContinuationToken.getPartition());
+        metadataTableDao.writeNewPartition(
+            changeStreamContinuationToken,
+            partitionRecord.getPartition(),
+            watermarkEstimator.getState());
+      }
+      if (shouldDebug) {
+        LOG.info(
+            "RCSP {}: Split/Merge into {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            partitionsToString(partitions));
+      }
+      if (!isSuperset(partitions, partitionRecord.getPartition())) {
+        LOG.warn(
+            "RCSP {}: CloseStream has child partition(s) {} that doesn't cover the keyspace",
+            formatByteStringRange(partitionRecord.getPartition()),
+            partitionsToString(partitions));
+      }
+      metadataTableDao.deleteStreamPartitionRow(partitionRecord.getPartition());
+      metrics.decPartitionStreamCount();
+      return ProcessContinuation.stop();
+    }
+
     // Update the metadata table with the watermark
     metadataTableDao.updateWatermark(
         partitionRecord.getPartition(),
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
index 3c3e828..fd7c3ae 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
@@ -21,9 +21,13 @@
 import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.NEW_PARTITION_PREFIX;
 import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.STREAM_PARTITION_PREFIX;
 
+import com.google.api.gax.rpc.ServerStream;
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Filters;
+import com.google.cloud.bigtable.data.v2.models.Query;
 import com.google.cloud.bigtable.data.v2.models.Range;
+import com.google.cloud.bigtable.data.v2.models.Row;
 import com.google.cloud.bigtable.data.v2.models.RowMutation;
 import com.google.protobuf.ByteString;
 import javax.annotation.Nullable;
@@ -96,6 +100,78 @@
   }
 
   /**
+   * Convert partition to a New Partition row key to query for partitions ready to be streamed as
+   * the result of splits and merges.
+   *
+   * @param partition convert to row key
+   * @return row key to insert to Cloud Bigtable.
+   */
+  public ByteString convertPartitionToNewPartitionRowKey(Range.ByteStringRange partition) {
+    return getFullNewPartitionPrefix().concat(Range.ByteStringRange.toByteString(partition));
+  }
+
+  /**
+   * @return stream of all the new partitions resulting from splits and merges waiting to be
+   *     streamed.
+   */
+  public ServerStream<Row> readNewPartitions() {
+    // It's important that we limit to the latest value per column because it's possible to write to
+    // the same column multiple times. We don't want to read and send duplicate tokens to the
+    // server.
+    Query query =
+        Query.create(tableId)
+            .prefix(getFullNewPartitionPrefix())
+            .filter(Filters.FILTERS.limit().cellsPerColumn(1));
+    return dataClient.readRows(query);
+  }
+
+  /**
+   * After a split or merge from a close stream, write the new partition's information to the
+   * metadata table.
+   *
+   * @param changeStreamContinuationToken the token that can be used to pick up from where the
+   *     parent left off
+   * @param parentPartition the parent that stopped and split or merged
+   * @param lowWatermark the low watermark of the parent stream
+   */
+  public void writeNewPartition(
+      ChangeStreamContinuationToken changeStreamContinuationToken,
+      Range.ByteStringRange parentPartition,
+      Instant lowWatermark) {
+    writeNewPartition(
+        changeStreamContinuationToken.getPartition(),
+        changeStreamContinuationToken.toByteString(),
+        Range.ByteStringRange.toByteString(parentPartition),
+        lowWatermark);
+  }
+
+  /**
+   * After a split or merge from a close stream, write the new partition's information to the
+   * metadata table.
+   *
+   * @param newPartition the new partition
+   * @param newPartitionContinuationToken continuation token for the new partition
+   * @param parentPartition the parent that stopped
+   * @param lowWatermark low watermark of the parent
+   */
+  private void writeNewPartition(
+      Range.ByteStringRange newPartition,
+      ByteString newPartitionContinuationToken,
+      ByteString parentPartition,
+      Instant lowWatermark) {
+    ByteString rowKey = convertPartitionToNewPartitionRowKey(newPartition);
+    RowMutation rowMutation =
+        RowMutation.create(tableId, rowKey)
+            .setCell(MetadataTableAdminDao.CF_INITIAL_TOKEN, newPartitionContinuationToken, 1)
+            .setCell(MetadataTableAdminDao.CF_PARENT_PARTITIONS, parentPartition, 1)
+            .setCell(
+                MetadataTableAdminDao.CF_PARENT_LOW_WATERMARKS,
+                parentPartition,
+                ByteString.copyFromUtf8(Long.toString(lowWatermark.getMillis())));
+    dataClient.mutateRow(rowMutation);
+  }
+
+  /**
    * Update the metadata for the rowKey. This helper adds necessary prefixes to the row key.
    *
    * @param rowKey row key of the row to update
@@ -135,6 +211,18 @@
   }
 
   /**
+   * Delete the row key represented by the partition. This represents that the partition will no
+   * longer be streamed.
+   *
+   * @param partition forms the row key of the row to delete
+   */
+  public void deleteStreamPartitionRow(Range.ByteStringRange partition) {
+    ByteString rowKey = convertPartitionToStreamPartitionRowKey(partition);
+    RowMutation rowMutation = RowMutation.create(tableId, rowKey).deleteRow();
+    dataClient.mutateRow(rowMutation);
+  }
+
+  /**
    * Set the version number for DetectNewPartition. This value can be checked later to verify that
    * the existing metadata table is compatible with current beam connector code.
    */
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
index a7871dd..e366507 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
@@ -82,6 +82,7 @@
 
   @GetInitialRestriction
   public StreamProgress initialRestriction() {
+    metrics.incPartitionStreamCount();
     return new StreamProgress();
   }
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
index f615889..5eeb096 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
@@ -79,7 +79,7 @@
    */
   @Override
   public void checkDone() throws java.lang.IllegalStateException {
-    boolean done = shouldStop;
+    boolean done = shouldStop || streamProgress.getCloseStream() != null;
     Preconditions.checkState(done, "There's more work to be done");
   }
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
index ef35a04..c594af3 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction;
 
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.protobuf.Timestamp;
 import java.io.Serializable;
 import java.util.Objects;
@@ -36,9 +37,10 @@
  */
 @Internal
 public class StreamProgress implements Serializable {
-  private static final long serialVersionUID = -5384329262726188695L;
+  private static final long serialVersionUID = -8597355120329526194L;
 
   private @Nullable ChangeStreamContinuationToken currentToken;
+  private @Nullable CloseStream closeStream;
   private @Nullable Timestamp lowWatermark;
 
   public @Nullable ChangeStreamContinuationToken getCurrentToken() {
@@ -49,6 +51,10 @@
     return lowWatermark;
   }
 
+  public @Nullable CloseStream getCloseStream() {
+    return closeStream;
+  }
+
   public StreamProgress() {}
 
   public StreamProgress(@Nullable ChangeStreamContinuationToken token, Timestamp lowWatermark) {
@@ -56,6 +62,10 @@
     this.lowWatermark = lowWatermark;
   }
 
+  public StreamProgress(@Nullable CloseStream closeStream) {
+    this.closeStream = closeStream;
+  }
+
   @Override
   public boolean equals(@Nullable Object o) {
     if (this == o) {
@@ -66,7 +76,8 @@
     }
     StreamProgress that = (StreamProgress) o;
     return Objects.equals(getCurrentToken(), that.getCurrentToken())
-        && Objects.equals(getLowWatermark(), that.getLowWatermark());
+        && Objects.equals(getLowWatermark(), that.getLowWatermark())
+        && Objects.equals(getCloseStream(), that.getCloseStream());
   }
 
   @Override
@@ -81,6 +92,8 @@
         + currentToken
         + ", lowWatermark="
         + lowWatermark
+        + ", closeStream="
+        + closeStream
         + '}';
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java
new file mode 100644
index 0000000..78e63a1
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams;
+
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.junit.Test;
+
+public class ByteStringRangeHelperTest {
+
+  @Test
+  public void testParentIsEntireKeySpaceIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("", "");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "B");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsLeftSubSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("", "n");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsRightSubSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("n", "");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition1 = ByteStringRange.create("", "n");
+    ByteStringRange partition2 = ByteStringRange.create("n", "");
+    parentPartitions.add(partition1);
+    parentPartitions.add(partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentKeySpaceStartsBeforeAndEndAfterChildIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("AA", "AB");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentStartKeyIsAfterChildStartKeyIsNotSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("AA", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "AB");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentEndKeyIsBeforeChildEndKeyIsNotSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("AA", "BA");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentIsSameAsChildIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "B");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentIsMissingPartitionIsNotSuperSet() {
+    ByteStringRange partition1 = ByteStringRange.create("A", "B");
+    ByteStringRange partition2 = ByteStringRange.create("C", "Z");
+    List<ByteStringRange> parentPartitions = Arrays.asList(partition1, partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "Z");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testParentHasOverlapIsSuperSet() {
+    ByteStringRange partition1 = ByteStringRange.create("A", "C");
+    ByteStringRange partition2 = ByteStringRange.create("B", "Z");
+    List<ByteStringRange> parentPartitions = Arrays.asList(partition1, partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "Z");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testEmptyParentsIsNotSuperset() {
+    List<ByteStringRange> parentPartitions = Collections.emptyList();
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, childPartition));
+  }
+
+  @Test
+  public void testPartitionsToString() {
+    ByteStringRange partition1 = ByteStringRange.create("", "A");
+    ByteStringRange partition2 = ByteStringRange.create("A", "B");
+    ByteStringRange partition3 = ByteStringRange.create("B", "");
+    List<ByteStringRange> partitions = Arrays.asList(partition1, partition2, partition3);
+    String partitionsString = partitionsToString(partitions);
+    assertEquals(
+        String.format(
+            "{%s, %s, %s}",
+            formatByteStringRange(partition1),
+            formatByteStringRange(partition2),
+            formatByteStringRange(partition3)),
+        partitionsString);
+  }
+
+  @Test
+  public void testPartitionsToStringEmptyPartition() {
+    List<ByteStringRange> partitions = new ArrayList<>();
+    String partitionsString = partitionsToString(partitions);
+    assertEquals("{}", partitionsString);
+  }
+
+  @Test
+  public void testPartitionComparator() {
+    ByteStringRange partition1 = ByteStringRange.create("", "a");
+    ByteStringRange partition2 = ByteStringRange.create("", "");
+    ByteStringRange partition3 = ByteStringRange.create("a", "z");
+    ByteStringRange partition4 = ByteStringRange.create("a", "");
+    List<ByteStringRange> unsorted = Arrays.asList(partition3, partition4, partition2, partition1);
+    List<ByteStringRange> sorted = Arrays.asList(partition1, partition2, partition3, partition4);
+    unsorted.sort(new ByteStringRangeHelper.PartitionComparator());
+    assertEquals(unsorted, sorted);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
index b2c8548..46453a4 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
@@ -17,7 +17,9 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
@@ -27,10 +29,13 @@
 
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.cloud.bigtable.data.v2.models.Heartbeat;
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.Timestamp;
+import com.google.rpc.Status;
+import java.util.Collections;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
@@ -89,6 +94,27 @@
   }
 
   @Test
+  public void testCloseStreamResume() {
+    ChangeStreamContinuationToken changeStreamContinuationToken =
+        new ChangeStreamContinuationToken(ByteStringRange.create("a", "b"), "1234");
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(11).build();
+    Mockito.when(mockCloseStream.getStatus())
+        .thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    Mockito.when(mockCloseStream.getChangeStreamContinuationTokens())
+        .thenReturn(Collections.singletonList(changeStreamContinuationToken));
+
+    final Optional<DoFn.ProcessContinuation> result =
+        action.run(partitionRecord, mockCloseStream, tracker, receiver, watermarkEstimator, false);
+
+    assertTrue(result.isPresent());
+    assertEquals(DoFn.ProcessContinuation.resume(), result.get());
+    verify(metrics).incClosestreamCount();
+    StreamProgress streamProgress = new StreamProgress(mockCloseStream);
+    verify(tracker).tryClaim(eq(streamProgress));
+  }
+
+  @Test
   public void testChangeStreamMutationUser() {
     ByteStringRange partition = ByteStringRange.create("", "");
     when(partitionRecord.getPartition()).thenReturn(partition);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java
new file mode 100644
index 0000000..bea8728
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.api.gax.rpc.ServerStream;
+import com.google.cloud.Timestamp;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
+import com.google.cloud.bigtable.data.v2.models.Heartbeat;
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.protobuf.ByteString;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Optional;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.ChangeStreamDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+public class ReadChangeStreamPartitionActionTest {
+
+  private ReadChangeStreamPartitionAction action;
+
+  private MetadataTableDao metadataTableDao;
+  private ChangeStreamDao changeStreamDao;
+  private ChangeStreamMetrics metrics;
+  private ChangeStreamAction changeStreamAction;
+
+  //    private PartitionRecord partitionRecord;
+  private StreamProgress restriction;
+  private RestrictionTracker<StreamProgress, StreamProgress> tracker;
+  private DoFn.OutputReceiver<KV<ByteString, ChangeStreamMutation>> receiver;
+  private ManualWatermarkEstimator<Instant> watermarkEstimator;
+
+  private ByteStringRange partition;
+  private String uuid;
+  private PartitionRecord partitionRecord;
+
+  @Before
+  public void setUp() throws Exception {
+    metadataTableDao = mock(MetadataTableDao.class);
+    changeStreamDao = mock(ChangeStreamDao.class);
+    metrics = mock(ChangeStreamMetrics.class);
+    changeStreamAction = mock(ChangeStreamAction.class);
+    Duration heartbeatDurationSeconds = Duration.standardSeconds(1);
+
+    action =
+        new ReadChangeStreamPartitionAction(
+            metadataTableDao,
+            changeStreamDao,
+            metrics,
+            changeStreamAction,
+            heartbeatDurationSeconds);
+
+    restriction = mock(StreamProgress.class);
+    tracker = mock(ReadChangeStreamPartitionProgressTracker.class);
+    receiver = mock(DoFn.OutputReceiver.class);
+    watermarkEstimator = mock(ManualWatermarkEstimator.class);
+
+    partition = ByteStringRange.create("A", "B");
+    uuid = "123456";
+    Timestamp startTime = Timestamp.now();
+    Timestamp parentLowWatermark = Timestamp.now();
+    partitionRecord = new PartitionRecord(partition, startTime, uuid, parentLowWatermark);
+    when(tracker.currentRestriction()).thenReturn(restriction);
+    when(restriction.getCurrentToken()).thenReturn(null);
+    when(restriction.getCloseStream()).thenReturn(null);
+    // Setting watermark estimator to now so we don't debug.
+    when(watermarkEstimator.getState()).thenReturn(Instant.now());
+  }
+
+  @Test
+  public void testThatChangeStreamWorkerCounterIsIncrementedOnInitialRun() throws IOException {
+    // Return null token to indicate that this is the first ever run.
+    when(restriction.getCurrentToken()).thenReturn(null);
+    when(restriction.getCloseStream()).thenReturn(null);
+
+    final ServerStream<ChangeStreamRecord> responses = mock(ServerStream.class);
+    final Iterator<ChangeStreamRecord> responseIterator = mock(Iterator.class);
+    when(responses.iterator()).thenReturn(responseIterator);
+
+    Heartbeat mockHeartBeat = Mockito.mock(Heartbeat.class);
+    when(responseIterator.next()).thenReturn(mockHeartBeat);
+    when(responseIterator.hasNext()).thenReturn(true);
+    when(changeStreamDao.readChangeStreamPartition(any(), any(), any(), anyBoolean()))
+        .thenReturn(responses);
+
+    when(changeStreamAction.run(any(), any(), any(), any(), any(), anyBoolean()))
+        .thenReturn(Optional.of(DoFn.ProcessContinuation.stop()));
+
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    verify(changeStreamAction).run(any(), any(), any(), any(), any(), anyBoolean());
+  }
+
+  @Test
+  public void testCloseStreamTerminateOKStatus() throws IOException {
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(0).build();
+    Mockito.when(mockCloseStream.getStatus())
+        .thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Should not try to write any new partition to the metadata table.
+    verify(metadataTableDao, never()).writeNewPartition(any(), any(), any());
+    verify(metadataTableDao, never()).deleteStreamPartitionRow(any());
+  }
+
+  @Test
+  public void testCloseStreamTerminateNotOutOfRangeStatus() throws IOException {
+    // Out of Range code is 11.
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(10).build();
+    Mockito.when(mockCloseStream.getStatus())
+        .thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Should not try to write any new partition to the metadata table.
+    verify(metadataTableDao, never()).writeNewPartition(any(), any(), any());
+    verify(metadataTableDao, never()).deleteStreamPartitionRow(any());
+  }
+
+  @Test
+  public void testCloseStreamWritesContinuationTokens() throws IOException {
+    ChangeStreamContinuationToken changeStreamContinuationToken1 =
+        new ChangeStreamContinuationToken(ByteStringRange.create("A", "AJ"), "1234");
+    ChangeStreamContinuationToken changeStreamContinuationToken2 =
+        new ChangeStreamContinuationToken(ByteStringRange.create("AJ", "B"), "5678");
+
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(11).build();
+    Mockito.when(mockCloseStream.getStatus())
+        .thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    Mockito.when(mockCloseStream.getChangeStreamContinuationTokens())
+        .thenReturn(Arrays.asList(changeStreamContinuationToken1, changeStreamContinuationToken2));
+
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Write the new partitions.
+    verify(metadataTableDao).writeNewPartition(eq(changeStreamContinuationToken1), any(), any());
+    verify(metadataTableDao).writeNewPartition(eq(changeStreamContinuationToken2), any(), any());
+    verify(metadataTableDao, times(1)).deleteStreamPartitionRow(partitionRecord.getPartition());
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
index 8bb2381..fbea020 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
@@ -18,7 +18,9 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
+import com.google.api.gax.rpc.ServerStream;
 import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient;
 import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings;
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
@@ -27,6 +29,8 @@
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
 import com.google.cloud.bigtable.data.v2.models.Row;
 import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.UniqueIdGenerator;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder;
@@ -81,6 +85,46 @@
   }
 
   @Test
+  public void testNewPartitionsWriteRead() throws InvalidProtocolBufferException {
+    // This test a split of ["", "") to ["", "a") and ["a", "")
+    ByteStringRange parentPartition = ByteStringRange.create("", "");
+    ByteStringRange partition1 = ByteStringRange.create("", "a");
+    ChangeStreamContinuationToken changeStreamContinuationToken1 =
+        new ChangeStreamContinuationToken(partition1, "tk1");
+    ByteStringRange partition2 = ByteStringRange.create("a", "");
+    ChangeStreamContinuationToken changeStreamContinuationToken2 =
+        new ChangeStreamContinuationToken(partition2, "tk2");
+
+    Instant lowWatermark = Instant.now();
+    metadataTableDao.writeNewPartition(
+        changeStreamContinuationToken1, parentPartition, lowWatermark);
+    metadataTableDao.writeNewPartition(
+        changeStreamContinuationToken2, parentPartition, lowWatermark);
+
+    ServerStream<Row> rows = metadataTableDao.readNewPartitions();
+    int rowsCount = 0;
+    boolean matchedPartition1 = false;
+    boolean matchedPartition2 = false;
+    for (Row row : rows) {
+      rowsCount++;
+      ByteString newPartitionPrefix =
+          metadataTableDao
+              .getChangeStreamNamePrefix()
+              .concat(MetadataTableAdminDao.NEW_PARTITION_PREFIX);
+      ByteStringRange partition =
+          ByteStringRange.toByteStringRange(row.getKey().substring(newPartitionPrefix.size()));
+      if (partition.equals(partition1)) {
+        matchedPartition1 = true;
+      } else if (partition.equals(partition2)) {
+        matchedPartition2 = true;
+      }
+    }
+    assertTrue(matchedPartition1);
+    assertTrue(matchedPartition2);
+    assertEquals(2, rowsCount);
+  }
+
+  @Test
   public void testUpdateWatermark() {
     ByteStringRange partition = ByteStringRange.create("a", "b");
     Instant watermark = Instant.now();