GEODE-8651: MsgReader's readHeader and readMessage should be synchron… (#5665)


    Co-authored-by: Anil <agingade@pivotal.io>
    Co-authored-by: Darrel Schneider <darrel@vmware.com>
    Co-authored-by: Bill Burcham <bill.burcham@gmail.com>
    Co-authored-by: Ernie Burghardt <eburghardt@pivotal.io>
diff --git a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
index 44fbe9d..29d15e3 100644
--- a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
+++ b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
@@ -799,7 +799,8 @@
     }
   }
 
-  private void notifyHandshakeWaiter(boolean success) {
+  @VisibleForTesting
+  void clearSSLInputBuffer() {
     if (getConduit().useSSL() && ioFilter != null) {
       synchronized (ioFilter.getSynchObject()) {
         if (!ioFilter.isClosed()) {
@@ -809,7 +810,18 @@
         }
       }
     }
+  }
+
+  @VisibleForTesting
+  void notifyHandshakeWaiter(boolean success) {
     synchronized (handshakeSync) {
+      /*
+       * Return early to avoid modifying ioFilter's buffer more than once.
+       */
+      if (handshakeRead || handshakeCancelled) {
+        return;
+      }
+      clearSSLInputBuffer();
       if (success) {
         handshakeRead = true;
       } else {
@@ -2649,25 +2661,27 @@
     final KnownVersion version = getRemoteVersion();
     try {
       msgReader = new MsgReader(this, ioFilter, version);
-
-      Header header = msgReader.readHeader();
-
       ReplyMessage msg;
       int len;
-      if (header.getMessageType() == NORMAL_MSG_TYPE) {
-        msg = (ReplyMessage) msgReader.readMessage(header);
-        len = header.getMessageLength();
-      } else {
-        MsgDestreamer destreamer = obtainMsgDestreamer(header.getMessageId(), version);
-        while (header.getMessageType() == CHUNKED_MSG_TYPE) {
+
+      synchronized (ioFilter.getSynchObject()) {
+        Header header = msgReader.readHeader();
+
+        if (header.getMessageType() == NORMAL_MSG_TYPE) {
+          msg = (ReplyMessage) msgReader.readMessage(header);
+          len = header.getMessageLength();
+        } else {
+          MsgDestreamer destreamer = obtainMsgDestreamer(header.getMessageId(), version);
+          while (header.getMessageType() == CHUNKED_MSG_TYPE) {
+            msgReader.readChunk(header, destreamer);
+            header = msgReader.readHeader();
+          }
           msgReader.readChunk(header, destreamer);
-          header = msgReader.readHeader();
+          msg = (ReplyMessage) destreamer.getMessage();
+          releaseMsgDestreamer(header.getMessageId(), destreamer);
+          len = destreamer.size();
         }
-        msgReader.readChunk(header, destreamer);
-        msg = (ReplyMessage) destreamer.getMessage();
-        releaseMsgDestreamer(header.getMessageId(), destreamer);
-        len = destreamer.size();
-      }
+      } // sync
       // I'd really just like to call dispatchMessage here. However,
       // that call goes through a bunch of checks that knock about
       // 10% of the performance. Since this direct-ack stuff is all
diff --git a/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java b/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
index 233de78..c064afb 100644
--- a/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
+++ b/geode-core/src/test/java/org/apache/geode/internal/tcp/ConnectionTest.java
@@ -19,10 +19,12 @@
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.Socket;
 import java.net.UnknownHostException;
@@ -113,4 +115,56 @@
       assertThat(connection.getP2PConnectTimeout(distributionConfig)).isEqualTo(100);
     });
   }
+
+  private Connection createSpiedConnection() throws IOException {
+    ConnectionTable connectionTable = mock(ConnectionTable.class);
+    Distribution distribution = mock(Distribution.class);
+    DistributionManager distributionManager = mock(DistributionManager.class);
+    DMStats dmStats = mock(DMStats.class);
+    CancelCriterion stopper = mock(CancelCriterion.class);
+    SocketCloser socketCloser = mock(SocketCloser.class);
+    TCPConduit tcpConduit = mock(TCPConduit.class);
+
+    when(connectionTable.getBufferPool()).thenReturn(new BufferPool(dmStats));
+    when(connectionTable.getConduit()).thenReturn(tcpConduit);
+    when(connectionTable.getDM()).thenReturn(distributionManager);
+    when(connectionTable.getSocketCloser()).thenReturn(socketCloser);
+    when(distributionManager.getDistribution()).thenReturn(distribution);
+    when(stopper.cancelInProgress()).thenReturn(null);
+    when(tcpConduit.getCancelCriterion()).thenReturn(stopper);
+    when(tcpConduit.getDM()).thenReturn(distributionManager);
+    when(tcpConduit.getSocketId()).thenReturn(new InetSocketAddress(getLocalHost(), 10337));
+    when(tcpConduit.getStats()).thenReturn(dmStats);
+
+    SocketChannel channel = SocketChannel.open();
+
+    Connection connection = new Connection(connectionTable, channel.socket());
+    connection = spy(connection);
+    return connection;
+  }
+
+  @Test
+  public void firstCallToNotifyHandshakeWaiterWillClearSSLInputBuffer() throws Exception {
+    Connection connection = createSpiedConnection();
+    connection.notifyHandshakeWaiter(true);
+    verify(connection, times(1)).clearSSLInputBuffer();
+  }
+
+  @Test
+  public void secondCallWithTrueToNotifyHandshakeWaiterShouldNotClearSSLInputBuffer()
+      throws Exception {
+    Connection connection = createSpiedConnection();
+    connection.notifyHandshakeWaiter(true);
+    connection.notifyHandshakeWaiter(true);
+    verify(connection, times(1)).clearSSLInputBuffer();
+  }
+
+  @Test
+  public void secondCallWithFalseToNotifyHandshakeWaiterShouldNotClearSSLInputBuffer()
+      throws Exception {
+    Connection connection = createSpiedConnection();
+    connection.notifyHandshakeWaiter(true);
+    connection.notifyHandshakeWaiter(false);
+    verify(connection, times(1)).clearSSLInputBuffer();
+  }
 }