[FLINK-35768][state,rocksdb] Refactor RocksDBStateDownloader
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
index 6c319cb..ec7cd94 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
@@ -20,6 +20,7 @@
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FileUtils;
@@ -33,9 +34,11 @@
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -68,11 +71,17 @@
// Make sure we also react to external close signals.
closeableRegistry.registerCloseable(internalCloser);
try {
- List<CompletableFuture<Void>> futures =
- transferAllStateDataToDirectoryAsync(downloadRequests, internalCloser)
- .collect(Collectors.toList());
- // Wait until either all futures completed successfully or one failed exceptionally.
- FutureUtils.completeAll(futures).get();
+ // We have to wait for all futures to be completed, to make sure in
+ // case of failure that we will clean up all the files
+ FutureUtils.completeAll(
+ createDownloadRunnables(downloadRequests, internalCloser).stream()
+ .map(
+ runnable ->
+ CompletableFuture.runAsync(
+ runnable,
+ transfer.getExecutorService()))
+ .collect(Collectors.toList()))
+ .get();
} catch (Exception e) {
downloadRequests.stream()
.map(StateHandleDownloadSpec::getDownloadDestination)
@@ -94,46 +103,39 @@
}
}
- /** Asynchronously runs the specified download requests on executorService. */
- private Stream<CompletableFuture<Void>> transferAllStateDataToDirectoryAsync(
- Collection<StateHandleDownloadSpec> handleWithPaths,
+ private List<Runnable> createDownloadRunnables(
+ Collection<StateHandleDownloadSpec> downloadRequests,
CloseableRegistry closeableRegistry) {
- return handleWithPaths.stream()
- .flatMap(
- downloadRequest ->
- // Take all files from shared and private state.
- Stream.concat(
- downloadRequest.getStateHandle().getSharedState()
- .stream(),
- downloadRequest.getStateHandle().getPrivateState()
- .stream())
- .map(
- // Create one runnable for each StreamStateHandle
- entry -> {
- String localPath = entry.getLocalPath();
- StreamStateHandle remoteFileHandle =
- entry.getHandle();
- Path downloadDest =
- downloadRequest
- .getDownloadDestination()
- .resolve(localPath);
- return ThrowingRunnable.unchecked(
- () ->
- downloadDataForStateHandle(
- downloadDest,
- remoteFileHandle,
- closeableRegistry));
- }))
- .map(
- runnable ->
- CompletableFuture.runAsync(
- runnable, transfer.getExecutorService()));
+ List<Runnable> runnables = new ArrayList<>();
+ for (StateHandleDownloadSpec downloadRequest : downloadRequests) {
+ Stream.concat(
+ downloadRequest.getStateHandle().getSharedState().stream(),
+ downloadRequest.getStateHandle().getPrivateState().stream())
+ .map(
+ handleAndLocalPath ->
+ runnables.add(
+ createDownloadRunnableUsingStreams(
+ handleAndLocalPath.getHandle(),
+ downloadRequest
+ .getDownloadDestination()
+ .resolve(handleAndLocalPath.getLocalPath()),
+ closeableRegistry)));
+ }
+ return runnables;
+ }
+
+ private Runnable createDownloadRunnableUsingStreams(
+ StreamStateHandle remoteFileHandle,
+ Path destination,
+ CloseableRegistry closeableRegistry) {
+ return ThrowingRunnable.unchecked(
+ () -> downloadDataForStateHandle(remoteFileHandle, destination, closeableRegistry));
}
/** Copies the file from a single state handle to the given path. */
private void downloadDataForStateHandle(
- Path restoreFilePath,
StreamStateHandle remoteFileHandle,
+ Path restoreFilePath,
CloseableRegistry closeableRegistry)
throws IOException {