PROTON-2416: frame writer should reset its position on failure
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/FrameWriter.java b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/FrameWriter.java
index 4c037ac..2ec13c1 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/FrameWriter.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/engine/impl/FrameWriter.java
@@ -99,25 +99,29 @@
void writeFrame(int channel, Object frameBody, ReadableBuffer payload, Runnable onPayloadTooLarge) {
frameStart = frameBuffer.position();
+ try {
+ final int performativeSize = writePerformative(frameBody, payload, onPayloadTooLarge);
+ final int capacity = maxFrameSize > 0 ? maxFrameSize - performativeSize : Integer.MAX_VALUE;
+ final int payloadSize = Math.min(payload == null ? 0 : payload.remaining(), capacity);
- final int performativeSize = writePerformative(frameBody, payload, onPayloadTooLarge);
- final int capacity = maxFrameSize > 0 ? maxFrameSize - performativeSize : Integer.MAX_VALUE;
- final int payloadSize = Math.min(payload == null ? 0 : payload.remaining(), capacity);
+ if (transport.isFrameTracingEnabled()) {
+ logFrame(channel, frameBody, payload, payloadSize);
+ }
- if (transport.isFrameTracingEnabled()) {
- logFrame(channel, frameBody, payload, payloadSize);
+ if (payloadSize > 0) {
+ int oldLimit = payload.limit();
+ payload.limit(payload.position() + payloadSize);
+ frameBuffer.put(payload);
+ payload.limit(oldLimit);
+ }
+
+ endFrame(channel);
+
+ framesOutput++;
+ } catch (Exception e) {
+ frameBuffer.position(frameStart);
+ throw e;
}
-
- if (payloadSize > 0) {
- int oldLimit = payload.limit();
- payload.limit(payload.position() + payloadSize);
- frameBuffer.put(payload);
- payload.limit(oldLimit);
- }
-
- endFrame(channel);
-
- framesOutput++;
}
private int writePerformative(Object frameBody, ReadableBuffer payload, Runnable onPayloadTooLarge) {
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/FrameWriterTest.java b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/FrameWriterTest.java
index dd93304..f0ac2f1 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/FrameWriterTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/engine/impl/FrameWriterTest.java
@@ -22,6 +22,7 @@
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.spy;
@@ -29,14 +30,19 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
import org.apache.qpid.proton.amqp.Binary;
+import org.apache.qpid.proton.amqp.Symbol;
import org.apache.qpid.proton.amqp.UnsignedInteger;
import org.apache.qpid.proton.amqp.security.SaslFrameBody;
import org.apache.qpid.proton.amqp.security.SaslInit;
+import org.apache.qpid.proton.amqp.transport.Open;
import org.apache.qpid.proton.amqp.transport.ReceiverSettleMode;
import org.apache.qpid.proton.amqp.transport.Transfer;
import org.apache.qpid.proton.codec.AMQPDefinedTypes;
@@ -107,6 +113,38 @@
}
@Test
+ public void testFailToWriteFrame() {
+ FrameWriter framer = new FrameWriter(encoder, Integer.MAX_VALUE, (byte) 1, transport);
+
+ final class FailOnUnknownType implements Function<Boolean, Boolean> {
+ @Override
+ public Boolean apply(Boolean t) {
+ throw new IllegalStateException();
+ }
+ };
+
+ Open open = new Open();
+ Map<Symbol, Object> invalidProperties = new HashMap<>();
+ invalidProperties.put(Symbol.valueOf("invalid-unknown-type"), new FailOnUnknownType());
+ open.setProperties(invalidProperties);
+
+ try {
+ framer.writeFrame(0, open, null, null);
+ fail("should have thrown exception");
+ } catch (IllegalArgumentException e) {
+ // Expected
+ assertNotNull(e.getMessage());
+ assertTrue(e.getMessage().contains(FailOnUnknownType.class.getName()));
+ }
+
+ ByteBuffer destBuffer = ByteBuffer.allocate(16);
+ int read = framer.readBytes(destBuffer);
+
+ assertEquals("should not have been any output read", 0, read);
+ assertEquals("should not have been any output in buffer", 0, destBuffer.position());
+ }
+
+ @Test
public void testFrameWrittenToBufferWithLargePayloadAndMaxFrameSizeInvokesHandlerOnce() {
Transfer transfer = createTransfer();
FrameWriter framer = new FrameWriter(encoder, 2048, (byte) 0, transport);