blob: c516464eb214e5007588cc1515de68a84ab80184 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.websocket;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.annotation.behavior.TriggerSerially;
import org.apache.nifi.annotation.lifecycle.OnStopped;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.AbstractSessionFactoryProcessor;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.ProcessSessionFactory;
import org.apache.nifi.processor.ProcessorInitializationContext;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.websocket.BinaryMessageConsumer;
import org.apache.nifi.websocket.ConnectedListener;
import org.apache.nifi.websocket.TextMessageConsumer;
import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketConfigurationException;
import org.apache.nifi.websocket.WebSocketConnectedMessage;
import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService;
import org.apache.nifi.websocket.WebSocketSessionInfo;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;
@TriggerSerially
public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionFactoryProcessor implements ConnectedListener, TextMessageConsumer, BinaryMessageConsumer {
protected volatile ComponentLog logger;
protected volatile ProcessSessionFactory processSessionFactory;
protected WebSocketService webSocketService;
protected String endpointId;
public static final Relationship REL_CONNECTED = new Relationship.Builder()
.name("connected")
.description("The WebSocket session is established")
.build();
public static final Relationship REL_MESSAGE_TEXT = new Relationship.Builder()
.name("text message")
.description("The WebSocket text message output")
.build();
public static final Relationship REL_MESSAGE_BINARY = new Relationship.Builder()
.name("binary message")
.description("The WebSocket binary message output")
.build();
static Set<Relationship> getAbstractRelationships() {
final Set<Relationship> relationships = new HashSet<>();
relationships.add(REL_CONNECTED);
relationships.add(REL_MESSAGE_TEXT);
relationships.add(REL_MESSAGE_BINARY);
return relationships;
}
@Override
protected void init(final ProcessorInitializationContext context) {
logger = getLogger();
}
@FunctionalInterface
public interface WebSocketFunction {
void execute(final WebSocketService webSocketService) throws IOException, WebSocketConfigurationException;
}
@Override
public void connected(WebSocketSessionInfo sessionInfo) {
final WebSocketMessage message = new WebSocketConnectedMessage(sessionInfo);
sessionInfo.setTransitUri(getTransitUri(sessionInfo));
enqueueMessage(message);
}
@Override
public void consume(WebSocketSessionInfo sessionInfo, String messageStr) {
final WebSocketMessage message = new WebSocketMessage(sessionInfo);
sessionInfo.setTransitUri(getTransitUri(sessionInfo));
message.setPayload(messageStr);
enqueueMessage(message);
}
@Override
public void consume(WebSocketSessionInfo sessionInfo, byte[] payload, int offset, int length) {
final WebSocketMessage message = new WebSocketMessage(sessionInfo);
sessionInfo.setTransitUri(getTransitUri(sessionInfo));
message.setPayload(payload, offset, length);
enqueueMessage(message);
}
// @OnScheduled can not report error messages well on bulletin since it's an async method.
// So, let's do it in onTrigger().
public void onWebSocketServiceReady(final WebSocketService webSocketService, final ProcessContext context) throws IOException {
if (webSocketService instanceof WebSocketClientService) {
// If it's a ws client, then connect to the remote here.
// Otherwise, ws server is already started at WebSocketServerService
WebSocketClientService webSocketClientService = (WebSocketClientService) webSocketService;
if (context.hasIncomingConnection()) {
final ProcessSession session = processSessionFactory.createSession();
final FlowFile flowFile = session.get();
final Map<String, String> attributes = flowFile.getAttributes();
try {
webSocketClientService.connect(endpointId, attributes);
} finally {
session.remove(flowFile);
session.commitAsync();
}
} else {
webSocketClientService.connect(endpointId);
}
}
}
protected void registerProcessorToService(final ProcessContext context, final WebSocketFunction afterRegistration) throws IOException, WebSocketConfigurationException {
webSocketService = getWebSocketService(context);
endpointId = getEndpointId(context);
webSocketService.registerProcessor(endpointId, this);
afterRegistration.execute(webSocketService);
}
protected abstract WebSocketService getWebSocketService(final ProcessContext context);
protected abstract String getEndpointId(final ProcessContext context);
protected boolean isProcessorRegisteredToService() {
return webSocketService != null
&& !StringUtils.isEmpty(endpointId)
&& webSocketService.isProcessorRegistered(endpointId, this);
}
@OnStopped
public void onStopped(final ProcessContext context) {
deregister();
}
private void deregister() {
if (webSocketService == null) {
return;
}
try {
// Deregister processor, so that it won't receive messages anymore.
webSocketService.deregisterProcessor(endpointId, this);
webSocketService = null;
} catch (WebSocketConfigurationException e) {
logger.warn("Failed to deregister processor {} due to: {}", new Object[]{this, e}, e);
}
}
@Override
public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) {
if (processSessionFactory == null) {
processSessionFactory = sessionFactory;
}
if (!isProcessorRegisteredToService()) {
try {
registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService, context));
} catch (IOException | WebSocketConfigurationException e) {
// Deregister processor if it failed so that it can retry next onTrigger.
deregister();
context.yield();
throw new ProcessException("Failed to register processor to WebSocket service due to: " + e, e);
}
} else {
try {
onWebSocketServiceReady(webSocketService, context);
} catch (IOException e) {
deregister();
context.yield();
throw new ProcessException("Failed to renew session and connect to WebSocket service due to: " + e, e);
}
}
context.yield();//nothing really to do here since handling WebSocket messages is done at ControllerService.
}
private void enqueueMessage(final WebSocketMessage incomingMessage) {
final ProcessSession session = processSessionFactory.createSession();
try {
FlowFile messageFlowFile = session.create();
final Map<String, String> attrs = new HashMap<>();
attrs.put(ATTR_WS_CS_ID, webSocketService.getIdentifier());
final WebSocketSessionInfo sessionInfo = incomingMessage.getSessionInfo();
attrs.put(ATTR_WS_SESSION_ID, sessionInfo.getSessionId());
attrs.put(ATTR_WS_ENDPOINT_ID, endpointId);
attrs.put(ATTR_WS_LOCAL_ADDRESS, sessionInfo.getLocalAddress().toString());
attrs.put(ATTR_WS_REMOTE_ADDRESS, sessionInfo.getRemoteAddress().toString());
final WebSocketMessage.Type messageType = incomingMessage.getType();
if (messageType != null) {
attrs.put(ATTR_WS_MESSAGE_TYPE, messageType.name());
}
messageFlowFile = session.putAllAttributes(messageFlowFile, attrs);
final byte[] payload = incomingMessage.getPayload();
if (payload != null) {
messageFlowFile = session.write(messageFlowFile, out ->
out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength())
);
}
session.getProvenanceReporter().receive(messageFlowFile, getTransitUri(sessionInfo));
if (incomingMessage instanceof WebSocketConnectedMessage) {
session.transfer(messageFlowFile, REL_CONNECTED);
} else {
switch (Objects.requireNonNull(messageType)) {
case TEXT:
session.transfer(messageFlowFile, REL_MESSAGE_TEXT);
break;
case BINARY:
session.transfer(messageFlowFile, REL_MESSAGE_BINARY);
break;
}
}
session.commitAsync();
} catch (Exception e) {
logger.error("Unable to fully process input due to " + e, e);
session.rollback();
}
}
protected abstract String getTransitUri(final WebSocketSessionInfo sessionInfo);
}