blob: a4d742491a9f7d171368a92d1c3677687b82c0a6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <adbc.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest-matchers.h>
#include <gtest/gtest-param-test.h>
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.h>
#include <algorithm>
#include <cstring>
#include <random>
#include "validation/adbc_validation.h"
#include "validation/adbc_validation_util.h"
using adbc_validation::IsOkStatus;
#define CHECK_OK(EXPR) \
do { \
if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \
return adbc_status; \
} \
} while (false)
namespace {
std::string GetUuid() {
static std::random_device dev;
static std::mt19937 rng(dev());
std::uniform_int_distribution<int> dist(0, 15);
const char* v = "0123456789ABCDEF";
const bool dash[] = {0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0};
std::string res;
for (int i = 0; i < 16; i++) {
if (dash[i]) res += "-";
res += v[dist(rng)];
res += v[dist(rng)];
}
return res;
}
} // namespace
class SnowflakeQuirks : public adbc_validation::DriverQuirks {
public:
SnowflakeQuirks() {
uri_ = std::getenv("ADBC_SNOWFLAKE_URI");
if (uri_ == nullptr || std::strlen(uri_) == 0) {
skip_ = true;
}
}
AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,
struct AdbcError* error) const override {
EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri_, error), IsOkStatus(error));
EXPECT_THAT(AdbcDatabaseSetOption(
database, "adbc.snowflake.sql.client_option.use_high_precision",
"false", error),
IsOkStatus(error));
EXPECT_THAT(AdbcDatabaseSetOption(database, "adbc.snowflake.sql.schema",
schema_.c_str(), error),
IsOkStatus(error));
return ADBC_STATUS_OK;
}
AdbcStatusCode DropTable(struct AdbcConnection* connection, const std::string& name,
struct AdbcError* error) const override {
adbc_validation::Handle<struct AdbcStatement> statement;
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));
std::string drop = "DROP TABLE IF EXISTS \"";
drop += name;
drop += "\"";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, drop.c_str(), error));
CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
CHECK_OK(AdbcStatementRelease(&statement.value, error));
return ADBC_STATUS_OK;
}
AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection,
const std::string& name,
struct AdbcError* error) const override {
adbc_validation::Handle<struct AdbcStatement> statement;
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));
std::string create = "CREATE TABLE \"";
create += name;
create += "\" (int64s INT, strings TEXT)";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error));
CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
std::string insert = "INSERT INTO \"";
insert += name;
insert += "\" VALUES (42, 'foo'), (-42, NULL), (NULL, '')";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, insert.c_str(), error));
CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
CHECK_OK(AdbcStatementRelease(&statement.value, error));
return ADBC_STATUS_OK;
}
ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
switch (ingest_type) {
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_UINT8:
case NANOARROW_TYPE_INT16:
case NANOARROW_TYPE_UINT16:
case NANOARROW_TYPE_INT32:
case NANOARROW_TYPE_UINT32:
case NANOARROW_TYPE_INT64:
case NANOARROW_TYPE_UINT64:
return NANOARROW_TYPE_INT64;
case NANOARROW_TYPE_FLOAT:
case NANOARROW_TYPE_DOUBLE:
return NANOARROW_TYPE_DOUBLE;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
return NANOARROW_TYPE_STRING;
default:
return ingest_type;
}
}
std::string BindParameter(int index) const override { return "?"; }
bool supports_bulk_ingest(const char* /*mode*/) const override { return true; }
bool supports_concurrent_statements() const override { return true; }
bool supports_transactions() const override { return true; }
bool supports_get_sql_info() const override { return false; }
bool supports_get_objects() const override { return true; }
bool supports_metadata_current_catalog() const override { return false; }
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return schema_; }
const char* uri_;
bool skip_{false};
std::string schema_{"ADBC_TESTING"};
};
class SnowflakeTest : public ::testing::Test, public adbc_validation::DatabaseTest {
public:
const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; }
void SetUp() override {
if (quirks_.skip_) {
GTEST_SKIP();
}
ASSERT_NO_FATAL_FAILURE(SetUpTest());
}
void TearDown() override {
if (!quirks_.skip_) {
ASSERT_NO_FATAL_FAILURE(TearDownTest());
}
}
protected:
SnowflakeQuirks quirks_;
};
ADBCV_TEST_DATABASE(SnowflakeTest)
class SnowflakeConnectionTest : public ::testing::Test,
public adbc_validation::ConnectionTest {
public:
const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; }
void SetUp() override {
if (quirks_.skip_) {
GTEST_SKIP();
}
ASSERT_NO_FATAL_FAILURE(SetUpTest());
}
void TearDown() override {
if (!quirks_.skip_) {
ASSERT_NO_FATAL_FAILURE(TearDownTest());
}
}
// Supported, but we don't validate the values
void TestMetadataCurrentCatalog() { GTEST_SKIP(); }
void TestMetadataCurrentDbSchema() { GTEST_SKIP(); }
protected:
SnowflakeQuirks quirks_;
};
ADBCV_TEST_CONNECTION(SnowflakeConnectionTest)
class SnowflakeStatementTest : public ::testing::Test,
public adbc_validation::StatementTest {
public:
const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; }
void SetUp() override {
if (quirks_.skip_) {
GTEST_SKIP();
}
ASSERT_NO_FATAL_FAILURE(SetUpTest());
}
void TearDown() override {
if (!quirks_.skip_) {
ASSERT_NO_FATAL_FAILURE(TearDownTest());
}
}
void TestSqlIngestInterval() { GTEST_SKIP(); }
void TestSqlIngestDuration() { GTEST_SKIP(); }
void TestSqlIngestColumnEscaping() { GTEST_SKIP(); }
public:
// will need to be updated to SetUpTestSuite when gtest is upgraded
static void SetUpTestCase() {
if (quirks_.skip_) {
GTEST_SKIP();
}
struct AdbcError error;
struct AdbcDatabase db;
struct AdbcConnection connection;
struct AdbcStatement statement;
std::memset(&error, 0, sizeof(error));
std::memset(&db, 0, sizeof(db));
std::memset(&connection, 0, sizeof(connection));
std::memset(&statement, 0, sizeof(statement));
ASSERT_THAT(AdbcDatabaseNew(&db, &error), IsOkStatus(&error));
ASSERT_THAT(quirks_.SetupDatabase(&db, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&db, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &db, &error), IsOkStatus(&error));
std::string schema_name = "ADBC_TESTING_" + GetUuid();
std::string query =
"CREATE SCHEMA IDENTIFIER('\"ADBC_TESTING\".\"" + schema_name + "\"')";
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));
quirks_.schema_ = schema_name;
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseRelease(&db, &error), IsOkStatus(&error));
}
// will need to be updated to TearDownTestSuite when gtest is upgraded
static void TearDownTestCase() {
if (quirks_.skip_) {
GTEST_SKIP();
}
struct AdbcError error;
struct AdbcDatabase db;
struct AdbcConnection connection;
struct AdbcStatement statement;
std::memset(&error, 0, sizeof(error));
std::memset(&db, 0, sizeof(db));
std::memset(&connection, 0, sizeof(connection));
std::memset(&statement, 0, sizeof(statement));
ASSERT_THAT(AdbcDatabaseNew(&db, &error), IsOkStatus(&error));
ASSERT_THAT(quirks_.SetupDatabase(&db, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&db, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &db, &error), IsOkStatus(&error));
std::string query =
"DROP SCHEMA IDENTIFIER('\"ADBC_TESTING\".\"" + quirks_.schema_ + "\"')";
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));
quirks_.schema_ = "ADBC_TESTING";
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseRelease(&db, &error), IsOkStatus(&error));
}
protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type,
enum ArrowTimeUnit unit,
const char* timezone) override {
switch (type) {
case NANOARROW_TYPE_TIMESTAMP: {
std::vector<std::optional<int64_t>> expected;
switch (unit) {
case NANOARROW_TIME_UNIT_SECOND:
expected = {std::nullopt, -42, 0, 42};
break;
case NANOARROW_TIME_UNIT_MILLI:
expected = {std::nullopt, -42, 0, 42};
break;
case NANOARROW_TIME_UNIT_MICRO:
expected = {std::nullopt, -42, 0, 42};
break;
case NANOARROW_TIME_UNIT_NANO:
expected = {std::nullopt, -42, 0, 42};
break;
}
ASSERT_NO_FATAL_FAILURE(
adbc_validation::CompareArray<std::int64_t>(values, expected));
break;
}
default:
FAIL() << "ValidateIngestedTemporalData not implemented for type " << type;
}
}
static SnowflakeQuirks quirks_;
};
SnowflakeQuirks SnowflakeStatementTest::quirks_;
ADBCV_TEST_STATEMENT(SnowflakeStatementTest)