PROTON-1426: check the initial bytes to ensure the expected SASL layer header is given
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
index 8becc72..37754ba 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
@@ -21,6 +21,8 @@
package org.apache.qpid.proton.engine.impl;
+import static org.apache.qpid.proton.engine.impl.AmqpHeader.SASL_HEADER;
+
import java.nio.ByteBuffer;
import org.apache.qpid.proton.amqp.Binary;
@@ -35,6 +37,14 @@
enum State
{
+ HEADER0,
+ HEADER1,
+ HEADER2,
+ HEADER3,
+ HEADER4,
+ HEADER5,
+ HEADER6,
+ HEADER7,
SIZE_0,
SIZE_1,
SIZE_2,
@@ -45,12 +55,11 @@
ERROR
}
- private State _state = State.SIZE_0;
+ private State _state = State.HEADER0;
private int _size;
private ByteBuffer _buffer;
- private int _ignore = 8;
private final ByteBufferDecoder _decoder;
@@ -70,19 +79,144 @@
State state = _state;
ByteBuffer oldIn = null;
- // Note that we simply skip over the header rather than parsing it.
- if(_ignore != 0)
- {
- int bytesToEat = Math.min(_ignore, input.remaining());
- input.position(input.position() + bytesToEat);
- _ignore -= bytesToEat;
- }
-
while(input.hasRemaining() && state != State.ERROR && !_sasl.isDone())
{
switch(state)
{
+ case HEADER0:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[0])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[0], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER1;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER1:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[1])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[1], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER2;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER2:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[2])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[2], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER3;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER3:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[3])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[3], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER4;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER4:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[4])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[4], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER5;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER5:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[5])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[5], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER6;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER6:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[6])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[6], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.HEADER7;
+ }
+ else
+ {
+ break;
+ }
+ case HEADER7:
+ if(input.hasRemaining())
+ {
+ byte c = input.get();
+ if(c != SASL_HEADER[7])
+ {
+ frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[7], state);
+ state = State.ERROR;
+ break;
+ }
+ state = State.SIZE_0;
+ }
+ else
+ {
+ break;
+ }
case SIZE_0:
+ if(!input.hasRemaining())
+ {
+ break;
+ }
+
if(input.remaining() >= 4)
{
size = input.getInt();
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
index 31dd200..7ca6f74 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
@@ -20,9 +20,10 @@
import static org.mockito.Mockito.*;
import static org.junit.Assert.*;
-import static org.junit.matchers.JUnitMatchers.containsString;
+import static org.hamcrest.CoreMatchers.containsString;
import java.nio.ByteBuffer;
+import java.util.Arrays;
import org.apache.qpid.proton.amqp.Binary;
import org.apache.qpid.proton.amqp.Symbol;
@@ -39,7 +40,6 @@
import org.junit.Test;
/**
- * TODO test case where header is malformed
* TODO test case where input provides frame and half etc
*/
public class SaslFrameParserTest
@@ -131,6 +131,67 @@
}
}
+ /*
+ * Test that if the first 8 bytes don't match the AMQP SASL header, it causes an error.
+ */
+ @Test
+ public void testInputOfInvalidHeader() {
+ for (int invalidIndex = 0; invalidIndex < 8; invalidIndex++) {
+ doInputOfInvalidHeaderTestImpl(invalidIndex);
+ }
+ }
+
+ private void doInputOfInvalidHeaderTestImpl(int invalidIndex) {
+ SaslFrameHandler mockSaslFrameHandler = mock(SaslFrameHandler.class);
+ ByteBufferDecoder mockDecoder = mock(ByteBufferDecoder.class);
+
+ SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder);
+
+ byte[] header = Arrays.copyOf(AmqpHeader.SASL_HEADER, AmqpHeader.SASL_HEADER.length);
+ header[invalidIndex] = 'X';
+
+ try {
+ saslFrameParser.input(ByteBuffer.wrap(header));
+ fail("expected exception");
+ } catch (TransportException e) {
+ assertThat(e.getMessage(), containsString("AMQP SASL header mismatch"));
+ assertThat(e.getMessage(), containsString("In state: HEADER" + invalidIndex));
+ }
+
+ // Check that further interaction throws TransportException.
+ try {
+ saslFrameParser.input(ByteBuffer.wrap(new byte[0]));
+ fail("expected exception");
+ } catch (TransportException e) {
+ // Expected
+ }
+ }
+
+ /*
+ * Test that if the first 8 bytes, fed in one at a time, don't match the AMQP SASL header, it causes an error.
+ */
+ @Test
+ public void testInputOfValidHeaderInSegments()
+ {
+ sendAmqpSaslHeaderInSegments();
+
+ // Try feeding in an actual frame now to check we get past the header parsing ok
+ when(_mockSaslFrameHandler.isDone()).thenReturn(false);
+
+ _frameParser.input(_saslFrameBytes);
+
+ verify(_mockSaslFrameHandler).handle(isA(SaslInit.class), (Binary)isNull());
+ }
+
+ private void sendAmqpSaslHeaderInSegments()
+ {
+ for (int headerIndex = 0; headerIndex < 8; headerIndex++)
+ {
+ byte headerPart = AmqpHeader.SASL_HEADER[headerIndex];
+ _frameParser.input(ByteBuffer.wrap(new byte[] { headerPart }));
+ }
+ }
+
private void sendAmqpSaslHeader(SaslFrameParser saslFrameParser)
{
saslFrameParser.input(ByteBuffer.wrap(AmqpHeader.SASL_HEADER));