GEODE-8651: MsgReader's readHeader and readMessage should be synchron… (#5665)
Co-authored-by: Anil <agingade@pivotal.io>
Co-authored-by: Darrel Schneider <darrel@vmware.com>
Co-authored-by: Bill Burcham <bill.burcham@gmail.com>
Co-authored-by: Ernie Burghardt <eburghardt@pivotal.io>
diff --git a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
index 44fbe9d..29d15e3 100644
--- a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
+++ b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
@@ -799,7 +799,8 @@
}
}
- private void notifyHandshakeWaiter(boolean success) {
+ @VisibleForTesting
+ void clearSSLInputBuffer() {
if (getConduit().useSSL() && ioFilter != null) {
synchronized (ioFilter.getSynchObject()) {
if (!ioFilter.isClosed()) {
@@ -809,7 +810,18 @@
}
}
}
+ }
+
+ @VisibleForTesting
+ void notifyHandshakeWaiter(boolean success) {
synchronized (handshakeSync) {
+ /*
+ * Return early to avoid modifying ioFilter's buffer more than once.
+ */
+ if (handshakeRead || handshakeCancelled) {
+ return;
+ }
+ clearSSLInputBuffer();
if (success) {
handshakeRead = true;
} else {
@@ -2649,25 +2661,27 @@
final KnownVersion version = getRemoteVersion();
try {
msgReader = new MsgReader(this, ioFilter, version);
-
- Header header = msgReader.readHeader();
-
ReplyMessage msg;
int len;
- if (header.getMessageType() == NORMAL_MSG_TYPE) {
- msg = (ReplyMessage) msgReader.readMessage(header);
- len = header.getMessageLength();
- } else {
- MsgDestreamer destreamer = obtainMsgDestreamer(header.getMessageId(), version);
- while (header.getMessageType() == CHUNKED_MSG_TYPE) {
+
+ synchronized (ioFilter.getSynchObject()) {
+ Header header = msgReader.readHeader();
+
+ if (header.getMessageType() == NORMAL_MSG_TYPE) {
+ msg = (ReplyMessage) msgReader.readMessage(header);
+ len = header.getMessageLength();
+ } else {
+ MsgDestreamer destreamer = obtainMsgDestreamer(header.getMessageId(), version);
+ while (header.getMessageType() == CHUNKED_MSG_TYPE) {
+ msgReader.readChunk(header, destreamer);
+ header = msgReader.readHeader();
+ }
msgReader.readChunk(header, destreamer);
- header = msgReader.readHeader();
+ msg = (ReplyMessage) destreamer.getMessage();
+ releaseMsgDestreamer(header.getMessageId(), destreamer);
+ len = destreamer.size();
}
- msgReader.readChunk(header, destreamer);
- msg = (ReplyMessage) destreamer.getMessage();
- releaseMsgDestreamer(header.getMessageId(), destreamer);
- len = destreamer.size();
- }
+ } // sync
// I'd really just like to call dispatchMessage here. However,
// that call goes through a bunch of checks that knock about
// 10% of the performance. Since this direct-ack stuff is all
diff --git a/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java b/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
index 233de78..c064afb 100644
--- a/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
+++ b/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
@@ -19,10 +19,12 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.UnknownHostException;
@@ -113,4 +115,56 @@
assertThat(connection.getP2PConnectTimeout(distributionConfig)).isEqualTo(100);
});
}
+
+ private Connection createSpiedConnection() throws IOException {
+ ConnectionTable connectionTable = mock(ConnectionTable.class);
+ Distribution distribution = mock(Distribution.class);
+ DistributionManager distributionManager = mock(DistributionManager.class);
+ DMStats dmStats = mock(DMStats.class);
+ CancelCriterion stopper = mock(CancelCriterion.class);
+ SocketCloser socketCloser = mock(SocketCloser.class);
+ TCPConduit tcpConduit = mock(TCPConduit.class);
+
+ when(connectionTable.getBufferPool()).thenReturn(new BufferPool(dmStats));
+ when(connectionTable.getConduit()).thenReturn(tcpConduit);
+ when(connectionTable.getDM()).thenReturn(distributionManager);
+ when(connectionTable.getSocketCloser()).thenReturn(socketCloser);
+ when(distributionManager.getDistribution()).thenReturn(distribution);
+ when(stopper.cancelInProgress()).thenReturn(null);
+ when(tcpConduit.getCancelCriterion()).thenReturn(stopper);
+ when(tcpConduit.getDM()).thenReturn(distributionManager);
+ when(tcpConduit.getSocketId()).thenReturn(new InetSocketAddress(getLocalHost(), 10337));
+ when(tcpConduit.getStats()).thenReturn(dmStats);
+
+ SocketChannel channel = SocketChannel.open();
+
+ Connection connection = new Connection(connectionTable, channel.socket());
+ connection = spy(connection);
+ return connection;
+ }
+
+ @Test
+ public void firstCallToNotifyHandshakeWaiterWillClearSSLInputBuffer() throws Exception {
+ Connection connection = createSpiedConnection();
+ connection.notifyHandshakeWaiter(true);
+ verify(connection, times(1)).clearSSLInputBuffer();
+ }
+
+ @Test
+ public void secondCallWithTrueToNotifyHandshakeWaiterShouldNotClearSSLInputBuffer()
+ throws Exception {
+ Connection connection = createSpiedConnection();
+ connection.notifyHandshakeWaiter(true);
+ connection.notifyHandshakeWaiter(true);
+ verify(connection, times(1)).clearSSLInputBuffer();
+ }
+
+ @Test
+ public void secondCallWithFalseToNotifyHandshakeWaiterShouldNotClearSSLInputBuffer()
+ throws Exception {
+ Connection connection = createSpiedConnection();
+ connection.notifyHandshakeWaiter(true);
+ connection.notifyHandshakeWaiter(false);
+ verify(connection, times(1)).clearSSLInputBuffer();
+ }
}