blob: 1a2d2e89c5c060ac8c547d06ef114888a7f83590 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "kudu/rpc/rpc-test-base.h"
#include <string>
#include <boost/thread/thread.hpp>
#include <gtest/gtest.h>
#include <sasl/sasl.h>
#include "kudu/gutil/gscoped_ptr.h"
#include "kudu/gutil/map-util.h"
#include "kudu/rpc/constants.h"
#include "kudu/rpc/auth_store.h"
#include "kudu/rpc/sasl_client.h"
#include "kudu/rpc/sasl_common.h"
#include "kudu/rpc/sasl_server.h"
#include "kudu/util/monotime.h"
#include "kudu/util/net/sockaddr.h"
#include "kudu/util/net/socket.h"
using std::string;
namespace kudu {
namespace rpc {
class TestSaslRpc : public RpcTestBase {
public:
virtual void SetUp() OVERRIDE {
RpcTestBase::SetUp();
ASSERT_OK(SaslInit(kSaslAppName));
}
};
// Test basic initialization of the objects.
TEST_F(TestSaslRpc, TestBasicInit) {
SaslServer server(kSaslAppName, -1);
ASSERT_OK(server.Init(kSaslAppName));
SaslClient client(kSaslAppName, -1);
ASSERT_OK(client.Init(kSaslAppName));
}
// A "Callable" that takes a Socket* param, for use with starting a thread.
// Can be used for SaslServer or SaslClient threads.
typedef void (*socket_callable_t)(Socket*);
// Call Accept() on the socket, then pass the connection to the server runner
static void RunAcceptingDelegator(Socket* acceptor, socket_callable_t server_runner) {
Socket conn;
Sockaddr remote;
CHECK_OK(acceptor->Accept(&conn, &remote, 0));
server_runner(&conn);
}
// Set up a socket and run a SASL negotiation.
static void RunNegotiationTest(socket_callable_t server_runner, socket_callable_t client_runner) {
Socket server_sock;
CHECK_OK(server_sock.Init(0));
ASSERT_OK(server_sock.BindAndListen(Sockaddr(), 1));
Sockaddr server_bind_addr;
ASSERT_OK(server_sock.GetSocketAddress(&server_bind_addr));
boost::thread server(RunAcceptingDelegator, &server_sock, server_runner);
Socket client_sock;
CHECK_OK(client_sock.Init(0));
ASSERT_OK(client_sock.Connect(server_bind_addr));
boost::thread client(client_runner, &client_sock);
LOG(INFO) << "Waiting for test threads to terminate...";
client.join();
LOG(INFO) << "Client thread terminated.";
server.join();
LOG(INFO) << "Server thread terminated.";
}
////////////////////////////////////////////////////////////////////////////////
static void RunAnonNegotiationServer(Socket* conn) {
SaslServer sasl_server(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_server.Init(kSaslAppName));
CHECK_OK(sasl_server.EnableAnonymous());
CHECK_OK(sasl_server.Negotiate());
}
static void RunAnonNegotiationClient(Socket* conn) {
SaslClient sasl_client(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_client.Init(kSaslAppName));
CHECK_OK(sasl_client.EnableAnonymous());
CHECK_OK(sasl_client.Negotiate());
}
// Test SASL negotiation using the ANONYMOUS mechanism over a socket.
TEST_F(TestSaslRpc, TestAnonNegotiation) {
RunNegotiationTest(RunAnonNegotiationServer, RunAnonNegotiationClient);
}
////////////////////////////////////////////////////////////////////////////////
static void RunPlainNegotiationServer(Socket* conn) {
SaslServer sasl_server(kSaslAppName, conn->GetFd());
gscoped_ptr<AuthStore> authstore(new AuthStore());
CHECK_OK(authstore->Add("danger", "burrito"));
CHECK_OK(sasl_server.Init(kSaslAppName));
CHECK_OK(sasl_server.EnablePlain(std::move(authstore)));
CHECK_OK(sasl_server.Negotiate());
CHECK(ContainsKey(sasl_server.client_features(), APPLICATION_FEATURE_FLAGS));
}
static void RunPlainNegotiationClient(Socket* conn) {
SaslClient sasl_client(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_client.Init(kSaslAppName));
CHECK_OK(sasl_client.EnablePlain("danger", "burrito"));
CHECK_OK(sasl_client.Negotiate());
CHECK(ContainsKey(sasl_client.server_features(), APPLICATION_FEATURE_FLAGS));
}
// Test SASL negotiation using the PLAIN mechanism over a socket.
TEST_F(TestSaslRpc, TestPlainNegotiation) {
RunNegotiationTest(RunPlainNegotiationServer, RunPlainNegotiationClient);
}
////////////////////////////////////////////////////////////////////////////////
static void RunPlainFailingNegotiationServer(Socket* conn) {
SaslServer sasl_server(kSaslAppName, conn->GetFd());
gscoped_ptr<AuthStore> authstore(new AuthStore());
CHECK_OK(authstore->Add("danger", "burrito"));
CHECK_OK(sasl_server.Init(kSaslAppName));
CHECK_OK(sasl_server.EnablePlain(std::move(authstore)));
Status s = sasl_server.Negotiate();
ASSERT_TRUE(s.IsNotAuthorized()) << "Expected auth failure! Got: " << s.ToString();
}
static void RunPlainFailingNegotiationClient(Socket* conn) {
SaslClient sasl_client(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_client.Init(kSaslAppName));
CHECK_OK(sasl_client.EnablePlain("unknown", "burrito"));
Status s = sasl_client.Negotiate();
ASSERT_TRUE(s.IsNotAuthorized()) << "Expected auth failure! Got: " << s.ToString();
}
// Test SASL negotiation using the PLAIN mechanism over a socket.
TEST_F(TestSaslRpc, TestPlainFailingNegotiation) {
RunNegotiationTest(RunPlainFailingNegotiationServer, RunPlainFailingNegotiationClient);
}
////////////////////////////////////////////////////////////////////////////////
static void RunTimeoutExpectingServer(Socket* conn) {
SaslServer sasl_server(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_server.Init(kSaslAppName));
CHECK_OK(sasl_server.EnableAnonymous());
Status s = sasl_server.Negotiate();
ASSERT_TRUE(s.IsNetworkError()) << "Expected client to time out and close the connection. Got: "
<< s.ToString();
}
static void RunTimeoutNegotiationClient(Socket* sock) {
SaslClient sasl_client(kSaslAppName, sock->GetFd());
CHECK_OK(sasl_client.Init(kSaslAppName));
CHECK_OK(sasl_client.EnableAnonymous());
MonoTime deadline = MonoTime::Now(MonoTime::FINE);
deadline.AddDelta(MonoDelta::FromMilliseconds(-100L));
sasl_client.set_deadline(deadline);
Status s = sasl_client.Negotiate();
ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
CHECK_OK(sock->Shutdown(true, true));
}
// Ensure that the client times out.
TEST_F(TestSaslRpc, TestClientTimeout) {
RunNegotiationTest(RunTimeoutExpectingServer, RunTimeoutNegotiationClient);
}
////////////////////////////////////////////////////////////////////////////////
static void RunTimeoutNegotiationServer(Socket* sock) {
SaslServer sasl_server(kSaslAppName, sock->GetFd());
CHECK_OK(sasl_server.Init(kSaslAppName));
CHECK_OK(sasl_server.EnableAnonymous());
MonoTime deadline = MonoTime::Now(MonoTime::FINE);
deadline.AddDelta(MonoDelta::FromMilliseconds(-100L));
sasl_server.set_deadline(deadline);
Status s = sasl_server.Negotiate();
ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
CHECK_OK(sock->Close());
}
static void RunTimeoutExpectingClient(Socket* conn) {
SaslClient sasl_client(kSaslAppName, conn->GetFd());
CHECK_OK(sasl_client.Init(kSaslAppName));
CHECK_OK(sasl_client.EnableAnonymous());
Status s = sasl_client.Negotiate();
ASSERT_TRUE(s.IsNetworkError()) << "Expected server to time out and close the connection. Got: "
<< s.ToString();
}
// Ensure that the server times out.
TEST_F(TestSaslRpc, TestServerTimeout) {
RunNegotiationTest(RunTimeoutNegotiationServer, RunTimeoutExpectingClient);
}
////////////////////////////////////////////////////////////////////////////////
} // namespace rpc
} // namespace kudu