RATIS-2001. TransactionContext can be wrongly reused. (#1015)

diff --git a/ratis-common/src/main/java/org/apache/ratis/util/Preconditions.java b/ratis-common/src/main/java/org/apache/ratis/util/Preconditions.java
index 36e647f..c757de2 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/Preconditions.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/Preconditions.java
@@ -20,6 +20,7 @@
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.function.Supplier;
 
@@ -87,6 +88,11 @@
         () -> name + ": expected == " + expected + " but computed == " + computed);
   }
 
+  static void assertEquals(Object expected, Object computed, String name) {
+    assertTrue(Objects.equals(expected, computed),
+        () -> name + ": expected == " + expected + " but computed == " + computed);
+  }
+
   static void assertNull(Object object, Supplier<Object> message) {
     assertTrue(object == null, message);
   }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index ce19fda..043c731 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -1197,12 +1197,12 @@
         && (server.getRaftConf().isSingleton() || lease.isValid());
   }
 
-  void replyPendingRequest(long logIndex, RaftClientReply reply) {
-    pendingRequests.replyPendingRequest(logIndex, reply);
+  void replyPendingRequest(TermIndex termIndex, RaftClientReply reply) {
+    pendingRequests.replyPendingRequest(termIndex, reply);
   }
 
-  TransactionContext getTransactionContext(long index) {
-    return pendingRequests.getTransactionContext(index);
+  TransactionContext getTransactionContext(TermIndex termIndex) {
+    return pendingRequests.getTransactionContext(termIndex);
   }
 
   long[] getFollowerNextIndices() {
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
index a0b96cc..06a3a7b 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
@@ -17,28 +17,28 @@
  */
 package org.apache.ratis.server.impl;
 
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase;
 import org.apache.ratis.proto.RaftProtos.CommitInfoProto;
 import org.apache.ratis.protocol.*;
 import org.apache.ratis.protocol.exceptions.NotLeaderException;
-import org.apache.ratis.server.raftlog.RaftLog;
+import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
 
 import java.util.Collection;
+import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
 
-public class PendingRequest implements Comparable<PendingRequest> {
-  private final long index;
+class PendingRequest {
+  private final TermIndex termIndex;
   private final RaftClientRequest request;
   private final TransactionContext entry;
   private final CompletableFuture<RaftClientReply> futureToComplete = new CompletableFuture<>();
   private final CompletableFuture<RaftClientReply> futureToReturn;
 
-  PendingRequest(long index, RaftClientRequest request, TransactionContext entry) {
-    this.index = index;
+  PendingRequest(RaftClientRequest request, TransactionContext entry) {
+    this.termIndex = entry == null? null: TermIndex.valueOf(entry.getLogEntry());
     this.request = request;
     this.entry = entry;
     if (request.is(TypeCase.FORWARD)) {
@@ -49,7 +49,7 @@
   }
 
   PendingRequest(SetConfigurationRequest request) {
-    this(RaftLog.INVALID_LOG_INDEX, request, null);
+    this(request, null);
   }
 
   RaftClientReply convert(RaftClientRequest q, RaftClientReply p) {
@@ -63,8 +63,8 @@
         .build();
   }
 
-  long getIndex() {
-    return index;
+  TermIndex getTermIndex() {
+    return Objects.requireNonNull(termIndex, "termIndex");
   }
 
   RaftClientRequest getRequest() {
@@ -102,13 +102,7 @@
   }
 
   @Override
-  @SuppressFBWarnings("EQ_COMPARETO_USE_OBJECT_EQUALS")
-  public int compareTo(PendingRequest that) {
-    return Long.compare(this.index, that.index);
-  }
-
-  @Override
   public String toString() {
-    return JavaUtils.getClassSimpleName(getClass()) + ":index=" + index + ", request=" + request;
+    return JavaUtils.getClassSimpleName(getClass()) + "-" + termIndex + ":request=" + request;
   }
 }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java
index dc840b0..259695d 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java
@@ -28,6 +28,7 @@
 import org.apache.ratis.protocol.SetConfigurationRequest;
 import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.metrics.RaftServerMetricsImpl;
+import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
@@ -96,7 +97,7 @@
 
   private static class RequestMap {
     private final Object name;
-    private final ConcurrentMap<Long, PendingRequest> map = new ConcurrentHashMap<>();
+    private final ConcurrentMap<TermIndex, PendingRequest> map = new ConcurrentHashMap<>();
     private final RaftServerMetricsImpl raftServerMetrics;
 
     /** Permits to put new requests, always synchronized. */
@@ -112,8 +113,8 @@
       this.resource = new RequestLimits(elementLimit, megabyteLimit);
       this.raftServerMetrics = raftServerMetrics;
 
-      raftServerMetrics.addNumPendingRequestsGauge(() -> resource.getElementCount());
-      raftServerMetrics.addNumPendingRequestsMegaByteSize(() -> resource.getMegaByteSize());
+      raftServerMetrics.addNumPendingRequestsGauge(resource::getElementCount);
+      raftServerMetrics.addNumPendingRequestsMegaByteSize(resource::getMegaByteSize);
     }
 
     Permit tryAcquire(Message message) {
@@ -150,27 +151,27 @@
       return permit;
     }
 
-    synchronized PendingRequest put(Permit permit, long index, PendingRequest p) {
-      LOG.debug("{}: PendingRequests.put {} -> {}", name, index, p);
+    synchronized PendingRequest put(Permit permit, PendingRequest p) {
+      LOG.debug("{}: PendingRequests.put {}", name, p);
       final Permit removed = permits.remove(permit);
       if (removed == null) {
         return null;
       }
       Preconditions.assertTrue(removed == permit);
-      final PendingRequest previous = map.put(index, p);
+      final PendingRequest previous = map.put(p.getTermIndex(), p);
       Preconditions.assertTrue(previous == null);
       return p;
     }
 
-    PendingRequest get(long index) {
-      final PendingRequest r = map.get(index);
-      LOG.debug("{}: PendingRequests.get {} returns {}", name, index, r);
+    PendingRequest get(TermIndex termIndex) {
+      final PendingRequest r = map.get(termIndex);
+      LOG.debug("{}: PendingRequests.get {} returns {}", name, termIndex, r);
       return r;
     }
 
-    PendingRequest remove(long index) {
-      final PendingRequest r = map.remove(index);
-      LOG.debug("{}: PendingRequests.remove {} returns {}", name, index, r);
+    PendingRequest remove(TermIndex termIndex) {
+      final PendingRequest r = map.remove(termIndex);
+      LOG.debug("{}: PendingRequests.remove {} returns {}", name, termIndex, r);
       if (r == null) {
         return null;
       }
@@ -193,7 +194,7 @@
       LOG.debug("{}: PendingRequests.setNotLeaderException", name);
       final List<TransactionContext> transactions = new ArrayList<>(map.size());
       for(;;) {
-        final Iterator<Long> i = map.keySet().iterator();
+        final Iterator<TermIndex> i = map.keySet().iterator();
         if (!i.hasNext()) { // the map is empty
           return transactions;
         }
@@ -232,11 +233,8 @@
   }
 
   PendingRequest add(Permit permit, RaftClientRequest request, TransactionContext entry) {
-    // externally synced for now
-    final long index = entry.getLogEntry().getIndex();
-    LOG.debug("{}: addPendingRequest at index={}, request={}", name, index, request);
-    final PendingRequest pending = new PendingRequest(index, request, entry);
-    return pendingRequests.put(permit, index, pending);
+    final PendingRequest pending = new PendingRequest(request, entry);
+    return pendingRequests.put(permit, pending);
   }
 
   PendingRequest addConfRequest(SetConfigurationRequest request) {
@@ -265,17 +263,17 @@
     pendingSetConf = null;
   }
 
-  TransactionContext getTransactionContext(long index) {
-    PendingRequest pendingRequest = pendingRequests.get(index);
+  TransactionContext getTransactionContext(TermIndex termIndex) {
+    final PendingRequest pendingRequest = pendingRequests.get(termIndex);
     // it is possible that the pendingRequest is null if this peer just becomes
     // the new leader and commits transactions received by the previous leader
     return pendingRequest != null ? pendingRequest.getEntry() : null;
   }
 
-  void replyPendingRequest(long index, RaftClientReply reply) {
-    final PendingRequest pending = pendingRequests.remove(index);
+  void replyPendingRequest(TermIndex termIndex, RaftClientReply reply) {
+    final PendingRequest pending = pendingRequests.remove(termIndex);
     if (pending != null) {
-      Preconditions.assertTrue(pending.getIndex() == index);
+      Preconditions.assertEquals(termIndex, pending.getTermIndex(), "termIndex");
       pending.setReply(reply);
     }
   }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 37d7f30..de64913 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -1800,7 +1800,7 @@
    *                           from which we will get transaction result later
    */
   private CompletableFuture<Message> replyPendingRequest(
-      ClientInvocationId invocationId, long logIndex, CompletableFuture<Message> stateMachineFuture) {
+      ClientInvocationId invocationId, TermIndex termIndex, CompletableFuture<Message> stateMachineFuture) {
     // update the retry cache
     final CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId);
     Objects.requireNonNull(cacheEntry , "cacheEntry == null");
@@ -1812,8 +1812,8 @@
     }
 
     return stateMachineFuture.whenComplete((reply, exception) -> {
-      transactionManager.remove(logIndex);
-      final RaftClientReply.Builder b = newReplyBuilder(invocationId, logIndex);
+      transactionManager.remove(termIndex);
+      final RaftClientReply.Builder b = newReplyBuilder(invocationId, termIndex.getIndex());
       final RaftClientReply r;
       if (exception == null) {
         r = b.setSuccess().setMessage(reply).build();
@@ -1825,7 +1825,7 @@
       }
 
       // update pending request
-      role.getLeaderState().ifPresent(leader -> leader.replyPendingRequest(logIndex, r));
+      role.getLeaderState().ifPresent(leader -> leader.replyPendingRequest(termIndex, r));
       cacheEntry.updateResult(r);
     });
   }
@@ -1835,18 +1835,19 @@
       return null;
     }
 
+    final TermIndex termIndex = TermIndex.valueOf(entry);
     final Optional<LeaderStateImpl> leader = getRole().getLeaderState();
     if (leader.isPresent()) {
-      final TransactionContext context = leader.get().getTransactionContext(entry.getIndex());
+      final TransactionContext context = leader.get().getTransactionContext(termIndex);
       if (context != null) {
         return context;
       }
     }
 
     if (!createNew) {
-      return transactionManager.get(entry.getIndex());
+      return transactionManager.get(termIndex);
     }
-    return transactionManager.computeIfAbsent(entry.getIndex(),
+    return transactionManager.computeIfAbsent(termIndex,
         // call startTransaction only once
         MemoizedSupplier.valueOf(() -> stateMachine.startTransaction(entry, getInfo().getCurrentRole())));
   }
@@ -1872,7 +1873,7 @@
         trx = stateMachine.applyTransactionSerial(trx);
 
         final CompletableFuture<Message> stateMachineFuture = stateMachine.applyTransaction(trx);
-        return replyPendingRequest(invocationId, next.getIndex(), stateMachineFuture);
+        return replyPendingRequest(invocationId, TermIndex.valueOf(next), stateMachineFuture);
       } catch (Exception e) {
         throw new RaftLogIOException(e);
       }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/TransactionManager.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/TransactionManager.java
index aa989cf..283900f 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/TransactionManager.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/TransactionManager.java
@@ -17,6 +17,7 @@
  */
 package org.apache.ratis.server.impl;
 
+import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.statemachine.TransactionContext;
 
 import java.util.Optional;
@@ -28,17 +29,17 @@
  * Managing {@link TransactionContext}.
  */
 class TransactionManager {
-  private final ConcurrentMap<Long, Supplier<TransactionContext>> contexts = new ConcurrentHashMap<>();
+  private final ConcurrentMap<TermIndex, Supplier<TransactionContext>> contexts = new ConcurrentHashMap<>();
 
-  TransactionContext get(long index) {
-    return Optional.ofNullable(contexts.get(index)).map(Supplier::get).orElse(null);
+  TransactionContext get(TermIndex termIndex) {
+    return Optional.ofNullable(contexts.get(termIndex)).map(Supplier::get).orElse(null);
   }
 
-  TransactionContext computeIfAbsent(long index, Supplier<TransactionContext> constructor) {
-    return contexts.computeIfAbsent(index, i -> constructor).get();
+  TransactionContext computeIfAbsent(TermIndex termIndex, Supplier<TransactionContext> constructor) {
+    return contexts.computeIfAbsent(termIndex, i -> constructor).get();
   }
 
-  void remove(long index) {
-    contexts.remove(index);
+  void remove(TermIndex termIndex) {
+    contexts.remove(termIndex);
   }
 }
\ No newline at end of file