blob: d4c8c86618ae7370273aa660131ed854aa8b7966 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <boost/algorithm/string.hpp>
#include <boost/filesystem.hpp>
#include <boost/thread.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <openssl/err.h>
#include <thrift/concurrency/Thread.h>
#include <thrift/concurrency/ThreadManager.h>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/transport/TSocket.h>
#include <thrift/transport/TSSLServerSocket.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/transport/TServerSocket.h>
#include <gflags/gflags.h>
#include <sstream>
#include "gen-cpp/Types_types.h"
#include "kudu/security/openssl_util.h"
#include "rpc/TAcceptQueueServer.h"
#include "rpc/auth-provider.h"
#include "rpc/authentication.h"
#include "rpc/thrift-server.h"
#include "rpc/thrift-thread.h"
#include "transport/THttpServer.h"
#include "util/condition-variable.h"
#include "util/debug-util.h"
#include "util/metrics.h"
#include "util/network-util.h"
#include "util/openssl-util.h"
#include "util/os-util.h"
#include "util/uid-util.h"
#include "common/names.h"
namespace posix_time = boost::posix_time;
using namespace boost::algorithm;
using boost::filesystem::exists;
using boost::get_system_time;
using boost::system_time;
using boost::uuids::uuid;
using namespace apache::thrift::concurrency;
using namespace apache::thrift::protocol;
using namespace apache::thrift::server;
using namespace apache::thrift::transport;
using namespace apache::thrift;
DECLARE_string(principal);
DECLARE_string(keytab_file);
namespace impala {
// Specifies the allowed set of values for --ssl_minimum_version. To keep consistent with
// Apache Kudu, specifying a single version enables all versions including and succeeding
// that one (e.g. TLSv1.1 enables v1.1 and v1.2).
map<string, SSLProtocol> SSLProtoVersions::PROTO_MAP = {
{"tlsv1.2", TLSv1_2},
{"tlsv1.1", TLSv1_1},
{"tlsv1", TLSv1_0}};
// A generic wrapper for OpenSSL structures.
template <typename T>
using c_unique_ptr = std::unique_ptr<T, std::function<void(T*)>>;
Status SSLProtoVersions::StringToProtocol(const string& in, SSLProtocol* protocol) {
for (const auto& proto : SSLProtoVersions::PROTO_MAP) {
if (iequals(in, proto.first)) {
*protocol = proto.second;
return Status::OK();
}
}
return Status(Substitute("Unknown TLS version: '$0'", in));
}
bool SSLProtoVersions::IsSupported(const SSLProtocol& protocol) {
DCHECK_LE(protocol, TLSv1_2);
int max_supported_tls_version = MaxSupportedTlsVersion();
DCHECK_GE(max_supported_tls_version, TLS1_VERSION);
switch (max_supported_tls_version) {
case TLS1_VERSION:
return protocol == TLSv1_0;
case TLS1_1_VERSION:
return protocol == TLSv1_0 || protocol == TLSv1_1;
default:
DCHECK_GE(max_supported_tls_version, TLS1_2_VERSION);
return true;
}
}
Status ThriftServer::ThriftServerEventProcessor::StartAndWaitForServer() {
// Locking here protects against missed notifications if Supervise executes quickly
unique_lock<mutex> lock(signal_lock_);
thrift_server_->started_ = false;
stringstream name;
name << "supervise-" << thrift_server_->name_;
RETURN_IF_ERROR(Thread::Create("thrift-server", name.str(),
&ThriftServer::ThriftServerEventProcessor::Supervise, this,
&thrift_server_->server_thread_));
timespec deadline;
TimeFromNowMillis(ThriftServer::ThriftServerEventProcessor::TIMEOUT_MS, &deadline);
// Loop protects against spurious wakeup. Locks provide necessary fences to ensure
// visibility.
while (!signal_fired_) {
// Yields lock and allows supervision thread to continue and signal
if (!signal_cond_.WaitUntil(lock, deadline)) {
stringstream ss;
ss << "ThriftServer '" << thrift_server_->name_ << "' (on port: "
<< thrift_server_->port_ << ") did not start within "
<< ThriftServer::ThriftServerEventProcessor::TIMEOUT_MS << "ms";
LOG(ERROR) << ss.str();
return Status(ss.str());
}
}
// started_ == true only if preServe was called. May be false if there was an exception
// after preServe that was caught by Supervise, causing it to reset the error condition.
if (thrift_server_->started_ == false) {
stringstream ss;
ss << "ThriftServer '" << thrift_server_->name_ << "' (on port: "
<< thrift_server_->port_ << ") did not start correctly ";
LOG(ERROR) << ss.str();
return Status(ss.str());
}
return Status::OK();
}
void ThriftServer::ThriftServerEventProcessor::Supervise() {
DCHECK(thrift_server_->server_.get() != NULL);
try {
thrift_server_->server_->serve();
} catch (TException& e) {
LOG(ERROR) << "ThriftServer '" << thrift_server_->name_ << "' (on port: "
<< thrift_server_->port_ << ") exited due to TException: " << e.what();
}
{
// signal_lock_ ensures mutual exclusion of access to thrift_server_
lock_guard<mutex> lock(signal_lock_);
thrift_server_->started_ = false;
// There may not be anyone waiting on this signal (if the
// exception occurs after startup). That's not a problem, this is
// just to avoid waiting for the timeout in case of a bind
// failure, for example.
signal_fired_ = true;
}
signal_cond_.NotifyAll();
}
void ThriftServer::ThriftServerEventProcessor::preServe() {
// Acquire the signal lock to ensure that StartAndWaitForServer is
// waiting on signal_cond_ when we notify.
lock_guard<mutex> lock(signal_lock_);
signal_fired_ = true;
// This is the (only) success path - if this is not reached within TIMEOUT_MS,
// StartAndWaitForServer will indicate failure.
thrift_server_->started_ = true;
// Should only be one thread waiting on signal_cond_, but wake all just in case.
signal_cond_.NotifyAll();
}
// This thread-local variable contains the current connection context for whichever
// thrift server is currently serving a request on the current thread. This includes
// connection state such as the connection identifier and the username.
__thread ThriftServer::ConnectionContext* __connection_context__;
bool ThriftServer::HasThreadConnectionContext() {
return __connection_context__ != nullptr;
}
const TUniqueId& ThriftServer::GetThreadConnectionId() {
return __connection_context__->connection_id;
}
const ThriftServer::ConnectionContext* ThriftServer::GetThreadConnectionContext() {
return __connection_context__;
}
void* ThriftServer::ThriftServerEventProcessor::createContext(
boost::shared_ptr<TProtocol> input, boost::shared_ptr<TProtocol> output) {
boost::shared_ptr<ConnectionContext> connection_ptr =
boost::shared_ptr<ConnectionContext>(new ConnectionContext);
thrift_server_->auth_provider_->SetupConnectionContext(connection_ptr,
thrift_server_->transport_type_, input->getTransport().get(),
output->getTransport().get());
{
connection_ptr->server_name = thrift_server_->name_;
lock_guard<mutex> l(thrift_server_->connection_contexts_lock_);
uuid connection_uuid = thrift_server_->uuid_generator_();
UUIDToTUniqueId(connection_uuid, &connection_ptr->connection_id);
// Add the connection to the connection map.
__connection_context__ = connection_ptr.get();
thrift_server_->connection_contexts_[connection_ptr.get()] = connection_ptr;
}
if (thrift_server_->connection_handler_ != NULL) {
thrift_server_->connection_handler_->ConnectionStart(*__connection_context__);
}
if (thrift_server_->metrics_enabled_) {
thrift_server_->num_current_connections_metric_->Increment(1L);
thrift_server_->total_connections_metric_->Increment(1L);
}
// Store the __connection_context__ in the per-client context. If only this were
// accessible from RPC method calls, we wouldn't have to
// mess around with thread locals.
return (void*)__connection_context__;
}
void ThriftServer::ThriftServerEventProcessor::processContext(void* context,
boost::shared_ptr<TTransport> transport) {
__connection_context__ = reinterpret_cast<ConnectionContext*>(context);
}
bool ThriftServer::ThriftServerEventProcessor::IsIdleContext(void* context) {
__connection_context__ = reinterpret_cast<ConnectionContext*>(context);
if (thrift_server_->connection_handler_ != nullptr) {
return thrift_server_->connection_handler_->IsIdleConnection(*__connection_context__);
}
return false;
}
void ThriftServer::ThriftServerEventProcessor::deleteContext(void* context,
boost::shared_ptr<TProtocol> input, boost::shared_ptr<TProtocol> output) {
__connection_context__ = reinterpret_cast<ConnectionContext*>(context);
if (thrift_server_->connection_handler_ != NULL) {
thrift_server_->connection_handler_->ConnectionEnd(*__connection_context__);
}
{
lock_guard<mutex> l(thrift_server_->connection_contexts_lock_);
thrift_server_->connection_contexts_.erase(__connection_context__);
}
if (thrift_server_->metrics_enabled_) {
thrift_server_->num_current_connections_metric_->Increment(-1L);
}
}
ThriftServer::ThriftServer(const string& name,
const boost::shared_ptr<TProcessor>& processor, int port, AuthProvider* auth_provider,
MetricGroup* metrics, int max_concurrent_connections, int64_t queue_timeout_ms,
int64_t idle_poll_period_ms, TransportType transport_type)
: started_(false),
port_(port),
ssl_enabled_(false),
max_concurrent_connections_(max_concurrent_connections),
queue_timeout_ms_(queue_timeout_ms),
idle_poll_period_ms_(idle_poll_period_ms),
name_(name),
metrics_name_(Substitute("impala.thrift-server.$0", name_)),
server_(NULL),
processor_(processor),
connection_handler_(NULL),
metrics_(NULL),
auth_provider_(auth_provider),
transport_type_(transport_type) {
if (auth_provider_ == NULL) {
auth_provider_ = AuthManager::GetInstance()->GetInternalAuthProvider();
}
if (metrics != NULL) {
metrics_enabled_ = true;
stringstream count_ss;
count_ss << metrics_name_ << ".connections-in-use";
num_current_connections_metric_ = metrics->AddGauge(count_ss.str(), 0);
stringstream max_ss;
max_ss << metrics_name_ << ".total-connections";
total_connections_metric_ = metrics->AddCounter(max_ss.str(), 0);
metrics_ = metrics;
} else {
metrics_enabled_ = false;
}
}
namespace {
/// Factory subclass to override getPassword() which provides a password string to Thrift
/// to decrypt the private key file.
class ImpalaSslSocketFactory : public TSSLSocketFactory {
public:
ImpalaSslSocketFactory(SSLProtocol version, const string& password)
: TSSLSocketFactory(version), password_(password) {}
void ciphers(const string& enable) override {
SCOPED_OPENSSL_NO_PENDING_ERRORS;
TSSLSocketFactory::ciphers(enable);
// The following was taken from be/src/kudu/security/tls_context.cc, bugs fixed here
// may also need to be fixed there.
// Enable ECDH curves. For OpenSSL 1.1.0 and up, this is done automatically.
#ifndef OPENSSL_NO_ECDH
#if OPENSSL_VERSION_NUMBER < 0x10002000L
// OpenSSL 1.0.1 and below only support setting a single ECDH curve at once.
// We choose prime256v1 because it's the first curve listed in the "modern
// compatibility" section of the Mozilla Server Side TLS recommendations,
// accessed Feb. 2017.
c_unique_ptr<EC_KEY> ecdh{
EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), &EC_KEY_free};
if (ecdh == nullptr) {
throw TSSLException(
"failed to create prime256v1 curve: " + kudu::security::GetOpenSSLErrors());
}
int rc = SSL_CTX_set_tmp_ecdh(ctx_->get(), ecdh.get());
if (rc <= 0) {
throw new TSSLException(
"failed to set ECDH curve: " + kudu::security::GetOpenSSLErrors());
}
#elif OPENSSL_VERSION_NUMBER < 0x10100000L
// OpenSSL 1.0.2 provides the set_ecdh_auto API which internally figures out
// the best curve to use.
int rc = SSL_CTX_set_ecdh_auto(ctx_->get(), 1);
if (rc <= 0) {
throw TSSLException(
"failed to configure ECDH support: " + kudu::security::GetOpenSSLErrors());
}
#endif
#endif
}
protected:
virtual void getPassword(string& output, int size) override {
output = password_;
if (output.size() > size) output.resize(size);
}
private:
/// The password string.
const string password_;
};
}
Status ThriftServer::CreateSocket(boost::shared_ptr<TServerSocket>* socket) {
if (ssl_enabled()) {
if (!SSLProtoVersions::IsSupported(version_)) {
return Status(TErrorCode::SSL_SOCKET_CREATION_FAILED,
Substitute("TLS ($0) version not supported (maximum supported version is $1)",
version_, MaxSupportedTlsVersion()));
}
try {
// This 'factory' is only called once, since CreateSocket() is only called from
// Start(). The c'tor may throw if there is an error initializing the SSL context.
boost::shared_ptr<TSSLSocketFactory> socket_factory(
new ImpalaSslSocketFactory(version_, key_password_));
socket_factory->overrideDefaultPasswordCallback();
if (!cipher_list_.empty()) socket_factory->ciphers(cipher_list_);
socket_factory->loadCertificate(certificate_path_.c_str());
socket_factory->loadPrivateKey(private_key_path_.c_str());
socket->reset(new TSSLServerSocket(port_, socket_factory));
} catch (const TException& e) {
return Status(TErrorCode::SSL_SOCKET_CREATION_FAILED, e.what());
}
return Status::OK();
} else {
socket->reset(new TServerSocket(port_));
return Status::OK();
}
}
Status ThriftServer::EnableSsl(SSLProtocol version, const string& certificate,
const string& private_key, const string& pem_password_cmd,
const std::string& ciphers) {
DCHECK(!started_);
if (certificate.empty()) return Status(TErrorCode::SSL_CERTIFICATE_PATH_BLANK);
if (private_key.empty()) return Status(TErrorCode::SSL_PRIVATE_KEY_PATH_BLANK);
if (!exists(certificate)) {
return Status(TErrorCode::SSL_CERTIFICATE_NOT_FOUND, certificate);
}
// TODO: Consider warning if private key file is world-readable
if (!exists(private_key)) {
return Status(TErrorCode::SSL_PRIVATE_KEY_NOT_FOUND, private_key);
}
ssl_enabled_ = true;
certificate_path_ = certificate;
private_key_path_ = private_key;
cipher_list_ = ciphers;
version_ = version;
if (!pem_password_cmd.empty()) {
if (!RunShellProcess(pem_password_cmd, &key_password_, true, {"JAVA_TOOL_OPTIONS"})) {
return Status(TErrorCode::SSL_PASSWORD_CMD_FAILED, pem_password_cmd, key_password_);
} else {
LOG(INFO) << "Command '" << pem_password_cmd << "' executed successfully, "
<< ".PEM password retrieved";
}
}
return Status::OK();
}
Status ThriftServer::Start() {
DCHECK(!started_);
boost::shared_ptr<TProtocolFactory> protocol_factory(new TBinaryProtocolFactory());
boost::shared_ptr<ThreadFactory> thread_factory(
new ThriftThreadFactory("thrift-server", name_));
// Note - if you change the transport types here, you must check that the
// logic in createContext is still accurate.
boost::shared_ptr<TServerSocket> server_socket;
boost::shared_ptr<TTransportFactory> transport_factory;
RETURN_IF_ERROR(CreateSocket(&server_socket));
RETURN_IF_ERROR(auth_provider_->GetServerTransportFactory(
transport_type_, metrics_name_, metrics_, &transport_factory));
server_.reset(new TAcceptQueueServer(processor_, server_socket, transport_factory,
protocol_factory, thread_factory, name_, max_concurrent_connections_,
queue_timeout_ms_, idle_poll_period_ms_));
if (metrics_ != NULL) {
(static_cast<TAcceptQueueServer*>(server_.get()))
->InitMetrics(metrics_, metrics_name_);
}
boost::shared_ptr<ThriftServer::ThriftServerEventProcessor> event_processor(
new ThriftServer::ThriftServerEventProcessor(this));
server_->setServerEventHandler(event_processor);
RETURN_IF_ERROR(event_processor->StartAndWaitForServer());
// If port_ was 0, figure out which port the server is listening on after starting.
port_ = server_socket->getPort();
LOG(INFO) << "ThriftServer '" << name_ << "' started on port: " << port_
<< (ssl_enabled() ? "s" : "");
DCHECK(started_);
return Status::OK();
}
void ThriftServer::Join() {
DCHECK(server_thread_ != NULL);
DCHECK(started_);
server_thread_->Join();
}
void ThriftServer::StopForTesting() {
DCHECK(server_thread_ != NULL);
DCHECK(server_);
server_->stop();
if (started_) Join();
}
}