blob: 6b15a2a38b0c70b07b503029ec54baffa40d2ed1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.transport;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.metrics.ClientMessageSizeMetrics;
import org.apache.cassandra.transport.messages.ErrorMessage;
import org.apache.cassandra.utils.ByteBufferUtil;
public class Envelope
{
public static final byte PROTOCOL_VERSION_MASK = 0x7f;
public final Header header;
public final ByteBuf body;
/**
* An on-wire message envelope consists of a header and a body.
*
* The header is defined the following way in native protocol version 3 and later:
*
* 0 8 16 24 32 40
* +---------+---------+---------+---------+---------+
* | version | flags | stream | opcode |
* +---------+---------+---------+---------+---------+
* | length |
* +---------+---------+---------+---------+
*/
public Envelope(Header header, ByteBuf body)
{
this.header = header;
this.body = body;
}
public void retain()
{
body.retain();
}
public boolean release()
{
return body.release();
}
@VisibleForTesting
public Envelope clone()
{
return new Envelope(header, Unpooled.wrappedBuffer(ByteBufferUtil.clone(body.nioBuffer())));
}
public static Envelope create(Message.Type type, int streamId, ProtocolVersion version, EnumSet<Header.Flag> flags, ByteBuf body)
{
Header header = new Header(version, flags, streamId, type, body.readableBytes());
return new Envelope(header, body);
}
// used by V4 and earlier in Encoder.encode
public ByteBuf encodeHeader()
{
ByteBuf buf = CBUtil.allocator.buffer(Header.LENGTH);
Message.Type type = header.type;
buf.writeByte(type.direction.addToVersion(header.version.asInt()));
buf.writeByte(Header.Flag.serialize(header.flags));
// Continue to support writing pre-v3 headers so that we can give proper error messages to drivers that
// connect with the v1/v2 protocol. See CASSANDRA-11464.
if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3))
buf.writeShort(header.streamId);
else
buf.writeByte(header.streamId);
buf.writeByte(type.opcode);
buf.writeInt(body.readableBytes());
return buf;
}
// Used by V5 and later
public void encodeHeaderInto(ByteBuffer buf)
{
buf.put((byte) header.type.direction.addToVersion(header.version.asInt()));
buf.put((byte) Envelope.Header.Flag.serialize(header.flags));
if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3))
buf.putShort((short) header.streamId);
else
buf.put((byte) header.streamId);
buf.put((byte) header.type.opcode);
buf.putInt(body.readableBytes());
}
// Used by V5 and later
public void encodeInto(ByteBuffer buf)
{
encodeHeaderInto(buf);
buf.put(body.nioBuffer());
}
public static class Header
{
// 9 bytes in protocol version 3 and later
public static final int LENGTH = 9;
public static final int BODY_LENGTH_SIZE = 4;
public final ProtocolVersion version;
public final EnumSet<Flag> flags;
public final int streamId;
public final Message.Type type;
public final long bodySizeInBytes;
private Header(ProtocolVersion version, EnumSet<Flag> flags, int streamId, Message.Type type, long bodySizeInBytes)
{
this.version = version;
this.flags = flags;
this.streamId = streamId;
this.type = type;
this.bodySizeInBytes = bodySizeInBytes;
}
public enum Flag
{
// The order of that enum matters!!
COMPRESSED,
TRACING,
CUSTOM_PAYLOAD,
WARNING,
USE_BETA;
private static final Flag[] ALL_VALUES = values();
public static EnumSet<Flag> deserialize(int flags)
{
EnumSet<Flag> set = EnumSet.noneOf(Flag.class);
for (int n = 0; n < ALL_VALUES.length; n++)
{
if ((flags & (1 << n)) != 0)
set.add(ALL_VALUES[n]);
}
return set;
}
public static int serialize(EnumSet<Flag> flags)
{
int i = 0;
for (Flag flag : flags)
i |= 1 << flag.ordinal();
return i;
}
}
}
public Envelope with(ByteBuf newBody)
{
return new Envelope(header, newBody);
}
public static class Decoder extends ByteToMessageDecoder
{
private static final int MAX_TOTAL_LENGTH = DatabaseDescriptor.getNativeTransportMaxFrameSize();
private boolean discardingTooLongMessage;
private long tooLongTotalLength;
private long bytesToDiscard;
private int tooLongStreamId;
/**
* Used by protocol V5 and later to extract a CQL message header from the buffer containing
* it, without modifying the position of the underlying buffer. This essentially mirrors the
* pre-V5 code in {@link Decoder#decode(ByteBuf)}, with three differences:
* <ul>
* <li>The input is a ByteBuffer rather than a ByteBuf</li>
* <li>This cannot return null, as V5 always deals with entire CQL messages. Coalescing of bytes
* off the wire happens at the layer below, in {@link org.apache.cassandra.net.FrameDecoder}</li>
* <li>This method never throws {@link ProtocolException}. Instead, a subclass of
* {@link HeaderExtractionResult} is returned which may provide either a {@link Header} or a
* {@link ProtocolException},depending on the result of its {@link HeaderExtractionResult#isSuccess()}
* method.</li>
*</ul>
*
* @param buffer ByteBuffer containing the message envelope
* @return The result of attempting to extract a header from the input buffer.
*/
HeaderExtractionResult extractHeader(ByteBuffer buffer)
{
Preconditions.checkArgument(buffer.remaining() >= Header.LENGTH,
"Undersized buffer supplied. Expected %s, actual %s",
Header.LENGTH,
buffer.remaining());
int idx = buffer.position();
int firstByte = buffer.get(idx++);
int versionNum = firstByte & PROTOCOL_VERSION_MASK;
int flags = buffer.get(idx++);
int streamId = buffer.getShort(idx);
idx += 2;
int opcode = buffer.get(idx++);
long bodyLength = buffer.getInt(idx);
// if a negative length is read, return error but report length as 0 so we don't attempt to skip
if (bodyLength < 0)
return new HeaderExtractionResult.Error(new ProtocolException("Invalid value for envelope header body length field: " + bodyLength),
streamId, bodyLength);
Message.Direction direction = Message.Direction.extractFromVersion(firstByte);
Message.Type type;
ProtocolVersion version;
EnumSet<Header.Flag> decodedFlags;
try
{
// This throws a protocol exception if the version number is unsupported,
// the opcode is unknown or invalid flags are set for the version
version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols());
decodedFlags = decodeFlags(version, flags);
type = Message.Type.fromOpcode(opcode, direction);
return new HeaderExtractionResult.Success(new Header(version, decodedFlags, streamId, type, bodyLength));
}
catch (ProtocolException e)
{
// Including the streamId and bodyLength is a best effort to allow the caller
// to send a meaningful response to the client and continue processing the
// rest of the frame. It's possible that these are bogus and may have contributed
// to the ProtocolException. If so, the upstream CQLMessageHandler should run into
// further errors and once it breaches its threshold for consecutive errors, it will
// cause the channel to be closed.
return new HeaderExtractionResult.Error(e, streamId, bodyLength);
}
}
public static abstract class HeaderExtractionResult
{
enum Outcome { SUCCESS, ERROR };
private final Outcome outcome;
private final int streamId;
private final long bodyLength;
private HeaderExtractionResult(Outcome outcome, int streamId, long bodyLength)
{
this.outcome = outcome;
this.streamId = streamId;
this.bodyLength = bodyLength;
}
boolean isSuccess()
{
return outcome == Outcome.SUCCESS;
}
int streamId()
{
return streamId;
}
long bodyLength()
{
return bodyLength;
}
Header header()
{
throw new IllegalStateException(String.format("Unable to provide header from extraction result : %s", outcome));
};
ProtocolException error()
{
throw new IllegalStateException(String.format("Unable to provide error from extraction result : %s", outcome));
}
private static class Success extends HeaderExtractionResult
{
private final Header header;
Success(Header header)
{
super(Outcome.SUCCESS, header.streamId, header.bodySizeInBytes);
this.header = header;
}
@Override
Header header()
{
return header;
}
}
private static class Error extends HeaderExtractionResult
{
private final ProtocolException error;
private Error(ProtocolException error, int streamId, long bodyLength)
{
super(Outcome.ERROR, streamId, bodyLength);
this.error = error;
}
@Override
ProtocolException error()
{
return error;
}
}
}
@VisibleForTesting
Envelope decode(ByteBuf buffer)
{
if (discardingTooLongMessage)
{
bytesToDiscard = discard(buffer, bytesToDiscard);
// If we have discarded everything, throw the exception
if (bytesToDiscard <= 0)
fail();
return null;
}
int readableBytes = buffer.readableBytes();
if (readableBytes == 0)
return null;
int idx = buffer.readerIndex();
// Check the first byte for the protocol version before we wait for a complete header. Protocol versions
// 1 and 2 use a shorter header, so we may never have a complete header's worth of bytes.
int firstByte = buffer.getByte(idx++);
Message.Direction direction = Message.Direction.extractFromVersion(firstByte);
int versionNum = firstByte & PROTOCOL_VERSION_MASK;
ProtocolVersion version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols());
// Wait until we have the complete header
if (readableBytes < Header.LENGTH)
return null;
int flags = buffer.getByte(idx++);
EnumSet<Header.Flag> decodedFlags = decodeFlags(version, flags);
int streamId = buffer.getShort(idx);
idx += 2;
// This throws a protocol exceptions if the opcode is unknown
Message.Type type;
try
{
type = Message.Type.fromOpcode(buffer.getByte(idx++), direction);
}
catch (ProtocolException e)
{
throw ErrorMessage.wrap(e, streamId);
}
long bodyLength = buffer.getUnsignedInt(idx);
idx += Header.BODY_LENGTH_SIZE;
long totalLength = bodyLength + Header.LENGTH;
if (totalLength > MAX_TOTAL_LENGTH)
{
// Enter the discard mode and discard everything received so far.
discardingTooLongMessage = true;
tooLongStreamId = streamId;
tooLongTotalLength = totalLength;
bytesToDiscard = discard(buffer, totalLength);
if (bytesToDiscard <= 0)
fail();
return null;
}
if (buffer.readableBytes() < totalLength)
return null;
ClientMessageSizeMetrics.bytesReceived.inc(totalLength);
ClientMessageSizeMetrics.bytesReceivedPerRequest.update(totalLength);
// extract body
ByteBuf body = buffer.slice(idx, (int) bodyLength);
body.retain();
idx += bodyLength;
buffer.readerIndex(idx);
return new Envelope(new Header(version, decodedFlags, streamId, type, bodyLength), body);
}
private EnumSet<Header.Flag> decodeFlags(ProtocolVersion version, int flags)
{
EnumSet<Header.Flag> decodedFlags = Header.Flag.deserialize(flags);
if (version.isBeta() && !decodedFlags.contains(Header.Flag.USE_BETA))
throw new ProtocolException(String.format("Beta version of the protocol used (%s), but USE_BETA flag is unset", version),
version);
return decodedFlags;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> results)
{
Envelope envelope = decode(buffer);
if (envelope == null)
return;
results.add(envelope);
}
private void fail()
{
// Reset to the initial state and throw the exception
long tooLongTotalLength = this.tooLongTotalLength;
this.tooLongTotalLength = 0;
discardingTooLongMessage = false;
String msg = String.format("Request is too big: length %d exceeds maximum allowed length %d.", tooLongTotalLength, MAX_TOTAL_LENGTH);
throw ErrorMessage.wrap(new InvalidRequestException(msg), tooLongStreamId);
}
}
// How much remains to be discarded
private static long discard(ByteBuf buffer, long remainingToDiscard)
{
int availableToDiscard = (int) Math.min(remainingToDiscard, buffer.readableBytes());
buffer.skipBytes(availableToDiscard);
return remainingToDiscard - availableToDiscard;
}
@ChannelHandler.Sharable
public static class Encoder extends MessageToMessageEncoder<Envelope>
{
public static final Encoder instance = new Envelope.Encoder();
private Encoder(){}
public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results)
{
ByteBuf serializedHeader = source.encodeHeader();
int messageSize = serializedHeader.readableBytes() + source.body.readableBytes();
ClientMessageSizeMetrics.bytesSent.inc(messageSize);
ClientMessageSizeMetrics.bytesSentPerResponse.update(messageSize);
results.add(serializedHeader);
results.add(source.body);
}
}
@ChannelHandler.Sharable
public static class Decompressor extends MessageToMessageDecoder<Envelope>
{
public static Decompressor instance = new Envelope.Decompressor();
private Decompressor(){}
public void decode(ChannelHandlerContext ctx, Envelope source, List<Object> results)
throws IOException
{
Connection connection = ctx.channel().attr(Connection.attributeKey).get();
if (!source.header.flags.contains(Header.Flag.COMPRESSED) || connection == null)
{
results.add(source);
return;
}
org.apache.cassandra.transport.Compressor compressor = connection.getCompressor();
if (compressor == null)
{
results.add(source);
return;
}
results.add(compressor.decompress(source));
}
}
@ChannelHandler.Sharable
public static class Compressor extends MessageToMessageEncoder<Envelope>
{
public static Compressor instance = new Compressor();
private Compressor(){}
public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results)
throws IOException
{
Connection connection = ctx.channel().attr(Connection.attributeKey).get();
// Never compress STARTUP messages
if (source.header.type == Message.Type.STARTUP || connection == null)
{
results.add(source);
return;
}
org.apache.cassandra.transport.Compressor compressor = connection.getCompressor();
if (compressor == null)
{
results.add(source);
return;
}
source.header.flags.add(Header.Flag.COMPRESSED);
results.add(compressor.compress(source));
}
}
}