| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.nifi.processors.websocket; |
| |
| import org.apache.nifi.processor.ProcessSessionFactory; |
| import org.apache.nifi.processor.Relationship; |
| import org.apache.nifi.provenance.ProvenanceEventRecord; |
| import org.apache.nifi.provenance.ProvenanceEventType; |
| import org.apache.nifi.remote.io.socket.NetworkUtils; |
| import org.apache.nifi.reporting.InitializationException; |
| import org.apache.nifi.util.MockFlowFile; |
| import org.apache.nifi.util.MockProcessSession; |
| import org.apache.nifi.util.SharedSessionState; |
| import org.apache.nifi.util.TestRunner; |
| import org.apache.nifi.util.TestRunners; |
| import org.apache.nifi.websocket.AbstractWebSocketSession; |
| import org.apache.nifi.websocket.WebSocketClientService; |
| import org.apache.nifi.websocket.WebSocketMessage; |
| import org.apache.nifi.websocket.WebSocketSession; |
| import org.apache.nifi.websocket.jetty.JettyWebSocketClient; |
| import org.apache.nifi.websocket.jetty.JettyWebSocketServer; |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| import java.net.InetSocketAddress; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.concurrent.atomic.AtomicLong; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertTrue; |
| import static org.mockito.Mockito.doAnswer; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.spy; |
| import static org.mockito.Mockito.when; |
| |
| |
| public class TestConnectWebSocket extends TestListenWebSocket { |
| |
| @Test |
| public void testSuccess() throws Exception { |
| final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class); |
| runner.setIncomingConnection(false); |
| |
| final ConnectWebSocket processor = (ConnectWebSocket) runner.getProcessor(); |
| |
| final SharedSessionState sharedSessionState = new SharedSessionState(processor, new AtomicLong(0)); |
| // Use this custom session factory implementation so that createdSessions can be read from test case, |
| // because MockSessionFactory doesn't expose it. |
| final Set<MockProcessSession> createdSessions = new HashSet<>(); |
| final ProcessSessionFactory sessionFactory = () -> { |
| final MockProcessSession session = new MockProcessSession(sharedSessionState, processor); |
| createdSessions.add(session); |
| return session; |
| }; |
| |
| final WebSocketClientService service = mock(WebSocketClientService.class); |
| |
| final WebSocketSession webSocketSession = spy(AbstractWebSocketSession.class); |
| when(webSocketSession.getSessionId()).thenReturn("ws-session-id"); |
| when(webSocketSession.getLocalAddress()).thenReturn(new InetSocketAddress("localhost", 12345)); |
| when(webSocketSession.getRemoteAddress()).thenReturn(new InetSocketAddress("example.com", 80)); |
| |
| final String serviceId = "ws-service"; |
| final String endpointId = "client-1"; |
| final String textMessageFromServer = "message from server."; |
| when(service.getIdentifier()).thenReturn(serviceId); |
| when(service.getTargetUri()).thenReturn("ws://example.com/web-socket"); |
| doAnswer(invocation -> { |
| processor.connected(webSocketSession); |
| // Two times. |
| processor.consume(webSocketSession, textMessageFromServer); |
| processor.consume(webSocketSession, textMessageFromServer); |
| // Three times. |
| final byte[] binaryMessage = textMessageFromServer.getBytes(); |
| processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); |
| processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); |
| processor.consume(webSocketSession, binaryMessage, 0, binaryMessage.length); |
| return null; |
| }).when(service).connect(endpointId); |
| runner.addControllerService(serviceId, service); |
| |
| runner.enableControllerService(service); |
| |
| runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_SERVICE, serviceId); |
| runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_ID, endpointId); |
| |
| processor.onTrigger(runner.getProcessContext(), sessionFactory); |
| |
| final Map<Relationship, List<MockFlowFile>> transferredFlowFiles = getAllTransferredFlowFiles(createdSessions, processor); |
| |
| List<MockFlowFile> connectedFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_CONNECTED); |
| assertEquals(1, connectedFlowFiles.size()); |
| connectedFlowFiles.forEach(ff -> { |
| assertFlowFile(webSocketSession, serviceId, endpointId, ff, null); |
| }); |
| |
| List<MockFlowFile> textFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_MESSAGE_TEXT); |
| assertEquals(2, textFlowFiles.size()); |
| textFlowFiles.forEach(ff -> { |
| assertFlowFile(webSocketSession, serviceId, endpointId, ff, WebSocketMessage.Type.TEXT); |
| }); |
| |
| List<MockFlowFile> binaryFlowFiles = transferredFlowFiles.get(AbstractWebSocketGatewayProcessor.REL_MESSAGE_BINARY); |
| assertEquals(3, binaryFlowFiles.size()); |
| binaryFlowFiles.forEach(ff -> { |
| assertFlowFile(webSocketSession, serviceId, endpointId, ff, WebSocketMessage.Type.BINARY); |
| }); |
| |
| final List<ProvenanceEventRecord> provenanceEvents = sharedSessionState.getProvenanceEvents(); |
| assertEquals(6, provenanceEvents.size()); |
| assertTrue(provenanceEvents.stream().allMatch(event -> ProvenanceEventType.RECEIVE.equals(event.getEventType()))); |
| } |
| |
| @Test |
| public void testDynamicUrlsParsedFromFlowFileAndAbleToConnect() throws InitializationException { |
| // Start websocket server |
| final int port = NetworkUtils.availablePort(); |
| TestRunner webSocketListener = getListenWebSocket(port); |
| webSocketListener.run(1, false); |
| |
| final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class); |
| |
| final String serviceId = "ws-service"; |
| final String endpointId = "client-1"; |
| |
| Map<String, String> attributes = new HashMap<>(); |
| attributes.put("dynamicUrlPart", "test"); |
| MockFlowFile flowFile = new MockFlowFile(1L); |
| flowFile.putAttributes(attributes); |
| runner.enqueue(flowFile); |
| |
| attributes.put("dynamicUrlPart", "test2"); |
| MockFlowFile flowFileWithWrongUrl = new MockFlowFile(2L); |
| flowFileWithWrongUrl.putAttributes(attributes); |
| runner.enqueue(flowFileWithWrongUrl); |
| |
| JettyWebSocketClient service = new JettyWebSocketClient(); |
| |
| |
| runner.addControllerService(serviceId, service); |
| runner.setProperty(service, JettyWebSocketClient.WS_URI, String.format("ws://localhost:%s/${dynamicUrlPart}", port)); |
| runner.enableControllerService(service); |
| |
| runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_SERVICE, serviceId); |
| runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_ID, endpointId); |
| |
| runner.run(1, false); |
| |
| final List<MockFlowFile> flowFilesForRelationship = runner.getFlowFilesForRelationship(ConnectWebSocket.REL_CONNECTED); |
| assertEquals(1, flowFilesForRelationship.size()); |
| |
| final AssertionError assertionError = Assert.assertThrows(AssertionError.class, () -> runner.run(1)); |
| assertTrue(assertionError.getCause().getLocalizedMessage().contains("Failed to renew session and connect to WebSocket service")); |
| |
| runner.stop(); |
| webSocketListener.stop(); |
| } |
| |
| private TestRunner getListenWebSocket(final int port) throws InitializationException { |
| final TestRunner runner = TestRunners.newTestRunner(ListenWebSocket.class); |
| |
| final String serviceId = "ws-server-service"; |
| JettyWebSocketServer service = new JettyWebSocketServer(); |
| runner.addControllerService(serviceId, service); |
| runner.setProperty(service, JettyWebSocketServer.LISTEN_PORT, String.valueOf(port)); |
| runner.enableControllerService(service); |
| |
| runner.setProperty(ListenWebSocket.PROP_WEBSOCKET_SERVER_SERVICE, serviceId); |
| runner.setProperty(ListenWebSocket.PROP_SERVER_URL_PATH, "/test"); |
| |
| return runner; |
| } |
| } |