blob: 29c5356a04fd77805e91211fb41597100e493821 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <algorithm>
#include <array>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <limits>
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include <gtest/gtest.h>
#include "arrow/array.h"
#include "arrow/buffer-builder.h"
#include "arrow/buffer.h"
#include "arrow/extension_type.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
namespace arrow {
class UUIDArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};
class UUIDType : public ExtensionType {
public:
UUIDType() : ExtensionType(fixed_size_binary(16)) {}
std::string extension_name() const override { return "uuid"; }
bool ExtensionEquals(const ExtensionType& other) const override {
const auto& other_ext = static_cast<const ExtensionType&>(other);
if (other_ext.extension_name() != this->extension_name()) {
return false;
}
return true;
}
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("uuid", static_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<UUIDArray>(data);
}
Status Deserialize(std::shared_ptr<DataType> storage_type,
const std::string& serialized,
std::shared_ptr<DataType>* out) const override {
if (serialized != "uuid-type-unique-code") {
return Status::Invalid("Type identifier did not match");
}
DCHECK(storage_type->Equals(*fixed_size_binary(16)));
*out = std::make_shared<UUIDType>();
return Status::OK();
}
std::string Serialize() const override { return "uuid-type-unique-code"; }
};
std::shared_ptr<DataType> uuid() { return std::make_shared<UUIDType>(); }
class Parametric1Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};
class Parametric2Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};
// A parametric type where the extension_name() is always the same
class Parametric1Type : public ExtensionType {
public:
explicit Parametric1Type(int32_t parameter)
: ExtensionType(int32()), parameter_(parameter) {}
int32_t parameter() const { return parameter_; }
std::string extension_name() const override { return "parametric-type-1"; }
bool ExtensionEquals(const ExtensionType& other) const override {
const auto& other_ext = static_cast<const ExtensionType&>(other);
if (other_ext.extension_name() != this->extension_name()) {
return false;
}
return this->parameter() == static_cast<const Parametric1Type&>(other).parameter();
}
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
return std::make_shared<Parametric1Array>(data);
}
Status Deserialize(std::shared_ptr<DataType> storage_type,
const std::string& serialized,
std::shared_ptr<DataType>* out) const override {
DCHECK_EQ(4, serialized.size());
const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
DCHECK(storage_type->Equals(int32()));
*out = std::make_shared<Parametric1Type>(parameter);
return Status::OK();
}
std::string Serialize() const override {
std::string result(" ");
memcpy(&result[0], &parameter_, sizeof(int32_t));
return result;
}
private:
int32_t parameter_;
};
// A parametric type where the extension_name() is different for each
// parameter, and must be separately registered
class Parametric2Type : public ExtensionType {
public:
explicit Parametric2Type(int32_t parameter)
: ExtensionType(int32()), parameter_(parameter) {}
int32_t parameter() const { return parameter_; }
std::string extension_name() const override {
std::stringstream ss;
ss << "parametric-type-2<param=" << parameter_ << ">";
return ss.str();
}
bool ExtensionEquals(const ExtensionType& other) const override {
const auto& other_ext = static_cast<const ExtensionType&>(other);
if (other_ext.extension_name() != this->extension_name()) {
return false;
}
return this->parameter() == static_cast<const Parametric2Type&>(other).parameter();
}
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
return std::make_shared<Parametric2Array>(data);
}
Status Deserialize(std::shared_ptr<DataType> storage_type,
const std::string& serialized,
std::shared_ptr<DataType>* out) const override {
DCHECK_EQ(4, serialized.size());
const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
DCHECK(storage_type->Equals(int32()));
*out = std::make_shared<Parametric2Type>(parameter);
return Status::OK();
}
std::string Serialize() const override {
std::string result(" ");
memcpy(&result[0], &parameter_, sizeof(int32_t));
return result;
}
private:
int32_t parameter_;
};
class TestExtensionType : public ::testing::Test {
public:
void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UUIDType>())); }
void TearDown() {
if (GetExtensionType("uuid")) {
ASSERT_OK(UnregisterExtensionType("uuid"));
}
}
};
TEST_F(TestExtensionType, ExtensionTypeTest) {
auto type_not_exist = GetExtensionType("uuid-unknown");
ASSERT_EQ(type_not_exist, nullptr);
auto registered_type = GetExtensionType("uuid");
ASSERT_NE(registered_type, nullptr);
auto type = uuid();
ASSERT_EQ(type->id(), Type::EXTENSION);
const auto& ext_type = static_cast<const ExtensionType&>(*type);
std::string serialized = ext_type.Serialize();
std::shared_ptr<DataType> deserialized;
ASSERT_OK(ext_type.Deserialize(fixed_size_binary(16), serialized, &deserialized));
ASSERT_TRUE(deserialized->Equals(*type));
ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
}
auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out) {
std::shared_ptr<io::BufferOutputStream> out_stream;
ASSERT_OK(io::BufferOutputStream::Create(1024, default_memory_pool(), &out_stream));
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, out_stream.get()));
std::shared_ptr<Buffer> complete_ipc_stream;
ASSERT_OK(out_stream->Finish(&complete_ipc_stream));
io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK(ipc::RecordBatchStreamReader::Open(&reader, &batch_reader));
ASSERT_OK(batch_reader->ReadNext(out));
};
std::shared_ptr<Array> ExampleUUID() {
auto storage_type = fixed_size_binary(16);
auto ext_type = uuid();
auto arr = ArrayFromJSON(
storage_type,
"[null, \"abcdefghijklmno0\", \"abcdefghijklmno1\", \"abcdefghijklmno2\"]");
auto ext_data = arr->data()->Copy();
ext_data->type = ext_type;
return MakeArray(ext_data);
}
TEST_F(TestExtensionType, IpcRoundtrip) {
auto ext_arr = ExampleUUID();
auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
std::shared_ptr<RecordBatch> read_batch;
RoundtripBatch(batch, &read_batch);
CompareBatch(*batch, *read_batch, false /* compare_metadata */);
// Wrap type in a ListArray and ensure it also makes it
auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]");
std::shared_ptr<Array> list_arr;
ASSERT_OK(
ListArray::FromArrays(*offsets_arr, *ext_arr, default_memory_pool(), &list_arr));
batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr});
RoundtripBatch(batch, &read_batch);
CompareBatch(*batch, *read_batch, false /* compare_metadata */);
}
TEST_F(TestExtensionType, UnrecognizedExtension) {
auto ext_arr = ExampleUUID();
auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage();
// Write full IPC stream including schema, then unregister type, then read
// and ensure that a plain instance of the storage type is created
std::shared_ptr<io::BufferOutputStream> out_stream;
ASSERT_OK(io::BufferOutputStream::Create(1024, default_memory_pool(), &out_stream));
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, out_stream.get()));
std::shared_ptr<Buffer> complete_ipc_stream;
ASSERT_OK(out_stream->Finish(&complete_ipc_stream));
ASSERT_OK(UnregisterExtensionType("uuid"));
auto ext_metadata =
key_value_metadata({{"ARROW:extension:name", "uuid"},
{"ARROW:extension:metadata", "uuid-type-unique-code"}});
auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});
io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK(ipc::RecordBatchStreamReader::Open(&reader, &batch_reader));
std::shared_ptr<RecordBatch> read_batch;
ASSERT_OK(batch_reader->ReadNext(&read_batch));
CompareBatch(*batch_no_ext, *read_batch);
}
std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type,
const std::string& json_data) {
auto arr = ArrayFromJSON(int32(), json_data);
auto ext_data = arr->data()->Copy();
ext_data->type = type;
return MakeArray(ext_data);
}
TEST_F(TestExtensionType, ParametricTypes) {
auto p1_type = std::make_shared<Parametric1Type>(6);
auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
auto p2_type = std::make_shared<Parametric1Type>(12);
auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]");
auto p3_type = std::make_shared<Parametric2Type>(2);
auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]");
auto p4_type = std::make_shared<Parametric2Type>(3);
auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]");
ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1)));
ASSERT_OK(RegisterExtensionType(p3_type));
ASSERT_OK(RegisterExtensionType(p4_type));
auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type),
field("f2", p3_type), field("f3", p4_type)}),
4, {p1, p2, p3, p4});
std::shared_ptr<RecordBatch> read_batch;
RoundtripBatch(batch, &read_batch);
CompareBatch(*batch, *read_batch, false /* compare_metadata */);
}
} // namespace arrow