/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 */
package org.apache.qpid.test.utils;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Collection;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A basic implementation of TCP traffic forwarder between ports.
 * It is intended to use in tests.
 */
public class TCPTunneler implements AutoCloseable
{
    private static final Logger LOGGER = LoggerFactory.getLogger(TCPTunneler.class);

    private final TCPWorker _tcpWorker;
    private final ExecutorService _executor;

    public TCPTunneler(final int localPort, final String remoteHost,
                       final int remotePort,
                       final int numberOfConcurrentClients)
    {
        _executor = Executors.newFixedThreadPool(numberOfConcurrentClients * 2 + 1);
        _tcpWorker = new TCPWorker(localPort, remoteHost, remotePort, _executor);
    }

    public void start() throws IOException
    {
        _tcpWorker.start();
    }

    public void stopClientToServerForwarding(final InetSocketAddress clientAddress)
    {
        _tcpWorker.stopClientToServerForwarding(clientAddress);
    }

    public void stop()
    {
        try
        {
            _tcpWorker.stop();
        }
        finally
        {
            _executor.shutdown();
        }
    }

    public void addClientListener(TunnelListener listener)
    {
        _tcpWorker.addClientListener(listener);
    }

    public void removeClientListener(TunnelListener listener)
    {
        _tcpWorker.removeClientListener(listener);
    }

    public void disconnect(InetSocketAddress address)
    {
        LOGGER.info("Disconnecting {}", address);
        if (address != null)
        {
            _tcpWorker.disconnect(address);
        }
    }

    @Override
    public void close() throws Exception
    {
        stop();
    }

    public interface TunnelListener
    {
        void clientConnected(InetSocketAddress clientAddress);

        void clientDisconnected(InetSocketAddress clientAddress);
    }

    private static class TCPWorker implements Runnable
    {
        private final String _remoteHost;
        private final int _remotePort;
        private final int _localPort;
        private final String _remoteHostPort;
        private final AtomicBoolean _closed;
        private final Collection<SocketTunnel> _tunnels;
        private final Collection<TunnelListener> _tunnelListeners;
        private final TunnelListener _notifyingListener;
        private volatile ServerSocket _serverSocket;
        private volatile ExecutorService _executor;
        private int _actualLocalPort;

        public TCPWorker(final int localPort,
                         final String remoteHost,
                         final int remotePort,
                         final ExecutorService executor)
        {
            _closed = new AtomicBoolean();
            _remoteHost = remoteHost;
            _remotePort = remotePort;
            _localPort = localPort;
            _remoteHostPort = _remoteHost + ":" + _remotePort;
            _executor = executor;
            _tunnels = new CopyOnWriteArrayList<>();
            _tunnelListeners = new CopyOnWriteArrayList<>();
            _notifyingListener = new TunnelListener()
            {
                @Override
                public void clientConnected(final InetSocketAddress clientAddress)
                {
                    notifyClientConnected(clientAddress);
                }

                @Override
                public void clientDisconnected(final InetSocketAddress clientAddress)
                {
                    try
                    {
                        notifyClientDisconnected(clientAddress);
                    }
                    finally
                    {
                        removeTunnel(clientAddress);
                    }
                }
            };
        }

        @Override
        public void run()
        {
            String threadName = Thread.currentThread().getName();
            try
            {
                Thread.currentThread().setName("TCPTunnelerAcceptingThread");
                while (!_closed.get())
                {
                    Socket acceptedSocket = _serverSocket.accept();
                    LOGGER.debug("Client opened socket {}", acceptedSocket);

                    createTunnel(acceptedSocket);
                }
            }
            catch (IOException e)
            {
                if (!_closed.get())
                {
                    LOGGER.error("Exception in accepting thread", e);
                }
            }
            finally
            {
                closeServerSocket();
                _closed.set(true);
                Thread.currentThread().setName(threadName);
            }
        }

        public void start()
        {
            _actualLocalPort = _localPort;
            try
            {
                _serverSocket = new ServerSocket(_localPort);
                _actualLocalPort = _serverSocket.getLocalPort();
                LOGGER.info                                  ("Starting TCPTunneler forwarding from port {} to {}",
                            _actualLocalPort, _remoteHostPort);
                _serverSocket.setReuseAddress(true);
            }
            catch (IOException e)
            {
                throw new RuntimeException("Cannot start TCPTunneler on port " + _actualLocalPort, e);
            }

            if (_serverSocket != null)
            {
                try
                {
                    _executor.execute(this);
                }
                catch (Exception e)
                {
                    try
                    {
                        closeServerSocket();
                    }
                    finally
                    {
                        throw new RuntimeException("Cannot start acceptor thread for TCPTunneler on port " + _actualLocalPort,
                                                   e);
                    }
                }
            }
        }

