JAVA-2772: Support new protocol v5 message format
https://issues.apache.org/jira/browse/CASSANDRA-15299
diff --git a/changelog/README.md b/changelog/README.md
index 76838b7..5030a10 100644
--- a/changelog/README.md
+++ b/changelog/README.md
@@ -8,6 +8,7 @@
## 3.10.0 (in progress)
- [improvement] JAVA-2676: Don't reschedule flusher after empty runs
+- [new feature] JAVA-2772: Support new protocol v5 message format
## 3.9.0
diff --git a/driver-core/src/main/java/com/datastax/driver/core/BytesToSegmentDecoder.java b/driver-core/src/main/java/com/datastax/driver/core/BytesToSegmentDecoder.java
new file mode 100644
index 0000000..58eda4f
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/BytesToSegmentDecoder.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import java.nio.ByteOrder;
+
+/**
+ * Decodes {@link Segment}s from a stream of bytes.
+ *
+ * <p>This works like a regular length-field-based decoder, but we override {@link
+ * #getUnadjustedFrameLength} to handle two peculiarities: the length is encoded on 17 bits, and we
+ * also want to check the header CRC before we use it. So we parse the whole segment header ahead of
+ * time, and store it until we're ready to build the segment.
+ */
+class BytesToSegmentDecoder extends LengthFieldBasedFrameDecoder {
+
+ private final SegmentCodec segmentCodec;
+ private SegmentCodec.Header header;
+
+ BytesToSegmentDecoder(SegmentCodec segmentCodec) {
+ super(
+ // max length (Netty wants this to be the overall length including everything):
+ segmentCodec.headerLength()
+ + SegmentCodec.CRC24_LENGTH
+ + Segment.MAX_PAYLOAD_LENGTH
+ + SegmentCodec.CRC32_LENGTH,
+ // offset and size of the "length" field: that's the whole header
+ 0,
+ segmentCodec.headerLength() + SegmentCodec.CRC24_LENGTH,
+ // length adjustment: add the trailing CRC to the declared length
+ SegmentCodec.CRC32_LENGTH,
+ // bytes to skip: the header (we've already parsed it while reading the length)
+ segmentCodec.headerLength() + SegmentCodec.CRC24_LENGTH);
+ this.segmentCodec = segmentCodec;
+ }
+
+ @Override
+ protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
+ try {
+ ByteBuf payloadAndCrc = (ByteBuf) super.decode(ctx, in);
+ if (payloadAndCrc == null) {
+ return null;
+ } else {
+ assert header != null;
+ Segment segment = segmentCodec.decode(header, payloadAndCrc);
+ header = null;
+ return segment;
+ }
+ } catch (Exception e) {
+ // Don't hold on to a stale header if we failed to decode the rest of the segment
+ header = null;
+ throw e;
+ }
+ }
+
+ @Override
+ protected long getUnadjustedFrameLength(ByteBuf buffer, int offset, int length, ByteOrder order) {
+ // The parent class calls this repeatedly for the same "frame" if there weren't enough
+ // accumulated bytes the first time. Only decode the header the first time:
+ if (header == null) {
+ header = segmentCodec.decodeHeader(buffer.slice(offset, length));
+ }
+ return header.payloadLength;
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/Connection.java b/driver-core/src/main/java/com/datastax/driver/core/Connection.java
index 6718273..974e2bd 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/Connection.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/Connection.java
@@ -22,6 +22,7 @@
import com.datastax.driver.core.exceptions.AuthenticationException;
import com.datastax.driver.core.exceptions.BusyConnectionException;
import com.datastax.driver.core.exceptions.ConnectionException;
+import com.datastax.driver.core.exceptions.CrcMismatchException;
import com.datastax.driver.core.exceptions.DriverException;
import com.datastax.driver.core.exceptions.DriverInternalError;
import com.datastax.driver.core.exceptions.FrameTooLongException;
@@ -345,6 +346,11 @@
return new AsyncFunction<Message.Response, Void>() {
@Override
public ListenableFuture<Void> apply(Message.Response response) throws Exception {
+
+ if (protocolVersion.compareTo(ProtocolVersion.V5) >= 0 && response.type != ERROR) {
+ switchToV5Framing();
+ }
+
switch (response.type) {
case READY:
return checkClusterName(protocolVersion, initExecutor);
@@ -1325,7 +1331,7 @@
// Special case, if we encountered a FrameTooLongException, raise exception on handler and
// don't defunct it since
// the connection is in an ok state.
- if (error != null && error instanceof FrameTooLongException) {
+ if (error instanceof FrameTooLongException) {
FrameTooLongException ftle = (FrameTooLongException) error;
int streamId = ftle.getStreamId();
ResponseHandler handler = pending.remove(streamId);
@@ -1344,6 +1350,9 @@
handler.callback.onException(
Connection.this, ftle, System.nanoTime() - handler.startTime, handler.retryCount);
return;
+ } else if (error instanceof CrcMismatchException) {
+ // Fall back to the defunct call below, but we want a clear warning in the logs
+ logger.warn("CRC mismatch while decoding a response, dropping the connection", error);
}
}
defunct(
@@ -1711,7 +1720,11 @@
pipeline.addLast("frameDecoder", new Frame.Decoder());
pipeline.addLast("frameEncoder", frameEncoder);
- if (compressor != null) {
+ if (compressor != null
+ // Frame-level compression is only done in legacy protocol versions. In V5 and above, it
+ // happens at a higher level ("segment" that groups multiple frames), so never install
+ // those handlers.
+ && protocolVersion.compareTo(ProtocolVersion.V5) < 0) {
pipeline.addLast("frameDecompressor", new Frame.Decompressor(compressor));
pipeline.addLast("frameCompressor", new Frame.Compressor(compressor));
}
@@ -1744,6 +1757,39 @@
}
}
+ /**
+ * Rearranges the pipeline to deal with the new framing structure in protocol v5 and above. This
+ * has to be done manually, because it only happens once we've confirmed that the server supports
+ * v5.
+ */
+ void switchToV5Framing() {
+ assert factory.protocolVersion.compareTo(ProtocolVersion.V5) >= 0;
+
+ // We want to do this on the event loop, to make sure it doesn't race with incoming requests
+ assert channel.eventLoop().inEventLoop();
+
+ ChannelPipeline pipeline = channel.pipeline();
+ SegmentCodec segmentCodec =
+ new SegmentCodec(
+ channel.alloc(), factory.configuration.getProtocolOptions().getCompression());
+
+ // Outbound: "message -> segment -> bytes" instead of "message -> frame -> bytes"
+ Message.ProtocolEncoder requestEncoder =
+ (Message.ProtocolEncoder) pipeline.get("messageEncoder");
+ pipeline.replace(
+ "messageEncoder",
+ "messageToSegmentEncoder",
+ new MessageToSegmentEncoder(channel.alloc(), requestEncoder));
+ pipeline.replace(
+ "frameEncoder", "segmentToBytesEncoder", new SegmentToBytesEncoder(segmentCodec));
+
+ // Inbound: "frame <- segment <- bytes" instead of "frame <- bytes"
+ pipeline.replace(
+ "frameDecoder", "bytesToSegmentDecoder", new BytesToSegmentDecoder(segmentCodec));
+ pipeline.addAfter(
+ "bytesToSegmentDecoder", "segmentToFrameDecoder", new SegmentToFrameDecoder());
+ }
+
/** A component that "owns" a connection, and should be notified when it dies. */
interface Owner {
void onConnectionDefunct(Connection connection);
diff --git a/driver-core/src/main/java/com/datastax/driver/core/Crc.java b/driver-core/src/main/java/com/datastax/driver/core/Crc.java
new file mode 100644
index 0000000..2abcfbf
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/Crc.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.util.concurrent.FastThreadLocal;
+import java.nio.ByteBuffer;
+import java.util.zip.CRC32;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Copied and adapted from the server-side version. */
+class Crc {
+
+ private static final Logger logger = LoggerFactory.getLogger(Crc.class);
+
+ private static final FastThreadLocal<CRC32> crc32 =
+ new FastThreadLocal<CRC32>() {
+ @Override
+ protected CRC32 initialValue() {
+ return new CRC32();
+ }
+ };
+
+ private static final byte[] initialBytes =
+ new byte[] {(byte) 0xFA, (byte) 0x2D, (byte) 0x55, (byte) 0xCA};
+
+ private static final CrcUpdater CRC_UPDATER = selectCrcUpdater();
+
+ static int computeCrc32(ByteBuf buffer) {
+ CRC32 crc = newCrc32();
+ CRC_UPDATER.update(crc, buffer);
+ return (int) crc.getValue();
+ }
+
+ private static CRC32 newCrc32() {
+ CRC32 crc = crc32.get();
+ crc.reset();
+ crc.update(initialBytes);
+ return crc;
+ }
+
+ private static final int CRC24_INIT = 0x875060;
+ /**
+ * Polynomial chosen from https://users.ece.cmu.edu/~koopman/crc/index.html, by Philip Koopman
+ *
+ * <p>This webpage claims a copyright to Philip Koopman, which he licenses under the Creative
+ * Commons Attribution 4.0 International License (https://creativecommons.org/licenses/by/4.0)
+ *
+ * <p>It is unclear if this copyright can extend to a 'fact' such as this specific number,
+ * particularly as we do not use Koopman's notation to represent the polynomial, but we anyway
+ * attribute his work and link the terms of his license since they are not incompatible with our
+ * usage and we greatly appreciate his work.
+ *
+ * <p>This polynomial provides hamming distance of 8 for messages up to length 105 bits; we only
+ * support 8-64 bits at present, with an expected range of 40-48.
+ */
+ private static final int CRC24_POLY = 0x1974F0B;
+
+ /**
+ * NOTE: the order of bytes must reach the wire in the same order the CRC is computed, with the
+ * CRC immediately following in a trailer. Since we read in least significant byte order, if you
+ * write to a buffer using putInt or putLong, the byte order will be reversed and you will lose
+ * the guarantee of protection from burst corruptions of 24 bits in length.
+ *
+ * <p>Make sure either to write byte-by-byte to the wire, or to use Integer/Long.reverseBytes if
+ * you write to a BIG_ENDIAN buffer.
+ *
+ * <p>See http://users.ece.cmu.edu/~koopman/pubs/ray06_crcalgorithms.pdf
+ *
+ * <p>Complain to the ethernet spec writers, for having inverse bit to byte significance order.
+ *
+ * <p>Note we use the most naive algorithm here. We support at most 8 bytes, and typically supply
+ * 5 or fewer, so any efficiency of a table approach is swallowed by the time to hit L3, even for
+ * a tiny (4bit) table.
+ *
+ * @param bytes an up to 8-byte register containing bytes to compute the CRC over the bytes AND
+ * bits will be read least-significant to most significant.
+ * @param len the number of bytes, greater than 0 and fewer than 9, to be read from bytes
+ * @return the least-significant bit AND byte order crc24 using the CRC24_POLY polynomial
+ */
+ static int computeCrc24(long bytes, int len) {
+ int crc = CRC24_INIT;
+ while (len-- > 0) {
+ crc ^= (bytes & 0xff) << 16;
+ bytes >>= 8;
+
+ for (int i = 0; i < 8; i++) {
+ crc <<= 1;
+ if ((crc & 0x1000000) != 0) crc ^= CRC24_POLY;
+ }
+ }
+ return crc;
+ }
+
+ private static CrcUpdater selectCrcUpdater() {
+ try {
+ CRC32.class.getDeclaredMethod("update", ByteBuffer.class);
+ return new Java8CrcUpdater();
+ } catch (Exception e) {
+ logger.warn(
+ "It looks like you are running Java 7 or below. "
+ + "CRC checks (used in protocol {} and above) will require a memory copy, which can "
+ + "negatively impact performance. Consider using a more modern VM.",
+ ProtocolVersion.V5,
+ e);
+ return new Java6CrcUpdater();
+ }
+ }
+
+ private interface CrcUpdater {
+ void update(CRC32 crc, ByteBuf buffer);
+ }
+
+ private static class Java6CrcUpdater implements CrcUpdater {
+ @Override
+ public void update(CRC32 crc, ByteBuf buffer) {
+ if (buffer.hasArray()) {
+ crc.update(buffer.array(), buffer.arrayOffset(), buffer.readableBytes());
+ } else {
+ byte[] bytes = new byte[buffer.readableBytes()];
+ buffer.getBytes(buffer.readerIndex(), bytes);
+ crc.update(bytes);
+ }
+ }
+ }
+
+ @IgnoreJDK6Requirement
+ private static class Java8CrcUpdater implements CrcUpdater {
+ @Override
+ public void update(CRC32 crc, ByteBuf buffer) {
+ crc.update(buffer.internalNioBuffer(buffer.readerIndex(), buffer.readableBytes()));
+ }
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/Frame.java b/driver-core/src/main/java/com/datastax/driver/core/Frame.java
index 36d6ee3..efe5949 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/Frame.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/Frame.java
@@ -87,29 +87,14 @@
final Header header;
final ByteBuf body;
- private Frame(Header header, ByteBuf body) {
+ Frame(Header header, ByteBuf body) {
this.header = header;
this.body = body;
}
private static Frame create(ByteBuf fullFrame) {
- assert fullFrame.readableBytes() >= 1
- : String.format("Frame too short (%d bytes)", fullFrame.readableBytes());
-
- int versionBytes = fullFrame.readByte();
- // version first byte is the "direction" of the frame (request or response)
- ProtocolVersion version = ProtocolVersion.fromInt(versionBytes & 0x7F);
- int hdrLen = Header.lengthFor(version);
- assert fullFrame.readableBytes() >= (hdrLen - 1)
- : String.format("Frame too short (%d bytes)", fullFrame.readableBytes());
-
- int flags = fullFrame.readByte();
- int streamId = readStreamid(fullFrame, version);
- int opcode = fullFrame.readByte();
- int length = fullFrame.readInt();
- assert length == fullFrame.readableBytes();
-
- Header header = new Header(version, flags, streamId, opcode);
+ Header header = Header.decode(fullFrame);
+ assert header.bodyLength == fullFrame.readableBytes();
return new Frame(header, fullFrame);
}
@@ -129,7 +114,7 @@
static Frame create(
ProtocolVersion version, int opcode, int streamId, EnumSet<Header.Flag> flags, ByteBuf body) {
- Header header = new Header(version, flags, streamId, opcode);
+ Header header = new Header(version, flags, streamId, opcode, body.readableBytes());
return new Frame(header, body);
}
@@ -139,16 +124,22 @@
final EnumSet<Flag> flags;
final int streamId;
final int opcode;
+ final int bodyLength;
- private Header(ProtocolVersion version, int flags, int streamId, int opcode) {
- this(version, Flag.deserialize(flags), streamId, opcode);
+ private Header(ProtocolVersion version, int flags, int streamId, int opcode, int bodyLength) {
+ this(version, Flag.deserialize(flags), streamId, opcode, bodyLength);
}
- private Header(ProtocolVersion version, EnumSet<Flag> flags, int streamId, int opcode) {
+ Header(ProtocolVersion version, EnumSet<Flag> flags, int streamId, int opcode, int bodyLength) {
this.version = version;
this.flags = flags;
this.streamId = streamId;
this.opcode = opcode;
+ this.bodyLength = bodyLength;
+ }
+
+ Header withNewBodyLength(int newBodyLength) {
+ return new Header(version, flags, streamId, opcode, newBodyLength);
}
/**
@@ -171,6 +162,46 @@
}
}
+ public void encodeInto(ByteBuf destination) {
+ // Don't bother with the direction, we only send requests.
+ destination.writeByte(version.toInt());
+ destination.writeByte(Flag.serialize(flags));
+ switch (version) {
+ case V1:
+ case V2:
+ destination.writeByte(streamId);
+ break;
+ case V3:
+ case V4:
+ case V5:
+ destination.writeShort(streamId);
+ break;
+ default:
+ throw version.unsupported();
+ }
+ destination.writeByte(opcode);
+ destination.writeInt(bodyLength);
+ }
+
+ static Header decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 1
+ : String.format("Frame too short (%d bytes)", buffer.readableBytes());
+
+ int versionBytes = buffer.readByte();
+ // version first byte is the "direction" of the frame (request or response)
+ ProtocolVersion version = ProtocolVersion.fromInt(versionBytes & 0x7F);
+ int hdrLen = Header.lengthFor(version);
+ assert buffer.readableBytes() >= (hdrLen - 1)
+ : String.format("Frame too short (%d bytes)", buffer.readableBytes());
+
+ int flags = buffer.readByte();
+ int streamId = readStreamid(buffer, version);
+ int opcode = buffer.readByte();
+ int length = buffer.readInt();
+
+ return new Header(version, flags, streamId, opcode, length);
+ }
+
enum Flag {
// The order of that enum matters!!
COMPRESSED,
@@ -197,7 +228,7 @@
}
Frame with(ByteBuf newBody) {
- return new Frame(header, newBody);
+ return new Frame(header.withNewBodyLength(newBody.readableBytes()), newBody);
}
static final class Decoder extends ByteToMessageDecoder {
@@ -273,32 +304,11 @@
throws Exception {
ProtocolVersion protocolVersion = frame.header.version;
ByteBuf header = ctx.alloc().ioBuffer(Frame.Header.lengthFor(protocolVersion));
- // We don't bother with the direction, we only send requests.
- header.writeByte(frame.header.version.toInt());
- header.writeByte(Header.Flag.serialize(frame.header.flags));
- writeStreamId(frame.header.streamId, header, protocolVersion);
- header.writeByte(frame.header.opcode);
- header.writeInt(frame.body.readableBytes());
+ frame.header.encodeInto(header);
out.add(header);
out.add(frame.body);
}
-
- private void writeStreamId(int streamId, ByteBuf header, ProtocolVersion protocolVersion) {
- switch (protocolVersion) {
- case V1:
- case V2:
- header.writeByte(streamId);
- break;
- case V3:
- case V4:
- case V5:
- header.writeShort(streamId);
- break;
- default:
- throw protocolVersion.unsupported();
- }
- }
}
static class Decompressor extends MessageToMessageDecoder<Frame> {
diff --git a/driver-core/src/main/java/com/datastax/driver/core/FrameCompressor.java b/driver-core/src/main/java/com/datastax/driver/core/FrameCompressor.java
index d38e230..a7d6daf 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/FrameCompressor.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/FrameCompressor.java
@@ -23,8 +23,17 @@
abstract Frame compress(Frame frame) throws IOException;
+ /**
+ * Unlike {@link #compress(Frame)}, this variant does not store the uncompressed length if the
+ * underlying algorithm does not do it natively (like LZ4). It must be stored separately and
+ * passed back to {@link #decompress(ByteBuf, int)}.
+ */
+ abstract ByteBuf compress(ByteBuf buffer) throws IOException;
+
abstract Frame decompress(Frame frame) throws IOException;
+ abstract ByteBuf decompress(ByteBuf buffer, int uncompressedLength) throws IOException;
+
protected static ByteBuffer inputNioBuffer(ByteBuf buf) {
// Using internalNioBuffer(...) as we only hold the reference in this method and so can
// reduce Object allocations.
diff --git a/driver-core/src/main/java/com/datastax/driver/core/LZ4Compressor.java b/driver-core/src/main/java/com/datastax/driver/core/LZ4Compressor.java
index c72208f..a20f400 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/LZ4Compressor.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/LZ4Compressor.java
@@ -59,22 +59,41 @@
@Override
Frame compress(Frame frame) throws IOException {
ByteBuf input = frame.body;
- ByteBuf frameBody = input.isDirect() ? compressDirect(input) : compressHeap(input);
+ ByteBuf frameBody = compress(input, true);
return frame.with(frameBody);
}
- private ByteBuf compressDirect(ByteBuf input) throws IOException {
+ @Override
+ ByteBuf compress(ByteBuf buffer) throws IOException {
+ return compress(buffer, false);
+ }
+
+ private ByteBuf compress(ByteBuf buffer, boolean prependWithUncompressedLength)
+ throws IOException {
+ return buffer.isDirect()
+ ? compressDirect(buffer, prependWithUncompressedLength)
+ : compressHeap(buffer, prependWithUncompressedLength);
+ }
+
+ private ByteBuf compressDirect(ByteBuf input, boolean prependWithUncompressedLength)
+ throws IOException {
int maxCompressedLength = compressor.maxCompressedLength(input.readableBytes());
// If the input is direct we will allocate a direct output buffer as well as this will allow us
// to use
// LZ4Compressor.compress and so eliminate memory copies.
- ByteBuf output = input.alloc().directBuffer(INTEGER_BYTES + maxCompressedLength);
+ ByteBuf output =
+ input
+ .alloc()
+ .directBuffer(
+ (prependWithUncompressedLength ? INTEGER_BYTES : 0) + maxCompressedLength);
try {
ByteBuffer in = inputNioBuffer(input);
// Increase reader index.
input.readerIndex(input.writerIndex());
- output.writeInt(in.remaining());
+ if (prependWithUncompressedLength) {
+ output.writeInt(in.remaining());
+ }
ByteBuffer out = outputNioBuffer(output);
int written =
@@ -90,7 +109,8 @@
return output;
}
- private ByteBuf compressHeap(ByteBuf input) throws IOException {
+ private ByteBuf compressHeap(ByteBuf input, boolean prependWithUncompressedLength)
+ throws IOException {
int maxCompressedLength = compressor.maxCompressedLength(input.readableBytes());
// Not a direct buffer so use byte arrays...
@@ -103,9 +123,14 @@
// Allocate a heap buffer from the ByteBufAllocator as we may use a PooledByteBufAllocator and
// so
// can eliminate the overhead of allocate a new byte[].
- ByteBuf output = input.alloc().heapBuffer(INTEGER_BYTES + maxCompressedLength);
+ ByteBuf output =
+ input
+ .alloc()
+ .heapBuffer((prependWithUncompressedLength ? INTEGER_BYTES : 0) + maxCompressedLength);
try {
- output.writeInt(len);
+ if (prependWithUncompressedLength) {
+ output.writeInt(len);
+ }
// calculate the correct offset.
int offset = output.arrayOffset() + output.writerIndex();
byte[] out = output.array();
@@ -124,16 +149,23 @@
@Override
Frame decompress(Frame frame) throws IOException {
ByteBuf input = frame.body;
- ByteBuf frameBody = input.isDirect() ? decompressDirect(input) : decompressHeap(input);
+ int uncompressedLength = input.readInt();
+ ByteBuf frameBody = decompress(input, uncompressedLength);
return frame.with(frameBody);
}
- private ByteBuf decompressDirect(ByteBuf input) throws IOException {
+ @Override
+ ByteBuf decompress(ByteBuf buffer, int uncompressedLength) throws IOException {
+ return buffer.isDirect()
+ ? decompressDirect(buffer, uncompressedLength)
+ : decompressHeap(buffer, uncompressedLength);
+ }
+
+ private ByteBuf decompressDirect(ByteBuf input, int uncompressedLength) throws IOException {
// If the input is direct we will allocate a direct output buffer as well as this will allow us
// to use
// LZ4Compressor.decompress and so eliminate memory copies.
int readable = input.readableBytes();
- int uncompressedLength = input.readInt();
ByteBuffer in = inputNioBuffer(input);
// Increase reader index.
input.readerIndex(input.writerIndex());
@@ -141,7 +173,7 @@
try {
ByteBuffer out = outputNioBuffer(output);
int read = decompressor.decompress(in, in.position(), out, out.position(), out.remaining());
- if (read != readable - INTEGER_BYTES) throw new IOException("Compressed lengths mismatch");
+ if (read != readable) throw new IOException("Compressed lengths mismatch");
// Set the writer index so the amount of written bytes is reflected
output.writerIndex(output.writerIndex() + uncompressedLength);
@@ -153,11 +185,10 @@
return output;
}
- private ByteBuf decompressHeap(ByteBuf input) throws IOException {
+ private ByteBuf decompressHeap(ByteBuf input, int uncompressedLength) throws IOException {
// Not a direct buffer so use byte arrays...
byte[] in = input.array();
int len = input.readableBytes();
- int uncompressedLength = input.readInt();
int inOffset = input.arrayOffset() + input.readerIndex();
// Increase reader index.
input.readerIndex(input.writerIndex());
@@ -170,7 +201,7 @@
int offset = output.arrayOffset() + output.writerIndex();
byte out[] = output.array();
int read = decompressor.decompress(in, inOffset, out, offset, uncompressedLength);
- if (read != len - INTEGER_BYTES) throw new IOException("Compressed lengths mismatch");
+ if (read != len) throw new IOException("Compressed lengths mismatch");
// Set the writer index so the amount of written bytes is reflected
output.writerIndex(output.writerIndex() + uncompressedLength);
diff --git a/driver-core/src/main/java/com/datastax/driver/core/Message.java b/driver-core/src/main/java/com/datastax/driver/core/Message.java
index 74e3adb..205c333 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/Message.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/Message.java
@@ -310,7 +310,7 @@
@ChannelHandler.Sharable
static class ProtocolEncoder extends MessageToMessageEncoder<Request> {
- private final ProtocolVersion protocolVersion;
+ final ProtocolVersion protocolVersion;
ProtocolEncoder(ProtocolVersion version) {
this.protocolVersion = version;
@@ -319,35 +319,13 @@
@Override
protected void encode(ChannelHandlerContext ctx, Request request, List<Object> out)
throws Exception {
- EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class);
- if (request.isTracingRequested()) flags.add(Frame.Header.Flag.TRACING);
- if (protocolVersion == ProtocolVersion.NEWEST_BETA) flags.add(Frame.Header.Flag.USE_BETA);
- Map<String, ByteBuffer> customPayload = request.getCustomPayload();
- if (customPayload != null) {
- if (protocolVersion.compareTo(ProtocolVersion.V4) < 0)
- throw new UnsupportedFeatureException(
- protocolVersion, "Custom payloads are only supported since native protocol V4");
- flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD);
- }
-
+ EnumSet<Frame.Header.Flag> flags = computeFlags(request);
+ int messageSize = encodedSize(request);
+ ByteBuf body = ctx.alloc().buffer(messageSize);
@SuppressWarnings("unchecked")
Coder<Request> coder = (Coder<Request>) request.type.coder;
- int messageSize = coder.encodedSize(request, protocolVersion);
- int payloadLength = -1;
- if (customPayload != null) {
- payloadLength = CBUtil.sizeOfBytesMap(customPayload);
- messageSize += payloadLength;
- }
- ByteBuf body = ctx.alloc().buffer(messageSize);
- if (customPayload != null) {
- CBUtil.writeBytesMap(customPayload, body);
- if (logger.isTraceEnabled()) {
- logger.trace(
- "Sending payload: {} ({} bytes total)", printPayload(customPayload), payloadLength);
- }
- }
-
coder.encode(request, body, protocolVersion);
+
if (body.capacity() != messageSize) {
logger.debug(
"Detected buffer resizing while encoding {} message ({} => {}), "
@@ -360,6 +338,50 @@
out.add(
Frame.create(protocolVersion, request.type.opcode, request.getStreamId(), flags, body));
}
+
+ EnumSet<Frame.Header.Flag> computeFlags(Request request) {
+ EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class);
+ if (request.isTracingRequested()) flags.add(Frame.Header.Flag.TRACING);
+ if (protocolVersion == ProtocolVersion.NEWEST_BETA) flags.add(Frame.Header.Flag.USE_BETA);
+ Map<String, ByteBuffer> customPayload = request.getCustomPayload();
+ if (customPayload != null) {
+ if (protocolVersion.compareTo(ProtocolVersion.V4) < 0)
+ throw new UnsupportedFeatureException(
+ protocolVersion, "Custom payloads are only supported since native protocol V4");
+ flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD);
+ }
+ return flags;
+ }
+
+ int encodedSize(Request request) {
+ @SuppressWarnings("unchecked")
+ Coder<Request> coder = (Coder<Request>) request.type.coder;
+ int messageSize = coder.encodedSize(request, protocolVersion);
+ int payloadLength = -1;
+ if (request.getCustomPayload() != null) {
+ payloadLength = CBUtil.sizeOfBytesMap(request.getCustomPayload());
+ messageSize += payloadLength;
+ }
+ return messageSize;
+ }
+
+ void encode(Request request, ByteBuf destination) {
+ @SuppressWarnings("unchecked")
+ Coder<Request> coder = (Coder<Request>) request.type.coder;
+
+ Map<String, ByteBuffer> customPayload = request.getCustomPayload();
+ if (customPayload != null) {
+ CBUtil.writeBytesMap(customPayload, destination);
+ if (logger.isTraceEnabled()) {
+ logger.trace(
+ "Sending payload: {} ({} bytes total)",
+ printPayload(customPayload),
+ CBUtil.sizeOfBytesMap(customPayload));
+ }
+ }
+
+ coder.encode(request, destination, protocolVersion);
+ }
}
// private stuff to debug custom payloads
diff --git a/driver-core/src/main/java/com/datastax/driver/core/MessageToSegmentEncoder.java b/driver-core/src/main/java/com/datastax/driver/core/MessageToSegmentEncoder.java
new file mode 100644
index 0000000..41924e6
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/MessageToSegmentEncoder.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+
+class MessageToSegmentEncoder extends ChannelOutboundHandlerAdapter {
+
+ private final ByteBufAllocator allocator;
+ private final Message.ProtocolEncoder requestEncoder;
+
+ private SegmentBuilder segmentBuilder;
+
+ MessageToSegmentEncoder(ByteBufAllocator allocator, Message.ProtocolEncoder requestEncoder) {
+ this.allocator = allocator;
+ this.requestEncoder = requestEncoder;
+ }
+
+ @Override
+ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+ super.handlerAdded(ctx);
+ this.segmentBuilder = new SegmentBuilder(ctx, allocator, requestEncoder);
+ }
+
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+ if (msg instanceof Message.Request) {
+ segmentBuilder.addRequest(((Message.Request) msg), promise);
+ } else {
+ super.write(ctx, msg, promise);
+ }
+ }
+
+ @Override
+ public void flush(ChannelHandlerContext ctx) throws Exception {
+ segmentBuilder.flush();
+ super.flush(ctx);
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/Segment.java b/driver-core/src/main/java/com/datastax/driver/core/Segment.java
new file mode 100644
index 0000000..231c500
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/Segment.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A container of {@link Frame}s in protocol v5 and above. This is a new protocol construct that
+ * allows checksumming and compressing multiple messages together.
+ *
+ * <p>{@link #getPayload()} contains either:
+ *
+ * <ul>
+ * <li>a sequence of encoded {@link Frame}s, all concatenated together. In this case, {@link
+ * #isSelfContained()} return true.
+ * <li>or a slice of an encoded large {@link Frame} (if that frame is longer than {@link
+ * #MAX_PAYLOAD_LENGTH}). In this case, {@link #isSelfContained()} returns false.
+ * </ul>
+ *
+ * The payload is not compressed; compression is handled at a lower level when encoding or decoding
+ * this object.
+ *
+ * <p>Naming is provisional: "segment" is not the official name, I picked it arbitrarily for the
+ * driver code to avoid a name clash. It's possible that this type will be renamed to "frame", and
+ * {@link Frame} to something else, at some point in the future (this is an ongoing discussion on
+ * the server ticket).
+ */
+class Segment {
+
+ static int MAX_PAYLOAD_LENGTH = 128 * 1024 - 1;
+
+ private final ByteBuf payload;
+ private final boolean isSelfContained;
+
+ Segment(ByteBuf payload, boolean isSelfContained) {
+ this.payload = payload;
+ this.isSelfContained = isSelfContained;
+ }
+
+ public ByteBuf getPayload() {
+ return payload;
+ }
+
+ public boolean isSelfContained() {
+ return isSelfContained;
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/SegmentBuilder.java b/driver-core/src/main/java/com/datastax/driver/core/SegmentBuilder.java
new file mode 100644
index 0000000..ddf0235
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/SegmentBuilder.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
+import java.util.ArrayList;
+import java.util.List;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Abstracts the details of batching a sequence of {@link Message.Request}s into one or more {@link
+ * Segment}s before sending them out on the network.
+ *
+ * <p>This class is not thread-safe.
+ */
+class SegmentBuilder {
+
+ private static final Logger logger = LoggerFactory.getLogger(SegmentBuilder.class);
+
+ private final ChannelHandlerContext context;
+ private final ByteBufAllocator allocator;
+ private final int maxPayloadLength;
+ private final Message.ProtocolEncoder requestEncoder;
+
+ private final List<Frame.Header> currentPayloadHeaders = new ArrayList<Frame.Header>();
+ private final List<Message.Request> currentPayloadBodies = new ArrayList<Message.Request>();
+ private final List<ChannelPromise> currentPayloadPromises = new ArrayList<ChannelPromise>();
+ private int currentPayloadLength;
+
+ SegmentBuilder(
+ ChannelHandlerContext context,
+ ByteBufAllocator allocator,
+ Message.ProtocolEncoder requestEncoder) {
+ this(context, allocator, requestEncoder, Segment.MAX_PAYLOAD_LENGTH);
+ }
+
+ /** Exposes the max length for unit tests; in production, this is hard-coded. */
+ @VisibleForTesting
+ SegmentBuilder(
+ ChannelHandlerContext context,
+ ByteBufAllocator allocator,
+ Message.ProtocolEncoder requestEncoder,
+ int maxPayloadLength) {
+ this.context = context;
+ this.allocator = allocator;
+ this.requestEncoder = requestEncoder;
+ this.maxPayloadLength = maxPayloadLength;
+ }
+
+ /**
+ * Adds a new request. It will be encoded into one or more segments, that will be passed to {@link
+ * #processSegment(Segment, ChannelPromise)} at some point in the future.
+ *
+ * <p>The caller <b>must</b> invoke {@link #flush()} after the last request.
+ */
+ public void addRequest(Message.Request request, ChannelPromise promise) {
+
+ // Wrap the request into a legacy frame, append that frame to the payload.
+ int frameHeaderLength = Frame.Header.lengthFor(requestEncoder.protocolVersion);
+ int frameBodyLength = requestEncoder.encodedSize(request);
+ int frameLength = frameHeaderLength + frameBodyLength;
+
+ Frame.Header header =
+ new Frame.Header(
+ requestEncoder.protocolVersion,
+ requestEncoder.computeFlags(request),
+ request.getStreamId(),
+ request.type.opcode,
+ frameBodyLength);
+
+ if (frameLength > maxPayloadLength) {
+ // Large request: split into multiple dedicated segments and process them immediately:
+ ByteBuf frame = allocator.ioBuffer(frameLength);
+ header.encodeInto(frame);
+ requestEncoder.encode(request, frame);
+
+ int sliceCount =
+ (frameLength / maxPayloadLength) + (frameLength % maxPayloadLength == 0 ? 0 : 1);
+
+ logger.trace(
+ "Splitting large request ({} bytes) into {} segments: {}",
+ frameLength,
+ sliceCount,
+ request);
+
+ List<ChannelPromise> segmentPromises = split(promise, sliceCount);
+ int i = 0;
+ do {
+ ByteBuf part = frame.readSlice(Math.min(maxPayloadLength, frame.readableBytes()));
+ part.retain();
+ process(part, false, segmentPromises.get(i++));
+ } while (frame.isReadable());
+ // We've retained each slice, and won't reference this buffer anymore
+ frame.release();
+ } else {
+ // Small request: append to an existing segment, together with other messages.
+ if (currentPayloadLength + frameLength > maxPayloadLength) {
+ // Current segment is full, process and start a new one:
+ processCurrentPayload();
+ resetCurrentPayload();
+ }
+ // Append frame to current segment
+ logger.trace(
+ "Adding {}th request to self-contained segment: {}",
+ currentPayloadHeaders.size() + 1,
+ request);
+ currentPayloadHeaders.add(header);
+ currentPayloadBodies.add(request);
+ currentPayloadPromises.add(promise);
+ currentPayloadLength += frameLength;
+ }
+ }
+
+ /**
+ * Signals that we're done adding requests.
+ *
+ * <p>This must be called after adding the last request, it will possibly trigger the generation
+ * of one last segment.
+ */
+ public void flush() {
+ if (currentPayloadLength > 0) {
+ processCurrentPayload();
+ resetCurrentPayload();
+ }
+ }
+
+ /** What to do whenever a full segment is ready. */
+ protected void processSegment(Segment segment, ChannelPromise segmentPromise) {
+ context.write(segment, segmentPromise);
+ }
+
+ private void process(ByteBuf payload, boolean isSelfContained, ChannelPromise segmentPromise) {
+ processSegment(new Segment(payload, isSelfContained), segmentPromise);
+ }
+
+ private void processCurrentPayload() {
+ int requestCount = currentPayloadHeaders.size();
+ assert currentPayloadBodies.size() == requestCount
+ && currentPayloadPromises.size() == requestCount;
+ logger.trace("Emitting new self-contained segment with {} frame(s)", requestCount);
+ ByteBuf payload = this.allocator.ioBuffer(currentPayloadLength);
+ for (int i = 0; i < requestCount; i++) {
+ Frame.Header header = currentPayloadHeaders.get(i);
+ Message.Request request = currentPayloadBodies.get(i);
+ header.encodeInto(payload);
+ requestEncoder.encode(request, payload);
+ }
+ process(payload, true, merge(currentPayloadPromises));
+ }
+
+ private void resetCurrentPayload() {
+ currentPayloadHeaders.clear();
+ currentPayloadBodies.clear();
+ currentPayloadPromises.clear();
+ currentPayloadLength = 0;
+ }
+
+ // Merges multiple promises into a single one, that will notify all of them when done.
+ // This is used when multiple requests are sent as a single segment.
+ private ChannelPromise merge(List<ChannelPromise> framePromises) {
+ if (framePromises.size() == 1) {
+ return framePromises.get(0);
+ }
+ ChannelPromise segmentPromise = context.newPromise();
+ final ImmutableList<ChannelPromise> dependents = ImmutableList.copyOf(framePromises);
+ segmentPromise.addListener(
+ new GenericFutureListener<Future<? super Void>>() {
+ @Override
+ public void operationComplete(Future<? super Void> future) throws Exception {
+ if (future.isSuccess()) {
+ for (ChannelPromise framePromise : dependents) {
+ framePromise.setSuccess();
+ }
+ } else {
+ Throwable cause = future.cause();
+ for (ChannelPromise framePromise : dependents) {
+ framePromise.setFailure(cause);
+ }
+ }
+ }
+ });
+ return segmentPromise;
+ }
+
+ // Splits a single promise into multiple ones. The original promise will complete when all the
+ // splits have.
+ // This is used when a single request is sliced into multiple segment.
+ private List<ChannelPromise> split(ChannelPromise framePromise, int sliceCount) {
+ // We split one frame into multiple slices. When all slices are written, the frame is written.
+ List<ChannelPromise> slicePromises = new ArrayList<ChannelPromise>(sliceCount);
+ for (int i = 0; i < sliceCount; i++) {
+ slicePromises.add(context.newPromise());
+ }
+ GenericFutureListener<Future<Void>> sliceListener =
+ new SliceWriteListener(framePromise, slicePromises);
+ for (int i = 0; i < sliceCount; i++) {
+ slicePromises.get(i).addListener(sliceListener);
+ }
+ return slicePromises;
+ }
+
+ static class SliceWriteListener implements GenericFutureListener<Future<Void>> {
+
+ private final ChannelPromise parentPromise;
+ private final List<ChannelPromise> slicePromises;
+
+ // All slices are written to the same channel, and the segment is built from the Flusher which
+ // also runs on the same event loop, so we don't need synchronization.
+ private int remainingSlices;
+
+ SliceWriteListener(ChannelPromise parentPromise, List<ChannelPromise> slicePromises) {
+ this.parentPromise = parentPromise;
+ this.slicePromises = slicePromises;
+ this.remainingSlices = slicePromises.size();
+ }
+
+ @Override
+ public void operationComplete(Future<Void> future) {
+ if (!parentPromise.isDone()) {
+ if (future.isSuccess()) {
+ remainingSlices -= 1;
+ if (remainingSlices == 0) {
+ parentPromise.setSuccess();
+ }
+ } else {
+ // If any slice fails, we can immediately mark the whole frame as failed:
+ parentPromise.setFailure(future.cause());
+ // Cancel any remaining slice, Netty will not send the bytes.
+ for (ChannelPromise slicePromise : slicePromises) {
+ slicePromise.cancel(/*Netty ignores this*/ false);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/SegmentCodec.java b/driver-core/src/main/java/com/datastax/driver/core/SegmentCodec.java
new file mode 100644
index 0000000..66122e4
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/SegmentCodec.java
@@ -0,0 +1,210 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import com.datastax.driver.core.exceptions.CrcMismatchException;
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import java.io.IOException;
+import java.util.List;
+
+class SegmentCodec {
+
+ private static final int COMPRESSED_HEADER_LENGTH = 5;
+ private static final int UNCOMPRESSED_HEADER_LENGTH = 3;
+ static final int CRC24_LENGTH = 3;
+ static final int CRC32_LENGTH = 4;
+
+ private final ByteBufAllocator allocator;
+ private final boolean compress;
+ private final FrameCompressor compressor;
+
+ SegmentCodec(ByteBufAllocator allocator, ProtocolOptions.Compression compression) {
+ this.allocator = allocator;
+ this.compress = compression != ProtocolOptions.Compression.NONE;
+ this.compressor = compression.compressor();
+ }
+
+ /** The length of the segment header, excluding the 3-byte trailing CRC. */
+ int headerLength() {
+ return compress ? COMPRESSED_HEADER_LENGTH : UNCOMPRESSED_HEADER_LENGTH;
+ }
+
+ void encode(Segment segment, List<Object> out) throws IOException {
+ ByteBuf uncompressedPayload = segment.getPayload();
+ int uncompressedPayloadLength = uncompressedPayload.readableBytes();
+ assert uncompressedPayloadLength <= Segment.MAX_PAYLOAD_LENGTH;
+ ByteBuf encodedPayload;
+ if (compress) {
+ uncompressedPayload.markReaderIndex();
+ ByteBuf compressedPayload = compressor.compress(uncompressedPayload);
+ if (compressedPayload.readableBytes() >= uncompressedPayloadLength) {
+ // Skip compression if it's not worth it
+ uncompressedPayload.resetReaderIndex();
+ encodedPayload = uncompressedPayload;
+ compressedPayload.release();
+ // By convention, this is how we signal this to the server:
+ uncompressedPayloadLength = 0;
+ } else {
+ encodedPayload = compressedPayload;
+ uncompressedPayload.release();
+ }
+ } else {
+ encodedPayload = uncompressedPayload;
+ }
+ int payloadLength = encodedPayload.readableBytes();
+
+ ByteBuf header =
+ encodeHeader(payloadLength, uncompressedPayloadLength, segment.isSelfContained());
+
+ int payloadCrc = Crc.computeCrc32(encodedPayload);
+ ByteBuf trailer = allocator.ioBuffer(CRC32_LENGTH);
+ for (int i = 0; i < CRC32_LENGTH; i++) {
+ trailer.writeByte(payloadCrc & 0xFF);
+ payloadCrc >>= 8;
+ }
+
+ out.add(header);
+ out.add(encodedPayload);
+ out.add(trailer);
+ }
+
+ @VisibleForTesting
+ ByteBuf encodeHeader(int payloadLength, int uncompressedLength, boolean isSelfContained) {
+ assert payloadLength <= Segment.MAX_PAYLOAD_LENGTH;
+ assert !compress || uncompressedLength <= Segment.MAX_PAYLOAD_LENGTH;
+
+ int headerLength = headerLength();
+
+ long headerData = payloadLength;
+ int flagOffset = 17;
+ if (compress) {
+ headerData |= (long) uncompressedLength << 17;
+ flagOffset += 17;
+ }
+ if (isSelfContained) {
+ headerData |= 1L << flagOffset;
+ }
+
+ int headerCrc = Crc.computeCrc24(headerData, headerLength);
+
+ ByteBuf header = allocator.ioBuffer(headerLength + CRC24_LENGTH);
+ // Write both data and CRC in little-endian order
+ for (int i = 0; i < headerLength; i++) {
+ int shift = i * 8;
+ header.writeByte((int) (headerData >> shift & 0xFF));
+ }
+ for (int i = 0; i < CRC24_LENGTH; i++) {
+ int shift = i * 8;
+ header.writeByte(headerCrc >> shift & 0xFF);
+ }
+ return header;
+ }
+
+ /**
+ * Decodes a segment header and checks its CRC. It is assumed that the caller has already checked
+ * that there are enough bytes.
+ */
+ Header decodeHeader(ByteBuf buffer) throws CrcMismatchException {
+ int headerLength = headerLength();
+ assert buffer.readableBytes() >= headerLength + CRC24_LENGTH;
+
+ // Read header data (little endian):
+ long headerData = 0;
+ for (int i = 0; i < headerLength; i++) {
+ headerData |= (buffer.readByte() & 0xFFL) << 8 * i;
+ }
+
+ // Read CRC (little endian) and check it:
+ int expectedHeaderCrc = 0;
+ for (int i = 0; i < CRC24_LENGTH; i++) {
+ expectedHeaderCrc |= (buffer.readByte() & 0xFF) << 8 * i;
+ }
+ int actualHeaderCrc = Crc.computeCrc24(headerData, headerLength);
+ if (actualHeaderCrc != expectedHeaderCrc) {
+ throw new CrcMismatchException(
+ String.format(
+ "CRC mismatch on header %s. Received %s, computed %s.",
+ Long.toHexString(headerData),
+ Integer.toHexString(expectedHeaderCrc),
+ Integer.toHexString(actualHeaderCrc)));
+ }
+
+ int payloadLength = (int) headerData & Segment.MAX_PAYLOAD_LENGTH;
+ headerData >>= 17;
+ int uncompressedPayloadLength;
+ if (compress) {
+ uncompressedPayloadLength = (int) headerData & Segment.MAX_PAYLOAD_LENGTH;
+ headerData >>= 17;
+ } else {
+ uncompressedPayloadLength = -1;
+ }
+ boolean isSelfContained = (headerData & 1) == 1;
+ return new Header(payloadLength, uncompressedPayloadLength, isSelfContained);
+ }
+
+ /**
+ * Decodes the rest of a segment from a previously decoded header, and checks the payload's CRC.
+ * It is assumed that the caller has already checked that there are enough bytes.
+ */
+ Segment decode(Header header, ByteBuf buffer) throws CrcMismatchException, IOException {
+ assert buffer.readableBytes() == header.payloadLength + CRC32_LENGTH;
+
+ // Extract payload:
+ ByteBuf encodedPayload = buffer.readSlice(header.payloadLength);
+ encodedPayload.retain();
+
+ // Read and check CRC:
+ int expectedPayloadCrc = 0;
+ for (int i = 0; i < CRC32_LENGTH; i++) {
+ expectedPayloadCrc |= (buffer.readByte() & 0xFF) << 8 * i;
+ }
+ buffer.release(); // done with this (we retained the payload independently)
+ int actualPayloadCrc = Crc.computeCrc32(encodedPayload);
+ if (actualPayloadCrc != expectedPayloadCrc) {
+ encodedPayload.release();
+ throw new CrcMismatchException(
+ String.format(
+ "CRC mismatch on payload. Received %s, computed %s.",
+ Integer.toHexString(expectedPayloadCrc), Integer.toHexString(actualPayloadCrc)));
+ }
+
+ // Decompress payload if needed:
+ ByteBuf payload;
+ if (compress && header.uncompressedPayloadLength > 0) {
+ payload = compressor.decompress(encodedPayload, header.uncompressedPayloadLength);
+ encodedPayload.release();
+ } else {
+ payload = encodedPayload;
+ }
+
+ return new Segment(payload, header.isSelfContained);
+ }
+
+ /** Temporary holder for header data. During decoding, it is convenient to store it separately. */
+ static class Header {
+ final int payloadLength;
+ final int uncompressedPayloadLength;
+ final boolean isSelfContained;
+
+ public Header(int payloadLength, int uncompressedPayloadLength, boolean isSelfContained) {
+ this.payloadLength = payloadLength;
+ this.uncompressedPayloadLength = uncompressedPayloadLength;
+ this.isSelfContained = isSelfContained;
+ }
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/SegmentToBytesEncoder.java b/driver-core/src/main/java/com/datastax/driver/core/SegmentToBytesEncoder.java
new file mode 100644
index 0000000..f4cd8b4
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/SegmentToBytesEncoder.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import java.util.List;
+
+@ChannelHandler.Sharable
+class SegmentToBytesEncoder extends MessageToMessageEncoder<Segment> {
+
+ private final SegmentCodec codec;
+
+ SegmentToBytesEncoder(SegmentCodec codec) {
+ super(Segment.class);
+ this.codec = codec;
+ }
+
+ @Override
+ protected void encode(ChannelHandlerContext ctx, Segment segment, List<Object> out)
+ throws Exception {
+ codec.encode(segment, out);
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/SegmentToFrameDecoder.java b/driver-core/src/main/java/com/datastax/driver/core/SegmentToFrameDecoder.java
new file mode 100644
index 0000000..095d383
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/SegmentToFrameDecoder.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import com.datastax.driver.core.Frame.Header;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import java.util.ArrayList;
+import java.util.List;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Converts the segments decoded by {@link BytesToSegmentDecoder} into legacy frames understood by
+ * the rest of the driver.
+ */
+class SegmentToFrameDecoder extends MessageToMessageDecoder<Segment> {
+
+ private static final Logger logger = LoggerFactory.getLogger(SegmentToFrameDecoder.class);
+
+ // Accumulated state when we are reading a sequence of slices
+ private Header pendingHeader;
+ private final List<ByteBuf> accumulatedSlices = new ArrayList<ByteBuf>();
+ private int accumulatedLength;
+
+ SegmentToFrameDecoder() {
+ super(Segment.class);
+ }
+
+ @Override
+ protected void decode(ChannelHandlerContext ctx, Segment segment, List<Object> out) {
+ if (segment.isSelfContained()) {
+ decodeSelfContained(segment, out);
+ } else {
+ decodeSlice(segment, ctx.alloc(), out);
+ }
+ }
+
+ private void decodeSelfContained(Segment segment, List<Object> out) {
+ ByteBuf payload = segment.getPayload();
+ int frameCount = 0;
+ do {
+ Header header = Header.decode(payload);
+ ByteBuf body = payload.readSlice(header.bodyLength);
+ body.retain();
+ out.add(new Frame(header, body));
+ frameCount += 1;
+ } while (payload.isReadable());
+ payload.release();
+ logger.trace("Decoded self-contained segment into {} frame(s)", frameCount);
+ }
+
+ private void decodeSlice(Segment segment, ByteBufAllocator allocator, List<Object> out) {
+ assert pendingHeader != null ^ (accumulatedSlices.isEmpty() && accumulatedLength == 0);
+ ByteBuf payload = segment.getPayload();
+ if (pendingHeader == null) { // first slice
+ pendingHeader = Header.decode(payload); // note: this consumes the header data
+ }
+ accumulatedSlices.add(payload);
+ accumulatedLength += payload.readableBytes();
+ logger.trace(
+ "StreamId {}: decoded slice {}, {}/{} bytes",
+ pendingHeader.streamId,
+ accumulatedSlices.size(),
+ accumulatedLength,
+ pendingHeader.bodyLength);
+ assert accumulatedLength <= pendingHeader.bodyLength;
+ if (accumulatedLength == pendingHeader.bodyLength) {
+ // We've received enough data to reassemble the whole message
+ CompositeByteBuf body = allocator.compositeBuffer(accumulatedSlices.size());
+ body.addComponents(true, accumulatedSlices);
+ out.add(new Frame(pendingHeader, body));
+ // Reset our state
+ pendingHeader = null;
+ accumulatedSlices.clear();
+ accumulatedLength = 0;
+ }
+ }
+}
diff --git a/driver-core/src/main/java/com/datastax/driver/core/SnappyCompressor.java b/driver-core/src/main/java/com/datastax/driver/core/SnappyCompressor.java
index 6f0c9f3..7b9e2b8 100644
--- a/driver-core/src/main/java/com/datastax/driver/core/SnappyCompressor.java
+++ b/driver-core/src/main/java/com/datastax/driver/core/SnappyCompressor.java
@@ -54,9 +54,12 @@
@Override
Frame compress(Frame frame) throws IOException {
- ByteBuf input = frame.body;
- ByteBuf frameBody = input.isDirect() ? compressDirect(input) : compressHeap(input);
- return frame.with(frameBody);
+ return frame.with(compress(frame.body));
+ }
+
+ @Override
+ ByteBuf compress(ByteBuf buffer) throws IOException {
+ return buffer.isDirect() ? compressDirect(buffer) : compressHeap(buffer);
}
private ByteBuf compressDirect(ByteBuf input) throws IOException {
@@ -117,6 +120,13 @@
return frame.with(frameBody);
}
+ @Override
+ ByteBuf decompress(ByteBuf buffer, int uncompressedLength) throws IOException {
+ // Note that the Snappy algorithm already encodes the uncompressed length, we don't need the
+ // provided one.
+ return buffer.isDirect() ? decompressDirect(buffer) : decompressHeap(buffer);
+ }
+
private ByteBuf decompressDirect(ByteBuf input) throws IOException {
ByteBuffer in = inputNioBuffer(input);
// Increase reader index.
diff --git a/driver-core/src/main/java/com/datastax/driver/core/exceptions/CrcMismatchException.java b/driver-core/src/main/java/com/datastax/driver/core/exceptions/CrcMismatchException.java
new file mode 100644
index 0000000..a79a49f
--- /dev/null
+++ b/driver-core/src/main/java/com/datastax/driver/core/exceptions/CrcMismatchException.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core.exceptions;
+
+/**
+ * Thrown when the checksums in a server response don't match (protocol v5 or above).
+ *
+ * <p>This indicates a data corruption issue, either due to a hardware issue on the client, or on
+ * the network between the server and the client. It is not recoverable: the driver will drop the
+ * connection.
+ */
+public class CrcMismatchException extends DriverException {
+
+ private static final long serialVersionUID = 0;
+
+ public CrcMismatchException(String message) {
+ super(message);
+ }
+
+ public CrcMismatchException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ @Override
+ public CrcMismatchException copy() {
+ return new CrcMismatchException(getMessage(), this);
+ }
+}
diff --git a/driver-core/src/test/java/com/datastax/driver/core/BytesToSegmentDecoderTest.java b/driver-core/src/test/java/com/datastax/driver/core/BytesToSegmentDecoderTest.java
new file mode 100644
index 0000000..8b2bcb9
--- /dev/null
+++ b/driver-core/src/test/java/com/datastax/driver/core/BytesToSegmentDecoderTest.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import com.datastax.driver.core.ProtocolOptions.Compression;
+import com.datastax.driver.core.exceptions.CrcMismatchException;
+import com.google.common.base.Strings;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.DecoderException;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+public class BytesToSegmentDecoderTest {
+
+ // Hard-coded test data, the values were generated with our encoding methods.
+ // We're not really testing the decoding itself here, only that our subclass calls the
+ // LengthFieldBasedFrameDecoder parent constructor with the right parameters.
+ private static final ByteBuf REGULAR_HEADER = byteBuf("04000201f9f2");
+ private static final ByteBuf REGULAR_PAYLOAD = byteBuf("00000001");
+ private static final ByteBuf REGULAR_TRAILER = byteBuf("1fd6022d");
+ private static final ByteBuf REGULAR_WRONG_HEADER = byteBuf("04000202f9f2");
+ private static final ByteBuf REGULAR_WRONG_TRAILER = byteBuf("1fd6022e");
+
+ private static final ByteBuf MAX_HEADER = byteBuf("ffff03254047");
+ private static final ByteBuf MAX_PAYLOAD =
+ byteBuf(Strings.repeat("01", Segment.MAX_PAYLOAD_LENGTH));
+ private static final ByteBuf MAX_TRAILER = byteBuf("a05c2f13");
+
+ private static final ByteBuf LZ4_HEADER = byteBuf("120020000491c94f");
+ private static final ByteBuf LZ4_PAYLOAD_UNCOMPRESSED =
+ byteBuf("00000001000000010000000100000001");
+ private static final ByteBuf LZ4_PAYLOAD_COMPRESSED =
+ byteBuf("f00100000001000000010000000100000001");
+ private static final ByteBuf LZ4_TRAILER = byteBuf("2bd67f90");
+
+ private EmbeddedChannel channel;
+
+ @BeforeMethod(groups = "unit")
+ public void setup() {
+ channel = new EmbeddedChannel();
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_regular_segment() {
+ channel.pipeline().addLast(newDecoder(Compression.NONE));
+ channel.writeInbound(Unpooled.wrappedBuffer(REGULAR_HEADER, REGULAR_PAYLOAD, REGULAR_TRAILER));
+ Segment segment = (Segment) channel.readInbound();
+ assertThat(segment.isSelfContained()).isTrue();
+ assertThat(segment.getPayload()).isEqualTo(REGULAR_PAYLOAD);
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_max_length_segment() {
+ channel.pipeline().addLast(newDecoder(Compression.NONE));
+ channel.writeInbound(Unpooled.wrappedBuffer(MAX_HEADER, MAX_PAYLOAD, MAX_TRAILER));
+ Segment segment = (Segment) channel.readInbound();
+ assertThat(segment.isSelfContained()).isTrue();
+ assertThat(segment.getPayload()).isEqualTo(MAX_PAYLOAD);
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_segment_from_multiple_incoming_chunks() {
+ channel.pipeline().addLast(newDecoder(Compression.NONE));
+ // Send the header in two slices, to cover the case where the length can't be read the first
+ // time:
+ ByteBuf headerStart = REGULAR_HEADER.slice(0, 3);
+ ByteBuf headerEnd = REGULAR_HEADER.slice(3, 3);
+ channel.writeInbound(headerStart);
+ channel.writeInbound(headerEnd);
+ channel.writeInbound(REGULAR_PAYLOAD.duplicate());
+ channel.writeInbound(REGULAR_TRAILER.duplicate());
+ Segment segment = (Segment) channel.readInbound();
+ assertThat(segment.isSelfContained()).isTrue();
+ assertThat(segment.getPayload()).isEqualTo(REGULAR_PAYLOAD);
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_compressed_segment() {
+ channel.pipeline().addLast(newDecoder(Compression.LZ4));
+ // We need a contiguous buffer for this one, because of how our decompressor operates
+ ByteBuf buffer = Unpooled.wrappedBuffer(LZ4_HEADER, LZ4_PAYLOAD_COMPRESSED, LZ4_TRAILER).copy();
+ channel.writeInbound(buffer);
+ Segment segment = (Segment) channel.readInbound();
+ assertThat(segment.isSelfContained()).isTrue();
+ assertThat(segment.getPayload()).isEqualTo(LZ4_PAYLOAD_UNCOMPRESSED);
+ }
+
+ @Test(groups = "unit")
+ public void should_surface_header_crc_mismatch() {
+ try {
+ channel.pipeline().addLast(newDecoder(Compression.NONE));
+ channel.writeInbound(
+ Unpooled.wrappedBuffer(REGULAR_WRONG_HEADER, REGULAR_PAYLOAD, REGULAR_TRAILER));
+ } catch (DecoderException exception) {
+ assertThat(exception).hasCauseInstanceOf(CrcMismatchException.class);
+ }
+ }
+
+ @Test(groups = "unit")
+ public void should_surface_trailer_crc_mismatch() {
+ try {
+ channel.pipeline().addLast(newDecoder(Compression.NONE));
+ channel.writeInbound(
+ Unpooled.wrappedBuffer(REGULAR_HEADER, REGULAR_PAYLOAD, REGULAR_WRONG_TRAILER));
+ } catch (DecoderException exception) {
+ assertThat(exception).hasCauseInstanceOf(CrcMismatchException.class);
+ }
+ }
+
+ private BytesToSegmentDecoder newDecoder(Compression compression) {
+ return new BytesToSegmentDecoder(new SegmentCodec(ByteBufAllocator.DEFAULT, compression));
+ }
+
+ private static ByteBuf byteBuf(String hex) {
+ return Unpooled.unreleasableBuffer(
+ Unpooled.unmodifiableBuffer(Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(hex))));
+ }
+}
diff --git a/driver-core/src/test/java/com/datastax/driver/core/SegmentBuilderTest.java b/driver-core/src/test/java/com/datastax/driver/core/SegmentBuilderTest.java
new file mode 100644
index 0000000..40adfb6
--- /dev/null
+++ b/driver-core/src/test/java/com/datastax/driver/core/SegmentBuilderTest.java
@@ -0,0 +1,308 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.when;
+
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.embedded.EmbeddedChannel;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+public class SegmentBuilderTest {
+
+ private static final Message.ProtocolEncoder REQUEST_ENCODER =
+ new Message.ProtocolEncoder(ProtocolVersion.V5);
+
+ // The constant names denote the total encoded size, including the frame header
+ private static final Message.Request _38B_REQUEST = new Requests.Query("SELECT * FROM table");
+ private static final Message.Request _51B_REQUEST =
+ new Requests.Query("SELECT * FROM table WHERE id = 1");
+ private static final Message.Request _1KB_REQUEST =
+ new Requests.Query(
+ "SELECT * FROM table WHERE id = ?",
+ new Requests.QueryProtocolOptions(
+ Message.Request.Type.QUERY,
+ ConsistencyLevel.ONE,
+ new ByteBuffer[] {ByteBuffer.allocate(967)},
+ Collections.<String, ByteBuffer>emptyMap(),
+ false,
+ -1,
+ null,
+ ConsistencyLevel.SERIAL,
+ Long.MIN_VALUE,
+ Integer.MIN_VALUE),
+ false);
+
+ private static final EmbeddedChannel MOCK_CHANNEL = new EmbeddedChannel();
+ private static final ChannelHandlerContext CONTEXT = Mockito.mock(ChannelHandlerContext.class);
+
+ @BeforeClass(groups = "unit")
+ public static void setup() {
+ // This is the only method called by our test implementation
+ when(CONTEXT.newPromise())
+ .thenAnswer(
+ new Answer<ChannelPromise>() {
+ @Override
+ public ChannelPromise answer(InvocationOnMock invocation) {
+ return MOCK_CHANNEL.newPromise();
+ }
+ });
+ }
+
+ @Test(groups = "unit")
+ public void should_concatenate_frames_when_under_limit() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise requestPromise1 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise1);
+ ChannelPromise requestPromise2 = newPromise();
+ builder.addRequest(_51B_REQUEST, requestPromise2);
+ // Nothing produced yet since we would still have room for more frames
+ assertThat(builder.segments).isEmpty();
+
+ builder.flush();
+ assertThat(builder.segments).hasSize(1);
+ assertThat(builder.segmentPromises).hasSize(1);
+ Segment segment = builder.segments.get(0);
+ assertThat(segment.getPayload().readableBytes()).isEqualTo(38 + 51);
+ assertThat(segment.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise = builder.segmentPromises.get(0);
+ assertForwards(segmentPromise, requestPromise1, requestPromise2);
+ }
+
+ @Test(groups = "unit")
+ public void should_start_new_segment_when_over_limit() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise requestPromise1 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise1);
+ ChannelPromise requestPromise2 = newPromise();
+ builder.addRequest(_51B_REQUEST, requestPromise2);
+ ChannelPromise requestPromise3 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise3);
+ // Adding the 3rd frame brings the total size over 100, so a first segment should be emitted
+ // with the first two messages:
+ assertThat(builder.segments).hasSize(1);
+
+ ChannelPromise requestPromise4 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise4);
+ builder.flush();
+ assertThat(builder.segments).hasSize(2);
+
+ Segment segment1 = builder.segments.get(0);
+ assertThat(segment1.getPayload().readableBytes()).isEqualTo(38 + 51);
+ assertThat(segment1.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise1 = builder.segmentPromises.get(0);
+ assertForwards(segmentPromise1, requestPromise1, requestPromise2);
+ Segment segment2 = builder.segments.get(1);
+ assertThat(segment2.getPayload().readableBytes()).isEqualTo(38 + 38);
+ assertThat(segment2.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise2 = builder.segmentPromises.get(1);
+ assertForwards(segmentPromise2, requestPromise3, requestPromise4);
+ }
+
+ @Test(groups = "unit")
+ public void should_start_new_segment_when_at_limit() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 38 + 51);
+
+ ChannelPromise requestPromise1 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise1);
+ ChannelPromise requestPromise2 = newPromise();
+ builder.addRequest(_51B_REQUEST, requestPromise2);
+ ChannelPromise requestPromise3 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise3);
+ assertThat(builder.segments).hasSize(1);
+
+ ChannelPromise requestPromise4 = newPromise();
+ builder.addRequest(_51B_REQUEST, requestPromise4);
+ builder.flush();
+ assertThat(builder.segments).hasSize(2);
+
+ Segment segment1 = builder.segments.get(0);
+ assertThat(segment1.getPayload().readableBytes()).isEqualTo(38 + 51);
+ assertThat(segment1.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise1 = builder.segmentPromises.get(0);
+ assertForwards(segmentPromise1, requestPromise1, requestPromise2);
+ Segment segment2 = builder.segments.get(1);
+ assertThat(segment2.getPayload().readableBytes()).isEqualTo(38 + 51);
+ assertThat(segment2.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise2 = builder.segmentPromises.get(1);
+ assertForwards(segmentPromise2, requestPromise3, requestPromise4);
+ }
+
+ @Test(groups = "unit")
+ public void should_split_large_frame() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise parentPromise = newPromise();
+ builder.addRequest(_1KB_REQUEST, parentPromise);
+
+ assertThat(builder.segments).hasSize(11);
+ assertThat(builder.segmentPromises).hasSize(11);
+ for (int i = 0; i < 11; i++) {
+ Segment slice = builder.segments.get(i);
+ assertThat(slice.getPayload().readableBytes()).isEqualTo(i == 10 ? 24 : 100);
+ assertThat(slice.isSelfContained()).isFalse();
+ }
+ }
+
+ @Test(groups = "unit")
+ public void should_succeed_parent_write_if_all_slices_successful() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise parentPromise = newPromise();
+ builder.addRequest(_1KB_REQUEST, parentPromise);
+
+ assertThat(builder.segments).hasSize(11);
+ assertThat(builder.segmentPromises).hasSize(11);
+
+ for (int i = 0; i < 11; i++) {
+ assertThat(parentPromise.isDone()).isFalse();
+ builder.segmentPromises.get(i).setSuccess();
+ }
+
+ assertThat(parentPromise.isDone()).isTrue();
+ }
+
+ @Test(groups = "unit")
+ public void should_fail_parent_write_if_any_slice_fails() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise parentPromise = newPromise();
+ builder.addRequest(_1KB_REQUEST, parentPromise);
+
+ assertThat(builder.segments).hasSize(11);
+
+ // Complete a few slices successfully
+ for (int i = 0; i < 5; i++) {
+ builder.segmentPromises.get(i).setSuccess();
+ }
+ assertThat(parentPromise.isDone()).isFalse();
+
+ // Fail a slice, the parent should fail immediately
+ Exception mockException = new Exception("test");
+ builder.segmentPromises.get(5).setFailure(mockException);
+ assertThat(parentPromise.isDone()).isTrue();
+ assertThat(parentPromise.cause()).isEqualTo(mockException);
+
+ // The remaining slices should have been cancelled
+ for (int i = 6; i < 11; i++) {
+ assertThat(builder.segmentPromises.get(i).isCancelled()).isTrue();
+ }
+ }
+
+ @Test(groups = "unit")
+ public void should_split_large_frame_when_exact_multiple() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 256);
+
+ ChannelPromise parentPromise = newPromise();
+ builder.addRequest(_1KB_REQUEST, parentPromise);
+
+ assertThat(builder.segments).hasSize(4);
+ assertThat(builder.segmentPromises).hasSize(4);
+ for (int i = 0; i < 4; i++) {
+ Segment slice = builder.segments.get(i);
+ assertThat(slice.getPayload().readableBytes()).isEqualTo(256);
+ assertThat(slice.isSelfContained()).isFalse();
+ }
+ }
+
+ @Test(groups = "unit")
+ public void should_mix_small_frames_and_large_frames() {
+ TestSegmentBuilder builder = new TestSegmentBuilder(CONTEXT, 100);
+
+ ChannelPromise requestPromise1 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise1);
+ ChannelPromise requestPromise2 = newPromise();
+ builder.addRequest(_51B_REQUEST, requestPromise2);
+
+ // Large frame: process immediately, does not impact accumulated small frames
+ ChannelPromise requestPromise3 = newPromise();
+ builder.addRequest(_1KB_REQUEST, requestPromise3);
+ assertThat(builder.segments).hasSize(11);
+
+ // Another small frames bring us above the limit
+ ChannelPromise requestPromise4 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise4);
+ assertThat(builder.segments).hasSize(12);
+
+ // One last frame and finish
+ ChannelPromise requestPromise5 = newPromise();
+ builder.addRequest(_38B_REQUEST, requestPromise5);
+ builder.flush();
+ assertThat(builder.segments).hasSize(13);
+ assertThat(builder.segmentPromises).hasSize(13);
+
+ for (int i = 0; i < 11; i++) {
+ Segment slice = builder.segments.get(i);
+ assertThat(slice.getPayload().readableBytes()).isEqualTo(i == 10 ? 24 : 100);
+ assertThat(slice.isSelfContained()).isFalse();
+ }
+
+ Segment smallMessages1 = builder.segments.get(11);
+ assertThat(smallMessages1.getPayload().readableBytes()).isEqualTo(38 + 51);
+ assertThat(smallMessages1.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise1 = builder.segmentPromises.get(11);
+ assertForwards(segmentPromise1, requestPromise1, requestPromise2);
+ Segment smallMessages2 = builder.segments.get(12);
+ assertThat(smallMessages2.getPayload().readableBytes()).isEqualTo(38 + 38);
+ assertThat(smallMessages2.isSelfContained()).isTrue();
+ ChannelPromise segmentPromise2 = builder.segmentPromises.get(12);
+ assertForwards(segmentPromise2, requestPromise4, requestPromise5);
+ }
+
+ private static ChannelPromise newPromise() {
+ return MOCK_CHANNEL.newPromise();
+ }
+
+ private void assertForwards(ChannelPromise segmentPromise, ChannelPromise... requestPromises) {
+ for (ChannelPromise requestPromise : requestPromises) {
+ assertThat(requestPromise.isDone()).isFalse();
+ }
+ segmentPromise.setSuccess();
+ for (ChannelPromise requestPromise : requestPromises) {
+ assertThat(requestPromise.isSuccess()).isTrue();
+ }
+ }
+
+ // Test implementation that simply stores segments and promises in the order they were produced.
+ static class TestSegmentBuilder extends SegmentBuilder {
+
+ List<Segment> segments = new ArrayList<Segment>();
+ List<ChannelPromise> segmentPromises = new ArrayList<ChannelPromise>();
+
+ TestSegmentBuilder(ChannelHandlerContext context, int maxPayloadLength) {
+ super(context, ByteBufAllocator.DEFAULT, REQUEST_ENCODER, maxPayloadLength);
+ }
+
+ @Override
+ protected void processSegment(Segment segment, ChannelPromise segmentPromise) {
+ segments.add(segment);
+ segmentPromises.add(segmentPromise);
+ }
+ }
+}
diff --git a/driver-core/src/test/java/com/datastax/driver/core/SegmentCodecTest.java b/driver-core/src/test/java/com/datastax/driver/core/SegmentCodecTest.java
new file mode 100644
index 0000000..1d81702
--- /dev/null
+++ b/driver-core/src/test/java/com/datastax/driver/core/SegmentCodecTest.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.fail;
+
+import com.datastax.driver.core.ProtocolOptions.Compression;
+import com.datastax.driver.core.SegmentCodec.Header;
+import com.datastax.driver.core.exceptions.CrcMismatchException;
+import com.google.common.base.Strings;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import org.testng.annotations.Test;
+
+public class SegmentCodecTest {
+
+ public static final SegmentCodec CODEC_NO_COMPRESSION =
+ new SegmentCodec(UnpooledByteBufAllocator.DEFAULT, Compression.NONE);
+ public static final SegmentCodec CODEC_LZ4 =
+ new SegmentCodec(UnpooledByteBufAllocator.DEFAULT, Compression.LZ4);
+
+ @Test(groups = "unit")
+ public void should_encode_uncompressed_header() {
+ ByteBuf header = CODEC_NO_COMPRESSION.encodeHeader(5, -1, true);
+
+ byte byte0 = header.getByte(2);
+ byte byte1 = header.getByte(1);
+ byte byte2 = header.getByte(0);
+
+ assertThat(bits(byte0) + bits(byte1) + bits(byte2))
+ .isEqualTo(
+ "000000" // padding (6 bits)
+ + "1" // selfContainedFlag
+ + "00000000000000101" // length (17 bits)
+ );
+ }
+
+ @Test(groups = "unit")
+ public void should_encode_compressed_header() {
+ ByteBuf header = CODEC_LZ4.encodeHeader(5, 12, true);
+
+ byte byte0 = header.getByte(4);
+ byte byte1 = header.getByte(3);
+ byte byte2 = header.getByte(2);
+ byte byte3 = header.getByte(1);
+ byte byte4 = header.getByte(0);
+
+ assertThat(bits(byte0) + bits(byte1) + bits(byte2) + bits(byte3) + bits(byte4))
+ .isEqualTo(
+ "00000" // padding (5 bits)
+ + "1" // selfContainedFlag
+ + "00000000000001100" // uncompressed length (17 bits)
+ + "00000000000000101" // compressed length (17 bits)
+ );
+ }
+
+ /**
+ * Checks that we correctly use 8 bytes when we left-shift the uncompressed length, to avoid
+ * overflows.
+ */
+ @Test(groups = "unit")
+ public void should_encode_compressed_header_when_aligned_uncompressed_length_overflows() {
+ ByteBuf header = CODEC_LZ4.encodeHeader(5, Segment.MAX_PAYLOAD_LENGTH, true);
+
+ byte byte0 = header.getByte(4);
+ byte byte1 = header.getByte(3);
+ byte byte2 = header.getByte(2);
+ byte byte3 = header.getByte(1);
+ byte byte4 = header.getByte(0);
+
+ assertThat(bits(byte0) + bits(byte1) + bits(byte2) + bits(byte3) + bits(byte4))
+ .isEqualTo(
+ "00000" // padding (5 bits)
+ + "1" // selfContainedFlag
+ + "11111111111111111" // uncompressed length (17 bits)
+ + "00000000000000101" // compressed length (17 bits)
+ );
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_uncompressed_payload() {
+ // Assembling the test data manually would have little value because it would be very similar to
+ // our production code. So simply use that production code, assuming it's correct.
+ ByteBuf buffer = CODEC_NO_COMPRESSION.encodeHeader(5, -1, true);
+ Header header = CODEC_NO_COMPRESSION.decodeHeader(buffer);
+ assertThat(header.payloadLength).isEqualTo(5);
+ assertThat(header.uncompressedPayloadLength).isEqualTo(-1);
+ assertThat(header.isSelfContained).isTrue();
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_compressed_payload() {
+ ByteBuf buffer = CODEC_LZ4.encodeHeader(5, 12, true);
+ Header header = CODEC_LZ4.decodeHeader(buffer);
+ assertThat(header.payloadLength).isEqualTo(5);
+ assertThat(header.uncompressedPayloadLength).isEqualTo(12);
+ assertThat(header.isSelfContained).isTrue();
+ }
+
+ @Test(groups = "unit")
+ public void should_fail_to_decode_if_corrupted() {
+ ByteBuf buffer = CODEC_NO_COMPRESSION.encodeHeader(5, -1, true);
+
+ // Flip a random byte
+ for (int bitOffset = 0; bitOffset < 47; bitOffset++) {
+ int byteOffset = bitOffset / 8;
+ int shift = bitOffset % 8;
+
+ ByteBuf slice = buffer.slice(buffer.readerIndex() + byteOffset, 1);
+ slice.markReaderIndex();
+ byte byteToCorrupt = slice.readByte();
+ slice.resetReaderIndex();
+ slice.writerIndex(slice.readerIndex());
+ slice.writeByte((byteToCorrupt & 0xFF) ^ (1 << shift));
+
+ try {
+ CODEC_NO_COMPRESSION.decodeHeader(buffer.duplicate());
+ fail("Expected CrcMismatchException");
+ } catch (CrcMismatchException e) {
+ // expected
+ }
+ }
+ }
+
+ private static String bits(byte b) {
+ return Strings.padStart(Integer.toBinaryString(b & 0xFF), 8, '0');
+ }
+}
diff --git a/driver-core/src/test/java/com/datastax/driver/core/SegmentToFrameDecoderTest.java b/driver-core/src/test/java/com/datastax/driver/core/SegmentToFrameDecoderTest.java
new file mode 100644
index 0000000..bef4a5b
--- /dev/null
+++ b/driver-core/src/test/java/com/datastax/driver/core/SegmentToFrameDecoderTest.java
@@ -0,0 +1,120 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.datastax.driver.core;
+
+import static com.datastax.driver.core.Message.Response.Type.READY;
+import static com.datastax.driver.core.Message.Response.Type.RESULT;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import com.datastax.driver.core.Frame.Header;
+import com.datastax.driver.core.Frame.Header.Flag;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.embedded.EmbeddedChannel;
+import java.util.EnumSet;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+public class SegmentToFrameDecoderTest {
+
+ private static final ByteBuf SMALL_BODY_1 = buffer(128);
+ private static final Header SMALL_HEADER_1 =
+ new Header(
+ ProtocolVersion.V5,
+ EnumSet.noneOf(Flag.class),
+ 2,
+ READY.opcode,
+ SMALL_BODY_1.readableBytes());
+
+ private static final ByteBuf SMALL_BODY_2 = buffer(1024);
+ private static final Header SMALL_HEADER_2 =
+ new Header(
+ ProtocolVersion.V5,
+ EnumSet.noneOf(Flag.class),
+ 7,
+ RESULT.opcode,
+ SMALL_BODY_2.readableBytes());
+
+ private static final ByteBuf LARGE_BODY = buffer(256 * 1024);
+ private static final Header LARGE_HEADER =
+ new Header(
+ ProtocolVersion.V5,
+ EnumSet.noneOf(Flag.class),
+ 12,
+ RESULT.opcode,
+ LARGE_BODY.readableBytes());
+
+ private EmbeddedChannel channel;
+
+ @BeforeMethod(groups = "unit")
+ public void setup() {
+ channel = new EmbeddedChannel();
+ channel.pipeline().addLast(new SegmentToFrameDecoder());
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_self_contained() {
+ ByteBuf payload = UnpooledByteBufAllocator.DEFAULT.buffer();
+ appendFrame(SMALL_HEADER_1, SMALL_BODY_1, payload);
+ appendFrame(SMALL_HEADER_2, SMALL_BODY_2, payload);
+
+ channel.writeInbound(new Segment(payload, true));
+
+ Frame frame1 = (Frame) channel.readInbound();
+ Header header1 = frame1.header;
+ assertThat(header1.streamId).isEqualTo(SMALL_HEADER_1.streamId);
+ assertThat(header1.opcode).isEqualTo(SMALL_HEADER_1.opcode);
+ assertThat(frame1.body).isEqualTo(SMALL_BODY_1);
+
+ Frame frame2 = (Frame) channel.readInbound();
+ Header header2 = frame2.header;
+ assertThat(header2.streamId).isEqualTo(SMALL_HEADER_2.streamId);
+ assertThat(header2.opcode).isEqualTo(SMALL_HEADER_2.opcode);
+ assertThat(frame2.body).isEqualTo(SMALL_BODY_2);
+ }
+
+ @Test(groups = "unit")
+ public void should_decode_sequence_of_slices() {
+ ByteBuf encodedFrame = UnpooledByteBufAllocator.DEFAULT.buffer();
+ appendFrame(LARGE_HEADER, LARGE_BODY, encodedFrame);
+
+ do {
+ ByteBuf payload =
+ encodedFrame.readSlice(
+ Math.min(Segment.MAX_PAYLOAD_LENGTH, encodedFrame.readableBytes()));
+ channel.writeInbound(new Segment(payload, false));
+ } while (encodedFrame.isReadable());
+
+ Frame frame = (Frame) channel.readInbound();
+ Header header = frame.header;
+ assertThat(header.streamId).isEqualTo(LARGE_HEADER.streamId);
+ assertThat(header.opcode).isEqualTo(LARGE_HEADER.opcode);
+ assertThat(frame.body).isEqualTo(LARGE_BODY);
+ }
+
+ private static final ByteBuf buffer(int length) {
+ ByteBuf buffer = UnpooledByteBufAllocator.DEFAULT.buffer(length);
+ // Contents don't really matter, keep all zeroes
+ buffer.writerIndex(buffer.readerIndex() + length);
+ return buffer;
+ }
+
+ private static void appendFrame(Header frameHeader, ByteBuf frameBody, ByteBuf payload) {
+ frameHeader.encodeInto(payload);
+ // this method doesn't affect the body's indices:
+ payload.writeBytes(frameBody, frameBody.readerIndex(), frameBody.readableBytes());
+ }
+}