[FLINK-35580] Fix synchronization issue when closing RocksDBWriteBatchWrapper
diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java b/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
index ebdb0f1..90fd201 100644
--- a/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
+++ b/flink-core/src/main/java/org/apache/flink/core/fs/ICloseableRegistry.java
@@ -19,11 +19,9 @@
package org.apache.flink.core.fs;
import org.apache.flink.annotation.Internal;
-import org.apache.flink.util.ExceptionUtils;
import java.io.Closeable;
import java.io.IOException;
-import java.util.concurrent.atomic.AtomicBoolean;
/**
* This class allows to register instances of {@link Closeable}, which are all closed if this
@@ -39,19 +37,6 @@
@Internal
public interface ICloseableRegistry extends Closeable {
- static Closeable asCloseable(AutoCloseable autoCloseable) {
- AtomicBoolean closed = new AtomicBoolean(false);
- return () -> {
- if (closed.compareAndSet(false, true)) {
- try {
- autoCloseable.close();
- } catch (Exception e) {
- ExceptionUtils.rethrowIOException(e);
- }
- }
- };
- }
-
/**
* Registers a {@link Closeable} with the registry. In case the registry is already closed, this
* method throws an {@link IllegalStateException} and closes the passed {@link Closeable}.
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index ba4c657..9f1857a 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -84,6 +84,7 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
@@ -840,7 +841,10 @@
try (RocksIteratorWrapper iterator =
RocksDBOperationUtils.getRocksIterator(db, stateMetaInfo.f0, readOptions);
RocksDBWriteBatchWrapper batchWriter =
- new RocksDBWriteBatchWrapper(db, getWriteOptions(), getWriteBatchSize())) {
+ new RocksDBWriteBatchWrapper(db, getWriteOptions(), getWriteBatchSize());
+ Closeable ignored =
+ cancelStreamRegistry.registerCloseableTemporarily(
+ writeBatchWrapper.getCancelCloseable())) {
iterator.seekToFirst();
DataInputDeserializer serializedValueInput = new DataInputDeserializer();
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
index 354009e..c2da551 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapper.java
@@ -19,6 +19,7 @@
package org.apache.flink.contrib.streaming.state;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.execution.CancelTaskException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.Preconditions;
@@ -32,6 +33,7 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import java.io.Closeable;
import java.util.ArrayList;
import java.util.List;
@@ -47,6 +49,9 @@
private static final int PER_RECORD_BYTES = 100;
// default 0 for disable memory size based flush
private static final long DEFAULT_BATCH_SIZE = 0;
+ private static final int DEFAULT_CANCELLATION_CHECK_INTERVAL = MIN_CAPACITY;
+ private static final int DEFAULT_CANCELLATION_CHECK_INTERVAL_BYTES =
+ DEFAULT_CANCELLATION_CHECK_INTERVAL * PER_RECORD_BYTES;
private final RocksDB db;
@@ -61,6 +66,14 @@
/** List of all objects that we need to close in close(). */
private final List<AutoCloseable> toClose;
+ private volatile boolean cancelled;
+
+ private final int cancellationCheckInterval;
+
+ private final long cancellationCheckIntervalBytes;
+
+ private long lastCancellationCheckBatchSize = 0L;
+
public RocksDBWriteBatchWrapper(@Nonnull RocksDB rocksDB, long writeBatchSize) {
this(rocksDB, null, 500, writeBatchSize);
}
@@ -79,6 +92,22 @@
@Nullable WriteOptions options,
int capacity,
long batchSize) {
+ this(
+ rocksDB,
+ options,
+ capacity,
+ batchSize,
+ DEFAULT_CANCELLATION_CHECK_INTERVAL,
+ DEFAULT_CANCELLATION_CHECK_INTERVAL_BYTES);
+ }
+
+ public RocksDBWriteBatchWrapper(
+ @Nonnull RocksDB rocksDB,
+ @Nullable WriteOptions options,
+ int capacity,
+ long batchSize,
+ int cancellationCheckInterval,
+ long cancellationCheckIntervalBytes) {
Preconditions.checkArgument(
capacity >= MIN_CAPACITY && capacity <= MAX_CAPACITY,
"capacity should be between " + MIN_CAPACITY + " and " + MAX_CAPACITY);
@@ -104,16 +133,27 @@
// We own this object, so we must ensure that we close it.
this.toClose.add(this.options);
}
+ this.cancellationCheckInterval = cancellationCheckInterval;
+ this.cancellationCheckIntervalBytes = cancellationCheckIntervalBytes;
}
public void put(@Nonnull ColumnFamilyHandle handle, @Nonnull byte[] key, @Nonnull byte[] value)
throws RocksDBException {
+ maybeEnsureNotCancelled();
batch.put(handle, key, value);
flushIfNeeded();
}
+ private void maybeEnsureNotCancelled() {
+ if (batch.count() % cancellationCheckInterval == 0
+ || batch.getDataSize() - lastCancellationCheckBatchSize
+ >= cancellationCheckIntervalBytes) {
+ ensureNotCancelled();
+ }
+ }
+
public void remove(@Nonnull ColumnFamilyHandle handle, @Nonnull byte[] key)
throws RocksDBException {
@@ -123,8 +163,10 @@
}
public void flush() throws RocksDBException {
+ ensureNotCancelled();
db.write(options, batch);
batch.clear();
+ lastCancellationCheckBatchSize = 0;
}
@VisibleForTesting
@@ -132,9 +174,18 @@
return options;
}
+ public void markCancelled() {
+ this.cancelled = true;
+ }
+
+ public Closeable getCancelCloseable() {
+ return this::markCancelled;
+ }
+
@Override
public void close() throws RocksDBException {
try {
+ ensureNotCancelled();
if (batch.count() != 0) {
flush();
}
@@ -143,6 +194,13 @@
}
}
+ private void ensureNotCancelled() {
+ if (cancelled) {
+ throw new CancelTaskException();
+ }
+ lastCancellationCheckBatchSize = batch.getDataSize();
+ }
+
private void flushIfNeeded() throws RocksDBException {
boolean needFlush =
batch.count() == capacity || (batchSize > 0 && getDataSize() >= batchSize);
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
index fdaf21f..052404b 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
@@ -52,8 +52,6 @@
import java.util.Map;
import java.util.function.Function;
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
-
/** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
public class RocksDBFullRestoreOperation<K> implements RocksDBRestoreOperation {
private final FullSnapshotRestoreOperation<K> savepointRestoreOperation;
@@ -149,7 +147,7 @@
new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
Closeable ignored =
cancelStreamRegistryForRestore.registerCloseableTemporarily(
- asCloseable(writeBatchWrapper))) {
+ writeBatchWrapper.getCancelCloseable())) {
ColumnFamilyHandle handle = null;
while (keyGroups.hasNext()) {
KeyGroup keyGroup = keyGroups.next();
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
index c4eed5f..da2faca 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
@@ -64,8 +64,6 @@
import java.util.Map;
import java.util.function.Function;
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
-
/** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
public class RocksDBHeapTimersFullRestoreOperation<K> implements RocksDBRestoreOperation {
private final FullSnapshotRestoreOperation<K> savepointRestoreOperation;
@@ -194,7 +192,7 @@
new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
Closeable ignored =
cancelStreamRegistryForRestore.registerCloseableTemporarily(
- asCloseable(writeBatchWrapper))) {
+ writeBatchWrapper.getCancelCloseable())) {
HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement> restoredPQ = null;
ColumnFamilyHandle handle = null;
while (keyGroups.hasNext()) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
index 6d7ed6e..89ec9d1 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
@@ -90,7 +90,6 @@
import java.util.function.Function;
import java.util.stream.Collectors;
-import static org.apache.flink.core.fs.ICloseableRegistry.asCloseable;
import static org.apache.flink.runtime.metrics.MetricNames.DOWNLOAD_STATE_DURATION;
import static org.apache.flink.runtime.metrics.MetricNames.RESTORE_ASYNC_COMPACTION_DURATION;
import static org.apache.flink.runtime.metrics.MetricNames.RESTORE_STATE_DURATION;
@@ -798,7 +797,7 @@
new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize);
Closeable ignored =
cancelStreamRegistryForRestore.registerCloseableTemporarily(
- asCloseable(writeBatchWrapper))) {
+ writeBatchWrapper.getCancelCloseable())) {
for (IncrementalLocalKeyedStateHandle handleToCopy : toImport) {
try (RestoredDBInstance restoredDBInstance =
restoreTempDBInstanceFromLocalState(handleToCopy)) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
index 6a7c95a..05465ca 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBWriteBatchWrapperTest.java
@@ -19,6 +19,8 @@
package org.apache.flink.contrib.streaming.state;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.execution.CancelTaskException;
import org.junit.Assert;
import org.junit.Rule;
@@ -29,11 +31,14 @@
import org.rocksdb.RocksDB;
import org.rocksdb.WriteOptions;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import static org.apache.flink.contrib.streaming.state.RocksDBConfigurableOptions.WRITE_BATCH_SIZE;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@@ -43,6 +48,58 @@
@Rule public TemporaryFolder folder = new TemporaryFolder();
+ @Test(expected = CancelTaskException.class)
+ public void testAsyncCancellation() throws Exception {
+ final CompletableFuture<Void> writeStartedFuture = new CompletableFuture<>();
+ final CompletableFuture<Void> cancellationRequestedFuture = new CompletableFuture<>();
+ final CloseableRegistry registry = new CloseableRegistry();
+ new Thread(
+ () -> {
+ writeStartedFuture.join();
+ try {
+ registry.close();
+ cancellationRequestedFuture.complete(null);
+ } catch (IOException e) {
+ cancellationRequestedFuture.completeExceptionally(e);
+ }
+ })
+ .start();
+
+ final int capacity = 1000; // max
+ final int cancellationCheckInterval = 1;
+ long batchSizeBytes = WRITE_BATCH_SIZE.defaultValue().getBytes();
+
+ try (RocksDB db = RocksDB.open(folder.newFolder().getAbsolutePath());
+ WriteOptions options = new WriteOptions().setDisableWAL(true);
+ ColumnFamilyHandle handle =
+ db.createColumnFamily(new ColumnFamilyDescriptor("test".getBytes()));
+ RocksDBWriteBatchWrapper writeBatchWrapper =
+ new RocksDBWriteBatchWrapper(
+ db,
+ options,
+ capacity,
+ batchSizeBytes,
+ cancellationCheckInterval,
+ batchSizeBytes)) {
+ registry.registerCloseable(writeBatchWrapper.getCancelCloseable());
+ writeStartedFuture.complete(null);
+
+ //noinspection InfiniteLoopStatement
+ for (int i = 0; ; i++) {
+ try {
+ writeBatchWrapper.put(
+ handle, ("key:" + i).getBytes(), ("value:" + i).getBytes());
+ } catch (Exception e) {
+ cancellationRequestedFuture.join(); // shouldn't have any errors
+ throw e;
+ }
+ // make sure that cancellation is triggered earlier than periodic flush
+ // but allow some delay of cancellation propagation
+ assertThat(i).isLessThan(cancellationCheckInterval * 2);
+ }
+ }
+ }
+
@Test
public void basicTest() throws Exception {