blob: 3a1c3039b20b41085cfbdc2882eeb431f63e2693 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.controller.queue.clustered.server
import org.apache.nifi.events.EventReporter
import org.apache.nifi.reporting.Severity
import org.apache.nifi.security.util.SslContextFactory
import org.apache.nifi.security.util.TemporaryKeyStoreBuilder
import org.apache.nifi.security.util.TlsConfiguration
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLServerSocket
@RunWith(JUnit4.class)
class ConnectionLoadBalanceServerTest extends GroovyTestCase {
private static final String HOSTNAME = "localhost"
private static final int PORT = 54321
private static final int NUM_THREADS = 1
private static final int TIMEOUT_MS = 1000
private static TlsConfiguration tlsConfiguration
private static SSLContext sslContext
private ConnectionLoadBalanceServer lbServer
@BeforeClass
static void setUpOnce() throws Exception {
tlsConfiguration = new TemporaryKeyStoreBuilder().build()
sslContext = SslContextFactory.createSslContext(tlsConfiguration)
}
@Before
void setUp() {
}
@After
void tearDown() {
if (lbServer) {
lbServer.stop()
}
}
@Test
void testRequestPeerListShouldUseTLS() {
// Arrange
SSLContext sslContext = SslContextFactory.createSslContext(tlsConfiguration)
def mockLBP = [
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
] as LoadBalanceProtocol
def mockER = [:] as EventReporter
lbServer = new ConnectionLoadBalanceServer(HOSTNAME, PORT, sslContext, NUM_THREADS, mockLBP, mockER, TIMEOUT_MS)
// Act
lbServer.start()
// Assert
// Assert that the actual socket is set correctly due to the override in the LB server
SSLServerSocket socket = lbServer.serverSocket as SSLServerSocket
assert socket.needClientAuth
// Clean up
lbServer.stop()
}
@Test
void testShouldHandleSSLPeerUnverifiedException() {
// Arrange
final long testStartMillis = System.currentTimeMillis()
final int CONNECTION_ATTEMPTS = 100
// If this test takes longer than 3 seconds, it's likely because of external delays, which would invalidate the assertions
final long MAX_TEST_DURATION_MILLIS = 3000
final String peerDescription = "Test peer"
final SSLPeerUnverifiedException e = new SSLPeerUnverifiedException("Test exception")
InputStream socketInputStream = new ByteArrayInputStream("This is the socket input stream".bytes)
OutputStream socketOutputStream = new ByteArrayOutputStream()
Socket mockSocket = [
getInputStream : { -> socketInputStream },
getOutputStream: { -> socketOutputStream },
] as Socket
LoadBalanceProtocol mockLBProtocol = [
receiveFlowFiles: { Socket s, InputStream i, OutputStream o -> null }
] as LoadBalanceProtocol
EventReporter mockER = [
reportEvent: { Severity s, String c, String m -> }
] as EventReporter
def output = [debug: 0, error: 0]
ConnectionLoadBalanceServer.CommunicateAction communicateAction = new ConnectionLoadBalanceServer.CommunicateAction(mockLBProtocol, mockSocket, mockER)
// Override the threshold to 100 ms
communicateAction.EXCEPTION_THRESHOLD_MILLIS = 100
// Act
CONNECTION_ATTEMPTS.times { int i ->
boolean printedError = communicateAction.handleTlsError(peerDescription, e)
if (printedError) {
output.error++
} else {
output.debug++
}
sleep(10)
}
// Only enforce if the test completed in a reasonable amount of time (i.e. external delays did not influence the timing)
long testStopMillis = System.currentTimeMillis()
long testDurationMillis = testStopMillis - testStartMillis
if (testDurationMillis <= MAX_TEST_DURATION_MILLIS) {
assert output.debug > output.error
}
// Clean up
communicateAction.stop()
}
}