blob: 44f8d434fa67b08734cb8f4eb25516fbb0b1c3cc [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
extern "C"
{
#include <postgres.h>
#include <access/xact.h>
#include <executor/spi.h>
#include <fmgr.h>
#include <lib/dshash.h>
#include <libpq/crypt.h>
#include <libpq/libpq-be.h>
#include <libpq/libpq.h>
#include <miscadmin.h>
#include <postmaster/bgworker.h>
#include <postmaster/postmaster.h>
#include <storage/ipc.h>
#include <storage/latch.h>
#include <storage/lwlock.h>
#include <storage/procsignal.h>
#include <storage/shmem.h>
#include <utils/backend_status.h>
#include <utils/builtins.h>
#include <utils/dsa.h>
#include <utils/guc.h>
#include <utils/memutils.h>
#include <utils/snapmgr.h>
#include <utils/timestamp.h>
#include <utils/wait_event.h>
}
#undef Abs
#if PG_VERSION_NUM >= 150000
# define PGRN_HAVE_SHMEM_REQUEST_HOOK
#endif
#include <arrow/buffer.h>
#include <arrow/builder.h>
#include <arrow/flight/server_middleware.h>
#include <arrow/flight/sql/server.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/ipc/writer.h>
#include <arrow/table_builder.h>
#include <arrow/util/base64.h>
#include <cinttypes>
#include <condition_variable>
#include <fstream>
#include <iterator>
#include <map>
#include <random>
#include <sstream>
#include <type_traits>
#include <arpa/inet.h>
#ifdef __GNUC__
# define AFS_FUNC __PRETTY_FUNCTION__
#else
# define AFS_FUNC __func__
#endif
#ifdef AFS_DEBUG
# define P(...) ereport(DEBUG5, errmsg_internal(__VA_ARGS__))
#else
# define P(...)
#endif
extern "C"
{
PG_MODULE_MAGIC;
extern PGDLLEXPORT void _PG_init(void);
extern PGDLLEXPORT void afs_executor(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn();
}
namespace {
static const char* LibraryName = "arrow_flight_sql";
static const char* SharedDataName = "arrow-flight-sql: shared data";
static const char* Tag = "arrow-flight-sql";
static const char* URIDefault = "grpc://127.0.0.1:15432";
static char* URI;
static const int SessionTimeoutDefault = 300;
static int SessionTimeout;
static const int MaxNRowsPerRecordBatchDefault = 1 * 1024 * 1024;
static int MaxNRowsPerRecordBatch;
static volatile sig_atomic_t GotSIGTERM = false;
void
afs_sigterm(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGTERM = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGHUP = false;
void
afs_sighup(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGHUP = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGUSR1 = false;
void
afs_sigusr1(SIGNAL_ARGS)
{
procsignal_sigusr1_handler(postgres_signal_arg);
auto errnoSaved = errno;
GotSIGUSR1 = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
#ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK
static shmem_request_hook_type PreviousShmemRequestHook = nullptr;
#endif
static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche";
void
afs_shmem_request_hook(void)
{
#ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK
if (PreviousShmemRequestHook)
PreviousShmemRequestHook();
#endif
RequestNamedLWLockTranche(LWLockTrancheName, 1);
}
class ScopedMemoryContext {
public:
explicit ScopedMemoryContext(MemoryContext memoryContext)
: memoryContext_(memoryContext), oldMemoryContext_(nullptr)
{
oldMemoryContext_ = MemoryContextSwitchTo(memoryContext_);
}
~ScopedMemoryContext()
{
MemoryContextSwitchTo(oldMemoryContext_);
MemoryContextDelete(memoryContext_);
}
private:
MemoryContext memoryContext_;
MemoryContext oldMemoryContext_;
};
struct ScopedTransaction {
ScopedTransaction() { StartTransactionCommand(); }
~ScopedTransaction() { CommitTransactionCommand(); }
};
struct ScopedSnapshot {
ScopedSnapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
~ScopedSnapshot() { PopActiveSnapshot(); }
};
struct ScopedPlan {
ScopedPlan(SPIPlanPtr plan) : plan_(plan) {}
~ScopedPlan() { SPI_freeplan(plan_); }
SPIPlanPtr plan_;
};
struct SharedRingBufferData {
dsa_pointer pointer;
size_t total;
size_t head;
size_t tail;
};
// Naive ring buffer implementation. We can improve this later.
class SharedRingBuffer {
public:
static void initialize_data(SharedRingBufferData* data)
{
data->pointer = InvalidDsaPointer;
data->total = 0;
data->head = 0;
data->tail = 0;
}
static void free_data(SharedRingBufferData* data, dsa_area* area)
{
if (data->pointer != InvalidDsaPointer)
dsa_free(area, data->pointer);
initialize_data(data);
}
SharedRingBuffer(SharedRingBufferData* data, dsa_area* area)
: data_(data), area_(area)
{
}
void allocate(size_t total)
{
data_->pointer = dsa_allocate(area_, total);
data_->total = total;
data_->head = 0;
data_->tail = 0;
}
void free() { free_data(data_, area_); }
size_t size() const
{
if (data_->head <= data_->tail)
{
return data_->tail - data_->head;
}
else
{
return (data_->total - data_->head) + data_->tail;
}
}
size_t rest_size() const { return data_->total - size() - 1; }
size_t write(const void* data, size_t n)
{
P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n);
if (rest_size() == 0)
{
P("%s: %s: after: no space: (%d:%d) %d:0",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n);
return 0;
}
size_t writtenSize = 0;
auto output = address();
if (data_->head <= data_->tail)
{
auto restSize = data_->total - data_->tail;
if (data_->head == 0)
{
restSize--;
}
const auto firstHalfWriteSize = std::min(n, restSize);
P("%s: %s: first half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
firstHalfWriteSize);
memcpy(output + data_->tail, data, firstHalfWriteSize);
data_->tail = (data_->tail + firstHalfWriteSize) % data_->total;
n -= firstHalfWriteSize;
writtenSize += firstHalfWriteSize;
}
if (n > 0 && rest_size() > 0)
{
const auto lastHalfWriteSize = std::min(n, data_->head - data_->tail - 1);
P("%s: %s: last half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
lastHalfWriteSize);
memcpy(output + data_->tail,
static_cast<const uint8_t*>(data) + writtenSize,
lastHalfWriteSize);
data_->tail += lastHalfWriteSize;
n -= lastHalfWriteSize;
writtenSize += lastHalfWriteSize;
}
P("%s: %s: after: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
writtenSize);
return writtenSize;
}
size_t read(size_t n, void* output)
{
P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n);
size_t readSize = 0;
const auto input = address();
if (data_->head > data_->tail)
{
const auto firstHalfReadSize = std::min(n, data_->total - data_->head);
P("%s: %s: first half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
firstHalfReadSize);
memcpy(output, input + data_->head, firstHalfReadSize);
data_->head = (data_->head + firstHalfReadSize) % data_->total;
n -= firstHalfReadSize;
readSize += firstHalfReadSize;
}
if (n > 0 && data_->head != data_->tail)
{
const auto lastHalfReadSize = std::min(n, data_->tail - data_->head);
P("%s: %s: last half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
lastHalfReadSize);
memcpy(static_cast<uint8_t*>(output) + readSize,
input + data_->head,
lastHalfReadSize);
data_->head += lastHalfReadSize;
n -= lastHalfReadSize;
readSize += lastHalfReadSize;
}
P("%s: %s: after: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
readSize);
return readSize;
}
private:
SharedRingBufferData* data_;
dsa_area* area_;
uint8_t* address()
{
return static_cast<uint8_t*>(dsa_get_address(area_, data_->pointer));
}
};
enum class Action
{
None,
Select,
Update,
Prepare,
ClosePreparedStatement,
SetParameters,
SelectPreparedStatement,
UpdatePreparedStatement,
};
const char*
action_name(Action action)
{
switch (action)
{
case Action::None:
return "Action::None";
case Action::Select:
return "Action::Select";
case Action::Update:
return "Action::Update";
case Action::Prepare:
return "Action::Prepare";
case Action::ClosePreparedStatement:
return "Action::ClosePreparedStatement";
case Action::SetParameters:
return "Action::SetParameters";
case Action::SelectPreparedStatement:
return "Action::SelectPreparedStatement";
case Action::UpdatePreparedStatement:
return "Action::UpdatePreparedStatement";
default:
return "Action::Unknown";
}
}
void
dsa_pointer_set_string(dsa_pointer& pointer, dsa_area* area, const std::string& input)
{
if (DsaPointerIsValid(pointer))
{
dsa_free(area, pointer);
pointer = InvalidDsaPointer;
}
if (input.empty())
{
return;
}
pointer = dsa_allocate(area, input.size() + 1);
memcpy(dsa_get_address(area, pointer), input.c_str(), input.size() + 1);
}
// Put only data (don't add methods) to use with dshash.
struct SessionData {
uint64_t id;
dsa_pointer errorMessage;
pid_t executorPID;
bool initialized;
dsa_pointer databaseName;
dsa_pointer userName;
dsa_pointer password;
dsa_pointer clientAddress;
Action action;
dsa_pointer selectQuery;
dsa_pointer updateQuery;
int64_t nUpdatedRecords;
dsa_pointer prepareQuery;
dsa_pointer preparedStatementHandle;
SharedRingBufferData bufferData;
};
void
session_data_initialize(SessionData* session,
dsa_area* area,
const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
session->errorMessage = InvalidDsaPointer;
session->executorPID = InvalidPid;
session->initialized = false;
dsa_pointer_set_string(session->databaseName, area, databaseName);
dsa_pointer_set_string(session->userName, area, userName);
dsa_pointer_set_string(session->password, area, password);
dsa_pointer_set_string(session->clientAddress, area, clientAddress);
session->action = Action::None;
session->selectQuery = InvalidDsaPointer;
session->updateQuery = InvalidDsaPointer;
session->nUpdatedRecords = -1;
session->prepareQuery = InvalidDsaPointer;
session->preparedStatementHandle = InvalidDsaPointer;
SharedRingBuffer::initialize_data(&(session->bufferData));
}
void
session_data_finalize(SessionData* session, dsa_area* area)
{
if (DsaPointerIsValid(session->errorMessage))
dsa_free(area, session->errorMessage);
if (DsaPointerIsValid(session->databaseName))
dsa_free(area, session->databaseName);
if (DsaPointerIsValid(session->userName))
dsa_free(area, session->userName);
if (DsaPointerIsValid(session->password))
dsa_free(area, session->password);
if (DsaPointerIsValid(session->selectQuery))
dsa_free(area, session->selectQuery);
if (DsaPointerIsValid(session->updateQuery))
dsa_free(area, session->updateQuery);
if (DsaPointerIsValid(session->prepareQuery))
dsa_free(area, session->prepareQuery);
if (DsaPointerIsValid(session->preparedStatementHandle))
dsa_free(area, session->preparedStatementHandle);
SharedRingBuffer::free_data(&(session->bufferData), area);
}
class SessionReleaser {
public:
explicit SessionReleaser(dshash_table* sessions, SessionData* data)
: sessions_(sessions), data_(data)
{
}
~SessionReleaser() { dshash_release_lock(sessions_, data_); }
private:
dshash_table* sessions_;
SessionData* data_;
};
static dshash_parameters SessionsParams = {
sizeof(uint64_t),
sizeof(SessionData),
dshash_memcmp,
dshash_memhash,
0, // Set later because this is determined dynamically.
};
struct SharedData {
int trancheID;
dsa_handle handle;
int sessionsTrancheID;
dshash_table_handle sessionsHandle;
pid_t serverPID;
pid_t mainPID;
};
class Processor {
public:
enum class WaitMode
{
Read,
Written,
};
Processor(const char* tag, bool runInPGThread)
: tag_(tag),
runInPGThread_(runInPGThread),
sharedData_(nullptr),
area_(nullptr),
lock_(),
mutex_(),
conditionVariable_()
{
}
virtual ~Processor()
{
if (area_)
{
dsa_detach(area_);
}
}
const char* tag() { return tag_; }
void lock_acquire() { LWLockAcquire(lock_, LW_EXCLUSIVE); }
void lock_release() { LWLockRelease(lock_); }
SharedRingBuffer create_shared_ring_buffer(SessionData* session)
{
return SharedRingBuffer(&(session->bufferData), area_);
}
arrow::Status wait(SessionData* session, SharedRingBuffer* buffer, WaitMode mode)
{
const bool read = (mode == WaitMode::Read);
const char* tag = read ? "wait read" : "wait written";
auto peerPID = peer_pid(session);
auto peerName = peer_name(session);
if (ARROW_PREDICT_FALSE(peerPID == InvalidPid))
{
return arrow::Status::IOError(
Tag, ": ", tag_, ": ", tag, ": ", peerName, ": not alive");
}
P("%s: %s: %s: %s: kill: %d", Tag, tag_, tag, peerName, peerPID);
kill(peerPID, SIGUSR1);
auto get_target_size =
read ? [](SharedRingBuffer* buffer) { return buffer->rest_size(); }
: [](SharedRingBuffer* buffer) { return buffer->size(); };
auto targetSize = get_target_size(buffer);
if (runInPGThread_)
{
while (true)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
WaitLatch(MyLatch, events, -1, PG_WAIT_EXTENSION);
if (GotSIGTERM)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
P("%s: %s: %s: %s: wait: %d:%d",
Tag,
tag_,
tag,
peerName,
get_target_size(buffer),
targetSize);
if (get_target_size(buffer) != targetSize)
{
break;
}
}
// TODO: Convert PG error to arrow::Status.
CHECK_FOR_INTERRUPTS();
}
}
else
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: %s: wait: %d:%d",
Tag,
tag_,
tag,
peerName,
get_target_size(buffer),
targetSize);
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
return get_target_size(buffer) != targetSize;
});
}
return arrow::Status::OK();
}
virtual void signaled()
{
P("%s: %s: signaled: before", Tag, tag_);
conditionVariable_.notify_all();
P("%s: %s: signaled: after", Tag, tag_);
}
protected:
void set_shared_string(dsa_pointer& pointer, const std::string& input)
{
dsa_pointer_set_string(pointer, area_, input);
}
virtual pid_t peer_pid(SessionData* session) { return InvalidPid; }
virtual const char* peer_name(SessionData* session) { return "unknown"; }
const char* tag_;
bool runInPGThread_;
SharedData* sharedData_;
dsa_area* area_;
LWLock* lock_;
std::mutex mutex_;
std::condition_variable conditionVariable_;
};
struct ProcessorLockGuard {
ProcessorLockGuard(Processor* processor) : processor_(processor)
{
processor_->lock_acquire();
}
~ProcessorLockGuard() { processor_->lock_release(); }
private:
Processor* processor_;
};
class Proxy;
class SharedRingBufferInputStream : public arrow::io::InputStream {
public:
SharedRingBufferInputStream(Processor* processor, SessionData* session)
: arrow::io::InputStream(),
processor_(processor),
session_(session),
position_(0),
is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Result<int64_t> Read(int64_t nBytes, void* out) override;
arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nBytes) override
{
ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(nBytes));
ARROW_ASSIGN_OR_RAISE(auto readBytes, Read(nBytes, buffer->mutable_data()));
ARROW_RETURN_NOT_OK(buffer->Resize(readBytes, false));
buffer->ZeroPadding();
return std::move(buffer);
}
private:
Processor* processor_;
SessionData* session_;
int64_t position_;
bool is_open_;
};
class Executor;
class SharedRingBufferOutputStream : public arrow::io::OutputStream {
public:
SharedRingBufferOutputStream(Processor* processor, SessionData* session)
: arrow::io::OutputStream(),
processor_(processor),
session_(session),
position_(0),
is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Status Write(const void* data, int64_t nBytes) override;
using arrow::io::OutputStream::Write;
private:
Processor* processor_;
SessionData* session_;
int64_t position_;
bool is_open_;
};
class WorkerProcessor : public Processor {
public:
explicit WorkerProcessor(const char* tag, bool runInPGThread)
: Processor(tag, runInPGThread), sessions_(nullptr)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
sharedData_ = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found));
if (!found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data isn't created yet", Tag, tag_));
}
area_ = dsa_attach(sharedData_->handle);
SessionsParams.tranche_id = sharedData_->sessionsTrancheID;
sessions_ =
dshash_attach(area_, &SessionsParams, sharedData_->sessionsHandle, nullptr);
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
}
~WorkerProcessor() override { dshash_detach(sessions_); }
protected:
void delete_session(SessionData* session)
{
session_data_finalize(session, area_);
dshash_delete_entry(sessions_, session);
}
protected:
dshash_table* sessions_;
};
class ArrowPGTypeConverter : public arrow::TypeVisitor {
public:
explicit ArrowPGTypeConverter() : oid_(InvalidOid) {}
Oid oid() const { return oid_; }
arrow::Status Visit(const arrow::Int8Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt8Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int16Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt16Type& type)
{
oid_ = INT2OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int32Type& type)
{
oid_ = INT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt32Type& type)
{
oid_ = INT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int64Type& type)
{
oid_ = INT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt64Type& type)
{
oid_ = INT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::FloatType& type)
{
oid_ = FLOAT4OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::DoubleType& type)
{
oid_ = FLOAT8OID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::StringType& type)
{
oid_ = TEXTOID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::BinaryType& type)
{
oid_ = BYTEAOID;
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::TimestampType& type)
{
oid_ = TIMESTAMPOID;
return arrow::Status::OK();
}
private:
Oid oid_;
};
class ArrowPGValueConverter : public arrow::ArrayVisitor {
public:
explicit ArrowPGValueConverter(int64_t i_row, Datum& datum)
: i_row_(i_row), datum_(datum)
{
}
arrow::Status Visit(const arrow::Int8Array& array)
{
datum_ = Int8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt8Array& array)
{
datum_ = UInt8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int16Array& array)
{
datum_ = Int16GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt16Array& array)
{
datum_ = UInt16GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int32Array& array)
{
datum_ = Int32GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt32Array& array)
{
datum_ = UInt32GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::Int64Array& array)
{
datum_ = Int64GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::UInt64Array& array)
{
datum_ = UInt64GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::FloatArray& array)
{
datum_ = Float4GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::DoubleArray& array)
{
datum_ = Float8GetDatum(array.Value(i_row_));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::StringArray& array)
{
auto value = array.GetView(i_row_);
datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length()));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::BinaryArray& array)
{
auto value = array.GetView(i_row_);
datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length()));
return arrow::Status::OK();
}
arrow::Status Visit(const arrow::TimestampArray& array)
{
const auto unit =
std::static_pointer_cast<arrow::TimestampType>(array.type())->unit();
Timestamp value = 0;
switch (unit)
{
case arrow::TimeUnit::SECOND:
value += array.Value(i_row_) * 1000000;
break;
case arrow::TimeUnit::MILLI:
value += array.Value(i_row_) * 1000;
break;
case arrow::TimeUnit::MICRO:
value += array.Value(i_row_);
break;
case arrow::TimeUnit::NANO:
value += array.Value(i_row_) / 1000;
break;
default:
return arrow::Status::NotImplemented("Unsupported time unit: ", unit);
}
// Arrow uses UNIX epoch (1970-01-01) but PostgreSQL uses 2000-01-01.
value -= (POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY;
datum_ = TimestampGetDatum(value);
return arrow::Status::OK();
}
private:
int64_t i_row_;
Datum& datum_;
};
class PGArrowValueConverter : public arrow::ArrayVisitor {
public:
explicit PGArrowValueConverter(Form_pg_attribute attribute) : attribute_(attribute) {}
arrow::Result<std::shared_ptr<arrow::DataType>> convert_type() const
{
switch (attribute_->atttypid)
{
case INT2OID:
return arrow::int16();
case INT4OID:
return arrow::int32();
case INT8OID:
return arrow::int64();
case FLOAT4OID:
return arrow::float32();
case FLOAT8OID:
return arrow::float64();
case VARCHAROID:
case TEXTOID:
return arrow::utf8();
case BYTEAOID:
return arrow::binary();
case TIMESTAMPOID:
return arrow::timestamp(arrow::TimeUnit::MICRO);
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute_->atttypid);
}
}
arrow::Status convert_value(arrow::ArrayBuilder* builder, Datum datum) const
{
switch (attribute_->atttypid)
{
case INT2OID:
return static_cast<arrow::Int16Builder*>(builder)->Append(
DatumGetInt16(datum));
case INT4OID:
return static_cast<arrow::Int32Builder*>(builder)->Append(
DatumGetInt32(datum));
case INT8OID:
return static_cast<arrow::Int64Builder*>(builder)->Append(
DatumGetInt64(datum));
case FLOAT4OID:
return static_cast<arrow::FloatBuilder*>(builder)->Append(
DatumGetFloat4(datum));
case FLOAT8OID:
return static_cast<arrow::DoubleBuilder*>(builder)->Append(
DatumGetFloat8(datum));
case VARCHAROID:
case TEXTOID:
return static_cast<arrow::StringBuilder*>(builder)->Append(
VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum));
case BYTEAOID:
return static_cast<arrow::BinaryBuilder*>(builder)->Append(
VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum));
case TIMESTAMPOID:
// Arrow uses UNIX epoch (1970-01-01) but PostgreSQL
// uses 2000-01-01.
return static_cast<arrow::TimestampBuilder*>(builder)->Append(
DatumGetTimestamp(datum) +
(POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY);
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute_->atttypid);
}
}
private:
Form_pg_attribute attribute_;
};
class PreparedStatement {
public:
explicit PreparedStatement(std::string query)
: query_(std::move(query)), parameters_()
{
}
~PreparedStatement() {}
using WriteFunc = std::add_pointer<arrow::Status(void*)>::type;
arrow::Status select(WriteFunc write, void* writeData)
{
for (const auto& recordBatch : parameters_)
{
SPIExecuteOptions options = {};
std::vector<Oid> pgTypes;
ARROW_RETURN_NOT_OK(prepare(options, pgTypes, recordBatch->schema()));
auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data());
ScopedPlan scopedPlan(plan);
ARROW_RETURN_NOT_OK(
execute(plan, recordBatch, options, [&]() { return write(writeData); }));
}
return arrow::Status::OK();
}
arrow::Status set_parameters(std::shared_ptr<SharedRingBufferInputStream>& input)
{
parameters_.clear();
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
parameters_.push_back(std::move(recordBatch));
}
return arrow::Status::OK();
}
arrow::Result<int64_t> update(std::shared_ptr<SharedRingBufferInputStream>& input)
{
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
SPIExecuteOptions options = {};
std::vector<Oid> pgTypes;
ARROW_RETURN_NOT_OK(prepare(options, pgTypes, reader->schema()));
auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data());
ScopedPlan scopedPlan(plan);
int64_t nUpdatedRecords = 0;
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
ARROW_RETURN_NOT_OK(execute(plan, recordBatch, options, [&nUpdatedRecords]() {
nUpdatedRecords += SPI_processed;
return arrow::Status::OK();
}));
}
return nUpdatedRecords;
}
private:
arrow::Status prepare_pg_types(std::vector<Oid>& pgTypes,
const std::shared_ptr<arrow::Schema>& schema)
{
ArrowPGTypeConverter converter;
for (const auto& field : schema->fields())
{
ARROW_RETURN_NOT_OK(field->type()->Accept(&converter));
pgTypes.push_back(converter.oid());
}
return arrow::Status::OK();
}
arrow::Status prepare(SPIExecuteOptions& options,
std::vector<Oid>& pgTypes,
const std::shared_ptr<arrow::Schema>& schema)
{
if (schema->num_fields() > 0)
{
options.params = makeParamList(schema->num_fields());
}
options.read_only = false;
options.tcount = 0;
ARROW_RETURN_NOT_OK(prepare_pg_types(pgTypes, schema));
for (size_t i = 0; i < pgTypes.size(); ++i)
{
options.params->params[i].pflags = PARAM_FLAG_CONST;
options.params->params[i].ptype = pgTypes[i];
}
return arrow::Status::OK();
}
template <typename OnSuccessFunc>
arrow::Status execute(SPIPlanPtr plan,
const std::shared_ptr<arrow::RecordBatch>& recordBatch,
SPIExecuteOptions& options,
OnSuccessFunc onSuccess)
{
const auto& columns = recordBatch->columns();
for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
{
ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options));
auto result = SPI_execute_plan_extended(plan, &options);
if (result <= 0)
{
return arrow::Status::Invalid("failed to run a prepared statement: ",
SPI_result_code_string(result),
": ",
query_);
}
ARROW_RETURN_NOT_OK(onSuccess());
}
return arrow::Status::OK();
}
arrow::Status assign_parameters(
const std::shared_ptr<arrow::RecordBatch>& recordBatch,
int64_t i_row,
const std::vector<std::shared_ptr<arrow::Array>>& columns,
SPIExecuteOptions& options)
{
int64_t i_column = 0;
for (const auto& column : columns)
{
auto param = &(options.params->params[i_column]);
param->isnull = column->IsNull(i_row);
if (!param->isnull)
{
ArrowPGValueConverter converter(i_row, param->value);
ARROW_RETURN_NOT_OK(column->Accept(&converter));
}
++i_column;
}
return arrow::Status::OK();
}
std::string query_;
std::vector<std::shared_ptr<arrow::RecordBatch>> parameters_;
};
class Executor : public WorkerProcessor {
public:
explicit Executor(uint64_t sessionID)
: WorkerProcessor("executor", true),
sessionID_(sessionID),
session_(nullptr),
connected_(false),
closed_(false),
nextPreparedStatementID_(1),
preparedStatements_()
{
}
~Executor()
{
if (!closed_)
{
close_internal(false);
}
}
void open()
{
const char* tag = "open";
// pg_usleep(5000000);
// pg_usleep(5000000);
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": opening").c_str());
session_ = static_cast<SessionData*>(dshash_find(sessions_, &sessionID_, false));
auto databaseName =
static_cast<const char*>(dsa_get_address(area_, session_->databaseName));
auto userName =
static_cast<const char*>(dsa_get_address(area_, session_->userName));
auto password =
static_cast<const char*>(dsa_get_address(area_, session_->password));
auto clientAddress =
static_cast<const char*>(dsa_get_address(area_, session_->clientAddress));
BackgroundWorkerInitializeConnection(databaseName, userName, 0);
CurrentResourceOwner = ResourceOwnerCreate(nullptr, "arrow-flight-sql: Executor");
if (!check_password(databaseName, userName, password, clientAddress))
{
session_->initialized = true;
signal_server(tag);
return;
}
{
SharedRingBuffer buffer(&(session_->bufferData), area_);
// TODO: Customizable.
buffer.allocate(1L * 1024L * 1024L);
}
SetCurrentStatementStartTimestamp();
SPI_connect();
pgstat_report_activity(STATE_IDLE, NULL);
session_->initialized = true;
connected_ = true;
signal_server(tag);
}
void close() { close_internal(true); }
void signaled() override
{
Action action;
{
ProcessorLockGuard lock(this);
action = session_->action;
session_->action = Action::None;
}
P("%s: %s: signaled: before: %s", Tag, tag_, action_name(action));
PG_TRY();
{
switch (action)
{
case Action::Select:
select();
break;
case Action::Update:
update();
break;
case Action::Prepare:
prepare();
break;
case Action::ClosePreparedStatement:
close_prepared_statement();
break;
case Action::SetParameters:
set_parameters();
break;
case Action::SelectPreparedStatement:
select_prepared_statement();
break;
case Action::UpdatePreparedStatement:
update_prepared_statement();
break;
default:
Processor::signaled();
break;
}
}
PG_CATCH();
{
if (session_ && !DsaPointerIsValid(session_->errorMessage))
{
auto error = CopyErrorData();
set_error_message(std::string("failed to run: ") + action_name(action) +
": " + error->message,
"unexpected error");
FreeErrorData(error);
}
pgstat_report_activity(STATE_IDLE, NULL);
PG_RE_THROW();
}
PG_END_TRY();
pgstat_report_activity(STATE_IDLE, NULL);
P("%s: %s: signaled: after: %s", Tag, tag_, action_name(action));
}
protected:
pid_t peer_pid(SessionData* session) override { return sharedData_->serverPID; }
const char* peer_name(SessionData* session) override { return "server"; }
private:
void signal_server(const char* tag)
{
if (sharedData_->serverPID == InvalidPid)
{
return;
}
P("%s: %s: %s: kill server: %d", Tag, tag_, tag, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}
void set_error_message(const std::string& message, const char* tag)
{
if (DsaPointerIsValid(session_->errorMessage))
{
return;
}
{
ProcessorLockGuard lock(this);
set_shared_string(session_->errorMessage, message);
}
signal_server(tag);
}
void close_internal(bool unlockSession)
{
const char* tag = "close";
closed_ = true;
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": closing").c_str());
preparedStatements_.clear();
if (connected_)
{
SPI_finish();
{
SharedRingBuffer buffer(&(session_->bufferData), area_);
buffer.free();
}
delete_session(session_);
}
else
{
set_error_message("failed to connect", tag);
session_->initialized = true;
if (unlockSession)
{
dshash_release_lock(sessions_, session_);
}
signal_server(tag);
}
if (CurrentResourceOwner)
{
auto resourceOwner = CurrentResourceOwner;
CurrentResourceOwner = nullptr;
ResourceOwnerRelease(
resourceOwner, RESOURCE_RELEASE_BEFORE_LOCKS, false, true);
ResourceOwnerRelease(resourceOwner, RESOURCE_RELEASE_LOCKS, false, true);
ResourceOwnerRelease(
resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS, false, true);
ResourceOwnerDelete(resourceOwner);
}
}
bool check_password(const char* databaseName,
const char* userName,
const char* password,
const char* clientAddress)
{
const char* tag = "check password";
MemoryContext memoryContext =
AllocSetContextCreate(CurrentMemoryContext,
"arrow-flight-sql: Executor::check_password()",
ALLOCSET_DEFAULT_SIZES);
ScopedMemoryContext scopedMemoryContext(memoryContext);
Port port = {};
port.database_name = pstrdup(databaseName);
port.user_name = pstrdup(userName);
if (!fill_client_address(&port, clientAddress))
{
return false;
}
load_hba();
hba_getauthmethod(&port);
if (!port.hba)
{
set_error_message("failed to get auth method", tag);
return false;
}
switch (port.hba->auth_method)
{
case uaMD5:
// TODO
set_error_message("MD5 auth method isn't supported yet", tag);
return false;
case uaSCRAM:
// TODO
set_error_message("SCRAM auth method isn't supported yet", tag);
return false;
case uaPassword:
{
const char* logDetail = nullptr;
auto shadowPassword = get_role_password(port.user_name, &logDetail);
if (!shadowPassword)
{
set_error_message(std::string("failed to get password: ") + logDetail,
tag);
return false;
}
auto result = plain_crypt_verify(
port.user_name, shadowPassword, password, &logDetail);
if (result != STATUS_OK)
{
set_error_message(
std::string("failed to verify password: ") + logDetail, tag);
return false;
}
return true;
}
case uaTrust:
return true;
default:
set_error_message(std::string("unsupported auth method: ") +
hba_authname(port.hba->auth_method),
tag);
return false;
}
}
bool fill_client_address(Port* port, const char* clientAddress)
{
const char* tag = "fill client address";
// clientAddress: "ipv4:127.0.0.1:40468"
// family: "ipv4"
// host: "127.0.0.1"
// port: "40468"
std::stringstream clientAddressStream{std::string(clientAddress)};
std::string clientFamily("");
std::string clientHost("");
std::string clientPort("");
std::getline(clientAddressStream, clientFamily, ':');
std::getline(clientAddressStream, clientHost, ':');
std::getline(clientAddressStream, clientPort);
if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
{
set_error_message(
std::string("client family must be ipv4 or ipv6: ") + clientFamily, tag);
return false;
}
auto clientPortStart = clientPort.c_str();
char* clientPortEnd = nullptr;
auto clientPortNumber = std::strtoul(clientPortStart, &clientPortEnd, 10);
if (clientPortEnd[0] != '\0')
{
set_error_message(std::string("client port is invalid: ") + clientPort, tag);
return false;
}
if (clientPortNumber == 0)
{
set_error_message(std::string("client port must not 0"), tag);
return false;
}
if (clientPortNumber > 65535)
{
set_error_message(std::string("client port is too large: ") +
std::to_string(clientPortNumber),
tag);
return false;
}
if (clientFamily == "ipv4")
{
auto raddr = reinterpret_cast<sockaddr_in*>(&(port->raddr.addr));
port->raddr.salen = sizeof(sockaddr_in);
raddr->sin_family = AF_INET;
raddr->sin_port = htons(clientPortNumber);
if (inet_pton(AF_INET, clientHost.c_str(), &(raddr->sin_addr)) == 0)
{
set_error_message(
std::string("client IPv4 address is invalid: ") + clientHost, tag);
return false;
}
}
else if (clientFamily == "ipv6")
{
auto raddr = reinterpret_cast<sockaddr_in6*>(&(port->raddr.addr));
port->raddr.salen = sizeof(sockaddr_in6);
raddr->sin6_family = AF_INET6;
raddr->sin6_port = htons(clientPortNumber);
raddr->sin6_flowinfo = 0;
if (inet_pton(AF_INET6, clientHost.c_str(), &(raddr->sin6_addr)) == 0)
{
set_error_message(
std::string("client IPv6 address is invalid: ") + clientHost, tag);
return false;
}
raddr->sin6_scope_id = 0;
}
return true;
}
void select()
{
const char* tag = "select";
if (!DsaPointerIsValid(session_->selectQuery))
{
set_error_message(
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str());
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session_->selectQuery));
dsa_free(area_, session_->selectQuery);
session_->selectQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
auto result = SPI_execute(query.c_str(), true, 0);
if (result > 0)
{
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": " + tag + ": writing").c_str());
auto status = write(tag);
if (status.ok())
{
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to write: " + status.ToString(),
tag);
}
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result),
tag);
}
}
}
arrow::Status write(const char* tag)
{
SharedRingBufferOutputStream output(this, session_);
std::vector<PGArrowValueConverter> converters;
std::vector<std::shared_ptr<arrow::Field>> fields;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i);
converters.emplace_back(attribute);
const auto& converter = converters[converters.size() - 1];
ARROW_ASSIGN_OR_RAISE(auto type, converter.convert_type());
fields.push_back(arrow::field(
NameStr(attribute->attname), std::move(type), !attribute->attnotnull));
}
auto schema = arrow::schema(fields);
ARROW_ASSIGN_OR_RAISE(
auto builder,
arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
// Write schema only stream format data to return only schema.
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
// Build an empty record batch to write schema.
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
P("%s: %s: %s: write: schema: WriteRecordBatch", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
P("%s: %s: %s: write: schema: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
// Write another stream format data with record batches.
ARROW_ASSIGN_OR_RAISE(writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
bool needLastFlush = false;
for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
{
P("%s: %s: %s: write: data: record batch: %d/%d",
Tag,
tag_,
tag,
iTuple,
SPI_processed);
for (int iAttribute = 0; iAttribute < SPI_tuptable->tupdesc->natts;
++iAttribute)
{
P("%s: %s: %s: write: data: record batch: %d/%d: %d/%d",
Tag,
tag_,
tag,
iTuple,
SPI_processed,
iAttribute,
SPI_tuptable->tupdesc->natts);
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
SPI_tuptable->tupdesc,
iAttribute + 1,
&isNull);
auto arrayBuilder = builder->GetField(iAttribute);
if (isNull)
{
ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull());
}
else
{
ARROW_RETURN_NOT_OK(
converters[iAttribute].convert_value(arrayBuilder, datum));
}
}
if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0)
{
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: %s: write: data: WriteRecordBatch: %d/%d",
Tag,
tag_,
tag,
iTuple,
SPI_processed);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
needLastFlush = false;
}
else
{
needLastFlush = true;
}
}
if (needLastFlush)
{
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: %s: write: data: WriteRecordBatch", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
}
P("%s: %s: %s, write: data: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
return output.Close();
}
void update()
{
const char* tag = "update";
if (!DsaPointerIsValid(session_->updateQuery))
{
set_error_message(
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str());
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session_->updateQuery));
dsa_free(area_, session_->updateQuery);
session_->updateQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
auto result = SPI_execute(query.c_str(), false, 0);
if (result > 0)
{
session_->nUpdatedRecords = SPI_processed;
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result),
tag);
}
}
}
void prepare()
{
const char* tag = "prepare";
if (!DsaPointerIsValid(session_->prepareQuery))
{
set_error_message(
std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": preparing").c_str());
std::string query;
{
ProcessorLockGuard lock(this);
query =
static_cast<const char*>(dsa_get_address(area_, session_->prepareQuery));
dsa_free(area_, session_->prepareQuery);
session_->prepareQuery = InvalidDsaPointer;
}
P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
std::string handle(std::to_string(nextPreparedStatementID_++));
preparedStatements_.insert(
std::make_pair(handle, PreparedStatement(std::move(query))));
{
ProcessorLockGuard lock(this);
set_shared_string(session_->preparedStatementHandle, handle);
}
signal_server(tag);
}
bool extract_handle(std::string& handle, const char* tag)
{
if (!DsaPointerIsValid(session_->preparedStatementHandle))
{
set_error_message(
std::string(Tag) + ": " + tag_ + ": " + tag + ": handle is missing", tag);
return false;
}
{
ProcessorLockGuard lock(this);
handle = static_cast<const char*>(
dsa_get_address(area_, session_->preparedStatementHandle));
dsa_free(area_, session_->preparedStatementHandle);
session_->preparedStatementHandle = InvalidDsaPointer;
}
return true;
}
void close_prepared_statement()
{
const char* tag = "close prepared statement";
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": " + tag).c_str());
std::string handle;
if (!extract_handle(handle, tag))
{
return;
}
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (preparedStatements_.erase(handle) > 0)
{
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": nonexistent handle: <" + handle + ">",
tag);
}
}
PreparedStatement* find_prepared_statement(std::string& handle, const char* tag)
{
if (!extract_handle(handle, tag))
{
return nullptr;
}
ProcessorLockGuard lock(this);
auto it = preparedStatements_.find(handle);
if (it == preparedStatements_.end())
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": nonexistent handle: <" + handle + ">",
tag);
return nullptr;
}
else
{
return &(it->second);
}
}
void set_parameters()
{
const char* tag = "set parameters";
pgstat_report_activity(STATE_RUNNING,
(std::string(Tag) + ": setting parameters").c_str());
std::string handle;
auto preparedStatement = find_prepared_statement(handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
auto input = std::make_shared<SharedRingBufferInputStream>(this, session_);
auto status = preparedStatement->set_parameters(input);
if (status.ok())
{
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to set parameters: <" + handle +
">: " + status.ToString(),
tag);
}
}
void select_prepared_statement()
{
const char* tag = "select prepared statement";
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": selecting prepared statement").c_str());
std::string handle;
auto preparedStatement = find_prepared_statement(handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
struct Data {
Executor* executor;
const char* tag;
} data = {this, tag};
auto status = preparedStatement->select(
[](void* data) {
auto d = static_cast<Data*>(data);
return d->executor->write(d->tag);
},
&data);
if (status.ok())
{
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to select a prepared statement: <" + handle +
">: " + status.ToString(),
tag);
}
}
void update_prepared_statement()
{
const char* tag = "update prepared statement";
pgstat_report_activity(
STATE_RUNNING, (std::string(Tag) + ": updating prepared statement").c_str());
std::string handle;
auto preparedStatement = find_prepared_statement(handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
if (!preparedStatement)
{
return;
}
ScopedTransaction scopedTransaction;
ScopedSnapshot scopedSnapshot;
auto input = std::make_shared<SharedRingBufferInputStream>(this, session_);
auto n_updated_records_result = preparedStatement->update(input);
if (n_updated_records_result.ok())
{
session_->nUpdatedRecords = *n_updated_records_result;
signal_server(tag);
}
else
{
set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag +
": failed to update a prepared statement: <" + handle +
">: " + n_updated_records_result.status().ToString(),
tag);
}
}
uint64_t sessionID_;
SessionData* session_;
bool connected_;
bool closed_;
uint64_t nextPreparedStatementID_;
std::map<std::string, PreparedStatement> preparedStatements_;
};
arrow::Status
SharedRingBufferOutputStream::Write(const void* data, int64_t nBytes)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferOutputStream is closed");
}
if (ARROW_PREDICT_TRUE(nBytes > 0))
{
auto buffer = processor_->create_shared_ring_buffer(session_);
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
processor_->lock_acquire();
auto writtenSize = buffer.write(data, rest);
processor_->lock_release();
position_ += writtenSize;
rest -= writtenSize;
data = static_cast<const uint8_t*>(data) + writtenSize;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
ARROW_RETURN_NOT_OK(
processor_->wait(session_, &buffer, Processor::WaitMode::Read));
}
}
return arrow::Status::OK();
}
class Proxy : public WorkerProcessor {
public:
explicit Proxy()
: WorkerProcessor("proxy", false), randomSeed_(), randomEngine_(randomSeed_())
{
}
arrow::Result<uint64_t> connect(const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
auto session = create_session(databaseName, userName, password, clientAddress);
auto id = session->id;
dshash_release_lock(sessions_, session);
kill(sharedData_->mainPID, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
if (INTERRUPTS_PENDING_CONDITION())
{
return true;
}
session = static_cast<SessionData*>(dshash_find(sessions_, &id, false));
if (!session)
{
return true;
}
const auto initialized = session->initialized;
dshash_release_lock(sessions_, session);
return initialized;
});
}
session = static_cast<SessionData*>(dshash_find(sessions_, &id, false));
if (!session)
{
return arrow::Status::Invalid("session is stale: ", id);
}
SessionReleaser sessionReleaser(sessions_, session);
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
}
return id;
}
bool is_valid_session(uint64_t sessionID)
{
auto session = find_session(sessionID);
if (session)
{
dshash_release_lock(sessions_, session);
return true;
}
else
{
return false;
}
}
arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
const std::string& query)
{
const char* tag = "select";
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
set_shared_string(session->selectQuery, query);
session->action = Action::Select;
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer = create_shared_ring_buffer(session);
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: open", Tag, tag_, tag);
auto schema = read_schema(session, tag);
P("%s: %s: %s: schema", Tag, tag_, tag);
return schema;
}
arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query)
{
#ifdef AFS_DEBUG
const char* tag = "update";
#endif
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->updateQuery, query);
session->action = Action::Update;
session->nUpdatedRecords = -1;
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) ||
session->nUpdatedRecords >= 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: done: %ld", Tag, tag_, tag, session->nUpdatedRecords);
return session->nUpdatedRecords;
}
arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t sessionID)
{
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
auto input = std::make_shared<SharedRingBufferInputStream>(this, session);
// Read another stream format data with record batches.
return arrow::ipc::RecordBatchStreamReader::Open(input);
}
arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult> prepare(
uint64_t sessionID, const std::string& query)
{
#ifdef AFS_DEBUG
const char* tag = "prepare";
#endif
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->prepareQuery, query);
session->action = Action::Prepare;
set_shared_string(session->preparedStatementHandle, std::string(""));
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) ||
DsaPointerIsValid(session->preparedStatementHandle);
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
std::string handle(static_cast<const char*>(
dsa_get_address(area_, session->preparedStatementHandle)));
arrow::flight::sql::ActionCreatePreparedStatementResult result = {
nullptr,
nullptr,
std::move(handle),
};
P("%s: %s: %s: done", Tag, tag_, tag);
return result;
}
arrow::Status close_prepared_statement(uint64_t sessionID, const std::string& handle)
{
#ifdef AFS_DEBUG
const char* tag = "close prepared statement";
#endif
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->preparedStatementHandle, handle);
session->action = Action::ClosePreparedStatement;
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) ||
!DsaPointerIsValid(session->preparedStatementHandle);
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: done", Tag, tag_, tag);
return arrow::Status::OK();
}
arrow::Status set_parameters(uint64_t sessionID,
const std::string& handle,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer)
{
#ifdef AFS_DEBUG
const char* tag = "set parameters";
#endif
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->preparedStatementHandle, handle);
session->action = Action::SetParameters;
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema());
SharedRingBufferOutputStream output(this, session);
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
while (true)
{
ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next());
if (!chunk.data)
{
break;
}
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
}
ARROW_RETURN_NOT_OK(writer->Close());
}
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer = create_shared_ring_buffer(session);
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) || buffer.size() == 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: done", Tag, tag_, tag);
return arrow::Status::OK();
}
arrow::Result<std::shared_ptr<arrow::Schema>> select_prepared_statement(
uint64_t sessionID, const std::string& handle)
{
const char* tag = "select prepared statement";
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->preparedStatementHandle, handle);
session->action = Action::SelectPreparedStatement;
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer = create_shared_ring_buffer(session);
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: open", Tag, tag_, tag);
auto schema = read_schema(session, tag);
P("%s: %s: %s: schema", Tag, tag_, tag);
return schema;
}
arrow::Result<int64_t> update_prepared_statement(
uint64_t sessionID,
const std::string& handle,
arrow::flight::FlightMessageReader* reader)
{
#ifdef AFS_DEBUG
const char* tag = "update prepared statement";
#endif
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
lock_acquire();
set_shared_string(session->preparedStatementHandle, handle);
session->action = Action::UpdatePreparedStatement;
session->nUpdatedRecords = -1;
lock_release();
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema());
SharedRingBufferOutputStream output(this, session);
auto options = arrow::ipc::IpcWriteOptions::Defaults();
options.emit_dictionary_deltas = true;
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, options));
while (true)
{
ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next());
if (!chunk.data)
{
break;
}
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
}
ARROW_RETURN_NOT_OK(writer->Close());
}
if (session->executorPID != InvalidPid)
{
P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage) ||
session->nUpdatedRecords >= 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: %s: done: %ld", Tag, tag_, tag, session->nUpdatedRecords);
return session->nUpdatedRecords;
}
protected:
pid_t peer_pid(SessionData* session) override { return session->executorPID; }
const char* peer_name(SessionData* session) override { return "executor"; }
private:
SessionData* create_session(const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
lock_acquire();
uint64_t id = 0;
SessionData* session = nullptr;
do
{
id = randomEngine_();
if (id == 0)
{
continue;
}
bool found = false;
session =
static_cast<SessionData*>(dshash_find_or_insert(sessions_, &id, &found));
if (!found)
{
break;
}
} while (true);
session_data_initialize(
session, area_, databaseName, userName, password, clientAddress);
lock_release();
return session;
}
SessionData* find_session(uint64_t sessionID)
{
return static_cast<SessionData*>(dshash_find(sessions_, &sessionID, false));
}
arrow::Status report_session_error(SessionData* session)
{
auto status = arrow::Status::Invalid(
static_cast<const char*>(dsa_get_address(area_, session->errorMessage)));
P("%s: %s: %s: kill SIGTERM executor: %d",
Tag,
tag_,
AFS_FUNC,
session->executorPID);
kill(session->executorPID, SIGTERM);
return status;
}
arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(SessionData* session,
const char* tag)
{
auto input = std::make_shared<SharedRingBufferInputStream>(this, session);
// Read schema only stream format data.
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
P("%s: %s: %s: read next", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
}
return reader->schema();
}
std::random_device randomSeed_;
std::mt19937_64 randomEngine_;
};
arrow::Result<int64_t>
SharedRingBufferInputStream::Read(int64_t nBytes, void* out)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferInputStream is closed");
}
auto buffer = processor_->create_shared_ring_buffer(session_);
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
processor_->lock_acquire();
auto readBytes = buffer.read(rest, out);
processor_->lock_release();
position_ += readBytes;
rest -= readBytes;
out = static_cast<uint8_t*>(out) + readBytes;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
ARROW_RETURN_NOT_OK(
processor_->wait(session_, &buffer, Processor::WaitMode::Written));
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": interrupted");
}
}
return nBytes;
}
class MainProcessor : public Processor {
public:
MainProcessor() : Processor("main", true), sessions_(nullptr)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
sharedData_ = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found));
if (found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data is already created", Tag, tag_));
}
sharedData_->trancheID = LWLockNewTrancheId();
sharedData_->sessionsTrancheID = LWLockNewTrancheId();
area_ = dsa_create(sharedData_->trancheID);
sharedData_->handle = dsa_get_handle(area_);
SessionsParams.tranche_id = sharedData_->sessionsTrancheID;
sessions_ = dshash_create(area_, &SessionsParams, nullptr);
sharedData_->sessionsHandle = dshash_get_hash_table_handle(sessions_);
sharedData_->serverPID = InvalidPid;
sharedData_->mainPID = MyProcPid;
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
}
~MainProcessor() override { dshash_destroy(sessions_); }
BackgroundWorkerHandle* start_server()
{
BackgroundWorker worker = {};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: server", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
if (!RegisterDynamicBackgroundWorker(&worker, &handle))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: failed to start server", Tag, tag_));
}
WaitForBackgroundWorkerStartup(handle, &(sharedData_->serverPID));
return handle;
}
void process_connect_requests()
{
dshash_seq_status sessionsStatus;
dshash_seq_init(&sessionsStatus, sessions_, false);
SessionData* session;
while ((session = static_cast<SessionData*>(dshash_seq_next(&sessionsStatus))))
{
if (session->initialized)
{
continue;
}
BackgroundWorker worker = {};
snprintf(
worker.bgw_name, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id);
snprintf(
worker.bgw_type, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id);
worker.bgw_flags =
BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor");
worker.bgw_main_arg = Int64GetDatum(session->id);
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
if (RegisterDynamicBackgroundWorker(&worker, &handle))
{
WaitForBackgroundWorkerStartup(handle, &(session->executorPID));
}
else
{
set_shared_string(
session->errorMessage,
std::string(Tag) + ": " + tag_ +
": failed to start executor: " + std::to_string(session->id));
}
}
dshash_seq_term(&sessionsStatus);
kill(sharedData_->serverPID, SIGUSR1);
}
private:
dshash_table* sessions_;
};
class HeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
explicit HeaderAuthServerMiddleware(uint64_t sessionID) : sessionID_(sessionID) {}
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override
{
outgoing_headers->AddHeader("authorization",
std::string("Bearer ") + std::to_string(sessionID_));
}
void CallCompleted(const arrow::Status& status) override {}
std::string name() const override { return "HeaderAuthServerMiddleware"; }
uint64_t session_id() { return sessionID_; }
private:
uint64_t sessionID_;
};
class HeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
explicit HeaderAuthServerMiddlewareFactory(Proxy* proxy)
: arrow::flight::ServerMiddlewareFactory(), proxy_(proxy)
{
}
arrow::Status StartCall(
const arrow::flight::CallInfo& info,
#if ARROW_VERSION_MAJOR >= 13
const arrow::flight::ServerCallContext& context,
#else
const arrow::flight::CallHeaders& incoming_headers,
#endif
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override
{
std::string databaseName("postgres");
#if ARROW_VERSION_MAJOR >= 13
const auto& incomingHeaders = context.incoming_headers();
#else
const auto& incomingHeaders = incoming_headers;
#endif
auto databaseHeader = incomingHeaders.find("x-flight-sql-database");
if (databaseHeader != incomingHeaders.end())
{
databaseName = databaseHeader->second;
}
auto authorizationHeader = incomingHeaders.find("authorization");
if (authorizationHeader == incomingHeaders.end())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
"No authorization header");
}
auto value = authorizationHeader->second;
std::stringstream valueStream{std::string(value)};
std::string type("");
std::getline(valueStream, type, ' ');
if (type == "Basic")
{
std::stringstream decodedStream(
arrow::util::base64_decode(value.substr(valueStream.tellg())));
std::string userName("");
std::string password("");
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
#if ARROW_VERSION_MAJOR >= 13
const auto& clientAddress = context.peer();
#else
// 192.0.0.1 is one of reserved IPv4 addresses for documentation.
std::string clientAddress("ipv4:192.0.2.1:2929");
#endif
auto sessionIDResult =
proxy_->connect(databaseName, userName, password, clientAddress);
if (!sessionIDResult.status().ok())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
sessionIDResult.status().ToString());
}
auto sessionID = *sessionIDResult;
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else if (type == "Bearer")
{
std::string sessionIDString(value.substr(valueStream.tellg()));
if (sessionIDString.size() == 0)
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
auto start = sessionIDString.c_str();
char* end = nullptr;
uint64_t sessionID = std::strtoull(start, &end, 10);
if (end[0] != '\0')
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
if (!proxy_->is_valid_session(sessionID))
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
std::string("authorization header must use Basic or Bearer: <") + type +
std::string(">"));
}
}
private:
Proxy* proxy_;
};
class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
public:
explicit FlightSQLServer(Proxy* proxy)
: arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy)
{
}
~FlightSQLServer() override {}
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
arrow::flight::FlightEndpoint{arrow::flight::Ticket{std::move(ticket)}, {}}};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID));
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}
arrow::Result<int64_t> DoPutCommandStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementUpdate& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
return proxy_->update(sessionID, query);
}
arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult>
CreatePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionCreatePreparedStatementRequest& request) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = request.query;
return proxy_->prepare(sessionID, query);
}
arrow::Status ClosePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionClosePreparedStatementRequest& request) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = request.prepared_statement_handle;
return proxy_->close_prepared_statement(sessionID, handle);
}
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
GetFlightInfoPreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
ARROW_ASSIGN_OR_RAISE(auto schema,
proxy_->select_prepared_statement(sessionID, handle));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(handle));
std::vector<arrow::flight::FlightEndpoint> endpoints{
arrow::flight::FlightEndpoint{arrow::flight::Ticket{std::move(ticket)}, {}}};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}
arrow::Status DoPutPreparedStatementQuery(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
return proxy_->set_parameters(sessionID, handle, reader, writer);
}
arrow::Result<int64_t> DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& handle = command.prepared_statement_handle;
return proxy_->update_prepared_statement(sessionID, handle, reader);
}
private:
arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context)
{
auto middleware = reinterpret_cast<HeaderAuthServerMiddleware*>(
context.GetMiddleware("header-auth"));
if (!middleware)
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated, "no authorization");
}
return middleware->session_id();
}
Proxy* proxy_;
};
arrow::Status
afs_server_internal(Proxy* proxy)
{
ARROW_ASSIGN_OR_RAISE(auto location, arrow::flight::Location::Parse(URI));
arrow::flight::FlightServerOptions options(location);
if (EnableSSL)
{
arrow::flight::CertKeyPair certificate;
if (ssl_cert_file)
{
std::ifstream input(ssl_cert_file);
if (input)
{
certificate.pem_cert =
std::string(std::istreambuf_iterator<char>{input}, {});
}
}
if (ssl_key_file)
{
std::ifstream input(ssl_key_file);
if (input)
{
certificate.pem_key =
std::string(std::istreambuf_iterator<char>{input}, {});
}
}
if (!certificate.pem_cert.empty() && !certificate.pem_key.empty())
{
options.tls_certificates.push_back(std::move(certificate));
}
}
options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
options.middleware.push_back(
{"header-auth", std::make_shared<HeaderAuthServerMiddlewareFactory>(proxy)});
FlightSQLServer flightSQLServer(proxy);
ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));
ereport(LOG, (errmsg("listening on %s for Apache Arrow Flight SQL", URI)));
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
proxy->signaled();
}
CHECK_FOR_INTERRUPTS();
}
// TODO: Use before_shmem_exit()
auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10);
return flightSQLServer.Shutdown(&deadline);
}
} // namespace
static void
afs_executor_before_shmem_exit(int code, Datum arg)
{
// TODO: This doesn't work. We need to improve
// BackgroundWorkerInitializeConnection() failed case.
auto executor = reinterpret_cast<Executor*>(DatumGetPointer(arg));
executor->close();
delete executor;
}
extern "C" void
afs_executor(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
auto executor = new Executor(DatumGetInt64(arg));
before_shmem_exit(afs_executor_before_shmem_exit, PointerGetDatum(executor));
executor->open();
while (!GotSIGTERM)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
const long timeout = SessionTimeout * 1000;
if (timeout >= 0)
{
events |= WL_TIMEOUT;
}
auto conditions = WaitLatch(MyLatch, events, timeout, PG_WAIT_EXTENSION);
if (conditions & WL_TIMEOUT)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
executor->signaled();
}
CHECK_FOR_INTERRUPTS();
}
executor->close();
proc_exit(0);
}
extern "C" void
afs_server(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
arrow::Status status;
{
Proxy proxy;
status = afs_server_internal(&proxy);
}
if (!status.ok())
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: server: failed: %s", Tag, status.ToString().c_str()));
}
}
proc_exit(0);
}
extern "C" void
afs_main(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
MainProcessor processor;
auto serverHandle = processor.start_server();
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
processor.process_connect_requests();
}
CHECK_FOR_INTERRUPTS();
}
WaitForBackgroundWorkerShutdown(serverHandle);
}
proc_exit(0);
}
extern "C" void
_PG_init(void)
{
if (!process_shared_preload_libraries_in_progress)
{
return;
}
DefineCustomStringVariable("arrow_flight_sql.uri",
"Apache Arrow Flight SQL endpoint URI.",
(std::string("default: ") + URIDefault).c_str(),
&URI,
URIDefault,
PGC_POSTMASTER,
0,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.session_timeout",
"Maximum session duration in seconds.",
"The default is 300 seconds. "
"-1 means no timeout.",
&SessionTimeout,
SessionTimeoutDefault,
-1,
INT_MAX,
PGC_USERSET,
GUC_UNIT_S,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.max_n_rows_per_record_batch",
"The maximum number of rows per record batch.",
"The default is 1 * 1024 * 1024 rows.",
&MaxNRowsPerRecordBatch,
MaxNRowsPerRecordBatchDefault,
1,
INT_MAX,
PGC_USERSET,
0,
NULL,
NULL,
NULL);
#ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK
PreviousShmemRequestHook = shmem_request_hook;
shmem_request_hook = afs_shmem_request_hook;
#else
afs_shmem_request_hook();
#endif
BackgroundWorker worker = {};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: main", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = 0;
RegisterBackgroundWorker(&worker);
}