blob: f7dfe3e14a7d1a526fa8dbc25b823e2233251515 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CHANNEL_HPP
#define CHANNEL_HPP
#include "drill/common.hpp"
#include "drill/drillClient.hpp"
#include "streamSocket.hpp"
#include "errmsgs.hpp"
#if defined(IS_SSL_ENABLED)
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif
namespace Drill {
class UserProperties;
class ConnectionEndpoint{
public:
ConnectionEndpoint(const char* connStr);
ConnectionEndpoint(const char* host, const char* port);
~ConnectionEndpoint();
//parse the connection string and set up the host and port to connect to
connectionStatus_t getDrillbitEndpoint();
const std::string& getProtocol() const {return m_protocol;}
const std::string& getHost() const {return m_host;}
const std::string& getPort() const {return m_port;}
DrillClientError* getError(){ return m_pError;};
private:
void parseConnectString();
bool isDirectConnection();
bool isZookeeperConnection();
connectionStatus_t getDrillbitEndpointFromZk();
connectionStatus_t handleError(connectionStatus_t status, std::string msg);
std::string m_connectString;
std::string m_pathToDrill;
std::string m_protocol;
std::string m_hostPortStr;
std::string m_host;
std::string m_port;
DrillClientError* m_pError;
};
class ChannelContext{
public:
ChannelContext(DrillUserProperties* props):m_properties(props){};
virtual ~ChannelContext(){};
const DrillUserProperties* getUserProperties() const { return m_properties;}
protected:
DrillUserProperties* m_properties;
};
class SSLChannelContext: public ChannelContext{
public:
static boost::asio::ssl::context::method getTlsVersion(const std::string & version){
if (version == "tlsv12") {
return boost::asio::ssl::context::tlsv12;
} else if (version == "tlsv11") {
return boost::asio::ssl::context::tlsv11;
} else if (version == "tlsv1") {
return boost::asio::ssl::context::tlsv1;
} else if ((version == "tlsv1+") || (version == "tlsv11+") || (version == "tlsv12+")) {
// SSLv2 and SSLv3 are disabled, so this is the equivalent of 'tls' only mode.
// In boost version 1.64+, they've added support for context::tls; method.
return boost::asio::ssl::context::sslv23;
} else {
return boost::asio::ssl::context::tlsv12;
}
}
/// @brief Applies Minimum TLS protocol restrictions.
/// tlsv11+ means restrict to TLS version 1.1 and higher.
/// tlsv12+ means restrict to TLS version 1.2 and higher.
/// Please note that SSL_OP_NO_TLSv tags are deprecated in openSSL 1.1.0.
///
/// @param in_ver The protocol version.
///
/// @return The SSL context options.
static long ApplyMinTLSRestriction(const std::string & in_ver){
#if defined(IS_SSL_ENABLED)
if (in_ver == "tlsv11+") {
return SSL_OP_NO_TLSv1;
} else if (in_ver == "tlsv12+") {
return (SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
}
#endif
return SSL_OP_NO_SSLv3;
}
SSLChannelContext(DrillUserProperties *props,
boost::asio::ssl::context::method tlsVersion,
boost::asio::ssl::verify_mode verifyMode,
const std::string& hostnameOverride,
const long customSSLCtxOptions = 0) :
ChannelContext(props),
m_SSLContext(tlsVersion),
m_hostnameOverride(hostnameOverride),
m_certHostnameVerificationStatus(true)
{
m_SSLContext.set_default_verify_paths();
m_SSLContext.set_options(
boost::asio::ssl::context::default_workarounds
| boost::asio::ssl::context::no_sslv2
| boost::asio::ssl::context::no_sslv3
| boost::asio::ssl::context::single_dh_use
| customSSLCtxOptions
);
m_SSLContext.set_verify_mode(verifyMode);
};
~SSLChannelContext(){};
boost::asio::ssl::context& getSslContext(){ return m_SSLContext;}
/// @brief Check the certificate host name verification status.
///
/// @return FALSE if the verification has failed, TRUE otherwise.
const bool GetCertificateHostnameVerificationStatus() const { return m_certHostnameVerificationStatus; }
/// @brief Set the certificate host name verification status.
///
/// @param in_result The host name verification status.
void SetCertHostnameVerificationStatus(bool in_result) { m_certHostnameVerificationStatus = in_result; }
/// @brief Returns the overridden hostname used for certificate verification
///
/// @return the hostname override, or empty if the hostname should not be overridden.
const std::string& GetHostnameOverride() { return m_hostnameOverride; }
private:
boost::asio::ssl::context m_SSLContext;
// The hostname to verify. Unused if empty.
std::string m_hostnameOverride;
// The flag to indicate the host name verification result.
bool m_certHostnameVerificationStatus;
};
typedef ChannelContext ChannelContext_t;
typedef SSLChannelContext SSLChannelContext_t;
/***
* The Channel class encapsulates a connection to a drillbit. Based on
* the connection string and the options, the connection will be either
* a simple socket or a socket using an ssl stream. The class also encapsulates
* connecting to a drillbit directly or thru zookeeper.
* The channel class owns the socket and the io_service that the applications
* will use to communicate with the server.
***/
class Channel{
friend class ChannelFactory;
public:
Channel(boost::asio::io_service& ioService, const char* connStr);
Channel(boost::asio::io_service& ioService, const char* host, const char* port);
virtual ~Channel();
virtual connectionStatus_t init()=0;
connectionStatus_t connect();
bool isConnected(){ return m_state == CHANNEL_CONNECTED;}
template <typename SettableSocketOption> void setOption(SettableSocketOption& option);
DrillClientError* getError(){ return m_pError;}
void close(){
if(m_state==CHANNEL_INITIALIZED||m_state==CHANNEL_CONNECTED){
m_pSocket->protocolClose();
m_state=CHANNEL_CLOSED;
}
} // Not OK to use the channel after this call.
boost::asio::io_service& getIOService(){
return m_ioService;
}
// returns a reference to the underlying socket
// This access should really be removed and encapsulated in calls that
// manage async_send and async_recv
// Until then we will let DrillClientImpl have direct access
streamSocket_t& getInnerSocket(){
return m_pSocket->getInnerSocket();
}
AsioStreamSocket& getSocketStream(){
return *m_pSocket;
}
ConnectionEndpoint* getEndpoint(){return m_pEndpoint;}
ChannelContext_t* getChannelContext(){ return m_pContext; }
protected:
connectionStatus_t handleError(connectionStatus_t status, std::string msg);
/// @brief Handle protocol handshake exceptions.
///
/// @param in_err The error.
///
/// @return the connectionStatus.
virtual connectionStatus_t HandleProtocolHandshakeException(const boost::system::system_error& in_err){
return handleError(CONN_HANDSHAKE_FAILED, in_err.what());
}
virtual connectionStatus_t setSocketInformation() {
return CONN_SUCCESS;
}
boost::asio::io_service& m_ioService;
boost::asio::io_service m_ioServiceFallback; // used if m_ioService is not provided
AsioStreamSocket* m_pSocket;
ConnectionEndpoint *m_pEndpoint;
ChannelContext_t *m_pContext;
private:
typedef enum channelState{
CHANNEL_UNINITIALIZED=1,
CHANNEL_INITIALIZED,
CHANNEL_CONNECTED,
CHANNEL_CLOSED
} channelState_t;
connectionStatus_t connectInternal();
connectionStatus_t protocolHandshake(bool useSystemConfig){
connectionStatus_t status = CONN_SUCCESS;
try{
m_pSocket->protocolHandshake(useSystemConfig);
} catch (boost::system::system_error e) {
status = HandleProtocolHandshakeException(e);
}
return status;
}
channelState_t m_state;
DrillClientError* m_pError;
};
class SocketChannel: public Channel{
public:
SocketChannel(boost::asio::io_service& ioService, const char* connStr)
:Channel(ioService, connStr){
}
SocketChannel(boost::asio::io_service& ioService, const char* host, const char* port)
:Channel(ioService, host, port){
}
connectionStatus_t init();
};
class SSLStreamChannel: public Channel{
public:
SSLStreamChannel(boost::asio::io_service& ioService, const char* connStr)
:Channel(ioService, connStr){
}
SSLStreamChannel(boost::asio::io_service& ioService, const char* host, const char* port)
:Channel(ioService, host, port){
}
connectionStatus_t init();
protected:
#if defined(IS_SSL_ENABLED)
/// @brief Handle protocol handshake exceptions for SSL specific failures.
///
/// @param in_err The error.
///
/// @return the connectionStatus.
connectionStatus_t HandleProtocolHandshakeException(const boost::system::system_error& in_err) {
const boost::system::error_code& errcode = in_err.code();
if (!(((SSLChannelContext_t *)m_pContext)->GetCertificateHostnameVerificationStatus())){
return handleError(
CONN_HANDSHAKE_FAILED,
getMessage(ERR_CONN_SSL_CN, in_err.what()));
}
else if (boost::asio::error::get_ssl_category() == errcode.category() &&
SSL_R_CERTIFICATE_VERIFY_FAILED == ERR_GET_REASON(errcode.value())){
return handleError(
CONN_HANDSHAKE_FAILED,
getMessage(ERR_CONN_SSL_CERTVERIFY, in_err.what()));
}
else if (boost::asio::error::get_ssl_category() == errcode.category() &&
SSL_R_UNSUPPORTED_PROTOCOL == ERR_GET_REASON(errcode.value())){
return handleError(
CONN_HANDSHAKE_FAILED,
getMessage(ERR_CONN_SSL_PROTOVER, in_err.what()));
}
else{
return handleError(
CONN_HANDSHAKE_FAILED,
getMessage(ERR_CONN_SSL_GENERAL, in_err.what()));
}
}
connectionStatus_t setSocketInformation() {
const char* sniProperty;
SSLChannelContext_t& context = *((SSLChannelContext_t *)m_pContext);
if (!context.GetHostnameOverride().empty()){
sniProperty = context.GetHostnameOverride().c_str();
}
else{
sniProperty = m_pEndpoint->getHost().c_str();
}
if (!SSL_set_tlsext_host_name(((SslSocket *)m_pSocket)->getSocketStream().native_handle(), sniProperty)) {
return handleError(CONN_SSLERROR, getMessage(ERR_CONN_SSL_SNI, sniProperty, ERR_func_error_string(ERR_get_error())));
}
return CONN_SUCCESS;
}
#endif
};
class ChannelFactory{
public:
static Channel* getChannel(channelType_t t,
boost::asio::io_service& ioService,
const char* connStr, DrillUserProperties* props);
static Channel* getChannel(channelType_t t,
boost::asio::io_service& ioService,
const char* host,
const char* port,
DrillUserProperties* props);
private:
static ChannelContext_t* getChannelContext(channelType_t t, DrillUserProperties* props);
};
/// @brief Hostname verification callback wrapper.
class DrillSSLHostnameVerifier{
public:
/// @brief The constructor.
///
/// @param in_channel The Channel.
DrillSSLHostnameVerifier(Channel* in_channel) : m_channel(in_channel){
DRILL_LOG(LOG_INFO)
<< "DrillSSLHostnameVerifier::DrillSSLHostnameVerifier: +++++ Enter +++++"
<< std::endl;
}
/// @brief Perform certificate verification.
///
/// @param in_preverified Pre-verified indicator.
/// @param in_ctx Verify context.
bool operator()(
bool in_preverified,
boost::asio::ssl::verify_context& in_ctx){
DRILL_LOG(LOG_INFO) << "DrillSSLHostnameVerifier::operator(): +++++ Enter +++++" << std::endl;
// Gets the channel context.
SSLChannelContext_t* context = (SSLChannelContext_t*)(m_channel->getChannelContext());
const char* hostname;
if (context->GetHostnameOverride().empty()) {
// Retrieve the host before we perform Host name verification.
// This is because host with ZK mode is selected after the connect() function is called.
hostname = m_channel->getEndpoint()->getHost().c_str();
} else {
hostname = context->GetHostnameOverride().c_str();
}
boost::asio::ssl::rfc2818_verification verifier(hostname);
// Perform verification.
bool verified = verifier(in_preverified, in_ctx);
DRILL_LOG(LOG_DEBUG)
<< "DrillSSLHostnameVerifier::operator(): Verification Result: "
<< verified
<< std::endl;
// Sets the result back to the context.
context->SetCertHostnameVerificationStatus(verified);
return verified;
}
private:
// The SSL channel.
Channel* m_channel;
};
} // namespace Drill
#endif // CHANNEL_HPP