RATIS-2028. Refactor RaftLog to supply log as ReferenceCountedObject (#1045)

diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
index e504462..e4fbd66 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
@@ -21,6 +21,7 @@
 import org.apache.ratis.server.metrics.RaftLogMetrics;
 import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.server.storage.RaftStorageMetadata;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.TimeDuration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -57,11 +58,25 @@
 
   /**
    * @return null if the log entry is not found in this log;
-   *         otherwise, return the log entry corresponding to the given index.
+   *         otherwise, return a copy of the log entry corresponding to the given index.
+   * @deprecated use {@link RaftLog#retainLog(long)} instead in order to avoid copying.
    */
+  @Deprecated
   LogEntryProto get(long index) throws RaftLogIOException;
 
   /**
+   * @return a retained {@link ReferenceCountedObject} to the log entry corresponding to the given index if it exists;
+   *         otherwise, return null.
+   *         Since the returned reference is retained, the caller must call {@link ReferenceCountedObject#release()}}
+   *         after use.
+   */
+  default ReferenceCountedObject<LogEntryProto> retainLog(long index) throws RaftLogIOException {
+    ReferenceCountedObject<LogEntryProto> wrap = ReferenceCountedObject.wrap(get(index));
+    wrap.retain();
+    return wrap;
+  }
+
+  /**
    * @return null if the log entry is not found in this log;
    *         otherwise, return the {@link EntryWithData} corresponding to the given index.
    */
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 17a741e..e8aeb66 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -1800,7 +1800,9 @@
         MemoizedSupplier.valueOf(() -> stateMachine.startTransaction(entry, getInfo().getCurrentRole())));
   }
 
-  CompletableFuture<Message> applyLogToStateMachine(LogEntryProto next) throws RaftLogIOException {
+  CompletableFuture<Message> applyLogToStateMachine(ReferenceCountedObject<LogEntryProto> nextRef)
+      throws RaftLogIOException {
+    LogEntryProto next = nextRef.get();
     if (!next.hasStateMachineLogEntry()) {
       stateMachine.event().notifyTermIndexUpdated(next.getTerm(), next.getIndex());
     }
@@ -1815,11 +1817,7 @@
       TransactionContext trx = getTransactionContext(next, true);
       final ClientInvocationId invocationId = ClientInvocationId.valueOf(next.getStateMachineLogEntry());
       writeIndexCache.add(invocationId.getClientId(), ((TransactionContextImpl) trx).getLogIndexFuture());
-
-      // TODO: RaftLog to provide the log entry as a ReferenceCountedObject as per RATIS-2028.
-      ReferenceCountedObject<?> ref = ReferenceCountedObject.wrap(next);
-      ((TransactionContextImpl) trx).setDelegatedRef(ref);
-      ref.retain();
+      ((TransactionContextImpl) trx).setDelegatedRef(nextRef);
       try {
         // Let the StateMachine inject logic for committed transactions in sequential order.
         trx = stateMachine.applyTransactionSerial(trx);
@@ -1828,8 +1826,6 @@
         return replyPendingRequest(invocationId, TermIndex.valueOf(next), stateMachineFuture);
       } catch (Exception e) {
         throw new RaftLogIOException(e);
-      } finally {
-        ref.release();
       }
     }
     return null;
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
index 5f6e972..b01270d 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java
@@ -235,10 +235,17 @@
     final long committed = raftLog.getLastCommittedIndex();
     for(long applied; (applied = getLastAppliedIndex()) < committed && state == State.RUNNING && !shouldStop(); ) {
       final long nextIndex = applied + 1;
-      final LogEntryProto next = raftLog.get(nextIndex);
-      if (next != null) {
+      final ReferenceCountedObject<LogEntryProto> next = raftLog.retainLog(nextIndex);
+      if (next == null) {
+        LOG.debug("{}: logEntry {} is null. There may be snapshot to load. state:{}",
+            this, nextIndex, state);
+        break;
+      }
+
+      try {
+        final LogEntryProto entry = next.get();
         if (LOG.isTraceEnabled()) {
-          LOG.trace("{}: applying nextIndex={}, nextLog={}", this, nextIndex, LogProtoUtils.toLogEntryString(next));
+          LOG.trace("{}: applying nextIndex={}, nextLog={}", this, nextIndex, LogProtoUtils.toLogEntryString(entry));
         } else {
           LOG.debug("{}: applying nextIndex={}", this, nextIndex);
         }
@@ -252,10 +259,8 @@
         } else {
           notifyAppliedIndex(incremented);
         }
-      } else {
-        LOG.debug("{}: logEntry {} is null. There may be snapshot to load. state:{}",
-            this, nextIndex, state);
-        break;
+      } finally {
+        next.release();
       }
     }
     return futures;
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java
index 49e66e2..0a9a1c9 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java
@@ -240,13 +240,19 @@
       //log neither lastMetadataEntry, nor entries with a smaller commit index.
       return false;
     }
