/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "OpenSSLSocket.h"

#ifdef HAVE_OPENSSL
    #include <openssl/ssl.h>
    #include <openssl/tls1.h>
    #include <openssl/x509.h>
    #include <openssl/x509v3.h>
    #include <openssl/bio.h>
#endif

#include <decaf/net/SocketImpl.h>
#include <decaf/io/IOException.h>
#include <decaf/net/SocketException.h>
#include <decaf/lang/Boolean.h>
#include <decaf/lang/exceptions/NullPointerException.h>
#include <decaf/lang/exceptions/IndexOutOfBoundsException.h>
#include <decaf/internal/util/StringUtils.h>
#include <decaf/internal/net/SocketFileDescriptor.h>
#include <decaf/internal/net/ssl/openssl/OpenSSLParameters.h>
#include <decaf/internal/net/ssl/openssl/OpenSSLSocketException.h>
#include <decaf/internal/net/ssl/openssl/OpenSSLSocketInputStream.h>
#include <decaf/internal/net/ssl/openssl/OpenSSLSocketOutputStream.h>
#include <decaf/util/concurrent/Mutex.h>

using namespace decaf;
using namespace decaf::lang;
using namespace decaf::lang::exceptions;
using namespace decaf::io;
using namespace decaf::net;
using namespace decaf::net::ssl;
using namespace decaf::util::concurrent;
using namespace decaf::internal;
using namespace decaf::internal::util;
using namespace decaf::internal::net;
using namespace decaf::internal::net::ssl;
using namespace decaf::internal::net::ssl::openssl;

////////////////////////////////////////////////////////////////////////////////
namespace decaf {
namespace internal {
namespace net {
namespace ssl {
namespace openssl {

    class SocketData {
    public:

        bool handshakeStarted;
        bool handshakeCompleted;
        std::string commonName;

        Mutex handshakeLock;

    public:

        SocketData() : handshakeStarted(false),
                       handshakeCompleted(false),
                       commonName(),
                       handshakeLock() {
        }

        ~SocketData() {}

#ifdef HAVE_OPENSSL
        static int verifyCallback(int verified, X509_STORE_CTX* store DECAF_UNUSED) {
            if (!verified) {
                // Trap debug info here about why the Certificate failed to validate.
            }

            return verified;
        }
#endif

    };

}}}}}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters) :
    SSLSocket(), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) {

    if (parameters == NULL) {
        throw NullPointerException(__FILE__, __LINE__,
            "The OpenSSL Parameters object instance passed was NULL.");
    }
}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const InetAddress* address, int port) :
    SSLSocket(address, port), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) {

    if (parameters == NULL) {
        throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL.");
    }
}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const InetAddress* address, int port, const InetAddress* localAddress, int localPort) :
    SSLSocket(address, port, localAddress, localPort), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) {

    if (parameters == NULL) {
        throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL.");
    }
}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const std::string& host, int port) :
    SSLSocket(host, port), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) {

    if (parameters == NULL) {
        throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL.");
    }
}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const std::string& host, int port, const InetAddress* localAddress, int localPort) :
    SSLSocket(host, port, localAddress, localPort), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) {

    if (parameters == NULL) {
        throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL.");
    }
}

