blob: bf6363aa3e82dd70b475183e5c353eb90d587362 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.websocket;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_FAILURE_DETAIL;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.nifi.controller.ControllerService;
import org.apache.nifi.provenance.ProvenanceEventRecord;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.apache.nifi.websocket.AbstractWebSocketSession;
import org.apache.nifi.websocket.SendMessage;
import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService;
import org.apache.nifi.websocket.WebSocketSession;
import org.junit.Test;
public class TestPutWebSocket {
private void assertFlowFile(WebSocketSession webSocketSession, String serviceId, String endpointId, MockFlowFile ff, WebSocketMessage.Type messageType) {
assertEquals(serviceId, ff.getAttribute(ATTR_WS_CS_ID));
assertEquals(webSocketSession.getSessionId(), ff.getAttribute(ATTR_WS_SESSION_ID));
assertEquals(endpointId, ff.getAttribute(ATTR_WS_ENDPOINT_ID));
assertEquals(webSocketSession.getLocalAddress().toString(), ff.getAttribute(ATTR_WS_LOCAL_ADDRESS));
assertEquals(webSocketSession.getRemoteAddress().toString(), ff.getAttribute(ATTR_WS_REMOTE_ADDRESS));
assertEquals(messageType != null ? messageType.name() : null, ff.getAttribute(ATTR_WS_MESSAGE_TYPE));
}
private WebSocketSession getWebSocketSession() {
final WebSocketSession webSocketSession = spy(AbstractWebSocketSession.class);
when(webSocketSession.getSessionId()).thenReturn("ws-session-id");
when(webSocketSession.getLocalAddress()).thenReturn(new InetSocketAddress("localhost", 12345));
when(webSocketSession.getRemoteAddress()).thenReturn(new InetSocketAddress("example.com", 80));
when(webSocketSession.getTransitUri()).thenReturn("ws://example.com/web-socket");
return webSocketSession;
}
@Test
public void testSessionIsNotSpecified() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(PutWebSocket.class);
final WebSocketService service = spy(WebSocketService.class);
final String serviceId = "ws-service";
final String endpointId = "client-1";
final String textMessageFromServer = "message from server.";
when(service.getIdentifier()).thenReturn(serviceId);
runner.addControllerService(serviceId, service);
runner.enableControllerService(service);
final Map<String, String> attributes = new HashMap<>();
attributes.put(ATTR_WS_CS_ID, serviceId);
attributes.put(ATTR_WS_ENDPOINT_ID, endpointId);
runner.enqueue(textMessageFromServer, attributes);
runner.run();
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
//assertEquals(0, succeededFlowFiles.size()); //No longer valid test after NIFI-3318 since not specifying sessionid will send to all clients
assertEquals(1, succeededFlowFiles.size());
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
//assertEquals(1, failedFlowFiles.size()); //No longer valid test after NIFI-3318
assertEquals(0, failedFlowFiles.size());
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(0, provenanceEvents.size());
}
@Test
public void testServiceIsNotFound() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(PutWebSocket.class);
final ControllerService service = spy(ControllerService.class);
final WebSocketSession webSocketSession = getWebSocketSession();
final String serviceId = "ws-service";
final String endpointId = "client-1";
final String textMessageFromServer = "message from server.";
when(service.getIdentifier()).thenReturn(serviceId);
runner.addControllerService(serviceId, service);
runner.enableControllerService(service);
final Map<String, String> attributes = new HashMap<>();
attributes.put(ATTR_WS_CS_ID, "different-service-id");
attributes.put(ATTR_WS_ENDPOINT_ID, endpointId);
attributes.put(ATTR_WS_SESSION_ID, webSocketSession.getSessionId());
runner.enqueue(textMessageFromServer, attributes);
runner.run();
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
assertEquals(0, succeededFlowFiles.size());
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
assertEquals(1, failedFlowFiles.size());
final MockFlowFile failedFlowFile = failedFlowFiles.iterator().next();
assertNotNull(failedFlowFile.getAttribute(ATTR_WS_FAILURE_DETAIL));
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(0, provenanceEvents.size());
}
@Test
public void testServiceIsNotWebSocketService() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(PutWebSocket.class);
final ControllerService service = spy(ControllerService.class);
final WebSocketSession webSocketSession = getWebSocketSession();
final String serviceId = "ws-service";
final String endpointId = "client-1";
final String textMessageFromServer = "message from server.";
when(service.getIdentifier()).thenReturn(serviceId);
runner.addControllerService(serviceId, service);
runner.enableControllerService(service);
final Map<String, String> attributes = new HashMap<>();
attributes.put(ATTR_WS_CS_ID, serviceId);
attributes.put(ATTR_WS_ENDPOINT_ID, endpointId);
attributes.put(ATTR_WS_SESSION_ID, webSocketSession.getSessionId());
runner.enqueue(textMessageFromServer, attributes);
runner.run();
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
assertEquals(0, succeededFlowFiles.size());
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
assertEquals(1, failedFlowFiles.size());
final MockFlowFile failedFlowFile = failedFlowFiles.iterator().next();
assertNotNull(failedFlowFile.getAttribute(ATTR_WS_FAILURE_DETAIL));
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(0, provenanceEvents.size());
}
@Test
public void testSendFailure() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(PutWebSocket.class);
final WebSocketService service = spy(WebSocketService.class);
final WebSocketSession webSocketSession = getWebSocketSession();
final String serviceId = "ws-service";
final String endpointId = "client-1";
final String textMessageFromServer = "message from server.";
when(service.getIdentifier()).thenReturn(serviceId);
doAnswer(invocation -> {
final SendMessage sendMessage = invocation.getArgument(2);
sendMessage.send(webSocketSession);
return null;
}).when(service).sendMessage(anyString(), anyString(), any(SendMessage.class));
doThrow(new IOException("Sending message failed.")).when(webSocketSession).sendString(anyString());
runner.addControllerService(serviceId, service);
runner.enableControllerService(service);
final Map<String, String> attributes = new HashMap<>();
attributes.put(ATTR_WS_CS_ID, serviceId);
attributes.put(ATTR_WS_ENDPOINT_ID, endpointId);
attributes.put(ATTR_WS_SESSION_ID, webSocketSession.getSessionId());
runner.enqueue(textMessageFromServer, attributes);
runner.run();
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
assertEquals(0, succeededFlowFiles.size());
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
assertEquals(1, failedFlowFiles.size());
final MockFlowFile failedFlowFile = failedFlowFiles.iterator().next();
assertNotNull(failedFlowFile.getAttribute(ATTR_WS_FAILURE_DETAIL));
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(0, provenanceEvents.size());
}
@Test
public void testSuccess() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(PutWebSocket.class);
final WebSocketService service = spy(WebSocketService.class);
final WebSocketSession webSocketSession = getWebSocketSession();
final String serviceId = "ws-service";
final String endpointId = "client-1";
final String textMessageFromServer = "message from server.";
when(service.getIdentifier()).thenReturn(serviceId);
doAnswer(invocation -> {
final SendMessage sendMessage = invocation.getArgument(2);
sendMessage.send(webSocketSession);
return null;
}).when(service).sendMessage(anyString(), anyString(), any(SendMessage.class));
runner.addControllerService(serviceId, service);
runner.enableControllerService(service);
runner.setProperty(PutWebSocket.PROP_WS_MESSAGE_TYPE, "${" + ATTR_WS_MESSAGE_TYPE + "}");
// Enqueue 1st file as Text.
final Map<String, String> attributes = new HashMap<>();
attributes.put(ATTR_WS_CS_ID, serviceId);
attributes.put(ATTR_WS_ENDPOINT_ID, endpointId);
attributes.put(ATTR_WS_SESSION_ID, webSocketSession.getSessionId());
attributes.put(ATTR_WS_MESSAGE_TYPE, WebSocketMessage.Type.TEXT.name());
runner.enqueue(textMessageFromServer, attributes);
// Enqueue 2nd file as Binary.
attributes.put(ATTR_WS_MESSAGE_TYPE, WebSocketMessage.Type.BINARY.name());
runner.enqueue(textMessageFromServer.getBytes(), attributes);
runner.run(2);
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
assertEquals(2, succeededFlowFiles.size());
assertFlowFile(webSocketSession, serviceId, endpointId, succeededFlowFiles.get(0), WebSocketMessage.Type.TEXT);
assertFlowFile(webSocketSession, serviceId, endpointId, succeededFlowFiles.get(1), WebSocketMessage.Type.BINARY);
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
assertEquals(0, failedFlowFiles.size());
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(2, provenanceEvents.size());
}
}