blob: 119467cffe1d951c676750b34639002843528732 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.qpid.protonj2.client.transport;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Principal;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.qpid.protonj2.buffer.ProtonBuffer;
import org.apache.qpid.protonj2.buffer.ProtonBufferAllocator;
import org.apache.qpid.protonj2.buffer.ProtonNettyByteBuffer;
import org.apache.qpid.protonj2.buffer.ProtonNettyByteBufferAllocator;
import org.apache.qpid.protonj2.client.SslOptions;
import org.apache.qpid.protonj2.client.TransportOptions;
import org.apache.qpid.protonj2.client.util.IOExceptionSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
/**
* TCP based transport that uses Netty as the underlying IO layer.
*/
public class TcpTransport implements Transport {
private static final Logger LOG = LoggerFactory.getLogger(TcpTransport.class);
protected final AtomicBoolean connected = new AtomicBoolean();
protected final AtomicBoolean closed = new AtomicBoolean();
protected final CountDownLatch connectedLatch = new CountDownLatch(1);
protected final TransportOptions options;
protected final SslOptions sslOptions;
protected final Bootstrap bootstrap;
protected Channel channel;
protected volatile IOException failureCause;
protected String host;
protected int port;
protected TransportListener listener;
/**
* Create a new {@link TcpTransport} instance with the given configuration.
*
* @param bootstrap
* the Netty {@link Bootstrap} that this transport's IO layer is bound to.
* @param options
* the {@link TransportOptions} used to configure the socket connection.
* @param sslOptions
* the {@link SslOptions} to use if the options indicate SSL is enabled.
*/
public TcpTransport(Bootstrap bootstrap, TransportOptions options, SslOptions sslOptions) {
if (options == null) {
throw new IllegalArgumentException("Transport Options cannot be null");
}
if (sslOptions == null) {
throw new IllegalArgumentException("Transport SSL Options cannot be null");
}
if (bootstrap == null) {
throw new IllegalArgumentException("A transport must have an assigned Bootstrap before connect.");
}
this.sslOptions = sslOptions;
this.options = options;
this.bootstrap = bootstrap;
}
@Override
public TcpTransport connect(String host, int port, TransportListener listener) throws IOException {
if (closed.get()) {
throw new IllegalStateException("Transport has already been closed");
}
if (listener == null) {
throw new IllegalArgumentException("A transport listener must be set before connection attempts.");
}
if (host == null || host.isEmpty()) {
throw new IllegalArgumentException("Transport host value cannot be null");
}
if (port < 0 && options.defaultTcpPort() < 0 && (sslOptions.sslEnabled() && sslOptions.defaultSslPort() < 0)) {
throw new IllegalArgumentException("Transport port value must be a non-negative int value or a default port configured");
}
this.host = host;
this.listener = listener;
if (port > 0) {
this.port = port;
} else {
if (sslOptions.sslEnabled()) {
this.port = sslOptions.defaultSslPort();
} else {
this.port = options.defaultTcpPort();
}
}
bootstrap.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel transportChannel) throws Exception {
channel = transportChannel;
configureChannel(transportChannel);
try {
listener.transportInitialized(TcpTransport.this);
} catch (Throwable initError) {
LOG.warn("Error during initialization of channel from Transport Listener");
handleTransportFailure(transportChannel, IOExceptionSupport.create(initError));
throw initError;
}
}
});
configureNetty(bootstrap, options);
bootstrap.connect(getHost(), getPort()).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
return this;
}
@Override
public void awaitConnect() throws InterruptedException, IOException {
connectedLatch.await();
if (!connected.get()) {
if (failureCause != null) {
throw failureCause;
} else {
throw new IOException("Transport was closed before a connection was established.");
}
}
}
@Override
public boolean isConnected() {
return connected.get();
}
@Override
public boolean isSecure() {
return sslOptions.sslEnabled();
}
@Override
public String getHost() {
return host;
}
@Override
public int getPort() {
return port;
}
@Override
public void close() throws IOException {
if (closed.compareAndSet(false, true)) {
connected.set(false);
connectedLatch.countDown();
if (channel != null) {
channel.close().syncUninterruptibly();
}
}
}
@Override
public ProtonBufferAllocator getBufferAllocator() {
return new ProtonNettyByteBufferAllocator() {
@Override
public ProtonBuffer outputBuffer(int initialCapacity) {
return new ProtonNettyByteBuffer(channel.alloc().ioBuffer(initialCapacity));
}
@Override
public ProtonBuffer outputBuffer(int initialCapacity, int maximumCapacity) {
return new ProtonNettyByteBuffer(channel.alloc().ioBuffer(initialCapacity, maximumCapacity));
}
};
}
@Override
public TcpTransport write(ProtonBuffer output) throws IOException {
return write(output, null);
}
@Override
public TcpTransport write(ProtonBuffer output, Runnable onComplete) throws IOException {
checkConnected(output);
LOG.trace("Attempted write of buffer: {}", output);
if (onComplete == null) {
channel.write(toOutputBuffer(output), channel.voidPromise());
} else {
channel.write(toOutputBuffer(output), channel.newPromise().addListener(new GenericFutureListener<Future<? super Void>>() {
@Override
public void operationComplete(Future<? super Void> future) throws Exception {
if (future.isSuccess()) {
onComplete.run();
}
}
}));
}
return this;
}
@Override
public TcpTransport writeAndFlush(ProtonBuffer output) throws IOException {
return writeAndFlush(output, null);
}
@Override
public TcpTransport writeAndFlush(ProtonBuffer output, Runnable onComplete) throws IOException {
checkConnected(output);
LOG.trace("Attempted write and flush of buffer: {}", output);
if (onComplete == null) {
channel.writeAndFlush(toOutputBuffer(output), channel.voidPromise());
} else {
channel.writeAndFlush(toOutputBuffer(output), channel.newPromise().addListener(new GenericFutureListener<Future<? super Void>>() {
@Override
public void operationComplete(Future<? super Void> future) throws Exception {
if (future.isSuccess()) {
onComplete.run();
}
}
}));
}
return this;
}
@Override
public TcpTransport flush() throws IOException {
checkConnected();
LOG.trace("Attempted flush of pending writes");
channel.flush();
return this;
}
@Override
public TransportListener getTransportListener() {
return listener;
}
@Override
public TransportOptions getTransportOptions() {
return options.clone();
}
@Override
public SslOptions getSslOptions() {
return sslOptions.clone();
}
@Override
public Principal getLocalPrincipal() {
Principal result = null;
if (isSecure()) {
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
result = sslHandler.engine().getSession().getLocalPrincipal();
}
return result;
}
protected final ByteBuf toOutputBuffer(final ProtonBuffer output) throws IOException {
final ByteBuf nettyBuf;
if (output instanceof ProtonNettyByteBuffer) {
nettyBuf = (ByteBuf) output.unwrap();
} else {
ProtonNettyByteBuffer wrapped = new ProtonNettyByteBuffer(channel.alloc().ioBuffer(output.getReadableBytes()));
wrapped.writeBytes(output);
nettyBuf = wrapped.unwrap();
}
return nettyBuf;
}
//----- Internal implementation details, can be overridden as needed -----//
protected void addAdditionalHandlers(ChannelPipeline pipeline) {
}
protected ChannelInboundHandlerAdapter createChannelHandler() {
return new NettyTcpTransportHandler();
}
//----- Event Handlers which can be overridden in subclasses -------------//
protected void handleConnected(Channel connectedChannel) throws Exception {
LOG.trace("Channel has become active! Channel is {}", connectedChannel);
channel = connectedChannel;
connected.set(true);
listener.transportConnected(this);
connectedLatch.countDown();
}
protected void handleTransportFailure(Channel failedChannel, Throwable cause) {
if (!closed.get()) {
LOG.trace("Transport indicates connection failure! Channel is {}", failedChannel);
failureCause = IOExceptionSupport.create(cause);
channel = failedChannel;
connected.set(false);
connectedLatch.countDown();
LOG.trace("Firing onTransportError listener");
if (channel.eventLoop().inEventLoop()) {
listener.transportError(failureCause);
} else {
channel.eventLoop().execute(() -> {
listener.transportError(failureCause);
});
}
} else {
LOG.trace("Closed Transport signalled that the channel ended: {}", channel);
}
}
//----- State change handlers and checks ---------------------------------//
protected final void checkConnected() throws IOException {
if (!connected.get() || !channel.isActive()) {
throw new IOException("Cannot send to a non-connected transport.", failureCause);
}
}
private void checkConnected(ProtonBuffer output) throws IOException {
if (!connected.get() || !channel.isActive()) {
if (output instanceof ProtonNettyByteBuffer) {
ReferenceCountUtil.release(output.unwrap());
}
throw new IOException("Cannot send to a non-connected transport.", failureCause);
}
}
private void configureNetty(Bootstrap bootstrap, TransportOptions options) {
bootstrap.option(ChannelOption.TCP_NODELAY, options.tcpNoDelay());
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, options.connectTimeout());
bootstrap.option(ChannelOption.SO_KEEPALIVE, options.tcpKeepAlive());
bootstrap.option(ChannelOption.SO_LINGER, options.soLinger());
if (options.sendBufferSize() != -1) {
bootstrap.option(ChannelOption.SO_SNDBUF, options.sendBufferSize());
}
if (options.receiveBufferSize() != -1) {
bootstrap.option(ChannelOption.SO_RCVBUF, options.receiveBufferSize());
bootstrap.option(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(options.receiveBufferSize()));
}
if (options.trafficClass() != -1) {
bootstrap.option(ChannelOption.IP_TOS, options.trafficClass());
}
if (options.localAddress() != null || options.localPort() != 0) {
if (options.localAddress() != null) {
bootstrap.localAddress(options.localAddress(), options.localPort());
} else {
bootstrap.localAddress(options.localPort());
}
}
}
private void configureChannel(final Channel channel) throws Exception {
if (isSecure()) {
final SslHandler sslHandler;
try {
sslHandler = SslSupport.createSslHandler(channel.alloc(), host, port, sslOptions);
} catch (Exception ex) {
LOG.warn("Error during initialization of channel from SSL Handler creation:");
handleTransportFailure(channel, IOExceptionSupport.create(ex));
throw IOExceptionSupport.create(ex);
}
channel.pipeline().addLast("ssl", sslHandler);
}
if (options.traceBytes()) {
channel.pipeline().addLast("logger", new LoggingHandler(getClass()));
}
addAdditionalHandlers(channel.pipeline());
channel.pipeline().addLast(createChannelHandler());
}
//----- Default implementation of Netty handler --------------------------//
protected abstract class NettyDefaultHandler<E> extends SimpleChannelInboundHandler<E> {
@Override
public final void channelRegistered(ChannelHandlerContext context) throws Exception {
channel = context.channel();
}
@Override
public void channelActive(ChannelHandlerContext context) throws Exception {
// In the Secure case we need to let the handshake complete before we
// trigger the connected event.
if (!isSecure()) {
handleConnected(context.channel());
} else {
SslHandler sslHandler = context.pipeline().get(SslHandler.class);
sslHandler.handshakeFuture().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) {
LOG.trace("SSL Handshake has completed: {}", channel);
handleConnected(channel);
} else {
LOG.trace("SSL Handshake has failed: {}", channel);
handleTransportFailure(channel, future.cause());
}
}
});
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
@Override
public void channelInactive(ChannelHandlerContext context) throws Exception {
handleTransportFailure(context.channel(), new IOException("Remote closed connection unexpectedly"));
}
@Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
handleTransportFailure(context.channel(), cause);
}
}
//----- Handle binary data over socket connections -----------------------//
protected class NettyTcpTransportHandler extends NettyDefaultHandler<ByteBuf> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
LOG.trace("New data read: {}", buffer);
final ProtonNettyByteBuffer wrapped = new ProtonNettyByteBuffer(buffer);
// Avoid all doubts to the contrary
if (channel.eventLoop().inEventLoop()) {
listener.transportRead(wrapped);
} else {
channel.eventLoop().execute(() -> {
listener.transportRead(wrapped);
});
}
}
}
@Override
public URI getRemoteURI() {
if (host != null) {
try {
return new URI(getScheme(), null, host, port, null, null, null);
} catch (URISyntaxException e) {
}
}
return null;
}
protected String getScheme() {
return isSecure() ? "ssl" : "tcp";
}
}