blob: 02cbaf84578afcac6dfb4ac8cd907cb58abac069 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.network;
import java.nio.channels.SelectionKey;
import javax.net.ssl.SSLEngine;
import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.memory.SimpleMemoryPool;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.ssl.SslFactory;
import org.apache.kafka.common.security.ssl.mock.TestKeyManagerFactory;
import org.apache.kafka.common.security.ssl.mock.TestProvider;
import org.apache.kafka.common.security.ssl.mock.TestTrustManagerFactory;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.test.TestCondition;
import org.apache.kafka.test.TestSslUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.Provider;
import java.security.Security;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* A set of tests for the selector. These use a test harness that runs a simple socket server that echos back responses.
*/
public class SslSelectorTest extends SelectorTest {
private Map<String, Object> sslClientConfigs;
@Before
public void setUp() throws Exception {
File trustStoreFile = File.createTempFile("truststore", ".jks");
Map<String, Object> sslServerConfigs = TestSslUtils.createSslConfig(false, true, Mode.SERVER, trustStoreFile, "server");
this.server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs);
this.server.start();
this.time = new MockTime();
sslClientConfigs = TestSslUtils.createSslConfig(false, false, Mode.CLIENT, trustStoreFile, "client");
this.channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false);
this.channelBuilder.configure(sslClientConfigs);
this.metrics = new Metrics();
this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext());
}
@After
public void tearDown() throws Exception {
this.selector.close();
this.server.close();
this.metrics.close();
}
@Override
public SecurityProtocol securityProtocol() {
return SecurityProtocol.PLAINTEXT;
}
@Test
public void testConnectionWithCustomKeyManager() throws Exception {
Provider provider = new TestProvider();
Security.addProvider(provider);
int requestSize = 100 * 1024;
final String node = "0";
String request = TestUtils.randomString(requestSize);
Map<String, Object> sslServerConfigs = TestSslUtils.createSslConfig(
TestKeyManagerFactory.ALGORITHM,
TestTrustManagerFactory.ALGORITHM
);
EchoServer server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs);
server.start();
Time time = new MockTime();
File trustStoreFile = new File(TestKeyManagerFactory.TestKeyManager.mockTrustStoreFile);
Map<String, Object> sslClientConfigs = TestSslUtils.createSslConfig(true, true, Mode.CLIENT, trustStoreFile, "client");
ChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
channelBuilder.configure(sslClientConfigs);
Metrics metrics = new Metrics();
Selector selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext());
selector.connect(node, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE);
while (!selector.connected().contains(node))
selector.poll(10000L);
while (!selector.isChannelReady(node))
selector.poll(10000L);
selector.send(createSend(node, request));
waitForBytesBuffered(selector, node);
selector.close(node);
super.verifySelectorEmpty(selector);
Security.removeProvider(provider.getName());
selector.close();
server.close();
metrics.close();
}
@Test
public void testDisconnectWithIntermediateBufferedBytes() throws Exception {
int requestSize = 100 * 1024;
final String node = "0";
String request = TestUtils.randomString(requestSize);
this.selector.close();
this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
this.channelBuilder.configure(sslClientConfigs);
this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext());
connect(node, new InetSocketAddress("localhost", server.port));
selector.send(createSend(node, request));
waitForBytesBuffered(selector, node);
selector.close(node);
verifySelectorEmpty();
}
private void waitForBytesBuffered(Selector selector, String node) throws Exception {
TestUtils.waitForCondition(new TestCondition() {
@Override
public boolean conditionMet() {
try {
selector.poll(0L);
return selector.channel(node).hasBytesBuffered();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}, 2000L, "Failed to reach socket state with bytes buffered");
}
@Test
public void testBytesBufferedChannelWithNoIncomingBytes() throws Exception {
verifyNoUnnecessaryPollWithBytesBuffered(key ->
key.interestOps(key.interestOps() & ~SelectionKey.OP_READ));
}
@Test
public void testBytesBufferedChannelAfterMute() throws Exception {
verifyNoUnnecessaryPollWithBytesBuffered(key -> ((KafkaChannel) key.attachment()).mute());
}
private void verifyNoUnnecessaryPollWithBytesBuffered(Consumer<SelectionKey> disableRead)
throws Exception {
this.selector.close();
String node1 = "1";
String node2 = "2";
final AtomicInteger node1Polls = new AtomicInteger();
this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
this.channelBuilder.configure(sslClientConfigs);
this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()) {
@Override
void pollSelectionKeys(Set<SelectionKey> selectionKeys, boolean isImmediatelyConnected, long currentTimeNanos) {
for (SelectionKey key : selectionKeys) {
KafkaChannel channel = (KafkaChannel) key.attachment();
if (channel != null && channel.id().equals(node1))
node1Polls.incrementAndGet();
}
super.pollSelectionKeys(selectionKeys, isImmediatelyConnected, currentTimeNanos);
}
};
// Get node1 into bytes buffered state and then disable read on the socket.
// Truncate the read buffers to ensure that there is buffered data, but not enough to make progress.
int largeRequestSize = 100 * 1024;
connect(node1, new InetSocketAddress("localhost", server.port));
selector.send(createSend(node1, TestUtils.randomString(largeRequestSize)));
waitForBytesBuffered(selector, node1);
TestSslChannelBuilder.TestSslTransportLayer.transportLayers.get(node1).truncateReadBuffer();
disableRead.accept(selector.channel(node1).selectionKey());
// Clear poll count and count the polls from now on
node1Polls.set(0);
// Process sends and receives on node2. Test verifies that we don't process node1
// unnecessarily on each of these polls.
connect(node2, new InetSocketAddress("localhost", server.port));
int received = 0;
String request = TestUtils.randomString(10);
selector.send(createSend(node2, request));
while (received < 100) {
received += selector.completedReceives().size();
if (!selector.completedSends().isEmpty()) {
selector.send(createSend(node2, request));
}
selector.poll(5);
}
// Verify that pollSelectionKeys was invoked once to process buffered data
// but not again since there isn't sufficient data to process.
assertEquals(1, node1Polls.get());
selector.close(node1);
selector.close(node2);
verifySelectorEmpty();
}
/**
* Renegotiation is not supported since it is potentially unsafe and it has been removed in TLS 1.3
*/
@Test
public void testRenegotiationFails() throws Exception {
String node = "0";
// create connections
InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
// send echo requests and receive responses
while (!selector.isChannelReady(node)) {
selector.poll(1000L);
}
selector.send(createSend(node, node + "-" + 0));
selector.poll(0L);
server.renegotiate();
selector.send(createSend(node, node + "-" + 1));
long expiryTime = System.currentTimeMillis() + 2000;
List<String> disconnected = new ArrayList<>();
while (!disconnected.contains(node) && System.currentTimeMillis() < expiryTime) {
selector.poll(10);
disconnected.addAll(selector.disconnected().keySet());
}
assertTrue("Renegotiation should cause disconnection", disconnected.contains(node));
}
@Override
public void testMuteOnOOM() throws Exception {
//clean up default selector, replace it with one that uses a finite mem pool
selector.close();
MemoryPool pool = new SimpleMemoryPool(900, 900, false, null);
//the initial channel builder is for clients, we need a server one
File trustStoreFile = File.createTempFile("truststore", ".jks");
Map<String, Object> sslServerConfigs = TestSslUtils.createSslConfig(false, true, Mode.SERVER, trustStoreFile, "server");
channelBuilder = new SslChannelBuilder(Mode.SERVER, null, false);
channelBuilder.configure(sslServerConfigs);
selector = new Selector(NetworkReceive.UNLIMITED, 5000, metrics, time, "MetricGroup",
new HashMap<String, String>(), true, false, channelBuilder, pool, new LogContext());
try (ServerSocketChannel ss = ServerSocketChannel.open()) {
ss.bind(new InetSocketAddress(0));
InetSocketAddress serverAddress = (InetSocketAddress) ss.getLocalAddress();
SslSender sender1 = createSender(serverAddress, randomPayload(900));
SslSender sender2 = createSender(serverAddress, randomPayload(900));
sender1.start();
sender2.start();
SocketChannel channelX = ss.accept(); //not defined if its 1 or 2
channelX.configureBlocking(false);
SocketChannel channelY = ss.accept();
channelY.configureBlocking(false);
selector.register("clientX", channelX);
selector.register("clientY", channelY);
boolean handshaked = false;
NetworkReceive firstReceive = null;
long deadline = System.currentTimeMillis() + 5000;
//keep calling poll until:
//1. both senders have completed the handshakes (so server selector has tried reading both payloads)
//2. a single payload is actually read out completely (the other is too big to fit)
while (System.currentTimeMillis() < deadline) {
selector.poll(10);
List<NetworkReceive> completed = selector.completedReceives();
if (firstReceive == null) {
if (!completed.isEmpty()) {
assertEquals("expecting a single request", 1, completed.size());
firstReceive = completed.get(0);
assertTrue(selector.isMadeReadProgressLastPoll());
assertEquals(0, pool.availableMemory());
}
} else {
assertTrue("only expecting single request", completed.isEmpty());
}
handshaked = sender1.waitForHandshake(1) && sender2.waitForHandshake(1);
if (handshaked && firstReceive != null && selector.isOutOfMemory())
break;
}
assertTrue("could not initiate connections within timeout", handshaked);
selector.poll(10);
assertTrue(selector.completedReceives().isEmpty());
assertEquals(0, pool.availableMemory());
assertNotNull("First receive not complete", firstReceive);
assertTrue("Selector not out of memory", selector.isOutOfMemory());
firstReceive.close();
assertEquals(900, pool.availableMemory()); //memory has been released back to pool
List<NetworkReceive> completed = Collections.emptyList();
deadline = System.currentTimeMillis() + 5000;
while (System.currentTimeMillis() < deadline && completed.isEmpty()) {
selector.poll(1000);
completed = selector.completedReceives();
}
assertEquals("could not read remaining request within timeout", 1, completed.size());
assertEquals(0, pool.availableMemory());
assertFalse(selector.isOutOfMemory());
}
}
/**
* Connects and waits for handshake to complete. This is required since SslTransportLayer
* implementation requires the channel to be ready before send is invoked (unlike plaintext
* where send can be invoked straight after connect)
*/
protected void connect(String node, InetSocketAddress serverAddr) throws IOException {
blockingConnect(node, serverAddr);
}
private SslSender createSender(InetSocketAddress serverAddress, byte[] payload) {
return new SslSender(serverAddress, payload);
}
private static class TestSslChannelBuilder extends SslChannelBuilder {
public TestSslChannelBuilder(Mode mode) {
super(mode, null, false);
}
@Override
protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
SocketChannel socketChannel = (SocketChannel) key.channel();
SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine);
return transportLayer;
}
/*
* TestSslTransportLayer will read from socket once every two tries. This increases
* the chance that there will be bytes buffered in the transport layer after read().
*/
static class TestSslTransportLayer extends SslTransportLayer {
static Map<String, TestSslTransportLayer> transportLayers = new HashMap<>();
boolean muteSocket = false;
public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine) throws IOException {
super(channelId, key, sslEngine);
transportLayers.put(channelId, this);
}
@Override
protected int readFromSocketChannel() throws IOException {
if (muteSocket) {
if ((selectionKey().interestOps() & SelectionKey.OP_READ) != 0)
muteSocket = false;
return 0;
}
muteSocket = true;
return super.readFromSocketChannel();
}
// Leave one byte in network read buffer so that some buffered bytes are present,
// but not enough to make progress on a read.
void truncateReadBuffer() throws Exception {
netReadBuffer().position(1);
appReadBuffer().position(0);
muteSocket = true;
}
}
}
}