blob: ba2e1e38397faa207a18aa9c60d88d5c15ac1067 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
**/
#include "tmb/internal/sqlite_connection.h"
#include <sqlite3.h>
#include <cassert>
#include "tmb/internal/sqlite_connection_pool.h"
namespace {
const char kCreateTables[] =
"CREATE TABLE client(\n"
" client_id INTEGER PRIMARY KEY AUTOINCREMENT,\n"
" connect_time DATETIME NOT NULL,\n"
" disconnect_time DATETIME NULL);\n"
"CREATE TABLE sendable(\n"
" client_id INTEGER REFERENCES client(client_id) ON DELETE CASCADE,\n"
" type_id UNSIGNED INTEGER NOT NULL,\n"
" PRIMARY KEY(client_id, type_id));\n"
"CREATE TABLE receivable(\n"
" client_id INTEGER REFERENCES client(client_id) ON DELETE CASCADE,\n"
" type_id UNSIGNED INTEGER NOT NULL,\n"
" PRIMARY KEY(client_id, type_id));\n"
"CREATE TABLE message(\n"
" message_id INTEGER PRIMARY KEY AUTOINCREMENT,\n"
" sender_id INTEGER REFERENCES client(client_id) ON DELETE NO ACTION,\n"
" send_time DATETIME NOT NULL,\n"
" expiration_time DATETIME NULL,\n"
" priority UNSIGNED TINYINT NOT NULL,\n"
" type_id UNSIGNED INTEGER NOT NULL,\n"
" cancellable TINYINT NOT NULL,\n"
" payload BLOB NOT NULL);\n"
"CREATE TABLE queued_message(\n"
" receiver_id INTEGER REFERENCES client(client_id) ON DELETE CASCADE,\n"
" message_id INTEGER REFERENCES message(message_id)\n"
" ON DELETE CASCADE,\n"
" priority UNSIGNED TINYINT NOT NULL,\n"
" expiration_time DATETIME NULL,\n"
" has_expiration_time TINYINT NOT NULL,\n"
" PRIMARY KEY(receiver_id, message_id));\n";
const char kCreateIndices[] =
"CREATE INDEX receiveable_index\n"
" ON receivable(type_id);\n"
"CREATE INDEX queue_index\n"
" ON queued_message(receiver_id, priority DESC,\n"
" has_expiration_time DESC, expiration_time ASC,\n"
" message_id ASC);\n"
"CREATE INDEX queue_cancellation_index\n"
" ON queued_message(message_id);\n"
"CREATE INDEX queue_expiration_index\n"
" ON queued_message(expiration_time);\n";
const char kDropTables[] =
"DROP TABLE IF EXISTS queued_message;\n"
"DROP TABLE IF EXISTS message;\n"
"DROP TABLE IF EXISTS receivable;\n"
"DROP TABLE IF EXISTS sendable;\n"
"DROP TABLE IF EXISTS client;\n";
// NOTE(chasseur): For simplicity, we just check if the queued_message table
// exists. We could be more rigorous if necessary.
const char kCheckDatabaseInitialized[] =
"SELECT EXISTS (SELECT name FROM sqlite_master\n"
" WHERE type='table' AND name='queued_message');\n";
const char kInsertClient[] =
"INSERT INTO client(connect_time) VALUES(DATETIME('now'));\n";
const char kDisconnectClient[] =
"UPDATE client\n"
" SET disconnect_time = DATETIME('now')\n"
" WHERE disconnect_time IS NULL AND client_id = ?;\n";
const char kDeleteSenderEntriesForClient[] =
"DELETE FROM sendable WHERE client_id = ?;\n";
const char kDeleteReceiverEntriesForClient[] =
"DELETE FROM receivable WHERE client_id = ?;\n";
const char kDeleteQueuedMessagesForClient[] =
"DELETE FROM queued_message\n"
" WHERE receiver_id = ?;\n";
const char kRegisterSender[] =
"INSERT INTO sendable(client_id, type_id) VALUES(?, ?);\n";
const char kRegisterReceiver[] =
"INSERT INTO receivable(client_id, type_id) VALUES(?, ?);\n";
const char kCheckClientConnected[] =
"SELECT EXISTS (SELECT client_id FROM client\n"
" WHERE disconnect_time IS NULL and client_id = ?);\n";
const char kCheckSenderRegistered[] =
"SELECT EXISTS (SELECT client_id FROM sendable\n"
" WHERE client_id = ? AND type_id = ?);\n";
const char kCheckReceiverRegistered[] =
"SELECT EXISTS (SELECT client_id FROM receivable\n"
" WHERE client_id = ? AND type_id = ?);\n";
const char kCountClientsInSet[] =
"SELECT COUNT(*) FROM client\n"
" WHERE disconnect_time IS NULL\n"
" AND client_id IN (?, ?, ?, ?, ?, ?, ?, ?);\n";
const char kCountReceiversInSet[] =
"SELECT COUNT(*) from receivable\n"
" WHERE type_id = ?\n"
" AND client_id IN (?, ?, ?, ?, ?, ?, ?, ?)\n";
const char kCheckAnyReceiver[] =
"SELECT EXISTS (SELECT client_id FROM receivable WHERE type_id = ?);\n";
const char kInsertMessage[] =
"INSERT INTO message(sender_id, send_time, expiration_time, priority,\n"
" type_id, cancellable, payload)\n"
" VALUES(?, DATETIME('now'), DATETIME(?, 'unixepoch'), ?, ?, ?, ?);\n";
const char kQueueMessageForAllReceivers[] =
"INSERT INTO queued_message(receiver_id, message_id, priority,\n"
" expiration_time, has_expiration_time)\n"
" SELECT client_id, ?, ?, DATETIME(?, 'unixepoch'), ?\n"
" FROM receivable\n"
" WHERE type_id = ?;\n";
const char kQueueMessageForAnyReceiver[] =
"INSERT INTO queued_message(receiver_id, message_id, priority,\n"
" expiration_time, has_expiration_time)\n"
" SELECT client_id, ?, ?, DATETIME(?, 'unixepoch'), ?\n"
" FROM receivable\n"
" WHERE type_id = ?\n"
" LIMIT 1;\n";
const char kQueueMessageForAllReceiversInSet[] =
"INSERT INTO queued_message(receiver_id, message_id, priority,\n"
" expiration_time, has_expiration_time)\n"
" SELECT client_id, ?, ?, DATETIME(?, 'unixepoch'), ?\n"
" FROM receivable\n"
" WHERE type_id = ?\n"
" AND client_id IN (?, ?, ?, ?, ?, ?, ?, ?);\n";
const char kQueueMessageForAnyReceiverInSet[] =
"INSERT INTO queued_message(receiver_id, message_id, priority,\n"
" expiration_time, has_expiration_time)\n"
" SELECT client_id, ?, ?, DATETIME(?, 'unixepoch'), ?\n"
" FROM receivable\n"
" WHERE type_id = ?\n"
" AND client_id IN (?, ?, ?, ?, ?, ?, ?, ?)\n"
" LIMIT 1;\n";
// TODO(chasseur): This implicitly checks that a client at least exists in the
// clients table, which isn't actually necessary for the unchecked version.
// Find a way to avoid the nested SELECT without repeating common parameters.
const char kQueueMessageForClientsInSetUnchecked[] =
"INSERT INTO queued_message(receiver_id, message_id, priority,\n"
" expiration_time, has_expiration_time)\n"
" SELECT client_id, ?, ?, DATETIME(?, 'unixepoch'), ?\n"
" FROM client\n"
" WHERE client_id IN (?, ?, ?, ?, ?, ?, ?, ?);\n";
const char kGetNextMessages[] =
"SELECT message.message_id, message.sender_id, message.send_time,\n"
" message.type_id, message.payload\n"
" FROM message, queued_message\n"
" WHERE message.message_id = queued_message.message_id\n"
" AND queued_message.receiver_id = ?\n"
" AND queued_message.priority >= ?\n"
" AND (queued_message.has_expiration_time = 0\n"
" OR queued_message.expiration_time >= DATETIME('now'))\n"
" ORDER BY queued_message.priority DESC,\n"
" queued_message.has_expiration_time DESC,\n"
" queued_message.expiration_time ASC,\n"
" queued_message.message_id ASC;\n";
const char kGetNextMessagesWithLimit[] =
"SELECT message.message_id, message.sender_id, message.send_time,\n"
" message.type_id, message.payload\n"
" FROM message, queued_message\n"
" WHERE message.message_id = queued_message.message_id\n"
" AND queued_message.receiver_id = ?\n"
" AND queued_message.priority >= ?\n"
" AND (queued_message.has_expiration_time = 0\n"
" OR queued_message.expiration_time >= DATETIME('now'))\n"
" ORDER BY queued_message.priority DESC,\n"
" queued_message.has_expiration_time DESC,\n"
" queued_message.expiration_time ASC,\n"
" queued_message.message_id ASC\n"
" LIMIT ?;\n";
const char kDeleteQueuedMessages[] =
"DELETE FROM queued_message\n"
" WHERE receiver_id = ? AND message_id IN (?, ?, ?, ?, ?, ?, ?, ?);\n";
const char kDeleteExpiredMessages[] =
"DELETE FROM queued_message WHERE expiration_time < DATETIME('now');\n";
const char kCancelMessages[] =
"DELETE FROM message WHERE message_id IN (?, ?, ?, ?, ?, ?, ?, ?)\n"
" AND cancellable = 1;\n";
const char kSelectConnectedClients[] =
"SELECT client_id FROM client WHERE disconnect_time IS NULL;\n";
const char kSelectAllSendable[] =
"SELECT client_id, type_id FROM sendable;\n";
const char kSelectAllReceivable[] =
"SELECT client_id, type_id FROM receivable;\n";
const char kSelectCancellableMessages[] =
"SELECT message_id FROM message WHERE cancellable = 1;\n";
const char kSelectAllQueuedMessages[] =
"SELECT queued_message.receiver_id, message.message_id,\n"
" message.payload, message.type_id, message.sender_id,\n"
" message.priority, message.send_time, message.expiration_time\n"
" FROM message, queued_message\n"
" WHERE message.message_id = queued_message.message_id;\n";
const char kCountQueuedMessagesForClient[] =
"SELECT COUNT(*) FROM queued_message WHERE receiver_id = ?;\n";
} // namespace
namespace tmb {
namespace internal {
SQLiteConnection::SQLiteConnection(const char *filename)
: internal_connection_(nullptr),
connected_(false),
prepared_(false) {
// Open connection.
int rc = sqlite3_open_v2(filename,
&internal_connection_,
SQLITE_OPEN_READWRITE
| SQLITE_OPEN_CREATE
| SQLITE_OPEN_NOMUTEX
| SQLITE_OPEN_PRIVATECACHE,
nullptr);
if (rc != SQLITE_OK) {
if (internal_connection_ != nullptr) {
rc = sqlite3_close(internal_connection_);
assert(rc == SQLITE_OK);
internal_connection_ = nullptr;
return;
}
}
// Enable extended result codes for finer-grained error handling.
rc = sqlite3_extended_result_codes(internal_connection_, 1);
if (rc != SQLITE_OK) {
rc = sqlite3_close(internal_connection_);
assert(rc == SQLITE_OK);
internal_connection_ = nullptr;
return;
}
// Turn on foreign-key enforcement.
rc = sqlite3_exec(internal_connection_,
"PRAGMA foreign_keys = ON;",
nullptr, nullptr, nullptr);
if (rc != SQLITE_OK) {
rc = sqlite3_close(internal_connection_);
assert(rc == SQLITE_OK);
internal_connection_ = nullptr;
return;
}
// Set the busy retry timeout for transactions to wait on locks.
rc = sqlite3_busy_timeout(internal_connection_, kBusyTimeoutMS);
if (rc != SQLITE_OK) {
rc = sqlite3_close(internal_connection_);
assert(rc == SQLITE_OK);
internal_connection_ = nullptr;
return;
}
// Prepare non-schema dependent SQL statements.
PrepareStatement("BEGIN TRANSACTION;", &begin_xact_stmt_);
PrepareStatement("BEGIN IMMEDIATE TRANSACTION;",
&begin_immediate_xact_stmt_);
PrepareStatement("COMMIT;", &commit_xact_stmt_);
PrepareStatement("ROLLBACK;", &rollback_xact_stmt_);
PrepareStatement(kCheckDatabaseInitialized,
&check_database_initialized_stmt_);
connected_ = true;
}
SQLiteConnection::~SQLiteConnection() {
if (connected_) {
sqlite3_finalize(begin_xact_stmt_);
sqlite3_finalize(begin_immediate_xact_stmt_);
sqlite3_finalize(commit_xact_stmt_);
sqlite3_finalize(rollback_xact_stmt_);
sqlite3_finalize(check_database_initialized_stmt_);
}
if (prepared_) {
sqlite3_finalize(insert_client_stmt_);
sqlite3_finalize(disconnect_client_stmt_);
sqlite3_finalize(delete_sender_entries_for_client_stmt_);
sqlite3_finalize(delete_receiver_entries_for_client_stmt_);
sqlite3_finalize(delete_queued_messages_for_client_stmt_);
sqlite3_finalize(register_sender_stmt_);
sqlite3_finalize(register_receiver_stmt_);
sqlite3_finalize(check_client_connected_stmt_);
sqlite3_finalize(check_sender_registered_stmt_);
sqlite3_finalize(check_receiver_registered_stmt_);
sqlite3_finalize(count_clients_in_set_stmt_);
sqlite3_finalize(count_receivers_in_set_stmt_);
sqlite3_finalize(check_any_receiver_stmt_);
sqlite3_finalize(insert_message_stmt_);
sqlite3_finalize(queue_message_for_all_receivers_stmt_);
sqlite3_finalize(queue_message_for_any_receiver_stmt_);
sqlite3_finalize(queue_message_for_all_receivers_in_set_stmt_);
sqlite3_finalize(queue_message_for_any_receiver_in_set_stmt_);
sqlite3_finalize(queue_message_for_clients_in_set_unchecked_stmt_);
sqlite3_finalize(get_next_messages_stmt_);
sqlite3_finalize(get_next_messages_with_limit_stmt_);
sqlite3_finalize(delete_queued_messages_stmt_);
sqlite3_finalize(delete_expired_messages_stmt_);
sqlite3_finalize(cancel_messages_stmt_);
sqlite3_finalize(select_connected_clients_stmt_);
sqlite3_finalize(select_all_sendable_stmt_);
sqlite3_finalize(select_all_receivable_stmt_);
sqlite3_finalize(select_cancellable_messages_stmt_);
sqlite3_finalize(select_all_queued_messages_stmt_);
sqlite3_finalize(count_queued_messages_for_client_stmt_);
}
if (connected_) {
int rc = sqlite3_close(internal_connection_);
assert(rc == SQLITE_OK);
}
}
bool SQLiteConnection::CheckDatabaseInitialized() {
assert(connected_);
int rc;
do {
rc = sqlite3_step(check_database_initialized_stmt_);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_ROW);
bool initialized = sqlite3_column_int(check_database_initialized_stmt_, 0)
!= 0;
do {
rc = sqlite3_step(check_database_initialized_stmt_);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_DONE);
sqlite3_reset(check_database_initialized_stmt_);
return initialized;
}
void SQLiteConnection::InitializeDatabase(const bool use_wal) {
assert(connected_);
int rc;
// Write-ahead logging is only supported in SQLite 3.7.0+.
if (use_wal && (SQLITE_VERSION_NUMBER >= 3007000)) {
do {
rc = sqlite3_exec(internal_connection_, "PRAGMA journal_mode = WAL;",
nullptr, nullptr, nullptr);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_OK);
}
do {
rc = sqlite3_step(begin_immediate_xact_stmt_);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_DONE);
sqlite3_reset(begin_immediate_xact_stmt_);
do {
rc = sqlite3_exec(internal_connection_, kDropTables,
nullptr, nullptr, nullptr);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_OK);
do {
rc = sqlite3_exec(internal_connection_, kCreateTables,
nullptr, nullptr, nullptr);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_OK);
do {
rc = sqlite3_exec(internal_connection_, kCreateIndices,
nullptr, nullptr, nullptr);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_OK);
do {
rc = sqlite3_step(commit_xact_stmt_);
} while (rc == SQLITE_BUSY);
assert(rc == SQLITE_DONE);
sqlite3_reset(commit_xact_stmt_);
if (!prepared_) {
Prepare();
}
}
void SQLiteConnection::Prepare() {
assert(connected_);
PrepareStatement(kInsertClient, &insert_client_stmt_);
PrepareStatement(kDisconnectClient, &disconnect_client_stmt_);
PrepareStatement(kDeleteSenderEntriesForClient,
&delete_sender_entries_for_client_stmt_);
PrepareStatement(kDeleteReceiverEntriesForClient,
&delete_receiver_entries_for_client_stmt_);
PrepareStatement(kDeleteQueuedMessagesForClient,
&delete_queued_messages_for_client_stmt_);
PrepareStatement(kRegisterSender, &register_sender_stmt_);
PrepareStatement(kRegisterReceiver, &register_receiver_stmt_);
PrepareStatement(kCheckClientConnected, &check_client_connected_stmt_);
PrepareStatement(kCheckSenderRegistered, &check_sender_registered_stmt_);
PrepareStatement(kCheckReceiverRegistered, &check_receiver_registered_stmt_);
PrepareStatement(kCountClientsInSet, &count_clients_in_set_stmt_);
PrepareStatement(kCountReceiversInSet, &count_receivers_in_set_stmt_);
PrepareStatement(kCheckAnyReceiver, &check_any_receiver_stmt_);
PrepareStatement(kInsertMessage, &insert_message_stmt_);
PrepareStatement(kQueueMessageForAllReceivers,
&queue_message_for_all_receivers_stmt_);
PrepareStatement(kQueueMessageForAnyReceiver,
&queue_message_for_any_receiver_stmt_);
PrepareStatement(kQueueMessageForAllReceiversInSet,
&queue_message_for_all_receivers_in_set_stmt_);
PrepareStatement(kQueueMessageForAnyReceiverInSet,
&queue_message_for_any_receiver_in_set_stmt_);
PrepareStatement(kQueueMessageForClientsInSetUnchecked,
&queue_message_for_clients_in_set_unchecked_stmt_);
PrepareStatement(kGetNextMessages, &get_next_messages_stmt_);
PrepareStatement(kGetNextMessagesWithLimit,
&get_next_messages_with_limit_stmt_);
PrepareStatement(kDeleteQueuedMessages, &delete_queued_messages_stmt_);
PrepareStatement(kDeleteExpiredMessages, &delete_expired_messages_stmt_);
PrepareStatement(kCancelMessages, &cancel_messages_stmt_);
PrepareStatement(kSelectConnectedClients, &select_connected_clients_stmt_);
PrepareStatement(kSelectAllSendable, &select_all_sendable_stmt_);
PrepareStatement(kSelectAllReceivable, &select_all_receivable_stmt_);
PrepareStatement(kSelectCancellableMessages,
&select_cancellable_messages_stmt_);
PrepareStatement(kSelectAllQueuedMessages,
&select_all_queued_messages_stmt_);
PrepareStatement(kCountQueuedMessagesForClient,
&count_queued_messages_for_client_stmt_);
prepared_ = true;
}
void SQLiteConnection::PrepareStatement(const char *sql, sqlite3_stmt **stmt) {
int rc = SQLITE_BUSY;
while (rc == SQLITE_BUSY) {
rc = sqlite3_prepare_v2(internal_connection_, sql, -1, stmt, nullptr);
}
assert(rc == SQLITE_OK);
}
PooledSQLiteConnection::~PooledSQLiteConnection() {
if (valid()) {
pool_->AddToPool(connection_);
}
}
PooledSQLiteConnection& PooledSQLiteConnection::operator=(
PooledSQLiteConnection &&rhs) { // NOLINT(build/c++11)
if (valid()) {
pool_->AddToPool(connection_);
}
connection_ = rhs.connection_;
pool_ = rhs.pool_;
rhs.connection_ = nullptr;
rhs.pool_ = nullptr;
return *this;
}
} // namespace internal
} // namespace tmb