blob: 20eaef772818ec541730b5402d38b0a0372d447a [file] [log] [blame]
// This file will be removed when the code is accepted into the Thrift library.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "config.h"
#ifdef HAVE_SASL_SASL_H
#include <stdint.h>
#include <sstream>
#include <boost/shared_ptr.hpp>
#include <boost/scoped_ptr.hpp>
#include <thrift/transport/TBufferTransports.h>
#include "transport/TSaslTransport.h"
#include "common/names.h"
// Default size, in bytes, for the memory buffer used to stage reads.
const int32_t DEFAULT_MEM_BUF_SIZE = 32 * 1024;
namespace apache { namespace thrift { namespace transport {
TSaslTransport::TSaslTransport(boost::shared_ptr<TTransport> transport)
: transport_(transport),
memBuf_(new TMemoryBuffer(DEFAULT_MEM_BUF_SIZE)),
sasl_(NULL),
shouldWrap_(false),
isClient_(false) {
}
TSaslTransport::TSaslTransport(boost::shared_ptr<sasl::TSasl> saslClient,
boost::shared_ptr<TTransport> transport)
: transport_(transport),
memBuf_(new TMemoryBuffer()),
sasl_(saslClient),
shouldWrap_(false),
isClient_(true) {
}
TSaslTransport::~TSaslTransport() {
delete memBuf_;
}
bool TSaslTransport::isOpen() {
return transport_->isOpen();
}
bool TSaslTransport::peek(){
return (transport_->peek());
}
string TSaslTransport::getUsername() {
return sasl_->getUsername();
}
void TSaslTransport::doSaslNegotiation() {
NegotiationStatus status = TSASL_INVALID;
uint32_t resLength;
try {
// Setup Sasl context.
setupSaslNegotiationState();
// Initiate SASL message.
handleSaslStartMessage();
// SASL connection handshake
while (!sasl_->isComplete()) {
uint8_t* message = receiveSaslMessage(&status, &resLength);
if (status == TSASL_COMPLETE) {
if (isClient_) {
if (!sasl_->isComplete()) {
// Server sent COMPLETE out of order.
throw TTransportException("Received COMPLETE but no handshake occurred");
}
break; // handshake complete
}
} else if (status != TSASL_OK) {
stringstream ss;
ss << "Expected COMPLETE or OK, got " << status;
throw TTransportException(ss.str());
}
uint32_t challengeLength;
uint8_t* challenge = sasl_->evaluateChallengeOrResponse(
message, resLength, &challengeLength);
sendSaslMessage(sasl_->isComplete() ? TSASL_COMPLETE : TSASL_OK,
challenge, challengeLength);
}
// If the server isn't complete yet, we need to wait for its response.
// This will occur with ANONYMOUS auth, for example, where we send an
// initial response and are immediately complete.
if (isClient_ && (status == TSASL_INVALID || status == TSASL_OK)) {
receiveSaslMessage(&status, &resLength);
if (status != TSASL_COMPLETE) {
stringstream ss;
ss << "Expected COMPLETE or OK, got " << status;
throw TTransportException(ss.str());
}
}
// TODO : need to set the shouldWrap_ based on QOP
/*
String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
if (qop != null && !qop.equalsIgnoreCase("auth"))
shouldWrap_ = true;
*/
} catch (const TException& e) {
// If we hit an exception, that means the Sasl negotiation failed. We explicitly
// reset the negotiation state here since the caller may retry an open() which would
// start a new connection negotiation.
resetSaslNegotiationState();
throw e;
}
}
void TSaslTransport::open() {
// Only client should open the underlying transport.
if (isClient_ && !transport_->isOpen()) {
transport_->open();
}
// Start the SASL negotiation protocol.
doSaslNegotiation();
}
void TSaslTransport::close() {
transport_->close();
}
uint32_t TSaslTransport::readLength() {
uint8_t lenBuf[PAYLOAD_LENGTH_BYTES];
transport_->readAll(lenBuf, PAYLOAD_LENGTH_BYTES);
int32_t len = decodeInt(lenBuf, 0);
if (len < 0) {
throw TTransportException("Frame size has negative value");
}
return static_cast<uint32_t>(len);
}
void TSaslTransport::shrinkBuffer() {
// readEnd() returns the number of bytes already read, i.e. the number of 'junk' bytes
// taking up space at the front of the memory buffer.
uint32_t read_end = memBuf_->readEnd();
// If the size of the junk space at the beginning of the buffer is too large, and
// there's no data left in the buffer to read (number of bytes read == number of bytes
// written), then shrink the buffer back to the default. We don't want to do this on
// every read that exhausts the buffer, since the layer above often reads in small
// chunks, which is why we only resize if there's too much junk. The write and read
// pointers will eventually catch up after every RPC, so we will always take this path
// eventually once the buffer becomes sufficiently full.
//
// readEnd() may reset the write / read pointers (but only once if there's no
// intervening read or write between calls), so needs to be called a second time to
// get their current position.
if (read_end > DEFAULT_MEM_BUF_SIZE && memBuf_->writeEnd() == memBuf_->readEnd()) {
memBuf_->resetBuffer(DEFAULT_MEM_BUF_SIZE);
}
}
uint32_t TSaslTransport::read(uint8_t* buf, uint32_t len) {
uint32_t read_bytes = memBuf_->read(buf, len);
if (read_bytes > 0) {
shrinkBuffer();
return read_bytes;
}
// if there's not enough data in cache, read from underlying transport
uint32_t dataLength = readLength();
// Fast path
if (len == dataLength && !shouldWrap_) {
transport_->readAll(buf, len);
return len;
}
uint8_t* tmpBuf = new uint8_t[dataLength];
transport_->readAll(tmpBuf, dataLength);
if (shouldWrap_) {
tmpBuf = sasl_->unwrap(tmpBuf, 0, dataLength, &dataLength);
}
// We will consume all the data, no need to put it in the memory buffer.
if (len == dataLength) {
memcpy(buf, tmpBuf, len);
delete[] tmpBuf;
return len;
}
memBuf_->write(tmpBuf, dataLength);
memBuf_->flush();
delete[] tmpBuf;
uint32_t ret = memBuf_->read(buf, len);
shrinkBuffer();
return ret;
}
void TSaslTransport::writeLength(uint32_t length) {
uint8_t lenBuf[PAYLOAD_LENGTH_BYTES];
encodeInt(length, lenBuf, 0);
transport_->write(lenBuf, PAYLOAD_LENGTH_BYTES);
}
void TSaslTransport::write(const uint8_t* buf, uint32_t len) {
const uint8_t* newBuf;
if (shouldWrap_) {
newBuf = sasl_->wrap(const_cast<uint8_t*>(buf), 0, len, &len);
} else {
newBuf = buf;
}
writeLength(len);
transport_->write(newBuf, len);
}
void TSaslTransport::flush() {
transport_->flush();
}
void TSaslTransport::sendSaslMessage(const NegotiationStatus status,
const uint8_t* payload, const uint32_t length, bool flush) {
uint8_t messageHeader[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
uint8_t dummy = 0;
if (payload == NULL) {
payload = &dummy;
}
messageHeader[0] = (uint8_t)status;
encodeInt(length, messageHeader, STATUS_BYTES);
transport_->write(messageHeader, HEADER_LENGTH);
transport_->write(payload, length);
if (flush) transport_->flush();
}
uint8_t* TSaslTransport::receiveSaslMessage(NegotiationStatus* status,
uint32_t* length) {
uint8_t messageHeader[HEADER_LENGTH];
// read header
transport_->readAll(messageHeader, HEADER_LENGTH);
// get payload status
*status = (NegotiationStatus)messageHeader[0];
if ((*status < TSASL_START) || (*status > TSASL_COMPLETE)) {
throw TTransportException("invalid sasl status");
} else if (*status == TSASL_BAD || *status == TSASL_ERROR) {
throw TTransportException("sasl Peer indicated failure: ");
}
// get the length
*length = decodeInt(messageHeader, STATUS_BYTES);
// get payload
protoBuf_.reset(new uint8_t[*length]);
transport_->readAll(protoBuf_.get(), *length);
return protoBuf_.get();
}
}
}
}
#endif