blob: f09b08f32ab32b8bb9c358408b2d20e7e2d0735d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.qpid.jms.transports.netty;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.apache.qpid.jms.transports.TransportOptions;
import org.apache.qpid.jms.transports.TransportSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.HandshakeComplete;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
/**
* Base Server implementation used to create Netty based server implementations for
* unit testing aspects of the client code.
*/
public abstract class NettyServer implements AutoCloseable {
private static final Logger LOG = LoggerFactory.getLogger(NettyServer.class);
static final int PORT = Integer.parseInt(System.getProperty("port", "8007"));
static final String WEBSOCKET_PATH = "/";
static final int SERVER_CHOOSES_PORT = 0;
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;
private Channel serverChannel;
private final TransportOptions options;
private final boolean secure;
private int serverPort = SERVER_CHOOSES_PORT;
private final boolean needClientAuth;
private final boolean webSocketServer;
private int maxFrameSize = NettyTcpTransport.DEFAULT_MAX_FRAME_SIZE;
private String webSocketPath = WEBSOCKET_PATH;
private volatile boolean fragmentWrites;
private volatile SslHandler sslHandler;
private volatile HandshakeComplete handshakeComplete;
private final CountDownLatch handshakeCompletion = new CountDownLatch(1);
private final AtomicBoolean started = new AtomicBoolean();
public NettyServer(TransportOptions options, boolean secure) {
this(options, secure, false);
}
public NettyServer(TransportOptions options, boolean secure, boolean needClientAuth) {
this(options, secure, needClientAuth, false);
}
public NettyServer(TransportOptions options, boolean secure, boolean needClientAuth, boolean webSocketServer) {
this.secure = secure;
this.options = options;
this.needClientAuth = needClientAuth;
this.webSocketServer = webSocketServer;
}
public boolean isSecureServer() {
return secure;
}
public boolean isWebSocketServer() {
return webSocketServer;
}
public String getWebSocketPath() {
return webSocketPath;
}
public void setWebSocketPath(String webSocketPath) {
this.webSocketPath = webSocketPath;
}
public int getMaxFrameSize() {
return maxFrameSize;
}
public void setMaxFrameSize(int maxFrameSize) {
this.maxFrameSize = maxFrameSize;
}
public void setFragmentWrites(boolean fragmentWrites) {
if(!webSocketServer) {
throw new IllegalStateException("Only applicable to WebSocket servers");
}
this.fragmentWrites = fragmentWrites;
}
public boolean isFragmentWrites() {
return fragmentWrites;
}
public boolean awaitHandshakeCompletion(long delayMs) throws InterruptedException {
return handshakeCompletion.await(delayMs, TimeUnit.MILLISECONDS);
}
public HandshakeComplete getHandshakeComplete() {
return handshakeComplete;
}
protected URI getConnectionURI() throws Exception {
if (!started.get()) {
throw new IllegalStateException("Cannot get URI of non-started server");
}
int port = getServerPort();
String scheme;
String path;
if (isWebSocketServer()) {
if (isSecureServer()) {
scheme = "amqpwss";
} else {
scheme = "amqpws";
}
} else {
if (isSecureServer()) {
scheme = "amqps";
} else {
scheme = "amqp";
}
}
if (isWebSocketServer()) {
path = getWebSocketPath();
} else {
path = null;
}
return new URI(scheme, null, "localhost", port, path, null, null);
}
public void start() throws Exception {
start(serverPort);
}
public void start(int listenOn) throws Exception {
if (started.compareAndSet(false, true)) {
// Basic server configuration with NIO only options.
bossGroup = new NioEventLoopGroup(1);
workerGroup = new NioEventLoopGroup();
ServerBootstrap server = new ServerBootstrap();
server.group(bossGroup, workerGroup);
server.channel(NioServerSocketChannel.class);
server.option(ChannelOption.SO_BACKLOG, 100);
server.handler(new LoggingHandler(LogLevel.INFO));
server.childHandler(new ChannelInitializer<Channel>() {
@Override
public void initChannel(Channel ch) throws Exception {
if (isSecureServer()) {
SSLContext context = TransportSupport.createJdkSslContext(options);
SSLEngine engine = TransportSupport.createJdkSslEngine(null, context, options);
engine.setUseClientMode(false);
engine.setNeedClientAuth(needClientAuth);
sslHandler = new SslHandler(engine);
ch.pipeline().addLast(sslHandler);
}
if (webSocketServer) {
ch.pipeline().addLast(new HttpServerCodec());
ch.pipeline().addLast(new HttpObjectAggregator(65536));
ch.pipeline().addLast(new WebSocketServerProtocolHandler(getWebSocketPath(), "amqp", true, maxFrameSize));
}
ch.pipeline().addLast(new NettyServerOutboundHandler());
ch.pipeline().addLast(new NettyServerInboundHandler());
ch.pipeline().addLast(getServerHandler());
}
});
// Start the server using specified port. If value is zero the server
// will select a free port and so we update the server port value after
// in order to reflect the correct value.
serverChannel = server.bind(listenOn).sync().channel();
serverPort = ((InetSocketAddress) serverChannel.localAddress()).getPort();
}
}
protected abstract ChannelHandler getServerHandler();
public void stop() throws InterruptedException {
if (started.compareAndSet(true, false)) {
try {
LOG.info("Syncing channel close");
serverChannel.close().sync();
} catch (InterruptedException e) {
}
// Shut down all event loops to terminate all threads.
int timeout = 100;
LOG.trace("Shutting down boss group");
bossGroup.shutdownGracefully(0, timeout, TimeUnit.MILLISECONDS).awaitUninterruptibly(timeout);
LOG.trace("Boss group shut down");
LOG.trace("Shutting down worker group");
workerGroup.shutdownGracefully(0, timeout, TimeUnit.MILLISECONDS).awaitUninterruptibly(timeout);
LOG.trace("Worker group shut down");
}
}
@Override
public void close() throws InterruptedException {
stop();
}
public int getServerPort() {
if (!started.get()) {
throw new IllegalStateException("Cannot get server port of non-started server");
}
return serverPort;
}
private class NettyServerOutboundHandler extends ChannelOutboundHandlerAdapter {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
LOG.trace("NettyServerHandler: Channel write: {}", msg);
if (isWebSocketServer() && msg instanceof ByteBuf) {
if (isFragmentWrites()) {
ByteBuf orig = (ByteBuf) msg;
int origIndex = orig.readerIndex();
int split = orig.readableBytes()/2;
ByteBuf part1 = orig.copy(origIndex, split);
LOG.trace("NettyServerHandler: Part1: {}", part1);
orig.readerIndex(origIndex + split);
LOG.trace("NettyServerHandler: Part2: {}", orig);
BinaryWebSocketFrame frame1 = new BinaryWebSocketFrame(false, 0, part1);
ctx.writeAndFlush(frame1);
ContinuationWebSocketFrame frame2 = new ContinuationWebSocketFrame(true, 0, orig);
ctx.write(frame2, promise);
} else {
BinaryWebSocketFrame frame = new BinaryWebSocketFrame((ByteBuf) msg);
ctx.write(frame, promise);
}
} else {
ctx.write(msg, promise);
}
}
}
private class NettyServerInboundHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(ChannelHandlerContext context, Object payload) {
if (payload instanceof HandshakeComplete) {
handshakeComplete = (HandshakeComplete) payload;
handshakeCompletion.countDown();
}
}
@Override
public void channelActive(final ChannelHandlerContext ctx) {
LOG.info("NettyServerHandler -> New active channel: {}", ctx.channel());
SslHandler handler = ctx.pipeline().get(SslHandler.class);
if (handler != null) {
handler.handshakeFuture().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
LOG.info("Server -> SSL handshake completed. Succeeded: {}", future.isSuccess());
if (!future.isSuccess()) {
ctx.close();
}
}
});
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
LOG.info("NettyServerHandler: channel has gone inactive: {}", ctx.channel());
ctx.close();
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
LOG.trace("NettyServerHandler: Channel read: {}", msg);
if (msg instanceof WebSocketFrame) {
WebSocketFrame frame = (WebSocketFrame) msg;
ctx.fireChannelRead(frame.content());
} else if (msg instanceof FullHttpRequest) {
// Reject anything not on the WebSocket path
FullHttpRequest request = (FullHttpRequest) msg;
sendHttpResponse(ctx, request, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
} else {
// Forward anything else along to the next handler.
ctx.fireChannelRead(msg);
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.info("NettyServerHandler: NettyServerHandlerException caught on channel: {}", ctx.channel());
// Close the connection when an exception is raised.
cause.printStackTrace();
ctx.close();
}
}
private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest request, FullHttpResponse response) {
// Generate an error page if response getStatus code is not OK (200).
if (response.status().code() != 200) {
ByteBuf buf = Unpooled.copiedBuffer(response.status().toString(), StandardCharsets.UTF_8);
response.content().writeBytes(buf);
buf.release();
HttpUtil.setContentLength(response, response.content().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture f = ctx.channel().writeAndFlush(response);
if (!HttpUtil.isKeepAlive(request) || response.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
protected SslHandler getSslHandler() {
return sslHandler;
}
}