blob: d40b0e93734760721a3aeeef611038b7c960a757 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <boost/algorithm/string.hpp>
#include <boost/asio.hpp>
#include <boost/assign.hpp>
#include <boost/bind.hpp>
#include "drill/drillc.hpp"
#include "drill/drillError.hpp"
#include "clientlib/channel.hpp"
#include "clientlib/drillClientImpl.hpp"
#include "clientlib/errmsgs.hpp"
#include "clientlib/logger.hpp"
#include "clientlib/rpcMessage.hpp"
#include "clientlib/utils.hpp"
#include "protobuf/GeneralRPC.pb.h"
#include "protobuf/UserBitShared.pb.h"
namespace Drill {
class DrillTestClient {
public:
DrillTestClient(Channel* pChannel):
m_handshakeStatus(exec::user::SUCCESS),
m_wbuf(MAX_SOCK_RD_BUFSIZE),
m_rbuf(0){
m_pChannel=pChannel;
m_pError=NULL;
m_coordinationId=Utils::s_randomNumber()%1729+1;
}
connectionStatus_t recvHandshake(){
if(m_rbuf==NULL){
m_rbuf = Utils::allocateBuffer(MAX_SOCK_RD_BUFSIZE);
}
m_pChannel->getIOService().reset();
m_pChannel->getSocketStream().asyncRead(
boost::asio::buffer(m_rbuf, LEN_PREFIX_BUFLEN),
boost::bind(
&DrillTestClient::handleHandshake,
this,
m_rbuf,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)
);
DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "DrillClientImpl::recvHandshake: async read waiting for server handshake response.\n";)
m_pChannel->getIOService().run();
if(m_rbuf!=NULL){
Utils::freeBuffer(m_rbuf, MAX_SOCK_RD_BUFSIZE); m_rbuf=NULL;
}
if (m_pError != NULL) {
DRILL_MT_LOG(DRILL_LOG(LOG_ERROR) << "DrillClientImpl::recvHandshake: failed to complete handshake with server."
<< m_pError->msg << "\n";)
return static_cast<connectionStatus_t>(m_pError->status);
}
return CONN_SUCCESS;
}
void doReadFromSocket(ByteBuf_t inBuf, size_t bytesToRead, boost::system::error_code& errorCode) {
// Check if bytesToRead is zero
if(0 == bytesToRead) {
return;
}
// Read all the bytes. In case when all the bytes were not read the proper
// errorCode will be set.
while(1){
size_t dataBytesRead = m_pChannel->getSocketStream().readSome(boost::asio::buffer(inBuf, bytesToRead), errorCode);
// Update the state
bytesToRead -= dataBytesRead;
inBuf += dataBytesRead;
// Check if errorCode is EINTR then just retry otherwise break from loop
if(EINTR != errorCode.value()) break;
// Check if all the data is read then break from loop
if(0 == bytesToRead) break;
}
}
void handleHandshake(ByteBuf_t inBuf,
const boost::system::error_code& err,
size_t bytes_transferred) {
boost::system::error_code error=err;
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Deadline timer cancelled." << std::endl;)
if(!error){
rpc::InBoundRpcMessage msg;
uint32_t length = 0;
std::size_t bytes_read = rpc::lengthDecode(m_rbuf, length);
if(length>0){
const size_t leftover = LEN_PREFIX_BUFLEN - bytes_read;
const ByteBuf_t b = m_rbuf + LEN_PREFIX_BUFLEN;
const size_t bytesToRead=length - leftover;
doReadFromSocket(b, bytesToRead, error);
// Check if any error happen while reading the message bytes. If yes then return before decoding the Msg
if(error) {
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. "
<< " Failed to read entire handshake message. with error: "
<< error.message().c_str() << "\n";)
handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Failed to read entire handshake message"));
return;
}
// Decode the bytes into a valid RPC Message
if (!decode(m_rbuf+bytes_read, length, msg)) {
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. Cannot decode handshake.\n";)
handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "Cannot decode handshake"));
return;
}
}else{
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "DrillClientImpl::handleHandshake: ERR_CONN_RDFAIL. No handshake.\n";)
handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, "No handshake"));
return;
}
exec::user::BitToUserHandshake b2u;
b2u.ParseFromArray(msg.m_pbody.data(), msg.m_pbody.size());
this->m_handshakeErrorId=b2u.errorid();
this->m_handshakeErrorMsg=b2u.errormessage();
}else{
// boost error
if(error==boost::asio::error::eof){ // Server broke off the connection
handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_NOHSHAKE, DRILL_RPC_VERSION));
}else{
handleConnError(CONN_FAILURE, getMessage(ERR_CONN_RDFAIL, error.message().c_str()));
}
return;
}
return;
}
connectionStatus_t handleConnError(connectionStatus_t status, const std::string& msg){
DrillClientError* pErr = new DrillClientError(status, DrillClientError::CONN_ERROR_START+status, msg);
if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
m_pError=pErr;
return status;
}
connectionStatus_t sendSyncCommon(rpc::OutBoundRpcMessage& msg) {
encode(m_wbuf, msg);
boost::system::error_code ec;
doWriteToSocket(reinterpret_cast<char*>(m_wbuf.data()), m_wbuf.size(), ec);
if(!ec) {
return CONN_SUCCESS;
} else {
return handleConnError(CONN_FAILURE, getMessage(ERR_CONN_WFAIL, ec.message().c_str()));
}
}
void doWriteToSocket(const char* dataPtr, size_t bytesToWrite,
boost::system::error_code& errorCode) {
if(0 == bytesToWrite) {
return;
}
// Write all the bytes to socket. In case of error when all bytes are not successfully written
// proper errorCode will be set.
while(1) {
size_t bytesWritten = m_pChannel->getSocketStream().writeSome(boost::asio::buffer(dataPtr, bytesToWrite), errorCode);
// Update the state
bytesToWrite -= bytesWritten;
dataPtr += bytesWritten;
if(EINTR != errorCode.value()) break;
// Check if all the data is written then break from loop
if(0 == bytesToWrite) break;
}
}
connectionStatus_t validateHandshake(DrillUserProperties* properties){
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "validateHandShake\n";)
exec::user::UserToBitHandshake u2b;
u2b.set_channel(exec::shared::USER);
u2b.set_rpc_version(DRILL_RPC_VERSION);
u2b.set_support_listening(true);
u2b.set_support_timeout(DrillClientConfig::getHeartbeatFrequency() > 0);
u2b.set_sasl_support(exec::user::SASL_PRIVACY);
// Adding version info
exec::user::RpcEndpointInfos* infos = u2b.mutable_client_infos();
infos->set_name(DrillClientConfig::getClientName());
infos->set_application(DrillClientConfig::getApplicationName());
infos->set_version(DRILL_VERSION_STRING);
infos->set_majorversion(DRILL_VERSION_MAJOR);
infos->set_minorversion(DRILL_VERSION_MINOR);
infos->set_patchversion(DRILL_VERSION_PATCH);
if(properties != NULL && properties->size()>0){
std::string username;
std::string err;
if(!properties->validate(err)){
DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Invalid user input:" << err << std::endl;)
}
exec::user::UserProperties* userProperties = u2b.mutable_properties();
std::map<char,int>::iterator it;
for (std::map<std::string,std::string>::const_iterator propIter=properties->begin(); propIter!=properties->end(); ++propIter){
std::string currKey=propIter->first;
std::string currVal=propIter->second;
std::map<std::string,uint32_t>::const_iterator it=DrillUserProperties::USER_PROPERTIES.find(currKey);
if(it==DrillUserProperties::USER_PROPERTIES.end()){
DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << "Connection property ("<< currKey
<< ") is unknown" << std::endl;)
exec::user::Property* connProp = userProperties->add_properties();
connProp->set_key(currKey);
connProp->set_value(currVal);
continue;
}
if(IS_BITSET((*it).second,USERPROP_FLAGS_SERVERPROP)){
exec::user::Property* connProp = userProperties->add_properties();
connProp->set_key(currKey);
connProp->set_value(currVal);
//Username(but not the password) also needs to be set in UserCredentials
if(IS_BITSET((*it).second,USERPROP_FLAGS_USERNAME)){
exec::shared::UserCredentials* creds = u2b.mutable_credentials();
username=currVal;
creds->set_user_name(username);
//u2b.set_credentials(&creds);
}
if(IS_BITSET((*it).second,USERPROP_FLAGS_PASSWORD)){
DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ": ********** " << std::endl;)
}else{
DRILL_MT_LOG(DRILL_LOG(LOG_INFO) << currKey << ":" << currVal << std::endl;)
}
}// Server properties
}
}
{
boost::lock_guard<boost::mutex> lock(this->m_dcMutex);
uint64_t coordId = ++m_coordinationId;
rpc::OutBoundRpcMessage out_msg(exec::rpc::REQUEST, exec::user::HANDSHAKE, coordId, &u2b);
sendSyncCommon(out_msg);
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Sent handshake request message. Coordination id: " << coordId << "\n";)
}
connectionStatus_t ret = recvHandshake();
if(ret!=CONN_SUCCESS){
return ret;
}
switch(this->m_handshakeStatus) {
case exec::user::SUCCESS:
// reset io_service after handshake is validated before running queries
m_pChannel->getIOService().reset();
return CONN_SUCCESS;
case exec::user::RPC_VERSION_MISMATCH:
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Invalid rpc version. Expected "
<< DRILL_RPC_VERSION << ", actual "<< 0 << "." << std::endl;)
return handleConnError(CONN_BAD_RPC_VER, getMessage(ERR_CONN_BAD_RPC_VER, DRILL_RPC_VERSION,
0,
this->m_handshakeErrorId.c_str(),
this->m_handshakeErrorMsg.c_str()));
case exec::user::AUTH_FAILED:
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Authentication failed." << std::endl;)
return handleConnError(CONN_AUTH_FAILED, getMessage(ERR_CONN_AUTHFAIL,
this->m_handshakeErrorId.c_str(),
this->m_handshakeErrorMsg.c_str()));
case exec::user::UNKNOWN_FAILURE:
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown error during handshake." << std::endl;)
return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
this->m_handshakeErrorId.c_str(),
this->m_handshakeErrorMsg.c_str()));
case exec::user::AUTH_REQUIRED:
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Server requires SASL authentication." << std::endl;)
return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
this->m_handshakeErrorId.c_str(),
this->m_handshakeErrorMsg.c_str()));
default:
DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Unknown return status." << std::endl;)
return handleConnError(CONN_HANDSHAKE_FAILED, getMessage(ERR_CONN_UNKNOWN_ERR,
this->m_handshakeErrorId.c_str(),
this->m_handshakeErrorMsg.c_str()));
}
}
DrillClientError* m_pError;
private:
Channel* m_pChannel;
int32_t m_coordinationId;
std::string m_handshakeErrorId;
std::string m_handshakeErrorMsg;
exec::user::HandshakeStatus m_handshakeStatus;
DataBuf m_wbuf;
ByteBuf_t m_rbuf;
boost::mutex m_dcMutex;
};
} // namespace Drill
using namespace Drill;
int main(int argc, char* argv[]){
Channel *pChannel = NULL;
ChannelContext *pChannelContext = NULL;
std::string connectStr = "zk=localhost:2181/drill/drillbits1";
//std::string connectStr = "drillbit=localhost:31090";
channelType_t type;
boost::asio::io_service ioService;
bool isSSL = argc==2 && !(strcmp(argv[1], "ssl"));
type = CHANNEL_TYPE_SOCKET;
if(isSSL){
type = CHANNEL_TYPE_SSLSTREAM;
}
Drill::DrillUserProperties props;
props.setProperty(USERPROP_USERNAME, "admin");
props.setProperty(USERPROP_PASSWORD, "admin");
props.setProperty(USERPROP_CERTFILEPATH, "../../../test/ssl/drillTestCert.pem");
pChannel = ChannelFactory::getChannel(type, ioService, connectStr.c_str(), &props);
if(pChannel != NULL){
connectionStatus_t connStat;
connStat = pChannel->init();
if(connStat != CONN_SUCCESS){
std::cout << "Init Failed." << std::endl;
return -1;
}
connStat = pChannel->connect();
if(connStat != CONN_SUCCESS){
std::cout << "Connect Failed." << std::endl;
std::cout << pChannel->getError()->msg << std::endl;
return -1;
}
} else{
std::cout << "Channel creation failed." << std::endl;
return -1;
}
std::cout << "Connected." << std::endl;
std::cout << "Starting Drill handshake" << std::endl;
DrillTestClient client(pChannel);
connectionStatus_t stat = client.validateHandshake(&props);
if(stat == CONN_SUCCESS){
std::cout << "Handshake validated." << std::endl;
} else{
if(client.m_pError != NULL){
std::cout << "Handshake failed: " << client.m_pError->msg << ". " << std::endl;
} else{
std::cout << "Handshake failed with unknown error" << ". " << std::endl;
}
}
return 0;
}