/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.nifi.websocket.jetty;

import org.apache.nifi.processor.Processor;
import org.apache.nifi.websocket.BinaryMessageConsumer;
import org.apache.nifi.websocket.ConnectedListener;
import org.apache.nifi.websocket.TextMessageConsumer;
import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketServerService;
import org.apache.nifi.websocket.WebSocketSessionInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;

import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;


public class ITJettyWebSocketCommunication {

    protected int serverPort;
    protected String serverPath = "/test";
    protected WebSocketServerService serverService;
    protected ControllerServiceTestContext serverServiceContext;
    protected WebSocketClientService clientService;
    protected ControllerServiceTestContext clientServiceContext;

    protected boolean isSecure() {
        return false;
    }

    @Before
    public void setup() throws Exception {
        setupServer();

        setupClient();
    }

    private void setupServer() throws Exception {
        // Find an open port.
        try (final ServerSocket serverSocket = new ServerSocket(0)) {
            serverPort = serverSocket.getLocalPort();
        }
        serverService = new JettyWebSocketServer();
        serverServiceContext = new ControllerServiceTestContext(serverService, "JettyWebSocketServer1");
        serverServiceContext.setCustomValue(JettyWebSocketServer.LISTEN_PORT, String.valueOf(serverPort));
        serverServiceContext.setCustomValue(JettyWebSocketServer.BASIC_AUTH, "true");
        serverServiceContext.setCustomValue(JettyWebSocketServer.USERS_PROPERTIES_FILE,
                getClass().getResource("/users.properties").getPath());
        serverServiceContext.setCustomValue(JettyWebSocketServer.AUTH_ROLES, "user,test");

        customizeServer();

        serverService.initialize(serverServiceContext.getInitializationContext());
        serverService.startServer(serverServiceContext.getConfigurationContext());
    }

    protected void customizeServer() {
    }

    private void setupClient() throws Exception {
        clientService = new JettyWebSocketClient();
        clientServiceContext = new ControllerServiceTestContext(clientService, "JettyWebSocketClient1");
        clientServiceContext.setCustomValue(JettyWebSocketClient.WS_URI, (isSecure() ? "wss" : "ws") + "://localhost:" + serverPort + serverPath);

        clientServiceContext.setCustomValue(JettyWebSocketClient.USER_NAME, "user2");
        clientServiceContext.setCustomValue(JettyWebSocketClient.USER_PASSWORD, "password2");

        customizeClient();

        clientService.initialize(clientServiceContext.getInitializationContext());
        clientService.startClient(clientServiceContext.getConfigurationContext());
    }

    protected void customizeClient() {
    }

    @After
    public void teardown() throws Exception {
        clientService.stopClient();
        serverService.stopServer();
    }

    protected interface MockWebSocketProcessor extends Processor, ConnectedListener, TextMessageConsumer, BinaryMessageConsumer {
    }

    private boolean isWindowsEnvironment() {
        return System.getProperty("os.name").toLowerCase().startsWith("windows");
    }