////////////////////////////////////////////////////////////////////////////////
OpenSSLSocket::~OpenSSLSocket() {
    try {

        SSLSocket::close();

#ifdef HAVE_OPENSSL
        if (this->parameters->getSSL()) {
            SSL_set_shutdown(this->parameters->getSSL(), SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
            SSL_shutdown(this->parameters->getSSL());
        }
#endif

        delete data;
        delete parameters;
        delete input;
        delete output;
    }
    DECAF_CATCH_NOTHROW(Exception)
    DECAF_CATCHALL_NOTHROW()}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::connect(const std::string& host, int port, int timeout) {

    try {

#ifdef HAVE_OPENSSL

        // Perform the actual Socket connection work
        SSLSocket::connect(host, port, timeout);

        // If we actually connected then we can connect the Socket to an OpenSSL
        // BIO filter so that we can use it in OpenSSL APIs.
        if (isConnected()) {

            BIO* bio = BIO_new(BIO_s_socket());
            if (!bio) {
                throw SocketException(__FILE__, __LINE__, "Failed to create SSL IO Bindings");
            }

            const SocketFileDescriptor* fd = dynamic_cast<const SocketFileDescriptor*>(this->impl->getFileDescriptor());

            if (fd == NULL) {
                throw SocketException(__FILE__, __LINE__, "Invalid File Descriptor returned from Socket");
            }

            BIO_set_fd(bio, (int )fd->getValue(), BIO_NOCLOSE);
            SSL_set_bio(this->parameters->getSSL(), bio, bio);

            // Later when startHandshake is called we will check for this common name
            // in the provided certificate
            this->data->commonName = host;
        }
#else
        throw SocketException( __FILE__, __LINE__, "Not Supported" );
#endif
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_RETHROW(IllegalArgumentException)
    DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::close() {

    try {

        if (isClosed()) {
            return;
        }

        SSLSocket::close();

        if (this->input != NULL) {
            this->input->close();
        }
        if (this->output != NULL) {
            this->output->close();
        }
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
decaf::io::InputStream* OpenSSLSocket::getInputStream() {

    checkClosed();

    try {
        if (this->input == NULL) {
            this->input = new OpenSSLSocketInputStream(this);
        }

        return this->input;
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
decaf::io::OutputStream* OpenSSLSocket::getOutputStream() {

    checkClosed();

    try {
        if (this->output == NULL) {
            this->output = new OpenSSLSocketOutputStream(this);
        }

        return this->output;
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::shutdownInput() {
    throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets");
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::shutdownOutput() {
    throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets");
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setOOBInline(bool value DECAF_UNUSED) {
    throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets");
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::sendUrgentData(int data DECAF_UNUSED) {
    throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets");
}

////////////////////////////////////////////////////////////////////////////////
decaf::net::ssl::SSLParameters OpenSSLSocket::getSSLParameters() const {

    SSLParameters params(this->getEnabledCipherSuites(), this->getEnabledProtocols());

    params.setServerNames(this->parameters->getServerNames());
    params.setNeedClientAuth(this->parameters->getNeedClientAuth());
    params.setWantClientAuth(this->parameters->getWantClientAuth());

    return params;
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setSSLParameters(const decaf::net::ssl::SSLParameters& value) {
    this->parameters->setEnabledCipherSuites(value.getCipherSuites());
    this->parameters->setEnabledProtocols(value.getProtocols());
    this->parameters->setServerNames(value.getServerNames());
}

////////////////////////////////////////////////////////////////////////////////
std::vector<std::string> OpenSSLSocket::getSupportedCipherSuites() const {
    return this->parameters->getSupportedCipherSuites();
}

////////////////////////////////////////////////////////////////////////////////
std::vector<std::string> OpenSSLSocket::getSupportedProtocols() const {
    return this->parameters->getSupportedProtocols();
}

////////////////////////////////////////////////////////////////////////////////
std::vector<std::string> OpenSSLSocket::getEnabledCipherSuites() const {
    return this->parameters->getEnabledCipherSuites();
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setEnabledCipherSuites(const std::vector<std::string>& suites) {
    this->parameters->setEnabledCipherSuites(suites);
}

////////////////////////////////////////////////////////////////////////////////
std::vector<std::string> OpenSSLSocket::getEnabledProtocols() const {
    return this->parameters->getEnabledProtocols();
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setEnabledProtocols(const std::vector<std::string>& protocols) {
    this->parameters->setEnabledProtocols(protocols);
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::startHandshake() {

    if (!this->isConnected()) {
        throw IOException(__FILE__, __LINE__, "Socket is not connected.");
    }

    if (this->isClosed()) {
        throw IOException(__FILE__, __LINE__, "Socket already closed.");
    }

    try {

#ifdef HAVE_OPENSSL
        synchronized( &(this->data->handshakeLock ) ) {

            if (this->data->handshakeStarted) {
                return;
            }

            this->data->handshakeStarted = true;

            bool peerVerifyDisabled = Boolean::parseBoolean(System::getProperty("decaf.net.ssl.disablePeerVerification", "false"));

            if (this->parameters->getUseClientMode()) {

                // Since we are a client we want to enforce peer verification, we set a
                // callback so we can collect data on why a verify failed for debugging.
                if (!peerVerifyDisabled) {
                                        // Check host https://wiki.openssl.org/index.php/Hostname_validation
                                        X509_VERIFY_PARAM *param = SSL_get0_param(this->parameters->getSSL());

                                        X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
                                        X509_VERIFY_PARAM_set1_host(param, this->data->commonName.c_str(), 0);

                    SSL_set_verify(this->parameters->getSSL(), SSL_VERIFY_PEER, SocketData::verifyCallback);
                } else {
                    SSL_set_verify(this->parameters->getSSL(), SSL_VERIFY_NONE, NULL);
                }

                std::vector<std::string> serverNames = this->parameters->getServerNames();
                if (!serverNames.empty()) {
                    std::string serverName = serverNames.at(0);
                    SSL_set_tlsext_host_name(this->parameters->getSSL(), serverName.c_str());
                }

                int result = SSL_connect(this->parameters->getSSL());

                // Checks the error status
                switch (SSL_get_error(this->parameters->getSSL(), result)) {
                case SSL_ERROR_NONE:
                    break;
                case SSL_ERROR_SSL:
                case SSL_ERROR_ZERO_RETURN:
                case SSL_ERROR_SYSCALL:
                               default:
                    SSLSocket::close();
                    throw OpenSSLSocketException(__FILE__, __LINE__);
                }

            } else { // We are in Server Mode.

                int mode = SSL_VERIFY_NONE;

                if (!peerVerifyDisabled) {

                    if (this->parameters->getWantClientAuth()) {
                        mode = SSL_VERIFY_PEER;
                    }

                    if (this->parameters->getNeedClientAuth()) {
                        mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
                    }
                }

                // Since we are a server we want to enforce peer verification, we set a
                // callback so we can collect data on why a verify failed for debugging.
                SSL_set_verify(this->parameters->getSSL(), mode, SocketData::verifyCallback);

                int result = SSL_accept(this->parameters->getSSL());

                if (result != SSL_ERROR_NONE) {
                    SSLSocket::close();
                    throw OpenSSLSocketException(__FILE__, __LINE__);
                }
            }

            this->data->handshakeCompleted = true;
        }
#else
        throw IOException( __FILE__, __LINE__, "SSL Not Supported." );
#endif
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setUseClientMode(bool value) {

    synchronized( &( this->data->handshakeLock ) ) {
        if (this->data->handshakeStarted) {
            throw IllegalArgumentException(__FILE__, __LINE__,
                "Handshake has already been started cannot change mode.");
        }

        this->parameters->setUseClientMode(value);
    }
}

////////////////////////////////////////////////////////////////////////////////
bool OpenSSLSocket::getUseClientMode() const {
    return this->parameters->getUseClientMode();
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setNeedClientAuth(bool value) {
    this->parameters->setNeedClientAuth(value);
}

////////////////////////////////////////////////////////////////////////////////
bool OpenSSLSocket::getNeedClientAuth() const {
    return this->parameters->getNeedClientAuth();
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::setWantClientAuth(bool value) {
    this->parameters->setWantClientAuth(value);
}

////////////////////////////////////////////////////////////////////////////////
bool OpenSSLSocket::getWantClientAuth() const {
    return this->parameters->getWantClientAuth();
}

////////////////////////////////////////////////////////////////////////////////
int OpenSSLSocket::read(unsigned char* buffer, int size, int offset, int length) {

    try {
        if (this->isClosed()) {
            throw IOException(__FILE__, __LINE__, "The Stream has been closed");
        }

        if (this->isInputShutdown() == true) {
            return -1;
        }

        if (length == 0) {
            return 0;
        }

        if (buffer == NULL) {
            throw NullPointerException(__FILE__, __LINE__, "Buffer passed is Null");
        }

        if (size < 0) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "size parameter out of Bounds: %d.", size);
        }

        if (offset > size || offset < 0) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "offset parameter out of Bounds: %d.", offset);
        }

        if (length < 0 || length > size - offset) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "length parameter out of Bounds: %d.", length);
        }

#ifdef HAVE_OPENSSL

        if (!this->data->handshakeCompleted) {
            this->startHandshake();
        }

        // Read data from the socket.
        int result = SSL_read(this->parameters->getSSL(), buffer + offset, length);

        switch (SSL_get_error(this->parameters->getSSL(), result)) {
        case SSL_ERROR_NONE:
            return result;
        case SSL_ERROR_ZERO_RETURN:
            if (!isClosed()) {
                this->shutdownInput();
                return -1;
            }
        default:
            throw OpenSSLSocketException(__FILE__, __LINE__);
        }
#else
        throw SocketException( __FILE__, __LINE__, "Not Supported" );
#endif
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_RETHROW(NullPointerException)
    DECAF_CATCH_RETHROW(IndexOutOfBoundsException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
void OpenSSLSocket::write(const unsigned char* buffer, int size, int offset, int length) {

    try {

        if (length == 0) {
            return;
        }

        if (buffer == NULL) {
            throw NullPointerException(__FILE__, __LINE__,
                "TcpSocketOutputStream::write - passed buffer is null");
        }

        if (isClosed()) {
            throw IOException(__FILE__, __LINE__,
                "TcpSocketOutputStream::write - This Stream has been closed.");
        }

        if (size < 0) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "size parameter out of Bounds: %d.", size);
        }

        if (offset > size || offset < 0) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "offset parameter out of Bounds: %d.", offset);
        }

        if (length < 0 || length > size - offset) {
            throw IndexOutOfBoundsException(__FILE__, __LINE__,
                "length parameter out of Bounds: %d.", length);
        }

#ifdef HAVE_OPENSSL

        if (!this->data->handshakeCompleted) {
            this->startHandshake();
        }

        int remaining = length;

        while (remaining > 0 && !isClosed()) {

            int written = SSL_write(this->parameters->getSSL(), buffer + offset, remaining);

            switch (SSL_get_error(this->parameters->getSSL(), written)) {
            case SSL_ERROR_NONE:
                offset += written;
                remaining -= written;
                break;
            case SSL_ERROR_ZERO_RETURN:
                throw SocketException(__FILE__, __LINE__, "The connection was broken unexpectedly.");
            default:
                throw OpenSSLSocketException(__FILE__, __LINE__);
            }
        }
#else
        throw SocketException( __FILE__, __LINE__, "Not Supported" );
#endif
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCH_RETHROW(NullPointerException)
    DECAF_CATCH_RETHROW(IndexOutOfBoundsException)
    DECAF_CATCHALL_THROW(IOException)
}

////////////////////////////////////////////////////////////////////////////////
int OpenSSLSocket::available() {

    try {

#ifdef HAVE_OPENSSL
        if (!isClosed()) {
            return SSL_pending(this->parameters->getSSL());
        }
#else
        throw SocketException( __FILE__, __LINE__, "Not Supported" );
#endif

        return -1;
    }
    DECAF_CATCH_RETHROW(IOException)
    DECAF_CATCHALL_THROW(IOException)
}
