Merge pull request #14704: [BEAM-12253] Change Read.UnboundedSourceAsSDFRestrictionTracker.getSplitBacklog to use the reader cache

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
index 09464db..241fa1f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
@@ -63,6 +63,7 @@
 import org.apache.beam.sdk.values.ValueWithRecordId.StripIdsDoFn;
 import org.apache.beam.sdk.values.ValueWithRecordId.ValueWithRecordIdCoder;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalListener;
@@ -827,24 +828,29 @@
                 source, checkpoint, BoundedWindow.TIMESTAMP_MIN_VALUE));
       }
 
+      private void initializeCurrentReader() throws IOException {
+        Preconditions.checkState(currentReader == null);
+        Object cacheKey =
+            createCacheKey(initialRestriction.getSource(), initialRestriction.getCheckpoint());
+        currentReader = cachedReaders.getIfPresent(cacheKey);
+        if (currentReader == null) {
+          currentReader =
+              initialRestriction
+                  .getSource()
+                  .createReader(pipelineOptions, initialRestriction.getCheckpoint());
+        } else {
+          // If the reader is from cache, then we know that the reader has been started.
+          // We also remove this cache entry to avoid eviction.
+          readerHasBeenStarted = true;
+          cachedReaders.invalidate(cacheKey);
+        }
+      }
+
       @Override
       public boolean tryClaim(UnboundedSourceValue<OutputT>[] position) {
         try {
           if (currentReader == null) {
-            Object cacheKey =
-                createCacheKey(initialRestriction.getSource(), initialRestriction.getCheckpoint());
-            currentReader = cachedReaders.getIfPresent(cacheKey);
-            if (currentReader == null) {
-              currentReader =
-                  initialRestriction
-                      .getSource()
-                      .createReader(pipelineOptions, initialRestriction.getCheckpoint());
-            } else {
-              // If the reader is from cache, then we know that the reader has been started.
-              // We also remove this cache entry to avoid eviction.
-              readerHasBeenStarted = true;
-              cachedReaders.invalidate(cacheKey);
-            }
+            initializeCurrentReader();
           }
           if (currentReader instanceof EmptyUnboundedSource.EmptyUnboundedReader) {
             return false;
@@ -872,6 +878,8 @@
               currentReader.close();
             } catch (IOException closeException) {
               e.addSuppressed(closeException);
+            } finally {
+              currentReader = null;
             }
           }
           throw new RuntimeException(e);
@@ -957,10 +965,7 @@
 
         if (currentReader == null) {
           try {
-            currentReader =
-                initialRestriction
-                    .getSource()
-                    .createReader(pipelineOptions, initialRestriction.getCheckpoint());
+            initializeCurrentReader();
           } catch (IOException e) {
             throw new RuntimeException(e);
           }