RATIS-236. Use TimeoutScheduler in RaftClientImpl.
diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
index 44df5c3..eb78463 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java
@@ -97,7 +97,7 @@
   /** Map: id -> {@link SlidingWindow}, in order to support async calls to the RAFT service or individual servers. */
   private final ConcurrentMap<String, SlidingWindow.Client<PendingAsyncRequest, RaftClientReply>>
       slidingWindows = new ConcurrentHashMap<>();
-  private final ScheduledExecutorService scheduler;
+  private final TimeoutScheduler scheduler;
   private final Semaphore asyncRequestSemaphore;
 
   RaftClientImpl(ClientId clientId, RaftGroup group, RaftPeerId leaderId,
@@ -111,7 +111,7 @@
     this.retryInterval = RaftClientConfigKeys.Rpc.retryInterval(properties);
 
     asyncRequestSemaphore = new Semaphore(RaftClientConfigKeys.Async.maxOutstandingRequests(properties));
-    scheduler = Executors.newScheduledThreadPool(RaftClientConfigKeys.Async.schedulerThreads(properties));
+    scheduler = TimeoutScheduler.newInstance(RaftClientConfigKeys.Async.schedulerThreads(properties));
     clientRpc.addServers(peers);
   }
 
@@ -237,9 +237,10 @@
     final CompletableFuture<RaftClientReply> f = pending.getReplyFuture();
     return sendRequestAsync(request).thenCompose(reply -> {
       if (reply == null) {
-        final TimeUnit unit = retryInterval.getUnit();
-        scheduler.schedule(() -> getSlidingWindow(request).retry(pending, this::sendRequestWithRetryAsync),
-            retryInterval.toLong(unit), unit);
+        LOG.debug("schedule a retry in {} for {}", retryInterval, request);
+        scheduler.onTimeout(retryInterval,
+            () -> getSlidingWindow(request).retry(pending, this::sendRequestWithRetryAsync),
+            LOG, () -> "Failed to retry " + request);
       } else {
         f.complete(reply);
       }
@@ -383,7 +384,7 @@
   }
 
   void assertScheduler(int numThreads) {
-    Preconditions.assertTrue(((ScheduledThreadPoolExecutor) scheduler).getCorePoolSize() == numThreads);
+    Preconditions.assertTrue(scheduler.getNumThreads() == numThreads);
   }
 
   long getCallId() {
diff --git a/ratis-common/src/main/java/org/apache/ratis/util/TimeoutScheduler.java b/ratis-common/src/main/java/org/apache/ratis/util/TimeoutScheduler.java
index 7a7d16c..7007a53 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/TimeoutScheduler.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/TimeoutScheduler.java
@@ -22,6 +22,7 @@
 
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
@@ -32,14 +33,10 @@
 
   private static final TimeDuration DEFAULT_GRACE_PERIOD = TimeDuration.valueOf(1, TimeUnit.MINUTES);
 
-  private static final Supplier<TimeoutScheduler> INSTANCE = JavaUtils.memoize(TimeoutScheduler::new);
-
-  public static TimeoutScheduler getInstance() {
-    return INSTANCE.get();
+  public static TimeoutScheduler newInstance(int numThreads) {
+    return new TimeoutScheduler(numThreads);
   }
 
-  private TimeoutScheduler() {}
-
   /** When there is no tasks, the time period to wait before shutting down the scheduler. */
   private final AtomicReference<TimeDuration> gracePeriod = new AtomicReference<>(DEFAULT_GRACE_PERIOD);
 
@@ -47,7 +44,18 @@
   private int numTasks = 0;
   /** The scheduleID for each task */
   private int scheduleID = 0;
-  private ScheduledExecutorService scheduler = null;
+
+  private final int numThreads;
+  private volatile ScheduledExecutorService scheduler = null;
+
+  private TimeoutScheduler(int numThreads) {
+    this.numThreads = numThreads;
+  }
+
+  public int getNumThreads() {
+    final ScheduledExecutorService s = scheduler;
+    return s instanceof ScheduledThreadPoolExecutor? ((ScheduledThreadPoolExecutor)s).getCorePoolSize(): numThreads;
+  }
 
   TimeDuration getGracePeriod() {
     return gracePeriod.get();
@@ -86,13 +94,13 @@
     if (scheduler == null) {
       Preconditions.assertTrue(numTasks == 0);
       LOG.debug("Initialize scheduler");
-      scheduler = Executors.newScheduledThreadPool(1);
+      scheduler = Executors.newScheduledThreadPool(numThreads);
     }
     numTasks++;
     final int sid = scheduleID++;
 
     LOG.debug("schedule a task: timeout {}, sid {}", timeout, sid);
-    scheduler.schedule(() -> toSchedule.accept(sid), timeout.getDuration(), timeout.getUnit());
+    schedule(scheduler, () -> toSchedule.accept(sid), () -> "task #" + sid, timeout);
   }
 
   private synchronized void onTaskCompleted() {
@@ -100,10 +108,15 @@
       final int sid = scheduleID;
       final TimeDuration grace = getGracePeriod();
       LOG.debug("Schedule a shutdown task: grace {}, sid {}", grace, sid);
-      scheduler.schedule(() -> tryShutdownScheduler(sid), grace.getDuration(), grace.getUnit());
+      schedule(scheduler, () -> tryShutdownScheduler(sid), () -> "shutdown task #" + sid, grace);
     }
   }
 
+  static void schedule(ScheduledExecutorService service, Runnable task, Supplier<String> name, TimeDuration timeDuration) {
+    service.schedule(LogUtils.newRunnable(LOG, task, name),
+        timeDuration.getDuration(), timeDuration.getUnit());
+  }
+
   private synchronized void tryShutdownScheduler(int sid) {
     if (sid == scheduleID) {
       // No new tasks submitted, shutdown the scheduler.
@@ -116,8 +129,7 @@
   }
 
   /** When timeout, run the task.  Log the error, if there is any. */
-  public static <THROWABLE extends Throwable> void onTimeout(
-      TimeDuration timeout, CheckedRunnable<THROWABLE> task, Logger log, Supplier<String> errorMessage) {
-    getInstance().onTimeout(timeout, task, t -> log.error(errorMessage.get(), t));
+  public void onTimeout(TimeDuration timeout, CheckedRunnable<?> task, Logger log, Supplier<String> errorMessage) {
+    onTimeout(timeout, task, t -> log.error(errorMessage.get(), t));
   }
 }
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/RaftClientProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/RaftClientProtocolClient.java
index d01bbe8..2d095ab 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/RaftClientProtocolClient.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/RaftClientProtocolClient.java
@@ -48,7 +48,6 @@
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
@@ -59,7 +58,10 @@
   private final Supplier<String> name;
   private final RaftPeer target;
   private final ManagedChannel channel;
+
   private final TimeDuration requestTimeoutDuration;
+  private final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
+
   private final RaftClientProtocolServiceBlockingStub blockingStub;
   private final RaftClientProtocolServiceStub asyncStub;
   private final AdminProtocolServiceBlockingStub adminBlockingStub;
@@ -188,7 +190,7 @@
           () -> getName() + ":" + getClass().getSimpleName());
       try {
         requestStreamObserver.onNext(ClientProtoUtils.toRaftClientRequestProto(request));
-        TimeoutScheduler.onTimeout(requestTimeoutDuration, () -> timeoutCheck(request), LOG,
+        scheduler.onTimeout(requestTimeoutDuration, () -> timeoutCheck(request), LOG,
             () -> "Timeout check failed for client request: " + request);
       } catch(Throwable t) {
         handleReplyFuture(request.getCallId(), future -> future.completeExceptionally(t));
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GRpcLogAppender.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GRpcLogAppender.java
index cb18caa..9a2163c 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GRpcLogAppender.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GRpcLogAppender.java
@@ -48,7 +48,8 @@
   private long callId = 0;
   private volatile boolean firstResponseReceived = false;
 
-  private static TimeDuration requestTimeoutDuration;
+  private final TimeDuration requestTimeoutDuration;
+  private final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
 
   private volatile StreamObserver<AppendEntriesRequestProto> appendLogRequestObserver;
 
@@ -172,7 +173,7 @@
         server.getId(), null, request);
 
     s.onNext(request);
-    TimeoutScheduler.onTimeout(requestTimeoutDuration, () -> timeoutAppendRequest(request), LOG,
+    scheduler.onTimeout(requestTimeoutDuration, () -> timeoutAppendRequest(request), LOG,
         () -> "Timeout check failed for append entry request: " + request);
     follower.updateLastRpcSendTime();
   }
diff --git a/ratis-server/src/test/java/org/apache/ratis/RaftAsyncTests.java b/ratis-server/src/test/java/org/apache/ratis/RaftAsyncTests.java
index d8909de..a343f76 100644
--- a/ratis-server/src/test/java/org/apache/ratis/RaftAsyncTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/RaftAsyncTests.java
@@ -63,8 +63,6 @@
   public void setup() {
     getProperties().setClass(MiniRaftCluster.STATEMACHINE_CLASS_KEY,
         SimpleStateMachine4Testing.class, StateMachine.class);
-    TimeDuration retryCacheExpiryDuration = TimeDuration.valueOf(5, TimeUnit.SECONDS);
-    RaftServerConfigKeys.RetryCache.setExpiryTime(getProperties(), retryCacheExpiryDuration);
   }
 
   @Test
@@ -144,12 +142,12 @@
     cluster.shutdown();
   }
 
-  void runTestBasicAppendEntriesAsync(ReplicationLevel replication) throws Exception {
-    final CLUSTER cluster = newCluster(NUM_SERVERS);
+  void runTestBasicAppendEntriesAsync(ReplicationLevel replication, boolean killLeader) throws Exception {
+    final CLUSTER cluster = newCluster(killLeader? 5: 3);
     try {
       cluster.start();
       waitForLeader(cluster);
-      RaftBasicTests.runTestBasicAppendEntries(true, replication, 1000, cluster, LOG);
+      RaftBasicTests.runTestBasicAppendEntries(true, replication, killLeader, 1000, cluster, LOG);
     } finally {
       cluster.shutdown();
     }
@@ -157,12 +155,17 @@
 
   @Test
   public void testBasicAppendEntriesAsync() throws Exception {
-    runTestBasicAppendEntriesAsync(ReplicationLevel.MAJORITY);
+    runTestBasicAppendEntriesAsync(ReplicationLevel.MAJORITY, false);
+  }
+
+  @Test
+  public void testBasicAppendEntriesAsyncKillLeader() throws Exception {
+    runTestBasicAppendEntriesAsync(ReplicationLevel.MAJORITY, true);
   }
 
   @Test
   public void testBasicAppendEntriesAsyncWithAllReplication() throws Exception {
-    runTestBasicAppendEntriesAsync(ReplicationLevel.ALL);
+    runTestBasicAppendEntriesAsync(ReplicationLevel.ALL, false);
   }
 
   @Test
@@ -253,10 +256,15 @@
 
   @Test
   public void testRequestTimeout() throws Exception {
+    final TimeDuration oldExpiryTime = RaftServerConfigKeys.RetryCache.expiryTime(getProperties());
+    RaftServerConfigKeys.RetryCache.setExpiryTime(getProperties(), TimeDuration.valueOf(5, TimeUnit.SECONDS));
     final CLUSTER cluster = newCluster(NUM_SERVERS);
     cluster.start();
     RaftBasicTests.testRequestTimeout(true, cluster, LOG);
     cluster.shutdown();
+
+    //reset for the other tests
+    RaftServerConfigKeys.RetryCache.setExpiryTime(getProperties(), oldExpiryTime);
   }
 
   @Test
diff --git a/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java b/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
index 64b3c7e..f35878c 100644
--- a/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
@@ -98,12 +98,17 @@
 
   @Test
   public void testBasicAppendEntries() throws Exception {
-    runTestBasicAppendEntries(false, ReplicationLevel.MAJORITY, 10, getCluster(), LOG);
+    runTestBasicAppendEntries(false, ReplicationLevel.MAJORITY, false, 10, getCluster(), LOG);
+  }
+
+  @Test
+  public void testBasicAppendEntriesKillLeader() throws Exception {
+    runTestBasicAppendEntries(false, ReplicationLevel.MAJORITY, true, 10, getCluster(), LOG);
   }
 
   @Test
   public void testBasicAppendEntriesWithAllReplication() throws Exception {
-    runTestBasicAppendEntries(false, ReplicationLevel.ALL, 10, getCluster(), LOG);
+    runTestBasicAppendEntries(false, ReplicationLevel.ALL, false, 10, getCluster(), LOG);
   }
 
   static void killAndRestartServer(RaftPeerId id, long killSleepMs, long restartSleepMs, MiniRaftCluster cluster, Logger LOG) {
@@ -119,9 +124,10 @@
   }
 
   static void runTestBasicAppendEntries(
-      boolean async, ReplicationLevel replication, int numMessages, MiniRaftCluster cluster, Logger LOG) throws Exception {
-    LOG.info("runTestBasicAppendEntries: async? {}, replication={}, numMessages={}",
-        async, replication, numMessages);
+      boolean async, ReplicationLevel replication, boolean killLeader, int numMessages, MiniRaftCluster cluster, Logger LOG)
+      throws Exception {
+    LOG.info("runTestBasicAppendEntries: async? {}, replication={}, killLeader={}, numMessages={}",
+        async, replication, killLeader, numMessages);
     for (RaftServer s : cluster.getServers()) {
       cluster.restartServer(s.getId(), false);
     }
@@ -129,6 +135,10 @@
     final long term = leader.getState().getCurrentTerm();
 
     new Thread(() -> killAndRestartServer(cluster.getFollowers().get(0).getId(), 0, 1000, cluster, LOG)).start();
+    if (killLeader) {
+      LOG.info("killAndRestart leader " + leader.getId());
+      new Thread(() -> killAndRestartServer(leader.getId(), 2000, 4000, cluster, LOG)).start();
+    }
 
     LOG.info(cluster.printServers());
 
diff --git a/ratis-server/src/test/java/org/apache/ratis/util/TestTimeoutScheduler.java b/ratis-server/src/test/java/org/apache/ratis/util/TestTimeoutScheduler.java
index 7c4ef4f..6a63569 100644
--- a/ratis-server/src/test/java/org/apache/ratis/util/TestTimeoutScheduler.java
+++ b/ratis-server/src/test/java/org/apache/ratis/util/TestTimeoutScheduler.java
@@ -46,7 +46,7 @@
 
   @Test(timeout = 1000)
   public void testSingleTask() throws Exception {
-    final TimeoutScheduler scheduler = TimeoutScheduler.getInstance();
+    final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
     final TimeDuration grace = TimeDuration.valueOf(100, TimeUnit.MILLISECONDS);
     scheduler.setGracePeriod(grace);
     Assert.assertFalse(scheduler.hasScheduler());
@@ -81,7 +81,7 @@
 
   @Test(timeout = 1000)
   public void testMultipleTasks() throws Exception {
-    final TimeoutScheduler scheduler = TimeoutScheduler.getInstance();
+    final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
     final TimeDuration grace = TimeDuration.valueOf(100, TimeUnit.MILLISECONDS);
     scheduler.setGracePeriod(grace);
     Assert.assertFalse(scheduler.hasScheduler());
@@ -127,7 +127,7 @@
 
   @Test(timeout = 1000)
   public void testExtendingGracePeriod() throws Exception {
-    final TimeoutScheduler scheduler = TimeoutScheduler.getInstance();
+    final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
     final TimeDuration grace = TimeDuration.valueOf(100, TimeUnit.MILLISECONDS);
     scheduler.setGracePeriod(grace);
     Assert.assertFalse(scheduler.hasScheduler());
@@ -177,7 +177,7 @@
 
   @Test(timeout = 1000)
   public void testRestartingScheduler() throws Exception {
-    final TimeoutScheduler scheduler = TimeoutScheduler.getInstance();
+    final TimeoutScheduler scheduler = TimeoutScheduler.newInstance(1);
     final TimeDuration grace = TimeDuration.valueOf(100, TimeUnit.MILLISECONDS);
     scheduler.setGracePeriod(grace);
     Assert.assertFalse(scheduler.hasScheduler());