[FLINK-35580] Fix synchronization issue when closing RocksDBWriteBatchWrapper
diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java b/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
index ebdb0f1..90fd201 100644
--- a/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
+++ b/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
@@ -19,11 +19,9 @@
 package org.apache.flink.core.fs;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.util.ExceptionUtils;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * This class allows to register instances of {@link Closeable}, which are all closed if this
@@ -39,19 +37,6 @@
 @Internal
 public interface ICloseableRegistry extends Closeable {
 
-    static Closeable asCloseable(AutoCloseable autoCloseable) {
-        AtomicBoolean closed = new AtomicBoolean(false);
-        return () -> {
-            if (closed.compareAndSet(false, true)) {
-                try {
-                    autoCloseable.close();
-                } catch (Exception e) {
-                    ExceptionUtils.rethrowIOException(e);
-                }
-            }
-        };
-    }
-
     /**
      * Registers a {@link Closeable} with the registry. In case the registry is already closed, this
      * method throws an {@link IllegalStateException} and closes the passed {@link Closeable}.
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index ba4c657..9f1857a 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -84,6 +84,7 @@
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import java.io.Closeable;
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
@@ -840,7 +841,10 @@
         try (RocksIteratorWrapper iterator =
                         RocksDBOperationUtils.getRocksIterator(db, stateMetaInfo.f0, readOptions);
                 RocksDBWriteBatchWrapper batchWriter =
-                        new RocksDBWriteBatchWrapper(db, getWriteOptions(), getWriteBatchSize())) {
+                        new RocksDBWriteBatchWrapper(db, getWriteOptions(), getWriteBatchSize());
+                Closeable ignored =
+                        cancelStreamRegistry.registerCloseableTemporarily(
+                                writeBatchWrapper.getCancelCloseable())) {
             iterator.seekToFirst();
 
             DataInputDeserializer serializedValueInput = new DataInputDeserializer();
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
index 354009e..c2da551 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
@@ -19,6 +19,7 @@
 package org.apache.flink.contrib.streaming.state;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.Preconditions;
 
@@ -32,6 +33,7 @@
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -47,6 +49,9 @@
     private static final int PER_RECORD_BYTES = 100;
     // default 0 for disable memory size based flush
     private static final long DEFAULT_BATCH_SIZE = 0;
+    private static final int DEFAULT_CANCELLATION_CHECK_INTERVAL = MIN_CAPACITY;
+    private static final int DEFAULT_CANCELLATION_CHECK_INTERVAL_BYTES =
+            DEFAULT_CANCELLATION_CHECK_INTERVAL * PER_RECORD_BYTES;
 
     private final RocksDB db;
 
@@ -61,6 +66,14 @@
     /** List of all objects that we need to close in close(). */
     private final List<AutoCloseable> toClose;
 
+    private volatile boolean cancelled;
+
+    private final int cancellationCheckInterval;
+
+    private final long cancellationCheckIntervalBytes;
+
+    private long lastCancellationCheckBatchSize = 0L;
+
     public RocksDBWriteBatchWrapper(@Nonnull RocksDB rocksDB, long writeBatchSize) {
         this(rocksDB, null, 500, writeBatchSize);
     }
@@ -79,6 +92,22 @@
             @Nullable WriteOptions options,
             int capacity,
             long batchSize) {
+        this(
+                rocksDB,
+                options,
+                capacity,
+                batchSize,
+                DEFAULT_CANCELLATION_CHECK_INTERVAL,
+                DEFAULT_CANCELLATION_CHECK_INTERVAL_BYTES);
+    }
+
+    public RocksDBWriteBatchWrapper(
+            @Nonnull RocksDB rocksDB,
+            @Nullable WriteOptions options,
+            int capacity,
+            long batchSize,
+            int cancellationCheckInterval,
+            long cancellationCheckIntervalBytes) {
         Preconditions.checkArgument(
                 capacity >= MIN_CAPACITY && capacity <= MAX_CAPACITY,
                 "capacity should be between " + MIN_CAPACITY + " and " + MAX_CAPACITY);
@@ -104,16 +133,27 @@
             // We own this object, so we must ensure that we close it.
             this.toClose.add(this.options);
         }
+        this.cancellationCheckInterval = cancellationCheckInterval;
+        this.cancellationCheckIntervalBytes = cancellationCheckIntervalBytes;
     }
 
     public void put(@Nonnull ColumnFamilyHandle handle, @Nonnull byte[] key, @Nonnull byte[] value)
             throws RocksDBException {
+        maybeEnsureNotCancelled();
 
         batch.put(handle, key, value);
 
         flushIfNeeded();
     }
 
+    private void maybeEnsureNotCancelled() {
+        if (batch.count() % cancellationCheckInterval == 0
+                || batch.getDataSize() - lastCancellationCheckBatchSize
+                        >= cancellationCheckIntervalBytes) {
+            ensureNotCancelled();
+        }
+    }
+
     public void remove(@Nonnull ColumnFamilyHandle handle, @Nonnull byte[] key)
             throws RocksDBException {
 
@@ -123,8 +163,10 @@
     }
 
     public void flush() throws RocksDBException {
+        ensureNotCancelled();
         db.write(options, batch);
         batch.clear();
+        lastCancellationCheckBatchSize = 0;
     }
 
     @VisibleForTesting
@@ -132,9 +174,18 @@
         return options;
     }
 
+    public void markCancelled() {
+        this.cancelled = true;
+    }
+
+    public Closeable getCancelCloseable() {
+        return this::markCancelled;
+    }
+
     @Override
     public void close() throws RocksDBException {
         try {
+            ensureNotCancelled();
             if (batch.count() != 0) {
                 flush();
             }
@@ -143,6 +194,13 @@
         }
     }
 
