| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "adbc_driver_manager.h" |
| #include <adbc.h> |
| |
| #include <algorithm> |
| #include <cstring> |
| #include <string> |
| #include <unordered_map> |
| #include <utility> |
| |
| #if defined(_WIN32) |
| #include <windows.h> // Must come first |
| |
| #include <libloaderapi.h> |
| #include <strsafe.h> |
| #else |
| #include <dlfcn.h> |
| #endif // defined(_WIN32) |
| |
| namespace { |
| |
| // Platform-specific helpers |
| |
| #if defined(_WIN32) |
| /// Append a description of the Windows error to the buffer. |
| void GetWinError(std::string* buffer) { |
| DWORD rc = GetLastError(); |
| LPVOID message; |
| |
| FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | |
| FORMAT_MESSAGE_IGNORE_INSERTS, |
| /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), |
| reinterpret_cast<LPSTR>(&message), /*nSize=*/0, /*Arguments=*/nullptr); |
| |
| (*buffer) += '('; |
| (*buffer) += std::to_string(rc); |
| (*buffer) += ") "; |
| (*buffer) += reinterpret_cast<char*>(message); |
| LocalFree(message); |
| } |
| |
| #endif // defined(_WIN32) |
| |
| // Error handling |
| |
| void ReleaseError(struct AdbcError* error) { |
| if (error) { |
| if (error->message) delete[] error->message; |
| error->message = nullptr; |
| error->release = nullptr; |
| } |
| } |
| |
| void SetError(struct AdbcError* error, const std::string& message) { |
| if (!error) return; |
| if (error->message) { |
| // Append |
| std::string buffer = error->message; |
| buffer.reserve(buffer.size() + message.size() + 1); |
| buffer += '\n'; |
| buffer += message; |
| error->release(error); |
| |
| error->message = new char[buffer.size() + 1]; |
| buffer.copy(error->message, buffer.size()); |
| error->message[buffer.size()] = '\0'; |
| } else { |
| error->message = new char[message.size() + 1]; |
| message.copy(error->message, message.size()); |
| error->message[message.size()] = '\0'; |
| } |
| error->release = ReleaseError; |
| } |
| |
| // Driver state |
| |
| /// Hold the driver DLL and the driver release callback in the driver struct. |
| struct ManagerDriverState { |
| // The original release callback |
| AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); |
| |
| #if defined(_WIN32) |
| // The loaded DLL |
| HMODULE handle; |
| #endif // defined(_WIN32) |
| }; |
| |
| /// Unload the driver DLL. |
| static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { |
| AdbcStatusCode status = ADBC_STATUS_OK; |
| |
| if (!driver->private_manager) return status; |
| ManagerDriverState* state = |
| reinterpret_cast<ManagerDriverState*>(driver->private_manager); |
| |
| if (state->driver_release) { |
| status = state->driver_release(driver, error); |
| } |
| |
| #if defined(_WIN32) |
| // TODO(apache/arrow-adbc#204): causes tests to segfault |
| // if (!FreeLibrary(state->handle)) { |
| // std::string message = "FreeLibrary() failed: "; |
| // GetWinError(&message); |
| // SetError(error, message); |
| // } |
| #endif // defined(_WIN32) |
| |
| driver->private_manager = nullptr; |
| delete state; |
| return status; |
| } |
| |
| // Default stubs |
| |
| AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, |
| const char* value, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, |
| size_t info_codes_length, struct ArrowArrayStream* out, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, const char*, |
| const char*, const char**, const char*, |
| struct ArrowArrayStream*, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, |
| const char*, struct ArrowSchema*, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection*, struct ArrowArrayStream*, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionReadPartition(struct AdbcConnection* connection, |
| const uint8_t* serialized_partition, |
| size_t serialized_length, |
| struct ArrowArrayStream* out, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionRollback(struct AdbcConnection*, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const char*, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, |
| struct ArrowSchema*, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, |
| struct ArrowSchema* schema, |
| struct AdbcPartitions* partitions, |
| int64_t* rows_affected, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, |
| struct ArrowSchema* schema, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementPrepare(struct AdbcStatement*, struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char*, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, size_t, |
| struct AdbcError* error) { |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| /// Temporary state while the database is being configured. |
| struct TempDatabase { |
| std::unordered_map<std::string, std::string> options; |
| std::string driver; |
| // Default name (see adbc.h) |
| std::string entrypoint = "AdbcDriverInit"; |
| AdbcDriverInitFunc init_func = nullptr; |
| }; |
| |
| /// Temporary state while the database is being configured. |
| struct TempConnection { |
| std::unordered_map<std::string, std::string> options; |
| }; |
| } // namespace |
| |
| // Direct implementations of API methods |
| |
| AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { |
| // Allocate a temporary structure to store options pre-Init |
| database->private_data = new TempDatabase(); |
| database->private_driver = nullptr; |
| return ADBC_STATUS_OK; |
| } |
| |
| AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, |
| const char* value, struct AdbcError* error) { |
| if (database->private_driver) { |
| return database->private_driver->DatabaseSetOption(database, key, value, error); |
| } |
| |
| TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data); |
| if (std::strcmp(key, "driver") == 0) { |
| args->driver = value; |
| } else if (std::strcmp(key, "entrypoint") == 0) { |
| args->entrypoint = value; |
| } else { |
| args->options[key] = value; |
| } |
| return ADBC_STATUS_OK; |
| } |
| |
| AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, |
| AdbcDriverInitFunc init_func, |
| struct AdbcError* error) { |
| if (database->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| |
| TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data); |
| args->init_func = init_func; |
| return ADBC_STATUS_OK; |
| } |
| |
| AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { |
| if (!database->private_data) { |
| SetError(error, "Must call AdbcDatabaseNew first"); |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data); |
| if (args->init_func) { |
| // Do nothing |
| } else if (args->driver.empty()) { |
| SetError(error, "Must provide 'driver' parameter"); |
| return ADBC_STATUS_INVALID_ARGUMENT; |
| } |
| |
| database->private_driver = new AdbcDriver; |
| std::memset(database->private_driver, 0, sizeof(AdbcDriver)); |
| AdbcStatusCode status; |
| // So we don't confuse a driver into thinking it's initialized already |
| database->private_data = nullptr; |
| if (args->init_func) { |
| status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, |
| database->private_driver, error); |
| } else { |
| status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), |
| ADBC_VERSION_1_0_0, database->private_driver, error); |
| } |
| if (status != ADBC_STATUS_OK) { |
| // Restore private_data so it will be released by AdbcDatabaseRelease |
| database->private_data = args; |
| if (database->private_driver->release) { |
| database->private_driver->release(database->private_driver, error); |
| } |
| delete database->private_driver; |
| database->private_driver = nullptr; |
| return status; |
| } |
| status = database->private_driver->DatabaseNew(database, error); |
| if (status != ADBC_STATUS_OK) { |
| if (database->private_driver->release) { |
| database->private_driver->release(database->private_driver, error); |
| } |
| delete database->private_driver; |
| database->private_driver = nullptr; |
| return status; |
| } |
| for (const auto& option : args->options) { |
| status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), |
| option.second.c_str(), error); |
| if (status != ADBC_STATUS_OK) { |
| delete args; |
| // Release the database |
| std::ignore = database->private_driver->DatabaseRelease(database, error); |
| if (database->private_driver->release) { |
| database->private_driver->release(database->private_driver, error); |
| } |
| delete database->private_driver; |
| database->private_driver = nullptr; |
| // Should be redundant, but ensure that AdbcDatabaseRelease |
| // below doesn't think that it contains a TempDatabase |
| database->private_data = nullptr; |
| return status; |
| } |
| } |
| delete args; |
| return database->private_driver->DatabaseInit(database, error); |
| } |
| |
| AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, |
| struct AdbcError* error) { |
| if (!database->private_driver) { |
| if (database->private_data) { |
| TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data); |
| delete args; |
| database->private_data = nullptr; |
| return ADBC_STATUS_OK; |
| } |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| auto status = database->private_driver->DatabaseRelease(database, error); |
| if (database->private_driver->release) { |
| database->private_driver->release(database->private_driver, error); |
| } |
| delete database->private_driver; |
| database->private_data = nullptr; |
| database->private_driver = nullptr; |
| return status; |
| } |
| |
| AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionCommit(connection, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, |
| uint32_t* info_codes, size_t info_codes_length, |
| struct ArrowArrayStream* out, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionGetInfo(connection, info_codes, |
| info_codes_length, out, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, |
| const char* catalog, const char* db_schema, |
| const char* table_name, const char** table_types, |
| const char* column_name, |
| struct ArrowArrayStream* stream, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionGetObjects( |
| connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, |
| error); |
| } |
| |
| AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, |
| const char* catalog, const char* db_schema, |
| const char* table_name, |
| struct ArrowSchema* schema, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionGetTableSchema( |
| connection, catalog, db_schema, table_name, schema, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, |
| struct ArrowArrayStream* stream, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, |
| struct AdbcDatabase* database, |
| struct AdbcError* error) { |
| if (!connection->private_data) { |
| SetError(error, "Must call AdbcConnectionNew first"); |
| return ADBC_STATUS_INVALID_STATE; |
| } else if (!database->private_driver) { |
| SetError(error, "Database is not initialized"); |
| return ADBC_STATUS_INVALID_ARGUMENT; |
| } |
| TempConnection* args = reinterpret_cast<TempConnection*>(connection->private_data); |
| connection->private_data = nullptr; |
| std::unordered_map<std::string, std::string> options = std::move(args->options); |
| delete args; |
| |
| auto status = database->private_driver->ConnectionNew(connection, error); |
| if (status != ADBC_STATUS_OK) return status; |
| connection->private_driver = database->private_driver; |
| |
| for (const auto& option : options) { |
| status = database->private_driver->ConnectionSetOption( |
| connection, option.first.c_str(), option.second.c_str(), error); |
| if (status != ADBC_STATUS_OK) return status; |
| } |
| return connection->private_driver->ConnectionInit(connection, database, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, |
| struct AdbcError* error) { |
| // Allocate a temporary structure to store options pre-Init, because |
| // we don't get access to the database (and hence the driver |
| // function table) until then |
| connection->private_data = new TempConnection; |
| connection->private_driver = nullptr; |
| return ADBC_STATUS_OK; |
| } |
| |
| AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, |
| const uint8_t* serialized_partition, |
| size_t serialized_length, |
| struct ArrowArrayStream* out, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionReadPartition( |
| connection, serialized_partition, serialized_length, out, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| if (connection->private_data) { |
| TempConnection* args = reinterpret_cast<TempConnection*>(connection->private_data); |
| delete args; |
| connection->private_data = nullptr; |
| return ADBC_STATUS_OK; |
| } |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| auto status = connection->private_driver->ConnectionRelease(connection, error); |
| connection->private_driver = nullptr; |
| return status; |
| } |
| |
| AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return connection->private_driver->ConnectionRollback(connection, error); |
| } |
| |
| AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, |
| const char* value, struct AdbcError* error) { |
| if (!connection->private_data) { |
| SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| if (!connection->private_driver) { |
| // Init not yet called, save the option |
| TempConnection* args = reinterpret_cast<TempConnection*>(connection->private_data); |
| args->options[key] = value; |
| return ADBC_STATUS_OK; |
| } |
| return connection->private_driver->ConnectionSetOption(connection, key, value, error); |
| } |
| |
| AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, |
| struct ArrowArray* values, struct ArrowSchema* schema, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementBind(statement, values, schema, error); |
| } |
| |
| AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, |
| struct ArrowArrayStream* stream, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementBindStream(statement, stream, error); |
| } |
| |
| // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' |
| AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, |
| ArrowSchema* schema, |
| struct AdbcPartitions* partitions, |
| int64_t* rows_affected, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementExecutePartitions( |
| statement, schema, partitions, rows_affected, error); |
| } |
| |
| AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, |
| struct ArrowArrayStream* out, |
| int64_t* rows_affected, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, |
| error); |
| } |
| |
| AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, |
| struct ArrowSchema* schema, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementGetParameterSchema(statement, schema, error); |
| } |
| |
| AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, |
| struct AdbcStatement* statement, |
| struct AdbcError* error) { |
| if (!connection->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| auto status = connection->private_driver->StatementNew(connection, statement, error); |
| statement->private_driver = connection->private_driver; |
| return status; |
| } |
| |
| AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementPrepare(statement, error); |
| } |
| |
| AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| auto status = statement->private_driver->StatementRelease(statement, error); |
| statement->private_driver = nullptr; |
| return status; |
| } |
| |
| AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, |
| const char* value, struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementSetOption(statement, key, value, error); |
| } |
| |
| AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, |
| const char* query, struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementSetSqlQuery(statement, query, error); |
| } |
| |
| AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, |
| const uint8_t* plan, size_t length, |
| struct AdbcError* error) { |
| if (!statement->private_driver) { |
| return ADBC_STATUS_INVALID_STATE; |
| } |
| return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, |
| error); |
| } |
| |
| const char* AdbcStatusCodeMessage(AdbcStatusCode code) { |
| #define STRINGIFY(s) #s |
| #define STRINGIFY_VALUE(s) STRINGIFY(s) |
| #define CASE(CONSTANT) \ |
| case CONSTANT: \ |
| return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; |
| |
| switch (code) { |
| CASE(ADBC_STATUS_OK); |
| CASE(ADBC_STATUS_UNKNOWN); |
| CASE(ADBC_STATUS_NOT_IMPLEMENTED); |
| CASE(ADBC_STATUS_NOT_FOUND); |
| CASE(ADBC_STATUS_ALREADY_EXISTS); |
| CASE(ADBC_STATUS_INVALID_ARGUMENT); |
| CASE(ADBC_STATUS_INVALID_STATE); |
| CASE(ADBC_STATUS_INVALID_DATA); |
| CASE(ADBC_STATUS_INTEGRITY); |
| CASE(ADBC_STATUS_INTERNAL); |
| CASE(ADBC_STATUS_IO); |
| CASE(ADBC_STATUS_CANCELLED); |
| CASE(ADBC_STATUS_TIMEOUT); |
| CASE(ADBC_STATUS_UNAUTHENTICATED); |
| CASE(ADBC_STATUS_UNAUTHORIZED); |
| default: |
| return "(invalid code)"; |
| } |
| #undef CASE |
| #undef STRINGIFY_VALUE |
| #undef STRINGIFY |
| } |
| |
| AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, |
| int version, void* raw_driver, struct AdbcError* error) { |
| AdbcDriverInitFunc init_func; |
| std::string error_message; |
| |
| if (version != ADBC_VERSION_1_0_0) { |
| SetError(error, "Only ADBC 1.0.0 is supported"); |
| return ADBC_STATUS_NOT_IMPLEMENTED; |
| } |
| |
| auto* driver = reinterpret_cast<struct AdbcDriver*>(raw_driver); |
| |
| if (!entrypoint) { |
| // Default entrypoint (see adbc.h) |
| entrypoint = "AdbcDriverInit"; |
| } |
| |
| #if defined(_WIN32) |
| |
| HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); |
| if (!handle) { |
| error_message += driver_name; |
| error_message += ": LoadLibraryExA() failed: "; |
| GetWinError(&error_message); |
| |
| std::string full_driver_name = driver_name; |
| full_driver_name += ".lib"; |
| handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); |
| if (!handle) { |
| error_message += '\n'; |
| error_message += full_driver_name; |
| error_message += ": LoadLibraryExA() failed: "; |
| GetWinError(&error_message); |
| } |
| } |
| if (!handle) { |
| SetError(error, error_message); |
| return ADBC_STATUS_INTERNAL; |
| } |
| |
| void* load_handle = reinterpret_cast<void*>(GetProcAddress(handle, entrypoint)); |
| init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
| if (!init_func) { |
| std::string message = "GetProcAddress("; |
| message += entrypoint; |
| message += ") failed: "; |
| GetWinError(&message); |
| if (!FreeLibrary(handle)) { |
| message += "\nFreeLibrary() failed: "; |
| GetWinError(&message); |
| } |
| SetError(error, message); |
| return ADBC_STATUS_INTERNAL; |
| } |
| |
| #else |
| |
| #if defined(__APPLE__) |
| static const std::string kPlatformLibraryPrefix = "lib"; |
| static const std::string kPlatformLibrarySuffix = ".dylib"; |
| #else |
| static const std::string kPlatformLibraryPrefix = "lib"; |
| static const std::string kPlatformLibrarySuffix = ".so"; |
| #endif // defined(__APPLE__) |
| |
| void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); |
| if (!handle) { |
| error_message = "dlopen() failed: "; |
| error_message += dlerror(); |
| |
| // If applicable, append the shared library prefix/extension and |
| // try again (this way you don't have to hardcode driver names by |
| // platform in the application) |
| const std::string driver_str = driver_name; |
| |
| std::string full_driver_name; |
| if (driver_str.size() < kPlatformLibraryPrefix.size() || |
| driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != |
| 0) { |
| full_driver_name += kPlatformLibraryPrefix; |
| } |
| full_driver_name += driver_name; |
| if (driver_str.size() < kPlatformLibrarySuffix.size() || |
| driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), |
| kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { |
| full_driver_name += kPlatformLibrarySuffix; |
| } |
| handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); |
| if (!handle) { |
| error_message += "\ndlopen() failed: "; |
| error_message += dlerror(); |
| } |
| } |
| if (!handle) { |
| SetError(error, error_message); |
| // AdbcDatabaseInit tries to call this if set |
| driver->release = nullptr; |
| return ADBC_STATUS_INTERNAL; |
| } |
| |
| void* load_handle = dlsym(handle, entrypoint); |
| if (!load_handle) { |
| std::string message = "dlsym("; |
| message += entrypoint; |
| message += ") failed: "; |
| message += dlerror(); |
| SetError(error, message); |
| return ADBC_STATUS_INTERNAL; |
| } |
| init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
| |
| #endif // defined(_WIN32) |
| |
| AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); |
| if (status == ADBC_STATUS_OK) { |
| ManagerDriverState* state = new ManagerDriverState; |
| state->driver_release = driver->release; |
| #if defined(_WIN32) |
| state->handle = handle; |
| #endif // defined(_WIN32) |
| driver->release = &ReleaseDriver; |
| driver->private_manager = state; |
| } else { |
| #if defined(_WIN32) |
| if (!FreeLibrary(handle)) { |
| std::string message = "FreeLibrary() failed: "; |
| GetWinError(&message); |
| SetError(error, message); |
| } |
| #endif // defined(_WIN32) |
| } |
| return status; |
| } |
| |
| AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, |
| void* raw_driver, struct AdbcError* error) { |
| #define FILL_DEFAULT(DRIVER, STUB) \ |
| if (!DRIVER->STUB) { \ |
| DRIVER->STUB = &STUB; \ |
| } |
| #define CHECK_REQUIRED(DRIVER, STUB) \ |
| if (!DRIVER->STUB) { \ |
| SetError(error, "Driver does not implement required function Adbc" #STUB); \ |
| return ADBC_STATUS_INTERNAL; \ |
| } |
| |
| auto result = init_func(version, raw_driver, error); |
| if (result != ADBC_STATUS_OK) { |
| return result; |
| } |
| |
| if (version == ADBC_VERSION_1_0_0) { |
| auto* driver = reinterpret_cast<struct AdbcDriver*>(raw_driver); |
| CHECK_REQUIRED(driver, DatabaseNew); |
| CHECK_REQUIRED(driver, DatabaseInit); |
| CHECK_REQUIRED(driver, DatabaseRelease); |
| FILL_DEFAULT(driver, DatabaseSetOption); |
| |
| CHECK_REQUIRED(driver, ConnectionNew); |
| CHECK_REQUIRED(driver, ConnectionInit); |
| CHECK_REQUIRED(driver, ConnectionRelease); |
| FILL_DEFAULT(driver, ConnectionCommit); |
| FILL_DEFAULT(driver, ConnectionGetInfo); |
| FILL_DEFAULT(driver, ConnectionGetObjects); |
| FILL_DEFAULT(driver, ConnectionGetTableSchema); |
| FILL_DEFAULT(driver, ConnectionGetTableTypes); |
| FILL_DEFAULT(driver, ConnectionReadPartition); |
| FILL_DEFAULT(driver, ConnectionRollback); |
| FILL_DEFAULT(driver, ConnectionSetOption); |
| |
| FILL_DEFAULT(driver, StatementExecutePartitions); |
| CHECK_REQUIRED(driver, StatementExecuteQuery); |
| CHECK_REQUIRED(driver, StatementNew); |
| CHECK_REQUIRED(driver, StatementRelease); |
| FILL_DEFAULT(driver, StatementBind); |
| FILL_DEFAULT(driver, StatementGetParameterSchema); |
| FILL_DEFAULT(driver, StatementPrepare); |
| FILL_DEFAULT(driver, StatementSetOption); |
| FILL_DEFAULT(driver, StatementSetSqlQuery); |
| FILL_DEFAULT(driver, StatementSetSubstraitPlan); |
| } |
| |
| return ADBC_STATUS_OK; |
| |
| #undef FILL_DEFAULT |
| #undef CHECK_REQUIRED |
| } |