blob: b53841796417a168fa42c150250419791765cd09 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.transport;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.*;
import org.apache.cassandra.transport.ClientResourceLimits.Overload;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.MessageToMessageDecoder;
import org.apache.cassandra.auth.AllowAllAuthenticator;
import org.apache.cassandra.auth.AllowAllAuthorizer;
import org.apache.cassandra.auth.AllowAllNetworkAuthorizer;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.cql3.QueryProcessor;
import org.apache.cassandra.metrics.ClientMetrics;
import org.apache.cassandra.net.*;
import org.apache.cassandra.net.proxy.InboundProxyHandler;
import org.apache.cassandra.service.NativeTransportService;
import org.apache.cassandra.transport.CQLMessageHandler.MessageConsumer;
import org.apache.cassandra.transport.messages.*;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.concurrent.NonBlockingRateLimiter;
import org.apache.cassandra.utils.concurrent.Condition;
import static org.apache.cassandra.config.EncryptionOptions.TlsEncryptionPolicy.UNENCRYPTED;
import static org.apache.cassandra.io.util.FileUtils.ONE_MIB;
import static org.apache.cassandra.net.FramingTest.randomishBytes;
import static org.apache.cassandra.transport.Flusher.MAX_FRAMED_PAYLOAD_SIZE;
import static org.apache.cassandra.utils.concurrent.Condition.newOneTimeCondition;
import static org.apache.cassandra.utils.concurrent.NonBlockingRateLimiter.NO_OP_LIMITER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class CQLConnectionTest
{
private static final Logger logger = LoggerFactory.getLogger(CQLConnectionTest.class);
private Random random;
private InetAddress address;
private int port;
private BufferPoolAllocator alloc;
@Before
public void setup()
{
DatabaseDescriptor.toolInitialization();
DatabaseDescriptor.setAuthenticator(new AllowAllAuthenticator());
DatabaseDescriptor.setAuthorizer(new AllowAllAuthorizer());
DatabaseDescriptor.setNetworkAuthorizer(new AllowAllNetworkAuthorizer());
long seed = new SecureRandom().nextLong();
logger.info("seed: {}", seed);
random = new Random(seed);
address = InetAddress.getLoopbackAddress();
try
{
try (ServerSocket serverSocket = new ServerSocket(0))
{
port = serverSocket.getLocalPort();
}
Thread.sleep(250);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
alloc = GlobalBufferPoolAllocator.instance;
// set connection-local queue size to 0 so that all capacity is allocated from reserves
DatabaseDescriptor.setNativeTransportReceiveQueueCapacityInBytes(0);
// set transport to max frame size possible
DatabaseDescriptor.setNativeTransportMaxFrameSize(256 * (int) ONE_MIB);
}
@Test
public void handleErrorDuringNegotiation() throws Throwable
{
int messageCount = 0;
Codec codec = Codec.crc(alloc);
AllocationObserver observer = new AllocationObserver();
InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
// Force protocol version to an unsupported version
controller.withPayloadTransform(msg -> {
ByteBuf bb = (ByteBuf)msg;
bb.setByte(0, 99 & Envelope.PROTOCOL_VERSION_MASK);
return msg;
});
ServerConfigurator configurator = ServerConfigurator.builder()
.withAllocationObserver(observer)
.withProxyController(controller)
.build();
Server server = server(configurator);
Client client = new Client(codec, messageCount);
server.start();
client.connect(address, port);
assertFalse(client.isConnected());
assertThat(client.getConnectionError())
.isNotNull()
.matches(message ->
message.error.getMessage()
.equals("Invalid or unsupported protocol version (99); " +
"supported versions are (3/v3, 4/v4, 5/v5, 6/v6-beta)"));
server.stop();
// the failure happens before any capacity is allocated
observer.verifier().accept(0);
}
@Test
public void handleFrameCorruptionAfterNegotiation() throws Throwable
{
// A corrupt messaging frame should terminate the connection as clients
// generally don't track which stream IDs are present in the frame, and the
// server has no way to signal which streams are affected.
// Before closing, the server should send an ErrorMessage to inform the
// client of the corrupt message.
Function<ByteBuf, ByteBuf> corruptor = msg -> {
flipBit(msg, msg.readableBytes() / 2);
return msg;
};
IntFunction<Envelope> envelopeProvider = i -> randomEnvelope(i, Message.Type.OPTIONS);
Predicate<ErrorMessage> errorCheck =
error -> error.error.getMessage().contains("unrecoverable CRC mismatch detected in frame body");
// expected allocated bytes are 0 as the errors happen before allocation
testFrameCorruption(10, Codec.crc(alloc), envelopeProvider, corruptor, 0, errorCheck);
testFrameCorruption(10, Codec.lz4(alloc), envelopeProvider, corruptor, 0, errorCheck);
testFrameCorruption(100, Codec.crc(alloc), envelopeProvider, corruptor, 0, errorCheck);
testFrameCorruption(100, Codec.lz4(alloc), envelopeProvider, corruptor, 0, errorCheck);
// we don't do more rounds with higher message count as the connection
// will be closed when the first corrupt frame is encountered
}
@Test
public void handleCorruptionOfLargeMessageFrame() throws Throwable
{
// A corrupt messaging frame should terminate the connection as clients
// generally don't track which stream IDs are present in the frame, and the
// server has no way to signal which streams are affected.
// Before closing, the server should send an ErrorMessage to inform the
// client of the corrupt message.
// Client needs to expect multiple responses or else awaitResponses returns
// after the error is first received and we race between handling the exception
// caused by remote disconnection and checking the connection status.
Function<ByteBuf, ByteBuf> corruptor = new Function<ByteBuf, ByteBuf>()
{
// Don't corrupt the first frame as this would fail early and bypass capacity allocation.
// Instead, allow enough bytes to fill the first frame through untouched. Then, corrupt
// a byte which will be in the second frame of the large message .
int seenBytes = 0;
int corruptedByte = 0;
public ByteBuf apply(ByteBuf msg)
{
// If we've already injected some corruption, pass through
if (corruptedByte > 0)
return msg;
// Will the current buffer size take us into the second frame? If so, corrupt it
if (seenBytes + msg.readableBytes() > MAX_FRAMED_PAYLOAD_SIZE + 100)
{
int frameBoundary = MAX_FRAMED_PAYLOAD_SIZE - seenBytes;
corruptedByte = msg.readerIndex() + frameBoundary + 100;
flipBit(msg, corruptedByte);
}
else
{
seenBytes += msg.readableBytes();
}
return msg;
}
};
int totalBytesPerEnvelope = MAX_FRAMED_PAYLOAD_SIZE * 2;
IntFunction<Envelope> envelopeProvider = i -> randomEnvelope(i, Message.Type.OPTIONS, totalBytesPerEnvelope, totalBytesPerEnvelope);
Predicate<ErrorMessage> errorCheck =
error -> error.error.getMessage().contains("unrecoverable CRC mismatch detected in frame body");
testFrameCorruption(2, Codec.crc(alloc), envelopeProvider, corruptor, totalBytesPerEnvelope, errorCheck);
}
@Test
public void testAquireAndRelease()
{
acquireAndRelease(10, 100, Codec.crc(alloc));
acquireAndRelease(10, 100, Codec.lz4(alloc));
acquireAndRelease(100, 1000, Codec.crc(alloc));
acquireAndRelease(100, 1000, Codec.lz4(alloc));
acquireAndRelease(1000, 10000, Codec.crc(alloc));
acquireAndRelease(1000, 10000, Codec.lz4(alloc));
}
private void acquireAndRelease(int minMessages, int maxMessages, Codec codec)
{
final int messageCount = minMessages + random.nextInt(maxMessages - minMessages);
logger.info("Sending total of {} messages", messageCount);
TestConsumer consumer = new TestConsumer(new ResultMessage.Void(), codec.encoder);
AllocationObserver observer = new AllocationObserver();
Message.Decoder<Message.Request> decoder = new FixedDecoder();
Predicate<Envelope.Header> responseMatcher = h -> h.type == Message.Type.RESULT;
ServerConfigurator configurator = ServerConfigurator.builder()
.withConsumer(consumer)
.withAllocationObserver(observer)
.withDecoder(decoder)
.build();
runTest(configurator, codec, messageCount, (i) -> randomEnvelope(i, Message.Type.OPTIONS), responseMatcher, observer.verifier());
}
@Test
public void testRecoverableEnvelopeDecodingErrors()
{
// If an error is encountered while decoding an Envelope header,
// it should be possible to continue processing subsequent Envelopes
// by skipping the Envelope body. For instance, a ProtocolException
// caused by an invalid opcode or version flag in the header should
// not require the connection to be terminated. Instead, an error
// response should be returned with the correct stream id and further
// Envelopes processed as normal.
// every other message should error while extracting the Envelope header
IntPredicate shouldError = i -> i % 2 == 0;
testEnvelopeDecodingErrors(10, shouldError, Codec.crc(alloc));
testEnvelopeDecodingErrors(10, shouldError, Codec.lz4(alloc));
testEnvelopeDecodingErrors(100, shouldError, Codec.crc(alloc));
testEnvelopeDecodingErrors(100, shouldError, Codec.lz4(alloc));
testEnvelopeDecodingErrors(1000, shouldError, Codec.crc(alloc));
testEnvelopeDecodingErrors(1000, shouldError, Codec.lz4(alloc));
}
private void testEnvelopeDecodingErrors(int messageCount, IntPredicate shouldError, Codec codec)
{
TestConsumer consumer = new TestConsumer(new ResultMessage.Void(), codec.encoder);
AllocationObserver observer = new AllocationObserver(false);
Message.Decoder<Message.Request> decoder = new FixedDecoder();
// mutate the request from the erroring streams to have an invalid opcode (99)
IntFunction<Envelope> envelopeProvider = mutatedEnvelopeProvider(shouldError, b -> b.put(4, (byte)99));
Predicate<Envelope.Header> responseMatcher =
h -> (shouldError.test(h.streamId) && h.type == Message.Type.ERROR) || h.type == Message.Type.RESULT;
ServerConfigurator configurator = ServerConfigurator.builder()
.withConsumer(consumer)
.withAllocationObserver(observer)
.withDecoder(decoder)
.build();
runTest(configurator, codec, messageCount, envelopeProvider, responseMatcher, observer.verifier());
}
@Test
public void testUnrecoverableEnvelopeDecodingErrors()
{
// If multiple consecutive Envelopes in a Frame cause protocol
// exceptions during decoding, we fail fast and close the connection.
// The reason for this is that while some protocol errors may be
// non-fatal (e.g. an incorrect opcode, or missing BETA flag), a
// badly behaved client could also include garbage which may render
// any following bytes in the Frame unusable, even though the Frame
// level CRC32 is valid for the payload.
final IntPredicate firstTen = i -> i < 10;
// mutate the request from the erroring streams to have an invalid opcode (99)
IntFunction<Envelope> envelopeProvider = mutatedEnvelopeProvider(firstTen, b -> b.put(4, (byte)99));
Predicate<ErrorMessage> errorCheck = error -> error.error.getMessage().contains("Unknown opcode 99");
testFrameCorruption(100, Codec.crc(alloc), envelopeProvider, Function.identity(), 0, errorCheck);
}
@Test
public void testNegativeEnvelopeBodySize()
{
// A negative value for the body length of an envelope is essentially a
// fatal exception as the stream of bytes is unrecoverable
// every other message should error while extracting the Envelope header
IntPredicate shouldError = i -> i % 2 == 0;
// set the bodyLength byte to a negative value
IntFunction<Envelope> envelopeProvider = mutatedEnvelopeProvider(shouldError, b -> b.putInt(5, -10));
Predicate<ErrorMessage> errorCheck = error ->
error.error.getMessage().contains("Invalid value for envelope header body length field: -10");
testFrameCorruption(100, Codec.crc(alloc), envelopeProvider, Function.identity(), 0, errorCheck);
}
@Test
public void testRecoverableMessageDecodingErrors()
{
// If an error is encountered while decoding a CQL message body
// then it is usually safe to continue processing subsequent
// Envelopes provided that the error is localised to the message
// body. If, following such an error, we are able to successfully
// extract an Envelope header from the Frame payload we continue
// processing as normal. However, if the subsequent header cannot
// be extracted, we infer that the corruption of the previous message
// has rendered the entire Frame unrecoverable and close the client
// connection.
recoverableMessageDecodingErrorEncounteredMidFrame(10, Codec.crc(alloc));
recoverableMessageDecodingErrorEncounteredMidFrame(10, Codec.lz4(alloc));
recoverableMessageDecodingErrorEncounteredMidFrame(100, Codec.crc(alloc));
recoverableMessageDecodingErrorEncounteredMidFrame(100, Codec.lz4(alloc));
recoverableMessageDecodingErrorEncounteredMidFrame(1000, Codec.crc(alloc));
recoverableMessageDecodingErrorEncounteredMidFrame(1000, Codec.lz4(alloc));
}
private void recoverableMessageDecodingErrorEncounteredMidFrame(int messageCount, Codec codec)
{
// Message bodies are consistent with Envelope headers, but decoding a message
// mid-frame generates an error. A concrete example would be a BatchMessage
// which contains SELECT statements.
final int streamWithError = messageCount / 2;
TestConsumer consumer = new TestConsumer(new ResultMessage.Void(), codec.encoder);
AllocationObserver observer = new AllocationObserver();
Message.Decoder<Message.Request> decoder =
new FixedDecoder(i -> i == streamWithError,
new ProtocolException("An exception was encountered when decoding a CQL message"));
Predicate<Envelope.Header> responseMatcher =
h -> (h.streamId == streamWithError && h.type == Message.Type.ERROR) || h.type == Message.Type.RESULT;
ServerConfigurator configurator = ServerConfigurator.builder()
.withConsumer(consumer)
.withAllocationObserver(observer)
.withDecoder(decoder)
.build();
runTest(configurator, codec, messageCount, (i) -> randomEnvelope(i, Message.Type.OPTIONS), responseMatcher, observer.verifier());
}
@Test
public void testUnrecoverableMessageDecodingErrors()
{
// If multiple consecutive CQL Messages in a Frame cause protocol
// exceptions during message decoding, we fail fast and close
// the connection. The reason for this is that while some protocol
// errors may be non-fatal (e.g. a SELECT statement contained in a
// BatchMessage, or unknown consistency level), a badly behaved
// client could also send garbage which may render any following
// bytes in the Frame unusable, even though the Frame level CRC32
// is valid for the payload.
final IntPredicate firstTen = i -> i < 10;
final ProtocolException protocolError = new ProtocolException("Unknown opcode 99");
IntFunction<Envelope> envelopeProvider = (i) -> randomEnvelope(i, Message.Type.OPTIONS);
Message.Decoder<Message.Request> decoder = new FixedDecoder(firstTen, protocolError);
Function<ByteBuf, ByteBuf> frameTransform = Function.identity();
Predicate<ErrorMessage> errorCheck = error -> error.error.getMessage().contains(protocolError.getMessage());
testFrameCorruption(100, Codec.crc(alloc), envelopeProvider, frameTransform, 0, decoder, errorCheck);
}
private void runTest(ServerConfigurator configurator,
Codec codec,
int messageCount,
IntFunction<Envelope> envelopeProvider,
Predicate<Envelope.Header> responseMatcher,
LongConsumer allocationVerifier)
{
Server server = server(configurator);
Client client = new Client(codec, messageCount);
try
{
server.start();
client.connect(address, port);
assertTrue(configurator.waitUntilReady());
for (int i = 0; i < messageCount; i++)
client.send(envelopeProvider.apply(i));
long totalBytes = client.sendSize;
// verify that all messages went through the pipeline & our test message consumer
client.awaitResponses();
Envelope response;
while ((response = client.pollResponses()) != null)
{
response.release();
assertThat(response.header).matches(responseMatcher);
}
// verify that we did have to acquire some resources from the global/endpoint reserves
allocationVerifier.accept(totalBytes);
}
catch (Throwable t)
{
logger.error("Unexpected error", t);
fail();
}
finally
{
client.stop();
server.stop();
}
}
private void testFrameCorruption(int messageCount,
Codec codec,
IntFunction<Envelope> envelopeProvider,
Function<ByteBuf, ByteBuf> transform,
long expectedBytesAllocated,
Predicate<ErrorMessage> errorPredicate)
{
testFrameCorruption(messageCount, codec, envelopeProvider, transform, expectedBytesAllocated, null, errorPredicate);
}
private void testFrameCorruption(int messageCount,
Codec codec,
IntFunction<Envelope> envelopeProvider,
Function<ByteBuf, ByteBuf> transform,
long expectedBytesAllocated,
Message.Decoder<Message.Request> requestDecoder,
Predicate<ErrorMessage> errorPredicate)
{
AllocationObserver observer = new AllocationObserver(false);
InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
if (requestDecoder == null)
requestDecoder = new FixedDecoder();
ServerConfigurator configurator = ServerConfigurator.builder()
.withAllocationObserver(observer)
.withProxyController(controller)
.withDecoder(requestDecoder)
.build();
Server server = server(configurator);
Client client = new Client(codec, messageCount);
server.start();
try
{
client.connect(address, port);
assertTrue(client.isConnected());
// Only install the transform after protocol negotiation is complete
controller.withPayloadTransform(transform);
for (int i = 0; i < messageCount; i++)
client.send(envelopeProvider.apply(i));
client.awaitResponses();
// Client has disconnected
assertFalse(client.isConnected());
// But before it did, it sent an error response
Envelope received = client.inboundMessages.poll();
assertNotNull(received);
Message.Response response = Message.responseDecoder().decode(client.channel, received);
assertEquals(Message.Type.ERROR, response.type);
assertTrue(errorPredicate.test((ErrorMessage) response));
observer.verifier().accept(expectedBytesAllocated);
}
catch (Exception e)
{
logger.error("Unexpected error", e);
fail();
}
finally
{
server.stop();
}
}
private Server server(ServerConfigurator configurator)
{
Server server = new Server.Builder().withHost(address)
.withPort(port)
.withPipelineConfigurator(configurator)
.build();
ClientMetrics.instance.init(Collections.singleton(server));
return server;
}
private Envelope randomEnvelope(int streamId, Message.Type type)
{
return randomEnvelope(streamId, type, 100, 1024);
}
private Envelope randomEnvelope(int streamId, Message.Type type, int minSize, int maxSize)
{
byte[] bytes = randomishBytes(random, minSize, maxSize);
return Envelope.create(type,
streamId,
ProtocolVersion.V5,
EnumSet.of(Envelope.Header.Flag.USE_BETA),
Unpooled.wrappedBuffer(bytes));
}
private IntFunction<Envelope> mutatedEnvelopeProvider(IntPredicate streamIdMatcher, Consumer<ByteBuffer> headerMutator)
{
// enables tests to mutate Envelope headers as they're serialized into a Frame
// payload. For instance, a test may modify the header length or set the opcode
// to something invalid to simulate a buggy client. Frame level CRCs will remain
// valid, so this can be used to exercise the CQL encoding layer.
return (i) -> new MutableEnvelope(randomEnvelope(i, Message.Type.OPTIONS))
{
@Override
Consumer<ByteBuffer> headerTransform()
{
if (streamIdMatcher.test(i))
return headerMutator;
return super.headerTransform();
}
};
}
private void flipBit(ByteBuf buf, int index)
{
buf.setByte(index, buf.getByte(index) ^ (1 << 4));
}
private static class MutableEnvelope extends Envelope
{
public MutableEnvelope(Envelope source)
{
super(source.header, source.body);
}
Consumer<ByteBuffer> headerTransform()
{
return byteBuffer -> {};
}
@Override
public void encodeHeaderInto(ByteBuffer buf)
{
int before = buf.position();
super.encodeHeaderInto(buf);
int after = buf.position();
// slice the output buffer to get another
// which shares the same backing bytes but
// is limited to the size of an Envelope.Header
buf.position(before);
ByteBuffer slice = buf.slice();
slice.limit(after - before);
buf.position(after);
// Apply the transformation to the header bytes
headerTransform().accept(slice);
}
}
// Every CQL Envelope received will be parsed as an OptionsMessage, which is trivial to execute
// on the server. This means we can randomise the actual content of the CQL messages to test
// resource allocation/release (which is based purely on request size), without having to
// worry about processing of the actual messages.
static class FixedDecoder extends Message.Decoder<Message.Request>
{
IntPredicate isErrorStream;
ProtocolException error;
FixedDecoder()
{
this(i -> false, null);
}
FixedDecoder(IntPredicate isErrorStream, ProtocolException error)
{
this.isErrorStream = isErrorStream;
this.error = error;
}
Message.Request decode(Channel channel, Envelope source)
{
if (isErrorStream.test(source.header.streamId))
throw error;
Message.Request request = new OptionsMessage();
request.setSource(source);
request.setStreamId(source.header.streamId);
Connection connection = channel.attr(Connection.attributeKey).get();
request.attach(connection);
return request;
}
}
// A simple consumer which "serves" a static response and employs a naive flusher
static class TestConsumer implements MessageConsumer<Message.Request>
{
final Message.Response fixedResponse;
final Envelope responseTemplate;
final FrameEncoder frameEncoder;
SimpleClient.SimpleFlusher flusher;
TestConsumer(Message.Response fixedResponse, FrameEncoder frameEncoder)
{
this.fixedResponse = fixedResponse;
this.responseTemplate = fixedResponse.encode(ProtocolVersion.V5);
this.frameEncoder = frameEncoder;
}
public void accept(Channel channel, Message.Request message, Dispatcher.FlushItemConverter toFlushItem, Overload backpressure)
{
if (flusher == null)
flusher = new SimpleClient.SimpleFlusher(frameEncoder);
Envelope response = Envelope.create(responseTemplate.header.type,
message.getStreamId(),
ProtocolVersion.V5,
responseTemplate.header.flags,
responseTemplate.body.copy());
flusher.enqueue(response);
// Schedule the proto-flusher to collate any messages to be served
// and flush them to the outbound pipeline
flusher.schedule(channel.pipeline().lastContext());
// this simulates the release of the allocated resources that a real flusher would do
Flusher.FlushItem.Framed item = (Flusher.FlushItem.Framed)toFlushItem.toFlushItem(channel, message, fixedResponse);
item.release();
}
}
static class ServerConfigurator extends PipelineConfigurator
{
private final Condition pipelineReady = newOneTimeCondition();
private final MessageConsumer<Message.Request> consumer;
private final AllocationObserver allocationObserver;
private final Message.Decoder<Message.Request> decoder;
private final InboundProxyHandler.Controller proxyController;
public ServerConfigurator(Builder builder)
{
super(NativeTransportService.useEpoll(), false, false, UNENCRYPTED);
this.consumer = builder.consumer;
this.decoder = builder.decoder;
this.allocationObserver = builder.observer;
this.proxyController = builder.proxyController;
}
static Builder builder()
{
return new Builder();
}
static class Builder
{
MessageConsumer<Message.Request> consumer;
AllocationObserver observer;
Message.Decoder<Message.Request> decoder;
InboundProxyHandler.Controller proxyController;
Builder withConsumer(MessageConsumer<Message.Request> consumer)
{
this.consumer = consumer;
return this;
}
Builder withDecoder(Message.Decoder<Message.Request> decoder)
{
this.decoder = decoder;
return this;
}
Builder withAllocationObserver(AllocationObserver observer)
{
this.observer = observer;
return this;
}
Builder withProxyController(InboundProxyHandler.Controller proxyController)
{
this.proxyController = proxyController;
return this;
}
ServerConfigurator build()
{
return new ServerConfigurator(this);
}
}
protected Message.Decoder<Message.Request> messageDecoder()
{
return decoder == null ? super.messageDecoder() : decoder;
}
protected void onInitialPipelineReady(ChannelPipeline pipeline)
{
if (proxyController != null)
{
InboundProxyHandler proxy = new InboundProxyHandler(proxyController);
pipeline.addFirst("PROXY", proxy);
}
}
protected void onNegotiationComplete(ChannelPipeline pipeline)
{
pipelineReady.signalAll();
}
private boolean waitUntilReady() throws InterruptedException
{
return pipelineReady.await(10, TimeUnit.SECONDS);
}
protected ClientResourceLimits.ResourceProvider resourceProvider(ClientResourceLimits.Allocator limits)
{
final ClientResourceLimits.ResourceProvider.Default delegate =
new ClientResourceLimits.ResourceProvider.Default(limits);
if (null == allocationObserver)
return delegate;
return new ClientResourceLimits.ResourceProvider()
{
public ResourceLimits.Limit globalLimit()
{
return allocationObserver.global(delegate.globalLimit());
}
public AbstractMessageHandler.WaitQueue globalWaitQueue()
{
return delegate.globalWaitQueue();
}
public ResourceLimits.Limit endpointLimit()
{
return allocationObserver.endpoint(delegate.endpointLimit());
}
public AbstractMessageHandler.WaitQueue endpointWaitQueue()
{
return delegate.endpointWaitQueue();
}
@Override
public NonBlockingRateLimiter requestRateLimiter()
{
return NO_OP_LIMITER;
}
public void release()
{
delegate.release();
}
};
}
protected MessageConsumer<Message.Request> messageConsumer()
{
return consumer == null ? super.messageConsumer() : consumer;
}
}
static class AllocationObserver
{
volatile InstrumentedLimit endpoint;
volatile InstrumentedLimit global;
final boolean strict;
AllocationObserver()
{
this(true);
}
AllocationObserver(boolean strict)
{
this.strict = strict;
}
long endpointAllocationTotal()
{
return endpoint == null ? 0 : endpoint.totalAllocated.get();
}
long endpointReleaseTotal()
{
return endpoint == null ? 0 : endpoint.totalReleased.get();
}
long globalAllocationTotal()
{
return global == null ? 0 : global.totalAllocated.get();
}
long globalReleaseTotal()
{
return global == null ? 0 : global.totalReleased.get();
}
synchronized InstrumentedLimit endpoint(ResourceLimits.Limit delegate)
{
if (endpoint == null)
endpoint = new InstrumentedLimit(delegate);
return endpoint;
}
synchronized InstrumentedLimit global(ResourceLimits.Limit delegate)
{
if (global == null)
global = new InstrumentedLimit(delegate);
return global;
}
LongConsumer verifier()
{
return totalBytes -> {
// if strict mode (the default), verify that we did have to acquire the expected resources
// from the global/endpoint reserves and that we released the same amount. If any errors
// were encountered before allocation (i.e. decoding Envelope headers), the message bytes
// are never allocated (and so neither are they released).
if (strict)
{
assertThat(endpointAllocationTotal()).isEqualTo(totalBytes);
assertThat(globalAllocationTotal()).isEqualTo(totalBytes);
// and that we released it all
assertThat(endpointReleaseTotal()).isEqualTo(totalBytes);
assertThat(globalReleaseTotal()).isEqualTo(totalBytes);
}
// assert that we definitely have no outstanding resources acquired from the reserves
ClientResourceLimits.Allocator tracker =
ClientResourceLimits.getAllocatorForEndpoint(FBUtilities.getJustLocalAddress());
assertThat(tracker.endpointUsing()).isEqualTo(0);
assertThat(tracker.globallyUsing()).isEqualTo(0);
};
}
}
static class InstrumentedLimit extends DelegatingLimit
{
AtomicLong totalAllocated = new AtomicLong(0);
AtomicLong totalReleased = new AtomicLong(0);
InstrumentedLimit(ResourceLimits.Limit wrapped)
{
super(wrapped);
}
public boolean tryAllocate(long amount)
{
totalAllocated.addAndGet(amount);
return super.tryAllocate(amount);
}
public ResourceLimits.Outcome release(long amount)
{
totalReleased.addAndGet(amount);
return super.release(amount);
}
}
static class DelegatingLimit implements ResourceLimits.Limit
{
private final ResourceLimits.Limit wrapped;
DelegatingLimit(ResourceLimits.Limit wrapped)
{
this.wrapped = wrapped;
}
public long limit()
{
return wrapped.limit();
}
public long setLimit(long newLimit)
{
return wrapped.setLimit(newLimit);
}
public long remaining()
{
return wrapped.remaining();
}
public long using()
{
return wrapped.using();
}
public boolean tryAllocate(long amount)
{
return wrapped.tryAllocate(amount);
}
public void allocate(long amount)
{
wrapped.allocate(amount);
}
public ResourceLimits.Outcome release(long amount)
{
return wrapped.release(amount);
}
}
static class Codec
{
final FrameEncoder encoder;
final FrameDecoder decoder;
Codec(FrameEncoder encoder, FrameDecoder decoder)
{
this.encoder = encoder;
this.decoder = decoder;
}
static Codec lz4(BufferPoolAllocator alloc)
{
return new Codec(FrameEncoderLZ4.fastInstance, FrameDecoderLZ4.fast(alloc));
}
static Codec crc(BufferPoolAllocator alloc)
{
return new Codec(FrameEncoderCrc.instance, new FrameDecoderCrc(alloc));
}
}
static class Client
{
private final Codec codec;
private Channel channel;
final int expectedResponses;
final CountDownLatch responsesReceived;
private volatile boolean connected = false;
final Queue<Envelope> inboundMessages = new LinkedBlockingQueue<>();
long sendSize = 0;
SimpleClient.SimpleFlusher flusher;
ErrorMessage connectionError;
Throwable disconnectionError;
Client(Codec codec, int expectedResponses)
{
this.codec = codec;
this.expectedResponses = expectedResponses;
this.responsesReceived = new CountDownLatch(expectedResponses);
flusher = new SimpleClient.SimpleFlusher(codec.encoder);
}
private void connect(InetAddress address, int port) throws IOException, InterruptedException
{
final CountDownLatch ready = new CountDownLatch(1);
Bootstrap bootstrap = new Bootstrap()
.group(new NioEventLoopGroup(0, new NamedThreadFactory("TEST-CLIENT")))
.channel(io.netty.channel.socket.nio.NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true);
bootstrap.handler(new ChannelInitializer<Channel>()
{
protected void initChannel(Channel channel) throws Exception
{
BufferPoolAllocator allocator = GlobalBufferPoolAllocator.instance;
channel.config().setOption(ChannelOption.ALLOCATOR, allocator);
ChannelPipeline pipeline = channel.pipeline();
// Outbound handlers to enable us to send the initial STARTUP
pipeline.addLast("envelopeEncoder", Envelope.Encoder.instance);
pipeline.addLast("messageEncoder", PreV5Handlers.ProtocolEncoder.instance);
pipeline.addLast("envelopeDecoder", new Envelope.Decoder());
// Inbound handler to perform the handshake & modify the pipeline on receipt of a READY
pipeline.addLast("handshake", new MessageToMessageDecoder<Envelope>()
{
final Envelope.Decoder decoder = new Envelope.Decoder();
protected void decode(ChannelHandlerContext ctx, Envelope msg, List<Object> out) throws Exception
{
// Handle ERROR responses during initial connection and protocol negotiation
if ( msg.header.type == Message.Type.ERROR)
{
connectionError = (ErrorMessage)Message.responseDecoder()
.decode(ctx.channel(), msg);
msg.release();
logger.info("ERROR");
stop();
ready.countDown();
return;
}
// As soon as we receive a READY message, modify the pipeline
assert msg.header.type == Message.Type.READY;
msg.release();
// just split the messaging into cql messages and stash them for verification
FrameDecoder.FrameProcessor processor = frame -> {
if (frame instanceof FrameDecoder.IntactFrame)
{
ByteBuffer bytes = ((FrameDecoder.IntactFrame)frame).contents.get();
while(bytes.hasRemaining())
{
ByteBuf buffer = Unpooled.wrappedBuffer(bytes);
try
{
inboundMessages.add(decoder.decode(buffer));
responsesReceived.countDown();
}
catch (Exception e)
{
throw new IOException(e);
}
bytes.position(bytes.position() + buffer.readerIndex());
}
}
return true;
};
// for testing purposes, don't actually encode CQL messages,
// we supply messaging frames directly to this client
channel.pipeline().remove("envelopeEncoder");
channel.pipeline().remove("messageEncoder");
channel.pipeline().remove("envelopeDecoder");
// replace this handshake handler with an inbound message frame decoder
channel.pipeline().replace(this, "frameDecoder", codec.decoder);
// add an outbound message frame encoder
channel.pipeline().addLast("frameEncoder", codec.encoder);
channel.pipeline().addLast("errorHandler", new ChannelInboundHandlerAdapter()
{
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, Throwable cause) throws Exception
{
if (cause instanceof IOException)
{
connected = false;
disconnectionError = cause;
}
}
});
codec.decoder.activate(processor);
connected = true;
// Schedule the proto-flusher to collate any messages that have been
// written, via enqueue(Envelope message), and flush them to the outbound pipeline
flusher.schedule(channel.pipeline().lastContext());
ready.countDown();
}
});
}
});
ChannelFuture future = bootstrap.connect(address, port);
// Wait until the connection attempt succeeds or fails.
channel = future.awaitUninterruptibly().channel();
if (!future.isSuccess())
{
bootstrap.group().shutdownGracefully();
throw new IOException("Connection Error", future.cause());
}
// Send an initial STARTUP message to kick off the handshake with the server
Map<String, String> options = new HashMap<>();
options.put(StartupMessage.CQL_VERSION, QueryProcessor.CQL_VERSION.toString());
if (codec.encoder instanceof FrameEncoderLZ4)
options.put(StartupMessage.COMPRESSION, "LZ4");
Connection connection = new Connection(channel, ProtocolVersion.V5, (ch, connection1) -> {});
channel.attr(Connection.attributeKey).set(connection);
channel.writeAndFlush(new StartupMessage(options)).sync();
if (!ready.await(10, TimeUnit.SECONDS))
throw new RuntimeException("Failed to establish client connection in 10s");
}
void send(Envelope request)
{
flusher.enqueue(request);
sendSize += request.header.bodySizeInBytes;
}
private void awaitResponses() throws InterruptedException
{
responsesReceived.await(1, TimeUnit.SECONDS);
}
private boolean isConnected()
{
return connected;
}
private ErrorMessage getConnectionError()
{
return connectionError;
}
private Envelope pollResponses()
{
return inboundMessages.poll();
}
private void stop()
{
if (channel != null && channel.isOpen())
channel.close().awaitUninterruptibly();
flusher.releaseAll();
Envelope f;
while ((f = inboundMessages.poll()) != null)
f.release();
}
}
}