RATIS-1504. Add timeout handling to DataStreamManagement#checkSuccessRemoteWrite. (#1064)

* RATIS-1504. Add timeout handling to DataStreamManagement#checkSuccessRemoteWrite.
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
index fc97b6f..4c49b1d 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java
@@ -30,11 +30,10 @@
 
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
 
 public class NettyClientReplies {
   public static final Logger LOG = LoggerFactory.getLogger(NettyClientReplies.class);
@@ -56,8 +55,8 @@
 
     ReplyEntry submitRequest(RequestEntry requestEntry, boolean isClose, CompletableFuture<DataStreamReply> f) {
       LOG.debug("put {} to the map for {}", requestEntry, clientInvocationId);
-      final MemoizedSupplier<ReplyEntry> replySupplier = MemoizedSupplier.valueOf(() -> new ReplyEntry(isClose, f));
-      return map.computeIfAbsent(requestEntry, r -> replySupplier.get());
+      // ConcurrentHashMap.computeIfAbsent javadoc: the function is applied at most once per key.
+      return map.computeIfAbsent(requestEntry, r -> new ReplyEntry(isClose, f));
     }
 
     void receiveReply(DataStreamReply reply) {
@@ -147,7 +146,7 @@
   static class ReplyEntry {
     private final boolean isClosed;
     private final CompletableFuture<DataStreamReply> replyFuture;
-    private final AtomicReference<ScheduledFuture<?>> timeoutFuture = new AtomicReference<>();
+    private ScheduledFuture<?> timeoutFuture; // for reply timeout
 
     ReplyEntry(boolean isClosed, CompletableFuture<DataStreamReply> replyFuture) {
       this.isClosed = isClosed;
@@ -158,22 +157,26 @@
       return isClosed;
     }
 
-    void complete(DataStreamReply reply) {
-      cancelTimeoutFuture();
+    synchronized void complete(DataStreamReply reply) {
+      cancel(timeoutFuture);
       replyFuture.complete(reply);
     }
 
-    void completeExceptionally(Throwable t) {
-      cancelTimeoutFuture();
+    synchronized void completeExceptionally(Throwable t) {
+      cancel(timeoutFuture);
       replyFuture.completeExceptionally(t);
     }
 
-    private void cancelTimeoutFuture() {
-      Optional.ofNullable(timeoutFuture.get()).ifPresent(f -> f.cancel(false));
+    static void cancel(ScheduledFuture<?> future) {
+      if (future != null) {
+        future.cancel(true);
+      }
     }
 
-    void setTimeoutFuture(ScheduledFuture<?> timeoutFuture) {
-      this.timeoutFuture.compareAndSet(null, timeoutFuture);
+    synchronized void scheduleTimeout(Supplier<ScheduledFuture<?>> scheduleMethod) {
+      if (!replyFuture.isDone()) {
+        timeoutFuture = scheduleMethod.get();
+      }
     }
   }
 }
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
index b2dc381..534fcc5 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
@@ -53,7 +53,6 @@
 import org.apache.ratis.thirdparty.io.netty.handler.codec.ByteToMessageDecoder;
 import org.apache.ratis.thirdparty.io.netty.handler.codec.MessageToMessageEncoder;
 import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
-import org.apache.ratis.thirdparty.io.netty.util.concurrent.ScheduledFuture;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.NetUtils;
@@ -466,15 +465,13 @@
         LOG.debug("{}: write after {}", this, request);
 
         final TimeDuration timeout = isClose ? closeTimeout : requestTimeout;
-        // if reply success cancel this future
-        final ScheduledFuture<?> timeoutFuture = channel.eventLoop().schedule(() -> {
+        replyEntry.scheduleTimeout(() -> channel.eventLoop().schedule(() -> {
           if (!f.isDone()) {
             f.completeExceptionally(new TimeoutIOException(
-                "Timeout " + timeout + ": Failed to send " + request + " channel: " + channel));
+                "Timeout " + timeout + ": Failed to send " + request + " via channel " + channel));
             replyMap.fail(requestEntry);
           }
-        }, timeout.toLong(timeout.getUnit()), timeout.getUnit());
-        replyEntry.setTimeoutFuture(timeoutFuture);
+        }, timeout.getDuration(), timeout.getUnit()));
       }
     });
     return f;
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
index e265d8b..74d5cd7 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java
@@ -18,6 +18,7 @@
 
 package org.apache.ratis.netty.server;
 
+import org.apache.ratis.client.RaftClientConfigKeys;
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
 import org.apache.ratis.conf.RaftProperties;
@@ -219,6 +220,7 @@
   private final ChannelMap channels;
   private final ExecutorService requestExecutor;
   private final ExecutorService writeExecutor;
+  private final TimeDuration requestTimeout;
 
   private final NettyServerStreamRpcMetrics nettyServerStreamRpcMetrics;
 
@@ -235,6 +237,7 @@
     this.writeExecutor = ConcurrentUtils.newThreadPoolWithMax(useCachedThreadPool,
           RaftServerConfigKeys.DataStream.asyncWriteThreadPoolSize(properties),
           name + "-write-");
+    this.requestTimeout = RaftClientConfigKeys.DataStream.requestTimeout(server.getProperties());
 
     this.nettyServerStreamRpcMetrics = metrics;
   }
@@ -339,7 +342,7 @@
         .build();
   }
 
-  static void sendReply(List<CompletableFuture<DataStreamReply>> remoteWrites,
+  private void sendReply(List<CompletableFuture<DataStreamReply>> remoteWrites,
       DataStreamRequestByteBuf request, long bytesWritten, Collection<CommitInfoProto> commitInfos,
       ChannelHandlerContext ctx) {
     final boolean success = checkSuccessRemoteWrite(remoteWrites, bytesWritten, request);
@@ -493,10 +496,15 @@
     Preconditions.assertTrue(request.getStreamOffset() == reply.getStreamOffset());
   }
 
-  static boolean checkSuccessRemoteWrite(List<CompletableFuture<DataStreamReply>> replyFutures, long bytesWritten,
+  private boolean checkSuccessRemoteWrite(List<CompletableFuture<DataStreamReply>> replyFutures, long bytesWritten,
       final DataStreamRequestByteBuf request) {
     for (CompletableFuture<DataStreamReply> replyFuture : replyFutures) {
-      final DataStreamReply reply = replyFuture.join();
+      final DataStreamReply reply;
+      try {
+        reply = replyFuture.get(requestTimeout.getDuration(), requestTimeout.getUnit());
+      } catch (Exception e) {
+        throw new CompletionException("Failed to get reply for bytesWritten=" + bytesWritten + ", " + request, e);
+      }
       assertReplyCorrespondingToRequest(request, reply);
       if (!reply.isSuccess()) {
         LOG.warn("reply is not success, request: {}", request);