    @Test
    public void testClientServerCommunication() throws Exception {
        assumeFalse(isWindowsEnvironment());
        // Expectations.
        final CountDownLatch serverIsConnectedByClient = new CountDownLatch(1);
        final CountDownLatch clientConnectedServer = new CountDownLatch(1);
        final CountDownLatch serverReceivedTextMessageFromClient = new CountDownLatch(1);
        final CountDownLatch serverReceivedBinaryMessageFromClient = new CountDownLatch(1);
        final CountDownLatch clientReceivedTextMessageFromServer = new CountDownLatch(1);
        final CountDownLatch clientReceivedBinaryMessageFromServer = new CountDownLatch(1);

        final String textMessageFromClient = "Message from client.";
        final String textMessageFromServer = "Message from server.";

        final MockWebSocketProcessor serverProcessor = mock(MockWebSocketProcessor.class);
        doReturn("serverProcessor1").when(serverProcessor).getIdentifier();
        final AtomicReference<String> serverSessionIdRef = new AtomicReference<>();

        doAnswer(invocation -> assertConnectedEvent(serverIsConnectedByClient, serverSessionIdRef, invocation))
                .when(serverProcessor).connected(any(WebSocketSessionInfo.class));

        doAnswer(invocation -> assertConsumeTextMessage(serverReceivedTextMessageFromClient, textMessageFromClient, invocation))
                .when(serverProcessor).consume(any(WebSocketSessionInfo.class), anyString());

        doAnswer(invocation -> assertConsumeBinaryMessage(serverReceivedBinaryMessageFromClient, textMessageFromClient, invocation))
                .when(serverProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());

        serverService.registerProcessor(serverPath, serverProcessor);

        final String clientId = "client1";

        final MockWebSocketProcessor clientProcessor = mock(MockWebSocketProcessor.class);
        doReturn("clientProcessor1").when(clientProcessor).getIdentifier();
        final AtomicReference<String> clientSessionIdRef = new AtomicReference<>();


        doAnswer(invocation -> assertConnectedEvent(clientConnectedServer, clientSessionIdRef, invocation))
                .when(clientProcessor).connected(any(WebSocketSessionInfo.class));

        doAnswer(invocation -> assertConsumeTextMessage(clientReceivedTextMessageFromServer, textMessageFromServer, invocation))
                .when(clientProcessor).consume(any(WebSocketSessionInfo.class), anyString());

        doAnswer(invocation -> assertConsumeBinaryMessage(clientReceivedBinaryMessageFromServer, textMessageFromServer, invocation))
                .when(clientProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());

        clientService.registerProcessor(clientId, clientProcessor);

        clientService.connect(clientId);

        assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));

        clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendString(textMessageFromClient));
        clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromClient.getBytes())));


        assertTrue("WebSocket server should be able to consume text message.", serverReceivedTextMessageFromClient.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket server should be able to consume binary message.", serverReceivedBinaryMessageFromClient.await(5, TimeUnit.SECONDS));

        serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendString(textMessageFromServer));
        serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromServer.getBytes())));

        assertTrue("WebSocket client should be able to consume text message.", clientReceivedTextMessageFromServer.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket client should be able to consume binary message.", clientReceivedBinaryMessageFromServer.await(5, TimeUnit.SECONDS));

        clientService.deregisterProcessor(clientId, clientProcessor);
        serverService.deregisterProcessor(serverPath, serverProcessor);
    }

    @Test
    public void testClientServerCommunicationRecovery() throws Exception {
        assumeFalse(isWindowsEnvironment());
        // Expectations.
        final CountDownLatch serverIsConnectedByClient = new CountDownLatch(1);
        final CountDownLatch clientConnectedServer = new CountDownLatch(1);
        final CountDownLatch serverReceivedTextMessageFromClient = new CountDownLatch(1);
        final CountDownLatch serverReceivedBinaryMessageFromClient = new CountDownLatch(1);
        final CountDownLatch clientReceivedTextMessageFromServer = new CountDownLatch(1);
        final CountDownLatch clientReceivedBinaryMessageFromServer = new CountDownLatch(1);

        final String textMessageFromClient = "Message from client.";
        final String textMessageFromServer = "Message from server.";

        final MockWebSocketProcessor serverProcessor = mock(MockWebSocketProcessor.class);
        doReturn("serverProcessor1").when(serverProcessor).getIdentifier();
        final AtomicReference<String> serverSessionIdRef = new AtomicReference<>();

        doAnswer(invocation -> assertConnectedEvent(serverIsConnectedByClient, serverSessionIdRef, invocation))
                .when(serverProcessor).connected(any(WebSocketSessionInfo.class));

        doAnswer(invocation -> assertConsumeTextMessage(serverReceivedTextMessageFromClient, textMessageFromClient, invocation))
                .when(serverProcessor).consume(any(WebSocketSessionInfo.class), anyString());

        doAnswer(invocation -> assertConsumeBinaryMessage(serverReceivedBinaryMessageFromClient, textMessageFromClient, invocation))
                .when(serverProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());

        serverService.registerProcessor(serverPath, serverProcessor);

        final String clientId = "client1";

        final MockWebSocketProcessor clientProcessor = mock(MockWebSocketProcessor.class);
        doReturn("clientProcessor1").when(clientProcessor).getIdentifier();
        final AtomicReference<String> clientSessionIdRef = new AtomicReference<>();


        doAnswer(invocation -> assertConnectedEvent(clientConnectedServer, clientSessionIdRef, invocation))
                .when(clientProcessor).connected(any(WebSocketSessionInfo.class));

        doAnswer(invocation -> assertConsumeTextMessage(clientReceivedTextMessageFromServer, textMessageFromServer, invocation))
                .when(clientProcessor).consume(any(WebSocketSessionInfo.class), anyString());

        doAnswer(invocation -> assertConsumeBinaryMessage(clientReceivedBinaryMessageFromServer, textMessageFromServer, invocation))
                .when(clientProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());

        clientService.registerProcessor(clientId, clientProcessor);

        clientService.connect(clientId, Collections.emptyMap());

        assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));

        // Nothing happens if maintenance is executed while sessions are alive.
        ((JettyWebSocketClient) clientService).maintainSessions();

        // Restart server.
        serverService.stopServer();
        serverService.startServer(serverServiceContext.getConfigurationContext());

        // Sessions will be recreated with the same session ids.
        ((JettyWebSocketClient) clientService).maintainSessions();

        clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendString(textMessageFromClient));
        clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromClient.getBytes())));

        assertTrue("WebSocket server should be able to consume text message.", serverReceivedTextMessageFromClient.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket server should be able to consume binary message.", serverReceivedBinaryMessageFromClient.await(5, TimeUnit.SECONDS));

        serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendString(textMessageFromServer));
        serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromServer.getBytes())));

        assertTrue("WebSocket client should be able to consume text message.", clientReceivedTextMessageFromServer.await(5, TimeUnit.SECONDS));
        assertTrue("WebSocket client should be able to consume binary message.", clientReceivedBinaryMessageFromServer.await(5, TimeUnit.SECONDS));

        clientService.deregisterProcessor(clientId, clientProcessor);
        serverService.deregisterProcessor(serverPath, serverProcessor);
    }

    protected Object assertConnectedEvent(CountDownLatch latch, AtomicReference<String> sessionIdRef, InvocationOnMock invocation) {
        final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
        assertNotNull(sessionInfo.getLocalAddress());
        assertNotNull(sessionInfo.getRemoteAddress());
        assertNotNull(sessionInfo.getSessionId());
        assertEquals(isSecure(), sessionInfo.isSecure());
        sessionIdRef.set(sessionInfo.getSessionId());
        latch.countDown();
        return null;
    }

    protected Object assertConsumeTextMessage(CountDownLatch latch, String expectedMessage, InvocationOnMock invocation) {
        final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
        assertNotNull(sessionInfo.getLocalAddress());
        assertNotNull(sessionInfo.getRemoteAddress());
        assertNotNull(sessionInfo.getSessionId());
        assertEquals(isSecure(), sessionInfo.isSecure());

        final String receivedMessage = invocation.getArgument(1);
        assertNotNull(receivedMessage);
        assertEquals(expectedMessage, receivedMessage);
        latch.countDown();
        return null;
    }

    protected Object assertConsumeBinaryMessage(CountDownLatch latch, String expectedMessage, InvocationOnMock invocation) {
        final WebSocketSessionInfo sessionInfo = invocation.getArgument(0);
        assertNotNull(sessionInfo.getLocalAddress());
        assertNotNull(sessionInfo.getRemoteAddress());
        assertNotNull(sessionInfo.getSessionId());
        assertEquals(isSecure(), sessionInfo.isSecure());

        final byte[] receivedMessage = invocation.getArgument(1);
        final byte[] expectedBinary = expectedMessage.getBytes();
        final int offset = invocation.getArgument(2);
        final int length = invocation.getArgument(3);
        assertNotNull(receivedMessage);
        assertEquals(expectedBinary.length, receivedMessage.length);
        assertEquals(expectedMessage, new String(receivedMessage));
        assertEquals(0, offset);
        assertEquals(expectedBinary.length, length);
        latch.countDown();
        return null;
    }

}
