| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| extern "C" |
| { |
| #include <postgres.h> |
| |
| #include <access/xact.h> |
| #include <executor/spi.h> |
| #include <fmgr.h> |
| #include <lib/dshash.h> |
| #include <libpq/crypt.h> |
| #include <libpq/libpq-be.h> |
| #include <libpq/libpq.h> |
| #include <miscadmin.h> |
| #include <postmaster/bgworker.h> |
| #include <postmaster/postmaster.h> |
| #include <storage/ipc.h> |
| #include <storage/latch.h> |
| #include <storage/lwlock.h> |
| #include <storage/procsignal.h> |
| #include <storage/shmem.h> |
| #include <utils/backend_status.h> |
| #include <utils/builtins.h> |
| #include <utils/dsa.h> |
| #include <utils/guc.h> |
| #include <utils/memutils.h> |
| #include <utils/snapmgr.h> |
| #include <utils/timestamp.h> |
| #include <utils/wait_event.h> |
| } |
| |
| #undef Abs |
| |
| #if PG_VERSION_NUM >= 150000 |
| # define PGRN_HAVE_SHMEM_REQUEST_HOOK |
| #endif |
| |
| #include <arrow/buffer.h> |
| #include <arrow/builder.h> |
| #include <arrow/flight/server_middleware.h> |
| #include <arrow/flight/sql/server.h> |
| #include <arrow/io/memory.h> |
| #include <arrow/ipc/reader.h> |
| #include <arrow/ipc/writer.h> |
| #include <arrow/table_builder.h> |
| #include <arrow/util/base64.h> |
| |
| #include <cinttypes> |
| #include <condition_variable> |
| #include <fstream> |
| #include <iterator> |
| #include <map> |
| #include <random> |
| #include <sstream> |
| #include <type_traits> |
| |
| #include <arpa/inet.h> |
| |
| #ifdef __GNUC__ |
| # define AFS_FUNC __PRETTY_FUNCTION__ |
| #else |
| # define AFS_FUNC __func__ |
| #endif |
| |
| #ifdef AFS_DEBUG |
| # define P(...) ereport(DEBUG5, errmsg_internal(__VA_ARGS__)) |
| #else |
| # define P(...) |
| #endif |
| |
| extern "C" |
| { |
| PG_MODULE_MAGIC; |
| |
| extern PGDLLEXPORT void _PG_init(void); |
| extern PGDLLEXPORT void afs_executor(Datum datum) pg_attribute_noreturn(); |
| extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn(); |
| extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn(); |
| } |
| |
| namespace { |
| static const char* LibraryName = "arrow_flight_sql"; |
| static const char* SharedDataName = "arrow-flight-sql: shared data"; |
| static const char* Tag = "arrow-flight-sql"; |
| |
| static const char* URIDefault = "grpc://127.0.0.1:15432"; |
| static char* URI; |
| |
| static const int SessionTimeoutDefault = 300; |
| static int SessionTimeout; |
| |
| static const int MaxNRowsPerRecordBatchDefault = 1 * 1024 * 1024; |
| static int MaxNRowsPerRecordBatch; |
| |
| static volatile sig_atomic_t GotSIGTERM = false; |
| void |
| afs_sigterm(SIGNAL_ARGS) |
| { |
| auto errnoSaved = errno; |
| GotSIGTERM = true; |
| SetLatch(MyLatch); |
| errno = errnoSaved; |
| } |
| |
| static volatile sig_atomic_t GotSIGHUP = false; |
| void |
| afs_sighup(SIGNAL_ARGS) |
| { |
| auto errnoSaved = errno; |
| GotSIGHUP = true; |
| SetLatch(MyLatch); |
| errno = errnoSaved; |
| } |
| |
| static volatile sig_atomic_t GotSIGUSR1 = false; |
| void |
| afs_sigusr1(SIGNAL_ARGS) |
| { |
| procsignal_sigusr1_handler(postgres_signal_arg); |
| auto errnoSaved = errno; |
| GotSIGUSR1 = true; |
| SetLatch(MyLatch); |
| errno = errnoSaved; |
| } |
| |
| #ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK |
| static shmem_request_hook_type PreviousShmemRequestHook = nullptr; |
| #endif |
| static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche"; |
| void |
| afs_shmem_request_hook(void) |
| { |
| #ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK |
| if (PreviousShmemRequestHook) |
| PreviousShmemRequestHook(); |
| #endif |
| RequestNamedLWLockTranche(LWLockTrancheName, 1); |
| } |
| |
| class ScopedMemoryContext { |
| public: |
| explicit ScopedMemoryContext(MemoryContext memoryContext) |
| : memoryContext_(memoryContext), oldMemoryContext_(nullptr) |
| { |
| oldMemoryContext_ = MemoryContextSwitchTo(memoryContext_); |
| } |
| |
| ~ScopedMemoryContext() |
| { |
| MemoryContextSwitchTo(oldMemoryContext_); |
| MemoryContextDelete(memoryContext_); |
| } |
| |
| private: |
| MemoryContext memoryContext_; |
| MemoryContext oldMemoryContext_; |
| }; |
| |
| struct ScopedTransaction { |
| ScopedTransaction() { StartTransactionCommand(); } |
| ~ScopedTransaction() { CommitTransactionCommand(); } |
| }; |
| |
| struct ScopedSnapshot { |
| ScopedSnapshot() { PushActiveSnapshot(GetTransactionSnapshot()); } |
| ~ScopedSnapshot() { PopActiveSnapshot(); } |
| }; |
| |
| struct ScopedPlan { |
| ScopedPlan(SPIPlanPtr plan) : plan_(plan) {} |
| ~ScopedPlan() { SPI_freeplan(plan_); } |
| SPIPlanPtr plan_; |
| }; |
| |
| struct SharedRingBufferData { |
| dsa_pointer pointer; |
| size_t total; |
| size_t head; |
| size_t tail; |
| }; |
| |
| // Naive ring buffer implementation. We can improve this later. |
| class SharedRingBuffer { |
| public: |
| static void initialize_data(SharedRingBufferData* data) |
| { |
| data->pointer = InvalidDsaPointer; |
| data->total = 0; |
| data->head = 0; |
| data->tail = 0; |
| } |
| |
| static void free_data(SharedRingBufferData* data, dsa_area* area) |
| { |
| if (data->pointer != InvalidDsaPointer) |
| dsa_free(area, data->pointer); |
| initialize_data(data); |
| } |
| |
| SharedRingBuffer(SharedRingBufferData* data, dsa_area* area) |
| : data_(data), area_(area) |
| { |
| } |
| |
| void allocate(size_t total) |
| { |
| data_->pointer = dsa_allocate(area_, total); |
| data_->total = total; |
| data_->head = 0; |
| data_->tail = 0; |
| } |
| |
| void free() { free_data(data_, area_); } |
| |
| size_t size() const |
| { |
| if (data_->head <= data_->tail) |
| { |
| return data_->tail - data_->head; |
| } |
| else |
| { |
| return (data_->total - data_->head) + data_->tail; |
| } |
| } |
| |
| size_t rest_size() const { return data_->total - size() - 1; } |
| |
| size_t write(const void* data, size_t n) |
| { |
| P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n); |
| if (rest_size() == 0) |
| { |
| P("%s: %s: after: no space: (%d:%d) %d:0", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n); |
| return 0; |
| } |
| |
| size_t writtenSize = 0; |
| auto output = address(); |
| if (data_->head <= data_->tail) |
| { |
| auto restSize = data_->total - data_->tail; |
| if (data_->head == 0) |
| { |
| restSize--; |
| } |
| const auto firstHalfWriteSize = std::min(n, restSize); |
| P("%s: %s: first half: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| firstHalfWriteSize); |
| memcpy(output + data_->tail, data, firstHalfWriteSize); |
| data_->tail = (data_->tail + firstHalfWriteSize) % data_->total; |
| n -= firstHalfWriteSize; |
| writtenSize += firstHalfWriteSize; |
| } |
| if (n > 0 && rest_size() > 0) |
| { |
| const auto lastHalfWriteSize = std::min(n, data_->head - data_->tail - 1); |
| P("%s: %s: last half: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| lastHalfWriteSize); |
| memcpy(output + data_->tail, |
| static_cast<const uint8_t*>(data) + writtenSize, |
| lastHalfWriteSize); |
| data_->tail += lastHalfWriteSize; |
| n -= lastHalfWriteSize; |
| writtenSize += lastHalfWriteSize; |
| } |
| P("%s: %s: after: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| writtenSize); |
| return writtenSize; |
| } |
| |
| size_t read(size_t n, void* output) |
| { |
| P("%s: %s: before: (%d:%d) %d", Tag, AFS_FUNC, data_->head, data_->tail, n); |
| size_t readSize = 0; |
| const auto input = address(); |
| if (data_->head > data_->tail) |
| { |
| const auto firstHalfReadSize = std::min(n, data_->total - data_->head); |
| P("%s: %s: first half: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| firstHalfReadSize); |
| memcpy(output, input + data_->head, firstHalfReadSize); |
| data_->head = (data_->head + firstHalfReadSize) % data_->total; |
| n -= firstHalfReadSize; |
| readSize += firstHalfReadSize; |
| } |
| if (n > 0 && data_->head != data_->tail) |
| { |
| const auto lastHalfReadSize = std::min(n, data_->tail - data_->head); |
| P("%s: %s: last half: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| lastHalfReadSize); |
| memcpy(static_cast<uint8_t*>(output) + readSize, |
| input + data_->head, |
| lastHalfReadSize); |
| data_->head += lastHalfReadSize; |
| n -= lastHalfReadSize; |
| readSize += lastHalfReadSize; |
| } |
| P("%s: %s: after: (%d:%d) %d:%d", |
| Tag, |
| AFS_FUNC, |
| data_->head, |
| data_->tail, |
| n, |
| readSize); |
| return readSize; |
| } |
| |
| private: |
| SharedRingBufferData* data_; |
| dsa_area* area_; |
| |
| uint8_t* address() |
| { |
| return static_cast<uint8_t*>(dsa_get_address(area_, data_->pointer)); |
| } |
| }; |
| |
| enum class Action |
| { |
| None, |
| Select, |
| Update, |
| Prepare, |
| ClosePreparedStatement, |
| SetParameters, |
| SelectPreparedStatement, |
| UpdatePreparedStatement, |
| }; |
| |
| const char* |
| action_name(Action action) |
| { |
| switch (action) |
| { |
| case Action::None: |
| return "Action::None"; |
| case Action::Select: |
| return "Action::Select"; |
| case Action::Update: |
| return "Action::Update"; |
| case Action::Prepare: |
| return "Action::Prepare"; |
| case Action::ClosePreparedStatement: |
| return "Action::ClosePreparedStatement"; |
| case Action::SetParameters: |
| return "Action::SetParameters"; |
| case Action::SelectPreparedStatement: |
| return "Action::SelectPreparedStatement"; |
| case Action::UpdatePreparedStatement: |
| return "Action::UpdatePreparedStatement"; |
| default: |
| return "Action::Unknown"; |
| } |
| } |
| |
| void |
| dsa_pointer_set_string(dsa_pointer& pointer, dsa_area* area, const std::string& input) |
| { |
| if (DsaPointerIsValid(pointer)) |
| { |
| dsa_free(area, pointer); |
| pointer = InvalidDsaPointer; |
| } |
| if (input.empty()) |
| { |
| return; |
| } |
| pointer = dsa_allocate(area, input.size() + 1); |
| memcpy(dsa_get_address(area, pointer), input.c_str(), input.size() + 1); |
| } |
| |
| // Put only data (don't add methods) to use with dshash. |
| struct SessionData { |
| uint64_t id; |
| dsa_pointer errorMessage; |
| pid_t executorPID; |
| bool initialized; |
| dsa_pointer databaseName; |
| dsa_pointer userName; |
| dsa_pointer password; |
| dsa_pointer clientAddress; |
| Action action; |
| dsa_pointer selectQuery; |
| dsa_pointer updateQuery; |
| int64_t nUpdatedRecords; |
| dsa_pointer prepareQuery; |
| dsa_pointer preparedStatementHandle; |
| SharedRingBufferData bufferData; |
| }; |
| |
| void |
| session_data_initialize(SessionData* session, |
| dsa_area* area, |
| const std::string& databaseName, |
| const std::string& userName, |
| const std::string& password, |
| const std::string& clientAddress) |
| { |
| session->errorMessage = InvalidDsaPointer; |
| session->executorPID = InvalidPid; |
| session->initialized = false; |
| dsa_pointer_set_string(session->databaseName, area, databaseName); |
| dsa_pointer_set_string(session->userName, area, userName); |
| dsa_pointer_set_string(session->password, area, password); |
| dsa_pointer_set_string(session->clientAddress, area, clientAddress); |
| session->action = Action::None; |
| session->selectQuery = InvalidDsaPointer; |
| session->updateQuery = InvalidDsaPointer; |
| session->nUpdatedRecords = -1; |
| session->prepareQuery = InvalidDsaPointer; |
| session->preparedStatementHandle = InvalidDsaPointer; |
| SharedRingBuffer::initialize_data(&(session->bufferData)); |
| } |
| |
| void |
| session_data_finalize(SessionData* session, dsa_area* area) |
| { |
| if (DsaPointerIsValid(session->errorMessage)) |
| dsa_free(area, session->errorMessage); |
| if (DsaPointerIsValid(session->databaseName)) |
| dsa_free(area, session->databaseName); |
| if (DsaPointerIsValid(session->userName)) |
| dsa_free(area, session->userName); |
| if (DsaPointerIsValid(session->password)) |
| dsa_free(area, session->password); |
| if (DsaPointerIsValid(session->selectQuery)) |
| dsa_free(area, session->selectQuery); |
| if (DsaPointerIsValid(session->updateQuery)) |
| dsa_free(area, session->updateQuery); |
| if (DsaPointerIsValid(session->prepareQuery)) |
| dsa_free(area, session->prepareQuery); |
| if (DsaPointerIsValid(session->preparedStatementHandle)) |
| dsa_free(area, session->preparedStatementHandle); |
| SharedRingBuffer::free_data(&(session->bufferData), area); |
| } |
| |
| class SessionReleaser { |
| public: |
| explicit SessionReleaser(dshash_table* sessions, SessionData* data) |
| : sessions_(sessions), data_(data) |
| { |
| } |
| |
| ~SessionReleaser() { dshash_release_lock(sessions_, data_); } |
| |
| private: |
| dshash_table* sessions_; |
| SessionData* data_; |
| }; |
| |
| static dshash_parameters SessionsParams = { |
| sizeof(uint64_t), |
| sizeof(SessionData), |
| dshash_memcmp, |
| dshash_memhash, |
| 0, // Set later because this is determined dynamically. |
| }; |
| |
| struct SharedData { |
| int trancheID; |
| dsa_handle handle; |
| int sessionsTrancheID; |
| dshash_table_handle sessionsHandle; |
| pid_t serverPID; |
| pid_t mainPID; |
| }; |
| |
| class Processor { |
| public: |
| enum class WaitMode |
| { |
| Read, |
| Written, |
| }; |
| |
| Processor(const char* tag, bool runInPGThread) |
| : tag_(tag), |
| runInPGThread_(runInPGThread), |
| sharedData_(nullptr), |
| area_(nullptr), |
| lock_(), |
| mutex_(), |
| conditionVariable_() |
| { |
| } |
| |
| virtual ~Processor() |
| { |
| if (area_) |
| { |
| dsa_detach(area_); |
| } |
| } |
| |
| const char* tag() { return tag_; } |
| |
| void lock_acquire() { LWLockAcquire(lock_, LW_EXCLUSIVE); } |
| |
| void lock_release() { LWLockRelease(lock_); } |
| |
| SharedRingBuffer create_shared_ring_buffer(SessionData* session) |
| { |
| return SharedRingBuffer(&(session->bufferData), area_); |
| } |
| |
| arrow::Status wait(SessionData* session, SharedRingBuffer* buffer, WaitMode mode) |
| { |
| const bool read = (mode == WaitMode::Read); |
| const char* tag = read ? "wait read" : "wait written"; |
| auto peerPID = peer_pid(session); |
| auto peerName = peer_name(session); |
| |
| if (ARROW_PREDICT_FALSE(peerPID == InvalidPid)) |
| { |
| return arrow::Status::IOError( |
| Tag, ": ", tag_, ": ", tag, ": ", peerName, ": not alive"); |
| } |
| |
| P("%s: %s: %s: %s: kill: %d", Tag, tag_, tag, peerName, peerPID); |
| kill(peerPID, SIGUSR1); |
| auto get_target_size = |
| read ? [](SharedRingBuffer* buffer) { return buffer->rest_size(); } |
| : [](SharedRingBuffer* buffer) { return buffer->size(); }; |
| auto targetSize = get_target_size(buffer); |
| if (runInPGThread_) |
| { |
| while (true) |
| { |
| int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH; |
| WaitLatch(MyLatch, events, -1, PG_WAIT_EXTENSION); |
| if (GotSIGTERM) |
| { |
| break; |
| } |
| ResetLatch(MyLatch); |
| |
| if (GotSIGUSR1) |
| { |
| GotSIGUSR1 = false; |
| P("%s: %s: %s: %s: wait: %d:%d", |
| Tag, |
| tag_, |
| tag, |
| peerName, |
| get_target_size(buffer), |
| targetSize); |
| if (get_target_size(buffer) != targetSize) |
| { |
| break; |
| } |
| } |
| |
| // TODO: Convert PG error to arrow::Status. |
| CHECK_FOR_INTERRUPTS(); |
| } |
| } |
| else |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: %s: wait: %d:%d", |
| Tag, |
| tag_, |
| tag, |
| peerName, |
| get_target_size(buffer), |
| targetSize); |
| if (INTERRUPTS_PENDING_CONDITION()) |
| { |
| return true; |
| } |
| return get_target_size(buffer) != targetSize; |
| }); |
| } |
| return arrow::Status::OK(); |
| } |
| |
| virtual void signaled() |
| { |
| P("%s: %s: signaled: before", Tag, tag_); |
| conditionVariable_.notify_all(); |
| P("%s: %s: signaled: after", Tag, tag_); |
| } |
| |
| protected: |
| void set_shared_string(dsa_pointer& pointer, const std::string& input) |
| { |
| dsa_pointer_set_string(pointer, area_, input); |
| } |
| |
| virtual pid_t peer_pid(SessionData* session) { return InvalidPid; } |
| |
| virtual const char* peer_name(SessionData* session) { return "unknown"; } |
| |
| const char* tag_; |
| bool runInPGThread_; |
| SharedData* sharedData_; |
| dsa_area* area_; |
| LWLock* lock_; |
| std::mutex mutex_; |
| std::condition_variable conditionVariable_; |
| }; |
| |
| struct ProcessorLockGuard { |
| ProcessorLockGuard(Processor* processor) : processor_(processor) |
| { |
| processor_->lock_acquire(); |
| } |
| ~ProcessorLockGuard() { processor_->lock_release(); } |
| |
| private: |
| Processor* processor_; |
| }; |
| |
| class Proxy; |
| class SharedRingBufferInputStream : public arrow::io::InputStream { |
| public: |
| SharedRingBufferInputStream(Processor* processor, SessionData* session) |
| : arrow::io::InputStream(), |
| processor_(processor), |
| session_(session), |
| position_(0), |
| is_open_(true) |
| { |
| } |
| |
| arrow::Status Close() override |
| { |
| is_open_ = false; |
| return arrow::Status::OK(); |
| } |
| |
| bool closed() const override { return !is_open_; } |
| |
| arrow::Result<int64_t> Tell() const override { return position_; } |
| |
| arrow::Result<int64_t> Read(int64_t nBytes, void* out) override; |
| |
| arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nBytes) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(nBytes)); |
| ARROW_ASSIGN_OR_RAISE(auto readBytes, Read(nBytes, buffer->mutable_data())); |
| ARROW_RETURN_NOT_OK(buffer->Resize(readBytes, false)); |
| buffer->ZeroPadding(); |
| return std::move(buffer); |
| } |
| |
| private: |
| Processor* processor_; |
| SessionData* session_; |
| int64_t position_; |
| bool is_open_; |
| }; |
| |
| class Executor; |
| class SharedRingBufferOutputStream : public arrow::io::OutputStream { |
| public: |
| SharedRingBufferOutputStream(Processor* processor, SessionData* session) |
| : arrow::io::OutputStream(), |
| processor_(processor), |
| session_(session), |
| position_(0), |
| is_open_(true) |
| { |
| } |
| |
| arrow::Status Close() override |
| { |
| is_open_ = false; |
| return arrow::Status::OK(); |
| } |
| |
| bool closed() const override { return !is_open_; } |
| |
| arrow::Result<int64_t> Tell() const override { return position_; } |
| |
| arrow::Status Write(const void* data, int64_t nBytes) override; |
| |
| using arrow::io::OutputStream::Write; |
| |
| private: |
| Processor* processor_; |
| SessionData* session_; |
| int64_t position_; |
| bool is_open_; |
| }; |
| |
| class WorkerProcessor : public Processor { |
| public: |
| explicit WorkerProcessor(const char* tag, bool runInPGThread) |
| : Processor(tag, runInPGThread), sessions_(nullptr) |
| { |
| LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); |
| bool found; |
| sharedData_ = static_cast<SharedData*>( |
| ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found)); |
| if (!found) |
| { |
| LWLockRelease(AddinShmemInitLock); |
| ereport(ERROR, |
| errcode(ERRCODE_INTERNAL_ERROR), |
| errmsg("%s: %s: shared data isn't created yet", Tag, tag_)); |
| } |
| area_ = dsa_attach(sharedData_->handle); |
| SessionsParams.tranche_id = sharedData_->sessionsTrancheID; |
| sessions_ = |
| dshash_attach(area_, &SessionsParams, sharedData_->sessionsHandle, nullptr); |
| lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock); |
| LWLockRelease(AddinShmemInitLock); |
| } |
| |
| ~WorkerProcessor() override { dshash_detach(sessions_); } |
| |
| protected: |
| void delete_session(SessionData* session) |
| { |
| session_data_finalize(session, area_); |
| dshash_delete_entry(sessions_, session); |
| } |
| |
| protected: |
| dshash_table* sessions_; |
| }; |
| |
| class ArrowPGTypeConverter : public arrow::TypeVisitor { |
| public: |
| explicit ArrowPGTypeConverter() : oid_(InvalidOid) {} |
| |
| Oid oid() const { return oid_; } |
| |
| arrow::Status Visit(const arrow::Int8Type& type) |
| { |
| oid_ = INT2OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt8Type& type) |
| { |
| oid_ = INT2OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int16Type& type) |
| { |
| oid_ = INT2OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt16Type& type) |
| { |
| oid_ = INT2OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int32Type& type) |
| { |
| oid_ = INT4OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt32Type& type) |
| { |
| oid_ = INT4OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int64Type& type) |
| { |
| oid_ = INT8OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt64Type& type) |
| { |
| oid_ = INT8OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::FloatType& type) |
| { |
| oid_ = FLOAT4OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::DoubleType& type) |
| { |
| oid_ = FLOAT8OID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::StringType& type) |
| { |
| oid_ = TEXTOID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::BinaryType& type) |
| { |
| oid_ = BYTEAOID; |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::TimestampType& type) |
| { |
| oid_ = TIMESTAMPOID; |
| return arrow::Status::OK(); |
| } |
| |
| private: |
| Oid oid_; |
| }; |
| |
| class ArrowPGValueConverter : public arrow::ArrayVisitor { |
| public: |
| explicit ArrowPGValueConverter(int64_t i_row, Datum& datum) |
| : i_row_(i_row), datum_(datum) |
| { |
| } |
| |
| arrow::Status Visit(const arrow::Int8Array& array) |
| { |
| datum_ = Int8GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt8Array& array) |
| { |
| datum_ = UInt8GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int16Array& array) |
| { |
| datum_ = Int16GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt16Array& array) |
| { |
| datum_ = UInt16GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int32Array& array) |
| { |
| datum_ = Int32GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt32Array& array) |
| { |
| datum_ = UInt32GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::Int64Array& array) |
| { |
| datum_ = Int64GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::UInt64Array& array) |
| { |
| datum_ = UInt64GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::FloatArray& array) |
| { |
| datum_ = Float4GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::DoubleArray& array) |
| { |
| datum_ = Float8GetDatum(array.Value(i_row_)); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::StringArray& array) |
| { |
| auto value = array.GetView(i_row_); |
| datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length())); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::BinaryArray& array) |
| { |
| auto value = array.GetView(i_row_); |
| datum_ = PointerGetDatum(cstring_to_text_with_len(value.data(), value.length())); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status Visit(const arrow::TimestampArray& array) |
| { |
| const auto unit = |
| std::static_pointer_cast<arrow::TimestampType>(array.type())->unit(); |
| Timestamp value = 0; |
| switch (unit) |
| { |
| case arrow::TimeUnit::SECOND: |
| value += array.Value(i_row_) * 1000000; |
| break; |
| case arrow::TimeUnit::MILLI: |
| value += array.Value(i_row_) * 1000; |
| break; |
| case arrow::TimeUnit::MICRO: |
| value += array.Value(i_row_); |
| break; |
| case arrow::TimeUnit::NANO: |
| value += array.Value(i_row_) / 1000; |
| break; |
| default: |
| return arrow::Status::NotImplemented("Unsupported time unit: ", unit); |
| } |
| // Arrow uses UNIX epoch (1970-01-01) but PostgreSQL uses 2000-01-01. |
| value -= (POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY; |
| datum_ = TimestampGetDatum(value); |
| return arrow::Status::OK(); |
| } |
| |
| private: |
| int64_t i_row_; |
| Datum& datum_; |
| }; |
| |
| class PGArrowValueConverter : public arrow::ArrayVisitor { |
| public: |
| explicit PGArrowValueConverter(Form_pg_attribute attribute) : attribute_(attribute) {} |
| |
| arrow::Result<std::shared_ptr<arrow::DataType>> convert_type() const |
| { |
| switch (attribute_->atttypid) |
| { |
| case INT2OID: |
| return arrow::int16(); |
| case INT4OID: |
| return arrow::int32(); |
| case INT8OID: |
| return arrow::int64(); |
| case FLOAT4OID: |
| return arrow::float32(); |
| case FLOAT8OID: |
| return arrow::float64(); |
| case VARCHAROID: |
| case TEXTOID: |
| return arrow::utf8(); |
| case BYTEAOID: |
| return arrow::binary(); |
| case TIMESTAMPOID: |
| return arrow::timestamp(arrow::TimeUnit::MICRO); |
| default: |
| return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ", |
| attribute_->atttypid); |
| } |
| } |
| |
| arrow::Status convert_value(arrow::ArrayBuilder* builder, Datum datum) const |
| { |
| switch (attribute_->atttypid) |
| { |
| case INT2OID: |
| return static_cast<arrow::Int16Builder*>(builder)->Append( |
| DatumGetInt16(datum)); |
| case INT4OID: |
| return static_cast<arrow::Int32Builder*>(builder)->Append( |
| DatumGetInt32(datum)); |
| case INT8OID: |
| return static_cast<arrow::Int64Builder*>(builder)->Append( |
| DatumGetInt64(datum)); |
| case FLOAT4OID: |
| return static_cast<arrow::FloatBuilder*>(builder)->Append( |
| DatumGetFloat4(datum)); |
| case FLOAT8OID: |
| return static_cast<arrow::DoubleBuilder*>(builder)->Append( |
| DatumGetFloat8(datum)); |
| case VARCHAROID: |
| case TEXTOID: |
| return static_cast<arrow::StringBuilder*>(builder)->Append( |
| VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum)); |
| case BYTEAOID: |
| return static_cast<arrow::BinaryBuilder*>(builder)->Append( |
| VARDATA_ANY(datum), VARSIZE_ANY_EXHDR(datum)); |
| case TIMESTAMPOID: |
| // Arrow uses UNIX epoch (1970-01-01) but PostgreSQL |
| // uses 2000-01-01. |
| return static_cast<arrow::TimestampBuilder*>(builder)->Append( |
| DatumGetTimestamp(datum) + |
| (POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * USECS_PER_DAY); |
| default: |
| return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ", |
| attribute_->atttypid); |
| } |
| } |
| |
| private: |
| Form_pg_attribute attribute_; |
| }; |
| |
| class PreparedStatement { |
| public: |
| explicit PreparedStatement(std::string query) |
| : query_(std::move(query)), parameters_() |
| { |
| } |
| |
| ~PreparedStatement() {} |
| |
| using WriteFunc = std::add_pointer<arrow::Status(void*)>::type; |
| arrow::Status select(WriteFunc write, void* writeData) |
| { |
| for (const auto& recordBatch : parameters_) |
| { |
| SPIExecuteOptions options = {}; |
| std::vector<Oid> pgTypes; |
| ARROW_RETURN_NOT_OK(prepare(options, pgTypes, recordBatch->schema())); |
| auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data()); |
| ScopedPlan scopedPlan(plan); |
| ARROW_RETURN_NOT_OK( |
| execute(plan, recordBatch, options, [&]() { return write(writeData); })); |
| } |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status set_parameters(std::shared_ptr<SharedRingBufferInputStream>& input) |
| { |
| parameters_.clear(); |
| ARROW_ASSIGN_OR_RAISE(auto reader, |
| arrow::ipc::RecordBatchStreamReader::Open(input)); |
| while (true) |
| { |
| std::shared_ptr<arrow::RecordBatch> recordBatch; |
| ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch)); |
| if (!recordBatch) |
| { |
| break; |
| } |
| parameters_.push_back(std::move(recordBatch)); |
| } |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Result<int64_t> update(std::shared_ptr<SharedRingBufferInputStream>& input) |
| { |
| ARROW_ASSIGN_OR_RAISE(auto reader, |
| arrow::ipc::RecordBatchStreamReader::Open(input)); |
| SPIExecuteOptions options = {}; |
| std::vector<Oid> pgTypes; |
| ARROW_RETURN_NOT_OK(prepare(options, pgTypes, reader->schema())); |
| auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), pgTypes.data()); |
| ScopedPlan scopedPlan(plan); |
| |
| int64_t nUpdatedRecords = 0; |
| while (true) |
| { |
| std::shared_ptr<arrow::RecordBatch> recordBatch; |
| ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch)); |
| if (!recordBatch) |
| { |
| break; |
| } |
| ARROW_RETURN_NOT_OK(execute(plan, recordBatch, options, [&nUpdatedRecords]() { |
| nUpdatedRecords += SPI_processed; |
| return arrow::Status::OK(); |
| })); |
| } |
| return nUpdatedRecords; |
| } |
| |
| private: |
| arrow::Status prepare_pg_types(std::vector<Oid>& pgTypes, |
| const std::shared_ptr<arrow::Schema>& schema) |
| { |
| ArrowPGTypeConverter converter; |
| for (const auto& field : schema->fields()) |
| { |
| ARROW_RETURN_NOT_OK(field->type()->Accept(&converter)); |
| pgTypes.push_back(converter.oid()); |
| } |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status prepare(SPIExecuteOptions& options, |
| std::vector<Oid>& pgTypes, |
| const std::shared_ptr<arrow::Schema>& schema) |
| { |
| if (schema->num_fields() > 0) |
| { |
| options.params = makeParamList(schema->num_fields()); |
| } |
| options.read_only = false; |
| options.tcount = 0; |
| ARROW_RETURN_NOT_OK(prepare_pg_types(pgTypes, schema)); |
| for (size_t i = 0; i < pgTypes.size(); ++i) |
| { |
| options.params->params[i].pflags = PARAM_FLAG_CONST; |
| options.params->params[i].ptype = pgTypes[i]; |
| } |
| return arrow::Status::OK(); |
| } |
| |
| template <typename OnSuccessFunc> |
| arrow::Status execute(SPIPlanPtr plan, |
| const std::shared_ptr<arrow::RecordBatch>& recordBatch, |
| SPIExecuteOptions& options, |
| OnSuccessFunc onSuccess) |
| { |
| const auto& columns = recordBatch->columns(); |
| for (int64_t i = 0; i < recordBatch->num_rows(); ++i) |
| { |
| ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options)); |
| auto result = SPI_execute_plan_extended(plan, &options); |
| if (result <= 0) |
| { |
| return arrow::Status::Invalid("failed to run a prepared statement: ", |
| SPI_result_code_string(result), |
| ": ", |
| query_); |
| } |
| ARROW_RETURN_NOT_OK(onSuccess()); |
| } |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status assign_parameters( |
| const std::shared_ptr<arrow::RecordBatch>& recordBatch, |
| int64_t i_row, |
| const std::vector<std::shared_ptr<arrow::Array>>& columns, |
| SPIExecuteOptions& options) |
| { |
| int64_t i_column = 0; |
| for (const auto& column : columns) |
| { |
| auto param = &(options.params->params[i_column]); |
| param->isnull = column->IsNull(i_row); |
| if (!param->isnull) |
| { |
| ArrowPGValueConverter converter(i_row, param->value); |
| ARROW_RETURN_NOT_OK(column->Accept(&converter)); |
| } |
| ++i_column; |
| } |
| return arrow::Status::OK(); |
| } |
| |
| std::string query_; |
| std::vector<std::shared_ptr<arrow::RecordBatch>> parameters_; |
| }; |
| |
| class Executor : public WorkerProcessor { |
| public: |
| explicit Executor(uint64_t sessionID) |
| : WorkerProcessor("executor", true), |
| sessionID_(sessionID), |
| session_(nullptr), |
| connected_(false), |
| closed_(false), |
| nextPreparedStatementID_(1), |
| preparedStatements_() |
| { |
| } |
| |
| ~Executor() |
| { |
| if (!closed_) |
| { |
| close_internal(false); |
| } |
| } |
| |
| void open() |
| { |
| const char* tag = "open"; |
| // pg_usleep(5000000); |
| // pg_usleep(5000000); |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": opening").c_str()); |
| session_ = static_cast<SessionData*>(dshash_find(sessions_, &sessionID_, false)); |
| auto databaseName = |
| static_cast<const char*>(dsa_get_address(area_, session_->databaseName)); |
| auto userName = |
| static_cast<const char*>(dsa_get_address(area_, session_->userName)); |
| auto password = |
| static_cast<const char*>(dsa_get_address(area_, session_->password)); |
| auto clientAddress = |
| static_cast<const char*>(dsa_get_address(area_, session_->clientAddress)); |
| BackgroundWorkerInitializeConnection(databaseName, userName, 0); |
| CurrentResourceOwner = ResourceOwnerCreate(nullptr, "arrow-flight-sql: Executor"); |
| if (!check_password(databaseName, userName, password, clientAddress)) |
| { |
| session_->initialized = true; |
| signal_server(tag); |
| return; |
| } |
| { |
| SharedRingBuffer buffer(&(session_->bufferData), area_); |
| // TODO: Customizable. |
| buffer.allocate(1L * 1024L * 1024L); |
| } |
| SetCurrentStatementStartTimestamp(); |
| SPI_connect(); |
| pgstat_report_activity(STATE_IDLE, NULL); |
| session_->initialized = true; |
| connected_ = true; |
| signal_server(tag); |
| } |
| |
| void close() { close_internal(true); } |
| |
| void signaled() override |
| { |
| Action action; |
| { |
| ProcessorLockGuard lock(this); |
| action = session_->action; |
| session_->action = Action::None; |
| } |
| P("%s: %s: signaled: before: %s", Tag, tag_, action_name(action)); |
| PG_TRY(); |
| { |
| switch (action) |
| { |
| case Action::Select: |
| select(); |
| break; |
| case Action::Update: |
| update(); |
| break; |
| case Action::Prepare: |
| prepare(); |
| break; |
| case Action::ClosePreparedStatement: |
| close_prepared_statement(); |
| break; |
| case Action::SetParameters: |
| set_parameters(); |
| break; |
| case Action::SelectPreparedStatement: |
| select_prepared_statement(); |
| break; |
| case Action::UpdatePreparedStatement: |
| update_prepared_statement(); |
| break; |
| default: |
| Processor::signaled(); |
| break; |
| } |
| } |
| PG_CATCH(); |
| { |
| if (session_ && !DsaPointerIsValid(session_->errorMessage)) |
| { |
| auto error = CopyErrorData(); |
| set_error_message(std::string("failed to run: ") + action_name(action) + |
| ": " + error->message, |
| "unexpected error"); |
| FreeErrorData(error); |
| } |
| pgstat_report_activity(STATE_IDLE, NULL); |
| PG_RE_THROW(); |
| } |
| PG_END_TRY(); |
| pgstat_report_activity(STATE_IDLE, NULL); |
| P("%s: %s: signaled: after: %s", Tag, tag_, action_name(action)); |
| } |
| |
| protected: |
| pid_t peer_pid(SessionData* session) override { return sharedData_->serverPID; } |
| |
| const char* peer_name(SessionData* session) override { return "server"; } |
| |
| private: |
| void signal_server(const char* tag) |
| { |
| if (sharedData_->serverPID == InvalidPid) |
| { |
| return; |
| } |
| P("%s: %s: %s: kill server: %d", Tag, tag_, tag, sharedData_->serverPID); |
| kill(sharedData_->serverPID, SIGUSR1); |
| } |
| |
| void set_error_message(const std::string& message, const char* tag) |
| { |
| if (DsaPointerIsValid(session_->errorMessage)) |
| { |
| return; |
| } |
| { |
| ProcessorLockGuard lock(this); |
| set_shared_string(session_->errorMessage, message); |
| } |
| signal_server(tag); |
| } |
| |
| void close_internal(bool unlockSession) |
| { |
| const char* tag = "close"; |
| closed_ = true; |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": closing").c_str()); |
| preparedStatements_.clear(); |
| if (connected_) |
| { |
| SPI_finish(); |
| { |
| SharedRingBuffer buffer(&(session_->bufferData), area_); |
| buffer.free(); |
| } |
| delete_session(session_); |
| } |
| else |
| { |
| set_error_message("failed to connect", tag); |
| session_->initialized = true; |
| if (unlockSession) |
| { |
| dshash_release_lock(sessions_, session_); |
| } |
| signal_server(tag); |
| } |
| if (CurrentResourceOwner) |
| { |
| auto resourceOwner = CurrentResourceOwner; |
| CurrentResourceOwner = nullptr; |
| ResourceOwnerRelease( |
| resourceOwner, RESOURCE_RELEASE_BEFORE_LOCKS, false, true); |
| ResourceOwnerRelease(resourceOwner, RESOURCE_RELEASE_LOCKS, false, true); |
| ResourceOwnerRelease( |
| resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS, false, true); |
| ResourceOwnerDelete(resourceOwner); |
| } |
| } |
| |
| bool check_password(const char* databaseName, |
| const char* userName, |
| const char* password, |
| const char* clientAddress) |
| { |
| const char* tag = "check password"; |
| MemoryContext memoryContext = |
| AllocSetContextCreate(CurrentMemoryContext, |
| "arrow-flight-sql: Executor::check_password()", |
| ALLOCSET_DEFAULT_SIZES); |
| ScopedMemoryContext scopedMemoryContext(memoryContext); |
| Port port = {}; |
| port.database_name = pstrdup(databaseName); |
| port.user_name = pstrdup(userName); |
| if (!fill_client_address(&port, clientAddress)) |
| { |
| return false; |
| } |
| load_hba(); |
| hba_getauthmethod(&port); |
| if (!port.hba) |
| { |
| set_error_message("failed to get auth method", tag); |
| return false; |
| } |
| switch (port.hba->auth_method) |
| { |
| case uaMD5: |
| // TODO |
| set_error_message("MD5 auth method isn't supported yet", tag); |
| return false; |
| case uaSCRAM: |
| // TODO |
| set_error_message("SCRAM auth method isn't supported yet", tag); |
| return false; |
| case uaPassword: |
| { |
| const char* logDetail = nullptr; |
| auto shadowPassword = get_role_password(port.user_name, &logDetail); |
| if (!shadowPassword) |
| { |
| set_error_message(std::string("failed to get password: ") + logDetail, |
| tag); |
| return false; |
| } |
| auto result = plain_crypt_verify( |
| port.user_name, shadowPassword, password, &logDetail); |
| if (result != STATUS_OK) |
| { |
| set_error_message( |
| std::string("failed to verify password: ") + logDetail, tag); |
| return false; |
| } |
| return true; |
| } |
| case uaTrust: |
| return true; |
| default: |
| set_error_message(std::string("unsupported auth method: ") + |
| hba_authname(port.hba->auth_method), |
| tag); |
| return false; |
| } |
| } |
| |
| bool fill_client_address(Port* port, const char* clientAddress) |
| { |
| const char* tag = "fill client address"; |
| // clientAddress: "ipv4:127.0.0.1:40468" |
| // family: "ipv4" |
| // host: "127.0.0.1" |
| // port: "40468" |
| std::stringstream clientAddressStream{std::string(clientAddress)}; |
| std::string clientFamily(""); |
| std::string clientHost(""); |
| std::string clientPort(""); |
| std::getline(clientAddressStream, clientFamily, ':'); |
| std::getline(clientAddressStream, clientHost, ':'); |
| std::getline(clientAddressStream, clientPort); |
| if (!(clientFamily == "ipv4" || clientFamily == "ipv6")) |
| { |
| set_error_message( |
| std::string("client family must be ipv4 or ipv6: ") + clientFamily, tag); |
| return false; |
| } |
| auto clientPortStart = clientPort.c_str(); |
| char* clientPortEnd = nullptr; |
| auto clientPortNumber = std::strtoul(clientPortStart, &clientPortEnd, 10); |
| if (clientPortEnd[0] != '\0') |
| { |
| set_error_message(std::string("client port is invalid: ") + clientPort, tag); |
| return false; |
| } |
| if (clientPortNumber == 0) |
| { |
| set_error_message(std::string("client port must not 0"), tag); |
| return false; |
| } |
| if (clientPortNumber > 65535) |
| { |
| set_error_message(std::string("client port is too large: ") + |
| std::to_string(clientPortNumber), |
| tag); |
| return false; |
| } |
| if (clientFamily == "ipv4") |
| { |
| auto raddr = reinterpret_cast<sockaddr_in*>(&(port->raddr.addr)); |
| port->raddr.salen = sizeof(sockaddr_in); |
| raddr->sin_family = AF_INET; |
| raddr->sin_port = htons(clientPortNumber); |
| if (inet_pton(AF_INET, clientHost.c_str(), &(raddr->sin_addr)) == 0) |
| { |
| set_error_message( |
| std::string("client IPv4 address is invalid: ") + clientHost, tag); |
| return false; |
| } |
| } |
| else if (clientFamily == "ipv6") |
| { |
| auto raddr = reinterpret_cast<sockaddr_in6*>(&(port->raddr.addr)); |
| port->raddr.salen = sizeof(sockaddr_in6); |
| raddr->sin6_family = AF_INET6; |
| raddr->sin6_port = htons(clientPortNumber); |
| raddr->sin6_flowinfo = 0; |
| if (inet_pton(AF_INET6, clientHost.c_str(), &(raddr->sin6_addr)) == 0) |
| { |
| set_error_message( |
| std::string("client IPv6 address is invalid: ") + clientHost, tag); |
| return false; |
| } |
| raddr->sin6_scope_id = 0; |
| } |
| return true; |
| } |
| |
| void select() |
| { |
| const char* tag = "select"; |
| if (!DsaPointerIsValid(session_->selectQuery)) |
| { |
| set_error_message( |
| std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag); |
| return; |
| } |
| |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str()); |
| |
| std::string query; |
| { |
| ProcessorLockGuard lock(this); |
| query = |
| static_cast<const char*>(dsa_get_address(area_, session_->selectQuery)); |
| dsa_free(area_, session_->selectQuery); |
| session_->selectQuery = InvalidDsaPointer; |
| } |
| P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str()); |
| |
| { |
| ScopedTransaction scopedTransaction; |
| ScopedSnapshot scopedSnapshot; |
| |
| SetCurrentStatementStartTimestamp(); |
| auto result = SPI_execute(query.c_str(), true, 0); |
| |
| if (result > 0) |
| { |
| pgstat_report_activity( |
| STATE_RUNNING, (std::string(Tag) + ": " + tag + ": writing").c_str()); |
| auto status = write(tag); |
| if (status.ok()) |
| { |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to write: " + status.ToString(), |
| tag); |
| } |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to run a query: <" + query + |
| ">: " + SPI_result_code_string(result), |
| tag); |
| } |
| } |
| } |
| |
| arrow::Status write(const char* tag) |
| { |
| SharedRingBufferOutputStream output(this, session_); |
| std::vector<PGArrowValueConverter> converters; |
| std::vector<std::shared_ptr<arrow::Field>> fields; |
| for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i) |
| { |
| auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i); |
| converters.emplace_back(attribute); |
| const auto& converter = converters[converters.size() - 1]; |
| ARROW_ASSIGN_OR_RAISE(auto type, converter.convert_type()); |
| fields.push_back(arrow::field( |
| NameStr(attribute->attname), std::move(type), !attribute->attnotnull)); |
| } |
| auto schema = arrow::schema(fields); |
| ARROW_ASSIGN_OR_RAISE( |
| auto builder, |
| arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool())); |
| auto options = arrow::ipc::IpcWriteOptions::Defaults(); |
| options.emit_dictionary_deltas = true; |
| |
| // Write schema only stream format data to return only schema. |
| ARROW_ASSIGN_OR_RAISE(auto writer, |
| arrow::ipc::MakeStreamWriter(&output, schema, options)); |
| // Build an empty record batch to write schema. |
| ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush()); |
| P("%s: %s: %s: write: schema: WriteRecordBatch", Tag, tag_, tag); |
| ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch)); |
| P("%s: %s: %s: write: schema: Close", Tag, tag_, tag); |
| ARROW_RETURN_NOT_OK(writer->Close()); |
| |
| // Write another stream format data with record batches. |
| ARROW_ASSIGN_OR_RAISE(writer, |
| arrow::ipc::MakeStreamWriter(&output, schema, options)); |
| bool needLastFlush = false; |
| for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple) |
| { |
| P("%s: %s: %s: write: data: record batch: %d/%d", |
| Tag, |
| tag_, |
| tag, |
| iTuple, |
| SPI_processed); |
| for (int iAttribute = 0; iAttribute < SPI_tuptable->tupdesc->natts; |
| ++iAttribute) |
| { |
| P("%s: %s: %s: write: data: record batch: %d/%d: %d/%d", |
| Tag, |
| tag_, |
| tag, |
| iTuple, |
| SPI_processed, |
| iAttribute, |
| SPI_tuptable->tupdesc->natts); |
| bool isNull; |
| auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple], |
| SPI_tuptable->tupdesc, |
| iAttribute + 1, |
| &isNull); |
| auto arrayBuilder = builder->GetField(iAttribute); |
| if (isNull) |
| { |
| ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull()); |
| } |
| else |
| { |
| ARROW_RETURN_NOT_OK( |
| converters[iAttribute].convert_value(arrayBuilder, datum)); |
| } |
| } |
| |
| if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0) |
| { |
| ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush()); |
| P("%s: %s: %s: write: data: WriteRecordBatch: %d/%d", |
| Tag, |
| tag_, |
| tag, |
| iTuple, |
| SPI_processed); |
| ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch)); |
| needLastFlush = false; |
| } |
| else |
| { |
| needLastFlush = true; |
| } |
| } |
| if (needLastFlush) |
| { |
| ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush()); |
| P("%s: %s: %s: write: data: WriteRecordBatch", Tag, tag_, tag); |
| ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch)); |
| } |
| P("%s: %s: %s, write: data: Close", Tag, tag_, tag); |
| ARROW_RETURN_NOT_OK(writer->Close()); |
| return output.Close(); |
| } |
| |
| void update() |
| { |
| const char* tag = "update"; |
| if (!DsaPointerIsValid(session_->updateQuery)) |
| { |
| set_error_message( |
| std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag); |
| return; |
| } |
| |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str()); |
| |
| std::string query; |
| { |
| ProcessorLockGuard lock(this); |
| query = |
| static_cast<const char*>(dsa_get_address(area_, session_->updateQuery)); |
| dsa_free(area_, session_->updateQuery); |
| session_->updateQuery = InvalidDsaPointer; |
| } |
| P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str()); |
| |
| { |
| ScopedTransaction scopedTransaction; |
| ScopedSnapshot scopedSnapshot; |
| |
| SetCurrentStatementStartTimestamp(); |
| auto result = SPI_execute(query.c_str(), false, 0); |
| if (result > 0) |
| { |
| session_->nUpdatedRecords = SPI_processed; |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to run a query: <" + query + |
| ">: " + SPI_result_code_string(result), |
| tag); |
| } |
| } |
| } |
| |
| void prepare() |
| { |
| const char* tag = "prepare"; |
| if (!DsaPointerIsValid(session_->prepareQuery)) |
| { |
| set_error_message( |
| std::string(Tag) + ": " + tag_ + ": " + tag + ": query is missing", tag); |
| return; |
| } |
| |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": preparing").c_str()); |
| |
| std::string query; |
| { |
| ProcessorLockGuard lock(this); |
| query = |
| static_cast<const char*>(dsa_get_address(area_, session_->prepareQuery)); |
| dsa_free(area_, session_->prepareQuery); |
| session_->prepareQuery = InvalidDsaPointer; |
| } |
| P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str()); |
| std::string handle(std::to_string(nextPreparedStatementID_++)); |
| preparedStatements_.insert( |
| std::make_pair(handle, PreparedStatement(std::move(query)))); |
| { |
| ProcessorLockGuard lock(this); |
| set_shared_string(session_->preparedStatementHandle, handle); |
| } |
| signal_server(tag); |
| } |
| |
| bool extract_handle(std::string& handle, const char* tag) |
| { |
| if (!DsaPointerIsValid(session_->preparedStatementHandle)) |
| { |
| set_error_message( |
| std::string(Tag) + ": " + tag_ + ": " + tag + ": handle is missing", tag); |
| return false; |
| } |
| |
| { |
| ProcessorLockGuard lock(this); |
| handle = static_cast<const char*>( |
| dsa_get_address(area_, session_->preparedStatementHandle)); |
| dsa_free(area_, session_->preparedStatementHandle); |
| session_->preparedStatementHandle = InvalidDsaPointer; |
| } |
| |
| return true; |
| } |
| |
| void close_prepared_statement() |
| { |
| const char* tag = "close prepared statement"; |
| |
| pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": " + tag).c_str()); |
| |
| std::string handle; |
| if (!extract_handle(handle, tag)) |
| { |
| return; |
| } |
| P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str()); |
| if (preparedStatements_.erase(handle) > 0) |
| { |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": nonexistent handle: <" + handle + ">", |
| tag); |
| } |
| } |
| |
| PreparedStatement* find_prepared_statement(std::string& handle, const char* tag) |
| { |
| if (!extract_handle(handle, tag)) |
| { |
| return nullptr; |
| } |
| |
| ProcessorLockGuard lock(this); |
| auto it = preparedStatements_.find(handle); |
| if (it == preparedStatements_.end()) |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": nonexistent handle: <" + handle + ">", |
| tag); |
| return nullptr; |
| } |
| else |
| { |
| return &(it->second); |
| } |
| } |
| |
| void set_parameters() |
| { |
| const char* tag = "set parameters"; |
| |
| pgstat_report_activity(STATE_RUNNING, |
| (std::string(Tag) + ": setting parameters").c_str()); |
| std::string handle; |
| auto preparedStatement = find_prepared_statement(handle, tag); |
| P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str()); |
| |
| if (!preparedStatement) |
| { |
| return; |
| } |
| |
| auto input = std::make_shared<SharedRingBufferInputStream>(this, session_); |
| auto status = preparedStatement->set_parameters(input); |
| if (status.ok()) |
| { |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to set parameters: <" + handle + |
| ">: " + status.ToString(), |
| tag); |
| } |
| } |
| |
| void select_prepared_statement() |
| { |
| const char* tag = "select prepared statement"; |
| |
| pgstat_report_activity( |
| STATE_RUNNING, (std::string(Tag) + ": selecting prepared statement").c_str()); |
| |
| std::string handle; |
| auto preparedStatement = find_prepared_statement(handle, tag); |
| P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str()); |
| |
| if (!preparedStatement) |
| { |
| return; |
| } |
| |
| ScopedTransaction scopedTransaction; |
| ScopedSnapshot scopedSnapshot; |
| |
| struct Data { |
| Executor* executor; |
| const char* tag; |
| } data = {this, tag}; |
| auto status = preparedStatement->select( |
| [](void* data) { |
| auto d = static_cast<Data*>(data); |
| return d->executor->write(d->tag); |
| }, |
| &data); |
| if (status.ok()) |
| { |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to select a prepared statement: <" + handle + |
| ">: " + status.ToString(), |
| tag); |
| } |
| } |
| |
| void update_prepared_statement() |
| { |
| const char* tag = "update prepared statement"; |
| |
| pgstat_report_activity( |
| STATE_RUNNING, (std::string(Tag) + ": updating prepared statement").c_str()); |
| |
| std::string handle; |
| auto preparedStatement = find_prepared_statement(handle, tag); |
| P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str()); |
| |
| if (!preparedStatement) |
| { |
| return; |
| } |
| |
| ScopedTransaction scopedTransaction; |
| ScopedSnapshot scopedSnapshot; |
| |
| auto input = std::make_shared<SharedRingBufferInputStream>(this, session_); |
| auto n_updated_records_result = preparedStatement->update(input); |
| if (n_updated_records_result.ok()) |
| { |
| session_->nUpdatedRecords = *n_updated_records_result; |
| signal_server(tag); |
| } |
| else |
| { |
| set_error_message(std::string(Tag) + ": " + tag_ + ": " + tag + |
| ": failed to update a prepared statement: <" + handle + |
| ">: " + n_updated_records_result.status().ToString(), |
| tag); |
| } |
| } |
| |
| uint64_t sessionID_; |
| SessionData* session_; |
| bool connected_; |
| bool closed_; |
| uint64_t nextPreparedStatementID_; |
| std::map<std::string, PreparedStatement> preparedStatements_; |
| }; |
| |
| arrow::Status |
| SharedRingBufferOutputStream::Write(const void* data, int64_t nBytes) |
| { |
| if (ARROW_PREDICT_FALSE(!is_open_)) |
| { |
| return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() + |
| ": SharedRingBufferOutputStream is closed"); |
| } |
| if (ARROW_PREDICT_TRUE(nBytes > 0)) |
| { |
| auto buffer = processor_->create_shared_ring_buffer(session_); |
| size_t rest = static_cast<size_t>(nBytes); |
| while (true) |
| { |
| processor_->lock_acquire(); |
| auto writtenSize = buffer.write(data, rest); |
| processor_->lock_release(); |
| |
| position_ += writtenSize; |
| rest -= writtenSize; |
| data = static_cast<const uint8_t*>(data) + writtenSize; |
| |
| if (ARROW_PREDICT_TRUE(rest == 0)) |
| { |
| break; |
| } |
| |
| ARROW_RETURN_NOT_OK( |
| processor_->wait(session_, &buffer, Processor::WaitMode::Read)); |
| } |
| } |
| return arrow::Status::OK(); |
| } |
| |
| class Proxy : public WorkerProcessor { |
| public: |
| explicit Proxy() |
| : WorkerProcessor("proxy", false), randomSeed_(), randomEngine_(randomSeed_()) |
| { |
| } |
| |
| arrow::Result<uint64_t> connect(const std::string& databaseName, |
| const std::string& userName, |
| const std::string& password, |
| const std::string& clientAddress) |
| { |
| auto session = create_session(databaseName, userName, password, clientAddress); |
| auto id = session->id; |
| dshash_release_lock(sessions_, session); |
| kill(sharedData_->mainPID, SIGUSR1); |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| if (INTERRUPTS_PENDING_CONDITION()) |
| { |
| return true; |
| } |
| session = static_cast<SessionData*>(dshash_find(sessions_, &id, false)); |
| if (!session) |
| { |
| return true; |
| } |
| const auto initialized = session->initialized; |
| dshash_release_lock(sessions_, session); |
| return initialized; |
| }); |
| } |
| session = static_cast<SessionData*>(dshash_find(sessions_, &id, false)); |
| if (!session) |
| { |
| return arrow::Status::Invalid("session is stale: ", id); |
| } |
| SessionReleaser sessionReleaser(sessions_, session); |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| if (INTERRUPTS_PENDING_CONDITION()) |
| { |
| return arrow::Status::Invalid("interrupted"); |
| } |
| return id; |
| } |
| |
| bool is_valid_session(uint64_t sessionID) |
| { |
| auto session = find_session(sessionID); |
| if (session) |
| { |
| dshash_release_lock(sessions_, session); |
| return true; |
| } |
| else |
| { |
| return false; |
| } |
| } |
| |
| arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID, |
| const std::string& query) |
| { |
| const char* tag = "select"; |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| set_shared_string(session->selectQuery, query); |
| session->action = Action::Select; |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| auto buffer = create_shared_ring_buffer(session); |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0; |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: open", Tag, tag_, tag); |
| auto schema = read_schema(session, tag); |
| P("%s: %s: %s: schema", Tag, tag_, tag); |
| return schema; |
| } |
| |
| arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query) |
| { |
| #ifdef AFS_DEBUG |
| const char* tag = "update"; |
| #endif |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->updateQuery, query); |
| session->action = Action::Update; |
| session->nUpdatedRecords = -1; |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || |
| session->nUpdatedRecords >= 0; |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: done: %ld", Tag, tag_, tag, session->nUpdatedRecords); |
| return session->nUpdatedRecords; |
| } |
| |
| arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t sessionID) |
| { |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| auto input = std::make_shared<SharedRingBufferInputStream>(this, session); |
| // Read another stream format data with record batches. |
| return arrow::ipc::RecordBatchStreamReader::Open(input); |
| } |
| |
| arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult> prepare( |
| uint64_t sessionID, const std::string& query) |
| { |
| #ifdef AFS_DEBUG |
| const char* tag = "prepare"; |
| #endif |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->prepareQuery, query); |
| session->action = Action::Prepare; |
| set_shared_string(session->preparedStatementHandle, std::string("")); |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || |
| DsaPointerIsValid(session->preparedStatementHandle); |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| std::string handle(static_cast<const char*>( |
| dsa_get_address(area_, session->preparedStatementHandle))); |
| arrow::flight::sql::ActionCreatePreparedStatementResult result = { |
| nullptr, |
| nullptr, |
| std::move(handle), |
| }; |
| P("%s: %s: %s: done", Tag, tag_, tag); |
| return result; |
| } |
| |
| arrow::Status close_prepared_statement(uint64_t sessionID, const std::string& handle) |
| { |
| #ifdef AFS_DEBUG |
| const char* tag = "close prepared statement"; |
| #endif |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->preparedStatementHandle, handle); |
| session->action = Action::ClosePreparedStatement; |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || |
| !DsaPointerIsValid(session->preparedStatementHandle); |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: done", Tag, tag_, tag); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Status set_parameters(uint64_t sessionID, |
| const std::string& handle, |
| arrow::flight::FlightMessageReader* reader, |
| arrow::flight::FlightMetadataWriter* writer) |
| { |
| #ifdef AFS_DEBUG |
| const char* tag = "set parameters"; |
| #endif |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->preparedStatementHandle, handle); |
| session->action = Action::SetParameters; |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema()); |
| SharedRingBufferOutputStream output(this, session); |
| auto options = arrow::ipc::IpcWriteOptions::Defaults(); |
| options.emit_dictionary_deltas = true; |
| ARROW_ASSIGN_OR_RAISE(auto writer, |
| arrow::ipc::MakeStreamWriter(&output, schema, options)); |
| while (true) |
| { |
| ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next()); |
| if (!chunk.data) |
| { |
| break; |
| } |
| ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data))); |
| } |
| ARROW_RETURN_NOT_OK(writer->Close()); |
| } |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| auto buffer = create_shared_ring_buffer(session); |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || buffer.size() == 0; |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: done", Tag, tag_, tag); |
| return arrow::Status::OK(); |
| } |
| |
| arrow::Result<std::shared_ptr<arrow::Schema>> select_prepared_statement( |
| uint64_t sessionID, const std::string& handle) |
| { |
| const char* tag = "select prepared statement"; |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->preparedStatementHandle, handle); |
| session->action = Action::SelectPreparedStatement; |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| auto buffer = create_shared_ring_buffer(session); |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0; |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: open", Tag, tag_, tag); |
| auto schema = read_schema(session, tag); |
| P("%s: %s: %s: schema", Tag, tag_, tag); |
| return schema; |
| } |
| |
| arrow::Result<int64_t> update_prepared_statement( |
| uint64_t sessionID, |
| const std::string& handle, |
| arrow::flight::FlightMessageReader* reader) |
| { |
| #ifdef AFS_DEBUG |
| const char* tag = "update prepared statement"; |
| #endif |
| auto session = find_session(sessionID); |
| SessionReleaser sessionReleaser(sessions_, session); |
| lock_acquire(); |
| set_shared_string(session->preparedStatementHandle, handle); |
| session->action = Action::UpdatePreparedStatement; |
| session->nUpdatedRecords = -1; |
| lock_release(); |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| ARROW_ASSIGN_OR_RAISE(const auto& schema, reader->GetSchema()); |
| SharedRingBufferOutputStream output(this, session); |
| auto options = arrow::ipc::IpcWriteOptions::Defaults(); |
| options.emit_dictionary_deltas = true; |
| ARROW_ASSIGN_OR_RAISE(auto writer, |
| arrow::ipc::MakeStreamWriter(&output, schema, options)); |
| while (true) |
| { |
| ARROW_ASSIGN_OR_RAISE(const auto& chunk, reader->Next()); |
| if (!chunk.data) |
| { |
| break; |
| } |
| ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data))); |
| } |
| ARROW_RETURN_NOT_OK(writer->Close()); |
| } |
| if (session->executorPID != InvalidPid) |
| { |
| P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, session->executorPID); |
| kill(session->executorPID, SIGUSR1); |
| } |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| conditionVariable_.wait(lock, [&] { |
| P("%s: %s: %s: wait", Tag, tag_, tag); |
| return DsaPointerIsValid(session->errorMessage) || |
| session->nUpdatedRecords >= 0; |
| }); |
| } |
| if (DsaPointerIsValid(session->errorMessage)) |
| { |
| return report_session_error(session); |
| } |
| P("%s: %s: %s: done: %ld", Tag, tag_, tag, session->nUpdatedRecords); |
| return session->nUpdatedRecords; |
| } |
| |
| protected: |
| pid_t peer_pid(SessionData* session) override { return session->executorPID; } |
| |
| const char* peer_name(SessionData* session) override { return "executor"; } |
| |
| private: |
| SessionData* create_session(const std::string& databaseName, |
| const std::string& userName, |
| const std::string& password, |
| const std::string& clientAddress) |
| { |
| lock_acquire(); |
| uint64_t id = 0; |
| SessionData* session = nullptr; |
| do |
| { |
| id = randomEngine_(); |
| if (id == 0) |
| { |
| continue; |
| } |
| bool found = false; |
| session = |
| static_cast<SessionData*>(dshash_find_or_insert(sessions_, &id, &found)); |
| if (!found) |
| { |
| break; |
| } |
| } while (true); |
| session_data_initialize( |
| session, area_, databaseName, userName, password, clientAddress); |
| lock_release(); |
| return session; |
| } |
| |
| SessionData* find_session(uint64_t sessionID) |
| { |
| return static_cast<SessionData*>(dshash_find(sessions_, &sessionID, false)); |
| } |
| |
| arrow::Status report_session_error(SessionData* session) |
| { |
| auto status = arrow::Status::Invalid( |
| static_cast<const char*>(dsa_get_address(area_, session->errorMessage))); |
| P("%s: %s: %s: kill SIGTERM executor: %d", |
| Tag, |
| tag_, |
| AFS_FUNC, |
| session->executorPID); |
| kill(session->executorPID, SIGTERM); |
| return status; |
| } |
| |
| arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(SessionData* session, |
| const char* tag) |
| { |
| auto input = std::make_shared<SharedRingBufferInputStream>(this, session); |
| // Read schema only stream format data. |
| ARROW_ASSIGN_OR_RAISE(auto reader, |
| arrow::ipc::RecordBatchStreamReader::Open(input)); |
| while (true) |
| { |
| std::shared_ptr<arrow::RecordBatch> recordBatch; |
| P("%s: %s: %s: read next", Tag, tag_, tag); |
| ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch)); |
| if (!recordBatch) |
| { |
| break; |
| } |
| } |
| return reader->schema(); |
| } |
| |
| std::random_device randomSeed_; |
| std::mt19937_64 randomEngine_; |
| }; |
| |
| arrow::Result<int64_t> |
| SharedRingBufferInputStream::Read(int64_t nBytes, void* out) |
| { |
| if (ARROW_PREDICT_FALSE(!is_open_)) |
| { |
| return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() + |
| ": SharedRingBufferInputStream is closed"); |
| } |
| auto buffer = processor_->create_shared_ring_buffer(session_); |
| size_t rest = static_cast<size_t>(nBytes); |
| while (true) |
| { |
| processor_->lock_acquire(); |
| auto readBytes = buffer.read(rest, out); |
| processor_->lock_release(); |
| |
| position_ += readBytes; |
| rest -= readBytes; |
| out = static_cast<uint8_t*>(out) + readBytes; |
| if (ARROW_PREDICT_TRUE(rest == 0)) |
| { |
| break; |
| } |
| |
| ARROW_RETURN_NOT_OK( |
| processor_->wait(session_, &buffer, Processor::WaitMode::Written)); |
| if (INTERRUPTS_PENDING_CONDITION()) |
| { |
| return arrow::Status::IOError(std::string(Tag) + ": " + processor_->tag() + |
| ": interrupted"); |
| } |
| } |
| return nBytes; |
| } |
| |
| class MainProcessor : public Processor { |
| public: |
| MainProcessor() : Processor("main", true), sessions_(nullptr) |
| { |
| LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); |
| bool found; |
| sharedData_ = static_cast<SharedData*>( |
| ShmemInitStruct(SharedDataName, sizeof(sharedData_), &found)); |
| if (found) |
| { |
| LWLockRelease(AddinShmemInitLock); |
| ereport(ERROR, |
| errcode(ERRCODE_INTERNAL_ERROR), |
| errmsg("%s: %s: shared data is already created", Tag, tag_)); |
| } |
| sharedData_->trancheID = LWLockNewTrancheId(); |
| sharedData_->sessionsTrancheID = LWLockNewTrancheId(); |
| area_ = dsa_create(sharedData_->trancheID); |
| sharedData_->handle = dsa_get_handle(area_); |
| SessionsParams.tranche_id = sharedData_->sessionsTrancheID; |
| sessions_ = dshash_create(area_, &SessionsParams, nullptr); |
| sharedData_->sessionsHandle = dshash_get_hash_table_handle(sessions_); |
| sharedData_->serverPID = InvalidPid; |
| sharedData_->mainPID = MyProcPid; |
| lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock); |
| LWLockRelease(AddinShmemInitLock); |
| } |
| |
| ~MainProcessor() override { dshash_destroy(sessions_); } |
| |
| BackgroundWorkerHandle* start_server() |
| { |
| BackgroundWorker worker = {}; |
| snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag); |
| snprintf(worker.bgw_type, BGW_MAXLEN, "%s: server", Tag); |
| worker.bgw_flags = BGWORKER_SHMEM_ACCESS; |
| worker.bgw_start_time = BgWorkerStart_ConsistentState; |
| worker.bgw_restart_time = BGW_NEVER_RESTART; |
| snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName); |
| snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server"); |
| worker.bgw_main_arg = 0; |
| worker.bgw_notify_pid = MyProcPid; |
| BackgroundWorkerHandle* handle; |
| if (!RegisterDynamicBackgroundWorker(&worker, &handle)) |
| { |
| ereport(ERROR, |
| errcode(ERRCODE_INTERNAL_ERROR), |
| errmsg("%s: %s: failed to start server", Tag, tag_)); |
| } |
| WaitForBackgroundWorkerStartup(handle, &(sharedData_->serverPID)); |
| return handle; |
| } |
| |
| void process_connect_requests() |
| { |
| dshash_seq_status sessionsStatus; |
| dshash_seq_init(&sessionsStatus, sessions_, false); |
| SessionData* session; |
| while ((session = static_cast<SessionData*>(dshash_seq_next(&sessionsStatus)))) |
| { |
| if (session->initialized) |
| { |
| continue; |
| } |
| |
| BackgroundWorker worker = {}; |
| snprintf( |
| worker.bgw_name, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id); |
| snprintf( |
| worker.bgw_type, BGW_MAXLEN, "%s: executor: %" PRIu64, Tag, session->id); |
| worker.bgw_flags = |
| BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION; |
| worker.bgw_start_time = BgWorkerStart_ConsistentState; |
| worker.bgw_restart_time = BGW_NEVER_RESTART; |
| snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName); |
| snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor"); |
| worker.bgw_main_arg = Int64GetDatum(session->id); |
| worker.bgw_notify_pid = MyProcPid; |
| BackgroundWorkerHandle* handle; |
| if (RegisterDynamicBackgroundWorker(&worker, &handle)) |
| { |
| WaitForBackgroundWorkerStartup(handle, &(session->executorPID)); |
| } |
| else |
| { |
| set_shared_string( |
| session->errorMessage, |
| std::string(Tag) + ": " + tag_ + |
| ": failed to start executor: " + std::to_string(session->id)); |
| } |
| } |
| dshash_seq_term(&sessionsStatus); |
| kill(sharedData_->serverPID, SIGUSR1); |
| } |
| |
| private: |
| dshash_table* sessions_; |
| }; |
| |
| class HeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware { |
| public: |
| explicit HeaderAuthServerMiddleware(uint64_t sessionID) : sessionID_(sessionID) {} |
| |
| void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override |
| { |
| outgoing_headers->AddHeader("authorization", |
| std::string("Bearer ") + std::to_string(sessionID_)); |
| } |
| |
| void CallCompleted(const arrow::Status& status) override {} |
| |
| std::string name() const override { return "HeaderAuthServerMiddleware"; } |
| |
| uint64_t session_id() { return sessionID_; } |
| |
| private: |
| uint64_t sessionID_; |
| }; |
| |
| class HeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory { |
| public: |
| explicit HeaderAuthServerMiddlewareFactory(Proxy* proxy) |
| : arrow::flight::ServerMiddlewareFactory(), proxy_(proxy) |
| { |
| } |
| |
| arrow::Status StartCall( |
| const arrow::flight::CallInfo& info, |
| #if ARROW_VERSION_MAJOR >= 13 |
| const arrow::flight::ServerCallContext& context, |
| #else |
| const arrow::flight::CallHeaders& incoming_headers, |
| #endif |
| std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override |
| { |
| std::string databaseName("postgres"); |
| #if ARROW_VERSION_MAJOR >= 13 |
| const auto& incomingHeaders = context.incoming_headers(); |
| #else |
| const auto& incomingHeaders = incoming_headers; |
| #endif |
| auto databaseHeader = incomingHeaders.find("x-flight-sql-database"); |
| if (databaseHeader != incomingHeaders.end()) |
| { |
| databaseName = databaseHeader->second; |
| } |
| auto authorizationHeader = incomingHeaders.find("authorization"); |
| if (authorizationHeader == incomingHeaders.end()) |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthenticated, |
| "No authorization header"); |
| } |
| auto value = authorizationHeader->second; |
| std::stringstream valueStream{std::string(value)}; |
| std::string type(""); |
| std::getline(valueStream, type, ' '); |
| if (type == "Basic") |
| { |
| std::stringstream decodedStream( |
| arrow::util::base64_decode(value.substr(valueStream.tellg()))); |
| std::string userName(""); |
| std::string password(""); |
| std::getline(decodedStream, userName, ':'); |
| std::getline(decodedStream, password); |
| #if ARROW_VERSION_MAJOR >= 13 |
| const auto& clientAddress = context.peer(); |
| #else |
| // 192.0.0.1 is one of reserved IPv4 addresses for documentation. |
| std::string clientAddress("ipv4:192.0.2.1:2929"); |
| #endif |
| auto sessionIDResult = |
| proxy_->connect(databaseName, userName, password, clientAddress); |
| if (!sessionIDResult.status().ok()) |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthenticated, |
| sessionIDResult.status().ToString()); |
| } |
| auto sessionID = *sessionIDResult; |
| *middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID); |
| return arrow::Status::OK(); |
| } |
| else if (type == "Bearer") |
| { |
| std::string sessionIDString(value.substr(valueStream.tellg())); |
| if (sessionIDString.size() == 0) |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthorized, |
| std::string("invalid Bearer token")); |
| } |
| auto start = sessionIDString.c_str(); |
| char* end = nullptr; |
| uint64_t sessionID = std::strtoull(start, &end, 10); |
| if (end[0] != '\0') |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthorized, |
| std::string("invalid Bearer token")); |
| } |
| if (!proxy_->is_valid_session(sessionID)) |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthorized, |
| std::string("invalid Bearer token")); |
| } |
| *middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID); |
| return arrow::Status::OK(); |
| } |
| else |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthenticated, |
| std::string("authorization header must use Basic or Bearer: <") + type + |
| std::string(">")); |
| } |
| } |
| |
| private: |
| Proxy* proxy_; |
| }; |
| |
| class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase { |
| public: |
| explicit FlightSQLServer(Proxy* proxy) |
| : arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy) |
| { |
| } |
| |
| ~FlightSQLServer() override {} |
| |
| arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::StatementQuery& command, |
| const arrow::flight::FlightDescriptor& descriptor) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& query = command.query; |
| ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query)); |
| ARROW_ASSIGN_OR_RAISE(auto ticket, |
| arrow::flight::sql::CreateStatementQueryTicket(query)); |
| std::vector<arrow::flight::FlightEndpoint> endpoints{ |
| arrow::flight::FlightEndpoint{arrow::flight::Ticket{std::move(ticket)}, {}}}; |
| ARROW_ASSIGN_OR_RAISE( |
| auto result, |
| arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); |
| return std::make_unique<arrow::flight::FlightInfo>(result); |
| } |
| |
| arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::StatementQueryTicket& command) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID)); |
| return std::make_unique<arrow::flight::RecordBatchStream>(reader); |
| } |
| |
| arrow::Result<int64_t> DoPutCommandStatementUpdate( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::StatementUpdate& command) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& query = command.query; |
| return proxy_->update(sessionID, query); |
| } |
| |
| arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult> |
| CreatePreparedStatement( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::ActionCreatePreparedStatementRequest& request) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& query = request.query; |
| return proxy_->prepare(sessionID, query); |
| } |
| |
| arrow::Status ClosePreparedStatement( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::ActionClosePreparedStatementRequest& request) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& handle = request.prepared_statement_handle; |
| return proxy_->close_prepared_statement(sessionID, handle); |
| } |
| |
| arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> |
| GetFlightInfoPreparedStatement( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::PreparedStatementQuery& command, |
| const arrow::flight::FlightDescriptor& descriptor) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& handle = command.prepared_statement_handle; |
| ARROW_ASSIGN_OR_RAISE(auto schema, |
| proxy_->select_prepared_statement(sessionID, handle)); |
| ARROW_ASSIGN_OR_RAISE(auto ticket, |
| arrow::flight::sql::CreateStatementQueryTicket(handle)); |
| std::vector<arrow::flight::FlightEndpoint> endpoints{ |
| arrow::flight::FlightEndpoint{arrow::flight::Ticket{std::move(ticket)}, {}}}; |
| ARROW_ASSIGN_OR_RAISE( |
| auto result, |
| arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); |
| return std::make_unique<arrow::flight::FlightInfo>(result); |
| } |
| |
| arrow::Status DoPutPreparedStatementQuery( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::PreparedStatementQuery& command, |
| arrow::flight::FlightMessageReader* reader, |
| arrow::flight::FlightMetadataWriter* writer) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& handle = command.prepared_statement_handle; |
| return proxy_->set_parameters(sessionID, handle, reader, writer); |
| } |
| |
| arrow::Result<int64_t> DoPutPreparedStatementUpdate( |
| const arrow::flight::ServerCallContext& context, |
| const arrow::flight::sql::PreparedStatementUpdate& command, |
| arrow::flight::FlightMessageReader* reader) override |
| { |
| ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); |
| const auto& handle = command.prepared_statement_handle; |
| return proxy_->update_prepared_statement(sessionID, handle, reader); |
| } |
| |
| private: |
| arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context) |
| { |
| auto middleware = reinterpret_cast<HeaderAuthServerMiddleware*>( |
| context.GetMiddleware("header-auth")); |
| if (!middleware) |
| { |
| return arrow::flight::MakeFlightError( |
| arrow::flight::FlightStatusCode::Unauthenticated, "no authorization"); |
| } |
| return middleware->session_id(); |
| } |
| |
| Proxy* proxy_; |
| }; |
| |
| arrow::Status |
| afs_server_internal(Proxy* proxy) |
| { |
| ARROW_ASSIGN_OR_RAISE(auto location, arrow::flight::Location::Parse(URI)); |
| arrow::flight::FlightServerOptions options(location); |
| if (EnableSSL) |
| { |
| arrow::flight::CertKeyPair certificate; |
| if (ssl_cert_file) |
| { |
| std::ifstream input(ssl_cert_file); |
| if (input) |
| { |
| certificate.pem_cert = |
| std::string(std::istreambuf_iterator<char>{input}, {}); |
| } |
| } |
| if (ssl_key_file) |
| { |
| std::ifstream input(ssl_key_file); |
| if (input) |
| { |
| certificate.pem_key = |
| std::string(std::istreambuf_iterator<char>{input}, {}); |
| } |
| } |
| if (!certificate.pem_cert.empty() && !certificate.pem_key.empty()) |
| { |
| options.tls_certificates.push_back(std::move(certificate)); |
| } |
| } |
| options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>(); |
| options.middleware.push_back( |
| {"header-auth", std::make_shared<HeaderAuthServerMiddlewareFactory>(proxy)}); |
| FlightSQLServer flightSQLServer(proxy); |
| ARROW_RETURN_NOT_OK(flightSQLServer.Init(options)); |
| |
| ereport(LOG, (errmsg("listening on %s for Apache Arrow Flight SQL", URI))); |
| |
| while (!GotSIGTERM) |
| { |
| WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION); |
| ResetLatch(MyLatch); |
| |
| if (GotSIGHUP) |
| { |
| GotSIGHUP = false; |
| ProcessConfigFile(PGC_SIGHUP); |
| } |
| |
| if (GotSIGUSR1) |
| { |
| GotSIGUSR1 = false; |
| proxy->signaled(); |
| } |
| |
| CHECK_FOR_INTERRUPTS(); |
| } |
| |
| // TODO: Use before_shmem_exit() |
| auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10); |
| return flightSQLServer.Shutdown(&deadline); |
| } |
| |
| } // namespace |
| |
| static void |
| afs_executor_before_shmem_exit(int code, Datum arg) |
| { |
| // TODO: This doesn't work. We need to improve |
| // BackgroundWorkerInitializeConnection() failed case. |
| auto executor = reinterpret_cast<Executor*>(DatumGetPointer(arg)); |
| executor->close(); |
| delete executor; |
| } |
| |
| extern "C" void |
| afs_executor(Datum arg) |
| { |
| pqsignal(SIGTERM, afs_sigterm); |
| pqsignal(SIGHUP, afs_sighup); |
| pqsignal(SIGUSR1, afs_sigusr1); |
| BackgroundWorkerUnblockSignals(); |
| |
| auto executor = new Executor(DatumGetInt64(arg)); |
| before_shmem_exit(afs_executor_before_shmem_exit, PointerGetDatum(executor)); |
| executor->open(); |
| |
| while (!GotSIGTERM) |
| { |
| int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH; |
| const long timeout = SessionTimeout * 1000; |
| if (timeout >= 0) |
| { |
| events |= WL_TIMEOUT; |
| } |
| auto conditions = WaitLatch(MyLatch, events, timeout, PG_WAIT_EXTENSION); |
| |
| if (conditions & WL_TIMEOUT) |
| { |
| break; |
| } |
| |
| ResetLatch(MyLatch); |
| |
| if (GotSIGHUP) |
| { |
| GotSIGHUP = false; |
| ProcessConfigFile(PGC_SIGHUP); |
| } |
| |
| if (GotSIGUSR1) |
| { |
| GotSIGUSR1 = false; |
| executor->signaled(); |
| } |
| |
| CHECK_FOR_INTERRUPTS(); |
| } |
| executor->close(); |
| |
| proc_exit(0); |
| } |
| |
| extern "C" void |
| afs_server(Datum arg) |
| { |
| pqsignal(SIGTERM, afs_sigterm); |
| pqsignal(SIGHUP, afs_sighup); |
| pqsignal(SIGUSR1, afs_sigusr1); |
| BackgroundWorkerUnblockSignals(); |
| |
| { |
| arrow::Status status; |
| { |
| Proxy proxy; |
| status = afs_server_internal(&proxy); |
| } |
| if (!status.ok()) |
| { |
| ereport(ERROR, |
| errcode(ERRCODE_INTERNAL_ERROR), |
| errmsg("%s: server: failed: %s", Tag, status.ToString().c_str())); |
| } |
| } |
| |
| proc_exit(0); |
| } |
| |
| extern "C" void |
| afs_main(Datum arg) |
| { |
| pqsignal(SIGTERM, afs_sigterm); |
| pqsignal(SIGHUP, afs_sighup); |
| pqsignal(SIGUSR1, afs_sigusr1); |
| BackgroundWorkerUnblockSignals(); |
| |
| { |
| MainProcessor processor; |
| auto serverHandle = processor.start_server(); |
| while (!GotSIGTERM) |
| { |
| WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, PG_WAIT_EXTENSION); |
| ResetLatch(MyLatch); |
| |
| if (GotSIGHUP) |
| { |
| GotSIGHUP = false; |
| ProcessConfigFile(PGC_SIGHUP); |
| } |
| |
| if (GotSIGUSR1) |
| { |
| GotSIGUSR1 = false; |
| processor.process_connect_requests(); |
| } |
| |
| CHECK_FOR_INTERRUPTS(); |
| } |
| WaitForBackgroundWorkerShutdown(serverHandle); |
| } |
| |
| proc_exit(0); |
| } |
| |
| extern "C" void |
| _PG_init(void) |
| { |
| if (!process_shared_preload_libraries_in_progress) |
| { |
| return; |
| } |
| |
| DefineCustomStringVariable("arrow_flight_sql.uri", |
| "Apache Arrow Flight SQL endpoint URI.", |
| (std::string("default: ") + URIDefault).c_str(), |
| &URI, |
| URIDefault, |
| PGC_POSTMASTER, |
| 0, |
| NULL, |
| NULL, |
| NULL); |
| |
| DefineCustomIntVariable("arrow_flight_sql.session_timeout", |
| "Maximum session duration in seconds.", |
| "The default is 300 seconds. " |
| "-1 means no timeout.", |
| &SessionTimeout, |
| SessionTimeoutDefault, |
| -1, |
| INT_MAX, |
| PGC_USERSET, |
| GUC_UNIT_S, |
| NULL, |
| NULL, |
| NULL); |
| |
| DefineCustomIntVariable("arrow_flight_sql.max_n_rows_per_record_batch", |
| "The maximum number of rows per record batch.", |
| "The default is 1 * 1024 * 1024 rows.", |
| &MaxNRowsPerRecordBatch, |
| MaxNRowsPerRecordBatchDefault, |
| 1, |
| INT_MAX, |
| PGC_USERSET, |
| 0, |
| NULL, |
| NULL, |
| NULL); |
| |
| #ifdef PGRN_HAVE_SHMEM_REQUEST_HOOK |
| PreviousShmemRequestHook = shmem_request_hook; |
| shmem_request_hook = afs_shmem_request_hook; |
| #else |
| afs_shmem_request_hook(); |
| #endif |
| |
| BackgroundWorker worker = {}; |
| snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag); |
| snprintf(worker.bgw_type, BGW_MAXLEN, "%s: main", Tag); |
| worker.bgw_flags = BGWORKER_SHMEM_ACCESS; |
| worker.bgw_start_time = BgWorkerStart_ConsistentState; |
| worker.bgw_restart_time = BGW_NEVER_RESTART; |
| snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName); |
| snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main"); |
| worker.bgw_main_arg = 0; |
| worker.bgw_notify_pid = 0; |
| RegisterBackgroundWorker(&worker); |
| } |