RATIS-2020. Refactor TransactionContext to supply LogEntryProto via a ReferenceCountedObject (#1042)
diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java
index 28e3fb1..e8b142f 100644
--- a/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java
+++ b/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java
@@ -164,7 +164,7 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
- final LogEntryProto entry = trx.getLogEntry();
+ final LogEntryProto entry = trx.getLogEntryUnsafe();
final AssignmentMessage assignment = new AssignmentMessage(entry.getStateMachineLogEntry().getLogData());
final long index = entry.getIndex();
diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java
index b88a763..47880af 100644
--- a/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java
+++ b/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java
@@ -247,7 +247,7 @@
*/
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
- final LogEntryProto entry = trx.getLogEntry();
+ final LogEntryProto entry = trx.getLogEntryUnsafe();
//increment the counter and update term-index
final TermIndex termIndex = TermIndex.valueOf(entry);
final int incremented = incrementCounter(termIndex);
diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
index 0ee7a60..f870cba 100644
--- a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
+++ b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java
@@ -227,7 +227,7 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
- final LogEntryProto entry = trx.getLogEntry();
+ final LogEntryProto entry = trx.getLogEntryUnsafe();
final long index = entry.getIndex();
updateLastAppliedTermIndex(entry.getTerm(), index);
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
index e019074..2ec87e3 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java
@@ -28,6 +28,7 @@
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
/**
* Context for a transaction.
@@ -94,11 +95,40 @@
LogEntryProto initLogEntry(long term, long index);
/**
- * Returns the committed log entry
- * @return the committed log entry
+ * @return a copy of the committed log entry if it exists; otherwise, returns null
+ *
+ * @deprecated Use {@link #getLogEntryRef()} or {@link #getLogEntryUnsafe()} to avoid copying.
*/
+ @Deprecated
LogEntryProto getLogEntry();
+ /**
+ * @return the committed log entry if it exists; otherwise, returns null.
+ * The returned value is safe to use only before {@link StateMachine#applyTransaction} returns.
+ * Once {@link StateMachine#applyTransaction} has returned, it is unsafe to use the log entry
+ * since the underlying buffers can possiby be released.
+ */
+ default LogEntryProto getLogEntryUnsafe() {
+ return getLogEntryRef().get();
+ }
+
+ /**
+ * Get a {@link ReferenceCountedObject} to the committed log entry.
+ *
+ * It is safe to access the log entry by calling {@link ReferenceCountedObject#get()}
+ * (without {@link ReferenceCountedObject#retain()})
+ * inside the scope of {@link StateMachine#applyTransaction}.
+ *
+ * If the log entry is needed after {@link StateMachine#applyTransaction} returns,
+ * e.g. for asynchronous computation or caching,
+ * the caller must invoke {@link ReferenceCountedObject#retain()} and {@link ReferenceCountedObject#release()}.
+ *
+ * @return a reference to the committed log entry if it exists; otherwise, returns null.
+ */
+ default ReferenceCountedObject<LogEntryProto> getLogEntryRef() {
+ return Optional.ofNullable(getLogEntryUnsafe()).map(this::wrap).orElse(null);
+ }
+
/** Wrap the given log entry as a {@link ReferenceCountedObject} for retaining it for later use. */
default ReferenceCountedObject<LogEntryProto> wrap(LogEntryProto entry) {
Preconditions.assertSame(getLogEntry().getTerm(), entry.getTerm(), "entry.term");
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index 4f313a4..e8a4adc 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -527,7 +527,7 @@
PendingRequest addPendingRequest(PendingRequests.Permit permit, RaftClientRequest request, TransactionContext entry) {
if (LOG.isDebugEnabled()) {
LOG.debug("{}: addPendingRequest at {}, entry={}", this, request,
- LogProtoUtils.toLogEntryString(entry.getLogEntry()));
+ LogProtoUtils.toLogEntryString(entry.getLogEntryUnsafe()));
}
return pendingRequests.add(permit, request, entry);
}
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
index 06a3a7b..4271d76 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java
@@ -38,7 +38,7 @@
private final CompletableFuture<RaftClientReply> futureToReturn;
PendingRequest(RaftClientRequest request, TransactionContext entry) {
- this.termIndex = entry == null? null: TermIndex.valueOf(entry.getLogEntry());
+ this.termIndex = entry == null? null: TermIndex.valueOf(entry.getLogEntryUnsafe());
this.request = request;
this.entry = entry;
if (request.is(TypeCase.FORWARD)) {
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 133cfeb..0885fb8 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -1802,6 +1802,10 @@
final ClientInvocationId invocationId = ClientInvocationId.valueOf(next.getStateMachineLogEntry());
writeIndexCache.add(invocationId.getClientId(), ((TransactionContextImpl) trx).getLogIndexFuture());
+ // TODO: RaftLog to provide the log entry as a ReferenceCountedObject as per RATIS-2028.
+ ReferenceCountedObject<?> ref = ReferenceCountedObject.wrap(next);
+ ((TransactionContextImpl) trx).setDelegatedRef(ref);
+ ref.retain();
try {
// Let the StateMachine inject logic for committed transactions in sequential order.
trx = stateMachine.applyTransactionSerial(trx);
@@ -1810,6 +1814,8 @@
return replyPendingRequest(invocationId, TermIndex.valueOf(next), stateMachineFuture);
} catch (Exception e) {
throw new RaftLogIOException(e);
+ } finally {
+ ref.release();
}
}
return null;
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
index d02994e..27eaf31 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java
@@ -318,7 +318,7 @@
void appendLog(TransactionContext operation) throws StateMachineException {
getLog().append(currentTerm.get(), operation);
- Objects.requireNonNull(operation.getLogEntry());
+ Objects.requireNonNull(operation.getLogEntryUnsafe(), "transaction-logEntry");
}
/** @return true iff the given peer id is recognized as the leader. */
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java
index de06faf..b177f0e 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java
@@ -27,6 +27,7 @@
import org.apache.ratis.server.protocol.TermIndex;
import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ProtoUtils;
@@ -221,4 +222,21 @@
final List<RaftPeer> oldListener = ProtoUtils.toRaftPeers(proto.getOldListenersList());
return ServerImplUtils.newRaftConfiguration(conf, listener, entry.getIndex(), oldConf, oldListener);
}
+
+ public static LogEntryProto copy(LogEntryProto proto) {
+ if (proto == null) {
+ return null;
+ }
+
+ if (!proto.hasStateMachineLogEntry() && !proto.hasMetadataEntry() && !proto.hasConfigurationEntry()) {
+ // empty entry, just return as is.
+ return proto;
+ }
+
+ try {
+ return LogEntryProto.parseFrom(proto.toByteString());
+ } catch (InvalidProtocolBufferException e) {
+ throw new IllegalArgumentException("Failed to copy log entry " + TermIndex.valueOf(proto), e);
+ }
+ }
}
diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java
index c987c53..98f270d 100644
--- a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java
+++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java
@@ -18,7 +18,7 @@
package org.apache.ratis.statemachine.impl;
-import org.apache.ratis.proto.RaftProtos;
+import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftGroupId;
@@ -110,10 +110,10 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
// return the same message contained in the entry
- RaftProtos.LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+ final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntryUnsafe());
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
return CompletableFuture.completedFuture(
- Message.valueOf(trx.getLogEntry().getStateMachineLogEntry().getLogData()));
+ Message.valueOf(entry.getStateMachineLogEntry().getLogData()));
}
@Override
diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
index 7c4f178..44bd32c 100644
--- a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java
@@ -25,12 +25,14 @@
import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ReferenceCountedObject;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
+import java.util.function.Supplier;
/**
* Implementation of {@link TransactionContext}
@@ -69,6 +71,9 @@
/** Committed LogEntry. */
private volatile LogEntryProto logEntry;
+ /** Committed LogEntry copy. */
+ private volatile Supplier<LogEntryProto> logEntryCopy;
+
/** For wrapping {@link #logEntry} in order to release the underlying buffer. */
private volatile ReferenceCountedObject<?> delegatedRef;
@@ -112,7 +117,7 @@
*/
TransactionContextImpl(RaftPeerRole serverRole, StateMachine stateMachine, LogEntryProto logEntry) {
this(serverRole, null, stateMachine, logEntry.getStateMachineLogEntry());
- this.logEntry = logEntry;
+ setLogEntry(logEntry);
this.logIndexFuture.complete(logEntry.getIndex());
}
@@ -135,8 +140,10 @@
if (delegatedRef == null) {
return TransactionContext.super.wrap(entry);
}
- Preconditions.assertSame(getLogEntry().getTerm(), entry.getTerm(), "entry.term");
- Preconditions.assertSame(getLogEntry().getIndex(), entry.getIndex(), "entry.index");
+ final LogEntryProto expected = getLogEntryUnsafe();
+ Objects.requireNonNull(expected, "logEntry == null");
+ Preconditions.assertSame(expected.getTerm(), entry.getTerm(), "entry.term");
+ Preconditions.assertSame(expected.getIndex(), entry.getIndex(), "entry.index");
return delegatedRef.delegate(entry);
}
@@ -168,18 +175,31 @@
Objects.requireNonNull(stateMachineLogEntry, "stateMachineLogEntry == null");
logIndexFuture.complete(index);
- return logEntry = LogProtoUtils.toLogEntryProto(stateMachineLogEntry, term, index);
+ return setLogEntry(LogProtoUtils.toLogEntryProto(stateMachineLogEntry, term, index));
}
public CompletableFuture<Long> getLogIndexFuture() {
return logIndexFuture;
}
+ private LogEntryProto setLogEntry(LogEntryProto entry) {
+ this.logEntry = entry;
+ this.logEntryCopy = MemoizedSupplier.valueOf(() -> LogProtoUtils.copy(entry));
+ return entry;
+ }
+
+
@Override
public LogEntryProto getLogEntry() {
+ return logEntryCopy == null ? null : logEntryCopy.get();
+ }
+
+ @Override
+ public LogEntryProto getLogEntryUnsafe() {
return logEntry;
}
+
@Override
public TransactionContext setException(Exception ioe) {
this.exception = ioe;
@@ -209,4 +229,8 @@
// call this to let the SM know that Transaction cannot be synced
return stateMachine.cancelTransaction(this);
}
+
+ public static LogEntryProto getLogEntry(TransactionContext context) {
+ return ((TransactionContextImpl) context).logEntry;
+ }
}
diff --git a/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java b/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
index eea7559..ead2a8b 100644
--- a/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/ReadOnlyRequestTests.java
@@ -19,6 +19,7 @@
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftPeerId;
@@ -366,10 +367,11 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
- LOG.debug("apply trx with index=" + trx.getLogEntry().getIndex());
- updateLastAppliedTermIndex(trx.getLogEntry().getTerm(), trx.getLogEntry().getIndex());
+ final RaftProtos.LogEntryProto logEntry = trx.getLogEntryUnsafe();
+ LOG.debug("apply trx with index=" + logEntry.getIndex());
+ updateLastAppliedTermIndex(logEntry.getTerm(), logEntry.getIndex());
- String command = trx.getLogEntry().getStateMachineLogEntry()
+ String command = logEntry.getStateMachineLogEntry()
.getLogData().toString(StandardCharsets.UTF_8);
LOG.info("receive command: {}", command);
diff --git a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
index 28f8e6a..246abb9 100644
--- a/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java
@@ -56,7 +56,7 @@
}
}
}
- RaftProtos.LogEntryProto entry = trx.getLogEntry();
+ final RaftProtos.LogEntryProto entry = trx.getLogEntryUnsafe();
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
return CompletableFuture.completedFuture(new RaftTestUtil.SimpleMessage("done"));
}
diff --git a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
index 17d5a60..7dd1db3 100644
--- a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
+++ b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
@@ -84,7 +84,7 @@
return (SimpleStateMachine4Testing)s.getStateMachine();
}
- private final SortedMap<Long, LogEntryProto> indexMap = Collections.synchronizedSortedMap(new TreeMap<>());
+ private final SortedMap<Long, ReferenceCountedObject<LogEntryProto>> indexMap = Collections.synchronizedSortedMap(new TreeMap<>());
private final SortedMap<String, LogEntryProto> dataMap = Collections.synchronizedSortedMap(new TreeMap<>());
private final Daemon checkpointer;
private final SimpleStateMachineStorage storage = new SimpleStateMachineStorage();
@@ -199,8 +199,9 @@
return leaderElectionTimeoutInfo;
}
- private void put(LogEntryProto entry) {
- final LogEntryProto previous = indexMap.put(entry.getIndex(), entry);
+ private void put(ReferenceCountedObject<LogEntryProto> entryRef) {
+ LogEntryProto entry = entryRef.retain();
+ final ReferenceCountedObject<LogEntryProto> previous = indexMap.put(entry.getIndex(), entryRef);
Preconditions.assertNull(previous, "previous");
final String s = entry.getStateMachineLogEntry().getLogData().toStringUtf8();
dataMap.put(s, entry);
@@ -246,27 +247,17 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
blocking.await(Blocking.Type.APPLY_TRANSACTION);
- LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+ ReferenceCountedObject<LogEntryProto> entryRef = Objects.requireNonNull(trx.getLogEntryRef());
+ LogEntryProto entry = entryRef.get();
LOG.info("applyTransaction for log index {}", entry.getIndex());
- // TODO: Logs kept in StateMachine's cache may be corrupted. Copy for now to have the test pass.
- // Use ReferenceCount per RATIS-1997.
- LogEntryProto copied = copy(entry);
- put(copied);
+ put(entryRef);
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
final SimpleMessage m = new SimpleMessage(entry.getIndex() + " OK");
return collecting.collect(Collecting.Type.APPLY_TRANSACTION, m);
}
- private LogEntryProto copy(LogEntryProto log) {
- try {
- return LogEntryProto.parseFrom(log.toByteString());
- } catch (InvalidProtocolBufferException e) {
- throw new IllegalStateException("Error copying log entry", e);
- }
- }
-
@Override
public long takeSnapshot() {
final TermIndex termIndex = getLastAppliedTermIndex();
@@ -280,7 +271,8 @@
LOG.debug("Taking a snapshot with {}, file:{}", termIndex, snapshotFile);
try (SegmentedRaftLogOutputStream out = new SegmentedRaftLogOutputStream(snapshotFile, false,
segmentMaxSize, preallocatedSize, ByteBuffer.allocateDirect(bufferSize))) {
- for (final LogEntryProto entry : indexMap.values()) {
+ for (final ReferenceCountedObject<LogEntryProto> entryRef : indexMap.values()) {
+ LogEntryProto entry = entryRef.get();
if (entry.getIndex() > endIndex) {
break;
} else {
@@ -315,7 +307,7 @@
snapshot.getFile().getPath().toFile(), 0, endIndex, false)) {
LogEntryProto entry;
while ((entry = in.nextEntry()) != null) {
- put(entry);
+ put(ReferenceCountedObject.wrap(entry));
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
}
}
@@ -390,10 +382,11 @@
running = false;
checkpointer.interrupt();
});
+ indexMap.values().forEach(ReferenceCountedObject::release);
}
public LogEntryProto[] getContent() {
- return indexMap.values().toArray(new LogEntryProto[0]);
+ return indexMap.values().stream().map(ReferenceCountedObject::get).toArray(LogEntryProto[]::new);
}
public void blockStartTransaction() {
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
index e4a930f..2970bbe 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -165,7 +165,7 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
- final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
+ final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntryUnsafe());
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
final ByteString bytesWritten = bytesWritten2ByteString(s.getDataChannel().getBytesWritten());
diff --git a/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java b/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java
index 0941898..07ea4ed 100644
--- a/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java
+++ b/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java
@@ -91,7 +91,7 @@
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
try {
- assertNotNull(trx.getLogEntry());
+ assertNotNull(trx.getLogEntryUnsafe());
assertNotNull(trx.getStateMachineLogEntry());
Object context = trx.getStateMachineContext();
if (isLeader.get()) {