| /* |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| #include "mockssandra.hpp" |
| |
| #include <assert.h> |
| #include <stdio.h> |
| |
| #include "control_connection.hpp" // For host queries |
| #include "memory.hpp" |
| #include "scoped_lock.hpp" |
| #include "tracing_data_handler.hpp" // For tracing query |
| #include "utils.hpp" |
| #include "uuids.hpp" |
| |
| #include <openssl/bio.h> |
| #include <openssl/dh.h> |
| #include <openssl/x509v3.h> |
| |
| #ifdef WIN32 |
| #include "winsock.h" |
| #endif |
| |
| using datastax::internal::bind_callback; |
| using datastax::internal::escape_id; |
| using datastax::internal::Map; |
| using datastax::internal::Memory; |
| using datastax::internal::OStringStream; |
| using datastax::internal::ScopedMutex; |
| using datastax::internal::trim; |
| using datastax::internal::core::UuidGen; |
| |
| #define SSL_BUF_SIZE 8192 |
| #define CASSANDRA_VERSION "3.11.4" |
| #define DSE_VERSION "6.7.1" |
| #define DSE_CASSANDRA_VERSION "4.0.0.671" |
| |
| #if defined(OPENSSL_VERSION_NUMBER) && \ |
| !defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined |
| // as 2.0.0 |
| #if (OPENSSL_VERSION_NUMBER >= 0x10100000L) |
| #define SSL_CAN_SET_MAX_VERSION |
| #define SSL_SERVER_METHOD TLS_server_method |
| #else |
| #define SSL_SERVER_METHOD SSLv23_server_method |
| #endif |
| #else |
| #if (LIBRESSL_VERSION_NUMBER >= 0x20302000L) |
| #define SSL_CAN_SET_MAX_VERSION |
| #define SSL_SERVER_METHOD TLS_server_method |
| #else |
| #define SSL_SERVER_METHOD SSLv23_server_method |
| #endif |
| #endif |
| |
| namespace mockssandra { |
| |
| namespace { |
| |
| template <class T> |
| struct FreeDeleterImpl {}; |
| |
| #define MAKE_DELETER(type, free_func) \ |
| template <> \ |
| struct FreeDeleterImpl<type> { \ |
| static void free(type* ptr) { free_func(ptr); } \ |
| }; |
| |
| MAKE_DELETER(BIO, BIO_free) |
| MAKE_DELETER(DH, DH_free) |
| MAKE_DELETER(EVP_PKEY, EVP_PKEY_free) |
| MAKE_DELETER(EVP_PKEY_CTX, EVP_PKEY_CTX_free) |
| MAKE_DELETER(X509, X509_free) |
| MAKE_DELETER(X509_REQ, X509_REQ_free) |
| MAKE_DELETER(X509_EXTENSION, X509_EXTENSION_free) |
| |
| template <class T> |
| struct FreeDeleter { |
| void operator()(T* ptr) const { FreeDeleterImpl<T>::free(ptr); } |
| }; |
| |
| template <class T> |
| class Scoped : public ScopedPtr<T, FreeDeleter<T> > { |
| public: |
| Scoped(T* ptr = NULL) |
| : ScopedPtr<T, FreeDeleter<T> >(ptr) {} |
| }; |
| |
| void print_ssl_error() { |
| unsigned long err = ERR_get_error(); |
| char buf[256]; |
| ERR_error_string_n(err, buf, sizeof(buf)); |
| fprintf(stderr, "%s\n", buf); |
| } |
| |
| X509* load_cert(const String& cert) { |
| X509* x509 = NULL; |
| Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(cert.c_str()), cert.length())); |
| if (PEM_read_bio_X509(bio.get(), &x509, NULL, NULL) == NULL) { |
| print_ssl_error(); |
| return NULL; |
| } |
| return x509; |
| } |
| |
| EVP_PKEY* load_private_key(const String& key) { |
| EVP_PKEY* pkey = NULL; |
| Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(key.c_str()), key.length())); |
| if (!PEM_read_bio_PrivateKey(bio.get(), &pkey, NULL, NULL)) { |
| print_ssl_error(); |
| return NULL; |
| } |
| return pkey; |
| } |
| |
| DH* dh_parameters() { |
| // Generated using the following command: `openssl dhparam -C 2048` |
| // Prime length of 2048 chosen to bypass client-side error: |
| // `SSL3_CHECK_CERT_AND_ALGORITHM:dh key too small` |
| |
| // Note: This is not generated, programmatically, using something like the following: |
| // `DH_generate_parameters_ex(dh, 2048, DH_GENERATOR_5, NULL)` |
| // because DH prime generation takes a *REALLY* long time. |
| static const char* dh_parameters_pem = |
| "-----BEGIN DH PARAMETERS-----\n" |
| "MIIBCAKCAQEAusYypYO7u8mHelHjpDuUy7hjBgPw/KS03iSRnP5SNMB6OxVFslXv\n" |
| "s6McqEf218Fqpzi18tWA7fq3fvlT+Nx1Tda+Za5C8o5niRYxHks5N+RfnnrFf7vn\n" |
| "0lxrzsXP6es08Ts/UGMsp1nEaCSd/gjDglPgjdC1V/KmBsbT+8IwpbzPPdir0/jA\n" |
| "r+DXssZRZl7JtymGHXPkXTSBhsqSHamfzGRnAQFWToKAinqAdhY7pN/8krwvRj04\n" |
| "VYp84xAy2M6mWWqUm/kokN9QjAiT/DZRxZK8VhY7O9+oATo7/YPCMd9Em417O13k\n" |
| "+F0o/8IMaQvpmtlAsLc2ZKwGqqG+HD2dOwIBAg==\n" |
| "-----END DH PARAMETERS-----"; |
| Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(dh_parameters_pem), |
| -1)); // Use null terminator for length |
| return PEM_read_bio_DHparams(bio.get(), NULL, NULL, NULL); |
| } |
| |
| } // namespace |
| |
| String Ssl::generate_key() { |
| Scoped<EVP_PKEY_CTX> pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL)); |
| |
| EVP_PKEY_keygen_init(pctx.get()); |
| EVP_PKEY_CTX_set_rsa_keygen_bits(pctx.get(), 2048); |
| |
| Scoped<EVP_PKEY> pkey; |
| { // Generate RSA key |
| EVP_PKEY* temp = NULL; |
| EVP_PKEY_keygen(pctx.get(), &temp); |
| pkey.reset(temp); |
| } |
| |
| Scoped<BIO> bio(BIO_new(BIO_s_mem())); |
| PEM_write_bio_PrivateKey(bio.get(), pkey.get(), NULL, NULL, 0, NULL, NULL); |
| |
| BUF_MEM* mem = NULL; |
| BIO_get_mem_ptr(bio.get(), &mem); |
| return String(mem->data, mem->length); |
| } |
| |
| String Ssl::generate_cert(const String& key, String cn, String ca_cert, String ca_key) { |
| // Assign the proper default hostname |
| if (cn.empty()) { |
| #ifdef WIN32 |
| char win_hostname[64]; |
| gethostname(win_hostname, 64); |
| cn = win_hostname; |
| #else |
| cn = "localhost"; |
| #endif |
| } |
| |
| Scoped<EVP_PKEY> pkey(load_private_key(key)); |
| if (!pkey) return ""; |
| |
| Scoped<X509_REQ> x509_req; |
| if (!ca_cert.empty() && !ca_key.empty()) { |
| x509_req.reset(X509_REQ_new()); |
| X509_REQ_set_version(x509_req.get(), 2); |
| X509_REQ_set_pubkey(x509_req.get(), pkey.get()); |
| |
| X509_NAME* name = X509_REQ_get_subject_name(x509_req.get()); |
| X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC, |
| reinterpret_cast<const unsigned char*>("US"), -1, -1, 0); |
| X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, |
| reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0); |
| X509_REQ_sign(x509_req.get(), pkey.get(), EVP_sha256()); |
| } |
| |
| Scoped<X509> x509(X509_new()); |
| X509_set_version(x509.get(), 2); |
| ASN1_INTEGER_set(X509_get_serialNumber(x509.get()), 0); |
| X509_gmtime_adj(X509_get_notBefore(x509.get()), 0); |
| X509_gmtime_adj(X509_get_notAfter(x509.get()), static_cast<long>(60 * 60 * 24 * 365)); |
| X509_set_pubkey(x509.get(), pkey.get()); |
| |
| if (x509_req) { |
| X509_set_subject_name(x509.get(), X509_REQ_get_subject_name(x509_req.get())); |
| |
| Scoped<X509> x509_ca(load_cert(ca_cert)); |
| if (!x509_ca) return ""; |
| X509_set_issuer_name(x509.get(), X509_get_issuer_name(x509_ca.get())); |
| |
| Scoped<EVP_PKEY> pkey_ca(load_private_key(ca_key)); |
| if (!pkey_ca) return ""; |
| X509_sign(x509.get(), pkey_ca.get(), EVP_sha256()); |
| } else { |
| if (cn == "CA") { // Set the purpose as a CA certificate. |
| X509V3_CTX x509v3_ctx; |
| X509V3_set_ctx_nodb(&x509v3_ctx); |
| X509V3_set_ctx(&x509v3_ctx, x509.get(), x509.get(), NULL, NULL, 0); |
| |
| Scoped<X509_EXTENSION> x509_ex(X509V3_EXT_conf_nid(NULL, &x509v3_ctx, NID_basic_constraints, |
| const_cast<char*>("critical,CA:TRUE"))); |
| if (!x509_ex) return ""; |
| |
| X509_add_ext(x509.get(), x509_ex.get(), -1); |
| } |
| X509_NAME* name = X509_get_subject_name(x509.get()); |
| X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC, |
| reinterpret_cast<const unsigned char*>("US"), -1, -1, 0); |
| X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, |
| reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0); |
| X509_set_issuer_name(x509.get(), name); |
| X509_sign(x509.get(), pkey.get(), EVP_sha256()); |
| } |
| |
| String result; |
| { // Write cert into string |
| Scoped<BIO> bio(BIO_new(BIO_s_mem())); |
| PEM_write_bio_X509(bio.get(), x509.get()); |
| BUF_MEM* mem = NULL; |
| BIO_get_mem_ptr(bio.get(), &mem); |
| result.append(mem->data, mem->length); |
| } |
| |
| return result; |
| } |
| |
| namespace internal { |
| |
| struct WriteReq { |
| WriteReq(const char* data, size_t len, ClientConnection* connection) |
| : data(data, len) |
| , connection(connection) { |
| req.data = this; |
| } |
| const String data; |
| ClientConnection* const connection; |
| uv_write_t req; |
| }; |
| |
| Tcp::Tcp(void* data) { tcp_.data = data; } |
| |
| int Tcp::init(uv_loop_t* loop) { return uv_tcp_init(loop, &tcp_); } |
| |
| int Tcp::bind(const struct sockaddr* addr) { return uv_tcp_bind(&tcp_, addr, 0); } |
| |
| uv_handle_t* Tcp::as_handle() { return reinterpret_cast<uv_handle_t*>(&tcp_); } |
| |
| uv_stream_t* Tcp::as_stream() { return reinterpret_cast<uv_stream_t*>(&tcp_); } |
| |
| ClientConnection::ClientConnection(ServerConnection* server) |
| : tcp_(this) |
| , server_(server) |
| , ssl_(server->ssl_context() ? SSL_new(server->ssl_context()) : NULL) |
| , incoming_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL) |
| , outgoing_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL) { |
| tcp_.init(server->loop()); |
| if (ssl_) { |
| SSL_set_bio(ssl_, incoming_bio_, outgoing_bio_); |
| } |
| } |
| |
| ClientConnection::~ClientConnection() { |
| if (ssl_) SSL_free(ssl_); |
| } |
| |
| int ClientConnection::write(const String& data) { return write(data.data(), data.length()); } |
| |
| int ClientConnection::write(const char* data, size_t len) { |
| if (ssl_) { |
| return ssl_write(data, len); |
| } else { |
| return internal_write(data, len); |
| } |
| } |
| |
| void ClientConnection::close() { |
| if (!uv_is_closing(tcp_.as_handle())) { |
| uv_close(tcp_.as_handle(), on_close); |
| } |
| } |
| |
| int ClientConnection::accept() { |
| int rc = server_->accept(tcp_.as_stream()); |
| if (rc != 0) return rc; |
| return uv_read_start(tcp_.as_stream(), on_alloc, on_read); |
| } |
| |
| const char* ClientConnection::sni_server_name() const { |
| if (ssl_) { |
| return SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name); |
| } |
| return NULL; |
| } |
| |
| void ClientConnection::on_close(uv_handle_t* handle) { |
| ClientConnection* connection = static_cast<ClientConnection*>(handle->data); |
| connection->handle_close(); |
| } |
| |
| void ClientConnection::handle_close() { |
| on_close(); |
| server_->remove(this); |
| delete this; |
| } |
| |
| void ClientConnection::on_alloc(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { |
| buf->base = static_cast<char*>(Memory::malloc(suggested_size)); |
| buf->len = suggested_size; |
| } |
| |
| void ClientConnection::on_read(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) { |
| ClientConnection* connection = static_cast<ClientConnection*>(stream->data); |
| connection->handle_read(nread, buf); |
| Memory::free(buf->base); |
| } |
| |
| void ClientConnection::handle_read(ssize_t nread, const uv_buf_t* buf) { |
| if (nread < 0) { |
| if (nread != UV_EOF && nread != UV_ECONNRESET) { |
| fprintf(stderr, "Read failure: %s\n", uv_strerror(nread)); |
| } |
| close(); |
| return; |
| } |
| if (ssl_) { |
| on_ssl_read(buf->base, nread); |
| } else { |
| on_read(buf->base, nread); |
| } |
| } |
| |
| void ClientConnection::on_write(uv_write_t* req, int status) { |
| WriteReq* write = static_cast<WriteReq*>(req->data); |
| write->connection->handle_write(status); |
| delete write; |
| } |
| |
| void ClientConnection::handle_write(int status) { |
| if (status != 0) { |
| fprintf(stderr, "Write failure: %s\n", uv_strerror(status)); |
| close(); |
| return; |
| } |
| |
| on_write(); |
| } |
| |
| int ClientConnection::internal_write(const char* data, size_t len) { |
| uv_buf_t buf; |
| WriteReq* write = new WriteReq(data, len, this); |
| buf.base = const_cast<char*>(write->data.data()); |
| buf.len = write->data.length(); |
| int rc = uv_write(&write->req, tcp_.as_stream(), &buf, 1, on_write); |
| if (rc != 0) { |
| delete write; |
| } |
| return rc; |
| } |
| |
| int ClientConnection::ssl_write(const char* data, size_t len) { |
| if (has_ssl_error(SSL_write(ssl_, data, len))) { |
| return -1; |
| } |
| |
| char buf[SSL_BUF_SIZE]; |
| int num_bytes; |
| while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) { |
| int rc = internal_write(buf, num_bytes); |
| if (rc != 0) { |
| return rc; |
| } |
| } |
| |
| return 0; |
| } |
| |
| bool ClientConnection::is_handshake_done() { return SSL_is_init_finished(ssl_) != 0; } |
| |
| bool ClientConnection::has_ssl_error(int rc) { |
| if (rc > 0) return false; |
| |
| int err = SSL_get_error(ssl_, rc); |
| if (err == SSL_ERROR_ZERO_RETURN) { |
| close(); |
| } else if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_NONE) { |
| const char* data; |
| int flags; |
| int err; |
| String error; |
| while ((err = ERR_get_error_line_data(NULL, NULL, &data, &flags)) != 0) { |
| char buf[256]; |
| ERR_error_string_n(err, buf, sizeof(buf)); |
| if (!error.empty()) error.push_back(','); |
| error.append(buf); |
| if (flags & ERR_TXT_STRING) { |
| error.push_back(':'); |
| error.append(data); |
| } |
| } |
| fprintf(stderr, "SSL error: %s\n", error.c_str()); |
| close(); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| void ClientConnection::on_ssl_read(const char* data, size_t len) { |
| int rc; |
| BIO_write(incoming_bio_, data, len); |
| |
| if (!is_handshake_done()) { |
| int rc = SSL_accept(ssl_); |
| if (has_ssl_error(rc)) { |
| return; |
| } |
| |
| char buf[SSL_BUF_SIZE]; |
| bool data_written = false; |
| int num_bytes; |
| while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) { |
| data_written = true; |
| internal_write(buf, num_bytes); |
| } |
| |
| if (is_handshake_done() && data_written) { |
| return; // Handshake is not completed; ingore remaining data |
| } |
| } else { |
| char buf[SSL_BUF_SIZE]; |
| while ((rc = SSL_read(ssl_, buf, sizeof(buf))) > 0) { |
| on_read(buf, rc); |
| } |
| has_ssl_error(rc); |
| } |
| } |
| |
| ServerConnection::ServerConnection(const Address& address, const ClientConnectionFactory& factory) |
| : tcp_(this) |
| , event_loop_(NULL) |
| , state_(STATE_CLOSED) |
| , rc_(0) |
| , address_(address) |
| , factory_(factory) |
| , ssl_context_(NULL) |
| , connection_attempts_(0) { |
| uv_mutex_init(&mutex_); |
| uv_cond_init(&cond_); |
| } |
| |
| ServerConnection::~ServerConnection() { |
| uv_mutex_destroy(&mutex_); |
| uv_cond_destroy(&cond_); |
| if (ssl_context_) { |
| SSL_CTX_free(ssl_context_); |
| } |
| } |
| |
| uv_loop_t* ServerConnection::loop() { |
| ScopedMutex l(&mutex_); |
| return event_loop_->loop(); |
| } |
| |
| bool ServerConnection::use_ssl(const String& key, const String& cert, |
| const String& ca_cert /*= ""*/, |
| bool require_client_cert /*= false*/) { |
| if (ssl_context_) { |
| SSL_CTX_free(ssl_context_); |
| } |
| |
| if ((ssl_context_ = SSL_CTX_new(SSL_SERVER_METHOD())) == NULL) { |
| print_ssl_error(); |
| return false; |
| } |
| |
| SSL_CTX_set_default_passwd_cb_userdata(ssl_context_, (void*)""); |
| SSL_CTX_set_default_passwd_cb(ssl_context_, on_password); |
| SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_NONE, NULL); |
| |
| { // Load server certificate |
| Scoped<X509> x509(load_cert(cert)); |
| if (!x509) return false; |
| if (SSL_CTX_use_certificate(ssl_context_, x509.get()) <= 0) { |
| print_ssl_error(); |
| return false; |
| } |
| } |
| |
| if (!ca_cert.empty()) { // Load CA certificate |
| |
| { // Add CA certificate to chain to send to the client |
| Scoped<X509> x509(load_cert(ca_cert)); |
| if (!x509) return false; |
| |
| if (SSL_CTX_add_extra_chain_cert(ssl_context_, x509.release()) <= |
| 0) { // Certificate freed by function |
| print_ssl_error(); |
| return false; |
| } |
| } |
| |
| if (require_client_cert) { |
| Scoped<X509> x509(load_cert(ca_cert)); |
| if (!x509) return false; |
| |
| // Add CA certificate to chain to validate peer certificate |
| X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context_); |
| if (X509_STORE_add_cert(cert_store, x509.get()) <= 0) { |
| print_ssl_error(); |
| return false; |
| } |
| |
| SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); |
| } |
| } |
| |
| Scoped<EVP_PKEY> pkey(load_private_key(key)); |
| if (!pkey) return false; |
| |
| if (SSL_CTX_use_PrivateKey(ssl_context_, pkey.get()) <= 0) { |
| print_ssl_error(); |
| return false; |
| } |
| |
| Scoped<DH> dh(dh_parameters()); |
| if (!dh || !SSL_CTX_set_tmp_dh(ssl_context_, dh.get())) { |
| print_ssl_error(); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| // Weaken the SSL connection, enforcing that it can only use TLS1.0 at max. |
| // This is used for testing client-side enforcement of more secure TLS |
| // protocols. |
| void ServerConnection::weaken_ssl() { |
| if (!ssl_context_) { |
| return; |
| } |
| |
| #ifdef SSL_CAN_SET_MAX_VERSION |
| SSL_CTX_set_max_proto_version(ssl_context_, TLS1_VERSION); |
| #else |
| SSL_CTX_set_options(ssl_context_, SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2); |
| #endif |
| } |
| |
| using datastax::internal::core::Task; |
| |
| class RunListen : public Task { |
| public: |
| RunListen(ServerConnection* server) |
| : server_(server) {} |
| |
| virtual void run(EventLoop* event_loop) { server_->internal_listen(); } |
| |
| private: |
| ServerConnection* server_; |
| }; |
| |
| class RunClose : public Task { |
| public: |
| RunClose(ServerConnection* server) |
| : server_(server) {} |
| |
| virtual void run(EventLoop* event_loop) { server_->internal_close(); } |
| |
| private: |
| ServerConnection* server_; |
| }; |
| |
| class RunTask : public Task { |
| public: |
| RunTask(const ServerConnectionTask::Ptr& task, const ServerConnection::Ptr& connection) |
| : task_(task) |
| , connection_(connection) {} |
| |
| virtual void run(EventLoop* event_loop) { task_->run(connection_.get()); } |
| |
| private: |
| ServerConnectionTask::Ptr task_; |
| ServerConnection::Ptr connection_; |
| }; |
| |
| void ServerConnection::listen(EventLoopGroup* event_loop_group) { |
| ScopedMutex l(&mutex_); |
| if (state_ != STATE_CLOSED) return; |
| rc_ = 0; |
| state_ = STATE_PENDING; |
| event_loop_ = event_loop_group->add(new RunListen(this)); |
| } |
| |
| int ServerConnection::wait_listen() { |
| ScopedMutex l(&mutex_); |
| while (state_ == STATE_PENDING) { |
| uv_cond_wait(&cond_, l.get()); |
| } |
| return rc_; |
| } |
| |
| void ServerConnection::close() { |
| ScopedMutex l(&mutex_); |
| if (state_ != STATE_LISTENING && state_ != STATE_PENDING) return; |
| state_ = STATE_CLOSING; |
| event_loop_->add(new RunClose(this)); |
| } |
| |
| void ServerConnection::wait_close() { |
| ScopedMutex l(&mutex_); |
| while (state_ == STATE_CLOSING) { |
| uv_cond_wait(&cond_, l.get()); |
| } |
| } |
| |
| unsigned ServerConnection::connection_attempts() const { return connection_attempts_.load(); } |
| void ServerConnection::run(const ServerConnectionTask::Ptr& task) { |
| ScopedMutex l(&mutex_); |
| if (state_ != STATE_LISTENING) return; |
| event_loop_->add(new RunTask(task, Ptr(this))); |
| } |
| |
| void ServerConnection::internal_listen() { |
| int rc = 0; |
| |
| rc = tcp_.init(loop()); |
| if (rc != 0) { |
| fprintf(stderr, "Unable to initialize socket\n"); |
| signal_listen(rc); |
| return; |
| } |
| |
| inc_ref(); // For the TCP handle |
| |
| Address::SocketStorage storage; |
| rc = tcp_.bind(address_.to_sockaddr(&storage)); |
| if (rc != 0) { |
| fprintf(stderr, "Unable to bind address %s\n", address_.to_string(true).c_str()); |
| uv_close(tcp_.as_handle(), on_close); |
| signal_listen(rc); |
| return; |
| } |
| |
| rc = uv_listen(tcp_.as_stream(), 128, on_connection); |
| if (rc != 0) { |
| fprintf(stderr, "Unable to listen on address %s\n", address_.to_string(true).c_str()); |
| uv_close(tcp_.as_handle(), on_close); |
| signal_listen(rc); |
| return; |
| } |
| |
| signal_listen(rc); |
| } |
| |
| int ServerConnection::accept(uv_stream_t* client) { return uv_accept(tcp_.as_stream(), client); } |
| |
| void ServerConnection::remove(ClientConnection* connection) { |
| clients_.erase(std::remove(clients_.begin(), clients_.end(), connection), clients_.end()); |
| maybe_close(); |
| } |
| |
| void ServerConnection::internal_close() { |
| for (ClientConnections::iterator it = clients_.begin(), end = clients_.end(); it != end; ++it) { |
| (*it)->close(); |
| } |
| maybe_close(); |
| } |
| |
| void ServerConnection::maybe_close() { |
| ScopedMutex l(&mutex_); |
| if (state_ == STATE_CLOSING && clients_.empty() && !uv_is_closing(tcp_.as_handle())) { |
| uv_close(tcp_.as_handle(), on_close); |
| } |
| } |
| |
| void ServerConnection::signal_listen(int rc) { |
| ScopedMutex l(&mutex_); |
| if (rc != 0) { |
| rc_ = rc; |
| state_ = STATE_CLOSING; |
| } else { |
| state_ = STATE_LISTENING; |
| } |
| uv_cond_signal(&cond_); |
| } |
| |
| void ServerConnection::signal_close() { |
| ScopedMutex l(&mutex_); |
| event_loop_ = NULL; |
| state_ = STATE_CLOSED; |
| uv_cond_signal(&cond_); |
| } |
| |
| void ServerConnection::on_connection(uv_stream_t* server, int status) { |
| ServerConnection* self = static_cast<ServerConnection*>(server->data); |
| self->handle_connection(status); |
| } |
| |
| void ServerConnection::handle_connection(int status) { |
| connection_attempts_.fetch_add(1); |
| |
| if (status != 0) { |
| fprintf(stderr, "Listen failure: %s\n", uv_strerror(status)); |
| return; |
| } |
| |
| ClientConnection* connection = factory_.create(this); |
| if (connection && connection->on_accept() != 0) { |
| delete connection; |
| return; |
| } |
| clients_.push_back(connection); |
| } |
| |
| void ServerConnection::on_close(uv_handle_t* handle) { |
| ServerConnection* self = static_cast<ServerConnection*>(handle->data); |
| self->handle_close(); |
| } |
| |
| void ServerConnection::handle_close() { |
| signal_close(); |
| dec_ref(); |
| } |
| |
| int ServerConnection::on_password(char* buf, int size, int rwflag, void* password) { |
| strncpy(buf, (char*)(password), size); |
| buf[size - 1] = '\0'; |
| return strlen(buf); |
| } |
| |
| } // namespace internal |
| |
| #define CHECK(pos, error) \ |
| do { \ |
| if ((pos) > end) { \ |
| fprintf(stderr, "Decoding error: %s\n", (error)); \ |
| return end + 1; \ |
| } \ |
| } while (0) |
| |
| const char* decode_int8(const char* input, const char* end, int8_t* value) { |
| CHECK(input + 1, "Unable to decode byte"); |
| *value = static_cast<int8_t>(input[0]); |
| return input + sizeof(int8_t); |
| } |
| |
| const char* decode_int16(const char* input, const char* end, int16_t* value) { |
| CHECK(input + 2, "Unable to decode signed short"); |
| *value = (static_cast<int16_t>(static_cast<uint8_t>(input[1])) << 0) | |
| (static_cast<int16_t>(static_cast<uint8_t>(input[0])) << 8); |
| return input + sizeof(int16_t); |
| } |
| |
| const char* decode_uint16(const char* input, const char* end, uint16_t* value) { |
| CHECK(input + 2, "Unable to decode unsigned short"); |
| *value = (static_cast<uint16_t>(static_cast<uint8_t>(input[1])) << 0) | |
| (static_cast<uint16_t>(static_cast<uint8_t>(input[0])) << 8); |
| return input + sizeof(uint16_t); |
| } |
| |
| const char* decode_int32(const char* input, const char* end, int32_t* value) { |
| CHECK(input + 4, "Unable to decode integer"); |
| *value = (static_cast<int32_t>(static_cast<uint8_t>(input[3])) << 0) | |
| (static_cast<int32_t>(static_cast<uint8_t>(input[2])) << 8) | |
| (static_cast<int32_t>(static_cast<uint8_t>(input[1])) << 16) | |
| (static_cast<int32_t>(static_cast<uint8_t>(input[0])) << 24); |
| return input + sizeof(int32_t); |
| } |
| |
| const char* decode_int64(const char* input, const char* end, int64_t* value) { |
| CHECK(input + 8, "Unable to decode long"); |
| *value = (static_cast<int64_t>(static_cast<uint8_t>(input[7])) << 0) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[6])) << 8) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[5])) << 16) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[4])) << 24) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[3])) << 32) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[2])) << 40) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[1])) << 48) | |
| (static_cast<int64_t>(static_cast<uint8_t>(input[0])) << 56); |
| return input + sizeof(int64_t); |
| } |
| |
| const char* decode_string(const char* input, const char* end, String* output) { |
| uint16_t len = 0; |
| const char* pos = decode_uint16(input, end, &len); |
| CHECK(pos + len, "Unable to decode string"); |
| output->assign(pos, len); |
| return pos + len; |
| } |
| |
| const char* decode_long_string(const char* input, const char* end, String* output) { |
| int32_t len = 0; |
| const char* pos = decode_int32(input, end, &len); |
| CHECK(pos + len, "Unable to decode long string"); |
| assert(len >= 0); |
| output->assign(pos, len); |
| return pos + len; |
| } |
| |
| const char* decode_bytes(const char* input, const char* end, String* output) { |
| int32_t len = 0; |
| const char* pos = decode_int32(input, end, &len); |
| if (len > 0) { |
| CHECK(pos + len, "Unable to decode bytes"); |
| output->assign(pos, len); |
| } |
| return pos + len; |
| } |
| |
| const char* decode_uuid(const char* input, CassUuid* output) { |
| output->time_and_version = static_cast<uint64_t>(static_cast<uint8_t>(input[3])); |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[2])) << 8; |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[1])) << 16; |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[0])) << 24; |
| |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[5])) << 32; |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[4])) << 40; |
| |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[7])) << 48; |
| output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[6])) << 56; |
| |
| output->clock_seq_and_node = 0; |
| for (size_t i = 0; i < 8; ++i) { |
| output->clock_seq_and_node |= static_cast<uint64_t>(static_cast<uint8_t>(input[15 - i])) |
| << (8 * i); |
| } |
| return input + 16; |
| } |
| |
| const char* decode_string_map(const char* input, const char* end, |
| Vector<std::pair<String, String> >* output) { |
| |
| uint16_t len = 0; |
| const char* pos = decode_uint16(input, end, &len); |
| output->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| String key; |
| String value; |
| pos = decode_string(pos, end, &key); |
| pos = decode_string(pos, end, &value); |
| output->push_back(std::pair<String, String>(key, value)); |
| } |
| return pos; |
| } |
| |
| const char* decode_stringlist(const char* input, const char* end, Vector<String>* output) { |
| uint16_t len = 0; |
| const char* pos = decode_uint16(input, end, &len); |
| output->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| String value; |
| pos = decode_string(pos, end, &value); |
| output->push_back(value); |
| } |
| return pos; |
| } |
| |
| const char* decode_values(const char* input, const char* end, Vector<String>* output) { |
| uint16_t len = 0; |
| const char* pos = decode_uint16(input, end, &len); |
| output->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| String value; |
| pos = decode_bytes(pos, end, &value); |
| output->push_back(value); |
| } |
| return pos; |
| } |
| |
| const char* decode_values_with_names(const char* input, const char* end, Vector<String>* names, |
| Vector<String>* values) { |
| uint16_t len = 0; |
| const char* pos = decode_uint16(input, end, &len); |
| names->reserve(len); |
| values->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| String name; |
| pos = decode_string(pos, end, &name); |
| names->push_back(name); |
| String value; |
| pos = decode_bytes(pos, end, &value); |
| values->push_back(value); |
| } |
| return pos; |
| } |
| |
| const char* decode_query_params_v1(const char* input, const char* end, bool is_execute, |
| QueryParameters* params) { |
| const char* pos = input; |
| if (is_execute) { |
| pos = decode_values(pos, end, ¶ms->values); |
| pos = decode_uint16(pos, end, ¶ms->consistency); |
| } else { |
| pos = decode_uint16(pos, end, ¶ms->consistency); |
| } |
| return pos; |
| } |
| |
| const char* decode_query_params_v2(const char* input, const char* end, QueryParameters* params) { |
| int8_t flags = 0; |
| const char* pos = input; |
| pos = decode_uint16(pos, end, ¶ms->consistency); |
| pos = decode_int8(pos, end, &flags); |
| params->flags = flags; |
| if (flags & QUERY_FLAG_VALUES) { |
| pos = decode_values(pos, end, ¶ms->values); |
| } |
| if (flags & QUERY_FLAG_PAGE_SIZE) { |
| pos = decode_int32(pos, end, ¶ms->result_page_size); |
| } |
| if (flags & QUERY_FLAG_PAGE_STATE) { |
| pos = decode_bytes(pos, end, ¶ms->paging_state); |
| } |
| if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) { |
| pos = decode_uint16(pos, end, ¶ms->serial_consistency); |
| } |
| return pos; |
| } |
| |
| const char* decode_query_params_v3v4(const char* input, const char* end, QueryParameters* params) { |
| int8_t flags = 0; |
| const char* pos = input; |
| pos = decode_uint16(pos, end, ¶ms->consistency); |
| pos = decode_int8(pos, end, &flags); |
| params->flags = flags; |
| if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) { |
| pos = decode_values_with_names(pos, end, ¶ms->names, ¶ms->values); |
| } else if (flags & QUERY_FLAG_VALUES) { |
| pos = decode_values(pos, end, ¶ms->values); |
| } |
| if (flags & QUERY_FLAG_PAGE_SIZE) { |
| pos = decode_int32(pos, end, ¶ms->result_page_size); |
| } |
| if (flags & QUERY_FLAG_PAGE_STATE) { |
| pos = decode_bytes(pos, end, ¶ms->paging_state); |
| } |
| if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) { |
| pos = decode_uint16(pos, end, ¶ms->serial_consistency); |
| } |
| if (flags & QUERY_FLAG_TIMESTAMP) { |
| pos = decode_int64(pos, end, ¶ms->timestamp); |
| } |
| return pos; |
| } |
| |
| const char* decode_query_params_v5(const char* input, const char* end, QueryParameters* params) { |
| int32_t flags = 0; |
| const char* pos = input; |
| pos = decode_uint16(pos, end, ¶ms->consistency); |
| pos = decode_int32(pos, end, &flags); |
| params->flags = flags; |
| if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) { |
| pos = decode_values_with_names(pos, end, ¶ms->names, ¶ms->values); |
| } else if (flags & QUERY_FLAG_VALUES) { |
| pos = decode_values(pos, end, ¶ms->values); |
| } |
| if (flags & QUERY_FLAG_PAGE_SIZE) { |
| pos = decode_int32(pos, end, ¶ms->result_page_size); |
| } |
| if (flags & QUERY_FLAG_PAGE_STATE) { |
| pos = decode_bytes(pos, end, ¶ms->paging_state); |
| } |
| if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) { |
| pos = decode_uint16(pos, end, ¶ms->serial_consistency); |
| } |
| if (flags & QUERY_FLAG_TIMESTAMP) { |
| pos = decode_int64(pos, end, ¶ms->timestamp); |
| } |
| if (flags & QUERY_FLAG_KEYSPACE) { |
| pos = decode_string(pos, end, ¶ms->keyspace); |
| } |
| return pos; |
| } |
| |
| const char* decode_query_params(int version, const char* input, const char* end, bool is_execute, |
| QueryParameters* params) { |
| const char* pos = input; |
| if (version == 1) { |
| pos = decode_query_params_v1(pos, end, is_execute, params); |
| } else if (version == 2) { |
| pos = decode_query_params_v2(pos, end, params); |
| } else if (version == 3 || version == 4) { |
| pos = decode_query_params_v3v4(pos, end, params); |
| } else if (version == 5) { |
| pos = decode_query_params_v5(pos, end, params); |
| } else { |
| assert(false && "Unsupported protocol version"); |
| } |
| return pos; |
| } |
| |
| const char* decode_prepare_params(int version, const char* input, const char* end, |
| PrepareParameters* params) { |
| const char* pos = input; |
| if (version >= 5) { |
| pos = decode_int32(pos, end, ¶ms->flags); |
| if (params->flags & PREPARE_FLAGS_KEYSPACE) { |
| pos = decode_string(pos, end, ¶ms->keyspace); |
| } |
| } |
| return pos; |
| } |
| |
| int32_t encode_int8(int8_t value, String* output) { |
| output->push_back(static_cast<char>(value)); |
| return 1; |
| } |
| |
| int32_t encode_int16(int16_t value, String* output) { |
| output->push_back(static_cast<char>(value >> 8)); |
| output->push_back(static_cast<char>(value >> 0)); |
| return 2; |
| } |
| |
| int32_t encode_uint16(uint16_t value, String* output) { |
| output->push_back(static_cast<char>(value >> 8)); |
| output->push_back(static_cast<char>(value >> 0)); |
| return 2; |
| } |
| |
| int32_t encode_int32(int32_t value, String* output) { |
| output->push_back(static_cast<char>(value >> 24)); |
| output->push_back(static_cast<char>(value >> 16)); |
| output->push_back(static_cast<char>(value >> 8)); |
| output->push_back(static_cast<char>(value >> 0)); |
| return 4; |
| } |
| |
| int32_t encode_string(const String& value, String* output) { |
| int32_t size = encode_uint16(value.size(), output) + value.size(); |
| output->append(value); |
| return size + value.size(); |
| } |
| |
| int32_t encode_string_list(const Vector<String>& value, String* output) { |
| int32_t size = encode_int16(value.size(), output); |
| for (Vector<String>::const_iterator it = value.begin(), end = value.end(); it != end; ++it) { |
| size += encode_string(*it, output); |
| } |
| return size; |
| } |
| |
| int32_t encode_bytes(const String& value, String* output) { |
| int32_t size = encode_int32(value.size(), output) + value.size(); |
| output->append(value); |
| return size + value.size(); |
| } |
| |
| int32_t encode_inet(const Address& value, String* output) { |
| uint8_t buf[16]; |
| uint8_t len = value.to_inet(buf); |
| encode_int8(len, output); |
| for (uint8_t i = 0; i < len; ++i) { |
| output->push_back(static_cast<char>(buf[i])); |
| } |
| encode_int32(value.port(), output); |
| return 1 + len + 4; |
| } |
| |
| int32_t encode_uuid(CassUuid uuid, String* output) { |
| uint64_t time_and_version = uuid.time_and_version; |
| char buf[16]; |
| buf[3] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| buf[2] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| buf[1] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| buf[0] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| |
| buf[5] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| buf[4] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| |
| buf[7] = static_cast<char>(time_and_version & 0x00000000000000FFLL); |
| time_and_version >>= 8; |
| buf[6] = static_cast<char>(time_and_version & 0x000000000000000FFLL); |
| |
| uint64_t clock_seq_and_node = uuid.clock_seq_and_node; |
| for (size_t i = 0; i < 8; ++i) { |
| buf[15 - i] = static_cast<char>(clock_seq_and_node & 0x00000000000000FFL); |
| clock_seq_and_node >>= 8; |
| } |
| output->append(buf, 16); |
| return 16; |
| } |
| |
| int32_t encode_string_map(const Map<String, Vector<String> >& value, String* output) { |
| int32_t size = encode_uint16(value.size(), output); |
| for (Map<String, Vector<String> >::const_iterator it = value.begin(); it != value.end(); ++it) { |
| size += encode_string(it->first, output); |
| size += encode_string_list(it->second, output); |
| } |
| return size; |
| } |
| |
| static String encode_header(int8_t version, int8_t flags, int16_t stream, int8_t opcode, |
| int32_t len) { |
| String header; |
| encode_int8(0x80 | version, &header); |
| encode_int8(flags, &header); |
| if (version >= 3) { |
| encode_int16(stream, &header); |
| } else { |
| encode_int8(stream, &header); |
| } |
| encode_int8(opcode, &header); |
| if (flags & FLAG_TRACING) { |
| len += 16; // Add enough space for the tracing ID |
| } |
| encode_int32(len, &header); |
| if (flags & FLAG_TRACING) { |
| UuidGen gen; |
| CassUuid tracing_id; |
| gen.generate_random(&tracing_id); |
| encode_uuid(tracing_id, &header); |
| } |
| return header; |
| } |
| |
| Type Type::text() { return Type(TYPE_VARCHAR); } |
| |
| Type Type::inet() { return Type(TYPE_INET); } |
| |
| Type Type::uuid() { return Type(TYPE_UUID); } |
| |
| Type Type::list(const Type& sub_type) { |
| Type type(TYPE_LIST); |
| type.types_.push_back(sub_type); |
| return type; |
| } |
| |
| void Type::encode(int protocol_version, String* output) const { |
| switch (type_) { |
| case TYPE_VARCHAR: |
| case TYPE_INET: |
| case TYPE_UUID: |
| encode_int16(type_, output); |
| break; |
| case TYPE_LIST: |
| encode_int16(type_, output); |
| types_[0].encode(protocol_version, output); |
| break; |
| default: |
| assert(false && "Unsupported type"); |
| break; |
| }; |
| } |
| |
| void Column::encode(int protocol_version, String* output) const { |
| encode_string(name_, output); |
| type_.encode(protocol_version, output); |
| } |
| |
| void Collection::encode(int protocol_version, String* output) const { |
| encode_int32(values_.size(), output); |
| for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) { |
| it->encode(protocol_version, output); |
| } |
| } |
| |
| Value::Value() |
| : type_(NUL) {} |
| |
| Value::Value(const String& value) |
| : type_(VALUE) |
| , value_(new String(value)) {} |
| |
| Value::Value(const Collection& collection) |
| : type_(COLLECTION) |
| , collection_(new Collection(collection)) {} |
| |
| Value::Value(const Value& other) |
| : type_(other.type_) { |
| if (type_ == VALUE) { |
| value_ = new String(*other.value_); |
| } else if (type_ == COLLECTION) { |
| collection_ = new Collection(*other.collection_); |
| } |
| } |
| |
| Value::~Value() { |
| if (type_ == VALUE) { |
| delete value_; |
| } else if (type_ == COLLECTION) { |
| delete collection_; |
| } |
| } |
| |
| void Value::encode(int protocol_version, String* output) const { |
| if (type_ == NUL) { |
| encode_int32(-1, output); |
| } else if (type_ == VALUE) { |
| encode_bytes(*value_, output); |
| } else if (type_ == COLLECTION) { |
| String buf; |
| collection_->encode(protocol_version, &buf); |
| encode_bytes(buf, output); |
| } |
| } |
| |
| Row::Builder& Row::Builder::text(const String& text) { |
| values_.push_back(Value(text)); |
| return *this; |
| } |
| |
| Row::Builder& Row::Builder::inet(const Address& inet) { |
| String value; |
| uint8_t buf[16]; |
| uint8_t len = inet.to_inet(buf); |
| for (uint8_t i = 0; i < len; ++i) { |
| value.push_back(static_cast<char>(buf[i])); |
| } |
| values_.push_back(Value(value)); |
| return *this; |
| } |
| |
| Row::Builder& Row::Builder::uuid(const CassUuid& uuid) { |
| String value; |
| encode_uuid(uuid, &value); |
| values_.push_back(Value(value)); |
| return *this; |
| } |
| |
| Row::Builder& Row::Builder::collection(const Collection& collection) { |
| values_.push_back(Value(collection)); |
| return *this; |
| } |
| |
| void Row::encode(int protocol_version, String* output) const { |
| for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) { |
| it->encode(protocol_version, output); |
| } |
| } |
| |
| String ResultSet::encode(int protocol_version) const { |
| String body; |
| |
| encode_int32(RESULT_ROWS, &body); // Result type |
| |
| encode_int32(RESULT_FLAG_GLOBAL_TABLESPEC, &body); // Flags |
| encode_int32(columns_.size(), &body); // Column count |
| encode_string(keyspace_name_, &body); // Global spec keyspace name |
| encode_string(table_name_, &body); // Global spec table name |
| |
| // Columns |
| for (Vector<Column>::const_iterator it = columns_.begin(), end = columns_.end(); it != end; |
| ++it) { |
| it->encode(protocol_version, &body); |
| } |
| |
| encode_int32(rows_.size(), &body); // Row count |
| |
| // Rows |
| for (Vector<Row>::const_iterator it = rows_.begin(), end = rows_.end(); it != end; ++it) { |
| it->encode(protocol_version, &body); |
| } |
| |
| return body; |
| } |
| |
| Action::Builder& Action::Builder::reset() { |
| first_.reset(); |
| last_ = NULL; |
| return *this; |
| } |
| |
| Action::Builder& Action::Builder::execute(Action* action) { |
| if (!first_) { |
| first_.reset(action); |
| } |
| if (last_) { |
| last_->next = action; |
| } |
| last_ = action; |
| return *this; |
| } |
| |
| Action::Builder& Action::Builder::execute_if(Action* action) { |
| if (last_ && last_->is_predicate()) { |
| static_cast<Predicate*>(last_)->then = action; |
| } |
| return *this; |
| } |
| |
| Action::Builder& Action::Builder::nop() { return execute(new Nop()); } |
| |
| Action::Builder& Action::Builder::wait(uint64_t timeout) { return execute(new Wait(timeout)); } |
| |
| Action::Builder& Action::Builder::close() { return execute(new Close()); } |
| |
| Action::Builder& Action::Builder::error(int32_t code, const String& message) { |
| return execute(new SendError(code, message)); |
| } |
| |
| Action::Builder& Action::Builder::invalid_protocol() { |
| return error(ERROR_PROTOCOL_ERROR, "Invalid or unsupported protocol version"); |
| } |
| |
| Action::Builder& Action::Builder::invalid_opcode() { |
| return error(ERROR_PROTOCOL_ERROR, "Invalid opcode (or not implemented)"); |
| } |
| |
| Action::Builder& Action::Builder::ready() { return execute(new SendReady()); } |
| |
| Action::Builder& Action::Builder::authenticate(const String& class_name) { |
| return execute(new SendAuthenticate(class_name)); |
| } |
| |
| Action::Builder& Action::Builder::auth_challenge(const String& token) { |
| return execute(new SendAuthChallenge(token)); |
| } |
| |
| Action::Builder& Action::Builder::auth_success(const String& token) { |
| return execute(new SendAuthSuccess(token)); |
| } |
| |
| Action::Builder& Action::Builder::supported() { return execute(new SendSupported()); } |
| |
| Action::Builder& Action::Builder::up_event(const Address& address) { |
| return execute(new SendUpEvent(address)); |
| } |
| |
| Action::Builder& Action::Builder::void_result() { return execute(new VoidResult()); } |
| |
| Action::Builder& Action::Builder::empty_rows_result(int32_t row_count) { |
| return execute(new EmptyRowsResult(row_count)); |
| } |
| |
| Action::Builder& Action::Builder::no_result() { return execute(new NoResult()); } |
| |
| Action::Builder& Action::Builder::match_query(const Matches& matches) { |
| return execute(new MatchQuery(matches)); |
| } |
| |
| Action::Builder& Action::Builder::client_options() { return execute(new ClientOptions()); } |
| |
| Action::Builder& Action::Builder::system_local() { return execute(new SystemLocal()); } |
| |
| Action::Builder& Action::Builder::system_local_dse() { return execute(new SystemLocalDse()); } |
| |
| Action::Builder& Action::Builder::system_peers() { return execute(new SystemPeers()); } |
| |
| Action::Builder& Action::Builder::system_peers_dse() { return execute(new SystemPeersDse()); } |
| |
| Action::Builder& Action::Builder::system_traces() { return execute(new SystemTraces()); } |
| |
| Action::Builder& Action::Builder::use_keyspace(const String& keyspace) { |
| return execute((new UseKeyspace(keyspace))); |
| } |
| |
| Action::Builder& Action::Builder::use_keyspace(const Vector<String>& keyspaces) { |
| return execute((new UseKeyspace(keyspaces))); |
| } |
| |
| Action::Builder& Action::Builder::plaintext_auth(const String& username, const String& password) { |
| return execute((new PlaintextAuth(username, password))); |
| } |
| |
| Action::Builder& Action::Builder::validate_startup() { return execute(new ValidateStartup()); } |
| |
| Action::Builder& Action::Builder::validate_credentials() { |
| return execute(new ValidateCredentials()); |
| } |
| |
| Action::Builder& Action::Builder::validate_auth_response() { |
| return execute(new ValidateAuthResponse()); |
| } |
| |
| Action::Builder& Action::Builder::validate_register() { return execute(new ValidateRegister()); } |
| |
| Action::Builder& Action::Builder::validate_query() { return execute(new ValidateQuery()); } |
| |
| Action::Builder& Action::Builder::set_registered_for_events() { |
| return execute(new SetRegisteredForEvents()); |
| } |
| |
| Action::Builder& Action::Builder::set_protocol_version() { |
| return execute(new SetProtocolVersion()); |
| } |
| |
| Action* Action::Builder::build() { return first_.release(); } |
| |
| Action::PredicateBuilder Action::Builder::is_address(const Address& address) { |
| return PredicateBuilder(execute(new IsAddress(address))); |
| } |
| |
| Action::PredicateBuilder Action::Builder::is_address(const String& address, int port) { |
| return PredicateBuilder(execute(new IsAddress(Address(address, port)))); |
| } |
| |
| Action::PredicateBuilder Action::Builder::is_query(const String& query) { |
| return PredicateBuilder(execute(new IsQuery(query))); |
| } |
| |
| void Action::run(Request* request) const { on_run(request); } |
| |
| void Action::run_next(Request* request) const { |
| if (next) { |
| next->on_run(request); |
| } |
| } |
| |
| Request::Request(int8_t version, int8_t flags, int16_t stream, int8_t opcode, const String& body, |
| ClientConnection* client) |
| : version_(version) |
| , flags_(flags) |
| , stream_(stream) |
| , opcode_(opcode) |
| , body_(body) |
| , client_(client) |
| , timer_action_(NULL) { |
| (void)flags_; // TODO: Implement custom payload etc. |
| } |
| |
| void Request::write(int8_t opcode, const String& body) { write(stream_, opcode, body); } |
| |
| void Request::write(int16_t stream, int8_t opcode, const String& body) { |
| client_->write(encode_header(version_, flags_, stream, opcode, body.size()) + body); |
| } |
| |
| void Request::error(int32_t code, const String& message) { |
| String body; |
| encode_int32(code, &body); |
| encode_string(message, &body); |
| write(OPCODE_ERROR, body); |
| } |
| |
| void Request::wait(uint64_t timeout, const Action* action) { |
| inc_ref(); |
| timer_action_ = action; |
| timer_.start(client_->server()->loop(), timeout, bind_callback(&Request::on_timeout, this)); |
| } |
| |
| void Request::close() { client_->close(); } |
| |
| bool Request::decode_startup(Options* options) { |
| return decode_string_map(start(), end(), options) == end(); |
| } |
| |
| bool Request::decode_credentials(Credentials* creds) { |
| return decode_string_map(start(), end(), creds) == end(); |
| } |
| |
| bool Request::decode_auth_response(String* token) { |
| return decode_bytes(start(), end(), token) == end(); |
| } |
| |
| bool Request::decode_register(EventTypes* types) { |
| return decode_stringlist(start(), end(), types) == end(); |
| } |
| |
| bool Request::decode_query(String* query, QueryParameters* params) { |
| return decode_query_params(version_, decode_long_string(start(), end(), query), end(), false, |
| params) == end(); |
| } |
| |
| bool Request::decode_execute(String* id, QueryParameters* params) { |
| return decode_query_params(version_, decode_string(start(), end(), id), end(), true, params) == |
| end(); |
| } |
| |
| bool Request::decode_prepare(String* query, PrepareParameters* params) { |
| return decode_prepare_params(version_, decode_long_string(start(), end(), query), end(), |
| params) == end(); |
| } |
| |
| const Address& Request::address() const { return client_->server()->address(); } |
| |
| const Host& Request::host(const Address& address) const { |
| return client_->cluster()->host(address); |
| } |
| |
| Hosts Request::hosts() const { return client_->cluster()->hosts(); } |
| |
| void Request::on_timeout(Timer* timer) { |
| timer_action_->run_next(this); |
| dec_ref(); |
| } |
| |
| void SendError::on_run(Request* request) const { request->error(code, message); } |
| |
| void SendReady::on_run(Request* request) const { request->write(OPCODE_READY, String()); } |
| |
| void SendAuthenticate::on_run(Request* request) const { |
| String body; |
| encode_string(class_name, &body); |
| request->write(OPCODE_AUTHENTICATE, body); |
| } |
| |
| void SendAuthChallenge::on_run(Request* request) const { |
| String body; |
| encode_string(token, &body); |
| request->write(OPCODE_AUTH_CHALLENGE, body); |
| } |
| |
| void SendAuthSuccess::on_run(Request* request) const { |
| String body; |
| encode_string(token, &body); |
| request->write(OPCODE_AUTH_SUCCESS, body); |
| } |
| |
| void SendSupported::on_run(Request* request) const { |
| String body; |
| encode_uint16(0, &body); |
| request->write(OPCODE_SUPPORTED, body); |
| } |
| |
| void SendUpEvent::on_run(Request* request) const { |
| request->write(-1, OPCODE_EVENT, StatusChangeEvent::encode(StatusChangeEvent::UP, address)); |
| run_next(request); |
| } |
| |
| void VoidResult::on_run(Request* request) const { |
| String body; |
| encode_int32(RESULT_VOID, &body); |
| request->write(OPCODE_RESULT, body); |
| } |
| |
| void EmptyRowsResult::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else { |
| String body; |
| encode_int32(RESULT_ROWS, &body); |
| encode_int32(0, &body); // Flags |
| encode_int32(0, &body); // Column count |
| encode_int32(row_count, &body); // Row count |
| request->write(OPCODE_RESULT, body); |
| } |
| } |
| |
| void NoResult::on_run(Request* request) const {} |
| |
| void MatchQuery::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| return; |
| } else { |
| for (Matches::const_iterator it = matches.begin(), end = matches.end(); it != end; ++it) { |
| if (it->first == query) { |
| request->write(OPCODE_RESULT, it->second.encode(request->version())); |
| return; |
| } |
| } |
| } |
| run_next(request); |
| } |
| |
| void ClientOptions::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query == CLIENT_OPTIONS_QUERY) { |
| ResultSet::Builder builder("client", "options"); |
| Row::Builder row_builder; |
| for (Options::const_iterator it = request->client()->options().begin(), |
| end = request->client()->options().end(); |
| it != end; ++it) { |
| builder.column((*it).first, Type::text()); |
| row_builder.text((*it).second); |
| } |
| ResultSet client_options = builder.row(row_builder.build()).build(); |
| |
| request->write(OPCODE_RESULT, client_options.encode(request->version())); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SystemLocal::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query.find(SELECT_LOCAL) != String::npos) { |
| const Host& host(request->host(request->address())); |
| |
| ResultSet local_rs = ResultSet::Builder("system", "local") |
| .column("key", Type::text()) |
| .column("data_center", Type::text()) |
| .column("rack", Type::text()) |
| .column("release_version", Type::text()) |
| .column("rpc_address", Type::inet()) |
| .column("partitioner", Type::text()) |
| .column("tokens", Type::list(Type::text())) |
| .row(Row::Builder() |
| .text(request->client()->server()->address().to_string()) |
| .text(host.dc) |
| .text(host.rack) |
| .text(CASSANDRA_VERSION) |
| .inet(request->client()->server()->address()) |
| .text(host.partitioner) |
| .collection(Collection::text(host.tokens)) |
| .build()) |
| .build(); |
| |
| request->write(OPCODE_RESULT, local_rs.encode(request->version())); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SystemLocalDse::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query.find(SELECT_LOCAL) != String::npos) { |
| const Host& host(request->host(request->address())); |
| |
| ResultSet local_rs = ResultSet::Builder("system", "local") |
| .column("key", Type::text()) |
| .column("data_center", Type::text()) |
| .column("rack", Type::text()) |
| .column("dse_version", Type::text()) |
| .column("release_version", Type::text()) |
| .column("rpc_address", Type::inet()) |
| .column("partitioner", Type::text()) |
| .column("tokens", Type::list(Type::text())) |
| .row(Row::Builder() |
| .text(request->client()->server()->address().to_string()) |
| .text(host.dc) |
| .text(host.rack) |
| .text(DSE_VERSION) |
| .text(DSE_CASSANDRA_VERSION) |
| .inet(request->client()->server()->address()) |
| .text(host.partitioner) |
| .collection(Collection::text(host.tokens)) |
| .build()) |
| .build(); |
| |
| request->write(OPCODE_RESULT, local_rs.encode(request->version())); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SystemPeers::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query.find(SELECT_PEERS) != String::npos) { |
| const String where_clause(" WHERE peer = '"); |
| ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers") |
| .column("peer", Type::inet()) |
| .column("data_center", Type::text()) |
| .column("rack", Type::text()) |
| .column("release_version", Type::text()) |
| .column("rpc_address", Type::inet()) |
| .column("tokens", Type::list(Type::text())); |
| |
| size_t pos = query.find(where_clause); |
| if (pos == String::npos) { |
| Hosts hosts(request->hosts()); |
| for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) { |
| const Host& host(*it); |
| if (host.address == request->address()) { |
| continue; |
| } |
| peers_builder.row(Row::Builder() |
| .inet(host.address) |
| .text(host.dc) |
| .text(host.rack) |
| .text(CASSANDRA_VERSION) |
| .inet(host.address) |
| .collection(Collection::text(host.tokens)) |
| .build()); |
| } |
| ResultSet peers_rs = peers_builder.build(); |
| request->write(OPCODE_RESULT, peers_rs.encode(request->version())); |
| } else { |
| pos += where_clause.size(); |
| size_t end_pos = query.find("'", pos); |
| if (end_pos == String::npos) { |
| request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause"); |
| return; |
| } |
| |
| String ip = query.substr(pos, end_pos - pos); |
| Address address(ip, request->address().port()); |
| if (!address.is_valid_and_resolved()) { |
| request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause"); |
| return; |
| } |
| |
| const Host& host(request->host(address)); |
| ResultSet peers_rs = peers_builder |
| .row(Row::Builder() |
| .inet(host.address) |
| .text(host.dc) |
| .text(host.rack) |
| .text(CASSANDRA_VERSION) |
| .inet(host.address) |
| .collection(Collection::text(host.tokens)) |
| .build()) |
| .build(); |
| request->write(OPCODE_RESULT, peers_rs.encode(request->version())); |
| } |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SystemPeersDse::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query.find(SELECT_PEERS) != String::npos) { |
| const String where_clause(" WHERE peer = '"); |
| ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers") |
| .column("peer", Type::inet()) |
| .column("data_center", Type::text()) |
| .column("rack", Type::text()) |
| .column("dse_version", Type::text()) |
| .column("release_version", Type::text()) |
| .column("rpc_address", Type::inet()) |
| .column("tokens", Type::list(Type::text())); |
| |
| size_t pos = query.find(where_clause); |
| if (pos == String::npos) { |
| Hosts hosts(request->hosts()); |
| for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) { |
| const Host& host(*it); |
| if (host.address == request->address()) { |
| continue; |
| } |
| peers_builder.row(Row::Builder() |
| .inet(host.address) |
| .text(host.dc) |
| .text(host.rack) |
| .text(DSE_VERSION) |
| .text(DSE_CASSANDRA_VERSION) |
| .inet(host.address) |
| .collection(Collection::text(host.tokens)) |
| .build()); |
| } |
| ResultSet peers_rs = peers_builder.build(); |
| request->write(OPCODE_RESULT, peers_rs.encode(request->version())); |
| } else { |
| pos += where_clause.size(); |
| size_t end_pos = query.find("'", pos); |
| if (end_pos == String::npos) { |
| request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause"); |
| return; |
| } |
| |
| String ip = query.substr(pos, end_pos - pos); |
| Address address(ip, request->address().port()); |
| if (!address.is_valid_and_resolved()) { |
| request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause"); |
| return; |
| } |
| |
| const Host& host(request->host(address)); |
| ResultSet peers_rs = peers_builder |
| .row(Row::Builder() |
| .inet(host.address) |
| .text(host.dc) |
| .text(host.rack) |
| .text(CASSANDRA_VERSION) |
| .inet(host.address) |
| .collection(Collection::text(host.tokens)) |
| .build()) |
| .build(); |
| request->write(OPCODE_RESULT, peers_rs.encode(request->version())); |
| } |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SystemTraces::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else if (query.find(SELECT_TRACES_SESSION) != String::npos) { |
| if (params.values.empty() || params.values.front().size() < 16) { |
| request->error(ERROR_INVALID_QUERY, "Query expects a UUID parameter (tracing)"); |
| return; |
| } |
| CassUuid tracing_id; |
| decode_uuid(params.values.front().data(), &tracing_id); |
| ResultSet session_rs = ResultSet::Builder("system_traces", "session") |
| .column("session_id", Type::uuid()) |
| .row(Row::Builder().uuid(tracing_id).build()) |
| .build(); |
| request->write(OPCODE_RESULT, session_rs.encode(request->version())); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void UseKeyspace::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (request->decode_query(&query, ¶ms)) { |
| trim(query); |
| if (query.compare(0, 3, "USE") == 0 || query.compare(0, 3, "use") == 0) { |
| String keyspace(query.substr(query.find_first_not_of(" \t", 3))); |
| for (Vector<String>::const_iterator it = keyspaces.begin(), end = keyspaces.end(); it != end; |
| ++it) { |
| String temp(*it); |
| if (keyspace == escape_id(temp)) { |
| String body; |
| encode_int32(RESULT_SET_KEYSPACE, &body); |
| encode_string(*it, &body); |
| request->client()->set_keyspace(*it); |
| request->write(OPCODE_RESULT, body); |
| return; |
| } |
| } |
| request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist"); |
| } else { |
| run_next(request); |
| } |
| } else { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } |
| } |
| |
| void PlaintextAuth::on_run(Request* request) const { |
| String token; |
| if (request->decode_auth_response(&token)) { |
| String username, password; |
| String::const_reverse_iterator last = token.rbegin(); |
| enum { PASSWORD, USERNAME } state = PASSWORD; |
| for (String::const_reverse_iterator it = token.rbegin(), end = token.rend(); it != end; ++it) { |
| if (*it == '\0') { |
| if (state == PASSWORD) { |
| password.assign(it.base(), last.base()); |
| state = USERNAME; |
| } else if (state == USERNAME) { |
| username.assign(it.base(), last.base()); |
| break; |
| } |
| last = it + 1; |
| } |
| } |
| |
| if (username == this->username && password == this->password) { |
| String body; |
| encode_int32(-1, &body); // Null bytes |
| request->write(OPCODE_AUTH_SUCCESS, body); |
| } else { |
| request->error(ERROR_BAD_CREDENTIALS, "Invalid credentials"); |
| } |
| } else { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message"); |
| } |
| } |
| |
| void ValidateStartup::on_run(Request* request) const { |
| Options options; |
| if (!request->decode_startup(&options)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid startup message"); |
| } else { |
| request->client()->set_options(options); |
| run_next(request); |
| } |
| } |
| |
| void ValidateCredentials::on_run(Request* request) const { |
| Credentials creds; |
| if (!request->decode_credentials(&creds)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid credentials message"); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void ValidateAuthResponse::on_run(Request* request) const { |
| String token; |
| if (!request->decode_auth_response(&token)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message"); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void ValidateRegister::on_run(Request* request) const { |
| EventTypes types; |
| if (!request->decode_register(&types)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid register message"); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void ValidateQuery::on_run(Request* request) const { |
| String query; |
| QueryParameters params; |
| if (!request->decode_query(&query, ¶ms)) { |
| request->error(ERROR_PROTOCOL_ERROR, "Invalid query message"); |
| } else { |
| run_next(request); |
| } |
| } |
| |
| void SetRegisteredForEvents::on_run(Request* request) const { |
| request->client()->set_registered_for_events(); |
| run_next(request); |
| } |
| |
| void SetProtocolVersion::on_run(Request* request) const { |
| request->client()->set_protocol_version(request->version()); |
| run_next(request); |
| } |
| |
| bool IsAddress::is_true(Request* request) const { |
| return request->client()->server()->address() == address; |
| } |
| |
| bool IsQuery::is_true(Request* request) const { |
| String query; |
| QueryParameters params; |
| return request->decode_query(&query, ¶ms) && query == this->query; |
| } |
| |
| RequestHandler::RequestHandler(RequestHandler::Builder* builder, |
| int lowest_supported_protocol_version, |
| int highest_supported_protocol_version) |
| : invalid_protocol_(builder->on_invalid_protocol().build()) |
| , invalid_opcode_(builder->on_invalid_opcode().build()) |
| , lowest_supported_protocol_version_(lowest_supported_protocol_version) |
| , highest_supported_protocol_version_(highest_supported_protocol_version) {} |
| |
| const RequestHandler* RequestHandler::Builder::build() { |
| RequestHandler* handler(new RequestHandler(this, lowest_supported_protocol_version_, |
| highest_supported_protocol_version_)); |
| |
| handler->actions_[OPCODE_STARTUP].reset(actions_[OPCODE_STARTUP].build()); |
| handler->actions_[OPCODE_OPTIONS].reset(actions_[OPCODE_OPTIONS].build()); |
| handler->actions_[OPCODE_CREDENTIALS].reset(actions_[OPCODE_CREDENTIALS].build()); |
| handler->actions_[OPCODE_QUERY].reset(actions_[OPCODE_QUERY].build()); |
| handler->actions_[OPCODE_PREPARE].reset(actions_[OPCODE_PREPARE].build()); |
| handler->actions_[OPCODE_EXECUTE].reset(actions_[OPCODE_EXECUTE].build()); |
| handler->actions_[OPCODE_REGISTER].reset(actions_[OPCODE_REGISTER].build()); |
| handler->actions_[OPCODE_AUTH_RESPONSE].reset(actions_[OPCODE_AUTH_RESPONSE].build()); |
| |
| return handler; |
| } |
| |
| void ProtocolHandler::decode(ClientConnection* client, const char* data, int32_t len) { |
| buffer_.append(data, len); |
| int32_t result = decode_frame(client, buffer_.data(), buffer_.size()); |
| if (result > 0) { |
| if (static_cast<size_t>(result) == buffer_.size()) { |
| buffer_.clear(); |
| } else { |
| // Not efficient, but concise. Copy the consumed part of the buffer |
| // forward then resize the buffer to what's left over. |
| std::copy(buffer_.begin() + result, buffer_.end(), buffer_.begin()); |
| buffer_.resize(buffer_.size() - result); |
| } |
| } |
| } |
| |
| int32_t ProtocolHandler::decode_frame(ClientConnection* client, const char* frame, int32_t len) { |
| const char* pos = frame; |
| const char* end = pos + len; |
| int32_t remaining = len; |
| |
| while (remaining > 0) { |
| switch (state_) { |
| case PROTOCOL_VERSION: |
| // Version requires a single byte and that's guaranteed by the loop check. |
| version_ = *pos++; |
| remaining--; |
| if (version_ < request_handler_->lowest_supported_protocol_version() || |
| version_ > request_handler_->highest_supported_protocol_version()) { |
| // Use the highest supported protocol version unless it's less than the lowest supported |
| // then use the original request's protocol version. |
| int8_t response_version = request_handler_->highest_supported_protocol_version(); |
| if (version_ < request_handler_->lowest_supported_protocol_version()) { |
| response_version = version_; |
| } |
| Request::Ptr request( |
| new Request(response_version, flags_, stream_, opcode_, String(), client)); |
| request_handler_->invalid_protocol(request.get()); |
| client->close(); |
| return -1; |
| } |
| state_ = HEADER; |
| break; |
| case HEADER: |
| if ((version_ == 1 || version_ == 2) && remaining >= 7) { |
| flags_ = *pos++; |
| stream_ = *pos++; |
| opcode_ = *pos++; |
| pos = decode_int32(pos, end, &length_); |
| remaining -= 7; |
| } else if (version_ >= 3 && remaining >= 8) { |
| flags_ = *pos++; |
| pos = decode_int16(pos, end, &stream_); |
| opcode_ = *pos++; |
| pos = decode_int32(pos, end, &length_); |
| remaining -= 8; |
| } else { |
| return len - remaining; |
| } |
| |
| if (length_ == 0) { |
| decode_body(client, pos, 0); |
| version_ = 0; |
| flags_ = 0; |
| opcode_ = 0; |
| length_ = 0; |
| state_ = PROTOCOL_VERSION; |
| } else { |
| state_ = BODY; |
| } |
| break; |
| case BODY: |
| if (remaining >= length_) { |
| decode_body(client, pos, length_); |
| pos += length_; |
| remaining -= length_; |
| } else { |
| return len - remaining; |
| } |
| version_ = 0; |
| flags_ = 0; |
| opcode_ = 0; |
| length_ = 0; |
| state_ = PROTOCOL_VERSION; |
| break; |
| } |
| } |
| |
| return len; // All bytes have been consumed |
| } |
| |
| void ProtocolHandler::decode_body(ClientConnection* client, const char* body, int32_t len) { |
| Request::Ptr request(new Request(version_, flags_, stream_, opcode_, String(body, len), client)); |
| request_handler_->run(request.get()); |
| } |
| |
| void ClientConnection::on_read(const char* data, size_t len) { handler_.decode(this, data, len); } |
| |
| Event::Event(const String& event_body) |
| : event_body_(event_body) {} |
| |
| void Event::run(internal::ServerConnection* server_connection) { |
| for (internal::ClientConnections::const_iterator it = server_connection->clients().begin(), |
| end = server_connection->clients().end(); |
| it != end; ++it) { |
| ClientConnection* client = static_cast<ClientConnection*>(*it); |
| if (client->is_registered_for_events() && client->protocol_version() > 0) { |
| client->write( |
| encode_header(client->protocol_version(), 0, -1, OPCODE_EVENT, event_body_.size()) + |
| event_body_); |
| } |
| } |
| } |
| |
| Event::Ptr TopologyChangeEvent::new_node(const Address& address) { |
| return Ptr(new TopologyChangeEvent(NEW_NODE, address)); |
| } |
| |
| Event::Ptr TopologyChangeEvent::moved_node(const Address& address) { |
| return Ptr(new TopologyChangeEvent(MOVED_NODE, address)); |
| } |
| |
| Event::Ptr TopologyChangeEvent::removed_node(const Address& address) { |
| return Ptr(new TopologyChangeEvent(REMOVED_NODE, address)); |
| } |
| |
| String TopologyChangeEvent::encode(TopologyChangeEvent::Type type, const Address& address) { |
| String body; |
| encode_string("TOPOLOGY_CHANGE", &body); |
| switch (type) { |
| case NEW_NODE: |
| encode_string("NEW_NODE", &body); |
| break; |
| case MOVED_NODE: |
| encode_string("MOVED_NODE", &body); |
| break; |
| case REMOVED_NODE: |
| encode_string("REMOVED_NODE", &body); |
| break; |
| }; |
| encode_inet(address, &body); |
| return body; |
| } |
| |
| Event::Ptr StatusChangeEvent::up(const Address& address) { |
| return Ptr(new StatusChangeEvent(UP, address)); |
| } |
| |
| Event::Ptr StatusChangeEvent::down(const Address& address) { |
| return Ptr(new StatusChangeEvent(DOWN, address)); |
| } |
| |
| String StatusChangeEvent::encode(Type type, const Address& address) { |
| String body; |
| encode_string("STATUS_CHANGE", &body); |
| switch (type) { |
| case UP: |
| encode_string("UP", &body); |
| break; |
| case DOWN: |
| encode_string("DOWN", &body); |
| break; |
| }; |
| encode_inet(address, &body); |
| return body; |
| } |
| |
| Event::Ptr SchemaChangeEvent::keyspace(SchemaChangeEvent::Type type, const String& keyspace_name) { |
| return Ptr(new SchemaChangeEvent(KEYSPACE, type, keyspace_name)); |
| } |
| |
| Event::Ptr SchemaChangeEvent::table(SchemaChangeEvent::Type type, const String& keyspace_name, |
| const String& table_name) { |
| return Ptr(new SchemaChangeEvent(TABLE, type, keyspace_name, table_name)); |
| } |
| |
| Event::Ptr SchemaChangeEvent::user_type(SchemaChangeEvent::Type type, const String& keyspace_name, |
| const String& user_type_name) { |
| return Ptr(new SchemaChangeEvent(USER_TYPE, type, keyspace_name, user_type_name)); |
| } |
| |
| Event::Ptr SchemaChangeEvent::function(SchemaChangeEvent::Type type, const String& keyspace_name, |
| const String& function_name, |
| const Vector<String>& args_types) { |
| return Ptr(new SchemaChangeEvent(FUNCTION, type, keyspace_name, function_name, args_types)); |
| } |
| |
| Event::Ptr SchemaChangeEvent::aggregate(SchemaChangeEvent::Type type, const String& keyspace_name, |
| const String& aggregate_name, |
| const Vector<String>& args_types) { |
| return Ptr(new SchemaChangeEvent(AGGREGATE, type, keyspace_name, aggregate_name, args_types)); |
| } |
| |
| String SchemaChangeEvent::encode(SchemaChangeEvent::Target target, SchemaChangeEvent::Type type, |
| const String& keyspace_name, const String& target_name, |
| const Vector<String>& arg_types) { |
| String body; |
| encode_string("SCHEMA_CHANGE", &body); |
| switch (type) { |
| case CREATED: |
| encode_string("CREATED", &body); |
| break; |
| case UPDATED: |
| encode_string("UPDATED", &body); |
| break; |
| case DROPPED: |
| encode_string("DROPPED", &body); |
| break; |
| } |
| switch (target) { |
| case KEYSPACE: |
| encode_string("KEYSPACE", &body); |
| encode_string(keyspace_name, &body); |
| break; |
| case TABLE: |
| encode_string("TABLE", &body); |
| encode_string(keyspace_name, &body); |
| encode_string(target_name, &body); |
| break; |
| ; |
| case USER_TYPE: |
| encode_string("TYPE", &body); |
| encode_string(keyspace_name, &body); |
| encode_string(target_name, &body); |
| break; |
| case FUNCTION: |
| encode_string("FUNCTION", &body); |
| encode_string(keyspace_name, &body); |
| encode_string(target_name, &body); |
| encode_string_list(arg_types, &body); |
| break; |
| case AGGREGATE: |
| encode_string("AGGREGATE", &body); |
| encode_string(keyspace_name, &body); |
| encode_string(target_name, &body); |
| encode_string_list(arg_types, &body); |
| break; |
| } |
| return body; |
| } |
| |
| void Cluster::init(AddressGenerator& generator, ClientConnectionFactory& factory, |
| size_t num_nodes_dc1, size_t num_nodes_dc2) { |
| for (size_t i = 0; i < num_nodes_dc1; ++i) { |
| create_and_add_server(generator, factory, "dc1"); |
| } |
| for (size_t i = 0; i < num_nodes_dc2; ++i) { |
| create_and_add_server(generator, factory, "dc2"); |
| } |
| } |
| |
| Cluster::~Cluster() { stop_all(); } |
| |
| String Cluster::use_ssl(const String& cn /*= ""*/) { |
| String key(Ssl::generate_key()); |
| String cert(Ssl::generate_cert(key, cn)); |
| for (size_t i = 0; i < servers_.size(); ++i) { |
| Server& server = servers_[i]; |
| if (!server.connection->use_ssl(key, cert)) { |
| return ""; |
| } |
| } |
| return cert; |
| } |
| |
| int Cluster::start_all(EventLoopGroup* event_loop_group) { |
| start_all_async(event_loop_group); |
| for (size_t i = 0; i < servers_.size(); ++i) { |
| Server& server = servers_[i]; |
| int rc = server.connection->wait_listen(); |
| if (rc != 0) return rc; |
| } |
| return 0; |
| } |
| |
| void Cluster::start_all_async(EventLoopGroup* event_loop_group) { |
| for (size_t i = 0; i < servers_.size(); ++i) { |
| Server& server = servers_[i]; |
| server.connection->listen(event_loop_group); |
| } |
| } |
| |
| void Cluster::stop_all() { |
| stop_all_async(); |
| for (size_t i = 0; i < servers_.size(); ++i) { |
| Server& server = servers_[i]; |
| server.connection->wait_close(); |
| } |
| } |
| |
| void Cluster::stop_all_async() { |
| for (size_t i = 0; i < servers_.size(); ++i) { |
| Server& server = servers_[i]; |
| server.connection->close(); |
| } |
| } |
| |
| int Cluster::start(EventLoopGroup* event_loop_group, size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return -1; |
| } |
| Server& server = servers_[node - 1]; |
| server.connection->listen(event_loop_group); |
| return server.connection->wait_listen(); |
| } |
| |
| void Cluster::start_async(EventLoopGroup* event_loop_group, size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return; |
| } |
| Server& server = servers_[node - 1]; |
| server.connection->listen(event_loop_group); |
| } |
| |
| void Cluster::stop(size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return; |
| } |
| Server& server = servers_[node - 1]; |
| server.connection->close(); |
| server.connection->wait_close(); |
| } |
| |
| void Cluster::stop_async(size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return; |
| } |
| Server& server = servers_[node - 1]; |
| server.connection->close(); |
| } |
| |
| int Cluster::add(EventLoopGroup* event_loop_group, size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return -1; |
| } |
| Server& server = servers_[node - 1]; |
| bool is_removed = server.is_removed.exchange(false); |
| server.connection->listen(event_loop_group); |
| int rc = server.connection->wait_listen(); |
| |
| // Send the added node event after starting the socket |
| if (is_removed) { // Only send topology change event if node was previously removed |
| event(TopologyChangeEvent::new_node(server.connection->address())); |
| } |
| |
| return rc; |
| } |
| |
| void Cluster::remove(size_t node) { |
| if (node < 1 || node > servers_.size()) { |
| return; |
| } |
| Server& server = servers_[node - 1]; |
| bool is_removed = server.is_removed.exchange(true); |
| |
| // Send the remove node event before closing the socket |
| if (!is_removed) { // Only send the topology change event if node was previously active |
| event(TopologyChangeEvent::removed_node(server.connection->address())); |
| } |
| |
| server.connection->close(); |
| server.connection->wait_close(); |
| } |
| |
| const Host& Cluster::host(const Address& address) const { |
| for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) { |
| if (it->host.address == address) { |
| return it->host; |
| } |
| } |
| |
| throw Exception(ERROR_PROTOCOL_ERROR, "Unable to find host " + address.to_string()); |
| } |
| |
| Hosts Cluster::hosts() const { |
| Hosts hosts; |
| hosts.reserve(servers_.size()); |
| for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) { |
| if (!it->is_removed.load()) { |
| hosts.push_back(it->host); |
| } |
| } |
| return hosts; |
| } |
| |
| unsigned Cluster::connection_attempts(size_t node) const { |
| if (node < 1 || node > servers_.size()) { |
| return 0; |
| } |
| const Server& server = servers_[node - 1]; |
| return server.connection->connection_attempts(); |
| } |
| |
| int Cluster::create_and_add_server(AddressGenerator& generator, ClientConnectionFactory& factory, |
| const String& dc) { |
| Address address(generator.next()); |
| Server server(Host(address, dc, "rack1", token_rng_), |
| internal::ServerConnection::Ptr(new internal::ServerConnection(address, factory))); |
| |
| servers_.push_back(server); |
| return static_cast<int>(servers_.size()); |
| } |
| |
| void Cluster::event(const Event::Ptr& event) { |
| for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) { |
| it->connection->run(internal::ServerConnectionTask::Ptr(event)); |
| } |
| } |
| |
| Address Ipv4AddressGenerator::next() { |
| char buf[32]; |
| sprintf(buf, "%d.%d.%d.%d", (ip_ >> 24) & 0xff, (ip_ >> 16) & 0xff, (ip_ >> 8) & 0xff, |
| ip_ & 0xff); |
| ip_++; |
| return Address(buf, port_); |
| } |
| |
| Host::Host(const Address& address, const String& dc, const String& rack, MT19937_64& token_rng, |
| int num_tokens) |
| : address(address) |
| , dc(dc) |
| , rack(rack) |
| , partitioner("org.apache.cassandra.dht.Murmur3Partitioner") { |
| // Only murmur tokens currently supported |
| for (int i = 0; i < num_tokens; ++i) { |
| OStringStream ss; |
| ss << static_cast<int64_t>(token_rng()); |
| tokens.push_back(ss.str()); |
| } |
| } |
| |
| SimpleEventLoopGroup::SimpleEventLoopGroup(size_t num_threads, |
| const String& thread_name /*= "mockssandra"*/) |
| : RoundRobinEventLoopGroup(num_threads) { |
| int rc = init(thread_name); |
| UNUSED_(rc); |
| assert(rc == 0 && "Unable to initialize simple event loop"); |
| run(); |
| } |
| |
| SimpleEventLoopGroup::~SimpleEventLoopGroup() { |
| close_handles(); |
| join(); |
| } |
| |
| SimpleRequestHandlerBuilder::SimpleRequestHandlerBuilder() |
| : RequestHandler::Builder() { |
| on(OPCODE_STARTUP).validate_startup().ready(); |
| on(OPCODE_OPTIONS).supported(); |
| on(OPCODE_CREDENTIALS).validate_credentials().ready(); |
| on(OPCODE_AUTH_RESPONSE).validate_auth_response().auth_success(""); |
| on(OPCODE_REGISTER) |
| .validate_register() |
| .set_protocol_version() |
| .set_registered_for_events() |
| .ready(); |
| on(OPCODE_QUERY).system_local().system_peers().empty_rows_result(1); |
| } |
| |
| AuthRequestHandlerBuilder::AuthRequestHandlerBuilder(const String& username, const String& password) |
| : SimpleRequestHandlerBuilder() { |
| on(mockssandra::OPCODE_STARTUP).validate_startup().authenticate("com.datastax.SomeAuthenticator"); |
| on(mockssandra::OPCODE_AUTH_RESPONSE).validate_auth_response().plaintext_auth(username, password); |
| } |
| |
| } // namespace mockssandra |