| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more contributor license |
| * agreements. See the NOTICE file distributed with this work for additional information regarding |
| * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance with the License. You may obtain a |
| * copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software distributed under the License |
| * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
| * or implied. See the License for the specific language governing permissions and limitations under |
| * the License. |
| */ |
| package org.apache.geode.internal.net; |
| |
| import static org.apache.geode.distributed.ConfigurationProperties.CLUSTER_SSL_CIPHERS; |
| import static org.apache.geode.distributed.ConfigurationProperties.CLUSTER_SSL_ENABLED; |
| import static org.apache.geode.distributed.ConfigurationProperties.CLUSTER_SSL_PROTOCOLS; |
| import static org.apache.geode.distributed.ConfigurationProperties.CLUSTER_SSL_REQUIRE_AUTHENTICATION; |
| import static org.apache.geode.distributed.ConfigurationProperties.MCAST_PORT; |
| import static org.apache.geode.internal.security.SecurableCommunicationChannel.CLUSTER; |
| import static org.apache.geode.test.awaitility.GeodeAwaitility.await; |
| import static org.assertj.core.api.Assertions.assertThat; |
| import static org.assertj.core.api.Assertions.assertThatThrownBy; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertNotEquals; |
| import static org.junit.Assert.assertNotNull; |
| import static org.junit.Assert.assertNull; |
| import static org.mockito.Mockito.mock; |
| |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.File; |
| import java.io.IOException; |
| import java.io.ObjectInputStream; |
| import java.io.ObjectOutputStream; |
| import java.net.ConnectException; |
| import java.net.InetAddress; |
| import java.net.InetSocketAddress; |
| import java.net.ServerSocket; |
| import java.net.Socket; |
| import java.net.SocketException; |
| import java.net.SocketTimeoutException; |
| import java.net.URL; |
| import java.nio.ByteBuffer; |
| import java.nio.channels.ServerSocketChannel; |
| import java.nio.channels.SocketChannel; |
| import java.util.Properties; |
| import java.util.concurrent.Semaphore; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicReference; |
| |
| import javax.net.ssl.SSLContext; |
| import javax.net.ssl.SSLException; |
| |
| import org.apache.commons.io.FileUtils; |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.contrib.java.lang.system.RestoreSystemProperties; |
| import org.junit.experimental.categories.Category; |
| import org.junit.rules.ErrorCollector; |
| import org.junit.rules.TemporaryFolder; |
| import org.junit.rules.TestName; |
| |
| import org.apache.geode.distributed.internal.DMStats; |
| import org.apache.geode.distributed.internal.DistributionConfig; |
| import org.apache.geode.distributed.internal.DistributionConfigImpl; |
| import org.apache.geode.internal.ByteBufferOutputStream; |
| import org.apache.geode.internal.security.SecurableCommunicationChannel; |
| import org.apache.geode.internal.tcp.ByteBufferInputStream; |
| import org.apache.geode.test.dunit.IgnoredException; |
| import org.apache.geode.test.junit.categories.MembershipTest; |
| |
| /** |
| * Integration tests for SocketCreatorFactory with SSL. |
| * <p> |
| * <p> |
| * Renamed from {@code JSSESocketJUnitTest}. |
| * |
| * @see ClientSocketFactoryIntegrationTest |
| */ |
| @Category({MembershipTest.class}) |
| public class SSLSocketIntegrationTest { |
| |
| private static final String MESSAGE = SSLSocketIntegrationTest.class.getName() + " Message"; |
| |
| private AtomicReference<String> messageFromClient = new AtomicReference<>(); |
| |
| private DistributionConfig distributionConfig; |
| private SocketCreator socketCreator; |
| private InetAddress localHost; |
| private Thread serverThread; |
| private ServerSocket serverSocket; |
| private Socket clientSocket; |
| |
| @Rule |
| public ErrorCollector errorCollector = new ErrorCollector(); |
| |
| @Rule |
| public RestoreSystemProperties restoreSystemProperties = new RestoreSystemProperties(); |
| |
| @Rule |
| public TemporaryFolder temporaryFolder = new TemporaryFolder(); |
| |
| @Rule |
| public TestName testName = new TestName(); |
| |
| |
| private Throwable serverException; |
| |
| @Before |
| public void setUp() throws Exception { |
| IgnoredException.addIgnoredException("javax.net.ssl.SSLException: Read timed out"); |
| |
| File keystore = findTestKeystore(); |
| System.setProperty("javax.net.ssl.trustStore", keystore.getCanonicalPath()); |
| System.setProperty("javax.net.ssl.trustStorePassword", "password"); |
| System.setProperty("javax.net.ssl.keyStore", keystore.getCanonicalPath()); |
| System.setProperty("javax.net.ssl.keyStorePassword", "password"); |
| // System.setProperty("javax.net.debug", "ssl,handshake"); |
| |
| |
| Properties properties = new Properties(); |
| properties.setProperty(MCAST_PORT, "0"); |
| properties.setProperty(CLUSTER_SSL_ENABLED, "true"); |
| properties.setProperty(CLUSTER_SSL_REQUIRE_AUTHENTICATION, "true"); |
| properties.setProperty(CLUSTER_SSL_CIPHERS, "any"); |
| properties.setProperty(CLUSTER_SSL_PROTOCOLS, "TLSv1.2"); |
| |
| this.distributionConfig = new DistributionConfigImpl(properties); |
| |
| SocketCreatorFactory.setDistributionConfig(this.distributionConfig); |
| this.socketCreator = SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER); |
| |
| this.localHost = InetAddress.getLocalHost(); |
| } |
| |
| @After |
| public void tearDown() throws Exception { |
| if (this.clientSocket != null) { |
| this.clientSocket.close(); |
| } |
| if (this.serverSocket != null) { |
| this.serverSocket.close(); |
| } |
| if (this.serverThread != null && this.serverThread.isAlive()) { |
| this.serverThread.interrupt(); |
| } |
| SocketCreatorFactory.close(); |
| } |
| |
| @Test |
| /** |
| * see GEODE-4087. Geode should not establish a default SSLContext, preventing apps from using |
| * different ssl settings via standard system properties. Since this test class sets these system |
| * properties to establish a default context we merely need to perform an equality check between |
| * the cluster's context and the default context and assert that they aren't the same. |
| */ |
| public void ensureSocketCreatorDoesNotOverrideDefaultSSLContext() throws Exception { |
| SSLContext defaultContext = SSLContext.getDefault(); |
| SSLContext clusterContext = SocketCreatorFactory |
| .getSocketCreatorForComponent(SecurableCommunicationChannel.CLUSTER).getSslContext(); |
| assertNotEquals(clusterContext, defaultContext); |
| } |
| |
| @Test |
| public void socketCreatorShouldUseSsl() throws Exception { |
| assertThat(this.socketCreator.useSSL()).isTrue(); |
| } |
| |
| @Test |
| public void securedSocketTransmissionShouldWork() throws Exception { |
| this.serverSocket = this.socketCreator.createServerSocket(0, 0, this.localHost); |
| this.serverThread = startServer(this.serverSocket, 15000); |
| |
| int serverPort = this.serverSocket.getLocalPort(); |
| this.clientSocket = this.socketCreator.connectForServer(this.localHost, serverPort); |
| |
| // transmit expected string from Client to Server |
| ObjectOutputStream output = new ObjectOutputStream(this.clientSocket.getOutputStream()); |
| output.writeObject(MESSAGE); |
| output.flush(); |
| |
| // this is the real assertion of this test |
| await().until(() -> { |
| return !serverThread.isAlive(); |
| }); |
| assertNull(serverException); |
| assertThat(this.messageFromClient.get()).isEqualTo(MESSAGE); |
| } |
| |
| @Test |
| public void testSecuredSocketTransmissionShouldWorkUsingNIO() throws Exception { |
| ServerSocketChannel serverChannel = ServerSocketChannel.open(); |
| serverSocket = serverChannel.socket(); |
| |
| InetSocketAddress addr = new InetSocketAddress(localHost, 0); |
| serverSocket.bind(addr, 10); |
| int serverPort = this.serverSocket.getLocalPort(); |
| |
| SocketCreator clusterSocketCreator = |
| SocketCreatorFactory.getSocketCreatorForComponent(SecurableCommunicationChannel.CLUSTER); |
| this.serverThread = startServerNIO(serverSocket, 15000); |
| |
| await().until(() -> serverThread.isAlive()); |
| |
| SocketChannel clientChannel = SocketChannel.open(); |
| await().until( |
| () -> clientChannel.connect(new InetSocketAddress(localHost, serverPort))); |
| |
| clientSocket = clientChannel.socket(); |
| NioSslEngine engine = |
| clusterSocketCreator.handshakeSSLSocketChannel(clientSocket.getChannel(), |
| clusterSocketCreator.createSSLEngine("localhost", 1234), 0, true, |
| ByteBuffer.allocate(65535), new BufferPool(mock(DMStats.class))); |
| clientChannel.configureBlocking(true); |
| |
| // transmit expected string from Client to Server |
| writeMessageToNIOSSLServer(clientChannel, engine); |
| writeMessageToNIOSSLServer(clientChannel, engine); |
| writeMessageToNIOSSLServer(clientChannel, engine); |
| // this is the real assertion of this test |
| await().until(() -> { |
| return !serverThread.isAlive(); |
| }); |
| assertNull(serverException); |
| // assertThat(this.messageFromClient.get()).isEqualTo(MESSAGE); |
| } |
| |
| private void writeMessageToNIOSSLServer(SocketChannel clientChannel, NioSslEngine engine) |
| throws IOException { |
| System.out.println("client sending Hello World message to server"); |
| ByteBufferOutputStream bbos = new ByteBufferOutputStream(5000); |
| DataOutputStream dos = new DataOutputStream(bbos); |
| dos.writeUTF("Hello world"); |
| dos.flush(); |
| bbos.flush(); |
| ByteBuffer buffer = bbos.getContentBuffer(); |
| System.out.println( |
| "client buffer position is " + buffer.position() + " and limit is " + buffer.limit()); |
| ByteBuffer wrappedBuffer = engine.wrap(buffer); |
| System.out.println("client wrapped buffer position is " + wrappedBuffer.position() |
| + " and limit is " + wrappedBuffer.limit()); |
| int bytesWritten = clientChannel.write(wrappedBuffer); |
| System.out.println("client bytes written is " + bytesWritten); |
| } |
| |
| private Thread startServerNIO(final ServerSocket serverSocket, int timeoutMillis) |
| throws Exception { |
| Thread serverThread = new Thread(new MyThreadGroup(this.testName.getMethodName()), () -> { |
| NioSslEngine engine = null; |
| Socket socket = null; |
| try { |
| ByteBuffer buffer = ByteBuffer.allocate(65535); |
| |
| socket = serverSocket.accept(); |
| SocketCreator sc = SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER); |
| engine = |
| sc.handshakeSSLSocketChannel(socket.getChannel(), sc.createSSLEngine("localhost", 1234), |
| timeoutMillis, |
| false, |
| ByteBuffer.allocate(500), |
| new BufferPool(mock(DMStats.class))); |
| |
| readMessageFromNIOSSLClient(socket, buffer, engine); |
| readMessageFromNIOSSLClient(socket, buffer, engine); |
| readMessageFromNIOSSLClient(socket, buffer, engine); |
| } catch (Throwable throwable) { |
| throwable.printStackTrace(System.out); |
| serverException = throwable; |
| } finally { |
| if (engine != null && socket != null) { |
| final NioSslEngine nioSslEngine = engine; |
| engine.close(socket.getChannel()); |
| assertThatThrownBy(() -> { |
| nioSslEngine.unwrap(ByteBuffer.wrap(new byte[0])); |
| }) |
| .isInstanceOf(IllegalStateException.class); |
| } |
| } |
| }, this.testName.getMethodName() + "-server"); |
| |
| serverThread.start(); |
| return serverThread; |
| } |
| |
| private void readMessageFromNIOSSLClient(Socket socket, ByteBuffer buffer, NioSslEngine engine) |
| throws IOException { |
| |
| ByteBuffer unwrapped = engine.getUnwrappedBuffer(buffer); |
| // if we already have unencrypted data skip unwrapping |
| if (unwrapped.position() == 0) { |
| int bytesRead; |
| // if we already have encrypted data skip reading from the socket |
| if (buffer.position() == 0) { |
| bytesRead = socket.getChannel().read(buffer); |
| buffer.flip(); |
| } else { |
| bytesRead = buffer.remaining(); |
| } |
| System.out.println("server bytes read is " + bytesRead + ": buffer position is " |
| + buffer.position() + " and limit is " + buffer.limit()); |
| unwrapped = engine.unwrap(buffer); |
| unwrapped.flip(); |
| System.out.println("server unwrapped buffer position is " + unwrapped.position() |
| + " and limit is " + unwrapped.limit()); |
| } |
| ByteBufferInputStream bbis = new ByteBufferInputStream(unwrapped); |
| DataInputStream dis = new DataInputStream(bbis); |
| String welcome = dis.readUTF(); |
| if (unwrapped.position() >= unwrapped.limit()) { |
| unwrapped.position(0).limit(unwrapped.capacity()); |
| } |
| assertThat(welcome).isEqualTo("Hello world"); |
| System.out.println("server read Hello World message from client"); |
| } |
| |
| |
| @Test(expected = SocketTimeoutException.class) |
| public void handshakeCanTimeoutOnServer() throws Throwable { |
| this.serverSocket = this.socketCreator.createServerSocket(0, 0, this.localHost); |
| this.serverThread = startServer(this.serverSocket, 1000); |
| |
| int serverPort = this.serverSocket.getLocalPort(); |
| Socket socket = new Socket(); |
| socket.connect(new InetSocketAddress(localHost, serverPort)); |
| await().untilAsserted(() -> assertFalse(serverThread.isAlive())); |
| assertNotNull(serverException); |
| if (serverException instanceof SSLException |
| && serverException.getCause() instanceof SocketTimeoutException) { |
| throw serverException.getCause(); |
| } |
| throw serverException; |
| } |
| |
| @Test(expected = SocketTimeoutException.class) |
| public void handshakeWithPeerCanTimeout() throws Throwable { |
| ServerSocketChannel serverChannel = ServerSocketChannel.open(); |
| serverSocket = serverChannel.socket(); |
| |
| InetSocketAddress addr = new InetSocketAddress(localHost, 0); |
| serverSocket.bind(addr, 10); |
| int serverPort = this.serverSocket.getLocalPort(); |
| |
| this.serverThread = startServerNIO(this.serverSocket, 1000); |
| |
| Socket socket = new Socket(); |
| await().atMost(5, TimeUnit.MINUTES).until(() -> { |
| try { |
| socket.connect(new InetSocketAddress(localHost, serverPort)); |
| } catch (ConnectException e) { |
| return false; |
| } catch (SocketException e) { |
| return true; // server socket was closed |
| } |
| return true; |
| }); |
| await().untilAsserted(() -> assertFalse(serverThread.isAlive())); |
| assertNotNull(serverException); |
| throw serverException; |
| } |
| |
| @Test |
| public void configureClientSSLSocketCanTimeOut() throws Exception { |
| final Semaphore serverCoordination = new Semaphore(0); |
| |
| // configure a non-SSL server socket. We will connect |
| // a client SSL socket to it and demonstrate that the |
| // handshake times out |
| final ServerSocket serverSocket = new ServerSocket(); |
| serverSocket.bind(new InetSocketAddress(SocketCreator.getLocalHost(), 0)); |
| Thread serverThread = new Thread() { |
| @Override |
| public void run() { |
| serverCoordination.release(); |
| try (Socket clientSocket = serverSocket.accept()) { |
| System.out.println("server thread accepted a connection"); |
| serverCoordination.acquire(); |
| } catch (Exception e) { |
| System.err.println("accept failed"); |
| e.printStackTrace(); |
| } |
| try { |
| serverSocket.close(); |
| } catch (IOException e) { |
| // ignored |
| } |
| System.out.println("server thread is exiting"); |
| } |
| }; |
| serverThread.setName("SocketCreatorJUnitTest serverSocket thread"); |
| serverThread.setDaemon(true); |
| serverThread.start(); |
| |
| serverCoordination.acquire(); |
| |
| SocketCreator socketCreator = |
| SocketCreatorFactory.getSocketCreatorForComponent(SecurableCommunicationChannel.SERVER); |
| |
| int serverSocketPort = serverSocket.getLocalPort(); |
| try { |
| await("connect to server socket").until(() -> { |
| try { |
| Socket clientSocket = socketCreator.connectForClient( |
| SocketCreator.getLocalHost().getHostAddress(), serverSocketPort, 500); |
| clientSocket.close(); |
| System.err.println( |
| "client successfully connected to server but should not have been able to do so"); |
| return false; |
| } catch (SSLException | SocketTimeoutException e) { |
| IOException ioException = e; |
| // we need to verify that this timed out in the handshake |
| // code |
| if (e instanceof SSLException && e.getCause() instanceof SocketTimeoutException) { |
| ioException = (SocketTimeoutException) ioException.getCause(); |
| } |
| System.out.println("client connect attempt timed out - checking stack trace"); |
| StackTraceElement[] trace = ioException.getStackTrace(); |
| for (StackTraceElement element : trace) { |
| if (element.getMethodName().equals("configureClientSSLSocket")) { |
| System.out.println("client connect attempt timed out in the appropriate method"); |
| return true; |
| } |
| } |
| // it wasn't in the configuration method so we need to try again |
| } catch (IOException e) { |
| // server socket may not be in accept() yet, causing a connection-refused |
| // exception |
| } |
| return false; |
| }); |
| } finally { |
| serverCoordination.release(); |
| } |
| } |
| |
| private File findTestKeystore() throws IOException { |
| return copyKeystoreResourceToFile("/ssl/trusted.keystore"); |
| } |
| |
| public File copyKeystoreResourceToFile(final String name) throws IOException { |
| URL resource = getClass().getResource(name); |
| assertThat(resource).isNotNull(); |
| |
| File file = this.temporaryFolder.newFile(name.replaceFirst(".*/", "")); |
| FileUtils.copyURLToFile(resource, file); |
| return file; |
| } |
| |
| private Thread startServer(final ServerSocket serverSocket, int timeoutMillis) throws Exception { |
| Thread serverThread = new Thread(new MyThreadGroup(this.testName.getMethodName()), () -> { |
| try { |
| Socket socket = serverSocket.accept(); |
| SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER).handshakeIfSocketIsSSL(socket, |
| timeoutMillis); |
| assertEquals(0, socket.getSoTimeout()); |
| ObjectInputStream ois = new ObjectInputStream(socket.getInputStream()); |
| messageFromClient.set((String) ois.readObject()); |
| } catch (Throwable throwable) { |
| serverException = throwable; |
| } |
| }, this.testName.getMethodName() + "-server"); |
| |
| serverThread.start(); |
| return serverThread; |
| } |
| |
| private class MyThreadGroup extends ThreadGroup { |
| |
| public MyThreadGroup(final String name) { |
| super(name); |
| } |
| |
| @Override |
| public void uncaughtException(final Thread thread, final Throwable throwable) { |
| errorCollector.addError(throwable); |
| } |
| } |
| |
| } |