PROTON-1525: enable configuring the initial remote max frame size limit to permit larger sasl frames
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
index 5d8b79d..b4864e5 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/Transport.java
@@ -217,6 +217,15 @@
int getRemoteMaxFrameSize();
/**
+ * Allows overriding the initial remote-max-frame-size to a value greater than the default 512bytes. The value set
+ * will be used until such time as the Open frame arrives from the peer and populates the remote max frame size.
+ *
+ * This method must be called before before {@link #sasl()} in order to influence SASL behaviour.
+ * @param size the remote frame size to use
+ */
+ void setInitialRemoteMaxFrameSize(int size);
+
+ /**
* Gets the local channel-max value to be advertised to the remote peer
*
* @return the local channel-max value
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
index dbf81cb..a6f75d5 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslFrameParser.java
@@ -61,12 +61,13 @@
private ByteBuffer _buffer;
private final ByteBufferDecoder _decoder;
+ private int _frameSizeLimit;
-
- SaslFrameParser(SaslFrameHandler sasl, ByteBufferDecoder decoder)
+ SaslFrameParser(SaslFrameHandler sasl, ByteBufferDecoder decoder, int frameSizeLimit)
{
_sasl = sasl;
_decoder = decoder;
+ _frameSizeLimit = frameSizeLimit;
}
/**
@@ -259,10 +260,10 @@
break;
}
- if (size > 512)
+ if (size > _frameSizeLimit)
{
frameParsingError = new TransportException(
- "specified frame size %d larger than maximum SASL frame size 512", size);
+ "specified frame size %d larger than maximum SASL frame size %d", size, _frameSizeLimit);
state = State.ERROR;
break;
}
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
index ffa49ff..d6f510b 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/SaslImpl.java
@@ -97,7 +97,7 @@
_outputBuffer = newWriteableBuffer(maxFrameSize);
AMQPDefinedTypes.registerAllTypes(_decoder,_encoder);
- _frameParser = new SaslFrameParser(this, _decoder);
+ _frameParser = new SaslFrameParser(this, _decoder, maxFrameSize);
_frameWriter = new FrameWriter(_encoder, maxFrameSize, FrameWriter.SASL_FRAME_TYPE, null, _transport);
}
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
index a7908f2..5441dec 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/TransportImpl.java
@@ -102,7 +102,7 @@
private EncoderImpl _encoder = new EncoderImpl(_decoder);
private int _maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
- private int _remoteMaxFrameSize = 512;
+ private int _remoteMaxFrameSize = MIN_MAX_FRAME_SIZE;
private int _channelMax = CHANNEL_MAX_LIMIT;
private int _remoteChannelMax = CHANNEL_MAX_LIMIT;
@@ -195,6 +195,17 @@
}
@Override
+ public void setInitialRemoteMaxFrameSize(int remoteMaxFrameSize)
+ {
+ if(_init)
+ {
+ throw new IllegalStateException("Cannot set initial remote max frame size after transport has been initialised");
+ }
+
+ _remoteMaxFrameSize = remoteMaxFrameSize;
+ }
+
+ @Override
public void setMaxFrameSize(int maxFrameSize)
{
if(_init)
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
index e67177f..e766572 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/SaslFrameParserTest.java
@@ -36,6 +36,7 @@
import org.apache.qpid.proton.codec.DecodeException;
import org.apache.qpid.proton.codec.DecoderImpl;
import org.apache.qpid.proton.codec.EncoderImpl;
+import org.apache.qpid.proton.engine.Transport;
import org.apache.qpid.proton.engine.TransportException;
import org.junit.Test;
@@ -47,7 +48,7 @@
private final SaslFrameHandler _mockSaslFrameHandler = mock(SaslFrameHandler.class);
private final ByteBufferDecoder _mockDecoder = mock(ByteBufferDecoder.class);
private final SaslFrameParser _frameParser;
- private final SaslFrameParser _frameParserWithMockDecoder = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder);
+ private final SaslFrameParser _frameParserWithMockDecoder = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder, Transport.MIN_MAX_FRAME_SIZE);
private final AmqpFramer _amqpFramer = new AmqpFramer();
private final SaslInit _saslFrameBody;
@@ -59,7 +60,7 @@
EncoderImpl encoder = new EncoderImpl(decoder);
AMQPDefinedTypes.registerAllTypes(decoder,encoder);
- _frameParser = new SaslFrameParser(_mockSaslFrameHandler, decoder);
+ _frameParser = new SaslFrameParser(_mockSaslFrameHandler, decoder, Transport.MIN_MAX_FRAME_SIZE);
_saslFrameBody = new SaslInit();
_saslFrameBody.setMechanism(Symbol.getSymbol("unused"));
_saslFrameBytes = ByteBuffer.wrap(_amqpFramer.generateSaslFrame(0, new byte[0], _saslFrameBody));
@@ -99,6 +100,28 @@
}
/*
+ * Test that when the frame parser is created with a size limit above the 512 byte min-max-frame-size, frames
+ * arriving with headers indicating they are over this size causes an exception.
+ */
+ @Test
+ public void testInputOfFrameWithInvalidSizeWhenSpecifyingLargeMaxFrameSize()
+ {
+ SaslFrameParser frameParserWithLargeMaxSize = new SaslFrameParser(_mockSaslFrameHandler, _mockDecoder, 2017);
+ sendAmqpSaslHeader(frameParserWithLargeMaxSize);
+
+ // http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-security-v1.0-os.html#doc-idp43536
+ // Description: '2057byte sized' SASL frame header
+ byte[] oversizedSaslFrameHeader = new byte[] { (byte) 0x00, 0x00, 0x08, 0x09, 0x02, 0x01, 0x00, 0x00 };
+
+ try {
+ frameParserWithLargeMaxSize.input(ByteBuffer.wrap(oversizedSaslFrameHeader));
+ fail("expected exception");
+ } catch (TransportException e) {
+ assertThat(e.getMessage(), containsString("frame size 2057 larger than maximum SASL frame size 2017"));
+ }
+ }
+
+ /*
* Test that SASL frames indicating they are under 8 bytes (the minimum size of the frame header) causes an error.
*/
@Test
@@ -226,7 +249,7 @@
SaslFrameHandler mockSaslFrameHandler = mock(SaslFrameHandler.class);
ByteBufferDecoder mockDecoder = mock(ByteBufferDecoder.class);
- SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder);
+ SaslFrameParser saslFrameParser = new SaslFrameParser(mockSaslFrameHandler, mockDecoder, Transport.MIN_MAX_FRAME_SIZE);
byte[] header = Arrays.copyOf(AmqpHeader.SASL_HEADER, AmqpHeader.SASL_HEADER.length);
header[invalidIndex] = 'X';
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
index a411cca..f6ca12c 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/TransportImplTest.java
@@ -71,7 +71,6 @@
public class TransportImplTest
{
- @SuppressWarnings("deprecation")
private TransportImpl _transport = new TransportImpl();
private static final int CHANNEL_ID = 1;
@@ -136,7 +135,6 @@
@Test
public void testEmptyInputWhenRemoteConnectionIsClosedUsingOldApi_isAllowed()
{
- @SuppressWarnings("deprecation")
ConnectionImpl connection = new ConnectionImpl();
_transport.bind(connection);
connection.setRemoteState(EndpointState.CLOSED);
@@ -174,7 +172,6 @@
@Test
public void testBoundTransport_continuesToHandleFrames()
{
- @SuppressWarnings("deprecation")
Connection connection = new ConnectionImpl();
assertTrue(_transport.isHandlingFrames());
@@ -217,7 +214,6 @@
int smallMaxFrameSize = 512;
_transport = new TransportImpl(smallMaxFrameSize);
- @SuppressWarnings("deprecation")
Connection conn = new ConnectionImpl();
_transport.bind(conn);
@@ -1637,4 +1633,38 @@
assertEquals("Unexpected frames written: " + getFrameTypesWritten(transport), 4, transport.writes.size());
assertTrue("Unexpected frame type", transport.writes.get(3) instanceof Detach);
}
+
+ @Test
+ public void testInitialRemoteMaxFrameSizeOverride()
+ {
+ MockTransportImpl transport = new MockTransportImpl();
+ transport.setInitialRemoteMaxFrameSize(768);
+
+ assertEquals("Unexpected value : " + getFrameTypesWritten(transport), 768, transport.getRemoteMaxFrameSize());
+
+ Connection connection = Proton.connection();
+ transport.bind(connection);
+ connection.open();
+ pumpMockTransport(transport);
+
+ assertEquals("Unexpected frames written: " + getFrameTypesWritten(transport), 1, transport.writes.size());
+ assertTrue("Unexpected frame type", transport.writes.get(0) instanceof Open);
+
+ try
+ {
+ transport.setInitialRemoteMaxFrameSize(12345);
+ fail("expected an exception");
+ }
+ catch (IllegalStateException ise )
+ {
+ //expected
+ }
+
+ // Send the necessary responses to open
+ Open open = new Open();
+ open.setMaxFrameSize(UnsignedInteger.valueOf(4567));
+ transport.handleFrame(new TransportFrame(0, open, null));
+
+ assertEquals("Unexpected value : " + getFrameTypesWritten(transport), 4567, transport.getRemoteMaxFrameSize());
+ }
}
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java b/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
index 93718a0..12f5143 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/systemtests/SaslTest.java
@@ -25,6 +25,7 @@
import static org.junit.Assert.fail;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
import java.util.logging.Logger;
import org.junit.Test;
@@ -284,4 +285,105 @@
}
+ /*
+ * Test that transports configured to do so are able to perform SASL process where frames are
+ * exchanged larger than the 512byte min-max-frame-size.
+ */
+ @Test
+ public void testSaslNegotiationWithConfiguredLargerFrameSize() throws Exception
+ {
+ final byte[] largeInitialResponseBytesOrig = fillBytes("initialResponse", 1431);
+ final byte[] largeChallengeBytesOrig = fillBytes("challenge", 1375);
+ final byte[] largeResponseBytesOrig = fillBytes("response", 1282);
+ final byte[] largeAdditionalBytesOrig = fillBytes("additionalData", 1529);
+
+ getClient().transport = Proton.transport();
+ getServer().transport = Proton.transport();
+
+ // Configure transports to allow for larger initial frame sizes
+ getClient().transport.setInitialRemoteMaxFrameSize(2048);
+ getServer().transport.setInitialRemoteMaxFrameSize(2048);
+
+ Sasl clientSasl = getClient().transport.sasl();
+ clientSasl.client();
+
+ Sasl serverSasl = getServer().transport.sasl();
+ serverSasl.server();
+
+ // Negotiate the mech
+ serverSasl.setMechanisms(TESTMECH1, TESTMECH2);
+
+ pumpClientToServer();
+ pumpServerToClient();
+
+ assertArrayEquals("Client should now know the server's mechanisms.", new String[] { TESTMECH1, TESTMECH2 },
+ clientSasl.getRemoteMechanisms());
+ assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_NONE, clientSasl.getOutcome());
+
+ // Select a mech, send large initial response along with it in sasl-init, verify server receives it
+ clientSasl.setMechanisms(TESTMECH1);
+ byte[] initialResponseBytes = Arrays.copyOf(largeInitialResponseBytesOrig, largeInitialResponseBytesOrig.length);
+ clientSasl.send(initialResponseBytes, 0, initialResponseBytes.length);
+
+ pumpClientToServer();
+
+ assertArrayEquals("Server should now know the client's chosen mechanism.", new String[] { TESTMECH1 },
+ serverSasl.getRemoteMechanisms());
+
+ byte[] serverReceivedInitialResponseBytes = new byte[serverSasl.pending()];
+ serverSasl.recv(serverReceivedInitialResponseBytes, 0, serverReceivedInitialResponseBytes.length);
+
+ assertArrayEquals("Server should now know the clients initial response", largeInitialResponseBytesOrig,
+ serverReceivedInitialResponseBytes);
+
+ // Send a large challenge in a sasl-challenge, verify client receives it
+ byte[] challengeBytes = Arrays.copyOf(largeChallengeBytesOrig, largeChallengeBytesOrig.length);
+ serverSasl.send(challengeBytes, 0, challengeBytes.length);
+
+ pumpServerToClient();
+
+ byte[] clientReceivedChallengeBytes = new byte[clientSasl.pending()];
+ clientSasl.recv(clientReceivedChallengeBytes, 0, clientReceivedChallengeBytes.length);
+
+ assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_NONE, clientSasl.getOutcome());
+ assertArrayEquals("Client should now know the server's challenge", largeChallengeBytesOrig,
+ clientReceivedChallengeBytes);
+
+ // Send a large response in a sasl-response, verify server receives it
+ byte[] responseBytes = Arrays.copyOf(largeResponseBytesOrig, largeResponseBytesOrig.length);
+ clientSasl.send(responseBytes, 0, responseBytes.length);
+
+ pumpClientToServer();
+
+ byte[] serverReceivedResponseBytes = new byte[serverSasl.pending()];
+ serverSasl.recv(serverReceivedResponseBytes, 0, serverReceivedResponseBytes.length);
+
+ assertArrayEquals("Server should now know the client's response", largeResponseBytesOrig, serverReceivedResponseBytes);
+
+ // Send an outcome with large additional data in a sasl-outcome, verify client receives it
+ byte[] additionalBytes = Arrays.copyOf(largeAdditionalBytesOrig, largeAdditionalBytesOrig.length);
+ serverSasl.send(additionalBytes, 0, additionalBytes.length);
+ serverSasl.done(SaslOutcome.PN_SASL_OK);
+ pumpServerToClient();
+
+ assertEquals("Unexpected SASL outcome at client", SaslOutcome.PN_SASL_OK, clientSasl.getOutcome());
+
+ byte[] clientReceivedAdditionalBytes = new byte[clientSasl.pending()];
+ clientSasl.recv(clientReceivedAdditionalBytes, 0, clientReceivedAdditionalBytes.length);
+
+ assertArrayEquals("Client should now know the server's outcome additional data", largeAdditionalBytesOrig,
+ clientReceivedAdditionalBytes);
+ }
+
+ private byte[] fillBytes(String seedString, int length)
+ {
+ byte[] seed = seedString.getBytes(StandardCharsets.UTF_8);
+ byte[] bytes = new byte[length];
+ for (int i = 0; i < length; i++)
+ {
+ bytes[i] = seed[i % seed.length];
+ }
+
+ return bytes;
+ }
}