PROTON-1525: enable configuring the initial remote max frame size limit to permit larger sasl frames
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
index 5d8b79d..b4864e5 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
@@ -217,6 +217,15 @@
     int getRemoteMaxFrameSize();
 
     /**
+     * Allows overriding the initial remote-max-frame-size to a value greater than the default 512bytes. The value set
+     * will be used until such time as the Open frame arrives from the peer and populates the remote max frame size.
+     *
+     * This method must be called before before {@link #sasl()} in order to influence SASL behaviour.
+     * @param size the remote frame size to use
+     */
+    void setInitialRemoteMaxFrameSize(int size);
+
+    /**
      * Gets the local channel-max value to be advertised to the remote peer
      *
      * @return the local channel-max value
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
index dbf81cb..a6f75d5 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
@@ -61,12 +61,13 @@
     private ByteBuffer _buffer;
 
     private final ByteBufferDecoder _decoder;
+    private int _frameSizeLimit;
 
-
-    SaslFrameParser(SaslFrameHandler sasl, ByteBufferDecoder decoder)
+    SaslFrameParser(SaslFrameHandler sasl, ByteBufferDecoder decoder, int frameSizeLimit)
     {
         _sasl = sasl;
         _decoder = decoder;
+        _frameSizeLimit = frameSizeLimit;
     }
 
     /**
@@ -259,10 +260,10 @@
                         break;
                     }
 
-                    if (size > 512)
+                    if (size > _frameSizeLimit)
                     {
                         frameParsingError = new TransportException(
-                                "specified frame size %d larger than maximum SASL frame size 512", size);
+                                "specified frame size %d larger than maximum SASL frame size %d", size, _frameSizeLimit);
                         state = State.ERROR;
                         break;
                     }
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
index ffa49ff..d6f510b 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
@@ -97,7 +97,7 @@
         _outputBuffer = newWriteableBuffer(maxFrameSize);
 
         AMQPDefinedTypes.registerAllTypes(_decoder,_encoder);
-        _frameParser = new SaslFrameParser(this, _decoder);
+        _frameParser = new SaslFrameParser(this, _decoder, maxFrameSize);
         _frameWriter = new FrameWriter(_encoder, maxFrameSize, FrameWriter.SASL_FRAME_TYPE, null, _transport);
     }
 
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
index a7908f2..5441dec 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
@@ -102,7 +102,7 @@
     private EncoderImpl _encoder = new EncoderImpl(_decoder);
 
     private int _maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
-    private int _remoteMaxFrameSize = 512;
+    private int _remoteMaxFrameSize = MIN_MAX_FRAME_SIZE;
     private int _channelMax       = CHANNEL_MAX_LIMIT;
     private int _remoteChannelMax = CHANNEL_MAX_LIMIT;
 
@@ -195,6 +195,17 @@
     }
 
     @Override
+    public void setInitialRemoteMaxFrameSize(int remoteMaxFrameSize)
+    {
+        if(_init)
+        {
+            throw new IllegalStateException("Cannot set initial remote max frame size after transport has been initialised");
+        }
+
+        _remoteMaxFrameSize = remoteMaxFrameSize;
+    }
+
+    @Override
     public void setMaxFrameSize(int maxFrameSize)
     {
         if(_init)
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
index e67177f..e766572 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
@@ -36,6 +36,7 @@
 import org.apache.qpid.proton.codec.DecodeException;
 import org.apache.qpid.proton.codec.DecoderImpl;
 import org.apache.qpid.proton.codec.EncoderImpl;
+import org.apache.qpid.proton.engine.Transport;
 import org.apache.qpid.proton.engine.TransportException;
 import org.junit.Test;
 
@@ -47,7 +48,7 @@
     private final SaslFrameHandler _mockSaslFrameHandler = mock(SaslFrameHandler.class);
     private final ByteBufferDecoder _mockDecoder = mock(ByteBufferDecoder.class);
     private final SaslFrameParser _frameParser;
-    private final SaslFrameParser _frameParserWithMockDecoder = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder);
+    private final SaslFrameParser _frameParserWithMockDecoder = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder, Transport.MIN_MAX_FRAME_SIZE);
     private final AmqpFramer _amqpFramer = new AmqpFramer();
 
     private final SaslInit _saslFrameBody;
@@ -59,7 +60,7 @@
         EncoderImpl encoder = new EncoderImpl(decoder);
         AMQPDefinedTypes.registerAllTypes(decoder,encoder);
 
-        _frameParser = new SaslFrameParser(_mockSaslFrameHandler, decoder);
+        _frameParser = new SaslFrameParser(_mockSaslFrameHandler, decoder, Transport.MIN_MAX_FRAME_SIZE);
         _saslFrameBody = new SaslInit();
         _saslFrameBody.setMechanism(Symbol.getSymbol("unused"));
         _saslFrameBytes = ByteBuffer.wrap(_amqpFramer.generateSaslFrame(0, new byte[0], _saslFrameBody));
@@ -99,6 +100,28 @@
     }
 
     /*
+     * Test that when the frame parser is created with a size limit above the 512 byte min-max-frame-size, frames
+     * arriving with headers indicating they are over this size causes an exception.
+     */
+    @Test
+    public void testInputOfFrameWithInvalidSizeWhenSpecifyingLargeMaxFrameSize()
+    {
+        SaslFrameParser frameParserWithLargeMaxSize = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder, 2017);
+        sendAmqpSaslHeader(frameParserWithLargeMaxSize);
+
+        // http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-security-v1.0-os.html#doc-idp43536
+        // Description: '2057byte sized' SASL frame header
+        byte[] oversizedSaslFrameHeader = new byte[] { (byte) 0x00, 0x00, 0x08, 0x09, 0x02, 0x01, 0x00, 0x00 };
+
+        try {
+            frameParserWithLargeMaxSize.input(ByteBuffer.wrap(oversizedSaslFrameHeader));
+            fail("expected exception");
+        } catch (TransportException e) {
+            assertThat(e.getMessage(), containsString("frame size 2057 larger than maximum SASL frame size 2017"));
+        }
+    }
+
+    /*
      * Test that SASL frames indicating they are under 8 bytes (the minimum size of the frame header) causes an error.
      */
     @Test
