RATIS-1983. Refactor client request processing to support reference count. (#998)

diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java
index 1a9f83c..1985bbe 100644
--- a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java
+++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -17,12 +17,40 @@
  */
 package org.apache.ratis.protocol;
 
+import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.ReferenceCountedObject;
+
 import java.io.IOException;
 import java.util.concurrent.CompletableFuture;
 
 /** Asynchronous version of {@link RaftClientProtocol}. */
 public interface RaftClientAsynchronousProtocol {
-  CompletableFuture<RaftClientReply> submitClientRequestAsync(
-      RaftClientRequest request) throws IOException;
+  /**
+   * It is recommended to override {@link #submitClientRequestAsync(ReferenceCountedObject)} instead.
+   * Then, it does not have to override this method.
+   */
+  default CompletableFuture<RaftClientReply> submitClientRequestAsync(
+      RaftClientRequest request) throws IOException {
+    return submitClientRequestAsync(ReferenceCountedObject.wrap(request));
+  }
 
+  /**
+   * A referenced counted request is submitted from a client for processing.
+   * Implementations of this method should retain the request, process it and then release it.
+   * The request may be retained even after the future returned by this method has completed.
+   *
+   * @return a future of the reply
+   * @see ReferenceCountedObject
+   */
+  default CompletableFuture<RaftClientReply> submitClientRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> requestRef) {
+    try {
+      // for backward compatibility
+      return submitClientRequestAsync(requestRef.retain())
+          .whenComplete((r, e) -> requestRef.release());
+    } catch (Exception e) {
+      requestRef.release();
+      return JavaUtils.completeExceptionally(e);
+    }
+  }
 }
\ No newline at end of file
diff --git a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
index 0dd378d..3f72f5f 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
@@ -102,6 +102,30 @@
   }
 
   /**
+   * @return a {@link ReferenceCountedObject} of the given value by delegating to this object.
+   */
+  default <V> ReferenceCountedObject<V> delegate(V value) {
+    final ReferenceCountedObject<T> delegated = this;
+    return new ReferenceCountedObject<V>() {
+      @Override
+      public V get() {
+        return value;
+      }
+
+      @Override
+      public V retain() {
+        delegated.retain();
+        return value;
+      }
+
+      @Override
+      public boolean release() {
+        return delegated.release();
+      }
+    };
+  }
+
+  /**
    * Wrap the given value as a {@link ReferenceCountedObject}.
    *
    * @param value the value being wrapped.
diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
index 5f258ee..858e300 100644
--- a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
+++ b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
@@ -32,6 +32,7 @@
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.raftlog.LogProtoUtils;
 import org.apache.ratis.server.storage.RaftStorage;
 import org.apache.ratis.statemachine.StateMachineStorage;
 import org.apache.ratis.statemachine.TransactionContext;
@@ -40,6 +41,7 @@
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.ratis.util.FileUtils;
+import org.apache.ratis.util.JavaUtils;
 
 import java.io.IOException;
 import java.nio.file.Path;
@@ -168,9 +170,11 @@
   }
 
   static class LocalStream implements DataStream {
+    private final String name;
     private final DataChannel dataChannel;
 
-    LocalStream(DataChannel dataChannel) {
+    LocalStream(String name, DataChannel dataChannel) {
+      this.name = JavaUtils.getClassSimpleName(getClass()) + "[" + name + "]";
       this.dataChannel = dataChannel;
     }
 
@@ -190,6 +194,11 @@
         }
       });
     }
+
+    @Override
+    public String toString() {
+      return name;
+    }
   }
 
   @Override
@@ -202,13 +211,14 @@
       return FileStoreCommon.completeExceptionally(
           "Failed to parse stream header", e);
     }
-    return files.createDataChannel(proto.getStream().getPath().toStringUtf8())
-        .thenApply(LocalStream::new);
+    final String file = proto.getStream().getPath().toStringUtf8();
+    return files.createDataChannel(file)
+        .thenApply(channel -> new LocalStream(file, channel));
   }
 
   @Override
   public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
-    LOG.info("linking {}", stream);
+    LOG.info("linking {} to {}", stream, LogProtoUtils.toLogEntryString(entry));
     return files.streamLink(stream);
   }
 
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
index 3821b05..e019074 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
@@ -23,6 +23,7 @@
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.ReflectionUtils;
 
 import java.io.IOException;
@@ -98,6 +99,13 @@
    */
   LogEntryProto getLogEntry();
 
