blob: 51f12f3ab7cf60ce44fe83c7a23505c3a35ad43a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.processor.Processor;
import org.apache.nifi.websocket.BinaryMessageConsumer;
import org.apache.nifi.websocket.ConnectedListener;
import org.apache.nifi.websocket.TextMessageConsumer;
import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketServerService;
import org.apache.nifi.websocket.WebSocketSessionInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
public class ITJettyWebSocketCommunication {
protected int serverPort;
protected String serverPath = "/test";
protected WebSocketServerService serverService;
protected ControllerServiceTestContext serverServiceContext;
protected WebSocketClientService clientService;
protected ControllerServiceTestContext clientServiceContext;
protected boolean isSecure() {
return false;
}
@Before
public void setup() throws Exception {
setupServer();
setupClient();
}
private void setupServer() throws Exception {
// Find an open port.
try (final ServerSocket serverSocket = new ServerSocket(0)) {
serverPort = serverSocket.getLocalPort();
}
serverService = new JettyWebSocketServer();
serverServiceContext = new ControllerServiceTestContext(serverService, "JettyWebSocketServer1");
serverServiceContext.setCustomValue(JettyWebSocketServer.LISTEN_PORT, String.valueOf(serverPort));
serverServiceContext.setCustomValue(JettyWebSocketServer.BASIC_AUTH, "true");
serverServiceContext.setCustomValue(JettyWebSocketServer.USERS_PROPERTIES_FILE,
getClass().getResource("/users.properties").getPath());
serverServiceContext.setCustomValue(JettyWebSocketServer.AUTH_ROLES, "user,test");
customizeServer();
serverService.initialize(serverServiceContext.getInitializationContext());
serverService.startServer(serverServiceContext.getConfigurationContext());
}
protected void customizeServer() {
}
private void setupClient() throws Exception {
clientService = new JettyWebSocketClient();
clientServiceContext = new ControllerServiceTestContext(clientService, "JettyWebSocketClient1");
clientServiceContext.setCustomValue(JettyWebSocketClient.WS_URI, (isSecure() ? "wss" : "ws") + "://localhost:" + serverPort + serverPath);
clientServiceContext.setCustomValue(JettyWebSocketClient.USER_NAME, "user2");
clientServiceContext.setCustomValue(JettyWebSocketClient.USER_PASSWORD, "password2");
customizeClient();
clientService.initialize(clientServiceContext.getInitializationContext());
clientService.startClient(clientServiceContext.getConfigurationContext());
}
protected void customizeClient() {
}
@After
public void teardown() throws Exception {
clientService.stopClient();
serverService.stopServer();
}
protected interface MockWebSocketProcessor extends Processor, ConnectedListener, TextMessageConsumer, BinaryMessageConsumer {
}
private boolean isWindowsEnvironment() {
return System.getProperty("os.name").toLowerCase().startsWith("windows");
}
@Test
public void testClientServerCommunication() throws Exception {
assumeFalse(isWindowsEnvironment());
// Expectations.
final CountDownLatch serverIsConnectedByClient = new CountDownLatch(1);
final CountDownLatch clientConnectedServer = new CountDownLatch(1);
final CountDownLatch serverReceivedTextMessageFromClient = new CountDownLatch(1);
final CountDownLatch serverReceivedBinaryMessageFromClient = new CountDownLatch(1);
final CountDownLatch clientReceivedTextMessageFromServer = new CountDownLatch(1);
final CountDownLatch clientReceivedBinaryMessageFromServer = new CountDownLatch(1);
final String textMessageFromClient = "Message from client.";
final String textMessageFromServer = "Message from server.";
final MockWebSocketProcessor serverProcessor = mock(MockWebSocketProcessor.class);
doReturn("serverProcessor1").when(serverProcessor).getIdentifier();
final AtomicReference<String> serverSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(serverIsConnectedByClient, serverSessionIdRef, invocation))
.when(serverProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(serverReceivedTextMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(serverReceivedBinaryMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
serverService.registerProcessor(serverPath, serverProcessor);
final String clientId = "client1";
final MockWebSocketProcessor clientProcessor = mock(MockWebSocketProcessor.class);
doReturn("clientProcessor1").when(clientProcessor).getIdentifier();
final AtomicReference<String> clientSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(clientConnectedServer, clientSessionIdRef, invocation))
.when(clientProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(clientReceivedTextMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(clientReceivedBinaryMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
clientService.registerProcessor(clientId, clientProcessor);
clientService.connect(clientId);
assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendString(textMessageFromClient));
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromClient.getBytes())));
assertTrue("WebSocket server should be able to consume text message.", serverReceivedTextMessageFromClient.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to consume binary message.", serverReceivedBinaryMessageFromClient.await(5, TimeUnit.SECONDS));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendString(textMessageFromServer));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromServer.getBytes())));
assertTrue("WebSocket client should be able to consume text message.", clientReceivedTextMessageFromServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket client should be able to consume binary message.", clientReceivedBinaryMessageFromServer.await(5, TimeUnit.SECONDS));
clientService.deregisterProcessor(clientId, clientProcessor);
serverService.deregisterProcessor(serverPath, serverProcessor);
}
@Test
public void testClientServerCommunicationRecovery() throws Exception {
assumeFalse(isWindowsEnvironment());
// Expectations.
final CountDownLatch serverIsConnectedByClient = new CountDownLatch(1);
final CountDownLatch clientConnectedServer = new CountDownLatch(1);
final CountDownLatch serverReceivedTextMessageFromClient = new CountDownLatch(1);
final CountDownLatch serverReceivedBinaryMessageFromClient = new CountDownLatch(1);
final CountDownLatch clientReceivedTextMessageFromServer = new CountDownLatch(1);
final CountDownLatch clientReceivedBinaryMessageFromServer = new CountDownLatch(1);
final String textMessageFromClient = "Message from client.";
final String textMessageFromServer = "Message from server.";
final MockWebSocketProcessor serverProcessor = mock(MockWebSocketProcessor.class);
doReturn("serverProcessor1").when(serverProcessor).getIdentifier();
final AtomicReference<String> serverSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(serverIsConnectedByClient, serverSessionIdRef, invocation))
.when(serverProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(serverReceivedTextMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(serverReceivedBinaryMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
serverService.registerProcessor(serverPath, serverProcessor);
final String clientId = "client1";
final MockWebSocketProcessor clientProcessor = mock(MockWebSocketProcessor.class);
doReturn("clientProcessor1").when(clientProcessor).getIdentifier();
final AtomicReference<String> clientSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(clientConnectedServer, clientSessionIdRef, invocation))
.when(clientProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(clientReceivedTextMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(clientReceivedBinaryMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
clientService.registerProcessor(clientId, clientProcessor);
clientService.connect(clientId, Collections.emptyMap());
assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));
// Nothing happens if maintenance is executed while sessions are alive.
((JettyWebSocketClient) clientService).maintainSessions();
// Restart server.
serverService.stopServer();
serverService.startServer(serverServiceContext.getConfigurationContext());
// Sessions will be recreated with the same session ids.
((JettyWebSocketClient) clientService).maintainSessions();
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendString(textMessageFromClient));
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromClient.getBytes())));
assertTrue("WebSocket server should be able to consume text message.", serverReceivedTextMessageFromClient.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to consume binary message.", serverReceivedBinaryMessageFromClient.await(5, TimeUnit.SECONDS));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendString(textMessageFromServer));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromServer.getBytes())));
assertTrue("WebSocket client should be able to consume text message.", clientReceivedTextMessageFromServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket client should be able to consume binary message.", clientReceivedBinaryMessageFromServer.await(5, TimeUnit.SECONDS));
clientService.deregisterProcessor(clientId, clientProcessor);
serverService.deregisterProcessor(serverPath, serverProcessor);
}
protected Object assertConnectedEvent(CountDownLatch latch, AtomicReference<String> sessionIdRef, InvocationOnMock invocation) {
final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
assertNotNull(sessionInfo.getLocalAddress());
assertNotNull(sessionInfo.getRemoteAddress());
assertNotNull(sessionInfo.getSessionId());
assertEquals(isSecure(), sessionInfo.isSecure());
sessionIdRef.set(sessionInfo.getSessionId());
latch.countDown();
return null;
}
protected Object assertConsumeTextMessage(CountDownLatch latch, String expectedMessage, InvocationOnMock invocation) {
final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
assertNotNull(sessionInfo.getLocalAddress());
assertNotNull(sessionInfo.getRemoteAddress());
assertNotNull(sessionInfo.getSessionId());
assertEquals(isSecure(), sessionInfo.isSecure());
final String receivedMessage = invocation.getArgument(1);
assertNotNull(receivedMessage);
assertEquals(expectedMessage, receivedMessage);
latch.countDown();
return null;
}
protected Object assertConsumeBinaryMessage(CountDownLatch latch, String expectedMessage, InvocationOnMock invocation) {
final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
assertNotNull(sessionInfo.getLocalAddress());
assertNotNull(sessionInfo.getRemoteAddress());
assertNotNull(sessionInfo.getSessionId());
assertEquals(isSecure(), sessionInfo.isSecure());
final byte[] receivedMessage = invocation.getArgument(1);
final byte[] expectedBinary = expectedMessage.getBytes();
final int offset = invocation.getArgument(2);
final int length = invocation.getArgument(3);
assertNotNull(receivedMessage);
assertEquals(expectedBinary.length, receivedMessage.length);
assertEquals(expectedMessage, new String(receivedMessage));
assertEquals(0, offset);
assertEquals(expectedBinary.length, length);
latch.countDown();
return null;
}
}