@@ -226,7 +249,7 @@
         SaslFrameHandler mockSaslFrameHandler = mock(SaslFrameHandler.class);
         ByteBufferDecoder mockDecoder = mock(ByteBufferDecoder.class);
 
-        SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder);
+        SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder, Transport.MIN_MAX_FRAME_SIZE);
 
         byte[] header = Arrays.copyOf(AmqpHeader.SASL_HEADER, AmqpHeader.SASL_HEADER.length);
         header[invalidIndex] = 'X';
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
index a411cca..f6ca12c 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
@@ -71,7 +71,6 @@
 
 public class TransportImplTest
 {
-    @SuppressWarnings("deprecation")
     private TransportImpl _transport = new TransportImpl();
 
     private static final int CHANNEL_ID = 1;
@@ -136,7 +135,6 @@
     @Test
     public void testEmptyInputWhenRemoteConnectionIsClosedUsingOldApi_isAllowed()
     {
-        @SuppressWarnings("deprecation")
         ConnectionImpl connection = new ConnectionImpl();
         _transport.bind(connection);
         connection.setRemoteState(EndpointState.CLOSED);
@@ -174,7 +172,6 @@
     @Test
     public void testBoundTransport_continuesToHandleFrames()
     {
-        @SuppressWarnings("deprecation")
         Connection connection = new ConnectionImpl();
 
         assertTrue(_transport.isHandlingFrames());
@@ -217,7 +214,6 @@
         int smallMaxFrameSize = 512;
         _transport = new TransportImpl(smallMaxFrameSize);
 
-        @SuppressWarnings("deprecation")
         Connection conn = new ConnectionImpl();
         _transport.bind(conn);
 
@@ -1637,4 +1633,38 @@
         assertEquals("Unexpected frames written: " + getFrameTypesWritten(transport), 4, transport.writes.size());
         assertTrue("Unexpected frame type", transport.writes.get(3) instanceof Detach);
     }
+
+    @Test
+    public void testInitialRemoteMaxFrameSizeOverride()
+    {
+        MockTransportImpl transport = new MockTransportImpl();
+        transport.setInitialRemoteMaxFrameSize(768);
+
+        assertEquals("Unexpected value : " + getFrameTypesWritten(transport), 768, transport.getRemoteMaxFrameSize());
+
+        Connection connection = Proton.connection();
+        transport.bind(connection);
+        connection.open();
+        pumpMockTransport(transport);
+
+        assertEquals("Unexpected frames written: " + getFrameTypesWritten(transport), 1, transport.writes.size());
+        assertTrue("Unexpected frame type", transport.writes.get(0) instanceof Open);
+
+        try
+        {
+            transport.setInitialRemoteMaxFrameSize(12345);
+            fail("expected an exception");
+        }
+        catch (IllegalStateException ise )
+        {
+            //expected
+        }
+
+        // Send the necessary responses to open
+        Open open = new Open();
+        open.setMaxFrameSize(UnsignedInteger.valueOf(4567));
+        transport.handleFrame(new TransportFrame(0, open, null));
+
+        assertEquals("Unexpected value : " + getFrameTypesWritten(transport), 4567, transport.getRemoteMaxFrameSize());
+    }
 }
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java b/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
index 93718a0..12f5143 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
@@ -25,6 +25,7 @@
 import static org.junit.Assert.fail;
 
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
 import java.util.logging.Logger;
 
 import org.junit.Test;
@@ -284,4 +285,105 @@
 
     }
 
