/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.nifi.processors.websocket;

import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.annotation.behavior.TriggerSerially;
import org.apache.nifi.annotation.lifecycle.OnStopped;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.AbstractSessionFactoryProcessor;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.ProcessSessionFactory;
import org.apache.nifi.processor.ProcessorInitializationContext;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.websocket.BinaryMessageConsumer;
import org.apache.nifi.websocket.ConnectedListener;
import org.apache.nifi.websocket.TextMessageConsumer;
import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketConfigurationException;
import org.apache.nifi.websocket.WebSocketConnectedMessage;
import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService;
import org.apache.nifi.websocket.WebSocketSessionInfo;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;

@TriggerSerially
public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionFactoryProcessor implements ConnectedListener, TextMessageConsumer, BinaryMessageConsumer {

    protected volatile ComponentLog logger;
    protected volatile ProcessSessionFactory processSessionFactory;

    protected WebSocketService webSocketService;
    protected String endpointId;

    public static final Relationship REL_CONNECTED = new Relationship.Builder()
            .name("connected")
            .description("The WebSocket session is established")
            .build();

    public static final Relationship REL_MESSAGE_TEXT = new Relationship.Builder()
            .name("text message")
            .description("The WebSocket text message output")
            .build();

    public static final Relationship REL_MESSAGE_BINARY = new Relationship.Builder()
            .name("binary message")
            .description("The WebSocket binary message output")
            .build();

    static Set<Relationship> getAbstractRelationships() {
        final Set<Relationship> relationships = new HashSet<>();
        relationships.add(REL_CONNECTED);
        relationships.add(REL_MESSAGE_TEXT);
        relationships.add(REL_MESSAGE_BINARY);
        return relationships;
    }

    @Override
    protected void init(final ProcessorInitializationContext context) {
        logger = getLogger();
    }

    @FunctionalInterface
    public interface WebSocketFunction {
        void execute(final WebSocketService webSocketService) throws IOException, WebSocketConfigurationException;
    }

    @Override
    public void connected(WebSocketSessionInfo sessionInfo) {
        final WebSocketMessage message = new WebSocketConnectedMessage(sessionInfo);
        sessionInfo.setTransitUri(getTransitUri(sessionInfo));
        enqueueMessage(message);
    }

    @Override
    public void consume(WebSocketSessionInfo sessionInfo, String messageStr) {
        final WebSocketMessage message = new WebSocketMessage(sessionInfo);
        sessionInfo.setTransitUri(getTransitUri(sessionInfo));
        message.setPayload(messageStr);
        enqueueMessage(message);
    }

    @Override
    public void consume(WebSocketSessionInfo sessionInfo, byte[] payload, int offset, int length) {
        final WebSocketMessage message = new WebSocketMessage(sessionInfo);
        sessionInfo.setTransitUri(getTransitUri(sessionInfo));
        message.setPayload(payload, offset, length);
        enqueueMessage(message);
    }

    // @OnScheduled can not report error messages well on bulletin since it's an async method.
    // So, let's do it in onTrigger().
    public void onWebSocketServiceReady(final WebSocketService webSocketService, final ProcessContext context) throws IOException {
        if (webSocketService instanceof WebSocketClientService) {
            // If it's a ws client, then connect to the remote here.
            // Otherwise, ws server is already started at WebSocketServerService
            WebSocketClientService webSocketClientService = (WebSocketClientService) webSocketService;
            if (context.hasIncomingConnection()) {
                final ProcessSession session = processSessionFactory.createSession();
                final FlowFile flowFile = session.get();
                final Map<String, String> attributes = flowFile.getAttributes();
                try {
                    webSocketClientService.connect(endpointId, attributes);
                } finally {
                    session.remove(flowFile);
                    session.commitAsync();
                }
            } else {
                webSocketClientService.connect(endpointId);
            }
        }

    }

    protected void registerProcessorToService(final ProcessContext context, final WebSocketFunction afterRegistration) throws IOException, WebSocketConfigurationException {
        webSocketService = getWebSocketService(context);
        endpointId = getEndpointId(context);
        webSocketService.registerProcessor(endpointId, this);

        afterRegistration.execute(webSocketService);
    }

    protected abstract WebSocketService getWebSocketService(final ProcessContext context);

    protected abstract String getEndpointId(final ProcessContext context);

    protected boolean isProcessorRegisteredToService() {
        return webSocketService != null
                && !StringUtils.isEmpty(endpointId)
                && webSocketService.isProcessorRegistered(endpointId, this);
    }

    @OnStopped
    public void onStopped(final ProcessContext context) {
        deregister();
    }

    private void deregister() {
        if (webSocketService == null) {
            return;
        }

        try {
            // Deregister processor, so that it won't receive messages anymore.
            webSocketService.deregisterProcessor(endpointId, this);
            webSocketService = null;
        } catch (WebSocketConfigurationException e) {
            logger.warn("Failed to deregister processor {} due to: {}", new Object[]{this, e}, e);
        }
    }

    @Override
    public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) {
        if (processSessionFactory == null) {
            processSessionFactory = sessionFactory;
        }

        if (!isProcessorRegisteredToService()) {
            try {
                registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService, context));
            } catch (IOException | WebSocketConfigurationException e) {
                // Deregister processor if it failed so that it can retry next onTrigger.
                deregister();
                context.yield();
                throw new ProcessException("Failed to register processor to WebSocket service due to: " + e, e);
            }

        } else {
            try {
                onWebSocketServiceReady(webSocketService, context);
            } catch (IOException e) {
                deregister();
                context.yield();
                throw new ProcessException("Failed to renew session and connect to WebSocket service due to: " + e, e);
            }
        }

        context.yield();//nothing really to do here since handling WebSocket messages is done at ControllerService.
    }


    private void enqueueMessage(final WebSocketMessage incomingMessage) {
        final ProcessSession session = processSessionFactory.createSession();
        try {
            FlowFile messageFlowFile = session.create();

            final Map<String, String> attrs = new HashMap<>();
            attrs.put(ATTR_WS_CS_ID, webSocketService.getIdentifier());
            final WebSocketSessionInfo sessionInfo = incomingMessage.getSessionInfo();
            attrs.put(ATTR_WS_SESSION_ID, sessionInfo.getSessionId());
            attrs.put(ATTR_WS_ENDPOINT_ID, endpointId);
            attrs.put(ATTR_WS_LOCAL_ADDRESS, sessionInfo.getLocalAddress().toString());
            attrs.put(ATTR_WS_REMOTE_ADDRESS, sessionInfo.getRemoteAddress().toString());
            final WebSocketMessage.Type messageType = incomingMessage.getType();
            if (messageType != null) {
                attrs.put(ATTR_WS_MESSAGE_TYPE, messageType.name());
            }

            messageFlowFile = session.putAllAttributes(messageFlowFile, attrs);

            final byte[] payload = incomingMessage.getPayload();
            if (payload != null) {
                messageFlowFile = session.write(messageFlowFile, out ->
                        out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength())
                );
            }

            session.getProvenanceReporter().receive(messageFlowFile, getTransitUri(sessionInfo));

            if (incomingMessage instanceof WebSocketConnectedMessage) {
                session.transfer(messageFlowFile, REL_CONNECTED);
            } else {
                switch (Objects.requireNonNull(messageType)) {
                    case TEXT:
                        session.transfer(messageFlowFile, REL_MESSAGE_TEXT);
                        break;
                    case BINARY:
                        session.transfer(messageFlowFile, REL_MESSAGE_BINARY);
                        break;
                }
            }
            session.commitAsync();

        } catch (Exception e) {
            logger.error("Unable to fully process input due to " + e, e);
            session.rollback();
        }
    }

    protected abstract String getTransitUri(final WebSocketSessionInfo sessionInfo);

}
