blob: 04cc395ad11bb17b0a93144d1e7c77f13c88ba70 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.io.socket.ssl;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.handler.ssl.SslHandler;
import org.apache.nifi.security.util.SslContextFactory;
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder;
import org.apache.nifi.security.util.TlsConfiguration;
import org.apache.nifi.security.util.TlsPlatform;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.condition.EnabledIf;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Timeout(value = 15)
public class SSLSocketChannelTest {
private static final String LOCALHOST = "localhost";
private static final int GROUP_THREADS = 1;
private static final boolean CLIENT_CHANNEL = true;
private static final boolean SERVER_CHANNEL = false;
private static final int CHANNEL_TIMEOUT = 15000;
private static final int CHANNEL_FAILURE_TIMEOUT = 100;
private static final int CHANNEL_POLL_TIMEOUT = 5000;
private static final int MAX_MESSAGE_LENGTH = 1024;
private static final long SHUTDOWN_TIMEOUT = 100;
private static final String TLS_1_3 = "TLSv1.3";
private static final String TLS_1_2 = "TLSv1.2";
private static final String MESSAGE = "PING\n";
private static final Charset MESSAGE_CHARSET = StandardCharsets.UTF_8;
private static final byte[] MESSAGE_BYTES = MESSAGE.getBytes(StandardCharsets.UTF_8);
private static final int FIRST_BYTE_OFFSET = 1;
private static final int SINGLE_COUNT_DOWN = 1;
private static SSLContext sslContext;
private static final String TLS_1_3_SUPPORTED = "isTls13Supported";
public static boolean isTls13Supported() {
return TlsPlatform.getSupportedProtocols().contains(TLS_1_3);
}
@BeforeAll
public static void setConfiguration() throws GeneralSecurityException {
final TlsConfiguration tlsConfiguration = new TemporaryKeyStoreBuilder().build();
sslContext = SslContextFactory.createSslContext(tlsConfiguration);
}
@Test
public void testClientConnectFailed() throws IOException {
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, "this-host-does-not-exist", 1, null, CLIENT_CHANNEL);
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
assertThrows(Exception.class, sslSocketChannel::connect);
}
@Test
public void testClientConnectHandshakeFailed() throws IOException {
final String enabledProtocol = isTls13Supported() ? TLS_1_3 : TLS_1_2;
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final Channel serverChannel = startServer(group, enabledProtocol, getSingleCountDownLatch());
final int port = getListeningPort(serverChannel);
socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
shutdownGroup(group);
assertThrows(SSLException.class, sslSocketChannel::connect);
} finally {
shutdownGroup(group);
}
}
@Test
public void testClientConnectWriteReadTls12() throws Exception {
assertChannelConnectedWriteReadClosed(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteReadTls13() throws Exception {
assertChannelConnectedWriteReadClosed(TLS_1_3);
}
@Test
public void testClientConnectWriteAvailableReadTls12() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
public void testClientConnectWriteAvailableReadTls13() throws Exception {
assertChannelConnectedWriteAvailableRead(TLS_1_3);
}
@Test
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
public void testServerReadWriteTls12() throws Exception {
assertServerChannelConnectedReadClosed(TLS_1_2);
}
@EnabledIf(TLS_1_3_SUPPORTED)
@Test
@Timeout(value = CHANNEL_TIMEOUT, unit = TimeUnit.MILLISECONDS)
public void testServerReadWriteTls13() throws Exception {
assertServerChannelConnectedReadClosed(TLS_1_3);
}
private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
final SocketAddress socketAddress = new InetSocketAddress(LOCALHOST, 0);
serverSocketChannel.bind(socketAddress);
final Executor executor = Executors.newSingleThreadExecutor();
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try {
final SocketAddress serverLocalAddress = serverSocketChannel.getLocalAddress();
final int listeningPort = (serverLocalAddress instanceof InetSocketAddress) ? ((InetSocketAddress) serverLocalAddress).getPort() : 0;
final Channel channel = startClient(group, listeningPort, enabledProtocol);
try {
final SocketChannel socketChannel = serverSocketChannel.accept();
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, socketChannel, SERVER_CHANNEL);
final BlockingQueue<String> queue = new LinkedBlockingQueue<>();
final Runnable readCommand = () -> {
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
try {
final int messageBytesRead = sslSocketChannel.read(messageBytes);
if (messageBytesRead == MESSAGE_BYTES.length) {
queue.add(new String(messageBytes, MESSAGE_CHARSET));
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
executor.execute(readCommand);
channel.writeAndFlush(MESSAGE).syncUninterruptibly();
final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
assertEquals(MESSAGE, messageRead, "Message not matched");
} finally {
channel.close();
}
} finally {
shutdownGroup(group);
serverSocketChannel.close();
}
}
private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
sslSocketChannel.connect();
assertFalse(sslSocketChannel.isClosed());
assertChannelWriteRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
private void assertChannelConnectedWriteAvailableRead(final String enabledProtocol) throws IOException {
final CountDownLatch countDownLatch = getSingleCountDownLatch();
processClientSslSocketChannel(enabledProtocol, countDownLatch, (sslSocketChannel -> {
try {
sslSocketChannel.connect();
assertFalse(sslSocketChannel.isClosed());
assertChannelWriteAvailableRead(sslSocketChannel, countDownLatch);
sslSocketChannel.close();
assertTrue(sslSocketChannel.isClosed());
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
}
}));
}
private void assertChannelWriteAvailableRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
sslSocketChannel.available();
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel, final CountDownLatch countDownLatch) throws IOException {
sslSocketChannel.write(MESSAGE_BYTES);
awaitCountDownLatch(countDownLatch);
assetMessageRead(sslSocketChannel);
}
private void awaitCountDownLatch(final CountDownLatch countDownLatch) throws IOException {
try {
countDownLatch.await();
} catch (final InterruptedException e) {
throw new IOException("Count Down Interrupted", e);
}
}
private void assetMessageRead(final SSLSocketChannel sslSocketChannel) throws IOException {
final byte firstByteRead = (byte) sslSocketChannel.read();
assertEquals(MESSAGE_BYTES[0], firstByteRead, "Channel Message first byte not matched");
final int available = sslSocketChannel.available();
final int availableExpected = MESSAGE_BYTES.length - FIRST_BYTE_OFFSET;
assertEquals(availableExpected, available, "Available Bytes not matched");
final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
messageBytes[0] = firstByteRead;
final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
assertEquals(messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead, "Channel Message Bytes Read not matched");
final String message = new String(messageBytes, MESSAGE_CHARSET);
assertEquals(MESSAGE, message, "Message not matched");
}
private void processClientSslSocketChannel(final String enabledProtocol, final CountDownLatch countDownLatch, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
try {
final Channel channel = startServer(group, enabledProtocol, countDownLatch);
final int port = getListeningPort(channel);
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
channelConsumer.accept(sslSocketChannel);
} finally {
shutdownGroup(group);
}
}
private int getListeningPort(final Channel serverChannel) {
final SocketAddress address = serverChannel.localAddress();
if (address instanceof InetSocketAddress) {
return ((InetSocketAddress) address).getPort();
}
return 0;
}
private Channel startClient(final EventLoopGroup group, final int port, final String enabledProtocol) {
final Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group);
bootstrap.channel(NioSocketChannel.class);
bootstrap.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(final Channel channel) {
final ChannelPipeline pipeline = channel.pipeline();
final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
setPipelineHandlers(pipeline, sslEngine);
}
});
return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
}
private Channel startServer(final EventLoopGroup group, final String enabledProtocol, final CountDownLatch countDownLatch) {
final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(group);
bootstrap.channel(NioServerSocketChannel.class);
bootstrap.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(final Channel channel) {
final ChannelPipeline pipeline = channel.pipeline();
final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL);
setPipelineHandlers(pipeline, sslEngine);
pipeline.addLast(new SimpleChannelInboundHandler<String>() {
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
countDownLatch.countDown();
}
});
}
});
final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, 0);
bindFuture.syncUninterruptibly();
return bindFuture.channel();
}
private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) {
final SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(useClientMode);
sslEngine.setEnabledProtocols(new String[]{enabledProtocol});
return sslEngine;
}
private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) {
pipeline.addLast(new SslHandler(sslEngine));
pipeline.addLast(new DelimiterBasedFrameDecoder(MAX_MESSAGE_LENGTH, Delimiters.lineDelimiter()));
pipeline.addLast(new StringDecoder());
pipeline.addLast(new StringEncoder());
}
private void shutdownGroup(final EventLoopGroup group) {
group.shutdownGracefully(SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS).syncUninterruptibly();
}
private CountDownLatch getSingleCountDownLatch() {
return new CountDownLatch(SINGLE_COUNT_DOWN);
}
}