+    /*
+     * Test that transports configured to do so are able to perform SASL process where frames are
+     * exchanged larger than the 512byte min-max-frame-size.
+     */
+    @Test
+    public void testSaslNegotiationWithConfiguredLargerFrameSize() throws Exception
+    {
+        final byte[] largeInitialResponseBytesOrig = fillBytes("initialResponse", 1431);
+        final byte[] largeChallengeBytesOrig = fillBytes("challenge", 1375);
+        final byte[] largeResponseBytesOrig = fillBytes("response", 1282);
+        final byte[] largeAdditionalBytesOrig = fillBytes("additionalData", 1529);
+
+        getClient().transport = Proton.transport();
+        getServer().transport = Proton.transport();
+
+        // Configure transports to allow for larger initial frame sizes
+        getClient().transport.setInitialRemoteMaxFrameSize(2048);
+        getServer().transport.setInitialRemoteMaxFrameSize(2048);
+
+        Sasl clientSasl = getClient().transport.sasl();
+        clientSasl.client();
+
+        Sasl serverSasl = getServer().transport.sasl();
+        serverSasl.server();
+
+        // Negotiate the mech
+        serverSasl.setMechanisms(TESTMECH1, TESTMECH2);
+
+        pumpClientToServer();
+        pumpServerToClient();
+
+        assertArrayEquals("Client should now know the server's mechanisms.", new String[] { TESTMECH1, TESTMECH2 },
+                clientSasl.getRemoteMechanisms());
+        assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_NONE, clientSasl.getOutcome());
+
+        // Select a mech, send large initial response along with it in sasl-init, verify server receives it
+        clientSasl.setMechanisms(TESTMECH1);
+        byte[] initialResponseBytes = Arrays.copyOf(largeInitialResponseBytesOrig, largeInitialResponseBytesOrig.length);
+        clientSasl.send(initialResponseBytes, 0, initialResponseBytes.length);
+
+        pumpClientToServer();
+
+        assertArrayEquals("Server should now know the client's chosen mechanism.", new String[] { TESTMECH1 },
+                serverSasl.getRemoteMechanisms());
+
+        byte[] serverReceivedInitialResponseBytes = new byte[serverSasl.pending()];
+        serverSasl.recv(serverReceivedInitialResponseBytes, 0, serverReceivedInitialResponseBytes.length);
+
+        assertArrayEquals("Server should now know the clients initial response", largeInitialResponseBytesOrig,
+                serverReceivedInitialResponseBytes);
+
+        // Send a large challenge in a sasl-challenge, verify client receives it
+        byte[] challengeBytes = Arrays.copyOf(largeChallengeBytesOrig, largeChallengeBytesOrig.length);
+        serverSasl.send(challengeBytes, 0, challengeBytes.length);
+
+        pumpServerToClient();
+
+        byte[] clientReceivedChallengeBytes = new byte[clientSasl.pending()];
+        clientSasl.recv(clientReceivedChallengeBytes, 0, clientReceivedChallengeBytes.length);
+
+        assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_NONE, clientSasl.getOutcome());
+        assertArrayEquals("Client should now know the server's challenge", largeChallengeBytesOrig,
+                clientReceivedChallengeBytes);
+
+        // Send a large response in a sasl-response, verify server receives it
+        byte[] responseBytes = Arrays.copyOf(largeResponseBytesOrig, largeResponseBytesOrig.length);
+        clientSasl.send(responseBytes, 0, responseBytes.length);
+
+        pumpClientToServer();
+
+        byte[] serverReceivedResponseBytes = new byte[serverSasl.pending()];
+        serverSasl.recv(serverReceivedResponseBytes, 0, serverReceivedResponseBytes.length);
+
+        assertArrayEquals("Server should now know the client's response", largeResponseBytesOrig, serverReceivedResponseBytes);
+
+        // Send an outcome with large additional data in a sasl-outcome, verify client receives it
+        byte[] additionalBytes = Arrays.copyOf(largeAdditionalBytesOrig, largeAdditionalBytesOrig.length);
+        serverSasl.send(additionalBytes, 0, additionalBytes.length);
+        serverSasl.done(SaslOutcome.PN_SASL_OK);
+        pumpServerToClient();
+
+        assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_OK, clientSasl.getOutcome());
+
+        byte[] clientReceivedAdditionalBytes = new byte[clientSasl.pending()];
+        clientSasl.recv(clientReceivedAdditionalBytes, 0, clientReceivedAdditionalBytes.length);
+
+        assertArrayEquals("Client should now know the server's outcome additional data", largeAdditionalBytesOrig,
+                clientReceivedAdditionalBytes);
+    }
+
+    private byte[] fillBytes(String seedString, int length)
+    {
+        byte[] seed = seedString.getBytes(StandardCharsets.UTF_8);
+        byte[] bytes = new byte[length];
+        for (int i = 0; i < length; i++)
+        {
+            bytes[i] = seed[i % seed.length];
+        }
+
+        return bytes;
+    }
 }