blob: 368f68067485af70884e4b46deacc8d0ab1670c1 [file] [log] [blame]
// This file will be removed when the code is accepted into the Thrift library.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef IMPALA_TRANSPORT_TSSLTRANSPORT_H
#define IMPALA_TRANSPORT_TSSLTRANSPORT_H
#include <string>
#include <boost/shared_ptr.hpp>
#include <boost/scoped_ptr.hpp>
#include <thrift/transport/TTransport.h>
#include <thrift/transport/TVirtualTransport.h>
#include <thrift/transport/TBufferTransports.h>
#include "transport/TSasl.h"
namespace apache { namespace thrift { namespace transport {
enum NegotiationStatus {
TSASL_INVALID = -1,
TSASL_START = 1,
TSASL_OK = 2,
TSASL_BAD = 3,
TSASL_ERROR = 4,
TSASL_COMPLETE = 5
};
static const int MECHANISM_NAME_BYTES = 1;
static const int STATUS_BYTES = 1;
static const int PAYLOAD_LENGTH_BYTES = 4;
static const int HEADER_LENGTH = STATUS_BYTES + PAYLOAD_LENGTH_BYTES;
/**
* This transport implements the Simple Authentication and Security Layer (SASL).
* see: http://www.ietf.org/rfc/rfc2222.txt. It is based on and depends
* on the presence of the cyrus-sasl library.
*
*/
class TSaslTransport : public TVirtualTransport<TSaslTransport> {
public:
/**
* Constructs a new TSaslTransport to act as a server.
* SetSaslServer must be called later to initialize the SASL endpoint underlying this
* transport.
*
*/
TSaslTransport(boost::shared_ptr<TTransport> transport);
/**
* Constructs a new TSaslTransport to act as a client.
*
*/
TSaslTransport(boost::shared_ptr<sasl::TSasl> saslClient,
boost::shared_ptr<TTransport> transport);
/**
* Destroys the TSasl object.
*/
virtual ~TSaslTransport();
/**
* Whether this transport is open.
*/
virtual bool isOpen();
/**
* Tests whether there is more data to read or if the remote side is
* still open.
*/
virtual bool peek();
/**
* Opens the transport for communications.
*
* @throws TTransportException if opening failed
*/
virtual void open();
/**
* Closes the transport.
*/
virtual void close();
/**
* Attempt to read up to the specified number of bytes into the string.
*
* @param buf Reference to the location to write the data
* @param len How many bytes to read
* @return How many bytes were actually read
* @throws TTransportException If an error occurs
*/
uint32_t read(uint8_t* buf, uint32_t len);
/**
* Writes the string in its entirety to the buffer.
*
* Note: You must call flush() to ensure the data is actually written,
* and available to be read back in the future. Destroying a TTransport
* object does not automatically flush pending data--if you destroy a
* TTransport object with written but unflushed data, that data may be
* discarded.
*
* @param buf The data to write out
* @throws TTransportException if an error occurs
*/
void write(const uint8_t* buf, uint32_t len);
/**
* Flushes any pending data to be written. Typically used with buffered
* transport mechanisms.
*
* @throws TTransportException if an error occurs
*/
virtual void flush();
/**
* Returns the transport underlying this one
*/
boost::shared_ptr<TTransport> getUnderlyingTransport() {
return transport_;
}
/**
* Returns the username associated with the underlying sasl connection.
*
* @throws TTransportException if an error occurs
*/
std::string getUsername();
protected:
/// Underlying transport
boost::shared_ptr<TTransport> transport_;
/// Buffer for reading and writing.
TMemoryBuffer* memBuf_;
/// Sasl implementation class. This is passed in to the transport constructor
/// initialized for either a client or a server.
boost::shared_ptr<sasl::TSasl> sasl_;
/// IF true we wrap data in encryption.
bool shouldWrap_;
/// True if this is a client.
bool isClient_;
/// Buffer to hold protocol info.
boost::scoped_array<uint8_t> protoBuf_;
/* store the big endian format int to given buffer */
void encodeInt(uint32_t x, uint8_t* buf, uint32_t offset) {
*(reinterpret_cast<uint32_t*>(buf + offset)) = htonl(x);
}
/* load the big endian format int to given buffer */
uint32_t decodeInt (uint8_t* buf, uint32_t offset) {
return ntohl(*(reinterpret_cast<uint32_t*>(buf + offset)));
}
/**
* Performs the SASL negotiation.
*/
void doSaslNegotiation();
/**
* Create the Sasl context for a server/client connection.
*/
virtual void setupSaslNegotiationState() = 0;
/**
* Reset the negotiation state.
*/
virtual void resetSaslNegotiationState() = 0;
/**
* Read a complete Thrift SASL message.
*
* @return The SASL status and payload from this message.
* Is valid only to till the next call.
* @throws TTransportException
* Thrown if there is a failure reading from the underlying
* transport, or if a status code of BAD or ERROR is encountered.
*/
uint8_t* receiveSaslMessage(NegotiationStatus* status , uint32_t* length);
/**
* send message with SASL transport headers.
* status is put before the payload.
* If flush is false we delay flushing the underlying transport so
* that the following message will be in the same packet if necessary.
*/
void sendSaslMessage(const NegotiationStatus status,
const uint8_t* payload, const uint32_t length, bool flush = true);
/**
* Opens the transport for communications.
*
* @return bool Whether the transport was successfully opened
* @throws TTransportException if opening failed
*/
uint32_t readLength();
/**
* Write the given integer as 4 bytes to the underlying transport.
*
* @param length
* The length prefix of the next SASL message to write.
* @throws TTransportException
* Thrown if writing to the underlying transport fails.
*/
void writeLength(uint32_t length);
virtual void handleSaslStartMessage() = 0;
/// If memBuf_ is filled with bytes that are already read, and has crossed a size
/// threshold (see implementation for exact value), resize the buffer to a default value.
void shrinkBuffer();
};
}}} // apache::thrift::transport
#endif // #ifndef IMPALA_TRANSPORT_TSSLTRANSPORT_H