+    ReferenceCountedObject<LogEntryProto> ref = null;
     try {
-      if (get(newCommitIndex).hasMetadataEntry()) {
+      ref = retainLog(newCommitIndex);
+      if (ref.get().hasMetadataEntry()) {
         // do not log the metadata entry
         return false;
       }
     } catch(RaftLogIOException e) {
       LOG.error("Failed to get log entry for index " + newCommitIndex, e);
+    } finally {
+      if (ref != null) {
+        ref.release();
+      }
     }
     return true;
   }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java
index fc7973a..feedaee 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java
@@ -22,8 +22,10 @@
 import org.apache.ratis.server.metrics.RaftLogMetricsBase;
 import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.server.raftlog.LogProtoUtils;
 import org.apache.ratis.server.raftlog.RaftLogBase;
 import org.apache.ratis.server.raftlog.LogEntryHeader;
+import org.apache.ratis.server.raftlog.RaftLogIOException;
 import org.apache.ratis.server.storage.RaftStorageMetadata;
 import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.util.AutoCloseableLock;
@@ -45,8 +47,13 @@
   static class EntryList {
     private final List<ReferenceCountedObject<LogEntryProto>> entries = new ArrayList<>();
 
+    ReferenceCountedObject<LogEntryProto> getRef(int i) {
+      return i >= 0 && i < entries.size() ? entries.get(i) : null;
+    }
+
     LogEntryProto get(int i) {
-      return i >= 0 && i < entries.size() ? entries.get(i).get() : null;
+      final ReferenceCountedObject<LogEntryProto> ref = getRef(i);
+      return ref != null ? ref.get() : null;
     }
 
     TermIndex getTermIndex(int i) {
@@ -108,16 +115,34 @@
   }
 
   @Override
-  public LogEntryProto get(long index) {
+  public LogEntryProto get(long index) throws RaftLogIOException {
+    final ReferenceCountedObject<LogEntryProto> ref = retainLog(index);
+    try {
+      return LogProtoUtils.copy(ref.get());
+    } finally {
+      ref.release();
+    }
+  }
+
+  @Override
+  public ReferenceCountedObject<LogEntryProto> retainLog(long index) {
     checkLogState();
-    try(AutoCloseableLock readLock = readLock()) {
-      return entries.get(Math.toIntExact(index));
+    try (AutoCloseableLock readLock = readLock()) {
+      ReferenceCountedObject<LogEntryProto> ref = entries.getRef(Math.toIntExact(index));
+      ref.retain();
+      return ref;
     }
   }
 
   @Override
   public EntryWithData getEntryWithData(long index) {
-    return newEntryWithData(get(index), null);
+    // TODO. The reference counted object should be passed to LogAppender RATIS-2026.
+    ReferenceCountedObject<LogEntryProto> ref = retainLog(index);
+    try {
+      return newEntryWithData(ref.get(), null);
+    } finally {
+      ref.release();
+    }
   }
 
   @Override
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
index 68da350..2fcd791 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java
@@ -224,7 +224,7 @@
    *
    * In the future we can make the cache loader configurable if necessary.
    */
-  class LogEntryLoader extends CacheLoader<LogRecord, LogEntryProto> {
+  class LogEntryLoader extends CacheLoader<LogRecord, ReferenceCountedObject<LogEntryProto>> {
     private final SegmentedRaftLogMetrics raftLogMetrics;
 
     LogEntryLoader(SegmentedRaftLogMetrics raftLogMetrics) {
@@ -232,18 +232,19 @@
     }
 
     @Override
-    public LogEntryProto load(LogRecord key) throws IOException {
+    public ReferenceCountedObject<LogEntryProto> load(LogRecord key) throws IOException {
       final File file = getFile();
       // note the loading should not exceed the endIndex: it is possible that
       // the on-disk log file should be truncated but has not been done yet.
-      final AtomicReference<LogEntryProto> toReturn = new AtomicReference<>();
+      final AtomicReference<ReferenceCountedObject<LogEntryProto>> toReturn = new AtomicReference<>();
       final LogSegmentStartEnd startEnd = LogSegmentStartEnd.valueOf(startIndex, endIndex, isOpen);
       readSegmentFile(file, startEnd, maxOpSize, getLogCorruptionPolicy(), raftLogMetrics, entryRef -> {
         final LogEntryProto entry = entryRef.retain();
         final TermIndex ti = TermIndex.valueOf(entry);
         putEntryCache(ti, entryRef, Op.LOAD_SEGMENT_FILE);
         if (ti.equals(key.getTermIndex())) {
-          toReturn.set(entry);
+          entryRef.retain();
+          toReturn.set(entryRef);
         }
         entryRef.release();
       });
@@ -260,10 +261,8 @@
       return size.get();
     }
 
-    LogEntryProto get(TermIndex ti) {
-      return Optional.ofNullable(map.get(ti))
-          .map(ReferenceCountedObject::get)
-          .orElse(null);
+    ReferenceCountedObject<LogEntryProto> get(TermIndex ti) {
+      return map.get(ti);
     }
 
     void clear() {
@@ -386,15 +385,15 @@
     return record;
   }
 
-  LogEntryProto getEntryFromCache(TermIndex ti) {
+  ReferenceCountedObject<LogEntryProto> getEntryFromCache(TermIndex ti) {
     return entryCache.get(ti);
   }
 
   /**
    * Acquire LogSegment's monitor so that there is no concurrent loading.
    */
-  synchronized LogEntryProto loadCache(LogRecord record) throws RaftLogIOException {
-    LogEntryProto entry = entryCache.get(record.getTermIndex());
+  synchronized ReferenceCountedObject<LogEntryProto> loadCache(LogRecord record) throws RaftLogIOException {
+    ReferenceCountedObject<LogEntryProto> entry = entryCache.get(record.getTermIndex());
     if (entry != null) {
       return entry;
     }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
index baac0c6..bb0793a 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
@@ -180,11 +180,17 @@
 
       @Override
       public void notifyTruncatedLogEntry(TermIndex ti) {
+        ReferenceCountedObject<LogEntryProto> ref = null;
         try {
-          final LogEntryProto entry = get(ti.getIndex());
+          ref = retainLog(ti.getIndex());
+          final LogEntryProto entry = ref != null ? ref.get() : null;
           notifyTruncatedLogEntry.accept(entry);
         } catch (RaftLogIOException e) {
           LOG.error("{}: Failed to read log {}", getName(), ti, e);
+        } finally {
+          if (ref != null) {
+            ref.release();
+          }
         }
       }
 
@@ -272,6 +278,19 @@
 
   @Override
   public LogEntryProto get(long index) throws RaftLogIOException {
+    final ReferenceCountedObject<LogEntryProto> ref = retainLog(index);
+    if (ref == null) {
+      return null;
+    }
+    try {
+      return LogProtoUtils.copy(ref.get());
+    } finally {
+      ref.release();
+    }
+  }
+
+  @Override
+  public ReferenceCountedObject<LogEntryProto> retainLog(long index) throws RaftLogIOException {
     checkLogState();
     final LogSegment segment;
     final LogRecord record;
@@ -284,9 +303,10 @@
       if (record == null) {
         return null;
       }
-      final LogEntryProto entry = segment.getEntryFromCache(record.getTermIndex());
+      final ReferenceCountedObject<LogEntryProto> entry = segment.getEntryFromCache(record.getTermIndex());
       if (entry != null) {
         getRaftLogMetrics().onRaftLogCacheHit();
+        entry.retain();
         return entry;
       }
     }
@@ -299,10 +319,19 @@
 
   @Override
   public EntryWithData getEntryWithData(long index) throws RaftLogIOException {
-    final LogEntryProto entry = get(index);
-    if (entry == null) {
+    final ReferenceCountedObject<LogEntryProto> entryRef = retainLog(index);
+    if (entryRef == null) {
       throw new RaftLogIOException("Log entry not found: index = " + index);
     }
+    try {
+      // TODO. The reference counted object should be passed to LogAppender RATIS-2026.
+      return getEntryWithData(entryRef.get());
+    } finally {
+      entryRef.release();
+    }
+  }
+
+  private EntryWithData getEntryWithData(LogEntryProto entry) throws RaftLogIOException {
     if (!LogProtoUtils.isStateMachineDataEmpty(entry)) {
       return newEntryWithData(entry, null);
     }
diff --git a/ratis-server/src/test/java/org/apache/ratis/server/storage/RaftStorageTestUtils.java b/ratis-server/src/test/java/org/apache/ratis/server/storage/RaftStorageTestUtils.java
index bb4f6a0..ee30bd2 100644
--- a/ratis-server/src/test/java/org/apache/ratis/server/storage/RaftStorageTestUtils.java
+++ b/ratis-server/src/test/java/org/apache/ratis/server/storage/RaftStorageTestUtils.java
@@ -21,12 +21,15 @@
 import static org.apache.ratis.server.metrics.SegmentedRaftLogMetrics.RATIS_LOG_WORKER_METRICS;
 
 import org.apache.ratis.metrics.RatisMetrics;
+import org.apache.ratis.proto.RaftProtos.LogEntryProto;
 import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.server.protocol.TermIndex;
 import org.apache.ratis.server.raftlog.LogProtoUtils;
+import org.apache.ratis.server.raftlog.RaftLog;
 import org.apache.ratis.server.raftlog.RaftLogBase;
 import org.apache.ratis.server.raftlog.RaftLogIOException;
 import org.apache.ratis.util.AutoCloseableLock;
+import org.apache.ratis.util.ReferenceCountedObject;
 
 import java.io.File;
 import java.io.IOException;
@@ -72,11 +75,22 @@
       b.append(i == committed? 'c': ' ');
       b.append(String.format("%3d: ", i));
       try {
-        b.append(LogProtoUtils.toLogEntryString(log.get(i)));
+        b.append(LogProtoUtils.toLogEntryString(getLogUnsafe(log, i)));
       } catch (RaftLogIOException e) {
         b.append(e);
       }
       println.accept(b.toString());
     }
   }
+
+  static LogEntryProto getLogUnsafe(RaftLog log, long index) throws RaftLogIOException {
+    ReferenceCountedObject<LogEntryProto> ref = log.retainLog(index);
+    try {
+      return ref != null ? ref.get() : null;
+    } finally {
+      if (ref != null) {
+        ref.release();
+      }
+    }
+  }
 }
diff --git a/ratis-server/src/test/java/org/apache/ratis/statemachine/RaftSnapshotBaseTest.java b/ratis-server/src/test/java/org/apache/ratis/statemachine/RaftSnapshotBaseTest.java
index fe1a97d..9a716ca 100644
--- a/ratis-server/src/test/java/org/apache/ratis/statemachine/RaftSnapshotBaseTest.java
+++ b/ratis-server/src/test/java/org/apache/ratis/statemachine/RaftSnapshotBaseTest.java
@@ -21,6 +21,7 @@
 import static org.apache.ratis.server.impl.StateMachineMetrics.RATIS_STATEMACHINE_METRICS_DESC;
 import static org.apache.ratis.server.impl.StateMachineMetrics.STATEMACHINE_TAKE_SNAPSHOT_TIMER;
 import static org.apache.ratis.metrics.RatisMetrics.RATIS_APPLICATION_NAME_METRICS;
+import static org.apache.ratis.server.storage.RaftStorageTestUtils.getLogUnsafe;
 
 import org.apache.ratis.BaseTest;
 import org.apache.ratis.metrics.LongCounter;
@@ -43,6 +44,7 @@
 import org.apache.ratis.server.raftlog.RaftLog;
 import org.apache.ratis.server.raftlog.segmented.LogSegmentPath;
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.server.storage.RaftStorageTestUtils;
 import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
 import org.apache.ratis.statemachine.impl.SimpleStateMachineStorage;
 import org.apache.ratis.util.FileUtils;
@@ -95,7 +97,7 @@
   public static void assertLogContent(RaftServer.Division server, boolean isLeader) throws Exception {
     final RaftLog log = server.getRaftLog();
     final long lastIndex = log.getLastEntryTermIndex().getIndex();
-    final LogEntryProto e = log.get(lastIndex);
+    final LogEntryProto e = getLogUnsafe(log, lastIndex);
     Assert.assertTrue(e.hasMetadataEntry());
 
     JavaUtils.attemptRepeatedly(() -> {
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
index 2970bbe..4713891 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -69,6 +69,8 @@
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ThreadLocalRandom;
 
+import static org.apache.ratis.server.storage.RaftStorageTestUtils.getLogUnsafe;
+
 public interface DataStreamTestUtils {
   Logger LOG = LoggerFactory.getLogger(DataStreamTestUtils.class);
 
@@ -383,7 +385,7 @@
 
   static LogEntryProto searchLogEntry(ClientInvocationId invocationId, RaftLog log) throws Exception {
     for (LogEntryHeader termIndex : log.getEntries(0, Long.MAX_VALUE)) {
-      final LogEntryProto entry = log.get(termIndex.getIndex());
+      final LogEntryProto entry = getLogUnsafe(log, termIndex.getIndex());
       if (entry.hasStateMachineLogEntry()) {
         if (invocationId.match(entry.getStateMachineLogEntry())) {
           return entry;
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java b/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
index db4e92b..11311f3 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/ServerRestartTests.java
@@ -65,6 +65,8 @@
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
+import static org.apache.ratis.server.storage.RaftStorageTestUtils.getLogUnsafe;
+
 /**
  * Test restarting raft peers.
  */
@@ -268,10 +270,10 @@
 
     final long lastIndex = leaderLog.getLastEntryTermIndex().getIndex();
     LOG.info("{}: leader lastIndex={}", leaderId, lastIndex);
-    final LogEntryProto lastEntry = leaderLog.get(lastIndex);
+    final LogEntryProto lastEntry = getLogUnsafe(leaderLog, lastIndex);
     LOG.info("{}: leader lastEntry entry[{}] = {}", leaderId, lastIndex, LogProtoUtils.toLogEntryString(lastEntry));
     final long loggedCommitIndex = lastEntry.getMetadataEntry().getCommitIndex();
-    final LogEntryProto lastCommittedEntry = leaderLog.get(loggedCommitIndex);
+    final LogEntryProto lastCommittedEntry = getLogUnsafe(leaderLog, loggedCommitIndex);
     LOG.info("{}: leader lastCommittedEntry = entry[{}] = {}",
         leaderId, loggedCommitIndex, LogProtoUtils.toLogEntryString(lastCommittedEntry));
 
@@ -317,11 +319,11 @@
   static void assertLastLogEntry(RaftServer.Division server) throws RaftLogIOException {
     final RaftLog raftLog = server.getRaftLog();
     final long lastIndex = raftLog.getLastEntryTermIndex().getIndex();
-    final LogEntryProto lastEntry = raftLog.get(lastIndex);
+    final LogEntryProto lastEntry = getLogUnsafe(raftLog, lastIndex);
     Assertions.assertTrue(lastEntry.hasMetadataEntry());
 
     final long loggedCommitIndex = lastEntry.getMetadataEntry().getCommitIndex();
-    final LogEntryProto lastCommittedEntry = raftLog.get(loggedCommitIndex);
+    final LogEntryProto lastCommittedEntry = getLogUnsafe(raftLog, loggedCommitIndex);
     Assertions.assertTrue(lastCommittedEntry.hasStateMachineLogEntry());
 
     final SimpleStateMachine4Testing leaderStateMachine = SimpleStateMachine4Testing.get(server);
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
index 8355c67..7692ad0 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java
@@ -141,11 +141,11 @@
       Assertions.assertEquals(term, ti.getTerm());
       Assertions.assertEquals(offset, record.getOffset());
 
-      LogEntryProto entry = segment.getEntryFromCache(ti);
+      ReferenceCountedObject<LogEntryProto> entry = segment.getEntryFromCache(ti);
       if (entry == null) {
         entry = segment.loadCache(record);
       }
-      offset += getEntrySize(entry, Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
+      offset += getEntrySize(entry.get(), Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE);
     }
   }
 
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
index 55fd6fb..7b20bab 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
@@ -74,6 +74,7 @@
 
 import static java.lang.Boolean.FALSE;
 import static java.lang.Boolean.TRUE;
+import static org.apache.ratis.server.storage.RaftStorageTestUtils.getLogUnsafe;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.params.provider.Arguments.arguments;
 
@@ -204,7 +205,7 @@
 
   private LogEntryProto getLastEntry(SegmentedRaftLog raftLog)
       throws IOException {
-    return raftLog.get(raftLog.getLastEntryTermIndex().getIndex());
+    return getLogUnsafe(raftLog, raftLog.getLastEntryTermIndex().getIndex());
   }
 
   @ParameterizedTest
@@ -229,7 +230,7 @@
       LogEntryProto[] entriesFromLog = Arrays.stream(termIndices)
           .map(ti -> {
             try {
-              return raftLog.get(ti.getIndex());
+              return getLogUnsafe(raftLog, ti.getIndex());
             } catch (IOException e) {
               throw new RuntimeException(e);
             }
@@ -451,7 +452,7 @@
       LogEntryProto[] entriesFromLog = Arrays.stream(termIndices)
           .map(ti -> {
             try {
-              return raftLog.get(ti.getIndex());
+              return getLogUnsafe(raftLog, ti.getIndex());
             } catch (IOException e) {
               throw new RuntimeException(e);
             }
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
index fa892b2..8717232 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java
@@ -81,8 +81,8 @@
     for (long index = start; index <= end; index++) {
       final LogSegment segment = cache.getSegment(index);
       final LogRecord record = segment.getLogRecord(index);
-      final LogEntryProto entry = segment.getEntryFromCache(record.getTermIndex());
-      Assertions.assertEquals(index, entry.getIndex());
+      final ReferenceCountedObject<LogEntryProto> entry = segment.getEntryFromCache(record.getTermIndex());
+      Assertions.assertEquals(index, entry.get().getIndex());
     }
 
     long[] offsets = new long[]{start, start + 1, start + (end - start) / 2,