blob: eb2416190a3cce4cc64603b06d23ea3d9b909f57 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
extern "C"
{
#include <postgres.h>
#include <access/xact.h>
#include <executor/spi.h>
#include <fmgr.h>
#include <lib/dshash.h>
#include <libpq/crypt.h>
#include <libpq/libpq-be.h>
#include <libpq/libpq.h>
#include <miscadmin.h>
#include <postmaster/bgworker.h>
#include <postmaster/postmaster.h>
#include <storage/ipc.h>
#include <storage/latch.h>
#include <storage/lwlock.h>
#include <storage/procsignal.h>
#include <storage/shmem.h>
#include <utils/backend_status.h>
#include <utils/builtins.h>
#include <utils/dsa.h>
#include <utils/guc.h>
#include <utils/memutils.h>
#include <utils/snapmgr.h>
#include <utils/timestamp.h>
#include <utils/wait_event.h>
}
#undef Abs
#if PG_VERSION_NUM >= 150000
# define AFS_HAVE_SHMEM_REQUEST_HOOK
#endif
#include <arrow/buffer.h>
#include <arrow/builder.h>
#include <arrow/flight/server_middleware.h>
#include <arrow/flight/sql/server.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/ipc/writer.h>
#include <arrow/table_builder.h>
#include <arrow/util/base64.h>
#include <arrow/util/string.h>
#include <cinttypes>
#include <condition_variable>
#include <fstream>
#include <iterator>
#include <map>
#include <random>
#include <sstream>
#include <type_traits>
#include <arpa/inet.h>
#ifdef __GNUC__
# define AFS_FUNC __PRETTY_FUNCTION__
#else
# define AFS_FUNC __func__
#endif
#ifdef _WIN32
# define PRIsize "Iu"
#else
# define PRIsize "zu"
#endif
// #define AFS_VERBOSE
#ifdef AFS_VERBOSE
# define P(...) ereport(DEBUG5, errmsg_internal(__VA_ARGS__))
#else
# define P(...)
#endif
extern "C"
{
PG_MODULE_MAGIC;
extern PGDLLEXPORT void _PG_init(void);
extern PGDLLEXPORT void afs_executor(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn();
}
namespace {
static const char* LibraryName = "arrow_flight_sql";
static const char* SharedDataName = "arrow-flight-sql: shared data";
static const char* Tag = "arrow-flight-sql";
static const char* URIDefault = "grpc://127.0.0.1:15432";
static char* URI;
static const int SessionTimeoutDefault = 300;
static int SessionTimeout;
static const int MaxNRowsPerRecordBatchDefault = 1 * 1024 * 1024;
static int MaxNRowsPerRecordBatch;
static volatile sig_atomic_t GotSIGTERM = false;
void
afs_sigterm(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGTERM = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGHUP = false;
void
afs_sighup(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGHUP = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGUSR1 = false;
void
afs_sigusr1(SIGNAL_ARGS)
{
procsignal_sigusr1_handler(postgres_signal_arg);
auto errnoSaved = errno;
GotSIGUSR1 = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
#ifdef AFS_DEBUG
void
afs_sigsegv(SIGNAL_ARGS)
{
ereport(FATAL, errcode(ERRCODE_INTERNAL_ERROR), errmsg("CRASHED!!!"), errbacktrace());
}
#endif
#ifdef AFS_HAVE_SHMEM_REQUEST_HOOK
static shmem_request_hook_type PreviousShmemRequestHook = nullptr;
#endif
static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche";
void
afs_shmem_request_hook(void)
{
#ifdef AFS_HAVE_SHMEM_REQUEST_HOOK
if (PreviousShmemRequestHook)
PreviousShmemRequestHook();
#endif
RequestNamedLWLockTranche(LWLockTrancheName, 1);
}
class ScopedMemoryContext {
public:
explicit ScopedMemoryContext(MemoryContext memoryContext)
: memoryContext_(memoryContext), oldMemoryContext_(nullptr)
{
oldMemoryContext_ = MemoryContextSwitchTo(memoryContext_);
}
~ScopedMemoryContext()
{
MemoryContextSwitchTo(oldMemoryContext_);
MemoryContextDelete(memoryContext_);
}
private:
MemoryContext memoryContext_;
MemoryContext oldMemoryContext_;
};
struct ScopedTransaction {
ScopedTransaction() { StartTransactionCommand(); }
~ScopedTransaction() { CommitTransactionCommand(); }
};
struct ScopedSnapshot {
ScopedSnapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
~ScopedSnapshot() { PopActiveSnapshot(); }
};
struct ScopedPlan {
ScopedPlan(SPIPlanPtr plan) : plan_(plan) {}
~ScopedPlan() { SPI_freeplan(plan_); }
SPIPlanPtr plan_;
};
struct SharedRingBufferData {
dsa_pointer pointer;
size_t total;
size_t head;
size_t tail;
};
// Naive ring buffer implementation. We can improve this later.
//
// We can't call PostgreSQL API in this class.
class SharedRingBuffer {
public:
static void initialize_data(SharedRingBufferData* data)
{
data->pointer = InvalidDsaPointer;
data->total = 0;
data->head = 0;
data->tail = 0;
}
static void allocate_data(SharedRingBufferData* data, dsa_area* area, size_t total)
{
data->pointer = dsa_allocate(area, total);
data->total = total;
data->head = 0;
data->tail = 0;
}
static void free_data(SharedRingBufferData* data, dsa_area* area)
{
if (data->pointer != InvalidDsaPointer)
dsa_free(area, data->pointer);
initialize_data(data);
}
SharedRingBuffer(SharedRingBufferData* data, void* address, const char* tag)
: data_(data), address_(static_cast<uint8_t*>(address)), tag_(tag)
{
}
size_t size() const
{
if (data_->head <= data_->tail)
{
return data_->tail - data_->head;
}
else
{
return (data_->total - data_->head) + data_->tail;
}
}
size_t rest_size() const { return data_->total - size() - 1; }
// Need LWLockAcquire() by caller.
size_t write(const void* data, size_t n)
{
#ifdef AFS_VERBOSE
const char* tag = "write";
#endif
P("%s: %s: %s: before: (%" PRIsize ":%" PRIsize ") %" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n);
if (rest_size() == 0)
{
P("%s: %s: %s: after: no space: (%" PRIsize ":%" PRIsize ") %" PRIsize ":0",
Tag,
tag_,
tag,
data_->head,
data_->tail,
n);
return 0;
}
size_t writtenSize = 0;
auto output = address_;
if (data_->head <= data_->tail)
{
auto restSize = data_->total - data_->tail;
if (data_->head == 0)
{
restSize--;
}
const auto firstHalfWriteSize = std::min(n, restSize);
P("%s: %s: %s: first half: (%" PRIsize ":%" PRIsize ") %" PRIsize
":%" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
firstHalfWriteSize);
memcpy(output + data_->tail, data, firstHalfWriteSize);
data_->tail = (data_->tail + firstHalfWriteSize) % data_->total;
n -= firstHalfWriteSize;
writtenSize += firstHalfWriteSize;
}
if (n > 0 && rest_size() > 0)
{
const auto lastHalfWriteSize = std::min(n, data_->head - data_->tail - 1);
P("%s: %s: %s: last half: (%" PRIsize ":%" PRIsize ") %" PRIsize ":%" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
lastHalfWriteSize);
memcpy(output + data_->tail,
static_cast<const uint8_t*>(data) + writtenSize,
lastHalfWriteSize);
data_->tail += lastHalfWriteSize;
n -= lastHalfWriteSize;
writtenSize += lastHalfWriteSize;
}
P("%s: %s: %s: after: (%" PRIsize ":%" PRIsize ") %" PRIsize ":%" PRIsize
" [%s] [%s]",
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
writtenSize,
arrow::HexEncode(reinterpret_cast<const uint8_t*>(data), writtenSize).c_str(),
arrow::HexEncode(reinterpret_cast<uint8_t*>(output), writtenSize).c_str());
return writtenSize;
}
// Need LWLockAcquire() by caller.
size_t read(size_t n, void* output)
{
#ifdef AFS_VERBOSE
const char* tag = "read";
#endif
P("%s: %s: %s: before: (%" PRIsize ":%" PRIsize ") %" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n);
size_t readSize = 0;
const auto input = address_;
if (data_->head > data_->tail)
{
const auto firstHalfReadSize = std::min(n, data_->total - data_->head);
P("%s: %s: %s: first half: (%" PRIsize ":%" PRIsize ") %" PRIsize
":%" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
firstHalfReadSize);
memcpy(output, input + data_->head, firstHalfReadSize);
data_->head = (data_->head + firstHalfReadSize) % data_->total;
n -= firstHalfReadSize;
readSize += firstHalfReadSize;
}
if (n > 0 && data_->head != data_->tail)
{
const auto lastHalfReadSize = std::min(n, data_->tail - data_->head);
P("%s: %s: %s: last half: (%" PRIsize ":%" PRIsize ") %" PRIsize ":%" PRIsize,
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
lastHalfReadSize);
memcpy(static_cast<uint8_t*>(output) + readSize,
input + data_->head,
lastHalfReadSize);
data_->head += lastHalfReadSize;
n -= lastHalfReadSize;
readSize += lastHalfReadSize;
}
P("%s: %s: %s: after: (%" PRIsize ":%" PRIsize ") %" PRIsize ":%" PRIsize
" [%s] [%s]",
Tag,
tag_,
tag,
data_->head,
data_->tail,
n,
readSize,
arrow::HexEncode(reinterpret_cast<const uint8_t*>(output), readSize).c_str(),
arrow::HexEncode(reinterpret_cast<const uint8_t*>(input), readSize).c_str());
return readSize;
}
private:
SharedRingBufferData* data_;
uint8_t* address_;
const char* tag_;
};
enum class Action
{
None,
Select,
Update,
Prepare,
ClosePreparedStatement,
SetParameters,
SelectPreparedStatement,
UpdatePreparedStatement,
};
const char*
action_name(Action action)
{
switch (action)
{
case Action::None:
return "Action::None";
case Action::Select:
return "Action::Select";
case Action::Update:
return "Action::Update";
case Action::Prepare:
return "Action::Prepare";
case Action::ClosePreparedStatement:
return "Action::ClosePreparedStatement";
case Action::SetParameters:
return "Action::SetParameters";
case Action::SelectPreparedStatement:
return "Action::SelectPreparedStatement";
case Action::UpdatePreparedStatement:
return "Action::UpdatePreparedStatement";
default:
return "Action::Unknown";
}
}
void
dsa_pointer_set_string(dsa_pointer& pointer, dsa_area* area, const std::string& input)
{
if (DsaPointerIsValid(pointer))
{
dsa_free(area, pointer);
pointer = InvalidDsaPointer;
}
if (input.empty())
{
return;
}
pointer = dsa_allocate(area, input.size() + 1);
memcpy(dsa_get_address(area, pointer), input.c_str(), input.size() + 1);
}
// Session data used in only one process.
//
// They are based on data in SharedSessionData and process local data.
struct LocalSessionData {
LocalSessionData()
: id(0),
valid(false),
peerPID(InvalidPid),
bufferData(nullptr),
bufferAddress(nullptr),
errorMessage(std::nullopt)
{
}
uint64_t id;
bool valid;
pid_t peerPID;
SharedRingBufferData* bufferData;
void* bufferAddress;
std::optional<std::string> errorMessage;
};
// Session data shared with multiple processes. LWLockAcquire() with
// Processor's lock_ is needed when we change this data.
//
// This struct must have only data (don't add methods) to use with
// dshash.
struct SharedSessionData {
uint64_t id;
dsa_pointer errorMessage;
pid_t executorPID;
bool started;
bool initialized;
bool finished;
dsa_pointer databaseName;
dsa_pointer userName;
dsa_pointer password;
dsa_pointer clientAddress;
Action action;
dsa_pointer selectQuery;
dsa_pointer updateQuery;
int64_t nUpdatedRecords;
dsa_pointer prepareQuery;
dsa_pointer preparedStatementHandle;
bool setParametersFinished;
SharedRingBufferData bufferData;
};
void
shared_session_data_initialize(SharedSessionData* session,
dsa_area* area,
const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
session->errorMessage = InvalidDsaPointer;
session->executorPID = InvalidPid;
session->started = false;
session->initialized = false;
session->finished = false;
dsa_pointer_set_string(session->databaseName, area, databaseName);
dsa_pointer_set_string(session->userName, area, userName);
dsa_pointer_set_string(session->password, area, password);
dsa_pointer_set_string(session->clientAddress, area, clientAddress);
session->action = Action::None;
session->selectQuery = InvalidDsaPointer;
session->updateQuery = InvalidDsaPointer;
session->nUpdatedRecords = -1;
session->prepareQuery = InvalidDsaPointer;
session->preparedStatementHandle = InvalidDsaPointer;
session->setParametersFinished = false;
SharedRingBuffer::initialize_data(&(session->bufferData));
}
void
shared_session_data_finalize(SharedSessionData* session, dsa_area* area)
{
if (DsaPointerIsValid(session->errorMessage))
dsa_free(area, session->errorMessage);
if (DsaPointerIsValid(session->databaseName))
dsa_free(area, session->databaseName);
if (DsaPointerIsValid(session->userName))
dsa_free(area, session->userName);
if (DsaPointerIsValid(session->password))
dsa_free(area, session->password);
if (DsaPointerIsValid(session->selectQuery))
dsa_free(area, session->selectQuery);
if (DsaPointerIsValid(session->updateQuery))
dsa_free(area, session->updateQuery);
if (DsaPointerIsValid(session->prepareQuery))
dsa_free(area, session->prepareQuery);
if (DsaPointerIsValid(session->preparedStatementHandle))
dsa_free(area, session->preparedStatementHandle);
SharedRingBuffer::free_data(&(session->bufferData), area);
}
class SharedSessionReleaser {
public:
explicit SharedSessionReleaser(dshash_table* sessions, SharedSessionData* data)
: sessions_(sessions), data_(data)
{
}
~SharedSessionReleaser() { dshash_release_lock(sessions_, data_); }
private:
dshash_table* sessions_;
SharedSessionData* data_;
};
static dshash_parameters SessionsParams = {
sizeof(uint64_t),
sizeof(SharedSessionData),
dshash_memcmp,
dshash_memhash,
#if PG_VERSION_NUM >= 170000
dshash_memcpy,
#endif
/* .tranche_id = */ 0, // Set later because this is determined dynamically.
};
struct SharedData {
int trancheID;
dsa_handle handle;
int sessionsTrancheID;
dshash_table_handle sessionsHandle;
pid_t serverPID;
pid_t mainPID;
};
class Processor {
public:
enum class WaitMode
{
Read,
Written,
};
Processor(const char* tag) : tag_(tag), sharedData_(nullptr), area_(nullptr), lock_()
{
}
virtual ~Processor()
{
if (area_)
{
dsa_detach(area_);
}
}
const char* tag() { return tag_; }
void lock_acquire() { LWLockAcquire(lock_, LW_EXCLUSIVE); }
void lock_release() { LWLockRelease(lock_); }
// Executor and Server must implement this.
virtual arrow::Result<size_t> read(LocalSessionData* session,
SharedRingBuffer* buffer,
size_t n,
void* output)
{
return 0;
}
// Executor and Server must implement this.
virtual arrow::Result<size_t> write(LocalSessionData* session,
SharedRingBuffer* buffer,
const void* data,
size_t n)
{
return 0;
}
arrow::Status wait(LocalSessionData* session, SharedRingBuffer* buffer, WaitMode mode)
{
const char* tag = (mode == WaitMode::Read) ? "wait read" : "wait written";
auto peerPID = peer_pid(session);
if (ARROW_PREDICT_FALSE(peerPID == InvalidPid))
{
return arrow::Status::IOError(
Tag, ": ", tag_, ": ", tag, ": ", peer_name(), ": not alive");
}
P("%s: %s: %s: %s: kill: %d", Tag, tag_, tag, peer_name(), peerPID);
kill(peerPID, SIGUSR1);
return wait_internal(session, buffer, mode, tag);
}
protected:
void set_shared_string(dsa_pointer& pointer, const std::string& input)
{
dsa_pointer_set_string(pointer, area_, input);
}
virtual pid_t peer_pid(LocalSessionData* session) { return InvalidPid; }
virtual const char* peer_name() { return "unknown"; }
virtual arrow::Status wait_internal(LocalSessionData* session,
SharedRingBuffer* buffer,
WaitMode mode,
const char* tag)
{
return arrow::Status::NotImplemented("wait_internal");
}
size_t get_waiting_buffer_size(SharedRingBuffer* buffer, WaitMode mode)
{
if (mode == WaitMode::Read)
{
return buffer->rest_size();
}
else
{
return buffer->size();
}
}
const char* tag_;
SharedData* sharedData_;
dsa_area* area_;
LWLock* lock_;
};
struct ProcessorLockGuard {
ProcessorLockGuard(Processor* processor) : processor_(processor)
{
processor_->lock_acquire();
}
~ProcessorLockGuard() { processor_->lock_release(); }
private:
Processor* processor_;
};
class SharedRingBufferInputStream : public arrow::io::InputStream {
public:
SharedRingBufferInputStream(Processor* processor,
std::shared_ptr<LocalSessionData> localSession)
: arrow::io::InputStream(),
processor_(processor),
localSession_(std::move(localSession)),
position_(0),
is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Result<int64_t> Read(int64_t nBytes, void* out) override;
arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nBytes) override
{
ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(nBytes));
ARROW_ASSIGN_OR_RAISE(auto readBytes, Read(nBytes, buffer->mutable_data()));
ARROW_RETURN_NOT_OK(buffer->Resize(readBytes, false));
buffer->ZeroPadding();
return std::move(buffer);
}
private:
Processor* processor_;
std::shared_ptr<LocalSessionData> localSession_;
int64_t position_;
bool is_open_;
};
class SharedRingBufferOutputStream : public arrow::io::OutputStream {
public:
SharedRingBufferOutputStream(Processor* processor,
std::shared_ptr<LocalSessionData> localSession)
: arrow::io::OutputStream(),
processor_(processor),
localSession_(std::move(localSession)),
position_(0),
is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Status Write(const void* data, int64_t nBytes) override;
using arrow::io::OutputStream::Write;
private:
Processor* processor_;
std::shared_ptr<LocalSessionData> localSession_;
int64_t position_;
bool is_open_;
};
class WorkerProcessor : public Processor {
public:
explicit WorkerProcessor(const char* tag) : Processor(tag), sessions_(nullptr)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
sharedData_ = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found));
if (!found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data isn't created yet", Tag, tag_));
}
area_ = dsa_attach(sharedData_->handle);
SessionsParams.tranche_id = sharedData_->sessionsTrancheID;
sessions_ =
dshash_attach(area_, &SessionsParams, sharedData_->sessionsHandle, nullptr);
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
}
~WorkerProcessor() override { dshash_detach(sessions_); }
protected:
dshash_table* sessions_;
};
class ArrowPGTypeConverter : public arrow::TypeVisitor {
public:
explicit ArrowPGTypeConverter() : oid_(InvalidOid) {}
Oid oid() const { return oid_; }
arrow::Status Visit(const arrow::Int8Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt8Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int16Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt16Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int32Type& type)
{
oid_ = INT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt32Type& type)
{
oid_ = INT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int64Type& type)
{
oid_ = INT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt64Type& type)
{
oid_ = INT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::FloatType& type)
{
oid_ = FLOAT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::DoubleType& type)
{
oid_ = FLOAT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::StringType& type)
{
oid_ = TEXTOID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::BinaryType& type)
{
oid_ = BYTEAOID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::TimestampType& type)
{
oid_ = TIMESTAMPOID;
return arrow::Status::OK();
}
private:
Oid oid_;
};
class ArrowPGValueConverter : public arrow::ArrayVisitor {
public:
explicit ArrowPGValueConverter(int64_t i_row, Datum& datum)
: i_row_(i_row), datum_(datum)
{
}
arrow::Status Visit(const arrow::Int8Array& array)
{
datum_ = Int8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt8Array& array)
{
datum_ = UInt8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int16Array& array)
{
datum_ = Int16GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt16Array& array)
{
datum_ = UInt16GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int32Array& array)
{
datum_ = Int32GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt32Array& array)
{
datum_ = UInt32GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int64Array& array)
{
datum_ = Int64GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt64Array& array)
{
datum_ = UInt64GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::FloatArray& array)
{
datum_ = Float4GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::DoubleArray& array)
{
datum_ = Float8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::StringArray& array)
{
auto value = array.GetView(i_row_);
datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length()));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::BinaryArray& array)
{
auto value = array.GetView(i_row_);
datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length()));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::TimestampArray& array)
{
const auto unit =
std::static_pointer_cast<arrow::TimestampType>(array.type())->unit();
Timestamp value = 0;
switch (unit)
{
case arrow::TimeUnit::SECOND:
value += array.Value(i_row_) * 1000000;
break;
case arrow::TimeUnit::MILLI:
value += array.Value(i_row_) * 1000;
break;
case arrow::TimeUnit::MICRO:
value += array.Value(i_row_);
break;
case arrow::TimeUnit::NANO:
value += array.Value(i_row_) / 1000;
break;
default:
return arrow::Status::NotImplemented("Unsupported time unit: ", unit);
}
// Arrow uses UNIX epoch (1970-01-01) but PostgreSQL uses 2000-01-01.
value -= (POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY;
datum_ = TimestampGetDatum(value);
return arrow::Status::OK();
}
private:
int64_t i_row_;
Datum& datum_;
};
class ArrowArrayBuilderBase {
public:
static arrow::Result<std::unique_ptr<ArrowArrayBuilderBase>> make(
Form_pg_attribute attribute, int iAttribute, arrow::ArrayBuilder* builder);
static arrow::Result<std::shared_ptr<arrow::DataType>> arrow_type(
Form_pg_attribute attribute);
ArrowArrayBuilderBase(Form_pg_attribute attribute, int iAttribute)
: attribute_(attribute), iAttribute_(iAttribute)
{
}
virtual ~ArrowArrayBuilderBase() = default;
arrow::Status build(uint64_t iTuple, uint64_t iTupleEnd)
{
if (attribute_->attnotnull)
{
return build_not_null(iTuple, iTupleEnd);
}
else
{
return build_may_null(iTuple, iTupleEnd);
}
}
protected:
Form_pg_attribute attribute_;
int iAttribute_;
virtual arrow::Status build_not_null(uint64_t iTuple, uint64_t iTupleEnd) = 0;
virtual arrow::Status build_may_null(uint64_t iTuple, uint64_t iTupleEnd) = 0;
};
template <typename ArrowType, typename Enable = void>
class ArrowArrayBuilder;
template <typename ArrowType>
class ArrowArrayBuilder<ArrowType, arrow::enable_if_has_c_type<ArrowType>>
: public ArrowArrayBuilderBase {
public:
ArrowArrayBuilder(Form_pg_attribute attribute,
int iAttribute,
arrow::ArrayBuilder* builder)
: ArrowArrayBuilderBase(attribute, iAttribute),
builder_(static_cast<BuilderType*>(builder)),
values_(),
validBytes_()
{
}
private:
using CType = typename arrow::TypeTraits<ArrowType>::CType;
using BuilderType = typename arrow::TypeTraits<ArrowType>::BuilderType;
BuilderType* builder_;
std::vector<CType> values_;
std::vector<uint8_t> validBytes_;
arrow::Status build_not_null(uint64_t iTuple, uint64_t iTupleEnd)
{
uint64_t n = iTupleEnd - iTuple;
values_.resize(n);
for (uint64_t i = 0; i < n; ++i)
{
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple + i],
SPI_tuptable->tupdesc,
iAttribute_ + 1,
&isNull);
values_[i] = convert_value<ArrowType>(datum);
}
return builder_->AppendValues(values_.data(), values_.size());
}
arrow::Status build_may_null(uint64_t iTuple, uint64_t iTupleEnd)
{
bool haveNull = false;
uint64_t n = iTupleEnd - iTuple;
values_.resize(n);
validBytes_.resize(n);
for (uint64_t i = 0; i < n; ++i)
{
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple + i],
SPI_tuptable->tupdesc,
iAttribute_ + 1,
&isNull);
if (isNull)
{
haveNull = true;
validBytes_[i] = 0;
}
else
{
validBytes_[i] = 1;
values_[i] = convert_value<ArrowType>(datum);
}
}
if (haveNull)
{
return builder_->AppendValues(
values_.data(), values_.size(), validBytes_.data());
}
else
{
return builder_->AppendValues(values_.data(), values_.size());
}
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int16Type>, int16_t>
convert_value(Datum datum)
{
return DatumGetInt16(datum);
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int32Type>, int32_t>
convert_value(Datum datum)
{
return DatumGetInt32(datum);
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int64Type>, int64_t>
convert_value(Datum datum)
{
return DatumGetInt64(datum);
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::FloatType>, float>
convert_value(Datum datum)
{
return DatumGetFloat4(datum);
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::DoubleType>, double>
convert_value(Datum datum)
{
return DatumGetFloat8(datum);
}
template <typename TargetArrowType>
std::enable_if_t<std::is_same_v<TargetArrowType, arrow::TimestampType>, int64_t>
convert_value(Datum datum)
{
// Arrow uses UNIX epoch (1970-01-01) but PostgreSQL
// uses 2000-01-01.
return DatumGetTimestamp(datum) +
((POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY);
}
};
template <typename ArrowType>
class ArrowArrayBuilder<ArrowType, arrow::enable_if_base_binary<ArrowType>>
: public ArrowArrayBuilderBase {
public:
ArrowArrayBuilder(Form_pg_attribute attribute,
int iAttribute,
arrow::ArrayBuilder* builder)
: ArrowArrayBuilderBase(attribute, iAttribute),
builder_(static_cast<BuilderType*>(builder)),
values_(),
validBytes_()
{
}
private:
using BuilderType = typename arrow::TypeTraits<ArrowType>::BuilderType;
BuilderType* builder_;
std::vector<std::string> values_;
std::vector<uint8_t> validBytes_;
arrow::Status build_not_null(uint64_t iTuple, uint64_t iTupleEnd)
{
uint64_t n = iTupleEnd - iTuple;
values_.resize(n);
for (uint64_t i = 0; i < n; ++i)
{
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple + i],
SPI_tuptable->tupdesc,
iAttribute_ + 1,
&isNull);
values_[i] = std::string(VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum));
}
return builder_->AppendValues(values_);
}
arrow::Status build_may_null(uint64_t iTuple, uint64_t iTupleEnd)
{
bool haveNull = false;
uint64_t n = iTupleEnd - iTuple;
values_.resize(n);
validBytes_.resize(n);
for (uint64_t i = 0; i < n; ++i)
{
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple + i],
SPI_tuptable->tupdesc,
iAttribute_ + 1,
&isNull);
if (isNull)
{
haveNull = true;
validBytes_[i] = 0;
}
else
{
validBytes_[i] = 1;
values_[i] = std::string(VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum));
}
}
if (haveNull)
{
return builder_->AppendValues(values_, validBytes_.data());
}
else
{
return builder_->AppendValues(values_);
}
}
};
arrow::Result<std::unique_ptr<ArrowArrayBuilderBase>>
ArrowArrayBuilderBase::make(Form_pg_attribute attribute,
int iAttribute,
arrow::ArrayBuilder* builder)
{
switch (attribute->atttypid)
{
case INT2OID:
return std::make_unique<ArrowArrayBuilder<arrow::Int16Type>>(
attribute, iAttribute, builder);
case INT4OID:
return std::make_unique<ArrowArrayBuilder<arrow::Int32Type>>(
attribute, iAttribute, builder);
case INT8OID:
return std::make_unique<ArrowArrayBuilder<arrow::Int64Type>>(
attribute, iAttribute, builder);
case FLOAT4OID:
return std::make_unique<ArrowArrayBuilder<arrow::FloatType>>(
attribute, iAttribute, builder);
case FLOAT8OID:
return std::make_unique<ArrowArrayBuilder<arrow::DoubleType>>(
attribute, iAttribute, builder);
case VARCHAROID:
case TEXTOID:
return std::make_unique<ArrowArrayBuilder<arrow::StringType>>(
attribute, iAttribute, builder);
case BYTEAOID:
return std::make_unique<ArrowArrayBuilder<arrow::BinaryType>>(
attribute, iAttribute, builder);
case TIMESTAMPOID:
return std::make_unique<ArrowArrayBuilder<arrow::TimestampType>>(
attribute, iAttribute, builder);
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute->atttypid);
}
}
arrow::Result<std::shared_ptr<arrow::DataType>>
ArrowArrayBuilderBase::arrow_type(Form_pg_attribute attribute)
{
switch (attribute->atttypid)
{
case INT2OID:
return arrow::int16();
case INT4OID:
return arrow::int32();
case INT8OID:
return arrow::int64();
case FLOAT4OID:
return arrow::float32();
case FLOAT8OID:
return arrow::float64();
case VARCHAROID:
case TEXTOID:
return arrow::utf8();
case BYTEAOID:
return arrow::binary();
case TIMESTAMPOID:
return arrow::timestamp(arrow::TimeUnit::MICRO);
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute->atttypid);
}
}
class PreparedStatement {
public:
explicit PreparedStatement(std::string query)
: query_(std::move(query)), parameters_()
{
}
~PreparedStatement() {}
using WriteFunc = std::add_pointer<arrow::Status(void*)>::type;
arrow::Status select(WriteFunc write, void* writeData)
{
if (parameters_.empty())
{
auto result = SPI_execute(query_.c_str(), false, 0);
if (result <= 0)
{
return arrow::Status::Invalid("failed to run a query: ",
SPI_result_code_string(result),
": ",
query_);
}
return write(writeData);
}
for (const auto& recordBatch : parameters_)
{
SPIExecuteOptions options = {};
std::vector<Oid> pgTypes;
ARROW_RETURN_NOT_OK(prepare(options, pgTypes, recordBatch->schema()));
auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data());
ScopedPlan scopedPlan(plan);
ARROW_RETURN_NOT_OK(
execute(plan, recordBatch, options, [&]() { return write(writeData); }));
}
return arrow::Status::OK();
}
arrow::Status set_parameters(std::shared_ptr<SharedRingBufferInputStream>& input)
{
parameters_.clear();
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
parameters_.push_back(std::move(recordBatch));
}
return arrow::Status::OK();
}
arrow::Result<int64_t> update(std::shared_ptr<SharedRingBufferInputStream>& input)
{
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
SPIExecuteOptions options = {};
std::vector<Oid> pgTypes;
ARROW_RETURN_NOT_OK(prepare(options, pgTypes, reader->schema()));
auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data());
ScopedPlan scopedPlan(plan);
int64_t nUpdatedRecords = 0;
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
ARROW_RETURN_NOT_OK(execute(plan, recordBatch, options, [&nUpdatedRecords]() {
nUpdatedRecords += SPI_processed;
return arrow::Status::OK();
}));
}
return nUpdatedRecords;
}
private:
arrow::Status prepare_pg_types(std::vector<Oid>& pgTypes,
const std::shared_ptr<arrow::Schema>& schema)
{
ArrowPGTypeConverter converter;
for (const auto& field : schema->fields())
{
ARROW_RETURN_NOT_OK(field->type()->Accept(&converter));
pgTypes.push_back(converter.oid());
}
return arrow::Status::OK();
}
arrow::Status prepare(SPIExecuteOptions& options,
std::vector<Oid>& pgTypes,
const std::shared_ptr<arrow::Schema>& schema)
{
if (schema->num_fields() > 0)
{
options.params = makeParamList(schema->num_fields());
}
options.read_only = false;
options.tcount = 0;
ARROW_RETURN_NOT_OK(prepare_pg_types(pgTypes, schema));
for (size_t i = 0; i < pgTypes.size(); ++i)
{
options.params->params[i].pflags = PARAM_FLAG_CONST;
options.params->params[i].ptype = pgTypes[i];
}
return arrow::Status::OK();
}
template <typename OnSuccessFunc>
arrow::Status execute(SPIPlanPtr plan,
const std::shared_ptr<arrow::RecordBatch>& recordBatch,
SPIExecuteOptions& options,
OnSuccessFunc onSuccess)
{
const auto& columns = recordBatch->columns();
for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
{
ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options));
auto result = SPI_execute_plan_extended(plan, &options);
if (result <= 0)
{
return arrow::Status::Invalid("failed to run a prepared statement: ",
SPI_result_code_string(result),
": ",
query_);
}
ARROW_RETURN_NOT_OK(onSuccess());
}
return arrow::Status::OK();
}
arrow::Status assign_parameters(
const std::shared_ptr<arrow::RecordBatch>& recordBatch,
int64_t i_row,
const std::vector<std::shared_ptr<arrow::Array>>& columns,
SPIExecuteOptions& options)
{
int64_t i_column = 0;
for (const auto& column : columns)
{
auto param = &(options.params->params[i_column]);
param->isnull = column->IsNull(i_row);
if (!param->isnull)
{
ArrowPGValueConverter converter(i_row, param->value);
ARROW_RETURN_NOT_OK(column->Accept(&converter));
}
++i_column;
}
return arrow::Status::OK();
}
std::string query_;
std::vector<std::shared_ptr<arrow::RecordBatch>> parameters_;
};
class Executor : public WorkerProcessor {
public:
explicit Executor(uint64_t sessionID)
: WorkerProcessor("executor"),
sessionID_(sessionID),
localSession_(std::make_shared<LocalSessionData>()),
needFinish_(false),
closed_(false),
nextPreparedStatementID_(1),
preparedStatements_()
{
localSession_->id = sessionID_;
localSession_->peerPID = sharedData_->serverPID;
}
~Executor()
{
if (!closed_)
{
close();
}
}
void open()
{
const char* tag = "open";
// Use this when you want to attach a executor process.
// sleep(5);
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": opening").c_str());
auto session = find_session();
PG_TRY();
{
auto databaseName =
static_cast<const char*>(dsa_get_address(area_, session->databaseName));
auto userName =
static_cast<const char*>(dsa_get_address(area_, session->userName));
auto password =
static_cast<const char*>(dsa_get_address(area_, session->password));
auto clientAddress =
static_cast<const char*>(dsa_get_address(area_, session->clientAddress));
BackgroundWorkerInitializeConnection(databaseName, userName, 0);
CurrentResourceOwner =
ResourceOwnerCreate(nullptr, "arrow-flight-sql: Executor");
if (check_password(session, databaseName, userName, password, clientAddress))
{
// TODO: Customizable.
SharedRingBuffer::allocate_data(
&(session->bufferData), area_, 1L * 1024L * 1024L);
}
session->initialized = true;
localSession_->valid = true;
localSession_->bufferData = &(session->bufferData);
localSession_->bufferAddress =
dsa_get_address(area_, session->bufferData.pointer);
P("%s: %s: open: %" PRIu64, Tag, tag_, session->bufferData.pointer);
dshash_release_lock(sessions_, session);
signal_server(tag);
}
PG_CATCH();
{
set_error_message(session, "failed to connect", tag);
session->initialized = true;
dshash_release_lock(sessions_, session);
signal_server(tag);
PG_RE_THROW();
}
PG_END_TRY();
}
void close()
{
closed_ = true;
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": closing").c_str());
preparedStatements_.clear();
if (needFinish_)
{
SPI_finish();
needFinish_ = false;
}
if (CurrentResourceOwner)
{
auto resourceOwner = CurrentResourceOwner;
CurrentResourceOwner = nullptr;
ResourceOwnerRelease(
resourceOwner, RESOURCE_RELEASE_BEFORE_LOCKS, false, true);
ResourceOwnerRelease(resourceOwner, RESOURCE_RELEASE_LOCKS, false, true);
ResourceOwnerRelease(
resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS, false, true);
ResourceOwnerDelete(resourceOwner);
}
pgstat_report_activity(STATE_IDLE, nullptr);
}
arrow::Result<size_t> read(LocalSessionData* session,
SharedRingBuffer* buffer,
size_t n,
void* output) override
{
ProcessorLockGuard lock(this);
return buffer->read(n, output);
}
arrow::Result<size_t> write(LocalSessionData* session,
SharedRingBuffer* buffer,
const void* data,
size_t n) override
{
ProcessorLockGuard lock(this);
return buffer->write(data, n);
}
void signaled()
{
Action action;
{
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
action = session->action;
session->action = Action::None;
}
P("%s: %s: signaled: before: %s", Tag, tag_, action_name(action));
PG_TRY();
{
switch (action)
{
case Action::Select:
select();
break;
case Action::Update:
update();
break;
case Action::Prepare:
prepare();
break;
case Action::ClosePreparedStatement:
close_prepared_statement();
break;
case Action::SetParameters:
set_parameters();
break;
case Action::SelectPreparedStatement:
select_prepared_statement();
break;
case Action::UpdatePreparedStatement:
update_prepared_statement();
break;
default:
// TODO: Report an error
break;
}
}
PG_CATCH();
{
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
if (!DsaPointerIsValid(session->errorMessage))
{
auto error = CopyErrorData();
set_error_message(session,
std::string("failed to run: ") + action_name(action) +
": " + error->message,
"unexpected error");
FreeErrorData(error);
}
pgstat_report_activity(STATE_IDLE, nullptr);
PG_RE_THROW();
}
PG_END_TRY();
pgstat_report_activity(STATE_IDLE, nullptr);
P("%s: %s: signaled: after: %s", Tag, tag_, action_name(action));
}
protected:
pid_t peer_pid(LocalSessionData* session) override { return sharedData_->serverPID; }
const char* peer_name() override { return "server"; }
arrow::Status wait_internal(LocalSessionData* sessio51n,
SharedRingBuffer* buffer,
WaitMode mode,
const char* tag) override
{
while (true)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
WaitLatch(MyLatch, events, -1, PG_WAIT_EXTENSION);
if (GotSIGTERM)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
P("%s: %s: %s: %s: wait: %" PRIsize,
Tag,
tag_,
tag,
peer_name(),
get_waiting_buffer_size(buffer, mode));
if (get_waiting_buffer_size(buffer, mode) > 0)
{
break;
}
}
// TODO: Convert PG error to arrow::Status.
CHECK_FOR_INTERRUPTS();
}
return arrow::Status::OK();
}
private:
SharedSessionData* find_session()
{
auto session =
static_cast<SharedSessionData*>(dshash_find(sessions_, &sessionID_, false));
if (!session)
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: unknown session: %" PRIu64, Tag, tag_, sessionID_));
}
return session;
}
void signal_server(const char* tag)
{
if (sharedData_->serverPID == InvalidPid)
{
return;
}
P("%s: %s: %s: kill server: %d", Tag, tag_, tag, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}
void set_error_message(SharedSessionData* session,
const std::string& message,
const char* tag)
{
if (DsaPointerIsValid(session->errorMessage))
{
return;
}
{
ProcessorLockGuard lock(this);
set_shared_string(session->errorMessage, message);
}
signal_server(tag);
}
bool check_password(SharedSessionData* session,
const char* databaseName,
const char* userName,
const char* password,
const char* clientAddress)
{
const char* tag = "check password";
MemoryContext memoryContext =
AllocSetContextCreate(CurrentMemoryContext,
"arrow-flight-sql: Executor::check_password()",
ALLOCSET_DEFAULT_SIZES);
ScopedMemoryContext scopedMemoryContext(memoryContext);
Port port = {};
port.database_name = pstrdup(databaseName);
port.user_name = pstrdup(userName);
if (!fill_client_address(session, &port, clientAddress))
{
return false;
}
load_hba();
hba_getauthmethod(&port);
if (!port.hba)
{
set_error_message(session, "failed to get auth method", tag);
return false;
}
switch (port.hba->auth_method)
{
case uaMD5:
// TODO
set_error_message(session, "MD5 auth method isn't supported yet", tag);
return false;
case uaSCRAM:
// TODO
set_error_message(session, "SCRAM auth method isn't supported yet", tag);
return false;
case uaPassword:
{
const char* logDetail = nullptr;
auto shadowPassword = get_role_password(port.user_name, &logDetail);
if (!shadowPassword)
{
set_error_message(session,
std::string("failed to get password: ") + logDetail,
tag);
return false;
}
auto result = plain_crypt_verify(
port.user_name, shadowPassword, password, &logDetail);
if (result != STATUS_OK)
{
set_error_message(
session,
std::string("failed to verify password: ") + logDetail,
tag);
return false;
}
return true;
}
case uaTrust:
return true;
default:
set_error_message(session,
std::string("unsupported auth method: ") +
hba_authname(port.hba->auth_method),
tag);
return false;
}
}
bool fill_client_address(SharedSessionData* session,
Port* port,
const char* clientAddress)
{
const char* tag = "fill client address";
// clientAddress: "ipv4:127.0.0.1:40468"
// family: "ipv4"
// host: "127.0.0.1"
// port: "40468"
std::stringstream clientAddressStream{std::string(clientAddress)};
std::string clientFamily("");
std::string clientHost("");
std::string clientPort("");
std::getline(clientAddressStream, clientFamily, ':');
std::getline(clientAddressStream, clientHost, ':');
std::getline(clientAddressStream, clientPort);
if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
{
set_error_message(
session,
std::string("client family must be ipv4 or ipv6: ") + clientFamily,
tag);
return false;
}
auto clientPortStart = clientPort.c_str();
char* clientPortEnd = nullptr;
auto clientPortNumber = std::strtoul(clientPortStart, &clientPortEnd, 10);
if (clientPortEnd[0] != '\0')
{
set_error_message(
session, std::string("client port is invalid: ") + clientPort, tag);
return false;
}
if (clientPortNumber == 0)
{
set_error_message(session, std::string("client port must not 0"), tag);
return false;
}
if (clientPortNumber > 65535)
{
set_error_message(session,
std::string("client port is too large: ") +
std::to_string(clientPortNumber),
tag);
return false;
}
if (clientFamily == "ipv4")
{
auto raddr = reinterpret_cast<sockaddr_in*>(&(port->raddr.addr));
port->raddr.salen = sizeof(sockaddr_in);
raddr->sin_family = AF_INET;
raddr->sin_port = htons(clientPortNumber);
if (inet_pton(AF_INET, clientHost.c_str(), &(raddr->sin_addr)) == 0)
{
set_error_message(
session,
std::string("client IPv4 address is invalid: ") + clientHost,
tag);
return false;
}
}
else if (clientFamily == "ipv6")
{
auto raddr = reinterpret_cast<sockaddr_in6*>(&(port->raddr.addr));
port->raddr.salen = sizeof(sockaddr_in6);
raddr->sin6_family = AF_INET6;
raddr->sin6_port = htons(clientPortNumber);
raddr->sin6_flowinfo = 0;
if (inet_pton(AF_INET6, clientHost.c_str(), &(raddr->sin6_addr)) == 0)
{
set_error_message(
session,
std::string("client IPv6 address is invalid: ") + clientHost,
tag);
return false;
}
raddr->sin6_scope_id = 0;
}
return true;
}
void select()
{
const char* tag = "select";
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
if (!DsaPointerIsValid(session->selectQuery))
{
set_error_message(
session,
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing",
tag);
return;
}
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session->selectQuery));
dsa_free(area_, session->selectQuery);
session->selectQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
SPI_connect();
needFinish_ = true;
auto result = SPI_execute(query.c_str(), true, 0);
if (result > 0)
{
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": " + tag + ": writing").c_str());
auto status = write_record_batches(tag);
if (!status.ok())
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to write: " + status.ToString(),
tag);
}
}
else
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result),
tag);
}
SPI_finish();
needFinish_ = false;
signal_server(tag);
}
}
arrow::Status write_record_batches(const char* tag)
{
SharedRingBufferOutputStream output(this, localSession_);
std::vector<std::shared_ptr<arrow::Field>> fields;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i);
ARROW_ASSIGN_OR_RAISE(auto type,
ArrowArrayBuilderBase::arrow_type(attribute));
fields.push_back(arrow::field(
NameStr(attribute->attname), std::move(type), !attribute->attnotnull));
}
auto schema = arrow::schema(fields);
ARROW_ASSIGN_OR_RAISE(
auto builder,
arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
std::vector<std::unique_ptr<ArrowArrayBuilderBase>> builders;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i);
ARROW_ASSIGN_OR_RAISE(
auto array_builder,
ArrowArrayBuilderBase::make(attribute, i, builder->GetField(i)));
builders.push_back(std::move(array_builder));
}
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
// Write schema only stream format data to return only schema.
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
// Build an empty record batch to write schema.
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
P("%s: %s: %s: write: schema: WriteRecordBatch", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
P("%s: %s: %s: write: schema: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
// Write another stream format data with record batches.
ARROW_ASSIGN_OR_RAISE(writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
uint64_t iTuple = 0;
for (; iTuple < SPI_processed; iTuple += MaxNRowsPerRecordBatch)
{
uint64_t iTupleEnd = iTuple + MaxNRowsPerRecordBatch;
if (iTupleEnd >= SPI_processed)
{
iTupleEnd = SPI_processed;
}
P("%s: %s: %s: write: data: record batch: %" PRIu64 "/%" PRIu64,
Tag,
tag_,
tag,
iTuple,
iTupleEnd);
for (int iAttribute = 0; iAttribute < SPI_tuptable->tupdesc->natts;
++iAttribute)
{
P("%s: %s: %s: write: data: record batch: %" PRIu64 "/%" PRIu64 ": %d/%d",
Tag,
tag_,
tag,
iTuple,
iTupleEnd,
iAttribute,
SPI_tuptable->tupdesc->natts);
ARROW_RETURN_NOT_OK(builders[iAttribute]->build(iTuple, iTupleEnd));
}
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: %s: write: data: WriteRecordBatch: %" PRIu64 "/%" PRIu64,
Tag,
tag_,
tag,
iTuple,
iTupleEnd);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
}
if (iTuple < SPI_processed)
{
P("%s: %s: %s: write: data: record batch: last: %" PRIu64 "/%" PRIu64,
Tag,
tag_,
tag,
iTuple,
SPI_processed);
for (int iAttribute = 0; iAttribute < SPI_tuptable->tupdesc->natts;
++iAttribute)
{
P("%s: %s: %s: write: data: record batch: last: %" PRIu64 "/%" PRIu64
": %d/%d",
Tag,
tag_,
tag,
iTuple,
SPI_processed,
iAttribute,
SPI_tuptable->tupdesc->natts);
ARROW_RETURN_NOT_OK(builders[iAttribute]->build(iTuple, SPI_processed));
}
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: %s: write: data: WriteRecordBatch: last: %" PRIu64 "/%" PRIu64,
Tag,
tag_,
tag,
iTuple,
SPI_processed);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
}
P("%s: %s: %s, write: data: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
return output.Close();
}
void update()
{
const char* tag = "update";
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
if (!DsaPointerIsValid(session->updateQuery))
{
set_error_message(
session,
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing",
tag);
return;
}
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session->updateQuery));
dsa_free(area_, session->updateQuery);
session->updateQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
SPI_connect();
needFinish_ = true;
auto result = SPI_execute(query.c_str(), false, 0);
if (result > 0)
{
session->nUpdatedRecords = SPI_processed;
}
else
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result),
tag);
}
SPI_finish();
needFinish_ = false;
signal_server(tag);
}
}
void prepare()
{
const char* tag = "prepare";
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": preparing").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
if (!DsaPointerIsValid(session->prepareQuery))
{
set_error_message(
session,
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing",
tag);
return;
}
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session->prepareQuery));
dsa_free(area_, session->prepareQuery);
session->prepareQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
std::string handle(std::to_string(nextPreparedStatementID_++));
preparedStatements_.insert(
std::make_pair(handle, PreparedStatement(std::move(query))));
{
ProcessorLockGuard lock(this);
set_shared_string(session->preparedStatementHandle, handle);
}
signal_server(tag);
}
bool extract_handle(SharedSessionData* session, std::string& handle, const char* tag)
{
if (!DsaPointerIsValid(session->preparedStatementHandle))
{
set_error_message(
session,
std::string(Tag) + ": " + tag_ + ": " + tag + ": handle is missing",
tag);
return false;
}
{
ProcessorLockGuard lock(this);
handle = static_cast<const char*>(
dsa_get_address(area_, session->preparedStatementHandle));
dsa_free(area_, session->preparedStatementHandle);
session->preparedStatementHandle = InvalidDsaPointer;
}
return true;
}
void close_prepared_statement()
{
const char* tag = "close prepared statement";
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": " + tag).c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
std::string handle;
if (!extract_handle(session, handle, tag))
{
return;
}
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (preparedStatements_.erase(handle) == 0)
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": nonexistent handle: <" + handle + ">",
tag);
}
signal_server(tag);
}
PreparedStatement* find_prepared_statement(SharedSessionData* session,
std::string& handle,
const char* tag)
{
if (!extract_handle(session, handle, tag))
{
return nullptr;
}
ProcessorLockGuard lock(this);
auto it = preparedStatements_.find(handle);
if (it == preparedStatements_.end())
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": nonexistent handle: <" + handle + ">",
tag);
return nullptr;
}
else
{
return &(it->second);
}
}
void set_parameters()
{
const char* tag = "set parameters";
pgstat_report_activity(STATE_RUNNING,
(std::string(Tag) + ": setting parameters").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
std::string handle;
auto preparedStatement = find_prepared_statement(session, handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
auto input = std::make_shared<SharedRingBufferInputStream>(this, localSession_);
auto status = preparedStatement->set_parameters(input);
if (!status.ok())
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to set parameters: <" + handle +
">: " + status.ToString(),
tag);
}
session->setParametersFinished = true;
signal_server(tag);
}
void select_prepared_statement()
{
const char* tag = "select prepared statement";
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": selecting prepared statement").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
std::string handle;
auto preparedStatement = find_prepared_statement(session, handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
SPI_connect();
needFinish_ = true;
struct Data {
Executor* executor;
const char* tag;
} data = {this, tag};
auto status = preparedStatement->select(
[](void* data) {
auto d = static_cast<Data*>(data);
return d->executor->write_record_batches(d->tag);
},
&data);
if (!status.ok())
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to select a prepared statement: <" + handle +
">: " + status.ToString(),
tag);
}
SPI_finish();
needFinish_ = false;
signal_server(tag);
}
void update_prepared_statement()
{
const char* tag = "update prepared statement";
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": updating prepared statement").c_str());
auto session = find_session();
SharedSessionReleaser sessionReleaser(sessions_, session);
std::string handle;
auto preparedStatement = find_prepared_statement(session, handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
SPI_connect();
needFinish_ = true;
auto input = std::make_shared<SharedRingBufferInputStream>(this, localSession_);
auto n_updated_records_result = preparedStatement->update(input);
if (n_updated_records_result.ok())
{
session->nUpdatedRecords = *n_updated_records_result;
}
else
{
set_error_message(session,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to update a prepared statement: <" + handle +
">: " + n_updated_records_result.status().ToString(),
tag);
}
SPI_finish();
needFinish_ = false;
signal_server(tag);
}
uint64_t sessionID_;
std::shared_ptr<LocalSessionData> localSession_;
bool needFinish_;
bool closed_;
uint64_t nextPreparedStatementID_;
std::map<std::string, PreparedStatement> preparedStatements_;
};
arrow::Status
SharedRingBufferOutputStream::Write(const void* data, int64_t nBytes)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferOutputStream is closed");
}
if (ARROW_PREDICT_TRUE(nBytes > 0))
{
SharedRingBuffer buffer(
localSession_->bufferData, localSession_->bufferAddress, processor_->tag());
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
ARROW_ASSIGN_OR_RAISE(
auto writtenBytes,
processor_->write(localSession_.get(), &buffer, data, rest));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() + ": interrupted");
}
if (writtenBytes == 0)
{
ARROW_RETURN_NOT_OK(processor_->wait(
localSession_.get(), &buffer, Processor::WaitMode::Read));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() + ": interrupted");
}
continue;
}
position_ += writtenBytes;
rest -= writtenBytes;
data = static_cast<const uint8_t*>(data) + writtenBytes;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
}
}
return arrow::Status::OK();
}
// There is only one Proxy object in a PostgreSQL instance. The Proxy
// object communicates multiple "executor" processes.
//
// This class must be thread safe. We call PostgreSQL API only in
// signaled(). We can't call PostgreSQL API in other methods because
// they are called from gRPC threads not the main thread. PostgreSQL
// API isn't thread safe and must be called only in the main thread.
//
// If we need to call PostgreSQL API from gRPC threads, we need to
// create a request and process it in signaled().
class Proxy : public WorkerProcessor {
public:
explicit Proxy()
: WorkerProcessor("proxy"),
mutex_(),
conditionVariable_(),
randomSeed_(),
randomEngine_(randomSeed_()),
localSessions_(),
connectRequests_(),
selectRequests_(),
updateRequests_(),
prepareRequests_(),
closePreparedStatementRequests_(),
setParametersRequests_(),
selectPreparedStatementRequests_(),
updatePreparedStatementRequests_(),
readRequests_(),
writeRequests_()
{
}
private:
std::mutex mutex_;
std::condition_variable conditionVariable_;
std::random_device randomSeed_;
std::mt19937_64 randomEngine_;
std::map<uint64_t, std::shared_ptr<LocalSessionData>> localSessions_;
public:
arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> open_reader(
uint64_t sessionID)
{
std::shared_ptr<LocalSessionData> localSession;
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = localSessions_.find(sessionID);
if (it == localSessions_.end())
{
return arrow::Status::Invalid("unknown session: ", sessionID);
}
localSession = it->second;
}
auto input =
std::make_shared<SharedRingBufferInputStream>(this, std::move(localSession));
// Read another stream format data with record batches.
return arrow::ipc::RecordBatchStreamReader::Open(input);
}
// This is only called from PostgreSQL (main) thread. We can use
// PostgreSQL API only in this method and methods called from this
// method.
void signaled()
{
process_connect_requests();
process_select_requests();
process_update_requests();
process_prepare_requests();
process_close_prepared_statement_requests();
process_set_parameters_requests();
process_select_prepared_statement_requests();
process_update_prepared_statement_requests();
process_read_requests();
process_write_requests();
sync_sessions();
conditionVariable_.notify_all();
}
// Don't use PostgreSQL API.
bool is_valid_session(uint64_t sessionID) { return find_session(sessionID).ok(); }
private:
// Don't use PostgreSQL API.
arrow::Result<std::shared_ptr<LocalSessionData>> find_session(uint64_t id)
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = localSessions_.find(id);
if (it == localSessions_.end())
{
return arrow::Status::Invalid("Unknown session: ", id);
}
return it->second;
}
// Don't use PostgreSQL API.
arrow::Status check_local_session_error(LocalSessionData* localSession,
bool locked = false)
{
if (!localSession->errorMessage.has_value())
{
return arrow::Status::OK();
}
if (locked)
{
auto errorMessage = std::move(localSession->errorMessage.value());
localSession->errorMessage = std::nullopt;
return arrow::Status::Invalid(errorMessage);
}
else
{
std::lock_guard<std::mutex> lock(mutex_);
if (!localSession->errorMessage.has_value())
{
return arrow::Status::OK();
}
auto errorMessage = std::move(localSession->errorMessage.value());
localSession->errorMessage = std::nullopt;
return arrow::Status::Invalid(errorMessage);
}
}
// Don't use PostgreSQL API.
arrow::Status check_local_session_error(
const std::shared_ptr<LocalSessionData>& localSession)
{
return check_local_session_error(localSession.get());
}
private:
struct ConnectRequest {
ConnectRequest(uint64_t sessionID,
std::string databaseName,
std::string userName,
std::string password,
std::string clientAddress)
: sessionID(sessionID),
databaseName(std::move(databaseName)),
userName(std::move(userName)),
password(std::move(password)),
clientAddress(std::move(clientAddress)),
processing(false),
finished(false)
{
}
uint64_t sessionID;
std::string databaseName;
std::string userName;
std::string password;
std::string clientAddress;
bool processing;
bool finished;
};
std::list<std::shared_ptr<ConnectRequest>> connectRequests_;
// Can use PostgreSQL API.
void process_connect_requests()
{
std::lock_guard<std::mutex> lock(mutex_);
bool haveRequest = false;
for (auto it = connectRequests_.begin(); it != connectRequests_.end();)
{
auto& request = *it;
ProcessorLockGuard lock(this);
if (request->processing)
{
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->sessionID), false));
if (!session)
{
++it;
continue;
}
const auto initialized = session->initialized;
if (initialized && !DsaPointerIsValid(session->errorMessage))
{
auto& localSession = localSessions_.find(request->sessionID)->second;
localSession->valid = true;
localSession->peerPID = session->executorPID;
localSession->bufferData = &(session->bufferData);
localSession->bufferAddress =
dsa_get_address(area_, session->bufferData.pointer);
P("%s: %s: connect: %" PRIu64,
Tag,
tag_,
session->bufferData.pointer);
}
dshash_release_lock(sessions_, session);
if (!initialized)
{
++it;
continue;
}
request->finished = true;
it = connectRequests_.erase(it);
}
else
{
bool found;
auto session = static_cast<SharedSessionData*>(
dshash_find_or_insert(sessions_, &(request->sessionID), &found));
if (found)
{
request->finished = true;
auto& localSession = localSessions_.find(request->sessionID)->second;
localSession->errorMessage = std::string("duplicated session ID: ") +
std::to_string(request->sessionID);
it = connectRequests_.erase(it);
}
else
{
request->processing = true;
shared_session_data_initialize(session,
area_,
request->databaseName,
request->userName,
request->password,
request->clientAddress);
haveRequest = true;
++it;
}
dshash_release_lock(sessions_, session);
}
}
if (haveRequest)
{
kill(sharedData_->mainPID, SIGUSR1);
}
}
// Don't use PostgreSQL API.
uint64_t assign_session_id()
{
std::lock_guard<std::mutex> lock(mutex_);
while (true)
{
auto id = randomEngine_();
if (id == 0)
{
continue;
}
auto beforeSize = localSessions_.size();
auto localSession = std::make_shared<LocalSessionData>();
localSessions_.insert(std::make_pair(id, localSession));
if (localSessions_.size() == beforeSize)
{
continue;
}
localSession->id = id;
localSession->valid = false;
return id;
}
}
public:
// Don't use PostgreSQL API.
arrow::Result<uint64_t> connect(const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
auto id = assign_session_id();
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(id));
auto request = std::make_shared<ConnectRequest>(
id, databaseName, userName, password, clientAddress);
{
std::lock_guard<std::mutex> lock(mutex_);
connectRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
return id;
}
private:
struct SelectRequest {
SelectRequest(std::shared_ptr<LocalSessionData> localSession, std::string query)
: localSession(std::move(localSession)),
query(std::move(query)),
finished(false)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string query;
bool finished;
};
std::list<std::shared_ptr<SelectRequest>> selectRequests_;
// Can use PostgreSQL API.
void process_select_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "select";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto& request : selectRequests_)
{
request->finished = true;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
continue;
}
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->selectQuery, request->query);
session->action = Action::Select;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
}
selectRequests_.clear();
}
arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(
std::shared_ptr<LocalSessionData> localSession, const char* tag)
{
auto input =
std::make_shared<SharedRingBufferInputStream>(this, std::move(localSession));
// Read schema only stream format data.
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
P("%s: %s: %s: read next", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
}
return reader->schema();
}
public:
arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
const std::string& query)
{
const char* tag = "select";
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request = std::make_shared<SelectRequest>(localSession, query);
{
std::lock_guard<std::mutex> lock(mutex_);
selectRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: open", Tag, tag_, tag);
auto schema = read_schema(localSession, tag);
P("%s: %s: %s: schema", Tag, tag_, tag);
return schema;
}
private:
struct UpdateRequest {
UpdateRequest(std::shared_ptr<LocalSessionData> localSession, std::string query)
: localSession(std::move(localSession)),
query(std::move(query)),
processing(false),
finished(false),
nUpdatedRecords(0)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string query;
bool processing;
bool finished;
int64_t nUpdatedRecords;
};
std::list<std::shared_ptr<UpdateRequest>> updateRequests_;
// Can use PostgreSQL API.
void process_update_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "update";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = updateRequests_.begin(); it != updateRequests_.end();)
{
auto& request = *it;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
it = updateRequests_.erase(it);
continue;
}
if (request->processing)
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
request->processing = false;
request->finished = true;
it = updateRequests_.erase(it);
}
else if (session->nUpdatedRecords >= 0)
{
request->processing = false;
request->finished = true;
request->nUpdatedRecords = session->nUpdatedRecords;
it = updateRequests_.erase(it);
}
else
{
++it;
}
}
else
{
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->updateQuery, request->query);
session->action = Action::Update;
session->nUpdatedRecords = -1;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
request->processing = true;
++it;
}
}
}
public:
arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query)
{
#ifdef AFS_VERBOSE
const char* tag = "update";
#endif
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request = std::make_shared<UpdateRequest>(localSession, query);
{
std::lock_guard<std::mutex> lock(mutex_);
updateRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: done: %" PRIu64, Tag, tag_, tag, request->nUpdatedRecords);
return request->nUpdatedRecords;
}
private:
struct PrepareRequest {
PrepareRequest(std::shared_ptr<LocalSessionData> localSession, std::string query)
: localSession(std::move(localSession)),
query(std::move(query)),
processing(false),
finished(false),
handle()
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string query;
bool processing;
bool finished;
std::string handle;
};
std::list<std::shared_ptr<PrepareRequest>> prepareRequests_;
// Can use PostgreSQL API.
void process_prepare_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "prepare";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = prepareRequests_.begin(); it != prepareRequests_.end();)
{
auto& request = *it;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
it = prepareRequests_.erase(it);
continue;
}
if (request->processing)
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
request->processing = false;
request->finished = true;
it = prepareRequests_.erase(it);
}
else if (DsaPointerIsValid(session->preparedStatementHandle))
{
request->processing = false;
request->finished = true;
request->handle = static_cast<const char*>(
dsa_get_address(area_, session->preparedStatementHandle));
it = prepareRequests_.erase(it);
}
else
{
++it;
}
}
else
{
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->prepareQuery, request->query);
session->action = Action::Prepare;
session->nUpdatedRecords = -1;
set_shared_string(session->preparedStatementHandle, std::string(""));
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
request->processing = true;
++it;
}
}
}
public:
arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult> prepare(
uint64_t sessionID, const std::string& query)
{
#ifdef AFS_VERBOSE
const char* tag = "prepare";
#endif
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request = std::make_shared<PrepareRequest>(localSession, query);
{
std::lock_guard<std::mutex> lock(mutex_);
prepareRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
arrow::flight::sql::ActionCreatePreparedStatementResult result = {
nullptr,
nullptr,
std::move(request->handle),
};
P("%s: %s: %s: done", Tag, tag_, tag);
return result;
}
private:
struct ClosePreparedStatementRequest {
ClosePreparedStatementRequest(std::shared_ptr<LocalSessionData> localSession,
std::string handle)
: localSession(std::move(localSession)),
handle(std::move(handle)),
processing(false),
finished(false)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string handle;
bool processing;
bool finished;
};
std::list<std::shared_ptr<ClosePreparedStatementRequest>>
closePreparedStatementRequests_;
// Can use PostgreSQL API.
void process_close_prepared_statement_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "close prepared statement";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = closePreparedStatementRequests_.begin();
it != closePreparedStatementRequests_.end();)
{
auto& request = *it;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
it = closePreparedStatementRequests_.erase(it);
continue;
}
if (request->processing)
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
request->processing = false;
request->finished = true;
it = closePreparedStatementRequests_.erase(it);
}
else if (DsaPointerIsValid(session->preparedStatementHandle))
{
++it;
}
else
{
request->processing = false;
request->finished = true;
it = closePreparedStatementRequests_.erase(it);
}
}
else
{
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->preparedStatementHandle, request->handle);
session->action = Action::ClosePreparedStatement;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
request->processing = true;
++it;
}
}
}
public:
arrow::Status close_prepared_statement(uint64_t sessionID, const std::string& handle)
{
#ifdef AFS_VERBOSE
const char* tag = "close prepared statement";
#endif
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request =
std::make_shared<ClosePreparedStatementRequest>(localSession, handle);
{
std::lock_guard<std::mutex> lock(mutex_);
closePreparedStatementRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: done", Tag, tag_, tag);
return arrow::Status::OK();
}
private:
struct SetParametersRequest {
SetParametersRequest(std::shared_ptr<LocalSessionData> localSession,
std::string handle)
: localSession(std::move(localSession)),
handle(std::move(handle)),
processing(false),
finished(false)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string handle;
bool processing;
bool finished;
};
std::list<std::shared_ptr<SetParametersRequest>> setParametersRequests_;
// Can use PostgreSQL API.
void process_set_parameters_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "set parameters";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = setParametersRequests_.begin();
it != setParametersRequests_.end();)
{
auto& request = *it;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
it = setParametersRequests_.erase(it);
continue;
}
if (request->processing)
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
request->processing = false;
request->finished = true;
it = setParametersRequests_.erase(it);
}
else if (session->setParametersFinished)
{
request->processing = false;
request->finished = true;
it = setParametersRequests_.erase(it);
}
else
{
++it;
}
}
else
{
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->preparedStatementHandle, request->handle);
session->action = Action::SetParameters;
session->setParametersFinished = false;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
request->processing = true;
++it;
}
}
}
public:
arrow::Status set_parameters(uint64_t sessionID,
const std::string& handle,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer)
{
#ifdef AFS_VERBOSE
const char* tag = "set parameters";
#endif
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request = std::make_shared<SetParametersRequest>(localSession, handle);
{
std::lock_guard<std::mutex> lock(mutex_);
setParametersRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
auto executorPID = localSession->peerPID;
{
ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema());
SharedRingBufferOutputStream output(this, localSession);
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
while (true)
{
ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next());
if (!chunk.data)
{
break;
}
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
}
ARROW_RETURN_NOT_OK(writer->Close());
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: done", Tag, tag_, tag);
return arrow::Status::OK();
}
private:
struct SelectPreparedStatementRequest {
SelectPreparedStatementRequest(std::shared_ptr<LocalSessionData> localSession,
std::string handle)
: localSession(std::move(localSession)),
handle(std::move(handle)),
finished(false)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string handle;
bool finished;
};
std::list<std::shared_ptr<SelectPreparedStatementRequest>>
selectPreparedStatementRequests_;
// Can use PostgreSQL API.
void process_select_prepared_statement_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "select prepared statement";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto& request : selectPreparedStatementRequests_)
{
request->finished = true;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
continue;
}
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->preparedStatementHandle, request->handle);
session->action = Action::SelectPreparedStatement;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
}
selectPreparedStatementRequests_.clear();
}
public:
arrow::Result<std::shared_ptr<arrow::Schema>> select_prepared_statement(
uint64_t sessionID, const std::string& handle)
{
const char* tag = "select prepared statement";
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request =
std::make_shared<SelectPreparedStatementRequest>(localSession, handle);
{
std::lock_guard<std::mutex> lock(mutex_);
selectPreparedStatementRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: open", Tag, tag_, tag);
auto schema = read_schema(localSession, tag);
P("%s: %s: %s: schema", Tag, tag_, tag);
return schema;
}
private:
struct UpdatePreparedStatementRequest {
UpdatePreparedStatementRequest(std::shared_ptr<LocalSessionData> localSession,
std::string handle)
: localSession(std::move(localSession)),
handle(std::move(handle)),
processing(false),
finished(false),
nUpdatedRecords(0)
{
}
std::shared_ptr<LocalSessionData> localSession;
std::string handle;
bool processing;
bool finished;
int64_t nUpdatedRecords;
};
std::list<std::shared_ptr<UpdatePreparedStatementRequest>>
updatePreparedStatementRequests_;
// Can use PostgreSQL API.
void process_update_prepared_statement_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "update prepared statement";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = updatePreparedStatementRequests_.begin();
it != updatePreparedStatementRequests_.end();)
{
auto& request = *it;
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
it = updatePreparedStatementRequests_.erase(it);
continue;
}
if (request->processing)
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
request->processing = false;
request->finished = true;
it = updatePreparedStatementRequests_.erase(it);
}
else if (session->nUpdatedRecords >= 0)
{
request->processing = false;
request->finished = true;
request->nUpdatedRecords = session->nUpdatedRecords;
it = updatePreparedStatementRequests_.erase(it);
}
else
{
++it;
}
}
else
{
auto executorPID = session->executorPID;
{
SharedSessionReleaser sessionReleaser(sessions_, session);
ProcessorLockGuard lock(this);
set_shared_string(session->preparedStatementHandle, request->handle);
session->action = Action::UpdatePreparedStatement;
session->nUpdatedRecords = -1;
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
request->processing = true;
++it;
}
}
}
public:
arrow::Result<int64_t> update_prepared_statement(
uint64_t sessionID,
const std::string& handle,
arrow::flight::FlightMessageReader* reader)
{
#ifdef AFS_VERBOSE
const char* tag = "update prepared statement";
#endif
ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(sessionID));
auto request =
std::make_shared<UpdatePreparedStatementRequest>(localSession, handle);
{
std::lock_guard<std::mutex> lock(mutex_);
updatePreparedStatementRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
auto executorPID = localSession->peerPID;
{
ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema());
SharedRingBufferOutputStream output(this, localSession);
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
while (true)
{
ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next());
if (!chunk.data)
{
break;
}
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
}
ARROW_RETURN_NOT_OK(writer->Close());
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, executorPID);
kill(executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
P("%s: %s: %s: done: %" PRIu64, Tag, tag_, tag, request->nUpdatedRecords);
return request->nUpdatedRecords;
}
private:
struct ReadRequest {
ReadRequest(LocalSessionData* localSession,
SharedRingBuffer* buffer,
size_t n,
void* output)
: localSession(localSession),
buffer(buffer),
n(n),
output(output),
finished(false),
readBytes(0)
{
}
LocalSessionData* localSession;
SharedRingBuffer* buffer;
size_t n;
void* output;
bool finished;
size_t readBytes;
};
std::list<std::shared_ptr<ReadRequest>> readRequests_;
// Can use PostgreSQL API.
void process_read_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "read";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto& request : readRequests_)
{
request->finished = true;
{
ProcessorLockGuard lock(this);
request->readBytes = request->buffer->read(request->n, request->output);
}
P("%s: %s: %s: kill executor: %d",
Tag,
tag_,
tag,
request->localSession->peerPID);
kill(request->localSession->peerPID, SIGUSR1);
}
readRequests_.clear();
}
public:
// Don't use PostgreSQL API.
arrow::Result<size_t> read(LocalSessionData* localSession,
SharedRingBuffer* buffer,
size_t n,
void* output) override
{
auto request = std::make_shared<ReadRequest>(localSession, buffer, n, output);
{
std::lock_guard<std::mutex> lock(mutex_);
readRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError("interrupted");
}
return request->readBytes;
}
private:
struct WriteRequest {
WriteRequest(LocalSessionData* localSession,
SharedRingBuffer* buffer,
const void* data,
size_t n)
: localSession(localSession),
buffer(buffer),
data(data),
n(n),
finished(false),
writtenBytes(0)
{
}
LocalSessionData* localSession;
SharedRingBuffer* buffer;
const void* data;
size_t n;
bool finished;
size_t writtenBytes;
};
std::list<std::shared_ptr<WriteRequest>> writeRequests_;
// Can use PostgreSQL API.
void process_write_requests()
{
#ifdef AFS_VERBOSE
const char* tag = "write";
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (auto& request : writeRequests_)
{
auto session = static_cast<SharedSessionData*>(
dshash_find(sessions_, &(request->localSession->id), false));
if (!session)
{
request->finished = true;
request->localSession->errorMessage =
std::string("stolen session: ") +
std::to_string(request->localSession->id);
continue;
}
{
SharedSessionReleaser sessionReleaser(sessions_, session);
if (!DsaPointerIsValid(session->errorMessage))
{
ProcessorLockGuard lock(this);
request->writtenBytes =
request->buffer->write(request->data, request->n);
}
request->finished = true;
}
P("%s: %s: %s: kill executor: %d",
Tag,
tag_,
tag,
request->localSession->peerPID);
kill(request->localSession->peerPID, SIGUSR1);
}
writeRequests_.clear();
}
public:
// Don't use PostgreSQL API.
arrow::Result<size_t> write(LocalSessionData* localSession,
SharedRingBuffer* buffer,
const void* data,
size_t n) override
{
auto request = std::make_shared<WriteRequest>(localSession, buffer, data, n);
{
std::lock_guard<std::mutex> lock(mutex_);
writeRequests_.push_back(request);
}
kill(MyProcPid, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return request->finished;
});
}
ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError("interrupted");
}
return request->writtenBytes;
}
private:
// Can use PostgreSQL API.
void sync_sessions()
{
dshash_seq_status sessionsStatus;
dshash_seq_init(&sessionsStatus, sessions_, false);
while (true)
{
auto session =
static_cast<SharedSessionData*>(dshash_seq_next(&sessionsStatus));
if (!session)
{
break;
}
if (DsaPointerIsValid(session->errorMessage))
{
std::lock_guard<std::mutex> lock(mutex_);
auto localSession = localSessions_.find(session->id)->second;
if (localSession)
{
localSession->errorMessage = static_cast<const char*>(
dsa_get_address(area_, session->errorMessage));
}
dsa_free(area_, session->errorMessage);
session->errorMessage = InvalidDsaPointer;
}
if (session->finished)
{
{
std::lock_guard<std::mutex> lock(mutex_);
localSessions_.erase(localSessions_.find(session->id));
}
shared_session_data_finalize(session, area_);
dshash_delete_current(&sessionsStatus);
}
}
dshash_seq_term(&sessionsStatus);
}
protected:
pid_t peer_pid(LocalSessionData* session) override { return session->peerPID; }
const char* peer_name() override { return "executor"; }
arrow::Status wait_internal(LocalSessionData* localSession,
SharedRingBuffer* buffer,
WaitMode mode,
const char* tag) override
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: %s: wait: %" PRIsize ": error: %s",
Tag,
tag_,
tag,
peer_name(),
get_waiting_buffer_size(buffer, mode),
localSession->errorMessage.has_value()
? localSession->errorMessage.value().c_str()
: "");
if (localSession->errorMessage.has_value())
{
return true;
}
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return get_waiting_buffer_size(buffer, mode) > 0;
});
ARROW_RETURN_NOT_OK(check_local_session_error(localSession, true));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid(tag_, ": ", tag, ": ", "interrupted");
}
return arrow::Status::OK();
}
};
arrow::Result<int64_t>
SharedRingBufferInputStream::Read(int64_t nBytes, void* out)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferInputStream is closed");
}
size_t rest = static_cast<size_t>(nBytes);
SharedRingBuffer buffer(
localSession_->bufferData, localSession_->bufferAddress, processor_->tag());
while (true)
{
ARROW_ASSIGN_OR_RAISE(auto readBytes,
processor_->read(localSession_.get(), &buffer, rest, out));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": interrupted");
}
if (readBytes == 0)
{
ARROW_RETURN_NOT_OK(processor_->wait(
localSession_.get(), &buffer, Processor::WaitMode::Written));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() + ": interrupted");
}
continue;
}
position_ += readBytes;
rest -= readBytes;
out = static_cast<uint8_t*>(out) + readBytes;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
}
return nBytes;
}
class MainProcessor : public Processor {
public:
MainProcessor() : Processor("main"), sessions_(nullptr), executorHandles_()
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
sharedData_ = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found));
if (found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data is already created", Tag, tag_));
}
sharedData_->trancheID = LWLockNewTrancheId();
sharedData_->sessionsTrancheID = LWLockNewTrancheId();
area_ = dsa_create(sharedData_->trancheID);
sharedData_->handle = dsa_get_handle(area_);
SessionsParams.tranche_id = sharedData_->sessionsTrancheID;
sessions_ = dshash_create(area_, &SessionsParams, nullptr);
sharedData_->sessionsHandle = dshash_get_hash_table_handle(sessions_);
sharedData_->serverPID = InvalidPid;
sharedData_->mainPID = MyProcPid;
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
}
~MainProcessor() override { dshash_destroy(sessions_); }
BackgroundWorkerHandle* start_server()
{
BackgroundWorker worker = {};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: server", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
if (!RegisterDynamicBackgroundWorker(&worker, &handle))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: failed to start server", Tag, tag_));
}
WaitForBackgroundWorkerStartup(handle, &(sharedData_->serverPID));
return handle;
}
void signaled() { update_sessions(); }
private:
void update_sessions()
{
const char* tag = "update sessions";
P("%s: %s: %s: start", Tag, tag_, tag);
dshash_seq_status sessionsStatus;
dshash_seq_init(&sessionsStatus, sessions_, false);
while (true)
{
auto session =
static_cast<SharedSessionData*>(dshash_seq_next(&sessionsStatus));
if (!session)
{
break;
}
if (session->started)
{
if (!session->finished)
{
try
{
auto handle = executorHandles_.at(session->id);
pid_t pid;
if (GetBackgroundWorkerPid(handle, &pid) == BGWH_STOPPED)
{
// A finished session
session->finished = true;
executorHandles_.erase(session->id);
}
} catch (const std::exception& exception)
{
session->finished = true;
}
}
continue;
}
// A new session request
BackgroundWorker worker = {};
snprintf(
worker.bgw_name, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id);
snprintf(
worker.bgw_type, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id);
worker.bgw_flags =
BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor");
worker.bgw_main_arg = Int64GetDatum(session->id);
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
P("%s: %s: %s: register executor: %" PRIu64, Tag, tag_, tag, session->id);
session->started = true;
if (RegisterDynamicBackgroundWorker(&worker, &handle))
{
executorHandles_[session->id] = handle;
WaitForBackgroundWorkerStartup(handle, &(session->executorPID));
P("%s: %s: %s: started executor: %" PRIu64, Tag, tag_, tag, session->id);
}
else
{
P("%s: %s: %s: failed to start executor: %" PRIu64,
Tag,
tag_,
tag,
session->id);
set_shared_string(
session->errorMessage,
std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to start executor: " + std::to_string(session->id));
}
}
dshash_seq_term(&sessionsStatus);
P("%s: %s: %s: end", Tag, tag_, tag);
kill(sharedData_->serverPID, SIGUSR1);
}
dshash_table* sessions_;
std::map<uint64_t, BackgroundWorkerHandle*> executorHandles_;
};
class HeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
explicit HeaderAuthServerMiddleware(uint64_t sessionID) : sessionID_(sessionID) {}
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override
{
outgoing_headers->AddHeader("authorization",
std::string("Bearer ") + std::to_string(sessionID_));
}
void CallCompleted(const arrow::Status& status) override {}
std::string name() const override { return "HeaderAuthServerMiddleware"; }
uint64_t session_id() { return sessionID_; }
private:
uint64_t sessionID_;
};
class HeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
explicit HeaderAuthServerMiddlewareFactory(Proxy* proxy)
: arrow::flight::ServerMiddlewareFactory(), proxy_(proxy)
{
}
arrow::Status StartCall(
const arrow::flight::CallInfo& info,
#if ARROW_VERSION_MAJOR >= 13
const arrow::flight::ServerCallContext& context,
#else
const arrow::flight::CallHeaders& incoming_headers,
#endif
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override
{
std::string databaseName("postgres");
#if ARROW_VERSION_MAJOR >= 13
const auto& incomingHeaders = context.incoming_headers();
#else
const auto& incomingHeaders = incoming_headers;
#endif
auto databaseHeader = incomingHeaders.find("x-flight-sql-database");
if (databaseHeader != incomingHeaders.end())
{
databaseName = databaseHeader->second;
}
auto authorizationHeader = incomingHeaders.find("authorization");
if (authorizationHeader == incomingHeaders.end())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
"No authorization header");
}
auto value = authorizationHeader->second;
std::stringstream valueStream{std::string(value)};
std::string type("");
std::getline(valueStream, type, ' ');
if (type == "Basic")
{
std::stringstream decodedStream(
arrow::util::base64_decode(value.substr(valueStream.tellg())));
std::string userName("");
std::string password("");
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
#if ARROW_VERSION_MAJOR >= 13
const auto& clientAddress = context.peer();
#else
// 192.0.0.1 is one of reserved IPv4 addresses for documentation.
std::string clientAddress("ipv4:192.0.2.1:2929");
#endif
auto sessionIDResult =
proxy_->connect(databaseName, userName, password, clientAddress);
if (!sessionIDResult.status().ok())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
sessionIDResult.status().ToString());
}
auto sessionID = *sessionIDResult;
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else if (type == "Bearer")
{
std::string sessionIDString(value.substr(valueStream.tellg()));
if (sessionIDString.size() == 0)
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
auto start = sessionIDString.c_str();
char* end = nullptr;
uint64_t sessionID = std::strtoull(start, &end, 10);
if (end[0] != '\0')
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
if (!proxy_->is_valid_session(sessionID))
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
std::string("authorization header must use Basic or Bearer: <") + type +
std::string(">"));
}
}
private:
Proxy* proxy_;
};
// Flight SQL server doesn't communicate an "executor" process
// directly. The Proxy object communicates an "executor" process.
class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
public:
explicit FlightSQLServer(Proxy* proxy)
: arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy)
{
}
~FlightSQLServer() override {}
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
create_endpoint(std::move(ticket))};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->open_reader(sessionID));
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}
arrow::Result<int64_t> DoPutCommandStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementUpdate& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
return proxy_->update(sessionID, query);
}
arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult>
CreatePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionCreatePreparedStatementRequest& request) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = request.query;
return proxy_->prepare(sessionID, query);
}
arrow::Status ClosePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionClosePreparedStatementRequest& request) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = request.prepared_statement_handle;
return proxy_->close_prepared_statement(sessionID, handle);
}
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
GetFlightInfoPreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
ARROW_ASSIGN_OR_RAISE(auto schema,
proxy_->select_prepared_statement(sessionID, handle));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(handle));
std::vector<arrow::flight::FlightEndpoint> endpoints{
create_endpoint(std::move(ticket))};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}
arrow::Status DoPutPreparedStatementQuery(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
return proxy_->set_parameters(sessionID, handle, reader, writer);
}
arrow::Result<int64_t> DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
return proxy_->update_prepared_statement(sessionID, handle, reader);
}
private:
arrow::flight::FlightEndpoint create_endpoint(std::string ticket)
{
#if ARROW_VERSION_MAJOR >= 14
return arrow::flight::FlightEndpoint{
arrow::flight::Ticket{std::move(ticket)}, {}, std::nullopt, ""};
#elif ARROW_VERSION_MAJOR >= 13
return arrow::flight::FlightEndpoint{
arrow::flight::Ticket{std::move(ticket)}, {}, std::nullopt};
#else
return arrow::flight::FlightEndpoint
{
arrow::flight::Ticket{std::move(ticket), {}};
#endif
}
arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context)
{
auto middleware = reinterpret_cast<HeaderAuthServerMiddleware*>(
context.GetMiddleware("header-auth"));
if (!middleware)
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated, "no authorization");
}
return middleware->session_id();
}
Proxy* proxy_;
};
arrow::Status
afs_server_internal(Proxy* proxy)
{
ARROW_ASSIGN_OR_RAISE(auto location, arrow::flight::Location::Parse(URI));
arrow::flight::FlightServerOptions options(location);
if (EnableSSL)
{
arrow::flight::CertKeyPair certificate;
if (ssl_cert_file)
{
std::ifstream input(ssl_cert_file);
if (input)
{
certificate.pem_cert =
std::string(std::istreambuf_iterator<char>{input}, {});
}
}
if (ssl_key_file)
{
std::ifstream input(ssl_key_file);
if (input)
{
certificate.pem_key =
std::string(std::istreambuf_iterator<char>{input}, {});
}
}
if (!certificate.pem_cert.empty() && !certificate.pem_key.empty())
{
options.tls_certificates.push_back(std::move(certificate));
}
}
options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
options.middleware.push_back(
{"header-auth", std::make_shared<HeaderAuthServerMiddlewareFactory>(proxy)});
FlightSQLServer flightSQLServer(proxy);
ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));
ereport(LOG, (errmsg("listening on %s for Apache Arrow Flight SQL", URI)));
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
while (GotSIGUSR1)
{
// SIGUSR1 is emitted by:
// * The "main" process notifies a new "executor" process is spawned
// * The "main" process notifies a "executor" process is finished
// (may be crashed)
// * The "executor" process requests a read/write
// * The "executor" process notifies a request is finished
GotSIGUSR1 = false;
proxy->signaled();
}
CHECK_FOR_INTERRUPTS();
}
// TODO: Use before_shmem_exit()
auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10);
return flightSQLServer.Shutdown(&deadline);
}
} // namespace
static void
afs_executor_before_shmem_exit(int code, Datum arg)
{
// TODO: This doesn't work. We need to improve
// BackgroundWorkerInitializeConnection() failed case.
auto executor = reinterpret_cast<Executor*>(DatumGetPointer(arg));
executor->close();
delete executor;
}
extern "C" void
afs_executor(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
#ifdef AFS_DEBUG
pqsignal(SIGSEGV, afs_sigsegv);
#endif
BackgroundWorkerUnblockSignals();
auto executor = new Executor(DatumGetInt64(arg));
before_shmem_exit(afs_executor_before_shmem_exit, PointerGetDatum(executor));
executor->open();
while (!GotSIGTERM)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
const long timeout = SessionTimeout * 1000;
if (timeout >= 0)
{
events |= WL_TIMEOUT;
}
auto conditions = WaitLatch(MyLatch, events, timeout, PG_WAIT_EXTENSION);
if (conditions & WL_TIMEOUT)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
while (GotSIGUSR1)
{
// SIGUSR1 is emitted by:
// * The "server" process requests a new request
GotSIGUSR1 = false;
executor->signaled();
}
CHECK_FOR_INTERRUPTS();
}
executor->close();
proc_exit(0);
}
extern "C" void
afs_server(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
#ifdef AFS_DEBUG
pqsignal(SIGSEGV, afs_sigsegv);
#endif
BackgroundWorkerUnblockSignals();
{
arrow::Status status;
{
Proxy proxy;
status = afs_server_internal(&proxy);
}
if (!status.ok())
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: server: failed: %s", Tag, status.ToString().c_str()));
}
}
proc_exit(0);
}
extern "C" void
afs_main(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
MainProcessor processor;
auto serverHandle = processor.start_server();
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
while (GotSIGUSR1)
{
// SIGUSR1 is emitted by:
// * The "server" process is exited (may be crashed)
// * The "executor" process is exited (may be crashed)
// * The "server" process requests a new session
GotSIGUSR1 = false;
processor.signaled();
}
CHECK_FOR_INTERRUPTS();
}
WaitForBackgroundWorkerShutdown(serverHandle);
}
proc_exit(0);
}
extern "C" void
_PG_init(void)
{
if (!process_shared_preload_libraries_in_progress)
{
return;
}
DefineCustomStringVariable("arrow_flight_sql.uri",
"Apache Arrow Flight SQL endpoint URI.",
(std::string("default: ") + URIDefault).c_str(),
&URI,
URIDefault,
PGC_POSTMASTER,
0,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.session_timeout",
"Maximum session duration in seconds.",
"The default is 300 seconds. "
"-1 means no timeout.",
&SessionTimeout,
SessionTimeoutDefault,
-1,
INT_MAX,
PGC_USERSET,
GUC_UNIT_S,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.max_n_rows_per_record_batch",
"The maximum number of rows per record batch.",
"The default is 1 * 1024 * 1024 rows.",
&MaxNRowsPerRecordBatch,
MaxNRowsPerRecordBatchDefault,
1,
INT_MAX,
PGC_USERSET,
0,
NULL,
NULL,
NULL);
#ifdef AFS_HAVE_SHMEM_REQUEST_HOOK
PreviousShmemRequestHook = shmem_request_hook;
shmem_request_hook = afs_shmem_request_hook;
#else
afs_shmem_request_hook();
#endif
BackgroundWorker worker = {};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: main", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = 0;
RegisterBackgroundWorker(&worker);
}