blob: bb8b5fc037e6402805908ff58af67dedd333347c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
extern "C"
{
#include <postgres.h>
#include <access/xact.h>
#include <executor/spi.h>
#include <fmgr.h>
#include <miscadmin.h>
#include <postmaster/bgworker.h>
#include <storage/ipc.h>
#include <storage/latch.h>
#include <storage/lwlock.h>
#include <storage/procsignal.h>
#include <storage/shmem.h>
#include <utils/backend_status.h>
#include <utils/dsa.h>
#include <utils/guc.h>
#include <utils/snapmgr.h>
#include <utils/wait_event.h>
}
#undef Abs
#include <arrow/buffer.h>
#include <arrow/builder.h>
#include <arrow/flight/server_middleware.h>
#include <arrow/flight/sql/server.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/ipc/writer.h>
#include <arrow/table_builder.h>
#include <arrow/util/base64.h>
#include <condition_variable>
#include <sstream>
#ifdef __GNUC__
#define AFS_FUNC __PRETTY_FUNCTION__
#else
#define AFS_FUNC __func__
#endif
#ifdef AFS_DEBUG
#define P(...) ereport(DEBUG5, errmsg_internal(__VA_ARGS__))
#else
#define P(...)
#endif
extern "C"
{
PG_MODULE_MAGIC;
extern PGDLLEXPORT void _PG_init(void);
extern PGDLLEXPORT void afs_executor(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn();
extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn();
}
namespace {
static const char* LibraryName = "arrow_flight_sql";
static const char* SharedDataName = "arrow-flight-sql: shared data";
static const char* Tag = "arrow-flight-sql";
static const char* URIDefault = "grpc://127.0.0.1:15432";
static char* URI;
static const int SessionTimeoutDefault = 300;
static int SessionTimeout;
static const int MaxNRowsPerRecordBatchDefault = 1 * 1024 * 1024;
static int MaxNRowsPerRecordBatch;
static volatile sig_atomic_t GotSIGTERM = false;
void afs_sigterm(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGTERM = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGHUP = false;
void afs_sighup(SIGNAL_ARGS)
{
auto errnoSaved = errno;
GotSIGHUP = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static volatile sig_atomic_t GotSIGUSR1 = false;
void afs_sigusr1(SIGNAL_ARGS)
{
procsignal_sigusr1_handler(postgres_signal_arg);
auto errnoSaved = errno;
GotSIGUSR1 = true;
SetLatch(MyLatch);
errno = errnoSaved;
}
static shmem_request_hook_type PreviousShmemRequestHook = nullptr;
static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche";
void
afs_shmem_request_hook(void)
{
if (PreviousShmemRequestHook)
PreviousShmemRequestHook();
RequestNamedLWLockTranche(LWLockTrancheName, 1);
}
struct ConnectData {
dsa_pointer databaseName;
dsa_pointer userName;
dsa_pointer password;
};
struct SharedRingBufferData {
dsa_pointer pointer;
size_t total;
size_t head;
size_t tail;
};
class SharedRingBuffer {
public:
static void initialize_data(SharedRingBufferData* data)
{
data->pointer = InvalidDsaPointer;
data->total = 0;
data->head = 0;
data->tail = 0;
}
SharedRingBuffer(SharedRingBufferData* data, dsa_area* area)
: data_(data), area_(area)
{
}
void allocate(size_t total)
{
data_->pointer = dsa_allocate(area_, total);
data_->total = total;
data_->head = 0;
data_->tail = 0;
}
void free()
{
dsa_free(area_, data_->pointer);
initialize_data(data_);
}
size_t size() const
{
if (data_->head <= data_->tail)
{
return data_->tail - data_->head;
}
else
{
return (data_->total - data_->head) + data_->tail;
}
}
size_t rest_size() const { return data_->total - size() - 1; }
size_t write(const void* data, size_t n)
{
P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n);
if (rest_size() == 0)
{
P("%s: %s: after: no space: (%d:%d) %d:0",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n);
return 0;
}
size_t writtenSize = 0;
auto output = address();
if (data_->head <= data_->tail)
{
auto restSize = data_->total - data_->tail;
if (data_->head == 0)
{
restSize--;
}
const auto firstHalfWriteSize = std::min(n, restSize);
P("%s: %s: first half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
firstHalfWriteSize);
memcpy(output + data_->tail, data, firstHalfWriteSize);
data_->tail = (data_->tail + firstHalfWriteSize) % data_->total;
n -= firstHalfWriteSize;
writtenSize += firstHalfWriteSize;
}
if (n > 0 && rest_size() > 0)
{
const auto lastHalfWriteSize = std::min(n, data_->head - data_->tail - 1);
P("%s: %s: last half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
lastHalfWriteSize);
memcpy(output + data_->tail,
static_cast<const uint8_t*>(data) + writtenSize,
lastHalfWriteSize);
data_->tail += lastHalfWriteSize;
n -= lastHalfWriteSize;
writtenSize += lastHalfWriteSize;
}
P("%s: %s: after: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
writtenSize);
return writtenSize;
}
size_t read(size_t n, void* output)
{
P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n);
size_t readSize = 0;
const auto input = address();
if (data_->head > data_->tail)
{
const auto firstHalfReadSize = std::min(n, data_->total - data_->head);
P("%s: %s: first half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
firstHalfReadSize);
memcpy(output, input + data_->head, firstHalfReadSize);
data_->head = (data_->head + firstHalfReadSize) % data_->total;
n -= firstHalfReadSize;
readSize += firstHalfReadSize;
}
if (n > 0 && data_->head != data_->tail)
{
const auto lastHalfReadSize = std::min(n, data_->tail - data_->head);
P("%s: %s: last half: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
lastHalfReadSize);
memcpy(static_cast<uint8_t*>(output) + readSize,
input + data_->head,
lastHalfReadSize);
data_->head += lastHalfReadSize;
n -= lastHalfReadSize;
readSize += lastHalfReadSize;
}
P("%s: %s: after: (%d:%d) %d:%d",
Tag,
AFS_FUNC,
data_->head,
data_->tail,
n,
readSize);
return readSize;
}
private:
SharedRingBufferData* data_;
dsa_area* area_;
uint8_t* address()
{
return static_cast<uint8_t*>(dsa_get_address(area_, data_->pointer));
}
};
struct ExecuteData {
dsa_pointer query;
SharedRingBufferData bufferData;
};
struct SharedData {
dsa_handle handle;
pid_t executorPID;
pid_t serverPID;
pid_t mainPID;
ConnectData connectData;
ExecuteData executeData;
};
class Processor {
public:
Processor(const char* tag)
: tag_(tag),
sharedData_(nullptr),
area_(nullptr),
lock_(),
mutex_(),
conditionVariable_()
{
}
virtual ~Processor() { dsa_detach(area_); }
const char* tag() { return tag_; }
SharedRingBuffer create_shared_ring_buffer()
{
return SharedRingBuffer(&(sharedData_->executeData.bufferData), area_);
}
void lock_acquire(LWLockMode mode) { LWLockAcquire(lock_, LW_EXCLUSIVE); }
void lock_release() { LWLockRelease(lock_); }
void wait_executor_written(SharedRingBuffer* buffer)
{
if (ARROW_PREDICT_FALSE(sharedData_->executorPID == InvalidPid))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: executor isn't alive", Tag, tag_));
}
P("%s: %s: %s: kill executor: %d", Tag, tag_, AFS_FUNC, sharedData_->executorPID);
kill(sharedData_->executorPID, SIGUSR1);
auto size = buffer->size();
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait: write: %d:%d",
Tag,
tag_,
AFS_FUNC,
buffer->size(),
size);
return buffer->size() != size;
});
}
void wait_server_read(SharedRingBuffer* buffer)
{
if (ARROW_PREDICT_FALSE(sharedData_->serverPID == InvalidPid))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: server isn't alive", Tag, tag_));
}
P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
auto restSize = buffer->rest_size();
while (true)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
WaitLatch(MyLatch, events, -1, PG_WAIT_EXTENSION);
if (GotSIGTERM)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
P("%s: %s: %s: wait: read: %d:%d",
Tag,
tag_,
AFS_FUNC,
buffer->rest_size(),
restSize);
if (buffer->rest_size() != restSize)
{
break;
}
}
CHECK_FOR_INTERRUPTS();
}
}
void signaled()
{
P("%s: %s: signaled: before", Tag, tag_);
conditionVariable_.notify_all();
P("%s: %s: signaled: after", Tag, tag_);
}
protected:
const char* tag_;
SharedData* sharedData_;
dsa_area* area_;
LWLock* lock_;
std::mutex mutex_;
std::condition_variable conditionVariable_;
};
class SharedRingBufferInputStream : public arrow::io::InputStream {
public:
SharedRingBufferInputStream(Processor* processor)
: arrow::io::InputStream(), processor_(processor), position_(0), is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Result<int64_t> Read(int64_t nBytes, void* out) override
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferInputStream is closed");
}
auto buffer = std::move(processor_->create_shared_ring_buffer());
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
processor_->lock_acquire(LW_EXCLUSIVE);
auto readBytes = buffer.read(rest, out);
processor_->lock_release();
position_ += readBytes;
rest -= readBytes;
out = static_cast<uint8_t*>(out) + readBytes;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
processor_->wait_executor_written(&buffer);
}
return nBytes;
}
arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nBytes) override
{
ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(nBytes));
ARROW_ASSIGN_OR_RAISE(auto readBytes, Read(nBytes, buffer->mutable_data()));
ARROW_RETURN_NOT_OK(buffer->Resize(readBytes, false));
buffer->ZeroPadding();
return std::move(buffer);
}
private:
Processor* processor_;
int64_t position_;
bool is_open_;
};
class SharedRingBufferOutputStream : public arrow::io::OutputStream {
public:
SharedRingBufferOutputStream(Processor* processor)
: arrow::io::OutputStream(), processor_(processor), position_(0), is_open_(true)
{
}
arrow::Status Close() override
{
is_open_ = false;
return arrow::Status::OK();
}
bool closed() const override { return !is_open_; }
arrow::Result<int64_t> Tell() const override { return position_; }
arrow::Status Write(const void* data, int64_t nBytes) override
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() +
": SharedRingBufferOutputStream is closed");
}
if (ARROW_PREDICT_TRUE(nBytes > 0))
{
auto buffer = std::move(processor_->create_shared_ring_buffer());
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
processor_->lock_acquire(LW_EXCLUSIVE);
auto writtenSize = buffer.write(data, rest);
processor_->lock_release();
position_ += writtenSize;
rest -= writtenSize;
data = static_cast<const uint8_t*>(data) + writtenSize;
if (ARROW_PREDICT_TRUE(rest == 0))
{
break;
}
processor_->wait_server_read(&buffer);
}
}
return arrow::Status::OK();
}
using arrow::io::OutputStream::Write;
private:
Processor* processor_;
int64_t position_;
bool is_open_;
};
class WorkerProcessor : public Processor {
public:
explicit WorkerProcessor(const char* tag) : Processor(tag)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
auto sharedData = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(SharedData), &found));
if (!found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data isn't created yet", Tag, tag_));
}
auto area = dsa_attach(sharedData->handle);
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
sharedData_ = sharedData;
area_ = area;
}
};
class Executor : public WorkerProcessor {
public:
explicit Executor() : WorkerProcessor("executor") {}
void open()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": opening").c_str());
LWLockAcquire(lock_, LW_EXCLUSIVE);
BackgroundWorkerInitializeConnection(
static_cast<const char*>(
dsa_get_address(area_, sharedData_->connectData.databaseName)),
static_cast<const char*>(
dsa_get_address(area_, sharedData_->connectData.userName)),
0);
unsetSharedString(sharedData_->connectData.databaseName);
unsetSharedString(sharedData_->connectData.userName);
unsetSharedString(sharedData_->connectData.password);
{
SharedRingBuffer buffer(&(sharedData_->executeData.bufferData), area_);
// TODO: Customizable.
buffer.allocate(1L * 1024L * 1024L);
}
LWLockRelease(lock_);
StartTransactionCommand();
SPI_connect();
pgstat_report_activity(STATE_IDLE, NULL);
}
void close()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": closing").c_str());
SPI_finish();
CommitTransactionCommand();
LWLockAcquire(lock_, LW_EXCLUSIVE);
{
SharedRingBuffer buffer(&(sharedData_->executeData.bufferData), area_);
buffer.free();
}
sharedData_->executorPID = InvalidPid;
LWLockRelease(lock_);
pgstat_report_activity(STATE_IDLE, NULL);
}
void signaled()
{
P("%s: %s: signaled: before: %d", Tag, tag_, sharedData_->executeData.query);
P("signaled: before: %d", sharedData_->executeData.query);
if (DsaPointerIsValid(sharedData_->executeData.query))
{
execute();
}
else
{
Processor::signaled();
}
P("%s: %s: signaled: after: %d", Tag, tag_, sharedData_->executeData.query);
}
private:
void unsetSharedString(dsa_pointer& pointer)
{
if (!DsaPointerIsValid(pointer))
{
return;
}
dsa_free(area_, pointer);
pointer = InvalidDsaPointer;
}
void execute()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str());
PushActiveSnapshot(GetTransactionSnapshot());
LWLockAcquire(lock_, LW_EXCLUSIVE);
auto query = static_cast<const char*>(
dsa_get_address(area_, sharedData_->executeData.query));
SetCurrentStatementStartTimestamp();
P("%s: %s: execute: %s", Tag, tag_, query);
auto result = SPI_execute(query, true, 0);
dsa_free(area_, sharedData_->executeData.query);
sharedData_->executeData.query = InvalidDsaPointer;
LWLockRelease(lock_);
if (result == SPI_OK_SELECT)
{
pgstat_report_activity(STATE_RUNNING,
(std::string(Tag) + ": writing").c_str());
auto status = write();
if (!status.ok())
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: failed to write", Tag, tag_));
}
}
PopActiveSnapshot();
if (sharedData_->serverPID != InvalidPid)
{
P("%s: %s: kill server: %s", Tag, tag_, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}
pgstat_report_activity(STATE_IDLE, NULL);
}
arrow::Status write()
{
SharedRingBufferOutputStream output(this);
std::vector<std::shared_ptr<arrow::Field>> fields;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i);
std::shared_ptr<arrow::DataType> type;
switch (attribute->atttypid)
{
case INT4OID:
type = arrow::int32();
break;
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute->atttypid);
}
fields.push_back(
arrow::field(NameStr(attribute->attname), type, !attribute->attnotnull));
}
auto schema = arrow::schema(fields);
ARROW_ASSIGN_OR_RAISE(
auto builder,
arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
auto option = arrow::ipc::IpcWriteOptions::Defaults();
option.emit_dictionary_deltas = true;
// Write schema only stream format data to return only schema.
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(&output, schema, option));
// Build an empty record batch to write schema.
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
P("%s: %s: write: schema: WriteRecordBatch", Tag, tag_);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
P("%s: %s: write: schema: Close", Tag, tag_);
ARROW_RETURN_NOT_OK(writer->Close());
// Write another stream format data with record batches.
ARROW_ASSIGN_OR_RAISE(writer,
arrow::ipc::MakeStreamWriter(&output, schema, option));
bool needLastFlush = false;
for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
{
P("%s: %s: write: data: record batch: %d/%d",
Tag,
tag_,
iTuple,
SPI_processed);
for (uint64_t iAttribute = 0; iAttribute < SPI_tuptable->tupdesc->natts;
++iAttribute)
{
P("%s: %s: write: data: record batch: %d/%d: %d/%d",
Tag,
tag_,
iTuple,
SPI_processed,
iAttribute,
SPI_tuptable->tupdesc->natts);
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
SPI_tuptable->tupdesc,
iAttribute + 1,
&isNull);
if (isNull)
{
auto arrayBuilder = builder->GetField(iAttribute);
ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull());
}
else
{
auto arrayBuilder =
builder->GetFieldAs<arrow::Int32Builder>(iAttribute);
ARROW_RETURN_NOT_OK(arrayBuilder->Append(DatumGetInt32(datum)));
}
}
if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0) {
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: write: data: WriteRecordBatch: %d/%d", Tag, tag_, iTuple, SPI_processed);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
needLastFlush = false;
} else {
needLastFlush = true;
}
}
if (needLastFlush) {
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
P("%s: %s: write: data: WriteRecordBatch", Tag, tag_);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
}
P("%s: %s: write: data: Close", Tag, tag_);
ARROW_RETURN_NOT_OK(writer->Close());
return output.Close();
}
};
class Proxy : public WorkerProcessor {
public:
explicit Proxy() : WorkerProcessor("proxy") {}
arrow::Status connect(const std::string& databaseName,
const std::string& userName,
const std::string& password)
{
if (sharedData_->executorPID != InvalidPid)
{
return arrow::Status::OK();
}
LWLockAcquire(lock_, LW_EXCLUSIVE);
setSharedString(sharedData_->connectData.databaseName, databaseName);
setSharedString(sharedData_->connectData.userName, userName);
setSharedString(sharedData_->connectData.password, password);
LWLockRelease(lock_);
kill(sharedData_->mainPID, SIGUSR1);
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(
lock, [&] { return sharedData_->executorPID != InvalidPid; });
}
return arrow::Status::OK();
}
arrow::Result<std::shared_ptr<arrow::Schema>> execute(const std::string& query)
{
LWLockAcquire(lock_, LW_EXCLUSIVE);
setSharedString(sharedData_->executeData.query, query);
LWLockRelease(lock_);
if (sharedData_->executorPID != InvalidPid)
{
P("%s: %s: execute: kill executor: %d", Tag, tag_, sharedData_->executorPID);
kill(sharedData_->executorPID, SIGUSR1);
}
P("%s: %s: execute: open", Tag, tag_);
auto input = std::make_shared<SharedRingBufferInputStream>(this);
// Read schema only stream format data.
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
P("%s: %s: execute: read next", Tag, tag_);
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
}
P("%s: %s: execute: schema", Tag, tag_);
return reader->schema();
}
arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read()
{
auto input = std::make_shared<SharedRingBufferInputStream>(this);
// Read another stream format data with record batches.
return arrow::ipc::RecordBatchStreamReader::Open(input);
}
private:
void setSharedString(dsa_pointer& pointer, const std::string& input)
{
if (input.empty())
{
return;
}
pointer = dsa_allocate(area_, input.size() + 1);
memcpy(dsa_get_address(area_, pointer), input.c_str(), input.size() + 1);
}
};
class MainProcessor : public Processor {
public:
MainProcessor() : Processor("main")
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
auto sharedData = static_cast<SharedData*>(
ShmemInitStruct(SharedDataName, sizeof(SharedData), &found));
if (found)
{
LWLockRelease(AddinShmemInitLock);
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: shared data is already created", Tag, tag_));
}
auto area = dsa_create(LWLockNewTrancheId());
sharedData->handle = dsa_get_handle(area);
sharedData->executorPID = InvalidPid;
sharedData->serverPID = InvalidPid;
sharedData->mainPID = MyProcPid;
sharedData->connectData.databaseName = InvalidDsaPointer;
sharedData->connectData.userName = InvalidDsaPointer;
sharedData->connectData.password = InvalidDsaPointer;
SharedRingBuffer::initialize_data(&(sharedData->executeData.bufferData));
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
sharedData_ = sharedData;
area_ = area;
}
BackgroundWorkerHandle* start_server()
{
BackgroundWorker worker = {0};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: server", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
if (!RegisterDynamicBackgroundWorker(&worker, &handle))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: failed to start server", Tag, tag_));
}
WaitForBackgroundWorkerStartup(handle, &(sharedData_->serverPID));
return handle;
}
void process_connect_request()
{
if (!DsaPointerIsValid(sharedData_->connectData.databaseName))
{
return;
}
BackgroundWorker worker = {0};
// TODO: Add executor ID to bgw_name
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: executor", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: executor", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = MyProcPid;
BackgroundWorkerHandle* handle;
if (!RegisterDynamicBackgroundWorker(&worker, &handle))
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: %s: failed to start executor", Tag, tag_));
}
WaitForBackgroundWorkerStartup(handle, &(sharedData_->executorPID));
kill(sharedData_->serverPID, SIGUSR1);
}
};
class HeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
explicit HeaderAuthServerMiddlewareFactory(Proxy* proxy)
: arrow::flight::ServerMiddlewareFactory(), proxy_(proxy)
{
}
arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::CallHeaders& incoming_headers,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)
{
std::string databaseName("postgres");
auto databaseHeader = incoming_headers.find("x-flight-sql-database");
if (databaseHeader != incoming_headers.end())
{
databaseName = databaseHeader->second;
}
std::string userName("");
std::string password("");
auto authorizationHeader = incoming_headers.find("authorization");
if (authorizationHeader != incoming_headers.end())
{
std::stringstream decodedStream(
arrow::util::base64_decode(authorizationHeader->second));
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
}
auto status = proxy_->connect(databaseName, userName, password);
if (status.ok())
{
return status;
}
else
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated, status.ToString());
}
}
private:
Proxy* proxy_;
};
class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
public:
explicit FlightSQLServer(Proxy* proxy)
: arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy)
{
}
~FlightSQLServer() override {}
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor)
{
const auto& query = command.query;
ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(query));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
arrow::flight::FlightEndpoint{std::move(ticket), {}}};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command)
{
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read());
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}
private:
Proxy* proxy_;
};
arrow::Status
afs_server_internal(Proxy* proxy)
{
ARROW_ASSIGN_OR_RAISE(auto location, arrow::flight::Location::Parse(URI));
arrow::flight::FlightServerOptions options(location);
options.middleware.push_back(
{"header-auth", std::make_shared<HeaderAuthServerMiddlewareFactory>(proxy)});
FlightSQLServer flightSQLServer(proxy);
ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
proxy->signaled();
}
CHECK_FOR_INTERRUPTS();
}
// TODO: Use before_shmem_exit()
auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10);
return flightSQLServer.Shutdown(&deadline);
}
} // namespace
extern "C" void
afs_executor(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
Executor executor;
executor.open();
while (!GotSIGTERM)
{
int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
const long timeout = SessionTimeout * 1000;
if (timeout >= 0)
{
events |= WL_TIMEOUT;
}
auto conditions = WaitLatch(MyLatch, events, timeout, PG_WAIT_EXTENSION);
if (conditions & WL_TIMEOUT)
{
break;
}
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
executor.signaled();
}
CHECK_FOR_INTERRUPTS();
}
// TODO: Use before_shmem_exit()
executor.close();
}
proc_exit(0);
}
extern "C" void
afs_server(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
arrow::Status status;
{
Proxy proxy;
status = afs_server_internal(&proxy);
}
if (!status.ok())
{
ereport(ERROR,
errcode(ERRCODE_INTERNAL_ERROR),
errmsg("%s: server: failed: %s", Tag, status.ToString().c_str()));
}
}
proc_exit(0);
}
extern "C" void
afs_main(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
pqsignal(SIGHUP, afs_sighup);
pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
{
MainProcessor processor;
auto serverHandle = processor.start_server();
while (!GotSIGTERM)
{
WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
if (GotSIGHUP)
{
GotSIGHUP = false;
ProcessConfigFile(PGC_SIGHUP);
}
if (GotSIGUSR1)
{
GotSIGUSR1 = false;
processor.process_connect_request();
}
CHECK_FOR_INTERRUPTS();
}
WaitForBackgroundWorkerShutdown(serverHandle);
}
proc_exit(0);
}
extern "C" void
_PG_init(void)
{
if (!process_shared_preload_libraries_in_progress)
{
return;
}
DefineCustomStringVariable("arrow_flight_sql.uri",
"Apache Arrow Flight SQL endpoint URI.",
(std::string("default: ") + URIDefault).c_str(),
&URI,
URIDefault,
PGC_POSTMASTER,
0,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.session_timeout",
"Maximum session duration in seconds.",
"The default is 300 seconds. "
"-1 means no timeout.",
&SessionTimeout,
SessionTimeoutDefault,
-1,
INT_MAX,
PGC_USERSET,
GUC_UNIT_S,
NULL,
NULL,
NULL);
DefineCustomIntVariable("arrow_flight_sql.max_n_rows_per_record_batch",
"The maximum number of rows per record batch.",
"The default is 1 * 1024 * 1024 rows.",
&MaxNRowsPerRecordBatch,
MaxNRowsPerRecordBatchDefault,
1,
INT_MAX,
PGC_USERSET,
0,
NULL,
NULL,
NULL);
PreviousShmemRequestHook = shmem_request_hook;
shmem_request_hook = afs_shmem_request_hook;
BackgroundWorker worker = {0};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, "%s: main", Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = 0;
RegisterBackgroundWorker(&worker);
}