+  /** Wrap the given log entry as a {@link ReferenceCountedObject} for retaining it for later use. */
+  default ReferenceCountedObject<LogEntryProto> wrap(LogEntryProto entry) {
+    Preconditions.assertSame(getLogEntry().getTerm(), entry.getTerm(), "entry.term");
+    Preconditions.assertSame(getLogEntry().getIndex(), entry.getIndex(), "entry.index");
+    return ReferenceCountedObject.wrap(entry);
+  }
+
   /**
    * Sets whether to commit the transaction to the RAFT log or not
    * @param shouldCommit true if the transaction is supposed to be committed to the RAFT log
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index ed5457b..73451bf 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -42,6 +42,7 @@
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import javax.management.ObjectName;
+
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.metrics.Timekeeper;
@@ -138,6 +139,7 @@
 import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.ProtoUtils;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.TimeDuration;
 import org.apache.ratis.util.Timestamp;
 import org.apache.ratis.util.function.CheckedSupplier;
@@ -822,15 +824,21 @@
   }
 
   /**
-   * Handle a normal update request from client.
+   * Append a transaction to the log for processing a client request.
+   * Note that the given request could be different from {@link TransactionContext#getClientRequest()}
+   * since the request could be converted; see {@link #convertRaftClientRequest(RaftClientRequest)}.
+   *
+   * @param request The client request.
+   * @param context The context of the transaction.
+   * @param cacheEntry the entry in the retry cache.
+   * @return a future of the reply.
    */
   private CompletableFuture<RaftClientReply> appendTransaction(
-      RaftClientRequest request, TransactionContextImpl context, CacheEntry cacheEntry) throws IOException {
+      RaftClientRequest request, TransactionContextImpl context, CacheEntry cacheEntry) {
+    Objects.requireNonNull(request, "request == null");
     CodeInjectionForTesting.execute(APPEND_TRANSACTION, getId(),
         request.getClientId(), request, context, cacheEntry);
 
-    assertLifeCycleState(LifeCycle.States.RUNNING);
-
     final PendingRequest pending;
     synchronized (this) {
       final CompletableFuture<RaftClientReply> reply = checkLeaderState(request, cacheEntry);
@@ -849,6 +857,7 @@
         return cacheEntry.getReplyFuture();
       }
       try {
+        assertLifeCycleState(LifeCycle.States.RUNNING);
         state.appendLog(context);
       } catch (StateMachineException e) {
         // the StateMachineException is thrown by the SM in the preAppend stage.
@@ -860,6 +869,9 @@
           leaderState.submitStepDownEvent(LeaderState.StepDownReason.STATE_MACHINE_EXCEPTION);
         }
         return CompletableFuture.completedFuture(exceptionReply);
+      } catch (ServerNotReadyException e) {
+        final RaftClientReply exceptionReply = newExceptionReply(request, e);
+        return CompletableFuture.completedFuture(exceptionReply);
       }
 
       // put the request into the pending queue
@@ -878,11 +890,13 @@
     role.getLeaderState().ifPresent(leader -> leader.submitStepDownEvent(LeaderState.StepDownReason.JVM_PAUSE));
   }
 
-  private RaftClientRequest filterDataStreamRaftClientRequest(RaftClientRequest request)
-      throws InvalidProtocolBufferException {
-    return !request.is(TypeCase.FORWARD) ? request : ClientProtoUtils.toRaftClientRequest(
-        RaftClientRequestProto.parseFrom(
-            request.getMessage().getContent().asReadOnlyByteBuffer()));
+  /** If the given request is {@link TypeCase#FORWARD}, convert it. */
+  static RaftClientRequest convertRaftClientRequest(RaftClientRequest request) throws InvalidProtocolBufferException {
+    if (!request.is(TypeCase.FORWARD)) {
+      return request;
+    }
+    return ClientProtoUtils.toRaftClientRequest(RaftClientRequestProto.parseFrom(
+        request.getMessage().getContent().asReadOnlyByteBuffer()));
   }
 
   <REPLY> CompletableFuture<REPLY> executeSubmitServerRequestAsync(
@@ -892,20 +906,29 @@
         serverExecutor).join();
   }
 
