| |
| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.cassandra.transport; |
| |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.util.EnumSet; |
| import java.util.List; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Preconditions; |
| |
| import io.netty.buffer.ByteBuf; |
| import io.netty.buffer.Unpooled; |
| import io.netty.channel.*; |
| import io.netty.handler.codec.ByteToMessageDecoder; |
| import io.netty.handler.codec.MessageToMessageDecoder; |
| import io.netty.handler.codec.MessageToMessageEncoder; |
| import org.apache.cassandra.config.DatabaseDescriptor; |
| import org.apache.cassandra.exceptions.InvalidRequestException; |
| import org.apache.cassandra.metrics.ClientMessageSizeMetrics; |
| import org.apache.cassandra.transport.messages.ErrorMessage; |
| import org.apache.cassandra.utils.ByteBufferUtil; |
| |
| public class Envelope |
| { |
| public static final byte PROTOCOL_VERSION_MASK = 0x7f; |
| |
| public final Header header; |
| public final ByteBuf body; |
| |
| /** |
| * An on-wire message envelope consists of a header and a body. |
| * |
| * The header is defined the following way in native protocol version 3 and later: |
| * |
| * 0 8 16 24 32 40 |
| * +---------+---------+---------+---------+---------+ |
| * | version | flags | stream | opcode | |
| * +---------+---------+---------+---------+---------+ |
| * | length | |
| * +---------+---------+---------+---------+ |
| */ |
| public Envelope(Header header, ByteBuf body) |
| { |
| this.header = header; |
| this.body = body; |
| } |
| |
| public void retain() |
| { |
| body.retain(); |
| } |
| |
| public boolean release() |
| { |
| return body.release(); |
| } |
| |
| @VisibleForTesting |
| public Envelope clone() |
| { |
| return new Envelope(header, Unpooled.wrappedBuffer(ByteBufferUtil.clone(body.nioBuffer()))); |
| } |
| |
| public static Envelope create(Message.Type type, int streamId, ProtocolVersion version, EnumSet<Header.Flag> flags, ByteBuf body) |
| { |
| Header header = new Header(version, flags, streamId, type, body.readableBytes()); |
| return new Envelope(header, body); |
| } |
| |
| // used by V4 and earlier in Encoder.encode |
| public ByteBuf encodeHeader() |
| { |
| ByteBuf buf = CBUtil.allocator.buffer(Header.LENGTH); |
| |
| Message.Type type = header.type; |
| buf.writeByte(type.direction.addToVersion(header.version.asInt())); |
| buf.writeByte(Header.Flag.serialize(header.flags)); |
| |
| // Continue to support writing pre-v3 headers so that we can give proper error messages to drivers that |
| // connect with the v1/v2 protocol. See CASSANDRA-11464. |
| if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3)) |
| buf.writeShort(header.streamId); |
| else |
| buf.writeByte(header.streamId); |
| |
| buf.writeByte(type.opcode); |
| buf.writeInt(body.readableBytes()); |
| return buf; |
| } |
| |
| // Used by V5 and later |
| public void encodeHeaderInto(ByteBuffer buf) |
| { |
| buf.put((byte) header.type.direction.addToVersion(header.version.asInt())); |
| buf.put((byte) Envelope.Header.Flag.serialize(header.flags)); |
| |
| if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3)) |
| buf.putShort((short) header.streamId); |
| else |
| buf.put((byte) header.streamId); |
| |
| buf.put((byte) header.type.opcode); |
| buf.putInt(body.readableBytes()); |
| } |
| |
| // Used by V5 and later |
| public void encodeInto(ByteBuffer buf) |
| { |
| encodeHeaderInto(buf); |
| buf.put(body.nioBuffer()); |
| } |
| |
| public static class Header |
| { |
| // 9 bytes in protocol version 3 and later |
| public static final int LENGTH = 9; |
| |
| public static final int BODY_LENGTH_SIZE = 4; |
| |
| public final ProtocolVersion version; |
| public final EnumSet<Flag> flags; |
| public final int streamId; |
| public final Message.Type type; |
| public final long bodySizeInBytes; |
| |
| private Header(ProtocolVersion version, EnumSet<Flag> flags, int streamId, Message.Type type, long bodySizeInBytes) |
| { |
| this.version = version; |
| this.flags = flags; |
| this.streamId = streamId; |
| this.type = type; |
| this.bodySizeInBytes = bodySizeInBytes; |
| } |
| |
| public enum Flag |
| { |
| // The order of that enum matters!! |
| COMPRESSED, |
| TRACING, |
| CUSTOM_PAYLOAD, |
| WARNING, |
| USE_BETA; |
| |
| private static final Flag[] ALL_VALUES = values(); |
| |
| public static EnumSet<Flag> deserialize(int flags) |
| { |
| EnumSet<Flag> set = EnumSet.noneOf(Flag.class); |
| for (int n = 0; n < ALL_VALUES.length; n++) |
| { |
| if ((flags & (1 << n)) != 0) |
| set.add(ALL_VALUES[n]); |
| } |
| return set; |
| } |
| |
| public static int serialize(EnumSet<Flag> flags) |
| { |
| int i = 0; |
| for (Flag flag : flags) |
| i |= 1 << flag.ordinal(); |
| return i; |
| } |
| } |
| } |
| |
| public Envelope with(ByteBuf newBody) |
| { |
| return new Envelope(header, newBody); |
| } |
| |
| public static class Decoder extends ByteToMessageDecoder |
| { |
| private static final int MAX_TOTAL_LENGTH = DatabaseDescriptor.getNativeTransportMaxFrameSize(); |
| |
| private boolean discardingTooLongMessage; |
| private long tooLongTotalLength; |
| private long bytesToDiscard; |
| private int tooLongStreamId; |
| |
| /** |
| * Used by protocol V5 and later to extract a CQL message header from the buffer containing |
| * it, without modifying the position of the underlying buffer. This essentially mirrors the |
| * pre-V5 code in {@link Decoder#decode(ByteBuf)}, with three differences: |
| * <ul> |
| * <li>The input is a ByteBuffer rather than a ByteBuf</li> |
| * <li>This cannot return null, as V5 always deals with entire CQL messages. Coalescing of bytes |
| * off the wire happens at the layer below, in {@link org.apache.cassandra.net.FrameDecoder}</li> |
| * <li>This method never throws {@link ProtocolException}. Instead, a subclass of |
| * {@link HeaderExtractionResult} is returned which may provide either a {@link Header} or a |
| * {@link ProtocolException},depending on the result of its {@link HeaderExtractionResult#isSuccess()} |
| * method.</li> |
| *</ul> |
| * |
| * @param buffer ByteBuffer containing the message envelope |
| * @return The result of attempting to extract a header from the input buffer. |
| */ |
| HeaderExtractionResult extractHeader(ByteBuffer buffer) |
| { |
| Preconditions.checkArgument(buffer.remaining() >= Header.LENGTH, |
| "Undersized buffer supplied. Expected %s, actual %s", |
| Header.LENGTH, |
| buffer.remaining()); |
| int idx = buffer.position(); |
| int firstByte = buffer.get(idx++); |
| int versionNum = firstByte & PROTOCOL_VERSION_MASK; |
| int flags = buffer.get(idx++); |
| int streamId = buffer.getShort(idx); |
| idx += 2; |
| int opcode = buffer.get(idx++); |
| long bodyLength = buffer.getInt(idx); |
| |
| // if a negative length is read, return error but report length as 0 so we don't attempt to skip |
| if (bodyLength < 0) |
| return new HeaderExtractionResult.Error(new ProtocolException("Invalid value for envelope header body length field: " + bodyLength), |
| streamId, bodyLength); |
| |
| Message.Direction direction = Message.Direction.extractFromVersion(firstByte); |
| Message.Type type; |
| ProtocolVersion version; |
| EnumSet<Header.Flag> decodedFlags; |
| try |
| { |
| // This throws a protocol exception if the version number is unsupported, |
| // the opcode is unknown or invalid flags are set for the version |
| version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols()); |
| decodedFlags = decodeFlags(version, flags); |
| type = Message.Type.fromOpcode(opcode, direction); |
| return new HeaderExtractionResult.Success(new Header(version, decodedFlags, streamId, type, bodyLength)); |
| } |
| catch (ProtocolException e) |
| { |
| // Including the streamId and bodyLength is a best effort to allow the caller |
| // to send a meaningful response to the client and continue processing the |
| // rest of the frame. It's possible that these are bogus and may have contributed |
| // to the ProtocolException. If so, the upstream CQLMessageHandler should run into |
| // further errors and once it breaches its threshold for consecutive errors, it will |
| // cause the channel to be closed. |
| return new HeaderExtractionResult.Error(e, streamId, bodyLength); |
| } |
| } |
| |
| public static abstract class HeaderExtractionResult |
| { |
| enum Outcome { SUCCESS, ERROR }; |
| |
| private final Outcome outcome; |
| private final int streamId; |
| private final long bodyLength; |
| private HeaderExtractionResult(Outcome outcome, int streamId, long bodyLength) |
| { |
| this.outcome = outcome; |
| this.streamId = streamId; |
| this.bodyLength = bodyLength; |
| } |
| |
| boolean isSuccess() |
| { |
| return outcome == Outcome.SUCCESS; |
| } |
| |
| int streamId() |
| { |
| return streamId; |
| } |
| |
| long bodyLength() |
| { |
| return bodyLength; |
| } |
| |
| Header header() |
| { |
| throw new IllegalStateException(String.format("Unable to provide header from extraction result : %s", outcome)); |
| }; |
| |
| ProtocolException error() |
| { |
| throw new IllegalStateException(String.format("Unable to provide error from extraction result : %s", outcome)); |
| } |
| |
| private static class Success extends HeaderExtractionResult |
| { |
| private final Header header; |
| Success(Header header) |
| { |
| super(Outcome.SUCCESS, header.streamId, header.bodySizeInBytes); |
| this.header = header; |
| } |
| |
| @Override |
| Header header() |
| { |
| return header; |
| } |
| } |
| |
| private static class Error extends HeaderExtractionResult |
| { |
| private final ProtocolException error; |
| private Error(ProtocolException error, int streamId, long bodyLength) |
| { |
| super(Outcome.ERROR, streamId, bodyLength); |
| this.error = error; |
| } |
| |
| @Override |
| ProtocolException error() |
| { |
| return error; |
| } |
| } |
| } |
| |
| @VisibleForTesting |
| Envelope decode(ByteBuf buffer) |
| { |
| if (discardingTooLongMessage) |
| { |
| bytesToDiscard = discard(buffer, bytesToDiscard); |
| // If we have discarded everything, throw the exception |
| if (bytesToDiscard <= 0) |
| fail(); |
| return null; |
| } |
| |
| int readableBytes = buffer.readableBytes(); |
| if (readableBytes == 0) |
| return null; |
| |
| int idx = buffer.readerIndex(); |
| |
| // Check the first byte for the protocol version before we wait for a complete header. Protocol versions |
| // 1 and 2 use a shorter header, so we may never have a complete header's worth of bytes. |
| int firstByte = buffer.getByte(idx++); |
| Message.Direction direction = Message.Direction.extractFromVersion(firstByte); |
| int versionNum = firstByte & PROTOCOL_VERSION_MASK; |
| |
| ProtocolVersion version; |
| |
| try |
| { |
| version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols()); |
| } |
| catch (ProtocolException e) |
| { |
| // Skip the remaining useless bytes. Otherwise the channel closing logic may try to decode again. |
| buffer.skipBytes(readableBytes); |
| throw e; |
| } |
| |
| // Wait until we have the complete header |
| if (readableBytes < Header.LENGTH) |
| return null; |
| |
| int flags = buffer.getByte(idx++); |
| EnumSet<Header.Flag> decodedFlags = decodeFlags(version, flags); |
| |
| int streamId = buffer.getShort(idx); |
| idx += 2; |
| |
| // This throws a protocol exceptions if the opcode is unknown |
| Message.Type type; |
| try |
| { |
| type = Message.Type.fromOpcode(buffer.getByte(idx++), direction); |
| } |
| catch (ProtocolException e) |
| { |
| throw ErrorMessage.wrap(e, streamId); |
| } |
| |
| long bodyLength = buffer.getUnsignedInt(idx); |
| idx += Header.BODY_LENGTH_SIZE; |
| |
| long totalLength = bodyLength + Header.LENGTH; |
| if (totalLength > MAX_TOTAL_LENGTH) |
| { |
| // Enter the discard mode and discard everything received so far. |
| discardingTooLongMessage = true; |
| tooLongStreamId = streamId; |
| tooLongTotalLength = totalLength; |
| bytesToDiscard = discard(buffer, totalLength); |
| if (bytesToDiscard <= 0) |
| fail(); |
| return null; |
| } |
| |
| if (buffer.readableBytes() < totalLength) |
| return null; |
| |
| ClientMessageSizeMetrics.bytesReceived.inc(totalLength); |
| ClientMessageSizeMetrics.bytesReceivedPerRequest.update(totalLength); |
| |
| // extract body |
| ByteBuf body = buffer.slice(idx, (int) bodyLength); |
| body.retain(); |
| |
| idx += bodyLength; |
| buffer.readerIndex(idx); |
| |
| return new Envelope(new Header(version, decodedFlags, streamId, type, bodyLength), body); |
| } |
| |
| private EnumSet<Header.Flag> decodeFlags(ProtocolVersion version, int flags) |
| { |
| EnumSet<Header.Flag> decodedFlags = Header.Flag.deserialize(flags); |
| |
| if (version.isBeta() && !decodedFlags.contains(Header.Flag.USE_BETA)) |
| throw new ProtocolException(String.format("Beta version of the protocol used (%s), but USE_BETA flag is unset", version), |
| version); |
| return decodedFlags; |
| } |
| |
| @Override |
| protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> results) |
| { |
| Envelope envelope = decode(buffer); |
| if (envelope == null) |
| return; |
| |
| results.add(envelope); |
| } |
| |
| private void fail() |
| { |
| // Reset to the initial state and throw the exception |
| long tooLongTotalLength = this.tooLongTotalLength; |
| this.tooLongTotalLength = 0; |
| discardingTooLongMessage = false; |
| String msg = String.format("Request is too big: length %d exceeds maximum allowed length %d.", tooLongTotalLength, MAX_TOTAL_LENGTH); |
| throw ErrorMessage.wrap(new InvalidRequestException(msg), tooLongStreamId); |
| } |
| } |
| |
| // How much remains to be discarded |
| private static long discard(ByteBuf buffer, long remainingToDiscard) |
| { |
| int availableToDiscard = (int) Math.min(remainingToDiscard, buffer.readableBytes()); |
| buffer.skipBytes(availableToDiscard); |
| return remainingToDiscard - availableToDiscard; |
| } |
| |
| @ChannelHandler.Sharable |
| public static class Encoder extends MessageToMessageEncoder<Envelope> |
| { |
| public static final Encoder instance = new Envelope.Encoder(); |
| private Encoder(){} |
| |
| public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results) |
| { |
| ByteBuf serializedHeader = source.encodeHeader(); |
| int messageSize = serializedHeader.readableBytes() + source.body.readableBytes(); |
| ClientMessageSizeMetrics.bytesSent.inc(messageSize); |
| ClientMessageSizeMetrics.bytesSentPerResponse.update(messageSize); |
| |
| results.add(serializedHeader); |
| results.add(source.body); |
| } |
| } |
| |
| @ChannelHandler.Sharable |
| public static class Decompressor extends MessageToMessageDecoder<Envelope> |
| { |
| public static Decompressor instance = new Envelope.Decompressor(); |
| private Decompressor(){} |
| |
| public void decode(ChannelHandlerContext ctx, Envelope source, List<Object> results) |
| throws IOException |
| { |
| Connection connection = ctx.channel().attr(Connection.attributeKey).get(); |
| |
| if (!source.header.flags.contains(Header.Flag.COMPRESSED) || connection == null) |
| { |
| results.add(source); |
| return; |
| } |
| |
| org.apache.cassandra.transport.Compressor compressor = connection.getCompressor(); |
| if (compressor == null) |
| { |
| results.add(source); |
| return; |
| } |
| |
| results.add(compressor.decompress(source)); |
| } |
| } |
| |
| @ChannelHandler.Sharable |
| public static class Compressor extends MessageToMessageEncoder<Envelope> |
| { |
| public static Compressor instance = new Compressor(); |
| private Compressor(){} |
| |
| public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results) |
| throws IOException |
| { |
| Connection connection = ctx.channel().attr(Connection.attributeKey).get(); |
| |
| // Never compress STARTUP messages |
| if (source.header.type == Message.Type.STARTUP || connection == null) |
| { |
| results.add(source); |
| return; |
| } |
| |
| org.apache.cassandra.transport.Compressor compressor = connection.getCompressor(); |
| if (compressor == null) |
| { |
| results.add(source); |
| return; |
| } |
| source.header.flags.add(Header.Flag.COMPRESSED); |
| results.add(compressor.compress(source)); |
| } |
| } |
| } |