        public void stop()
        {
            if (_closed.compareAndSet(false, true))
            {
                LOGGER.info("Stopping TCPTunneler forwarding from port {} to {}",
                            _actualLocalPort,
                            _remoteHostPort);
                try
                {
                    for (SocketTunnel tunnel : _tunnels)
                    {
                        tunnel.close();
                    }
                }
                finally
                {
                    closeServerSocket();
                }

                LOGGER.info("TCPTunneler forwarding from port {} to {} is stopped",
                            _actualLocalPort,
                            _remoteHostPort);
            }
        }

        public void addClientListener(TunnelListener listener)
        {
            _tunnelListeners.add(listener);
            for (SocketTunnel socketTunnel : _tunnels)
            {
                try
                {
                    listener.clientConnected(socketTunnel.getClientAddress());
                }
                catch (Exception e)
                {
                    LOGGER.warn("Exception on notifying client listener about connected client", e);
                }
            }
        }

        public void removeClientListener(TunnelListener listener)
        {
            _tunnelListeners.remove(listener);
        }

        public void disconnect(final InetSocketAddress address)
        {
            SocketTunnel client = removeTunnel(address);
            if (client != null && !client.isClosed())
            {
                client.close();
                LOGGER.info("Tunnel for {} is disconnected", address);
            }
            else
            {
                LOGGER.info("Tunnel for {} not found", address);
            }
        }


        private void createTunnel(final Socket localSocket)
        {
            Socket remoteSocket = null;
            try
            {
                LOGGER.debug("Opening socket to {} for {}", _remoteHostPort, localSocket);
                remoteSocket = new Socket(_remoteHost, _remotePort);
                LOGGER.debug("Opened socket to {} for {}", remoteSocket, localSocket);
                SocketTunnel tunnel = new SocketTunnel(localSocket, remoteSocket, _notifyingListener);
                LOGGER.debug("Socket tunnel is created from {} to {}", localSocket, remoteSocket);
                _tunnels.add(tunnel);
                tunnel.start(_executor);
            }
            catch (Exception e)
            {
                LOGGER.error("Cannot forward i/o traffic between {} and {}", localSocket, _remoteHostPort, e);
                SocketTunnel.closeSocket(localSocket);
                SocketTunnel.closeSocket(remoteSocket);
            }
        }

        private void notifyClientConnected(final InetSocketAddress clientAddress)
        {
            for (TunnelListener listener : _tunnelListeners)
            {
                try
                {
                    listener.clientConnected(clientAddress);
                }
                catch (Exception e)
                {
                    LOGGER.warn("Exception on notifying client listener about connected client", e);
                }
            }
        }


        private void notifyClientDisconnected(final InetSocketAddress clientAddress)
        {
            for (TunnelListener listener : _tunnelListeners)
            {
                try
                {
                    listener.clientDisconnected(clientAddress);
                }
                catch (Exception e)
                {
                    LOGGER.warn("Exception on notifying client listener about disconnected client", e);
                }
            }
        }

        public void stopClientToServerForwarding(final InetSocketAddress clientAddress)
        {
            SocketTunnel target = null;
            for (SocketTunnel tunnel : _tunnels)
            {
                if (tunnel.getClientAddress().equals(clientAddress))
                {
                    target = tunnel;
                    break;
                }
            }
            if (target != null)
            {
                LOGGER.debug("Stopping forwarding from client {} to server", clientAddress);
                target.stopClientToServerForwarding();
            }
            else
            {
                throw new IllegalArgumentException("Could not find tunnel for address " + clientAddress);
            }
        }

        private void closeServerSocket()
        {
            if (_serverSocket != null)
            {
                try
                {
                    _serverSocket.close();
                }
                catch (IOException e)
                {
                    LOGGER.warn("Exception on closing of accepting socket", e);
                }
                finally
                {
                    _serverSocket = null;
                }
            }
        }

        private SocketTunnel removeTunnel(final InetSocketAddress clientAddress)
        {
            SocketTunnel client = null;
            for (SocketTunnel c : _tunnels)
            {
                if (c.isClientAddress(clientAddress))
                {
                    client = c;
                    break;
                }
            }
            if (client != null)
            {
                _tunnels.remove(client);
            }
            return client;
        }

    }

