| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "OpenSSLSocket.h" |
| |
| #ifdef HAVE_OPENSSL |
| #include <openssl/ssl.h> |
| #include <openssl/tls1.h> |
| #include <openssl/x509.h> |
| #include <openssl/x509v3.h> |
| #include <openssl/bio.h> |
| #endif |
| |
| #include <decaf/net/SocketImpl.h> |
| #include <decaf/io/IOException.h> |
| #include <decaf/net/SocketException.h> |
| #include <decaf/lang/Boolean.h> |
| #include <decaf/lang/exceptions/NullPointerException.h> |
| #include <decaf/lang/exceptions/IndexOutOfBoundsException.h> |
| #include <decaf/internal/util/StringUtils.h> |
| #include <decaf/internal/net/SocketFileDescriptor.h> |
| #include <decaf/internal/net/ssl/openssl/OpenSSLParameters.h> |
| #include <decaf/internal/net/ssl/openssl/OpenSSLSocketException.h> |
| #include <decaf/internal/net/ssl/openssl/OpenSSLSocketInputStream.h> |
| #include <decaf/internal/net/ssl/openssl/OpenSSLSocketOutputStream.h> |
| #include <decaf/util/concurrent/Mutex.h> |
| |
| using namespace decaf; |
| using namespace decaf::lang; |
| using namespace decaf::lang::exceptions; |
| using namespace decaf::io; |
| using namespace decaf::net; |
| using namespace decaf::net::ssl; |
| using namespace decaf::util::concurrent; |
| using namespace decaf::internal; |
| using namespace decaf::internal::util; |
| using namespace decaf::internal::net; |
| using namespace decaf::internal::net::ssl; |
| using namespace decaf::internal::net::ssl::openssl; |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| namespace decaf { |
| namespace internal { |
| namespace net { |
| namespace ssl { |
| namespace openssl { |
| |
| class SocketData { |
| public: |
| |
| bool handshakeStarted; |
| bool handshakeCompleted; |
| std::string commonName; |
| |
| Mutex handshakeLock; |
| |
| public: |
| |
| SocketData() : handshakeStarted(false), |
| handshakeCompleted(false), |
| commonName(), |
| handshakeLock() { |
| } |
| |
| ~SocketData() {} |
| |
| #ifdef HAVE_OPENSSL |
| static int verifyCallback(int verified, X509_STORE_CTX* store DECAF_UNUSED) { |
| if (!verified) { |
| // Trap debug info here about why the Certificate failed to validate. |
| } |
| |
| return verified; |
| } |
| #endif |
| |
| }; |
| |
| }}}}} |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters) : |
| SSLSocket(), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) { |
| |
| if (parameters == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, |
| "The OpenSSL Parameters object instance passed was NULL."); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const InetAddress* address, int port) : |
| SSLSocket(address, port), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) { |
| |
| if (parameters == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL."); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const InetAddress* address, int port, const InetAddress* localAddress, int localPort) : |
| SSLSocket(address, port, localAddress, localPort), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) { |
| |
| if (parameters == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL."); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const std::string& host, int port) : |
| SSLSocket(host, port), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) { |
| |
| if (parameters == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL."); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::OpenSSLSocket(OpenSSLParameters* parameters, const std::string& host, int port, const InetAddress* localAddress, int localPort) : |
| SSLSocket(host, port, localAddress, localPort), data(new SocketData()), parameters(parameters), input(NULL), output(NULL) { |
| |
| if (parameters == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, "The OpenSSL Parameters object instance passed was NULL."); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| OpenSSLSocket::~OpenSSLSocket() { |
| try { |
| |
| SSLSocket::close(); |
| |
| #ifdef HAVE_OPENSSL |
| if (this->parameters->getSSL()) { |
| SSL_set_shutdown(this->parameters->getSSL(), SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN); |
| SSL_shutdown(this->parameters->getSSL()); |
| } |
| #endif |
| |
| delete data; |
| delete parameters; |
| delete input; |
| delete output; |
| } |
| DECAF_CATCH_NOTHROW(Exception) |
| DECAF_CATCHALL_NOTHROW()} |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::connect(const std::string& host, int port, int timeout) { |
| |
| try { |
| |
| #ifdef HAVE_OPENSSL |
| |
| // Perform the actual Socket connection work |
| SSLSocket::connect(host, port, timeout); |
| |
| // If we actually connected then we can connect the Socket to an OpenSSL |
| // BIO filter so that we can use it in OpenSSL APIs. |
| if (isConnected()) { |
| |
| BIO* bio = BIO_new(BIO_s_socket()); |
| if (!bio) { |
| throw SocketException(__FILE__, __LINE__, "Failed to create SSL IO Bindings"); |
| } |
| |
| const SocketFileDescriptor* fd = dynamic_cast<const SocketFileDescriptor*>(this->impl->getFileDescriptor()); |
| |
| if (fd == NULL) { |
| throw SocketException(__FILE__, __LINE__, "Invalid File Descriptor returned from Socket"); |
| } |
| |
| BIO_set_fd(bio, (int )fd->getValue(), BIO_NOCLOSE); |
| SSL_set_bio(this->parameters->getSSL(), bio, bio); |
| |
| // Later when startHandshake is called we will check for this common name |
| // in the provided certificate |
| this->data->commonName = host; |
| } |
| #else |
| throw SocketException( __FILE__, __LINE__, "Not Supported" ); |
| #endif |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_RETHROW(IllegalArgumentException) |
| DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::close() { |
| |
| try { |
| |
| if (isClosed()) { |
| return; |
| } |
| |
| SSLSocket::close(); |
| |
| if (this->input != NULL) { |
| this->input->close(); |
| } |
| if (this->output != NULL) { |
| this->output->close(); |
| } |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| decaf::io::InputStream* OpenSSLSocket::getInputStream() { |
| |
| checkClosed(); |
| |
| try { |
| if (this->input == NULL) { |
| this->input = new OpenSSLSocketInputStream(this); |
| } |
| |
| return this->input; |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| decaf::io::OutputStream* OpenSSLSocket::getOutputStream() { |
| |
| checkClosed(); |
| |
| try { |
| if (this->output == NULL) { |
| this->output = new OpenSSLSocketOutputStream(this); |
| } |
| |
| return this->output; |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_EXCEPTION_CONVERT(Exception, IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::shutdownInput() { |
| throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets"); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::shutdownOutput() { |
| throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets"); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setOOBInline(bool value DECAF_UNUSED) { |
| throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets"); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::sendUrgentData(int data DECAF_UNUSED) { |
| throw SocketException(__FILE__, __LINE__, "Not supported for SSL Sockets"); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| decaf::net::ssl::SSLParameters OpenSSLSocket::getSSLParameters() const { |
| |
| SSLParameters params(this->getEnabledCipherSuites(), this->getEnabledProtocols()); |
| |
| params.setServerNames(this->parameters->getServerNames()); |
| params.setNeedClientAuth(this->parameters->getNeedClientAuth()); |
| params.setWantClientAuth(this->parameters->getWantClientAuth()); |
| |
| return params; |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setSSLParameters(const decaf::net::ssl::SSLParameters& value) { |
| this->parameters->setEnabledCipherSuites(value.getCipherSuites()); |
| this->parameters->setEnabledProtocols(value.getProtocols()); |
| this->parameters->setServerNames(value.getServerNames()); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| std::vector<std::string> OpenSSLSocket::getSupportedCipherSuites() const { |
| return this->parameters->getSupportedCipherSuites(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| std::vector<std::string> OpenSSLSocket::getSupportedProtocols() const { |
| return this->parameters->getSupportedProtocols(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| std::vector<std::string> OpenSSLSocket::getEnabledCipherSuites() const { |
| return this->parameters->getEnabledCipherSuites(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setEnabledCipherSuites(const std::vector<std::string>& suites) { |
| this->parameters->setEnabledCipherSuites(suites); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| std::vector<std::string> OpenSSLSocket::getEnabledProtocols() const { |
| return this->parameters->getEnabledProtocols(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setEnabledProtocols(const std::vector<std::string>& protocols) { |
| this->parameters->setEnabledProtocols(protocols); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::startHandshake() { |
| |
| if (!this->isConnected()) { |
| throw IOException(__FILE__, __LINE__, "Socket is not connected."); |
| } |
| |
| if (this->isClosed()) { |
| throw IOException(__FILE__, __LINE__, "Socket already closed."); |
| } |
| |
| try { |
| |
| #ifdef HAVE_OPENSSL |
| synchronized( &(this->data->handshakeLock ) ) { |
| |
| if (this->data->handshakeStarted) { |
| return; |
| } |
| |
| this->data->handshakeStarted = true; |
| |
| bool peerVerifyDisabled = Boolean::parseBoolean(System::getProperty("decaf.net.ssl.disablePeerVerification", "false")); |
| |
| if (this->parameters->getUseClientMode()) { |
| |
| // Since we are a client we want to enforce peer verification, we set a |
| // callback so we can collect data on why a verify failed for debugging. |
| if (!peerVerifyDisabled) { |
| // Check host https://wiki.openssl.org/index.php/Hostname_validation |
| X509_VERIFY_PARAM *param = SSL_get0_param(this->parameters->getSSL()); |
| |
| X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); |
| X509_VERIFY_PARAM_set1_host(param, this->data->commonName.c_str(), 0); |
| |
| SSL_set_verify(this->parameters->getSSL(), SSL_VERIFY_PEER, SocketData::verifyCallback); |
| } else { |
| SSL_set_verify(this->parameters->getSSL(), SSL_VERIFY_NONE, NULL); |
| } |
| |
| std::vector<std::string> serverNames = this->parameters->getServerNames(); |
| if (!serverNames.empty()) { |
| std::string serverName = serverNames.at(0); |
| SSL_set_tlsext_host_name(this->parameters->getSSL(), serverName.c_str()); |
| } |
| |
| int result = SSL_connect(this->parameters->getSSL()); |
| |
| // Checks the error status |
| switch (SSL_get_error(this->parameters->getSSL(), result)) { |
| case SSL_ERROR_NONE: |
| break; |
| case SSL_ERROR_SSL: |
| case SSL_ERROR_ZERO_RETURN: |
| case SSL_ERROR_SYSCALL: |
| default: |
| SSLSocket::close(); |
| throw OpenSSLSocketException(__FILE__, __LINE__); |
| } |
| |
| } else { // We are in Server Mode. |
| |
| int mode = SSL_VERIFY_NONE; |
| |
| if (!peerVerifyDisabled) { |
| |
| if (this->parameters->getWantClientAuth()) { |
| mode = SSL_VERIFY_PEER; |
| } |
| |
| if (this->parameters->getNeedClientAuth()) { |
| mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; |
| } |
| } |
| |
| // Since we are a server we want to enforce peer verification, we set a |
| // callback so we can collect data on why a verify failed for debugging. |
| SSL_set_verify(this->parameters->getSSL(), mode, SocketData::verifyCallback); |
| |
| int result = SSL_accept(this->parameters->getSSL()); |
| |
| if (result != SSL_ERROR_NONE) { |
| SSLSocket::close(); |
| throw OpenSSLSocketException(__FILE__, __LINE__); |
| } |
| } |
| |
| this->data->handshakeCompleted = true; |
| } |
| #else |
| throw IOException( __FILE__, __LINE__, "SSL Not Supported." ); |
| #endif |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setUseClientMode(bool value) { |
| |
| synchronized( &( this->data->handshakeLock ) ) { |
| if (this->data->handshakeStarted) { |
| throw IllegalArgumentException(__FILE__, __LINE__, |
| "Handshake has already been started cannot change mode."); |
| } |
| |
| this->parameters->setUseClientMode(value); |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| bool OpenSSLSocket::getUseClientMode() const { |
| return this->parameters->getUseClientMode(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setNeedClientAuth(bool value) { |
| this->parameters->setNeedClientAuth(value); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| bool OpenSSLSocket::getNeedClientAuth() const { |
| return this->parameters->getNeedClientAuth(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::setWantClientAuth(bool value) { |
| this->parameters->setWantClientAuth(value); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| bool OpenSSLSocket::getWantClientAuth() const { |
| return this->parameters->getWantClientAuth(); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| int OpenSSLSocket::read(unsigned char* buffer, int size, int offset, int length) { |
| |
| try { |
| if (this->isClosed()) { |
| throw IOException(__FILE__, __LINE__, "The Stream has been closed"); |
| } |
| |
| if (this->isInputShutdown() == true) { |
| return -1; |
| } |
| |
| if (length == 0) { |
| return 0; |
| } |
| |
| if (buffer == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, "Buffer passed is Null"); |
| } |
| |
| if (size < 0) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "size parameter out of Bounds: %d.", size); |
| } |
| |
| if (offset > size || offset < 0) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "offset parameter out of Bounds: %d.", offset); |
| } |
| |
| if (length < 0 || length > size - offset) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "length parameter out of Bounds: %d.", length); |
| } |
| |
| #ifdef HAVE_OPENSSL |
| |
| if (!this->data->handshakeCompleted) { |
| this->startHandshake(); |
| } |
| |
| // Read data from the socket. |
| int result = SSL_read(this->parameters->getSSL(), buffer + offset, length); |
| |
| switch (SSL_get_error(this->parameters->getSSL(), result)) { |
| case SSL_ERROR_NONE: |
| return result; |
| case SSL_ERROR_ZERO_RETURN: |
| if (!isClosed()) { |
| this->shutdownInput(); |
| return -1; |
| } |
| default: |
| throw OpenSSLSocketException(__FILE__, __LINE__); |
| } |
| #else |
| throw SocketException( __FILE__, __LINE__, "Not Supported" ); |
| #endif |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_RETHROW(NullPointerException) |
| DECAF_CATCH_RETHROW(IndexOutOfBoundsException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| void OpenSSLSocket::write(const unsigned char* buffer, int size, int offset, int length) { |
| |
| try { |
| |
| if (length == 0) { |
| return; |
| } |
| |
| if (buffer == NULL) { |
| throw NullPointerException(__FILE__, __LINE__, |
| "TcpSocketOutputStream::write - passed buffer is null"); |
| } |
| |
| if (isClosed()) { |
| throw IOException(__FILE__, __LINE__, |
| "TcpSocketOutputStream::write - This Stream has been closed."); |
| } |
| |
| if (size < 0) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "size parameter out of Bounds: %d.", size); |
| } |
| |
| if (offset > size || offset < 0) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "offset parameter out of Bounds: %d.", offset); |
| } |
| |
| if (length < 0 || length > size - offset) { |
| throw IndexOutOfBoundsException(__FILE__, __LINE__, |
| "length parameter out of Bounds: %d.", length); |
| } |
| |
| #ifdef HAVE_OPENSSL |
| |
| if (!this->data->handshakeCompleted) { |
| this->startHandshake(); |
| } |
| |
| int remaining = length; |
| |
| while (remaining > 0 && !isClosed()) { |
| |
| int written = SSL_write(this->parameters->getSSL(), buffer + offset, remaining); |
| |
| switch (SSL_get_error(this->parameters->getSSL(), written)) { |
| case SSL_ERROR_NONE: |
| offset += written; |
| remaining -= written; |
| break; |
| case SSL_ERROR_ZERO_RETURN: |
| throw SocketException(__FILE__, __LINE__, "The connection was broken unexpectedly."); |
| default: |
| throw OpenSSLSocketException(__FILE__, __LINE__); |
| } |
| } |
| #else |
| throw SocketException( __FILE__, __LINE__, "Not Supported" ); |
| #endif |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCH_RETHROW(NullPointerException) |
| DECAF_CATCH_RETHROW(IndexOutOfBoundsException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| int OpenSSLSocket::available() { |
| |
| try { |
| |
| #ifdef HAVE_OPENSSL |
| if (!isClosed()) { |
| return SSL_pending(this->parameters->getSSL()); |
| } |
| #else |
| throw SocketException( __FILE__, __LINE__, "Not Supported" ); |
| #endif |
| |
| return -1; |
| } |
| DECAF_CATCH_RETHROW(IOException) |
| DECAF_CATCHALL_THROW(IOException) |
| } |