RATIS-1997. Refactor StateMachine interface to use ReferenceCountedObject (#1036)

diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
index 858e300..0ee7a60 100644
--- a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
+++ b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
@@ -42,6 +42,7 @@
 import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.ratis.util.FileUtils;
 import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.ReferenceCountedObject;
 
 import java.io.IOException;
 import java.nio.file.Path;
@@ -123,7 +124,8 @@
   }
 
   @Override
-  public CompletableFuture<Integer> write(LogEntryProto entry, TransactionContext context) {
+  public CompletableFuture<Integer> write(ReferenceCountedObject<LogEntryProto> entryRef, TransactionContext context) {
+    LogEntryProto entry = entryRef.retain();
     final FileStoreRequestProto proto = getProto(context, entry);
     if (proto.getRequestCase() != FileStoreRequestProto.RequestCase.WRITEHEADER) {
       return null;
@@ -132,9 +134,10 @@
     final WriteRequestHeaderProto h = proto.getWriteHeader();
     final CompletableFuture<Integer> f = files.write(entry.getIndex(),
         h.getPath().toStringUtf8(), h.getClose(),  h.getSync(), h.getOffset(),
-        entry.getStateMachineLogEntry().getStateMachineEntry().getStateMachineData());
+        entry.getStateMachineLogEntry().getStateMachineEntry().getStateMachineData()
+    ).whenComplete((r, e) -> entryRef.release());
     // sync only if closing the file
-    return h.getClose()? f: null;
+    return h.getClose() ? f: null;
   }
 
   static FileStoreRequestProto getProto(TransactionContext context, LogEntryProto entry) {
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java
index b1fc5ad..915b70b 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java
@@ -92,7 +92,9 @@
      * Write asynchronously the state machine data in the given log entry to this state machine.
      *
      * @return a future for the write task
+     * @deprecated Applications should implement {@link #write(ReferenceCountedObject, TransactionContext)} instead.
      */
+    @Deprecated
     default CompletableFuture<?> write(LogEntryProto entry) {
       return CompletableFuture.completedFuture(null);
     }
@@ -101,12 +103,37 @@
      * Write asynchronously the state machine data in the given log entry to this state machine.
      *
      * @return a future for the write task
+     * @deprecated Applications should implement {@link #write(ReferenceCountedObject, TransactionContext)} instead.
      */
+    @Deprecated
     default CompletableFuture<?> write(LogEntryProto entry, TransactionContext context) {
       return write(entry);
     }
 
     /**
+     * Write asynchronously the state machine data in the given log entry to this state machine.
+     *
+     * @param entryRef Reference to a log entry.
+     *                 Implementations of this method may call {@link ReferenceCountedObject#get()}
+     *                 to access the log entry before this method returns.
+     *                 If the log entry is needed after this method returns,
+     *                 e.g. for asynchronous computation or caching,
+     *                 the implementation must invoke {@link ReferenceCountedObject#retain()}
+     *                 and {@link ReferenceCountedObject#release()}.
+     * @return a future for the write task
+     */
+    default CompletableFuture<?> write(ReferenceCountedObject<LogEntryProto> entryRef, TransactionContext context) {
+      final LogEntryProto entry = entryRef.get();
+      try {
+        final LogEntryProto copy = LogEntryProto.parseFrom(entry.toByteString());
+        return write(copy, context);
+      } catch (InvalidProtocolBufferException e) {
+        return JavaUtils.completeExceptionally(new IllegalStateException(
+            "Failed to copy log entry " + TermIndex.valueOf(entry), e));
+      }
+    }
+
+    /**
      * Create asynchronously a {@link DataStream} to stream state machine data.
      * The state machine may use the first message (i.e. request.getMessage()) as the header to create the stream.
      *
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
index 4e057c0..baac0c6 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java
@@ -428,7 +428,7 @@
       // If the entry has state machine data, then the entry should be inserted
       // to statemachine first and then to the cache. Not following the order
       // will leave a spurious entry in the cache.
-      final Task write = fileLogWorker.writeLogEntry(entry, removedStateMachineData, context);
+      final Task write = fileLogWorker.writeLogEntry(entryRef, removedStateMachineData, context);
       if (stateMachineCachingEnabled) {
         // The stateMachineData will be cached inside the StateMachine itself.
         cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE,
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
index 0d1ea76..0250607 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java
@@ -438,7 +438,8 @@
     addIOTask(new StartLogSegment(segmentToClose.getEndIndex() + 1));
   }
 
-  Task writeLogEntry(LogEntryProto entry, LogEntryProto removedStateMachineData, TransactionContext context) {
+  Task writeLogEntry(ReferenceCountedObject<LogEntryProto> entry,
+      LogEntryProto removedStateMachineData, TransactionContext context) {
     return addIOTask(new WriteLog(entry, removedStateMachineData, context));
   }
 
@@ -486,25 +487,28 @@
     private final CompletableFuture<?> stateMachineFuture;
     private final CompletableFuture<Long> combined;
 
-    WriteLog(LogEntryProto entry, LogEntryProto removedStateMachineData, TransactionContext context) {
+    WriteLog(ReferenceCountedObject<LogEntryProto> entryRef, LogEntryProto removedStateMachineData,
+        TransactionContext context) {
+      LogEntryProto origEntry = entryRef.get();
       this.entry = removedStateMachineData;
-      if (this.entry == entry) {
-        final StateMachineLogEntryProto proto = entry.hasStateMachineLogEntry()? entry.getStateMachineLogEntry(): null;
+      if (this.entry == origEntry) {
+        final StateMachineLogEntryProto proto = origEntry.hasStateMachineLogEntry() ?
+            origEntry.getStateMachineLogEntry(): null;
         if (stateMachine != null && proto != null && proto.getType() == StateMachineLogEntryProto.Type.DATASTREAM) {
           final ClientInvocationId invocationId = ClientInvocationId.valueOf(proto);
           final CompletableFuture<DataStream> removed = server.getDataStreamMap().remove(invocationId);
-          this.stateMachineFuture = removed == null? stateMachine.data().link(null, entry)
-              : removed.thenApply(stream -> stateMachine.data().link(stream, entry));
+          this.stateMachineFuture = removed == null? stateMachine.data().link(null, origEntry)
+              : removed.thenApply(stream -> stateMachine.data().link(stream, origEntry));
         } else {
           this.stateMachineFuture = null;
         }
       } else {
         try {
-          // this.entry != entry iff the entry has state machine data
-          this.stateMachineFuture = stateMachine.data().write(entry, context);
+          // this.entry != origEntry if it has state machine data
+          this.stateMachineFuture = stateMachine.data().write(entryRef, context);
         } catch (Exception e) {
-          LOG.error(name + ": writeStateMachineData failed for index " + entry.getIndex()
-              + ", entry=" + LogProtoUtils.toLogEntryString(entry, stateMachine::toStateMachineLogEntryString), e);
+          LOG.error(name + ": writeStateMachineData failed for index " + origEntry.getIndex()
+              + ", entry=" + LogProtoUtils.toLogEntryString(origEntry, stateMachine::toStateMachineLogEntryString), e);
           throw e;
         }
       }
diff --git a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
index 7c40ec2..17d5a60 100644
--- a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
+++ b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
@@ -48,6 +48,7 @@
 import org.apache.ratis.util.LifeCycle;
 import org.apache.ratis.util.MD5FileUtil;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -367,7 +368,8 @@
   }
 
   @Override
-  public CompletableFuture<Void> write(LogEntryProto entry) {
+  public CompletableFuture<Void> write(ReferenceCountedObject<LogEntryProto> entry, TransactionContext context) {
+    Preconditions.assertTrue(entry.get() != null);
     return blocking.getFuture(Blocking.Type.WRITE_STATE_MACHINE_DATA);
   }
 
diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
index 3d5d5f8..38341e0 100644
--- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
+++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java
@@ -39,6 +39,7 @@
 import org.apache.ratis.server.raftlog.RaftLog;
 import org.apache.ratis.server.storage.RaftStorage;
 import org.apache.ratis.server.storage.RaftStorageTestUtils;
+import org.apache.ratis.statemachine.TransactionContext;
 import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
 import org.apache.ratis.statemachine.StateMachine;
 import org.apache.ratis.statemachine.impl.BaseStateMachine;
@@ -634,7 +635,7 @@
     final LogEntryProto entry = prepareLogEntry(0, 0, null, true);
     final StateMachine sm = new BaseStateMachine() {
       @Override
-      public CompletableFuture<Void> write(LogEntryProto entry) {
+      public CompletableFuture<Void> write(ReferenceCountedObject<LogEntryProto> entry, TransactionContext context) {
         getLifeCycle().transition(LifeCycle.State.STARTING);
         getLifeCycle().transition(LifeCycle.State.RUNNING);