-  CompletableFuture<RaftClientReply> executeSubmitClientRequestAsync(RaftClientRequest request) {
-    return CompletableFuture.supplyAsync(
-        () -> JavaUtils.callAsUnchecked(() -> submitClientRequestAsync(request), CompletionException::new),
-        clientExecutor).join();
+  CompletableFuture<RaftClientReply> executeSubmitClientRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> request) {
+    return CompletableFuture.supplyAsync(() -> submitClientRequestAsync(request), clientExecutor).join();
   }
 
   @Override
   public CompletableFuture<RaftClientReply> submitClientRequestAsync(
-      RaftClientRequest request) throws IOException {
-    assertLifeCycleState(LifeCycle.States.RUNNING);
+      ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.retain();
     LOG.debug("{}: receive client request({})", getMemberId(), request);
+
+    try {
+      assertLifeCycleState(LifeCycle.States.RUNNING);
+    } catch (ServerNotReadyException e) {
+      final RaftClientReply reply = newExceptionReply(request, e);
+      requestRef.release();
+      return CompletableFuture.completedFuture(reply);
+    }
+
     final Timekeeper timer = raftServerMetrics.getClientRequestTimer(request.getType());
     final Optional<Timekeeper.Context> timerContext = Optional.ofNullable(timer).map(Timekeeper::time);
-    return replyFuture(request).whenComplete((clientReply, exception) -> {
+    return replyFuture(requestRef).whenComplete((clientReply, exception) -> {
+      requestRef.release();
       timerContext.ifPresent(Timekeeper.Context::stop);
       if (exception != null || clientReply.getException() != null) {
         raftServerMetrics.incFailedRequestCount(request.getType());
@@ -913,7 +936,8 @@
     });
   }
 
-  private CompletableFuture<RaftClientReply> replyFuture(RaftClientRequest request) throws IOException {
+  private CompletableFuture<RaftClientReply> replyFuture(ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.get();
     retryCache.invalidateRepliedRequests(request);
 
     final TypeCase type = request.getType().getTypeCase();
@@ -925,16 +949,17 @@
       case WATCH:
         return watchAsync(request);
       case MESSAGESTREAM:
-        return messageStreamAsync(request);
+        return messageStreamAsync(requestRef);
       case WRITE:
       case FORWARD:
-        return writeAsync(request);
+        return writeAsync(requestRef);
       default:
         throw new IllegalStateException("Unexpected request type: " + type + ", request=" + request);
     }
   }
 
-  private CompletableFuture<RaftClientReply> writeAsync(RaftClientRequest request) throws IOException {
+  private CompletableFuture<RaftClientReply> writeAsync(ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.get();
     final CompletableFuture<RaftClientReply> reply = checkLeaderState(request);
     if (reply != null) {
       return reply;
@@ -950,8 +975,15 @@
     // TODO: this client request will not be added to pending requests until
     // later which means that any failure in between will leave partial state in
     // the state machine. We should call cancelTransaction() for failed requests
-    final TransactionContextImpl context = (TransactionContextImpl) stateMachine.startTransaction(
-        filterDataStreamRaftClientRequest(request));
+    final TransactionContextImpl context;
+    try {
+      context = (TransactionContextImpl) stateMachine.startTransaction(convertRaftClientRequest(request));
+    } catch (IOException e) {
+      final RaftClientReply exceptionReply = newExceptionReply(request,
+          new RaftException("Failed to startTransaction for " + request, e));
+      cacheEntry.failWithReply(exceptionReply);
+      return CompletableFuture.completedFuture(exceptionReply);
+    }
     if (context.getException() != null) {
       final StateMachineException e = new StateMachineException(getMemberId(), context.getException());
       final RaftClientReply exceptionReply = newExceptionReply(request, e);
@@ -959,6 +991,7 @@
       return CompletableFuture.completedFuture(exceptionReply);
     }
 
+    context.setDelegatedRef(requestRef);
     return appendTransaction(request, context, cacheEntry);
   }
 
@@ -1062,7 +1095,8 @@
     }
   }
 
-  private CompletableFuture<RaftClientReply> messageStreamAsync(RaftClientRequest request) throws IOException {
+  private CompletableFuture<RaftClientReply> messageStreamAsync(ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.get();
     final CompletableFuture<RaftClientReply> reply = checkLeaderState(request);
     if (reply != null) {
       return reply;
@@ -1074,7 +1108,7 @@
         return f.thenApply(r -> null);
       }
       // the message stream has ended and the request become a WRITE request
-      return replyFuture(f.join());
+      return replyFuture(requestRef.delegate(f.join()));
     }
 
     return role.getLeaderState()
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java
index fd80d69..cb7918e 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java
@@ -52,6 +52,7 @@
 import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.ProtoUtils;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.TimeDuration;
 
 import java.io.Closeable;
@@ -444,9 +445,15 @@
   }
 
   @Override
-  public CompletableFuture<RaftClientReply> submitClientRequestAsync(RaftClientRequest request) {
-    return getImplFuture(request.getRaftGroupId())
-        .thenCompose(impl -> impl.executeSubmitClientRequestAsync(request));
+  public CompletableFuture<RaftClientReply> submitClientRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.retain();
+    try {
+      return getImplFuture(request.getRaftGroupId())
+          .thenCompose(impl -> impl.executeSubmitClientRequestAsync(requestRef));
+    } finally {
+      requestRef.release();
+    }
   }
 
   @Override
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
index b8e0e72..1e0ef66 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
@@ -31,6 +31,7 @@
 import org.apache.ratis.thirdparty.com.google.protobuf.CodedOutputStream;
 import org.apache.ratis.util.FileUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.SizeInBytes;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -41,6 +42,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
@@ -66,17 +68,20 @@
   }
 
   static long getEntrySize(LogEntryProto entry, Op op) {
-    LogEntryProto e = entry;
-    if (op == Op.CHECK_SEGMENT_FILE_FULL) {
-      e = LogProtoUtils.removeStateMachineData(entry);
-    } else if (op == Op.LOAD_SEGMENT_FILE || op == Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE) {
-      Preconditions.assertTrue(entry == LogProtoUtils.removeStateMachineData(entry),
-          () -> "Unexpected LogEntryProto with StateMachine data: op=" + op + ", entry=" + entry);
-    } else {
-      Preconditions.assertTrue(op == Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE || op == Op.REMOVE_CACHE,
-          () -> "Unexpected op " + op + ", entry=" + entry);
+    switch (op) {
+      case CHECK_SEGMENT_FILE_FULL:
+      case LOAD_SEGMENT_FILE:
+      case WRITE_CACHE_WITH_STATE_MACHINE_CACHE:
+        Preconditions.assertTrue(entry == LogProtoUtils.removeStateMachineData(entry),
+            () -> "Unexpected LogEntryProto with StateMachine data: op=" + op + ", entry=" + entry);
+        break;
+      case WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE:
+      case REMOVE_CACHE:
+        break;
+      default:
+        throw new IllegalStateException("Unexpected op " + op + ", entry=" + entry);
     }
-    final int serialized = e.getSerializedSize();
+    final int serialized = entry.getSerializedSize();
     return serialized + CodedOutputStream.computeUInt32SizeNoTag(serialized) + 4L;
   }
 
@@ -123,7 +128,8 @@
   }
 
   public static int readSegmentFile(File file, LogSegmentStartEnd startEnd, SizeInBytes maxOpSize,
-      CorruptionPolicy corruptionPolicy, SegmentedRaftLogMetrics raftLogMetrics, Consumer<LogEntryProto> entryConsumer)
+      CorruptionPolicy corruptionPolicy, SegmentedRaftLogMetrics raftLogMetrics,
+      Consumer<ReferenceCountedObject<LogEntryProto>> entryConsumer)
       throws IOException {
     int count = 0;
     try (SegmentedRaftLogInputStream in = new SegmentedRaftLogInputStream(
@@ -135,7 +141,8 @@
         }
 
         if (entryConsumer != null) {
-          entryConsumer.accept(next);
+          // TODO: use reference count to support zero buffer copying for readSegmentFile
+          entryConsumer.accept(ReferenceCountedObject.wrap(next));
         }
         count++;
       }
@@ -162,10 +169,7 @@
     final CorruptionPolicy corruptionPolicy = CorruptionPolicy.get(storage, RaftStorage::getLogCorruptionPolicy);
     final boolean isOpen = startEnd.isOpen();
     final int entryCount = readSegmentFile(file, startEnd, maxOpSize, corruptionPolicy, raftLogMetrics, entry -> {
-      segment.append(keepEntryInCache || isOpen, entry, Op.LOAD_SEGMENT_FILE);
-      if (logConsumer != null) {
-        logConsumer.accept(entry);
-      }
+      segment.append(Op.LOAD_SEGMENT_FILE, entry, keepEntryInCache || isOpen, logConsumer);
     });
     LOG.info("Successfully read {} entries from segment file {}", entryCount, file);
 
@@ -233,10 +237,10 @@
       // the on-disk log file should be truncated but has not been done yet.
       final AtomicReference<LogEntryProto> toReturn = new AtomicReference<>();
       final LogSegmentStartEnd startEnd = LogSegmentStartEnd.valueOf(startIndex, endIndex, isOpen);
-      readSegmentFile(file, startEnd, maxOpSize,
-          getLogCorruptionPolicy(), raftLogMetrics, entry -> {
+      readSegmentFile(file, startEnd, maxOpSize, getLogCorruptionPolicy(), raftLogMetrics, entryRef -> {
+        final LogEntryProto entry = entryRef.retain();
         final TermIndex ti = TermIndex.valueOf(entry);
-        putEntryCache(ti, entry, Op.LOAD_SEGMENT_FILE);
+        putEntryCache(ti, entryRef, Op.LOAD_SEGMENT_FILE);
         if (ti.equals(key.getTermIndex())) {
           toReturn.set(entry);
         }
@@ -246,13 +250,48 @@
     }
   }
 
+  static class EntryCache {
+    private final Map<TermIndex, ReferenceCountedObject<LogEntryProto>> map = new ConcurrentHashMap<>();
+    private final AtomicLong size = new AtomicLong();
+
+    long size() {
+      return size.get();
+    }
+
+    LogEntryProto get(TermIndex ti) {
+      return Optional.ofNullable(map.get(ti))
+          .map(ReferenceCountedObject::get)
+          .orElse(null);
+    }
+
+    void clear() {
+      map.values().forEach(ReferenceCountedObject::release);
+      map.clear();
+      size.set(0);
+    }
+
+    void put(TermIndex key, ReferenceCountedObject<LogEntryProto> valueRef, Op op) {
+      valueRef.retain();
+      Optional.ofNullable(map.put(key, valueRef)).ifPresent(this::release);
+      size.getAndAdd(getEntrySize(valueRef.get(), op));
+    }
+
+    private void release(ReferenceCountedObject<LogEntryProto> entry) {
+      size.getAndAdd(-getEntrySize(entry.get(), Op.REMOVE_CACHE));
+      entry.release();
+    }
+
+    void remove(TermIndex key) {
+      Optional.ofNullable(map.remove(key)).ifPresent(this::release);
+    }
+  }
+
   File getFile() {
     return LogSegmentStartEnd.valueOf(startIndex, endIndex, isOpen).getFile(storage);
   }
 
   private volatile boolean isOpen;
   private long totalFileSize = SegmentedRaftLogFormat.getHeaderLength();
-  private AtomicLong totalCacheSize = new AtomicLong(0);
   /** Segment start index, inclusive. */
   private long startIndex;
   /** Segment end index, inclusive. */
@@ -270,7 +309,7 @@
   /**
    * the entryCache caches the content of log entries.
    */
-  private final Map<TermIndex, LogEntryProto> entryCache = new ConcurrentHashMap<>();
+  private final EntryCache entryCache = new EntryCache();
 
   private LogSegment(RaftStorage storage, boolean isOpen, long start, long end, SizeInBytes maxOpSize,
       SegmentedRaftLogMetrics raftLogMetrics) {
@@ -302,12 +341,29 @@
     return CorruptionPolicy.get(storage, RaftStorage::getLogCorruptionPolicy);
   }
 
-  void appendToOpenSegment(LogEntryProto entry, Op op) {
+  void appendToOpenSegment(Op op, ReferenceCountedObject<LogEntryProto> entryRef) {
     Preconditions.assertTrue(isOpen(), "The log segment %s is not open for append", this);
-    append(true, entry, op);
+    append(op, entryRef, true, null);
   }
 
-  private void append(boolean keepEntryInCache, LogEntryProto entry, Op op) {
+  private void append(Op op, ReferenceCountedObject<LogEntryProto> entryRef,
+      boolean keepEntryInCache, Consumer<LogEntryProto> logConsumer) {
+    final LogEntryProto entry = entryRef.retain();
+    try {
+      final LogRecord record = appendLogRecord(op, entry);
+      if (keepEntryInCache) {
+        putEntryCache(record.getTermIndex(), entryRef, op);
+      }
+      if (logConsumer != null) {
+        logConsumer.accept(entry);
+      }
+    } finally {
+      entryRef.release();
+    }
+  }
+
+
+  private LogRecord appendLogRecord(Op op, LogEntryProto entry) {
     Objects.requireNonNull(entry, "entry == null");
     if (records.isEmpty()) {
       Preconditions.assertTrue(entry.getIndex() == startIndex,
@@ -323,11 +379,9 @@
 
     final LogRecord record = new LogRecord(totalFileSize, entry);
     records.add(record);
-    if (keepEntryInCache) {
-      putEntryCache(record.getTermIndex(), entry, op);
-    }
     totalFileSize += getEntrySize(entry, op);
     endIndex = entry.getIndex();
+    return record;
   }
 
   LogEntryProto getEntryFromCache(TermIndex ti) {
@@ -370,7 +424,7 @@
   }
 
   long getTotalCacheSize() {
-    return totalCacheSize.get();
+    return entryCache.size();
   }
 
   /**
@@ -380,7 +434,7 @@
     Preconditions.assertTrue(fromIndex >= startIndex && fromIndex <= endIndex);
     for (long index = endIndex; index >= fromIndex; index--) {
       LogRecord removed = records.remove(Math.toIntExact(index - startIndex));
-      removeEntryCache(removed.getTermIndex(), Op.REMOVE_CACHE);
+      removeEntryCache(removed.getTermIndex());
       totalFileSize = removed.offset;
     }
     isOpen = false;
@@ -417,28 +471,18 @@
 
   void evictCache() {
     entryCache.clear();
-    totalCacheSize.set(0);
   }
 
-  void putEntryCache(TermIndex key, LogEntryProto value, Op op) {
-    final LogEntryProto previous = entryCache.put(key, value);
-    long previousSize = 0;
-    if (previous != null) {
-      // Different threads maybe load LogSegment file into cache at the same time, so duplicate maybe happen
-      previousSize = getEntrySize(value, Op.REMOVE_CACHE);
-    }
-    totalCacheSize.getAndAdd(getEntrySize(value, op) - previousSize);
+  void putEntryCache(TermIndex key, ReferenceCountedObject<LogEntryProto> valueRef, Op op) {
+    entryCache.put(key, valueRef, op);
   }
 
-  void removeEntryCache(TermIndex key, Op op) {
-    LogEntryProto value = entryCache.remove(key);
-    if (value != null) {
-      totalCacheSize.getAndAdd(-getEntrySize(value, op));
-    }
+  void removeEntryCache(TermIndex key) {
+    entryCache.remove(key);
   }
 
   boolean hasCache() {
-    return isOpen || !entryCache.isEmpty(); // open segment always has cache.
+    return isOpen || entryCache.size() > 0; // open segment always has cache.
   }
 
   boolean containsIndex(long index) {
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
index a729f8e..1cfb593 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
@@ -41,6 +41,7 @@
 import org.apache.ratis.util.AwaitToRun;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.StringUtils;
 
 import java.io.File;
@@ -53,6 +54,7 @@
 import java.util.concurrent.CompletionException;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import java.util.function.LongSupplier;
 
 import org.apache.ratis.util.UncheckedAutoCloseable;
@@ -391,6 +393,7 @@
     if (LOG.isTraceEnabled()) {
       LOG.trace("{}: appendEntry {}", getName(), LogProtoUtils.toLogEntryString(entry));
     }
+    final LogEntryProto removedStateMachineData = LogProtoUtils.removeStateMachineData(entry);
     try(AutoCloseableLock writeLock = writeLock()) {
       final Timekeeper.Context appendEntryTimerContext = getRaftLogMetrics().startAppendEntryTimer();
       validateLogEntry(entry);
@@ -399,7 +402,7 @@
       if (currentOpenSegment == null) {
         cache.addOpenSegment(entry.getIndex());
         fileLogWorker.startLogSegment(entry.getIndex());
-      } else if (isSegmentFull(currentOpenSegment, entry)) {
+      } else if (isSegmentFull(currentOpenSegment, removedStateMachineData)) {
         rollOpenSegment = true;
       } else {
         final TermIndex last = currentOpenSegment.getLastTermIndex();
@@ -421,17 +424,17 @@
       // If the entry has state machine data, then the entry should be inserted
       // to statemachine first and then to the cache. Not following the order
       // will leave a spurious entry in the cache.
-      CompletableFuture<Long> writeFuture =
-          fileLogWorker.writeLogEntry(entry, context).getFuture();
+      final Task write = fileLogWorker.writeLogEntry(entry, removedStateMachineData, context);
+      final Function<LogEntryProto, ReferenceCountedObject<LogEntryProto>> wrap = context != null ?
+          context::wrap : ReferenceCountedObject::wrap;
       if (stateMachineCachingEnabled) {
         // The stateMachineData will be cached inside the StateMachine itself.
-        cache.appendEntry(LogProtoUtils.removeStateMachineData(entry),
-            LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE);
+        cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE, wrap.apply(removedStateMachineData));
       } else {
-        cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+        cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, wrap.apply(entry)
+        );
       }
-      writeFuture.whenComplete((clientReply, exception) -> appendEntryTimerContext.stop());
-      return writeFuture;
+      return write.getFuture().whenComplete((clientReply, exception) -> appendEntryTimerContext.stop());
     } catch (Exception e) {
       LOG.error("{}: Failed to append {}", getName(), LogProtoUtils.toLogEntryString(entry), e);
       throw e;
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java
index bd6d831..81f4677 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java
@@ -32,6 +32,7 @@
 import org.apache.ratis.util.AutoCloseableReadWriteLock;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.SizeInBytes;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -596,11 +597,11 @@
     }
   }
 
-  void appendEntry(LogEntryProto entry, LogSegment.Op op) {
+  void appendEntry(LogSegment.Op op, ReferenceCountedObject<LogEntryProto> entry) {
     // SegmentedRaftLog does the segment creation/rolling work. Here we just
     // simply append the entry into the open segment.
     Preconditions.assertNotNull(openSegment, "openSegment");
-    openSegment.appendToOpenSegment(entry, op);
+    openSegment.appendToOpenSegment(op, entry);
   }
 
   /**
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
index 0e8d0f3..0d1ea76 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
@@ -438,8 +438,8 @@
     addIOTask(new StartLogSegment(segmentToClose.getEndIndex() + 1));
   }
 
-  Task writeLogEntry(LogEntryProto entry, TransactionContext context) {
-    return addIOTask(new WriteLog(entry, context));
+  Task writeLogEntry(LogEntryProto entry, LogEntryProto removedStateMachineData, TransactionContext context) {
+    return addIOTask(new WriteLog(entry, removedStateMachineData, context));
   }
 
   Task truncate(TruncationSegments ts, long index) {
@@ -486,8 +486,8 @@
     private final CompletableFuture<?> stateMachineFuture;
     private final CompletableFuture<Long> combined;
 
-    WriteLog(LogEntryProto entry, TransactionContext context) {
-      this.entry = LogProtoUtils.removeStateMachineData(entry);
+    WriteLog(LogEntryProto entry, LogEntryProto removedStateMachineData, TransactionContext context) {
+      this.entry = removedStateMachineData;
       if (this.entry == entry) {
         final StateMachineLogEntryProto proto = entry.hasStateMachineLogEntry()? entry.getStateMachineLogEntry(): null;
         if (stateMachine != null && proto != null && proto.getType() == StateMachineLogEntryProto.Type.DATASTREAM) {
diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
index a1a878e..7c4f178 100644
--- a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
@@ -26,6 +26,7 @@
 import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -46,7 +47,7 @@
   private final RaftClientRequest clientRequest;
 
   /** Exception from the {@link StateMachine} or from the log */
-  private Exception exception;
+  private volatile Exception exception;
 
   /** Data from the {@link StateMachine} */
   private final StateMachineLogEntryProto stateMachineLogEntry;
@@ -57,7 +58,7 @@
    * {@link StateMachine#startTransaction(RaftClientRequest)} and
    * {@link StateMachine#applyTransaction(TransactionContext)}.
    */
-  private Object stateMachineContext;
+  private volatile Object stateMachineContext;
 
   /**
    * Whether to commit the transaction to the RAFT Log.
@@ -67,7 +68,9 @@
   private boolean shouldCommit = true;
 
   /** Committed LogEntry. */
-  private LogEntryProto logEntry;
+  private volatile LogEntryProto logEntry;
+  /** For wrapping {@link #logEntry} in order to release the underlying buffer. */
+  private volatile ReferenceCountedObject<?> delegatedRef;
 
   private final CompletableFuture<Long> logIndexFuture = new CompletableFuture<>();
 
@@ -123,6 +126,20 @@
     return clientRequest;
   }
 
+  public void setDelegatedRef(ReferenceCountedObject<?> ref) {
+    this.delegatedRef = ref;
+  }
+
+  @Override
+  public ReferenceCountedObject<LogEntryProto> wrap(LogEntryProto entry) {
+    if (delegatedRef == null) {
+      return TransactionContext.super.wrap(entry);
+    }
+    Preconditions.assertSame(getLogEntry().getTerm(), entry.getTerm(), "entry.term");
+    Preconditions.assertSame(getLogEntry().getIndex(), entry.getIndex(), "entry.index");
+    return delegatedRef.delegate(entry);
+  }
+
   @Override
   public StateMachineLogEntryProto getStateMachineLogEntry() {
     return stateMachineLogEntry;
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
index 755476b..ece17a0 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
@@ -21,18 +21,20 @@
 import org.apache.ratis.RaftTestUtil.SimpleOperation;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.metrics.impl.DefaultTimekeeperImpl;
+import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
 import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.impl.RaftServerTestUtil;
 import org.apache.ratis.server.metrics.SegmentedRaftLogMetrics;
 import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.server.raftlog.LogProtoUtils;
+import org.apache.ratis.server.raftlog.segmented.LogSegment.Op;
 import org.apache.ratis.server.storage.RaftStorage;
-import org.apache.ratis.proto.RaftProtos.LogEntryProto;
-import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
 import org.apache.ratis.server.storage.RaftStorageTestUtils;
 import org.apache.ratis.thirdparty.com.google.protobuf.CodedOutputStream;
 import org.apache.ratis.util.FileUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.SizeInBytes;
 import org.apache.ratis.util.TraditionalBinaryPrefix;
 import org.junit.After;
@@ -143,7 +145,7 @@
       if (entry == null) {
         entry = segment.loadCache(record);
       }
-      offset += getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      offset += getEntrySize(entry, Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
     }
   }
 
@@ -202,8 +204,8 @@
     while (size < max) {
       SimpleOperation op = new SimpleOperation("m" + i);
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(op.getLogEntryContent(), term, i++ + start);
-      size += getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
-      segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      size += getEntrySize(entry, Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      segment.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
     }
 
     Assert.assertTrue(segment.getTotalFileSize() >= max);
@@ -235,18 +237,18 @@
     final StateMachineLogEntryProto m = op.getLogEntryContent();
     try {
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(m, 0, 1001);
-      segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      segment.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
       Assert.fail("should fail since the entry's index needs to be 1000");
     } catch (IllegalStateException e) {
       // the exception is expected.
     }
 
     LogEntryProto entry = LogProtoUtils.toLogEntryProto(m, 0, 1000);
-    segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+    segment.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
 
     try {
       entry = LogProtoUtils.toLogEntryProto(m, 0, 1002);
-      segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      segment.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
       Assert.fail("should fail since the entry's index needs to be 1001");
     } catch (IllegalStateException e) {
       // the exception is expected.
@@ -261,7 +263,7 @@
     for (int i = 0; i < 100; i++) {
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(
           new SimpleOperation("m" + i).getLogEntryContent(), term, i + start);
-      segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      segment.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
     }
 
     // truncate an open segment (remove 1080~1099)
@@ -316,7 +318,7 @@
         1024, 1024, ByteBuffer.allocateDirect(bufferSize))) {
       SimpleOperation op = new SimpleOperation(new String(content));
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(op.getLogEntryContent(), 0, 0);
-      size = LogSegment.getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      size = LogSegment.getEntrySize(entry, Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
       out.write(entry);
     }
     Assert.assertEquals(file.length(),
@@ -343,7 +345,7 @@
     Arrays.fill(content, (byte) 1);
     SimpleOperation op = new SimpleOperation(new String(content));
     LogEntryProto entry = LogProtoUtils.toLogEntryProto(op.getLogEntryContent(), 0, 0);
-    final long entrySize = LogSegment.getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+    final long entrySize = LogSegment.getEntrySize(entry, Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
 
     long totalSize = SegmentedRaftLogFormat.getHeaderLength();
     long preallocated = 16 * 1024;
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
index 1cf3d02..5be3c36 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
@@ -34,7 +34,9 @@
 import org.apache.ratis.server.raftlog.LogProtoUtils;
 import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLogCache.TruncationSegments;
 import org.apache.ratis.server.raftlog.segmented.LogSegment.LogRecord;
+import org.apache.ratis.server.raftlog.segmented.LogSegment.Op;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -64,7 +66,7 @@
     for (long i = start; i <= end; i++) {
       SimpleOperation m = new SimpleOperation("m" + i);
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, i);
-      s.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      s.appendToOpenSegment(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry));
     }
     if (!isOpen) {
       s.close();
@@ -148,14 +150,15 @@
   }
 
   @Test
-  public void testAppendEntry() throws Exception {
+  public void testAppendEntry() {
     LogSegment closedSegment = prepareLogSegment(0, 99, false);
     cache.addSegment(closedSegment);
 
     final SimpleOperation m = new SimpleOperation("m");
     try {
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, 0);
-      cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      cache.appendEntry(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)
+      );
       Assert.fail("the open segment is null");
     } catch (IllegalStateException ignored) {
     }
@@ -164,7 +167,8 @@
     cache.addSegment(openSegment);
     for (long index = 101; index < 200; index++) {
       LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, index);
-      cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      cache.appendEntry(Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)
+      );
     }
 
     Assert.assertNotNull(cache.getOpenSegment());
diff --git a/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java b/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java
index 564ce0b..ea512fa 100644
--- a/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java
+++ b/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java
@@ -24,6 +24,7 @@
 import org.apache.ratis.server.raftlog.LogProtoUtils;
 import org.apache.ratis.server.raftlog.segmented.LogSegmentPath;
 import org.apache.ratis.server.raftlog.segmented.LogSegment;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.SizeInBytes;
 
 import java.io.File;
@@ -69,7 +70,8 @@
   }
 
 
-  private void processLogEntry(LogEntryProto proto) {
+  private void processLogEntry(ReferenceCountedObject<LogEntryProto> ref) {
+    final LogEntryProto proto = ref.retain();
     if (proto.hasConfigurationEntry()) {
       numConfEntries++;
     } else if (proto.hasMetadataEntry()) {
@@ -77,12 +79,13 @@
     } else if (proto.hasStateMachineLogEntry()) {
       numStateMachineEntries++;
     } else {
-      System.out.println("Found invalid entry" + proto.toString());
+      System.out.println("Found an invalid entry: " + proto);
       numInvalidEntries++;
     }
 
     String str = LogProtoUtils.toLogEntryString(proto, smLogToString);
     System.out.println(str);
+    ref.release();
   }
 
   public static class Builder {