PROTON-1426: check the initial bytes to ensure the expected SASL layer header is given
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
index 8becc72..37754ba 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
@@ -21,6 +21,8 @@
 
 package org.apache.qpid.proton.engine.impl;
 
+import static org.apache.qpid.proton.engine.impl.AmqpHeader.SASL_HEADER;
+
 import java.nio.ByteBuffer;
 
 import org.apache.qpid.proton.amqp.Binary;
@@ -35,6 +37,14 @@
 
     enum State
     {
+        HEADER0,
+        HEADER1,
+        HEADER2,
+        HEADER3,
+        HEADER4,
+        HEADER5,
+        HEADER6,
+        HEADER7,
         SIZE_0,
         SIZE_1,
         SIZE_2,
@@ -45,12 +55,11 @@
         ERROR
     }
 
-    private State _state = State.SIZE_0;
+    private State _state = State.HEADER0;
     private int _size;
 
     private ByteBuffer _buffer;
 
-    private int _ignore = 8;
     private final ByteBufferDecoder _decoder;
 
 
@@ -70,19 +79,144 @@
         State state = _state;
         ByteBuffer oldIn = null;
 
-        // Note that we simply skip over the header rather than parsing it.
-        if(_ignore != 0)
-        {
-            int bytesToEat = Math.min(_ignore, input.remaining());
-            input.position(input.position() + bytesToEat);
-            _ignore -= bytesToEat;
-        }
-
         while(input.hasRemaining() && state != State.ERROR && !_sasl.isDone())
         {
             switch(state)
             {
+                case HEADER0:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[0])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[0], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER1;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER1:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[1])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[1], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER2;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER2:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[2])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[2], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER3;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER3:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[3])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[3], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER4;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER4:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[4])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[4], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER5;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER5:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[5])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[5], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER6;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER6:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[6])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[6], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.HEADER7;
+                    }
+                    else
+                    {
+                        break;
+                    }
+                case HEADER7:
+                    if(input.hasRemaining())
+                    {
+                        byte c = input.get();
+                        if(c != SASL_HEADER[7])
+                        {
+                            frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s", c, SASL_HEADER[7], state);
+                            state = State.ERROR;
+                            break;
+                        }
+                        state = State.SIZE_0;
+                    }
+                    else
+                    {
+                        break;
+                    }
                 case SIZE_0:
+                    if(!input.hasRemaining())
+                    {
+                        break;
+                    }
+
                     if(input.remaining() >= 4)
                     {
                         size = input.getInt();
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
index 31dd200..7ca6f74 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
@@ -20,9 +20,10 @@
 
 import static org.mockito.Mockito.*;
 import static org.junit.Assert.*;
-import static org.junit.matchers.JUnitMatchers.containsString;
+import static org.hamcrest.CoreMatchers.containsString;
 
 import java.nio.ByteBuffer;
+import java.util.Arrays;
 
 import org.apache.qpid.proton.amqp.Binary;
 import org.apache.qpid.proton.amqp.Symbol;
@@ -39,7 +40,6 @@
 import org.junit.Test;
 
 /**
- * TODO test case where header is malformed
  * TODO test case where input provides frame and half etc
  */
 public class SaslFrameParserTest
@@ -131,6 +131,67 @@
         }
     }
 
+    /*
+     * Test that if the first 8 bytes don't match the AMQP SASL header, it causes an error.
+     */
+    @Test
+    public void testInputOfInvalidHeader() {
+        for (int invalidIndex = 0; invalidIndex < 8; invalidIndex++) {
+            doInputOfInvalidHeaderTestImpl(invalidIndex);
+        }
+    }
+
+    private void doInputOfInvalidHeaderTestImpl(int invalidIndex) {
+        SaslFrameHandler mockSaslFrameHandler = mock(SaslFrameHandler.class);
+        ByteBufferDecoder mockDecoder = mock(ByteBufferDecoder.class);
+
+        SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder);
+
+        byte[] header = Arrays.copyOf(AmqpHeader.SASL_HEADER, AmqpHeader.SASL_HEADER.length);
+        header[invalidIndex] = 'X';
+
+        try {
+            saslFrameParser.input(ByteBuffer.wrap(header));
+            fail("expected exception");
+        } catch (TransportException e) {
+            assertThat(e.getMessage(), containsString("AMQP SASL header mismatch"));
+            assertThat(e.getMessage(), containsString("In state: HEADER" + invalidIndex));
+        }
+
+        // Check that further interaction throws TransportException.
+        try {
+            saslFrameParser.input(ByteBuffer.wrap(new byte[0]));
+            fail("expected exception");
+        } catch (TransportException e) {
+            // Expected
+        }
+    }
+
+    /*
+     * Test that if the first 8 bytes, fed in one at a time, don't match the AMQP SASL header, it causes an error.
+     */
+    @Test
+    public void testInputOfValidHeaderInSegments()
+    {
+        sendAmqpSaslHeaderInSegments();
+
+        // Try feeding in an actual frame now to check we get past the header parsing ok
+        when(_mockSaslFrameHandler.isDone()).thenReturn(false);
+
+        _frameParser.input(_saslFrameBytes);
+
+        verify(_mockSaslFrameHandler).handle(isA(SaslInit.class), (Binary)isNull());
+    }
+
+    private void sendAmqpSaslHeaderInSegments()
+    {
+        for (int headerIndex = 0; headerIndex < 8; headerIndex++)
+        {
+            byte headerPart = AmqpHeader.SASL_HEADER[headerIndex];
+            _frameParser.input(ByteBuffer.wrap(new byte[] { headerPart }));
+        }
+    }
+
     private void sendAmqpSaslHeader(SaslFrameParser saslFrameParser)
     {
         saslFrameParser.input(ByteBuffer.wrap(AmqpHeader.SASL_HEADER));