| /* |
| Copyright (c) 2014-2017 DataStax |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| #include "socket.hpp" |
| |
| #include "logger.hpp" |
| |
| #define SSL_READ_SIZE 8192 |
| #define SSL_WRITE_SIZE 8192 |
| #define SSL_ENCRYPTED_BUFS_COUNT 16 |
| |
| #define MAX_BUFFER_REUSE_NO 8 |
| #define BUFFER_REUSE_SIZE 64 * 1024 |
| |
| using namespace datastax::internal; |
| using namespace datastax::internal::core; |
| |
| typedef Vector<uv_buf_t> UvBufVec; |
| |
| /** |
| * A basic socket write handler. |
| */ |
| class SocketWrite : public SocketWriteBase { |
| public: |
| SocketWrite(Socket* socket) |
| : SocketWriteBase(socket) {} |
| |
| size_t flush(); |
| }; |
| |
| size_t SocketWrite::flush() { |
| size_t total = 0; |
| if (!is_flushed_ && !buffers_.empty()) { |
| UvBufVec bufs; |
| |
| bufs.reserve(buffers_.size()); |
| |
| for (BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end(); it != end; ++it) { |
| total += it->size(); |
| bufs.push_back(uv_buf_init(const_cast<char*>(it->data()), it->size())); |
| } |
| |
| is_flushed_ = true; |
| uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp()); |
| uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SocketWrite::on_write); |
| } |
| return total; |
| } |
| |
| SocketHandler::~SocketHandler() { |
| while (!buffer_reuse_list_.empty()) { |
| uv_buf_t buf = buffer_reuse_list_.top(); |
| Memory::free(buf.base); |
| buffer_reuse_list_.pop(); |
| } |
| } |
| |
| SocketWriteBase* SocketHandler::new_pending_write(Socket* socket) { |
| return new SocketWrite(socket); |
| } |
| |
| void SocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) { |
| if (suggested_size <= BUFFER_REUSE_SIZE) { |
| if (!buffer_reuse_list_.empty()) { |
| *buf = buffer_reuse_list_.top(); |
| buffer_reuse_list_.pop(); |
| } else { |
| *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(BUFFER_REUSE_SIZE)), |
| BUFFER_REUSE_SIZE); |
| } |
| } else { |
| *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(suggested_size)), suggested_size); |
| } |
| } |
| |
| void SocketHandler::free_buffer(const uv_buf_t* buf) { |
| if (buf->len == BUFFER_REUSE_SIZE && buffer_reuse_list_.size() < MAX_BUFFER_REUSE_NO) { |
| buffer_reuse_list_.push(*buf); |
| return; |
| } |
| Memory::free(buf->base); |
| } |
| |
| /** |
| * A SSL socket write handler. |
| */ |
| class SslSocketWrite : public SocketWriteBase { |
| public: |
| SslSocketWrite(Socket* socket, SslSession* ssl_session) |
| : SocketWriteBase(socket) |
| , ssl_session_(ssl_session) |
| , encrypted_size_(0) {} |
| |
| virtual size_t flush(); |
| |
| private: |
| void encrypt(); |
| |
| static void on_write(uv_write_t* req, int status); |
| |
| private: |
| SslSession* ssl_session_; |
| size_t encrypted_size_; |
| }; |
| |
| size_t SslSocketWrite::flush() { |
| size_t total = 0; |
| if (!is_flushed_ && !buffers_.empty()) { |
| rb::RingBuffer::Position prev_pos = ssl_session_->outgoing().write_position(); |
| |
| encrypt(); |
| |
| SmallVector<uv_buf_t, SSL_ENCRYPTED_BUFS_COUNT> bufs; |
| total = encrypted_size_ = ssl_session_->outgoing().peek_multiple(prev_pos, &bufs); |
| |
| LOG_TRACE("Sending %u encrypted bytes", static_cast<unsigned int>(encrypted_size_)); |
| |
| uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp()); |
| uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SslSocketWrite::on_write); |
| |
| is_flushed_ = true; |
| } |
| return total; |
| } |
| |
| void SslSocketWrite::encrypt() { |
| char buf[SSL_WRITE_SIZE]; |
| |
| size_t copied = 0; |
| size_t offset = 0; |
| size_t total = 0; |
| |
| BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end(); |
| |
| LOG_TRACE("Copying %u bufs", static_cast<unsigned int>(buffers_.size())); |
| |
| bool is_done = (it == end); |
| |
| while (!is_done) { |
| assert(it->size() > 0); |
| size_t size = it->size(); |
| |
| size_t to_copy = size - offset; |
| size_t available = SSL_WRITE_SIZE - copied; |
| if (available < to_copy) { |
| to_copy = available; |
| } |
| |
| memcpy(buf + copied, it->data() + offset, to_copy); |
| |
| copied += to_copy; |
| offset += to_copy; |
| total += to_copy; |
| |
| if (offset == size) { |
| ++it; |
| offset = 0; |
| } |
| |
| is_done = (it == end); |
| |
| if (is_done || copied == SSL_WRITE_SIZE) { |
| int rc = ssl_session_->encrypt(buf, copied); |
| if (rc <= 0 && ssl_session_->has_error()) { |
| LOG_ERROR("Unable to encrypt data: %s", ssl_session_->error_message().c_str()); |
| socket_->defunct(); |
| return; |
| } |
| copied = 0; |
| } |
| } |
| |
| LOG_TRACE("Copied %u bytes for encryption", static_cast<unsigned int>(total)); |
| } |
| |
| void SslSocketWrite::on_write(uv_write_t* req, int status) { |
| if (status == 0) { |
| SslSocketWrite* socket_write = static_cast<SslSocketWrite*>(req->data); |
| socket_write->ssl_session_->outgoing().read(NULL, socket_write->encrypted_size_); |
| } |
| SocketWriteBase::on_write(req, status); |
| } |
| |
| SocketWriteBase* SslSocketHandler::new_pending_write(Socket* socket) { |
| return new SslSocketWrite(socket, ssl_session_.get()); |
| } |
| |
| void SslSocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) { |
| buf->base = ssl_session_->incoming().peek_writable(&suggested_size); |
| buf->len = suggested_size; |
| } |
| |
| void SslSocketHandler::on_read(Socket* socket, ssize_t nread, const uv_buf_t* buf) { |
| if (nread < 0) return; |
| |
| ssl_session_->incoming().commit(nread); |
| char decrypted[SSL_READ_SIZE]; |
| int rc = 0; |
| while ((rc = ssl_session_->decrypt(decrypted, sizeof(decrypted))) > 0) { |
| on_ssl_read(socket, decrypted, rc); |
| } |
| if (rc <= 0 && ssl_session_->has_error()) { |
| if (ssl_session_->error_code() == CASS_ERROR_SSL_CLOSED) { |
| LOG_DEBUG("SSL session closed"); |
| socket->close(); |
| } else { |
| LOG_ERROR("Unable to decrypt data: %s", ssl_session_->error_message().c_str()); |
| socket->defunct(); |
| } |
| } |
| } |
| |
| uv_tcp_t* SocketWriteBase::tcp() { return &socket_->tcp_; } |
| |
| void SocketWriteBase::on_close() { |
| for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) { |
| (*i)->on_close(); |
| } |
| } |
| |
| int32_t SocketWriteBase::write(SocketRequest* request) { |
| size_t last_buffer_size = buffers_.size(); |
| int32_t request_size = request->encode(&buffers_); |
| if (request_size <= 0) { |
| buffers_.resize(last_buffer_size); // Rollback |
| return request_size; |
| } |
| |
| requests_.push_back(request); |
| |
| return request_size; |
| } |
| |
| void SocketWriteBase::on_write(uv_write_t* req, int status) { |
| SocketWriteBase* pending_write = static_cast<SocketWriteBase*>(req->data); |
| pending_write->handle_write(req, status); |
| } |
| |
| void SocketWriteBase::handle_write(uv_write_t* req, int status) { |
| Socket* socket = socket_; |
| |
| if (status != 0) { |
| if (!socket->is_closing()) { |
| LOG_ERROR("Socket write error '%s'", uv_strerror(status)); |
| socket->defunct(); |
| } |
| } |
| |
| if (socket->handler_) { |
| for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) { |
| socket->handler_->on_write(socket, status, *i); |
| } |
| } |
| |
| socket->pending_writes_.remove(this); |
| |
| if (socket->free_writes_.size() < socket->max_reusable_write_objects_) { |
| clear(); |
| socket->free_writes_.push_back(this); |
| } else { |
| delete this; |
| } |
| |
| socket->flush(); |
| } |
| |
| Socket::Socket(const Address& address, size_t max_reusable_write_objects) |
| : is_defunct_(false) |
| , max_reusable_write_objects_(max_reusable_write_objects) |
| , address_(address) { |
| tcp_.data = this; |
| } |
| |
| Socket::~Socket() { cleanup_free_writes(); } |
| |
| void Socket::set_handler(SocketHandlerBase* handler) { |
| handler_.reset(handler); |
| cleanup_free_writes(); |
| free_writes_.clear(); |
| if (handler_) { |
| uv_read_start(reinterpret_cast<uv_stream_t*>(&tcp_), Socket::alloc_buffer, Socket::on_read); |
| } else { |
| uv_read_stop(reinterpret_cast<uv_stream_t*>(&tcp_)); |
| } |
| } |
| |
| int32_t Socket::write(SocketRequest* request) { |
| if (!handler_) { |
| return SocketRequest::SOCKET_REQUEST_ERROR_NO_HANDLER; |
| } |
| |
| if (is_closing()) { |
| return SocketRequest::SOCKET_REQUEST_ERROR_CLOSED; |
| } |
| |
| if (pending_writes_.is_empty() || pending_writes_.back()->is_flushed()) { |
| if (!free_writes_.empty()) { |
| pending_writes_.add_to_back(free_writes_.back()); |
| free_writes_.pop_back(); |
| } else { |
| pending_writes_.add_to_back(handler_->new_pending_write(this)); |
| } |
| } |
| |
| return pending_writes_.back()->write(request); |
| } |
| |
| int32_t Socket::write_and_flush(SocketRequest* request) { |
| int32_t result = write(request); |
| if (result > 0) { |
| flush(); |
| } |
| return result; |
| } |
| |
| size_t Socket::flush() { |
| if (pending_writes_.is_empty()) return 0; |
| |
| return pending_writes_.back()->flush(); |
| } |
| |
| bool Socket::is_closing() const { |
| return uv_is_closing(reinterpret_cast<const uv_handle_t*>(&tcp_)) != 0; |
| } |
| |
| void Socket::close() { |
| uv_handle_t* handle = reinterpret_cast<uv_handle_t*>(&tcp_); |
| if (!uv_is_closing(handle)) { |
| uv_close(handle, on_close); |
| } |
| } |
| |
| void Socket::defunct() { |
| close(); |
| is_defunct_ = true; |
| } |
| |
| void Socket::alloc_buffer(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { |
| Socket* socket = static_cast<Socket*>(handle->data); |
| socket->handler_->alloc_buffer(suggested_size, buf); |
| } |
| |
| void Socket::on_read(uv_stream_t* client, ssize_t nread, const uv_buf_t* buf) { |
| Socket* socket = static_cast<Socket*>(client->data); |
| socket->handle_read(nread, buf); |
| } |
| |
| void Socket::handle_read(ssize_t nread, const uv_buf_t* buf) { |
| if (nread < 0) { |
| if (nread != UV_EOF) { |
| LOG_ERROR("Socket read error '%s'", uv_strerror(nread)); |
| } |
| defunct(); |
| } |
| handler_->on_read(this, nread, buf); |
| } |
| |
| void Socket::on_close(uv_handle_t* handle) { |
| Socket* socket = static_cast<Socket*>(handle->data); |
| socket->handle_close(); |
| } |
| |
| void Socket::handle_close() { |
| LOG_DEBUG("Socket(%p) to host %s closed", static_cast<void*>(this), address_.to_string().c_str()); |
| |
| while (!pending_writes_.is_empty()) { |
| SocketWriteBase* pending_write = pending_writes_.pop_front(); |
| pending_write->on_close(); |
| delete pending_write; |
| } |
| |
| if (handler_) { |
| handler_->on_close(); |
| } |
| dec_ref(); |
| } |
| |
| void Socket::cleanup_free_writes() { |
| for (SocketWriteVec::iterator i = free_writes_.begin(), end = free_writes_.end(); i != end; ++i) { |
| delete *i; |
| } |
| } |