RATIS-1989. Intermittent timeout in TestStreamObserverWithTimeout (#1012)

diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
index dd4e199..970134d 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
@@ -139,7 +139,7 @@
   StreamObserver<InstallSnapshotRequestProto> installSnapshot(
       String name, TimeDuration timeout, int limit, StreamObserver<InstallSnapshotReplyProto> responseHandler) {
     return StreamObserverWithTimeout.newInstance(name, ServerStringUtils::toInstallSnapshotRequestString,
-        timeout, limit, i -> asyncStub.withInterceptors(i).installSnapshot(responseHandler));
+        () -> timeout, limit, i -> asyncStub.withInterceptors(i).installSnapshot(responseHandler));
   }
 
   // short-circuit the backoff timer and make them reconnect immediately.
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/StreamObserverWithTimeout.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/StreamObserverWithTimeout.java
index ff89e7d..3cc754e 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/StreamObserverWithTimeout.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/StreamObserverWithTimeout.java
@@ -32,13 +32,14 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import java.util.function.IntSupplier;
+import java.util.function.Supplier;
 
 public final class StreamObserverWithTimeout<T> implements StreamObserver<T> {
   public static final Logger LOG = LoggerFactory.getLogger(StreamObserverWithTimeout.class);
 
   public static <T> StreamObserverWithTimeout<T> newInstance(
       String name, Function<T, String> request2String,
-      TimeDuration timeout, int outstandingLimit,
+      Supplier<TimeDuration> timeout, int outstandingLimit,
       Function<ClientInterceptor, StreamObserver<T>> newStreamObserver) {
     final AtomicInteger responseCount = new AtomicInteger();
     final ResourceSemaphore semaphore = outstandingLimit > 0? new ResourceSemaphore(outstandingLimit): null;
@@ -55,7 +56,7 @@
   private final String name;
   private final Function<T, String> requestToStringFunction;
 
-  private final TimeDuration timeout;
+  private final Supplier<TimeDuration> timeoutSupplier;
   private final StreamObserver<T> observer;
   private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance();
 
@@ -65,17 +66,18 @@
   private final ResourceSemaphore semaphore;
 
   private StreamObserverWithTimeout(String name, Function<T, String> requestToStringFunction,
-      TimeDuration timeout, IntSupplier responseCount, ResourceSemaphore semaphore, StreamObserver<T> observer) {
+      Supplier<TimeDuration> timeoutSupplier, IntSupplier responseCount, ResourceSemaphore semaphore,
+      StreamObserver<T> observer) {
     this.name = JavaUtils.getClassSimpleName(getClass()) + "-" + name;
     this.requestToStringFunction = requestToStringFunction;
 
-    this.timeout = timeout;
+    this.timeoutSupplier = timeoutSupplier;
     this.responseCount = responseCount;
     this.semaphore = semaphore;
     this.observer = observer;
   }
 
-  private void acquire(StringSupplier request) {
+  private void acquire(StringSupplier request, TimeDuration timeout) {
     if (semaphore == null) {
       return;
     }
@@ -96,14 +98,16 @@
   @Override
   public void onNext(T request) {
     final StringSupplier requestString = StringSupplier.get(() -> requestToStringFunction.apply(request));
-    acquire(requestString);
+    final TimeDuration timeout = timeoutSupplier.get();
+    acquire(requestString, timeout);
     observer.onNext(request);
     final int id = requestCount.incrementAndGet();
-    scheduler.onTimeout(timeout, () -> handleTimeout(id, requestString),
+    LOG.debug("{}: send {} with timeout={}: {}", name, id, timeout, requestString);
+    scheduler.onTimeout(timeout, () -> handleTimeout(id, timeout, requestString),
         LOG, () -> name + ": Timeout check failed for request: " + requestString);
   }
 
-  private void handleTimeout(int id, StringSupplier request) {
+  private void handleTimeout(int id, TimeDuration timeout, StringSupplier request) {
     if (id > responseCount.getAsInt()) {
       onError(new TimeoutIOException(name + ": Timed out " + timeout + " for sending request " + request));
     }
diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestClient.java b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestClient.java
index 130c05e..ca8957e 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestClient.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestClient.java
@@ -37,6 +37,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiFunction;
 
 /** gRPC client for testing */
@@ -55,8 +56,11 @@
 
   static StreamObserverFactory withTimeout(TimeDuration timeout) {
     final String className = JavaUtils.getClassSimpleName(HelloRequest.class) + ":";
+    final AtomicBoolean initialized = new AtomicBoolean();
     return (stub, responseHandler) -> StreamObserverWithTimeout.newInstance("test",
-        r -> className + r.getName(), timeout, 2,
+        r -> className + r.getName(),
+        () -> initialized.getAndSet(true) ? timeout : TimeDuration.ONE_MINUTE.add(timeout),
+        2,
         i -> stub.withInterceptors(i).hello(responseHandler));
   }
 
diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestServer.java b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestServer.java
index d497ac1..345c565 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestServer.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcTestServer.java
@@ -23,6 +23,8 @@
 import org.apache.ratis.thirdparty.io.grpc.Server;
 import org.apache.ratis.thirdparty.io.grpc.ServerBuilder;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.thirdparty.io.netty.util.concurrent.ThreadPerTaskExecutor;
+import org.apache.ratis.util.Daemon;
 import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.TimeDuration;
 import org.slf4j.Logger;
@@ -31,16 +33,22 @@
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
 
 /** gRPC server for testing */
 class GrpcTestServer implements Closeable {
   private static final Logger LOG = LoggerFactory.getLogger(GrpcTestServer.class);
+  private static final AtomicLong COUNTER = new AtomicLong();
 
   private final Server server;
 
-  GrpcTestServer(int port, int slow, TimeDuration timeout) {
+  GrpcTestServer(int port, int warmup, int slow, TimeDuration timeout) {
     this.server = ServerBuilder.forPort(port)
-        .addService(new GreeterImpl(slow, timeout))
+        .executor(new ThreadPerTaskExecutor(r -> Daemon.newBuilder()
+            .setName("test-server-" + COUNTER.getAndIncrement())
+            .setRunnable(r)
+            .build()))
+        .addService(new GreeterImpl(warmup, slow, timeout))
         .build();
   }
 
@@ -64,14 +72,16 @@
       return ") Hello " + request;
     }
 
+    private final int warmup;
     private final int slow;
     private final TimeDuration shortSleepTime;
     private final TimeDuration longSleepTime;
     private int count = 0;
 
-    GreeterImpl(int slow, TimeDuration timeout) {
+    GreeterImpl(int warmup, int slow, TimeDuration timeout) {
+      this.warmup = warmup;
       this.slow = slow;
-      this.shortSleepTime = timeout.multiply(0.1);
+      this.shortSleepTime = timeout.multiply(0.25);
       this.longSleepTime = timeout.multiply(2);
     }
 
@@ -81,7 +91,8 @@
         @Override
         public void onNext(HelloRequest helloRequest) {
           final String reply = count + toReplySuffix(helloRequest.getName());
-          final TimeDuration sleepTime = count < slow ? shortSleepTime : longSleepTime;
+          final TimeDuration sleepTime = count < warmup ? TimeDuration.ZERO :
+              count < (warmup + slow) ? shortSleepTime : longSleepTime;
           LOG.info("count = {}, slow = {}, sleep {}", reply, slow, sleepTime);
           try {
             sleepTime.sleep();
@@ -105,4 +116,4 @@
       };
     }
   }
-}
+}
\ No newline at end of file
diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestStreamObserverWithTimeout.java b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestStreamObserverWithTimeout.java
index 7a32fb9..d0c936a 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestStreamObserverWithTimeout.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestStreamObserverWithTimeout.java
@@ -24,6 +24,7 @@
 import org.apache.ratis.util.Slf4jUtils;
 import org.apache.ratis.util.StringUtils;
 import org.apache.ratis.util.TimeDuration;
+import org.apache.ratis.util.TimeoutTimer;
 import org.junit.Assert;
 import org.junit.Test;
 import org.slf4j.event.Level;
@@ -37,6 +38,8 @@
 public class TestStreamObserverWithTimeout extends BaseTest {
   {
     Slf4jUtils.setLogLevel(ResponseNotifyClientInterceptor.LOG, Level.TRACE);
+    Slf4jUtils.setLogLevel(StreamObserverWithTimeout.LOG, Level.DEBUG);
+    Slf4jUtils.setLogLevel(TimeoutTimer.LOG, Level.DEBUG);
   }
 
   enum Type {
@@ -57,14 +60,14 @@
   @Test
   public void testWithDeadline() throws Exception {
     //the total sleep time is within the deadline
-    runTestTimeout(8, Type.WithDeadline);
+    runTestTimeout(2, Type.WithDeadline);
   }
 
   @Test
   public void testWithDeadlineFailure() {
     //Expected to have DEADLINE_EXCEEDED
     testFailureCase("total sleep time is longer than the deadline",
-        () -> runTestTimeout(12, Type.WithDeadline),
+        () -> runTestTimeout(5, Type.WithDeadline),
         ExecutionException.class, StatusRuntimeException.class);
   }
 
@@ -72,7 +75,7 @@
   public void testWithTimeout() throws Exception {
     //Each sleep time is within the timeout,
     //Note that the total sleep time is longer than the timeout, but it does not matter.
-    runTestTimeout(12, Type.WithTimeout);
+    runTestTimeout(5, Type.WithTimeout);
   }
 
   void runTestTimeout(int slow, Type type) throws Exception {
@@ -80,14 +83,20 @@
     final TimeDuration timeout = ONE_SECOND.multiply(0.5);
     final StreamObserverFactory function = type.createFunction(timeout);
 
+    // first request may take longer due to initialization
+    final int warmup = type == Type.WithTimeout ? 1 : 0;
     final List<String> messages = new ArrayList<>();
     for (int i = 0; i < 2 * slow; i++) {
-      messages.add("m" + i);
+      messages.add("m" + (i + warmup));
     }
-    try (GrpcTestServer server = new GrpcTestServer(NetUtils.getFreePort(), slow, timeout)) {
+    try (GrpcTestServer server = new GrpcTestServer(NetUtils.getFreePort(), warmup, slow, timeout)) {
       final int port = server.start();
       try (GrpcTestClient client = new GrpcTestClient(NetUtils.LOCALHOST, port, function)) {
 
+        if (warmup == 1) {
+          client.send("warmup").join();
+        }
+
         final List<CompletableFuture<String>> futures = new ArrayList<>();
         for (String m : messages) {
           futures.add(client.send(m));
@@ -95,20 +104,20 @@
 
         int i = 0;
         for (; i < slow; i++) {
-          final String expected = i + GrpcTestServer.GreeterImpl.toReplySuffix(messages.get(i));
+          final String expected = (i + warmup) + GrpcTestServer.GreeterImpl.toReplySuffix(messages.get(i));
           final String reply = futures.get(i).get();
-          Assert.assertEquals("expected = " + expected + " != reply = " + reply, expected, reply);
-          LOG.info("{}) passed", i);
+          Assert.assertEquals(expected, reply);
+          LOG.info("{}) passed", (i + warmup));
         }
 
         for (; i < messages.size(); i++) {
           final CompletableFuture<String> f = futures.get(i);
           try {
             final String reply = f.get();
-            Assert.fail(i + ") reply = " + reply + ", "
+            Assert.fail((i + warmup) + ") reply = " + reply + ", "
                 + StringUtils.completableFuture2String(f, false));
           } catch (ExecutionException e) {
-             LOG.info("GOOD! {}) {}, {}", i, StringUtils.completableFuture2String(f, true), e);
+             LOG.info("GOOD! {}) {}, {}", (i + warmup), StringUtils.completableFuture2String(f, true), e);
           }
         }
       }