blob: 95a460114bb55abafc3aba53f9efeda8ee1e8c0a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "kudu/security/tls_handshake.h"
#include <algorithm>
#include <pthread.h>
#include <sched.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <atomic>
#include <csignal>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "kudu/gutil/macros.h"
#include "kudu/security/tls_context.h"
#include "kudu/util/countdown_latch.h"
#include "kudu/util/monotime.h"
#include "kudu/util/net/sockaddr.h"
#include "kudu/util/net/socket.h"
#include "kudu/util/random.h"
#include "kudu/util/random_util.h"
#include "kudu/util/scoped_cleanup.h"
#include "kudu/util/status.h"
#include "kudu/util/test_macros.h"
#include "kudu/util/test_util.h"
using std::string;
using std::thread;
using std::unique_ptr;
using std::vector;
namespace kudu {
namespace security {
const MonoDelta kTimeout = MonoDelta::FromSeconds(10);
// Size is big enough to not fit into output socket buffer of default size
// (controlled by setsockopt() with SO_SNDBUF).
constexpr size_t kEchoChunkSize = 32 * 1024 * 1024;
class TlsSocketTest : public KuduTest {
public:
void SetUp() override {
KuduTest::SetUp();
ASSERT_OK(client_tls_.Init());
}
protected:
void ConnectClient(const Sockaddr& addr, unique_ptr<Socket>* sock);
TlsContext client_tls_;
};
Status DoNegotiationSide(Socket* sock, TlsHandshake* tls, const char* side) {
tls->set_verification_mode(TlsVerificationMode::VERIFY_NONE);
bool done = false;
string received;
while (!done) {
string to_send;
Status s = tls->Continue(received, &to_send);
if (s.ok()) {
done = true;
} else if (!s.IsIncomplete()) {
RETURN_NOT_OK_PREPEND(s, "unexpected tls error");
}
if (!to_send.empty()) {
size_t nwritten;
auto deadline = MonoTime::Now() + MonoDelta::FromSeconds(10);
RETURN_NOT_OK_PREPEND(sock->BlockingWrite(
reinterpret_cast<const uint8_t*>(to_send.data()),
to_send.size(), &nwritten, deadline),
"error sending");
}
if (!done) {
uint8_t buf[1024];
int32_t n = 0;
RETURN_NOT_OK_PREPEND(sock->Recv(buf, arraysize(buf), &n),
"error receiving");
received = string(reinterpret_cast<char*>(&buf[0]), n);
}
}
LOG(INFO) << side << ": negotiation complete";
return Status::OK();
}
void TlsSocketTest::ConnectClient(const Sockaddr& addr, unique_ptr<Socket>* sock) {
unique_ptr<Socket> client_sock(new Socket());
ASSERT_OK(client_sock->Init(0));
ASSERT_OK(client_sock->Connect(addr));
TlsHandshake client;
ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client));
ASSERT_OK(DoNegotiationSide(client_sock.get(), &client, "client"));
ASSERT_OK(client.Finish(&client_sock));
*sock = std::move(client_sock);
}
class EchoServer {
public:
EchoServer()
: pthread_sync_(1) {
}
~EchoServer() {
Stop();
Join();
}
void Start() {
ASSERT_OK(server_tls_.Init());
ASSERT_OK(server_tls_.GenerateSelfSignedCertAndKey());
ASSERT_OK(listen_addr_.ParseString("127.0.0.1", 0));
ASSERT_OK(listener_.Init(0));
ASSERT_OK(listener_.BindAndListen(listen_addr_, /*listen_queue_size=*/10));
ASSERT_OK(listener_.GetSocketAddress(&listen_addr_));
thread_ = thread([&] {
pthread_ = pthread_self();
pthread_sync_.CountDown();
unique_ptr<Socket> sock(new Socket());
Sockaddr remote;
CHECK_OK(listener_.Accept(sock.get(), &remote, /*flags=*/0));
TlsHandshake server;
CHECK_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server));
CHECK_OK(DoNegotiationSide(sock.get(), &server, "server"));
CHECK_OK(server.Finish(&sock));
CHECK_OK(sock->SetRecvTimeout(kTimeout));
unique_ptr<uint8_t[]> buf(new uint8_t[kEchoChunkSize]);
// An "echo" loop for kEchoChunkSize byte buffers.
while (!stop_) {
size_t n;
Status s = sock->BlockingRecv(buf.get(), kEchoChunkSize, &n, MonoTime::Now() + kTimeout);
if (!s.ok()) {
CHECK(stop_) << "unexpected error reading: " << s.ToString();
}
LOG(INFO) << "server echoing " << n << " bytes";
size_t written;
s = sock->BlockingWrite(buf.get(), n, &written, MonoTime::Now() + kTimeout);
if (!s.ok()) {
CHECK(stop_) << "unexpected error writing: " << s.ToString();
}
if (slow_read_) {
SleepFor(MonoDelta::FromMilliseconds(10));
}
}
});
}
void EnableSlowRead() {
slow_read_ = true;
}
const Sockaddr& listen_addr() const {
return listen_addr_;
}
bool stopped() const {
return stop_;
}
void Stop() {
stop_ = true;
}
void Join() {
thread_.join();
}
const pthread_t& pthread() {
pthread_sync_.Wait();
return pthread_;
}
private:
TlsContext server_tls_;
Socket listener_;
Sockaddr listen_addr_;
thread thread_;
pthread_t pthread_;
CountDownLatch pthread_sync_;
std::atomic<bool> stop_ { false };
bool slow_read_ = false;
};
void handler(int /* signal */) {}
// Test for failures to handle EINTR during TLS connection
// negotiation and data send/receive.
TEST_F(TlsSocketTest, TestTlsSocketInterrupted) {
// Set up a no-op signal handler for SIGUSR2.
struct sigaction sa, sa_old;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = &handler;
sigaction(SIGUSR2, &sa, &sa_old);
SCOPED_CLEANUP({ sigaction(SIGUSR2, &sa_old, nullptr); });
EchoServer server;
NO_FATALS(server.Start());
// Start a thread to send signals to the server thread.
thread killer([&]() {
while (!server.stopped()) {
PCHECK(pthread_kill(server.pthread(), SIGUSR2) == 0);
SleepFor(MonoDelta::FromMicroseconds(rand() % 10));
}
});
SCOPED_CLEANUP({ killer.join(); });
unique_ptr<Socket> client_sock;
NO_FATALS(ConnectClient(server.listen_addr(), &client_sock));
unique_ptr<uint8_t[]> buf(new uint8_t[kEchoChunkSize]);
for (int i = 0; i < 10; i++) {
SleepFor(MonoDelta::FromMilliseconds(1));
size_t nwritten;
ASSERT_OK(client_sock->BlockingWrite(buf.get(), kEchoChunkSize, &nwritten,
MonoTime::Now() + kTimeout));
size_t n;
ASSERT_OK(client_sock->BlockingRecv(buf.get(), kEchoChunkSize, &n,
MonoTime::Now() + kTimeout));
}
server.Stop();
ASSERT_OK(client_sock->Close());
LOG(INFO) << "client done";
}
// Return an iovec containing the same data as the buffer 'buf' with the length 'len',
// but split into random-sized chunks. The chunks are sized randomly between 1 and
// 'max_chunk_size' bytes.
vector<struct iovec> ChunkIOVec(Random* rng, uint8_t* buf, int len, int max_chunk_size) {
vector<struct iovec> ret;
uint8_t* p = buf;
int rem = len;
while (rem > 0) {
int len = rng->Uniform(max_chunk_size) + 1;
len = std::min(len, rem);
ret.push_back({p, static_cast<size_t>(len)});
p += len;
rem -= len;
}
return ret;
}
// Regression test for KUDU-2218, a bug in which Writev would improperly handle
// partial writes in non-blocking mode.
TEST_F(TlsSocketTest, TestNonBlockingWritev) {
Random rng(GetRandomSeed32());
EchoServer server;
server.EnableSlowRead();
NO_FATALS(server.Start());
unique_ptr<Socket> client_sock;
NO_FATALS(ConnectClient(server.listen_addr(), &client_sock));
int sndbuf = 16 * 1024;
CHECK_ERR(setsockopt(client_sock->GetFd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, sizeof(sndbuf)));
unique_ptr<uint8_t[]> buf(new uint8_t[kEchoChunkSize]);
unique_ptr<uint8_t[]> rbuf(new uint8_t[kEchoChunkSize]);
RandomString(buf.get(), kEchoChunkSize, &rng);
for (int i = 0; i < 10; i++) {
ASSERT_OK(client_sock->SetNonBlocking(true));
// Prepare an IOV with the input data split into a bunch of randomly-sized
// chunks.
vector<struct iovec> iov = ChunkIOVec(&rng, buf.get(), kEchoChunkSize, 1024 * 1024);
// Loop calling writev until the iov is exhausted
int rem = kEchoChunkSize;
while (rem > 0) {
CHECK(!iov.empty()) << rem;
int64_t n;
Status s = client_sock->Writev(&iov[0], iov.size(), &n);
if (Socket::IsTemporarySocketError(s.posix_code())) {
sched_yield();
continue;
}
ASSERT_OK(s);
ASSERT_LE(n, rem);
rem -= n;
ASSERT_GE(n, 0);
while (n > 0) {
if (n < iov[0].iov_len) {
iov[0].iov_len -= n;
iov[0].iov_base = reinterpret_cast<uint8_t*>(iov[0].iov_base) + n;
n = 0;
} else {
n -= iov[0].iov_len;
iov.erase(iov.begin());
}
}
}
LOG(INFO) << "client waiting";
size_t n;
ASSERT_OK(client_sock->SetNonBlocking(false));
ASSERT_OK(client_sock->BlockingRecv(rbuf.get(), kEchoChunkSize, &n,
MonoTime::Now() + kTimeout));
LOG(INFO) << "client got response";
ASSERT_EQ(0, memcmp(buf.get(), rbuf.get(), kEchoChunkSize));
}
server.Stop();
ASSERT_OK(client_sock->Close());
}
} // namespace security
} // namespace kudu