    private static class SocketTunnel
    {
        private final Socket _clientSocket;
        private final Socket _serverSocket;
        private final TunnelListener _tunnelListener;
        private final AtomicBoolean _closed;
        private final AutoClosingStreamForwarder _clientToServer;
        private final AutoClosingStreamForwarder _serverToClient;
        private final InetSocketAddress _clientSocketAddress;

        public SocketTunnel(final Socket clientSocket,
                            final Socket serverSocket,
                            final TunnelListener tunnelListener) throws IOException
        {
            _clientSocket = clientSocket;
            _clientSocketAddress =
                    new InetSocketAddress(clientSocket.getInetAddress().getHostName(), _clientSocket.getPort());
            _serverSocket = serverSocket;
            _closed = new AtomicBoolean();
            _tunnelListener = tunnelListener;
            _clientSocket.setKeepAlive(true);
            _serverSocket.setKeepAlive(true);
            _clientToServer = new AutoClosingStreamForwarder(new StreamForwarder(_clientSocket, _serverSocket));
            _serverToClient = new AutoClosingStreamForwarder(new StreamForwarder(_serverSocket, _clientSocket));
        }

        public void close()
        {
            if (_closed.compareAndSet(false, true))
            {
                try
                {
                    closeSocket(_serverSocket);
                    closeSocket(_clientSocket);
                }
                finally
                {
                    _tunnelListener.clientDisconnected(getClientAddress());
                }
            }
        }

        public void start(Executor executor) throws IOException
        {
            executor.execute(_clientToServer);
            executor.execute(_serverToClient);
            _tunnelListener.clientConnected(getClientAddress());
        }

        public void stopClientToServerForwarding()
        {
            _clientToServer.stopForwarding();
        }

        public boolean isClosed()
        {
            return _closed.get();
        }

        public boolean isClientAddress(final InetSocketAddress clientAddress)
        {
            return getClientAddress().equals(clientAddress);
        }

        public InetSocketAddress getClientAddress()
        {
            return _clientSocketAddress;
        }

        private static void closeSocket(Socket socket)
        {
            if (socket != null)
            {
                try
                {
                    socket.close();
                }
                catch (IOException e)
                {
                    LOGGER.warn("Exception on closing of socket {}", socket, e);
                }
            }
        }


        private class AutoClosingStreamForwarder implements Runnable
        {
            private StreamForwarder _streamForwarder;

            public AutoClosingStreamForwarder(StreamForwarder streamForwarder)
            {
                _streamForwarder = streamForwarder;
            }

            @Override
            public void run()
            {
                Thread currentThread = Thread.currentThread();
                String originalThreadName = currentThread.getName();
                try
                {
                    currentThread.setName(_streamForwarder.getName());
                    _streamForwarder.run();
                }
                finally
                {
                    close();
                    currentThread.setName(originalThreadName);
                }
            }

            public void stopForwarding()
            {
                _streamForwarder.stopForwarding();
            }
        }
    }

    private static class StreamForwarder implements Runnable
    {
        private static final int BUFFER_SIZE = 4096;

        private final InputStream _inputStream;
        private final OutputStream _outputStream;
        private final String _name;
        private AtomicBoolean _stopForwarding = new AtomicBoolean();

        public StreamForwarder(Socket input, Socket output) throws IOException
        {
            _inputStream = input.getInputStream();
            _outputStream = output.getOutputStream();
            _name = "Forwarder-" + input.getLocalSocketAddress() + "->" + output.getRemoteSocketAddress();
        }

        @Override
        public void run()
        {
            byte[] buffer = new byte[BUFFER_SIZE];
            int bytesRead;
            try
            {
                while ((bytesRead = _inputStream.read(buffer)) != -1)
                {
                    if (!_stopForwarding.get())
                    {
                        _outputStream.write(buffer, 0, bytesRead);
                        _outputStream.flush();
                        LOGGER.debug("Forwarded {} byte(s)", bytesRead);
                    }
                    else
                    {
                        LOGGER.debug("Discarded {} byte(s)", bytesRead);
                    }
                }
            }
            catch (IOException e)
            {
                LOGGER.warn("Exception on forwarding data for {}: {}", _name, e.getMessage());
            }
            finally
            {
                try
                {
                    _inputStream.close();
                }
                catch (IOException e)
                {
                    // ignore
                }

                try
                {
                    _outputStream.close();
                }
                catch (IOException e)
                {
                    // ignore
                }
            }
        }


        public String getName()
        {
            return _name;
        }

        public void stopForwarding()
        {
            _stopForwarding.set(true);
        }

    }
}
