blob: a1766c38b1c9ea8c9e9c9bfa2fc69e7eccc50434 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "kudu/thrift/sasl_client_transport.h"
#include <algorithm>
#include <cstring>
#include <functional>
#include <limits>
#include <memory>
#include <ostream>
#include <string>
#include <glog/logging.h>
#include <thrift/transport/TTransport.h>
#include "kudu/gutil/endian.h"
#include "kudu/gutil/port.h"
#include "kudu/gutil/strings/human_readable.h"
#include "kudu/gutil/strings/substitute.h"
#include "kudu/rpc/sasl_common.h"
#include "kudu/rpc/sasl_helper.h"
#include "kudu/util/faststring.h"
#include "kudu/util/logging.h"
#include "kudu/util/slice.h"
#include "kudu/util/status.h"
#if defined(__APPLE__)
// Almost all functions in the SASL API are marked as deprecated
// since macOS 10.11.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif // #if defined(__APPLE__)
using apache::thrift::transport::TTransportException;
using std::shared_ptr;
using std::string;
using strings::Substitute;
namespace kudu {
using rpc::SaslMechanism;
using rpc::WrapSaslCall;
namespace thrift {
namespace {
// SASL negotiation frames are sent with an 8-bit status and a 32-bit length.
const uint32_t kSaslHeaderSize = sizeof(uint8_t) + sizeof(uint32_t);
// Frame headers consist of a 32-bit length.
const uint32_t kFrameHeaderSize = sizeof(uint32_t);
// SASL SASL_CB_GETOPT callback function.
int GetoptCb(SaslClientTransport* client_transport,
const char* plugin_name,
const char* option,
const char** result,
unsigned* len) {
return client_transport->GetOptionCb(plugin_name, option, result, len);
}
// SASL SASL_CB_CANON_USER callback function.
int CanonUserCb(sasl_conn_t* /*conn*/,
void* /*context*/,
const char* in, unsigned inlen,
unsigned /*flags*/,
const char* /*user_realm*/,
char* out, unsigned out_max, unsigned* out_len) {
CHECK_LE(inlen, out_max);
memcpy(out, in, inlen);
*out_len = inlen;
return SASL_OK;
}
// SASL SASL_CB_USER callback function.
int UserCb(void* /*context*/, int id, const char** result, unsigned* len) {
CHECK_EQ(SASL_CB_USER, id);
// Setting the username to the empty string causes the remote end to use the
// clients Kerberos principal, which is correct.
*result = "";
if (len != nullptr) *len = 0;
return SASL_OK;
}
} // anonymous namespace
SaslClientTransport::SaslClientTransport(string service_principal,
const string& server_fqdn,
shared_ptr<TTransport> transport,
size_t max_recv_buf_size)
: transport_(std::move(transport)),
sasl_helper_(rpc::SaslHelper::CLIENT),
sasl_callbacks_({
rpc::SaslBuildCallback(SASL_CB_GETOPT, reinterpret_cast<int (*)()>(&GetoptCb), this),
rpc::SaslBuildCallback(SASL_CB_CANON_USER,
reinterpret_cast<int (*)()>(&CanonUserCb),
this),
rpc::SaslBuildCallback(SASL_CB_USER, reinterpret_cast<int (*)()>(&UserCb), nullptr),
rpc::SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr)
}),
needs_wrap_(false),
max_recv_buf_size_(max_recv_buf_size),
// Set a reasonable max send buffer size for negotiation. Once negotiation
// is complete the negotiated value will be used.
max_send_buf_size_(64 * 1024),
service_principal_(std::move(service_principal)) {
sasl_helper_.set_server_fqdn(server_fqdn);
sasl_helper_.EnableGSSAPI();
ResetWriteBuf();
}
bool SaslClientTransport::isOpen() {
return transport_->isOpen();
}
bool SaslClientTransport::peek() {
return !read_slice_.empty() || transport_->peek();
}
void SaslClientTransport::open() {
transport_->open();
DCHECK(transport_->isOpen());
try {
Negotiate();
} catch (...) {
transport_->close();
throw;
}
}
void SaslClientTransport::close() {
transport_->close();
sasl_conn_.reset();
}
void SaslClientTransport::ReadFrame() {
DCHECK_EQ(0, read_buf_.size());
DCHECK(read_slice_.empty());
uint8_t payload_len_buf[kFrameHeaderSize];
transport_->readAll(payload_len_buf, kFrameHeaderSize);
size_t payload_len = NetworkByteOrder::Load32(payload_len_buf);
if (payload_len > 1024 * 1024) {
KLOG_EVERY_N_SECS(WARNING, 60) << "Received large Thrift SASL frame: "
<< HumanReadableNumBytes::ToString(payload_len);
if (payload_len > max_recv_buf_size_) {
throw TTransportException(Substitute("Thrift SASL frame is too long: $0/$1",
HumanReadableNumBytes::ToString(payload_len),
HumanReadableNumBytes::ToString(max_recv_buf_size_)));
}
}
read_buf_.reserve(kFrameHeaderSize + payload_len);
read_buf_.append(payload_len_buf, kFrameHeaderSize);
read_buf_.resize(kFrameHeaderSize + payload_len);
transport_->readAll(&read_buf_.data()[kFrameHeaderSize], payload_len);
if (needs_wrap_) {
// Point read_slice_ directly at the SASL library's internal buffer. This
// avoids having to copy the decoded data back into read_buf_.
Status s = rpc::SaslDecode(sasl_conn_.get(), read_buf_, &read_slice_);
if (!s.ok()) {
throw SaslException(s);
}
ResetReadBuf();
} else {
read_slice_ = read_buf_;
read_slice_.remove_prefix(kFrameHeaderSize);
}
}
uint32_t SaslClientTransport::read(uint8_t* buf, uint32_t len) {
// If there is nothing left to read in the buffer, then fill it.
if (read_slice_.empty()) {
ReadFrame();
}
uint32_t n = std::min(read_slice_.size(), static_cast<size_t>(len));
memcpy(buf, read_slice_.data(), n);
read_slice_.remove_prefix(n);
if (read_slice_.empty()) {
ResetReadBuf();
}
return n;
}
void SaslClientTransport::write(const uint8_t* buf, uint32_t len) {
// Check that we've already preallocated space in the buffer for the frame-header.
DCHECK(write_buf_.size() >= kFrameHeaderSize);
// Check if the amount to write would overflow a frame.
while (write_buf_.size() + len > max_send_buf_size_) {
uint32_t n = max_send_buf_size_ - write_buf_.size();
write_buf_.append(buf, n);
flush();
buf += n;
len -= n;
}
write_buf_.append(buf, len);
}
void SaslClientTransport::flush() {
if (needs_wrap_) {
Slice plaintext(write_buf_);
plaintext.remove_prefix(kFrameHeaderSize);
Slice ciphertext;
Status s = rpc::SaslEncode(sasl_conn_.get(), plaintext, &ciphertext);
if (!s.ok()) {
throw SaslException(s);
}
// Note: when the SASL C library encodes the plaintext, it prefixes the
// ciphertext with the length. Since this happens to match the SASL/Thrift
// frame format, we can send the ciphertext unmodified to the remote server.
transport_->write(ciphertext.data(), ciphertext.size());
} else {
size_t payload_len = write_buf_.size() - kFrameHeaderSize;
NetworkByteOrder::Store32(write_buf_.data(), payload_len);
transport_->write(write_buf_.data(), write_buf_.size());
}
transport_->flush();
ResetWriteBuf();
}
void SaslClientTransport::Negotiate() {
SetupSaslContext();
faststring recv_buf;
SendSaslStart();
for (;;) {
NegotiationStatus status = ReceiveSaslMessage(&recv_buf);
if (status == TSASL_COMPLETE) {
throw SaslException(
Status::IllegalState("Received SASL COMPLETE status, but handshake is not finished"));
}
CHECK_EQ(status, TSASL_OK);
const char* out;
unsigned out_len;
Status s = WrapSaslCall(sasl_conn_.get(), [&] {
return sasl_client_step(sasl_conn_.get(),
reinterpret_cast<const char*>(recv_buf.data()),
recv_buf.size(),
nullptr,
&out,
&out_len);
}, "calling sasl_client_step()");
if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
throw SaslException(std::move(s));
}
SendSaslMessage(status, Slice(out, out_len));
transport_->flush();
if (s.ok()) {
break;
}
}
NegotiationStatus status = ReceiveSaslMessage(&recv_buf);
if (status != TSASL_COMPLETE) {
throw SaslException(
Status::IllegalState("Received SASL OK status, but expected SASL COMPLETE"));
}
DCHECK_EQ(0, recv_buf.size());
needs_wrap_ = rpc::NeedsWrap(sasl_conn_.get());
max_send_buf_size_ = rpc::GetMaxSendBufferSize(sasl_conn_.get());
VLOG(2) << "Thrift SASL GSSAPI negotiation complete"
<< "; needs wrap: " << (needs_wrap_ ? "true" : "false")
<< ", max send frame length: "
<< HumanReadableNumBytes::ToStringWithoutRounding(max_send_buf_size_)
<< ", max receive frame length: "
<< HumanReadableNumBytes::ToStringWithoutRounding(max_recv_buf_size_);
}
void SaslClientTransport::SendSaslMessage(NegotiationStatus status, Slice payload) {
uint8_t header[kSaslHeaderSize];
header[0] = status;
DCHECK_LT(payload.size(), std::numeric_limits<int32_t>::max());
NetworkByteOrder::Store32(&header[1], payload.size());
transport_->write(header, kSaslHeaderSize);
if (!payload.empty()) {
transport_->write(payload.data(), payload.size());
}
}
NegotiationStatus SaslClientTransport::ReceiveSaslMessage(faststring* payload) {
// Read the fixed-length message header.
uint8_t header[kSaslHeaderSize];
transport_->readAll(header, kSaslHeaderSize);
size_t len = NetworkByteOrder::Load32(&header[1]);
// Handle status errors.
switch (header[0]) {
case TSASL_OK:
case TSASL_COMPLETE: break;
case TSASL_BAD:
case TSASL_ERROR:
throw SaslException(Status::RuntimeError("SASL peer indicated failure"));
// The Thrift client should never receive TSASL_START.
case TSASL_START:
default:
throw SaslException(Status::RuntimeError("Unexpected SASL status",
std::to_string(header[0])));
}
// Read the message payload.
if (len > max_recv_buf_size_) {
throw SaslException(Status::RuntimeError(Substitute(
"SASL negotiation message payload exceeds maximum length: $0/$1",
HumanReadableNumBytes::ToString(len),
HumanReadableNumBytes::ToString(max_recv_buf_size_))));
}
payload->resize(len);
transport_->readAll(payload->data(), len);
return static_cast<NegotiationStatus>(header[0]);
}
void SaslClientTransport::SendSaslStart() {
auto s = rpc::EnableProtection(sasl_conn_.get(),
rpc::SaslProtection::kAuthentication,
max_recv_buf_size_);
if (PREDICT_FALSE(!s.ok())) {
throw SaslException(std::move(s));
}
const char* init_msg = nullptr;
unsigned init_msg_len = 0;
const char* negotiated_mech = nullptr;
s = WrapSaslCall(sasl_conn_.get(), [&] {
return sasl_client_start(
sasl_conn_.get(), // The SASL connection context created by sasl_client_new()
SaslMechanism::name_of(SaslMechanism::GSSAPI), // The mechanism to use.
nullptr, // Disables INTERACT return if NULL.
&init_msg, // Filled in on success.
&init_msg_len, // Filled in on success.
&negotiated_mech); // Filled in on success.
}, "calling sasl_client_start()");
if (PREDICT_FALSE(!s.ok() && !s.IsIncomplete())) {
throw SaslException(std::move(s));
}
// Check that the SASL library is using the mechanism that we picked.
DCHECK_EQ(SaslMechanism::value_of(negotiated_mech), SaslMechanism::GSSAPI);
// These two calls comprise a single message in the thrift-sasl protocol.
SendSaslMessage(TSASL_START, Slice(negotiated_mech));
SendSaslMessage(TSASL_OK, Slice(init_msg, init_msg_len));
transport_->flush();
}
int SaslClientTransport::GetOptionCb(const char* plugin_name,
const char* option,
const char** result,
unsigned* len) {
return sasl_helper_.GetOptionCb(plugin_name, option, result, len);
}
void SaslClientTransport::SetupSaslContext() {
sasl_conn_t* sasl_conn = nullptr;
Status s = WrapSaslCall(nullptr /* no conn */, [&] {
return sasl_client_new(
service_principal_.c_str(), // Registered name of the service using SASL. Required.
sasl_helper_.server_fqdn(), // The fully qualified domain name of the remote server.
nullptr, // Local and remote IP address strings. (we don't use
nullptr, // any mechanisms which require this info.)
sasl_callbacks_.data(), // Connection-specific callbacks.
0, // flags
&sasl_conn);
},
"calling sasl_client_new()");
if (!s.ok()) {
throw SaslException(s);
}
sasl_conn_.reset(sasl_conn);
}
void SaslClientTransport::ResetReadBuf() {
read_buf_.clear();
read_buf_.shrink_to_fit();
}
void SaslClientTransport::ResetWriteBuf() {
write_buf_.resize(kFrameHeaderSize);
write_buf_.shrink_to_fit();
}
} // namespace thrift
} // namespace kudu
#if defined(__APPLE__)
#pragma GCC diagnostic pop
#endif // #if defined(__APPLE__)