+    private void ensureNotCancelled() {
+        if (cancelled) {
+            throw new CancelTaskException();
+        }
+        lastCancellationCheckBatchSize = batch.getDataSize();
+    }
+
     private void flushIfNeeded() throws RocksDBException {
         boolean needFlush =
                 batch.count() == capacity || (batchSize > 0 && getDataSize() >= batchSize);
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
index fdaf21f..052404b 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
@@ -52,8 +52,6 @@
 import java.util.Map;
 import java.util.function.Function;
 
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
-
 /** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
 public class RocksDBFullRestoreOperation<K> implements RocksDBRestoreOperation {
     private final FullSnapshotRestoreOperation<K> savepointRestoreOperation;
@@ -149,7 +147,7 @@
                         new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
                 Closeable ignored =
                         cancelStreamRegistryForRestore.registerCloseableTemporarily(
-                                asCloseable(writeBatchWrapper))) {
+                                writeBatchWrapper.getCancelCloseable())) {
             ColumnFamilyHandle handle = null;
             while (keyGroups.hasNext()) {
                 KeyGroup keyGroup = keyGroups.next();
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
index c4eed5f..da2faca 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
@@ -64,8 +64,6 @@
 import java.util.Map;
 import java.util.function.Function;
 
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
-
 /** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
 public class RocksDBHeapTimersFullRestoreOperation<K> implements RocksDBRestoreOperation {
     private final FullSnapshotRestoreOperation<K> savepointRestoreOperation;
@@ -194,7 +192,7 @@
                         new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
                 Closeable ignored =
                         cancelStreamRegistryForRestore.registerCloseableTemporarily(
-                                asCloseable(writeBatchWrapper))) {
+                                writeBatchWrapper.getCancelCloseable())) {
             HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement> restoredPQ = null;
             ColumnFamilyHandle handle = null;
             while (keyGroups.hasNext()) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
index 6d7ed6e..89ec9d1 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
@@ -90,7 +90,6 @@
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
 import static org.apache.flink.runtime.metrics.MetricNames.DOWNLOAD_STATE_DURATION;
 import static org.apache.flink.runtime.metrics.MetricNames.RESTORE_ASYNC_COMPACTION_DURATION;
 import static org.apache.flink.runtime.metrics.MetricNames.RESTORE_STATE_DURATION;
@@ -798,7 +797,7 @@
                         new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
                 Closeable ignored =
                         cancelStreamRegistryForRestore.registerCloseableTemporarily(
-                                asCloseable(writeBatchWrapper))) {
+                                writeBatchWrapper.getCancelCloseable())) {
             for (IncrementalLocalKeyedStateHandle handleToCopy : toImport) {
                 try (RestoredDBInstance restoredDBInstance =
                         restoreTempDBInstanceFromLocalState(handleToCopy)) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
index 6a7c95a..05465ca 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
@@ -19,6 +19,8 @@
 package org.apache.flink.contrib.streaming.state;
 
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.execution.CancelTaskException;
 
 import org.junit.Assert;
 import org.junit.Rule;
@@ -29,11 +31,14 @@
 import org.rocksdb.RocksDB;
 import org.rocksdb.WriteOptions;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.apache.flink.contrib.streaming.state.RocksDBConfigurableOptions.WRITE_BATCH_SIZE;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -43,6 +48,58 @@
 
     @Rule public TemporaryFolder folder = new TemporaryFolder();
 
+    @Test(expected = CancelTaskException.class)
+    public void testAsyncCancellation() throws Exception {
+        final CompletableFuture<Void> writeStartedFuture = new CompletableFuture<>();
+        final CompletableFuture<Void> cancellationRequestedFuture = new CompletableFuture<>();
+        final CloseableRegistry registry = new CloseableRegistry();
+        new Thread(
+                        () -> {
+                            writeStartedFuture.join();
+                            try {
+                                registry.close();
+                                cancellationRequestedFuture.complete(null);
+                            } catch (IOException e) {
+                                cancellationRequestedFuture.completeExceptionally(e);
+                            }
+                        })
+                .start();
+
+        final int capacity = 1000; // max
+        final int cancellationCheckInterval = 1;
+        long batchSizeBytes = WRITE_BATCH_SIZE.defaultValue().getBytes();
+
+        try (RocksDB db = RocksDB.open(folder.newFolder().getAbsolutePath());
+                WriteOptions options = new WriteOptions().setDisableWAL(true);
+                ColumnFamilyHandle handle =
+                        db.createColumnFamily(new ColumnFamilyDescriptor("test".getBytes()));
+                RocksDBWriteBatchWrapper writeBatchWrapper =
+                        new RocksDBWriteBatchWrapper(
+                                db,
+                                options,
+                                capacity,
+                                batchSizeBytes,
+                                cancellationCheckInterval,
+                                batchSizeBytes)) {
+            registry.registerCloseable(writeBatchWrapper.getCancelCloseable());
+            writeStartedFuture.complete(null);
+
+            //noinspection InfiniteLoopStatement
+            for (int i = 0; ; i++) {
+                try {
+                    writeBatchWrapper.put(
+                            handle, ("key:" + i).getBytes(), ("value:" + i).getBytes());
+                } catch (Exception e) {
+                    cancellationRequestedFuture.join(); // shouldn't have any errors
+                    throw e;
+                }
+                // make sure that cancellation is triggered earlier than periodic flush
+                // but allow some delay of cancellation propagation
+                assertThat(i).isLessThan(cancellationCheckInterval * 2);
+            }
+        }
+    }
+
     @Test
     public void basicTest() throws Exception {