[#1564] fix(server): disk health check invalid when hang (#1568)

### What changes were proposed in this pull request?

Using the completefuture to check whether the disk checker timeout. If it is, it will directly mark disk corrupted

### Why are the changes needed?

Fix: #1564 

### Does this PR introduce _any_ user-facing change?

Yes. Introduing the conf: `rss.server.health.checker. localStorageExecutionTimeoutMS`. unit is ms, default value is 1minute

### How was this patch tested?

unit tests. But I will apply this to our online env.
diff --git a/common/src/main/java/org/apache/uniffle/common/future/CompletableFutureExtension.java b/common/src/main/java/org/apache/uniffle/common/future/CompletableFutureExtension.java
new file mode 100644
index 0000000..22948c8
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/future/CompletableFutureExtension.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.future;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.function.BiConsumer;
+
+public class CompletableFutureExtension {
+  public static <T> CompletableFuture<T> orTimeout(
+      CompletableFuture<T> future, long timeout, TimeUnit unit) {
+    if (future.isDone()) {
+      return future;
+    }
+
+    return future.whenComplete(new Canceller(Delayer.delay(new Timeout(future), timeout, unit)));
+  }
+
+  static final class Timeout implements Runnable {
+    final CompletableFuture<?> future;
+
+    Timeout(CompletableFuture<?> future) {
+      this.future = future;
+    }
+
+    public void run() {
+      if (null != future && !future.isDone()) {
+        future.completeExceptionally(new TimeoutException());
+      }
+    }
+  }
+
+  static final class Canceller implements BiConsumer<Object, Throwable> {
+    final Future<?> future;
+
+    Canceller(Future<?> future) {
+      this.future = future;
+    }
+
+    public void accept(Object ignore, Throwable ex) {
+      if (null == ex && null != future && !future.isDone()) {
+        future.cancel(false);
+      }
+    }
+  }
+
+  static final class Delayer {
+    static ScheduledFuture<?> delay(Runnable command, long delay, TimeUnit unit) {
+      return delayer.schedule(command, delay, unit);
+    }
+
+    static final class DaemonThreadFactory implements ThreadFactory {
+      public Thread newThread(Runnable r) {
+        Thread t = new Thread(r);
+        t.setDaemon(true);
+        t.setName("CompletableFutureExtensionDelayScheduler");
+        return t;
+      }
+    }
+
+    static final ScheduledThreadPoolExecutor delayer;
+
+    static {
+      delayer = new ScheduledThreadPoolExecutor(1, new DaemonThreadFactory());
+      delayer.setRemoveOnCancelPolicy(true);
+    }
+  }
+}
diff --git a/common/src/test/java/org/apache/uniffle/common/future/CompletableFutureExtensionTest.java b/common/src/test/java/org/apache/uniffle/common/future/CompletableFutureExtensionTest.java
new file mode 100644
index 0000000..ec12000
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/future/CompletableFutureExtensionTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.future;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class CompletableFutureExtensionTest {
+
+  @Test
+  public void timeoutExceptionTest() throws ExecutionException, InterruptedException {
+    // case1
+    CompletableFuture<Integer> future =
+        CompletableFuture.supplyAsync(
+            () -> {
+              try {
+                Thread.sleep(2000);
+              } catch (InterruptedException e) {
+                e.printStackTrace();
+              }
+              return 10;
+            });
+
+    CompletableFuture<Integer> wrap =
+        CompletableFutureExtension.orTimeout(future, 1, TimeUnit.SECONDS);
+    try {
+      wrap.get();
+      fail();
+    } catch (Exception e) {
+      if (!(e instanceof ExecutionException) || !(e.getCause() instanceof TimeoutException)) {
+        fail();
+      }
+    }
+
+    // case2
+    future =
+        CompletableFuture.supplyAsync(
+            () -> {
+              try {
+                Thread.sleep(2000);
+              } catch (InterruptedException e) {
+                e.printStackTrace();
+              }
+              return 10;
+            });
+    wrap = CompletableFutureExtension.orTimeout(future, 3, TimeUnit.SECONDS);
+    assertEquals(10, wrap.get());
+  }
+}
diff --git a/server/src/main/java/org/apache/uniffle/server/LocalStorageChecker.java b/server/src/main/java/org/apache/uniffle/server/LocalStorageChecker.java
index ac30baf..0554035 100644
--- a/server/src/main/java/org/apache/uniffle/server/LocalStorageChecker.java
+++ b/server/src/main/java/org/apache/uniffle/server/LocalStorageChecker.java
@@ -21,8 +21,15 @@
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.List;
-import java.util.concurrent.CountDownLatch;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 
@@ -34,6 +41,8 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.future.CompletableFutureExtension;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.storage.common.LocalStorage;
 import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -48,8 +57,10 @@
   private final double diskMaxUsagePercentage;
   private final double diskRecoveryUsagePercentage;
   private final double minStorageHealthyPercentage;
-  private final List<StorageInfo> storageInfos = Lists.newArrayList();
+  protected List<StorageInfo> storageInfos = Lists.newArrayList();
   private boolean isHealthy = true;
+  private ExecutorService workers;
+  private final long diskCheckerExecutionTimeoutMs;
 
   public LocalStorageChecker(ShuffleServerConf conf, List<LocalStorage> storages) {
     super(conf);
@@ -72,6 +83,10 @@
         conf.getDouble(ShuffleServerConf.HEALTH_STORAGE_RECOVERY_USAGE_PERCENTAGE);
     this.minStorageHealthyPercentage =
         conf.getDouble(ShuffleServerConf.HEALTH_MIN_STORAGE_PERCENTAGE);
+
+    this.diskCheckerExecutionTimeoutMs =
+        conf.getLong(ShuffleServerConf.HEALTH_CHECKER_LOCAL_STORAGE_EXECUTE_TIMEOUT);
+    this.workers = Executors.newFixedThreadPool(basePaths.size());
   }
 
   @Override
@@ -81,36 +96,60 @@
     AtomicLong wholeDiskUsedSpace = new AtomicLong(0L);
     AtomicLong serviceUsedSpace = new AtomicLong(0L);
     AtomicInteger corruptedDirs = new AtomicInteger(0);
-    CountDownLatch cdl = new CountDownLatch(storageInfos.size());
-    storageInfos
-        .parallelStream()
-        .forEach(
-            storageInfo -> {
-              if (!storageInfo.checkStorageReadAndWrite()) {
-                storageInfo.markCorrupted();
-                corruptedDirs.incrementAndGet();
-                cdl.countDown();
-                return;
-              }
 
-              long total = getTotalSpace(storageInfo.storageDir);
-              long free = getFreeSpace(storageInfo.storageDir);
+    Map<StorageInfo, CompletableFuture<Void>> futureMap = new HashMap<>();
+    for (StorageInfo storageInfo : storageInfos) {
+      CompletableFuture<Void> storageCheckFuture =
+          CompletableFuture.supplyAsync(
+              () -> {
+                if (!storageInfo.checkStorageReadAndWrite()) {
+                  storageInfo.markCorrupted();
+                  corruptedDirs.incrementAndGet();
+                  return null;
+                }
 
-              totalSpace.addAndGet(total);
-              wholeDiskUsedSpace.addAndGet(total - free);
-              serviceUsedSpace.addAndGet(getServiceUsedSpace(storageInfo.storageDir));
+                long total = getTotalSpace(storageInfo.storageDir);
+                long free = getFreeSpace(storageInfo.storageDir);
 
-              storageInfo.updateStorageFreeSpace(free);
-              if (storageInfo.checkIsSpaceEnough(total, free)) {
-                num.incrementAndGet();
-              }
-              cdl.countDown();
-            });
-    try {
-      cdl.await();
-    } catch (InterruptedException e) {
-      LOG.error("Failed to check local storage!");
+                totalSpace.addAndGet(total);
+                wholeDiskUsedSpace.addAndGet(total - free);
+                serviceUsedSpace.addAndGet(getServiceUsedSpace(storageInfo.storageDir));
+
+                storageInfo.updateStorageFreeSpace(free);
+                if (storageInfo.checkIsSpaceEnough(total, free)) {
+                  num.incrementAndGet();
+                }
+                return null;
+              },
+              workers);
+
+      futureMap.put(
+          storageInfo,
+          CompletableFutureExtension.orTimeout(
+              storageCheckFuture, diskCheckerExecutionTimeoutMs, TimeUnit.MILLISECONDS));
     }
+
+    for (Map.Entry<StorageInfo, CompletableFuture<Void>> entry : futureMap.entrySet()) {
+      StorageInfo storageInfo = entry.getKey();
+      CompletableFuture<Void> f = entry.getValue();
+
+      try {
+        f.get();
+      } catch (Exception e) {
+        if (e instanceof ExecutionException) {
+          if (e.getCause() instanceof TimeoutException) {
+            storageInfo.markCorrupted();
+            LOG.error(
+                "Timeout of checking local storage: {}. This should not happen and mark this disk corrupted.",
+                storageInfo.storage.getBasePath());
+            continue;
+          }
+        }
+
+        throw new RssException(e);
+      }
+    }
+
     ShuffleServerMetrics.gaugeLocalStorageTotalSpace.set(totalSpace.get());
     ShuffleServerMetrics.gaugeLocalStorageWholeDiskUsedSpace.set(wholeDiskUsedSpace.get());
     ShuffleServerMetrics.gaugeLocalStorageServiceUsedSpace.set(serviceUsedSpace.get());
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
index f968a17..a6cd265 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
@@ -306,6 +306,13 @@
           .defaultValue(5000L)
           .withDescription("The health script file execute timeout ms.");
 
+  public static final ConfigOption<Long> HEALTH_CHECKER_LOCAL_STORAGE_EXECUTE_TIMEOUT =
+      ConfigOptions.key("rss.server.health.checker.localStorageExecutionTimeoutMS")
+          .longType()
+          .defaultValue(1000 * 60L)
+          .withDescription(
+              "The health checker for LocalStorageChecker execution timeout (Unit: ms). Default value is 1min");
+
   public static final ConfigOption<Double> SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE =
       ConfigOptions.key("rss.server.memory.shuffle.lowWaterMark.percentage")
           .doubleType()
diff --git a/server/src/test/java/org/apache/uniffle/server/LocalStorageCheckerTest.java b/server/src/test/java/org/apache/uniffle/server/LocalStorageCheckerTest.java
index e5eaba1..9336ce9 100644
--- a/server/src/test/java/org/apache/uniffle/server/LocalStorageCheckerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/LocalStorageCheckerTest.java
@@ -20,12 +20,81 @@
 import java.io.File;
 import java.io.IOException;
 import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
 
+import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.api.io.TempDir;
 
+import org.apache.uniffle.common.StorageType;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.storage.common.LocalStorage;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+
 public class LocalStorageCheckerTest {
+  @BeforeAll
+  public static void setup() {
+    ShuffleServerMetrics.register();
+  }
+
+  @AfterAll
+  public static void clear() {
+    ShuffleServerMetrics.clear();
+  }
+
+  private class SlowDiskStorageChecker extends LocalStorageChecker {
+    private long hangTimeSec;
+
+    SlowDiskStorageChecker(ShuffleServerConf conf, List<LocalStorage> storages, long hangTimeSec) {
+      super(conf, storages);
+      this.hangTimeSec = hangTimeSec;
+
+      List<StorageInfo> storageInfoList =
+          storages.stream().map(x -> new SlowStorageInfo(x)).collect(Collectors.toList());
+      super.storageInfos = storageInfoList;
+    }
+
+    private class SlowStorageInfo extends StorageInfo {
+
+      SlowStorageInfo(LocalStorage storage) {
+        super(storage);
+      }
+
+      @Override
+      public boolean checkStorageReadAndWrite() {
+        try {
+          Thread.sleep(hangTimeSec * 1000);
+        } catch (InterruptedException e) {
+          throw new RuntimeException(e);
+        }
+        return true;
+      }
+    }
+  }
+
+  @Test
+  @Timeout(10)
+  public void testCheckingStorageHang(@TempDir File tempDir) {
+    String basePath = tempDir.getAbsolutePath();
+
+    ShuffleServerConf conf = new ShuffleServerConf();
+    conf.set(RssBaseConf.RSS_STORAGE_BASE_PATH, Arrays.asList(basePath));
+    conf.set(RssBaseConf.RSS_STORAGE_TYPE, StorageType.LOCALFILE);
+    conf.set(ShuffleServerConf.HEALTH_CHECKER_LOCAL_STORAGE_EXECUTE_TIMEOUT, 2 * 1000L);
+
+    LocalStorage localStorage =
+        LocalStorage.newBuilder().basePath(tempDir.getAbsolutePath()).capacity(100000L).build();
+
+    SlowDiskStorageChecker checker =
+        new SlowDiskStorageChecker(conf, Arrays.asList(localStorage), 600);
+    assertFalse(checker.checkIsHealthy());
+  }
 
   @Test
   public void testGetUniffleUsedSpace(@TempDir File tempDir) throws IOException {