blob: 45a3d3e3cd81ea134217f32b8150df47e72e030c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/ipc/reader.h"
#include <algorithm>
#include <climits>
#include <cstdint>
#include <cstring>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include <flatbuffers/flatbuffers.h> // IWYU pragma: export
#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/extension_type.h"
#include "arrow/io/interfaces.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/message.h"
#include "arrow/ipc/metadata_internal.h"
#include "arrow/ipc/util.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/sparse_tensor.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/compression.h"
#include "arrow/util/endian.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/parallel.h"
#include "arrow/util/string.h"
#include "arrow/util/ubsan.h"
#include "arrow/visitor_inline.h"
#include "generated/File_generated.h" // IWYU pragma: export
#include "generated/Message_generated.h"
#include "generated/Schema_generated.h"
#include "generated/SparseTensor_generated.h"
namespace arrow {
namespace flatbuf = org::apache::arrow::flatbuf;
using internal::checked_cast;
using internal::checked_pointer_cast;
using internal::GetByteWidth;
namespace ipc {
using internal::FileBlock;
using internal::kArrowMagicBytes;
namespace {
enum class DictionaryKind { New, Delta, Replacement };
Status InvalidMessageType(MessageType expected, MessageType actual) {
return Status::IOError("Expected IPC message of type ", FormatMessageType(expected),
" but got ", FormatMessageType(actual));
}
#define CHECK_MESSAGE_TYPE(expected, actual) \
do { \
if ((actual) != (expected)) { \
return InvalidMessageType((expected), (actual)); \
} \
} while (0)
#define CHECK_HAS_BODY(message) \
do { \
if ((message).body() == nullptr) { \
return Status::IOError("Expected body in IPC message of type ", \
FormatMessageType((message).type())); \
} \
} while (0)
#define CHECK_HAS_NO_BODY(message) \
do { \
if ((message).body_length() != 0) { \
return Status::IOError("Unexpected body in IPC message of type ", \
FormatMessageType((message).type())); \
} \
} while (0)
} // namespace
// ----------------------------------------------------------------------
// Record batch read path
/// \brief Structure to keep common arguments to be passed
struct IpcReadContext {
IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap,
MetadataVersion version = MetadataVersion::V5,
Compression::type kind = Compression::UNCOMPRESSED)
: dictionary_memo(memo),
options(option),
metadata_version(version),
compression(kind),
swap_endian(swap) {}
DictionaryMemo* dictionary_memo;
const IpcReadOptions& options;
MetadataVersion metadata_version;
Compression::type compression;
/// \brief LoadRecordBatch() or LoadRecordBatchSubset() swaps endianness of elements
/// if this flag is true
const bool swap_endian;
};
/// The field_index and buffer_index are incremented based on how much of the
/// batch is "consumed" (through nested data reconstruction, for example)
class ArrayLoader {
public:
explicit ArrayLoader(const flatbuf::RecordBatch* metadata,
MetadataVersion metadata_version, const IpcReadOptions& options,
io::RandomAccessFile* file)
: metadata_(metadata),
metadata_version_(metadata_version),
file_(file),
max_recursion_depth_(options.max_recursion_depth) {}
Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) {
if (skip_io_) {
return Status::OK();
}
if (offset < 0) {
return Status::Invalid("Negative offset for reading buffer ", buffer_index_);
}
if (length < 0) {
return Status::Invalid("Negative length for reading buffer ", buffer_index_);
}
// This construct permits overriding GetBuffer at compile time
if (!BitUtil::IsMultipleOf8(offset)) {
return Status::Invalid("Buffer ", buffer_index_,
" did not start on 8-byte aligned offset: ", offset);
}
return file_->ReadAt(offset, length).Value(out);
}
Status LoadType(const DataType& type) { return VisitTypeInline(type, this); }
Status Load(const Field* field, ArrayData* out) {
if (max_recursion_depth_ <= 0) {
return Status::Invalid("Max recursion depth reached");
}
field_ = field;
out_ = out;
out_->type = field_->type();
return LoadType(*field_->type());
}
Status SkipField(const Field* field) {
ArrayData dummy;
skip_io_ = true;
Status status = Load(field, &dummy);
skip_io_ = false;
return status;
}
Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
auto buffers = metadata_->buffers();
CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers");
if (buffer_index >= static_cast<int>(buffers->size())) {
return Status::IOError("buffer_index out of range.");
}
const flatbuf::Buffer* buffer = buffers->Get(buffer_index);
if (buffer->length() == 0) {
// Should never return a null buffer here.
// (zero-sized buffer allocations are cheap)
return AllocateBuffer(0).Value(out);
} else {
return ReadBuffer(buffer->offset(), buffer->length(), out);
}
}
Status GetFieldMetadata(int field_index, ArrayData* out) {
auto nodes = metadata_->nodes();
CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes");
// pop off a field
if (field_index >= static_cast<int>(nodes->size())) {
return Status::Invalid("Ran out of field metadata, likely malformed");
}
const flatbuf::FieldNode* node = nodes->Get(field_index);
out->length = node->length();
out->null_count = node->null_count();
out->offset = 0;
return Status::OK();
}
Status LoadCommon(Type::type type_id) {
// This only contains the length and null count, which we need to figure
// out what to do with the buffers. For example, if null_count == 0, then
// we can skip that buffer without reading from shared memory
RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_));
if (internal::HasValidityBitmap(type_id, metadata_version_)) {
// Extract null_bitmap which is common to all arrays except for unions
// and nulls.
if (out_->null_count != 0) {
RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0]));
}
buffer_index_++;
}
return Status::OK();
}
template <typename TYPE>
Status LoadPrimitive(Type::type type_id) {
out_->buffers.resize(2);
RETURN_NOT_OK(LoadCommon(type_id));
if (out_->length > 0) {
RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
} else {
buffer_index_++;
out_->buffers[1].reset(new Buffer(nullptr, 0));
}
return Status::OK();
}
template <typename TYPE>
Status LoadBinary(Type::type type_id) {
out_->buffers.resize(3);
RETURN_NOT_OK(LoadCommon(type_id));
RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
return GetBuffer(buffer_index_++, &out_->buffers[2]);
}
template <typename TYPE>
Status LoadList(const TYPE& type) {
out_->buffers.resize(2);
RETURN_NOT_OK(LoadCommon(type.id()));
RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
const int num_children = type.num_fields();
if (num_children != 1) {
return Status::Invalid("Wrong number of children: ", num_children);
}
return LoadChildren(type.fields());
}
Status LoadChildren(const std::vector<std::shared_ptr<Field>>& child_fields) {
ArrayData* parent = out_;
parent->child_data.resize(child_fields.size());
for (int i = 0; i < static_cast<int>(child_fields.size()); ++i) {
parent->child_data[i] = std::make_shared<ArrayData>();
--max_recursion_depth_;
RETURN_NOT_OK(Load(child_fields[i].get(), parent->child_data[i].get()));
++max_recursion_depth_;
}
out_ = parent;
return Status::OK();
}
Status Visit(const NullType& type) {
out_->buffers.resize(1);
// ARROW-6379: NullType has no buffers in the IPC payload
return GetFieldMetadata(field_index_++, out_);
}
template <typename T>
enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
!std::is_base_of<FixedSizeBinaryType, T>::value &&
!std::is_base_of<DictionaryType, T>::value,
Status>
Visit(const T& type) {
return LoadPrimitive<T>(type.id());
}
template <typename T>
enable_if_base_binary<T, Status> Visit(const T& type) {
return LoadBinary<T>(type.id());
}
Status Visit(const FixedSizeBinaryType& type) {
out_->buffers.resize(2);
RETURN_NOT_OK(LoadCommon(type.id()));
return GetBuffer(buffer_index_++, &out_->buffers[1]);
}
template <typename T>
enable_if_var_size_list<T, Status> Visit(const T& type) {
return LoadList(type);
}
Status Visit(const MapType& type) {
RETURN_NOT_OK(LoadList(type));
return MapArray::ValidateChildData(out_->child_data);
}
Status Visit(const FixedSizeListType& type) {
out_->buffers.resize(1);
RETURN_NOT_OK(LoadCommon(type.id()));
const int num_children = type.num_fields();
if (num_children != 1) {
return Status::Invalid("Wrong number of children: ", num_children);
}
return LoadChildren(type.fields());
}
Status Visit(const StructType& type) {
out_->buffers.resize(1);
RETURN_NOT_OK(LoadCommon(type.id()));
return LoadChildren(type.fields());
}
Status Visit(const UnionType& type) {
int n_buffers = type.mode() == UnionMode::SPARSE ? 2 : 3;
out_->buffers.resize(n_buffers);
RETURN_NOT_OK(LoadCommon(type.id()));
// With metadata V4, we can get a validity bitmap.
// Trying to fix up union data to do without the top-level validity bitmap
// is hairy:
// - type ids must be rewritten to all have valid values (even for former
// null slots)
// - sparse union children must have their validity bitmaps rewritten
// by ANDing the top-level validity bitmap
// - dense union children must be rewritten (at least one of them)
// to insert the required null slots that were formerly omitted
// So instead we bail out.
if (out_->null_count != 0 && out_->buffers[0] != nullptr) {
return Status::Invalid(
"Cannot read pre-1.0.0 Union array with top-level validity bitmap");
}
out_->buffers[0] = nullptr;
out_->null_count = 0;
if (out_->length > 0) {
RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1]));
if (type.mode() == UnionMode::DENSE) {
RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2]));
}
}
buffer_index_ += n_buffers - 1;
return LoadChildren(type.fields());
}
Status Visit(const DictionaryType& type) {
// out_->dictionary will be filled later in ResolveDictionaries()
return LoadType(*type.index_type());
}
Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); }
private:
const flatbuf::RecordBatch* metadata_;
const MetadataVersion metadata_version_;
io::RandomAccessFile* file_;
int max_recursion_depth_;
int buffer_index_ = 0;
int field_index_ = 0;
bool skip_io_ = false;
const Field* field_;
ArrayData* out_;
};
Result<std::shared_ptr<Buffer>> DecompressBuffer(const std::shared_ptr<Buffer>& buf,
const IpcReadOptions& options,
util::Codec* codec) {
if (buf == nullptr || buf->size() == 0) {
return buf;
}
if (buf->size() < 8) {
return Status::Invalid(
"Likely corrupted message, compressed buffers "
"are larger than 8 bytes by construction");
}
const uint8_t* data = buf->data();
int64_t compressed_size = buf->size() - sizeof(int64_t);
int64_t uncompressed_size = BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data));
ARROW_ASSIGN_OR_RAISE(auto uncompressed,
AllocateBuffer(uncompressed_size, options.memory_pool));
ARROW_ASSIGN_OR_RAISE(
int64_t actual_decompressed,
codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size,
uncompressed->mutable_data()));
if (actual_decompressed != uncompressed_size) {
return Status::Invalid("Failed to fully decompress buffer, expected ",
uncompressed_size, " bytes but decompressed ",
actual_decompressed);
}
return std::move(uncompressed);
}
Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options,
ArrayDataVector* fields) {
struct BufferAccumulator {
using BufferPtrVector = std::vector<std::shared_ptr<Buffer>*>;
void AppendFrom(const ArrayDataVector& fields) {
for (const auto& field : fields) {
for (auto& buffer : field->buffers) {
buffers_.push_back(&buffer);
}
AppendFrom(field->child_data);
}
}
BufferPtrVector Get(const ArrayDataVector& fields) && {
AppendFrom(fields);
return std::move(buffers_);
}
BufferPtrVector buffers_;
};
// Flatten all buffers
auto buffers = BufferAccumulator{}.Get(*fields);
std::unique_ptr<util::Codec> codec;
ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression));
return ::arrow::internal::OptionalParallelFor(
options.use_threads, static_cast<int>(buffers.size()), [&](int i) {
ARROW_ASSIGN_OR_RAISE(*buffers[i],
DecompressBuffer(*buffers[i], options, codec.get()));
return Status::OK();
});
}
Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
const std::vector<bool>* inclusion_mask, const IpcReadContext& context,
io::RandomAccessFile* file) {
ArrayLoader loader(metadata, context.metadata_version, context.options, file);
ArrayDataVector columns(schema->num_fields());
ArrayDataVector filtered_columns;
FieldVector filtered_fields;
std::shared_ptr<Schema> filtered_schema;
for (int i = 0; i < schema->num_fields(); ++i) {
const Field& field = *schema->field(i);
if (!inclusion_mask || (*inclusion_mask)[i]) {
// Read field
auto column = std::make_shared<ArrayData>();
RETURN_NOT_OK(loader.Load(&field, column.get()));
if (metadata->length() != column->length) {
return Status::IOError("Array length did not match record batch length");
}
columns[i] = std::move(column);
if (inclusion_mask) {
filtered_columns.push_back(columns[i]);
filtered_fields.push_back(schema->field(i));
}
} else {
// Skip field. This logic must be executed to advance the state of the
// loader to the next field
RETURN_NOT_OK(loader.SkipField(&field));
}
}
// Dictionary resolution needs to happen on the unfiltered columns,
// because fields are mapped structurally (by path in the original schema).
RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
context.options.memory_pool));
if (inclusion_mask) {
filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
columns.clear();
} else {
filtered_schema = schema;
filtered_columns = std::move(columns);
}
if (context.compression != Compression::UNCOMPRESSED) {
RETURN_NOT_OK(
DecompressBuffers(context.compression, context.options, &filtered_columns));
}
// swap endian in a set of ArrayData if necessary (swap_endian == true)
if (context.swap_endian) {
for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) {
ARROW_ASSIGN_OR_RAISE(filtered_columns[i],
arrow::internal::SwapEndianArrayData(filtered_columns[i]));
}
}
return RecordBatch::Make(std::move(filtered_schema), metadata->length(),
std::move(filtered_columns));
}
Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
const std::vector<bool>& inclusion_mask, const IpcReadContext& context,
io::RandomAccessFile* file) {
if (inclusion_mask.size() > 0) {
return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file);
} else {
return LoadRecordBatchSubset(metadata, schema, /*param_name=*/nullptr, context, file);
}
}
// ----------------------------------------------------------------------
// Array loading
Status GetCompression(const flatbuf::RecordBatch* batch, Compression::type* out) {
*out = Compression::UNCOMPRESSED;
const flatbuf::BodyCompression* compression = batch->compression();
if (compression != nullptr) {
if (compression->method() != flatbuf::BodyCompressionMethod::BUFFER) {
// Forward compatibility
return Status::Invalid("This library only supports BUFFER compression method");
}
if (compression->codec() == flatbuf::CompressionType::LZ4_FRAME) {
*out = Compression::LZ4_FRAME;
} else if (compression->codec() == flatbuf::CompressionType::ZSTD) {
*out = Compression::ZSTD;
} else {
return Status::Invalid("Unsupported codec in RecordBatch::compression metadata");
}
return Status::OK();
}
return Status::OK();
}
Status GetCompressionExperimental(const flatbuf::Message* message,
Compression::type* out) {
*out = Compression::UNCOMPRESSED;
if (message->custom_metadata() != nullptr) {
// TODO: Ensure this deserialization only ever happens once
std::shared_ptr<KeyValueMetadata> metadata;
RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata));
int index = metadata->FindKey("ARROW:experimental_compression");
if (index != -1) {
// Arrow 0.17 stored string in upper case, internal utils now require lower case
auto name = arrow::internal::AsciiToLower(metadata->value(index));
ARROW_ASSIGN_OR_RAISE(*out, util::Codec::GetCompressionType(name));
}
return internal::CheckCompressionSupported(*out);
}
return Status::OK();
}
static Status ReadContiguousPayload(io::InputStream* file,
std::unique_ptr<Message>* message) {
ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file));
if (*message == nullptr) {
return Status::Invalid("Unable to read metadata at offset");
}
return Status::OK();
}
Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
const IpcReadOptions& options, io::InputStream* file) {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadContiguousPayload(file, &message));
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options,
reader.get());
}
Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
const Message& message, const std::shared_ptr<Schema>& schema,
const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) {
CHECK_MESSAGE_TYPE(MessageType::RECORD_BATCH, message.type());
CHECK_HAS_BODY(message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options,
reader.get());
}
Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
const Buffer& metadata, const std::shared_ptr<Schema>& schema,
const std::vector<bool>& inclusion_mask, IpcReadContext& context,
io::RandomAccessFile* file) {
const flatbuf::Message* message = nullptr;
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
auto batch = message->header_as_RecordBatch();
if (batch == nullptr) {
return Status::IOError(
"Header-type of flatbuffer-encoded Message is not RecordBatch.");
}
Compression::type compression;
RETURN_NOT_OK(GetCompression(batch, &compression));
if (context.compression == Compression::UNCOMPRESSED &&
message->version() == flatbuf::MetadataVersion::V4) {
// Possibly obtain codec information from experimental serialization format
// in 0.17.x
RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
}
context.compression = compression;
context.metadata_version = internal::GetMetadataVersion(message->version());
return LoadRecordBatch(batch, schema, inclusion_mask, context, file);
}
// If we are selecting only certain fields, populate an inclusion mask for fast lookups.
// Additionally, drop deselected fields from the reader's schema.
Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema,
const std::vector<int>& included_indices,
std::vector<bool>* inclusion_mask,
std::shared_ptr<Schema>* out_schema) {
inclusion_mask->clear();
if (included_indices.empty()) {
*out_schema = full_schema;
return Status::OK();
}
inclusion_mask->resize(full_schema->num_fields(), false);
auto included_indices_sorted = included_indices;
std::sort(included_indices_sorted.begin(), included_indices_sorted.end());
FieldVector included_fields;
for (int i : included_indices_sorted) {
// Ignore out of bounds indices
if (i < 0 || i >= full_schema->num_fields()) {
return Status::Invalid("Out of bounds field index: ", i);
}
if (inclusion_mask->at(i)) continue;
inclusion_mask->at(i) = true;
included_fields.push_back(full_schema->field(i));
}
*out_schema = schema(std::move(included_fields), full_schema->endianness(),
full_schema->metadata());
return Status::OK();
}
Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options,
DictionaryMemo* dictionary_memo,
std::shared_ptr<Schema>* schema,
std::shared_ptr<Schema>* out_schema,
std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
// If we are selecting only certain fields, populate the inclusion mask now
// for fast lookups
RETURN_NOT_OK(GetInclusionMaskAndOutSchema(*schema, options.included_fields,
field_inclusion_mask, out_schema));
*swap_endian = options.ensure_native_endian && !out_schema->get()->is_native_endian();
if (*swap_endian) {
// create a new schema with native endianness before swapping endian in ArrayData
*schema = schema->get()->WithEndianness(Endianness::Native);
*out_schema = out_schema->get()->WithEndianness(Endianness::Native);
}
return Status::OK();
}
Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options,
DictionaryMemo* dictionary_memo,
std::shared_ptr<Schema>* schema,
std::shared_ptr<Schema>* out_schema,
std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message.type());
CHECK_HAS_NO_BODY(message);
return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema,
out_schema, field_inclusion_mask, swap_endian);
}
Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
const Buffer& metadata, const std::shared_ptr<Schema>& schema,
const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
io::RandomAccessFile* file) {
std::shared_ptr<Schema> out_schema;
// Empty means do not use
std::vector<bool> inclusion_mask;
IpcReadContext context(const_cast<DictionaryMemo*>(dictionary_memo), options, false);
RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
&inclusion_mask, &out_schema));
return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file);
}
Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
DictionaryKind* kind, io::RandomAccessFile* file) {
const flatbuf::Message* message = nullptr;
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
const auto dictionary_batch = message->header_as_DictionaryBatch();
if (dictionary_batch == nullptr) {
return Status::IOError(
"Header-type of flatbuffer-encoded Message is not DictionaryBatch.");
}
// The dictionary is embedded in a record batch with a single column
const auto batch_meta = dictionary_batch->data();
CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
Compression::type compression;
RETURN_NOT_OK(GetCompression(batch_meta, &compression));
if (compression == Compression::UNCOMPRESSED &&
message->version() == flatbuf::MetadataVersion::V4) {
// Possibly obtain codec information from experimental serialization format
// in 0.17.x
RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
}
const int64_t id = dictionary_batch->id();
// Look up the dictionary value type, which must have been added to the
// DictionaryMemo already prior to invoking this function
ARROW_ASSIGN_OR_RAISE(auto value_type, context.dictionary_memo->GetDictionaryType(id));
// Load the dictionary data from the dictionary batch
ArrayLoader loader(batch_meta, internal::GetMetadataVersion(message->version()),
context.options, file);
auto dict_data = std::make_shared<ArrayData>();
const Field dummy_field("", value_type);
RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get()));
if (compression != Compression::UNCOMPRESSED) {
ArrayDataVector dict_fields{dict_data};
RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields));
}
// swap endian in dict_data if necessary (swap_endian == true)
if (context.swap_endian) {
ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(dict_data));
}
if (dictionary_batch->isDelta()) {
if (kind != nullptr) {
*kind = DictionaryKind::Delta;
}
return context.dictionary_memo->AddDictionaryDelta(id, dict_data);
}
ARROW_ASSIGN_OR_RAISE(bool inserted,
context.dictionary_memo->AddOrReplaceDictionary(id, dict_data));
if (kind != nullptr) {
*kind = inserted ? DictionaryKind::New : DictionaryKind::Replacement;
}
return Status::OK();
}
Status ReadDictionary(const Message& message, const IpcReadContext& context,
DictionaryKind* kind) {
// Only invoke this method if we already know we have a dictionary message
DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH);
CHECK_HAS_BODY(message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
return ReadDictionary(*message.metadata(), context, kind, reader.get());
}
// ----------------------------------------------------------------------
// RecordBatchStreamReader implementation
class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
public:
Status Open(std::unique_ptr<MessageReader> message_reader,
const IpcReadOptions& options) {
message_reader_ = std::move(message_reader);
options_ = options;
// Read schema
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, ReadNextMessage());
if (!message) {
return Status::Invalid("Tried reading schema message, was null or length 0");
}
RETURN_NOT_OK(UnpackSchemaMessage(*message, options, &dictionary_memo_, &schema_,
&out_schema_, &field_inclusion_mask_,
&swap_endian_));
return Status::OK();
}
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
if (!have_read_initial_dictionaries_) {
RETURN_NOT_OK(ReadInitialDictionaries());
}
if (empty_stream_) {
// ARROW-6006: Degenerate case where stream contains no data, we do not
// bother trying to read a RecordBatch message from the stream
*batch = nullptr;
return Status::OK();
}
// Continue to read other dictionaries, if any
std::unique_ptr<Message> message;
ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
while (message != nullptr && message->type() == MessageType::DICTIONARY_BATCH) {
RETURN_NOT_OK(ReadDictionary(*message));
ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
}
if (message == nullptr) {
// End of stream
*batch = nullptr;
return Status::OK();
}
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
context, reader.get())
.Value(batch);
}
std::shared_ptr<Schema> schema() const override { return out_schema_; }
ReadStats stats() const override { return stats_; }
private:
Result<std::unique_ptr<Message>> ReadNextMessage() {
ARROW_ASSIGN_OR_RAISE(auto message, message_reader_->ReadNextMessage());
if (message) {
++stats_.num_messages;
switch (message->type()) {
case MessageType::RECORD_BATCH:
++stats_.num_record_batches;
break;
case MessageType::DICTIONARY_BATCH:
++stats_.num_dictionary_batches;
break;
default:
break;
}
}
return std::move(message);
}
// Read dictionary from dictionary batch
Status ReadDictionary(const Message& message) {
DictionaryKind kind;
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind));
switch (kind) {
case DictionaryKind::New:
break;
case DictionaryKind::Delta:
++stats_.num_dictionary_deltas;
break;
case DictionaryKind::Replacement:
++stats_.num_replaced_dictionaries;
break;
}
return Status::OK();
}
Status ReadInitialDictionaries() {
// We must receive all dictionaries before reconstructing the
// first record batch. Subsequent dictionary deltas modify the memo
std::unique_ptr<Message> message;
// TODO(wesm): In future, we may want to reconcile the ids in the stream with
// those found in the schema
const auto num_dicts = dictionary_memo_.fields().num_dicts();
for (int i = 0; i < num_dicts; ++i) {
ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
if (!message) {
if (i == 0) {
/// ARROW-6006: If we fail to find any dictionaries in the stream, then
/// it may be that the stream has a schema but no actual data. In such
/// case we communicate that we were unable to find the dictionaries
/// (but there was no failure otherwise), so the caller can decide what
/// to do
empty_stream_ = true;
break;
} else {
// ARROW-6126, the stream terminated before receiving the expected
// number of dictionaries
return Status::Invalid("IPC stream ended without reading the expected number (",
num_dicts, ") of dictionaries");
}
}
if (message->type() != MessageType::DICTIONARY_BATCH) {
return Status::Invalid("IPC stream did not have the expected number (", num_dicts,
") of dictionaries at the start of the stream");
}
RETURN_NOT_OK(ReadDictionary(*message));
}
have_read_initial_dictionaries_ = true;
return Status::OK();
}
std::unique_ptr<MessageReader> message_reader_;
IpcReadOptions options_;
std::vector<bool> field_inclusion_mask_;
bool have_read_initial_dictionaries_ = false;
// Flag to set in case where we fail to observe all dictionaries in a stream,
// and so the reader should not attempt to parse any messages
bool empty_stream_ = false;
ReadStats stats_;
DictionaryMemo dictionary_memo_;
std::shared_ptr<Schema> schema_, out_schema_;
bool swap_endian_;
};
// ----------------------------------------------------------------------
// Stream reader constructors
Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
std::unique_ptr<MessageReader> message_reader, const IpcReadOptions& options) {
// Private ctor
auto result = std::make_shared<RecordBatchStreamReaderImpl>();
RETURN_NOT_OK(result->Open(std::move(message_reader), options));
return result;
}
Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
io::InputStream* stream, const IpcReadOptions& options) {
return Open(MessageReader::Open(stream), options);
}
Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
const std::shared_ptr<io::InputStream>& stream, const IpcReadOptions& options) {
return Open(MessageReader::Open(stream), options);
}
// ----------------------------------------------------------------------
// Reader implementation
static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) {
return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()};
}
class RecordBatchFileReaderImpl : public RecordBatchFileReader {
public:
RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {}
int num_record_batches() const override {
return static_cast<int>(internal::FlatBuffersVectorSize(footer_->recordBatches()));
}
MetadataVersion version() const override {
return internal::GetMetadataVersion(footer_->version());
}
Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) override {
DCHECK_GE(i, 0);
DCHECK_LT(i, num_record_batches());
if (!read_dictionaries_) {
RETURN_NOT_OK(ReadDictionaries());
read_dictionaries_ = true;
}
ARROW_ASSIGN_OR_RAISE(auto message, ReadMessageFromBlock(GetRecordBatchBlock(i)));
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
ARROW_ASSIGN_OR_RAISE(auto batch, ReadRecordBatchInternal(
*message->metadata(), schema_,
field_inclusion_mask_, context, reader.get()));
++stats_.num_record_batches;
return batch;
}
Result<int64_t> CountRows() override {
int64_t total = 0;
for (int i = 0; i < num_record_batches(); i++) {
ARROW_ASSIGN_OR_RAISE(auto outer_message,
ReadMessageFromBlock(GetRecordBatchBlock(i)));
auto metadata = outer_message->metadata();
const flatbuf::Message* message = nullptr;
RETURN_NOT_OK(
internal::VerifyMessage(metadata->data(), metadata->size(), &message));
auto batch = message->header_as_RecordBatch();
if (batch == nullptr) {
return Status::IOError(
"Header-type of flatbuffer-encoded Message is not RecordBatch.");
}
total += batch->length();
}
return total;
}
Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
const IpcReadOptions& options) {
owned_file_ = file;
return Open(file.get(), footer_offset, options);
}
Status Open(io::RandomAccessFile* file, int64_t footer_offset,
const IpcReadOptions& options) {
file_ = file;
options_ = options;
footer_offset_ = footer_offset;
RETURN_NOT_OK(ReadFooter());
// Get the schema and record any observed dictionaries
RETURN_NOT_OK(UnpackSchemaMessage(footer_->schema(), options, &dictionary_memo_,
&schema_, &out_schema_, &field_inclusion_mask_,
&swap_endian_));
++stats_.num_messages;
return Status::OK();
}
std::shared_ptr<Schema> schema() const override { return out_schema_; }
std::shared_ptr<const KeyValueMetadata> metadata() const override { return metadata_; }
ReadStats stats() const override { return stats_; }
private:
FileBlock GetRecordBatchBlock(int i) const {
return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
}
FileBlock GetDictionaryBlock(int i) const {
return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
}
Result<std::unique_ptr<Message>> ReadMessageFromBlock(const FileBlock& block) {
if (!BitUtil::IsMultipleOf8(block.offset) ||
!BitUtil::IsMultipleOf8(block.metadata_length) ||
!BitUtil::IsMultipleOf8(block.body_length)) {
return Status::Invalid("Unaligned block in IPC file");
}
// TODO(wesm): this breaks integration tests, see ARROW-3256
// DCHECK_EQ((*out)->body_length(), block.body_length);
ARROW_ASSIGN_OR_RAISE(auto message,
ReadMessage(block.offset, block.metadata_length, file_));
++stats_.num_messages;
return std::move(message);
}
Status ReadDictionaries() {
// Read all the dictionaries
for (int i = 0; i < num_dictionaries(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto message, ReadMessageFromBlock(GetDictionaryBlock(i)));
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
DictionaryKind kind;
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
RETURN_NOT_OK(ReadDictionary(*message->metadata(), context, &kind, reader.get()));
++stats_.num_dictionary_batches;
if (kind != DictionaryKind::New) {
return Status::Invalid(
"Unsupported dictionary replacement or "
"dictionary delta in IPC file");
}
}
return Status::OK();
}
Status ReadFooter() {
const int32_t magic_size = static_cast<int>(strlen(kArrowMagicBytes));
if (footer_offset_ <= magic_size * 2 + 4) {
return Status::Invalid("File is too small: ", footer_offset_);
}
int file_end_size = static_cast<int>(magic_size + sizeof(int32_t));
ARROW_ASSIGN_OR_RAISE(auto buffer,
file_->ReadAt(footer_offset_ - file_end_size, file_end_size));
const int64_t expected_footer_size = magic_size + sizeof(int32_t);
if (buffer->size() < expected_footer_size) {
return Status::Invalid("Unable to read ", expected_footer_size, "from end of file");
}
if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) {
return Status::Invalid("Not an Arrow file");
}
int32_t footer_length =
BitUtil::FromLittleEndian(*reinterpret_cast<const int32_t*>(buffer->data()));
if (footer_length <= 0 || footer_length > footer_offset_ - magic_size * 2 - 4) {
return Status::Invalid("File is smaller than indicated metadata size");
}
// Now read the footer
ARROW_ASSIGN_OR_RAISE(
footer_buffer_,
file_->ReadAt(footer_offset_ - footer_length - file_end_size, footer_length));
const auto data = footer_buffer_->data();
const auto size = footer_buffer_->size();
if (!internal::VerifyFlatbuffers<flatbuf::Footer>(data, size)) {
return Status::IOError("Verification of flatbuffer-encoded Footer failed.");
}
footer_ = flatbuf::GetFooter(data);
auto fb_metadata = footer_->custom_metadata();
if (fb_metadata != nullptr) {
std::shared_ptr<KeyValueMetadata> md;
RETURN_NOT_OK(internal::GetKeyValueMetadata(fb_metadata, &md));
metadata_ = std::move(md); // const-ify
}
return Status::OK();
}
int num_dictionaries() const {
return static_cast<int>(internal::FlatBuffersVectorSize(footer_->dictionaries()));
}
io::RandomAccessFile* file_;
IpcReadOptions options_;
std::vector<bool> field_inclusion_mask_;
std::shared_ptr<io::RandomAccessFile> owned_file_;
// The location where the Arrow file layout ends. May be the end of the file
// or some other location if embedded in a larger file.
int64_t footer_offset_;
// Footer metadata
std::shared_ptr<Buffer> footer_buffer_;
const flatbuf::Footer* footer_;
std::shared_ptr<const KeyValueMetadata> metadata_;
bool read_dictionaries_ = false;
DictionaryMemo dictionary_memo_;
// Reconstructed schema, including any read dictionaries
std::shared_ptr<Schema> schema_;
// Schema with deselected fields dropped
std::shared_ptr<Schema> out_schema_;
ReadStats stats_;
bool swap_endian_;
};
Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
io::RandomAccessFile* file, const IpcReadOptions& options) {
ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
return Open(file, footer_offset, options);
}
Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) {
auto result = std::make_shared<RecordBatchFileReaderImpl>();
RETURN_NOT_OK(result->Open(file, footer_offset, options));
return result;
}
Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
return Open(file, footer_offset, options);
}
Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
const IpcReadOptions& options) {
auto result = std::make_shared<RecordBatchFileReaderImpl>();
RETURN_NOT_OK(result->Open(file, footer_offset, options));
return result;
}
Status Listener::OnEOS() { return Status::OK(); }
Status Listener::OnSchemaDecoded(std::shared_ptr<Schema> schema) { return Status::OK(); }
Status Listener::OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch) {
return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented");
}
class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener {
private:
enum State {
SCHEMA,
INITIAL_DICTIONARIES,
RECORD_BATCHES,
EOS,
};
public:
explicit StreamDecoderImpl(std::shared_ptr<Listener> listener, IpcReadOptions options)
: listener_(std::move(listener)),
options_(std::move(options)),
state_(State::SCHEMA),
message_decoder_(std::shared_ptr<StreamDecoderImpl>(this, [](void*) {}),
options_.memory_pool),
n_required_dictionaries_(0) {}
Status OnMessageDecoded(std::unique_ptr<Message> message) override {
++stats_.num_messages;
switch (state_) {
case State::SCHEMA:
ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message)));
break;
case State::INITIAL_DICTIONARIES:
ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message)));
break;
case State::RECORD_BATCHES:
ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message)));
break;
case State::EOS:
break;
}
return Status::OK();
}
Status OnEOS() override {
state_ = State::EOS;
return listener_->OnEOS();
}
Status Consume(const uint8_t* data, int64_t size) {
return message_decoder_.Consume(data, size);
}
Status Consume(std::shared_ptr<Buffer> buffer) {
return message_decoder_.Consume(std::move(buffer));
}
std::shared_ptr<Schema> schema() const { return out_schema_; }
int64_t next_required_size() const { return message_decoder_.next_required_size(); }
ReadStats stats() const { return stats_; }
private:
Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) {
RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_,
&out_schema_, &field_inclusion_mask_,
&swap_endian_));
n_required_dictionaries_ = dictionary_memo_.fields().num_fields();
if (n_required_dictionaries_ == 0) {
state_ = State::RECORD_BATCHES;
RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
} else {
state_ = State::INITIAL_DICTIONARIES;
}
return Status::OK();
}
Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) {
if (message->type() != MessageType::DICTIONARY_BATCH) {
return Status::Invalid("IPC stream did not have the expected number (",
dictionary_memo_.fields().num_fields(),
") of dictionaries at the start of the stream");
}
RETURN_NOT_OK(ReadDictionary(*message));
n_required_dictionaries_--;
if (n_required_dictionaries_ == 0) {
state_ = State::RECORD_BATCHES;
ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
}
return Status::OK();
}
Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
if (message->type() == MessageType::DICTIONARY_BATCH) {
return ReadDictionary(*message);
} else {
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
ARROW_ASSIGN_OR_RAISE(
auto batch,
ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
context, reader.get()));
++stats_.num_record_batches;
return listener_->OnRecordBatchDecoded(std::move(batch));
}
}
// Read dictionary from dictionary batch
Status ReadDictionary(const Message& message) {
DictionaryKind kind;
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind));
++stats_.num_dictionary_batches;
switch (kind) {
case DictionaryKind::New:
break;
case DictionaryKind::Delta:
++stats_.num_dictionary_deltas;
break;
case DictionaryKind::Replacement:
++stats_.num_replaced_dictionaries;
break;
}
return Status::OK();
}
std::shared_ptr<Listener> listener_;
const IpcReadOptions options_;
State state_;
MessageDecoder message_decoder_;
std::vector<bool> field_inclusion_mask_;
int n_required_dictionaries_;
DictionaryMemo dictionary_memo_;
std::shared_ptr<Schema> schema_, out_schema_;
ReadStats stats_;
bool swap_endian_;
};
StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener, IpcReadOptions options) {
impl_.reset(new StreamDecoderImpl(std::move(listener), options));
}
StreamDecoder::~StreamDecoder() {}
Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
return impl_->Consume(data, size);
}
Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
return impl_->Consume(std::move(buffer));
}
std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
int64_t StreamDecoder::next_required_size() const { return impl_->next_required_size(); }
ReadStats StreamDecoder::stats() const { return impl_->stats(); }
Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
DictionaryMemo* dictionary_memo) {
std::unique_ptr<MessageReader> reader = MessageReader::Open(stream);
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage());
if (!message) {
return Status::Invalid("Tried reading schema message, was null or length 0");
}
CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message->type());
return ReadSchema(*message, dictionary_memo);
}
Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
DictionaryMemo* dictionary_memo) {
std::shared_ptr<Schema> result;
RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result));
return result;
}
Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadContiguousPayload(file, &message));
return ReadTensor(*message);
}
Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
std::vector<std::string> dim_names;
CHECK_HAS_BODY(message);
RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
&dim_names));
return Tensor::Make(type, message.body(), shape, strides, dim_names);
}
namespace {
Result<std::shared_ptr<SparseIndex>> ReadSparseCOOIndex(
const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
int64_t non_zero_length, io::RandomAccessFile* file) {
auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO();
const auto ndim = static_cast<int64_t>(shape.size());
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(sparse_index, &indices_type));
const int64_t indices_elsize = GetByteWidth(*indices_type);
auto* indices_buffer = sparse_index->indicesBuffer();
ARROW_ASSIGN_OR_RAISE(auto indices_data,
file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
std::vector<int64_t> indices_shape({non_zero_length, ndim});
auto* indices_strides = sparse_index->indicesStrides();
std::vector<int64_t> strides(2);
if (indices_strides && indices_strides->size() > 0) {
if (indices_strides->size() != 2) {
return Status::Invalid("Wrong size for indicesStrides in SparseCOOIndex");
}
strides[0] = indices_strides->Get(0);
strides[1] = indices_strides->Get(1);
} else {
// Row-major by default
strides[0] = indices_elsize * ndim;
strides[1] = indices_elsize;
}
return SparseCOOIndex::Make(
std::make_shared<Tensor>(indices_type, indices_data, indices_shape, strides),
sparse_index->isCanonical());
}
Result<std::shared_ptr<SparseIndex>> ReadSparseCSXIndex(
const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
int64_t non_zero_length, io::RandomAccessFile* file) {
if (shape.size() != 2) {
return Status::Invalid("Invalid shape length for a sparse matrix");
}
auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
std::shared_ptr<DataType> indptr_type, indices_type;
RETURN_NOT_OK(
internal::GetSparseCSXIndexMetadata(sparse_index, &indptr_type, &indices_type));
const int indptr_byte_width = GetByteWidth(*indptr_type);
auto* indptr_buffer = sparse_index->indptrBuffer();
ARROW_ASSIGN_OR_RAISE(auto indptr_data,
file->ReadAt(indptr_buffer->offset(), indptr_buffer->length()));
auto* indices_buffer = sparse_index->indicesBuffer();
ARROW_ASSIGN_OR_RAISE(auto indices_data,
file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
std::vector<int64_t> indices_shape({non_zero_length});
const auto indices_minimum_bytes = indices_shape[0] * GetByteWidth(*indices_type);
if (indices_minimum_bytes > indices_buffer->length()) {
return Status::Invalid("shape is inconsistent to the size of indices buffer");
}
switch (sparse_index->compressedAxis()) {
case flatbuf::SparseMatrixCompressedAxis::Row: {
std::vector<int64_t> indptr_shape({shape[0] + 1});
const int64_t indptr_minimum_bytes = indptr_shape[0] * indptr_byte_width;
if (indptr_minimum_bytes > indptr_buffer->length()) {
return Status::Invalid("shape is inconsistent to the size of indptr buffer");
}
return std::make_shared<SparseCSRIndex>(
std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
}
case flatbuf::SparseMatrixCompressedAxis::Column: {
std::vector<int64_t> indptr_shape({shape[1] + 1});
const int64_t indptr_minimum_bytes = indptr_shape[0] * indptr_byte_width;
if (indptr_minimum_bytes > indptr_buffer->length()) {
return Status::Invalid("shape is inconsistent to the size of indptr buffer");
}
return std::make_shared<SparseCSCIndex>(
std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
}
default:
return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
}
}
Result<std::shared_ptr<SparseIndex>> ReadSparseCSFIndex(
const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
io::RandomAccessFile* file) {
auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCSF();
const auto ndim = static_cast<int64_t>(shape.size());
auto* indptr_buffers = sparse_index->indptrBuffers();
auto* indices_buffers = sparse_index->indicesBuffers();
std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
std::shared_ptr<DataType> indptr_type, indices_type;
std::vector<int64_t> axis_order, indices_size;
RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
sparse_index, &axis_order, &indices_size, &indptr_type, &indices_type));
for (int i = 0; i < static_cast<int>(indptr_buffers->size()); ++i) {
ARROW_ASSIGN_OR_RAISE(indptr_data[i], file->ReadAt(indptr_buffers->Get(i)->offset(),
indptr_buffers->Get(i)->length()));
}
for (int i = 0; i < static_cast<int>(indices_buffers->size()); ++i) {
ARROW_ASSIGN_OR_RAISE(indices_data[i],
file->ReadAt(indices_buffers->Get(i)->offset(),
indices_buffers->Get(i)->length()));
}
return SparseCSFIndex::Make(indptr_type, indices_type, indices_size, axis_order,
indptr_data, indices_data);
}
Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCOOIndex(
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names,
const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
const std::shared_ptr<Buffer>& data) {
return SparseCOOTensor::Make(sparse_index, type, data, shape, dim_names);
}
Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSRIndex(
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names,
const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
const std::shared_ptr<Buffer>& data) {
return SparseCSRMatrix::Make(sparse_index, type, data, shape, dim_names);
}
Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSCIndex(
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names,
const std::shared_ptr<SparseCSCIndex>& sparse_index, int64_t non_zero_length,
const std::shared_ptr<Buffer>& data) {
return SparseCSCMatrix::Make(sparse_index, type, data, shape, dim_names);
}
Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSFIndex(
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names,
const std::shared_ptr<SparseCSFIndex>& sparse_index,
const std::shared_ptr<Buffer>& data) {
return SparseCSFTensor::Make(sparse_index, type, data, shape, dim_names);
}
Status ReadSparseTensorMetadata(const Buffer& metadata,
std::shared_ptr<DataType>* out_type,
std::vector<int64_t>* out_shape,
std::vector<std::string>* out_dim_names,
int64_t* out_non_zero_length,
SparseTensorFormat::type* out_format_id,
const flatbuf::SparseTensor** out_fb_sparse_tensor,
const flatbuf::Buffer** out_buffer) {
RETURN_NOT_OK(internal::GetSparseTensorMetadata(
metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));
const flatbuf::Message* message = nullptr;
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
auto sparse_tensor = message->header_as_SparseTensor();
if (sparse_tensor == nullptr) {
return Status::IOError(
"Header-type of flatbuffer-encoded Message is not SparseTensor.");
}
*out_fb_sparse_tensor = sparse_tensor;
auto buffer = sparse_tensor->data();
if (!BitUtil::IsMultipleOf8(buffer->offset())) {
return Status::Invalid(
"Buffer of sparse index data did not start on 8-byte aligned offset: ",
buffer->offset());
}
*out_buffer = buffer;
return Status::OK();
}
} // namespace
namespace internal {
namespace {
Result<size_t> GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
const size_t ndim) {
switch (format_id) {
case SparseTensorFormat::COO:
return 2;
case SparseTensorFormat::CSR:
return 3;
case SparseTensorFormat::CSC:
return 3;
case SparseTensorFormat::CSF:
return 2 * ndim;
default:
return Status::Invalid("Unrecognized sparse tensor format");
}
}
Status CheckSparseTensorBodyBufferCount(const IpcPayload& payload,
SparseTensorFormat::type sparse_tensor_format_id,
const size_t ndim) {
size_t expected_body_buffer_count = 0;
ARROW_ASSIGN_OR_RAISE(expected_body_buffer_count,
GetSparseTensorBodyBufferCount(sparse_tensor_format_id, ndim));
if (payload.body_buffers.size() != expected_body_buffer_count) {
return Status::Invalid("Invalid body buffer count for a sparse tensor");
}
return Status::OK();
}
} // namespace
Result<size_t> ReadSparseTensorBodyBufferCount(const Buffer& metadata) {
SparseTensorFormat::type format_id;
std::vector<int64_t> shape;
RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, &shape, nullptr,
nullptr, &format_id));
return GetSparseTensorBodyBufferCount(format_id, static_cast<size_t>(shape.size()));
}
Result<std::shared_ptr<SparseTensor>> ReadSparseTensorPayload(const IpcPayload& payload) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<std::string> dim_names;
int64_t non_zero_length;
SparseTensorFormat::type sparse_tensor_format_id;
const flatbuf::SparseTensor* sparse_tensor;
const flatbuf::Buffer* buffer;
RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
&non_zero_length, &sparse_tensor_format_id,
&sparse_tensor, &buffer));
RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id,
static_cast<size_t>(shape.size())));
switch (sparse_tensor_format_id) {
case SparseTensorFormat::COO: {
std::shared_ptr<SparseCOOIndex> sparse_index;
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
ARROW_ASSIGN_OR_RAISE(sparse_index,
SparseCOOIndex::Make(indices_type, shape, non_zero_length,
payload.body_buffers[0]));
return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
non_zero_length, payload.body_buffers[1]);
}
case SparseTensorFormat::CSR: {
std::shared_ptr<SparseCSRIndex> sparse_index;
std::shared_ptr<DataType> indptr_type;
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
&indices_type));
ARROW_CHECK_EQ(indptr_type, indices_type);
ARROW_ASSIGN_OR_RAISE(
sparse_index,
SparseCSRIndex::Make(indices_type, shape, non_zero_length,
payload.body_buffers[0], payload.body_buffers[1]));
return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
non_zero_length, payload.body_buffers[2]);
}
case SparseTensorFormat::CSC: {
std::shared_ptr<SparseCSCIndex> sparse_index;
std::shared_ptr<DataType> indptr_type;
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
&indices_type));
ARROW_CHECK_EQ(indptr_type, indices_type);
ARROW_ASSIGN_OR_RAISE(
sparse_index,
SparseCSCIndex::Make(indices_type, shape, non_zero_length,
payload.body_buffers[0], payload.body_buffers[1]));
return MakeSparseTensorWithSparseCSCIndex(type, shape, dim_names, sparse_index,
non_zero_length, payload.body_buffers[2]);
}
case SparseTensorFormat::CSF: {
std::shared_ptr<SparseCSFIndex> sparse_index;
std::shared_ptr<DataType> indptr_type, indices_type;
std::vector<int64_t> axis_order, indices_size;
RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
sparse_tensor->sparseIndex_as_SparseTensorIndexCSF(), &axis_order,
&indices_size, &indptr_type, &indices_type));
ARROW_CHECK_EQ(indptr_type, indices_type);
const int64_t ndim = shape.size();
std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
for (int64_t i = 0; i < ndim - 1; ++i) {
indptr_data[i] = payload.body_buffers[i];
}
for (int64_t i = 0; i < ndim; ++i) {
indices_data[i] = payload.body_buffers[i + ndim - 1];
}
ARROW_ASSIGN_OR_RAISE(sparse_index,
SparseCSFIndex::Make(indptr_type, indices_type, indices_size,
axis_order, indptr_data, indices_data));
return MakeSparseTensorWithSparseCSFIndex(type, shape, dim_names, sparse_index,
payload.body_buffers[2 * ndim - 1]);
}
default:
return Status::Invalid("Unsupported sparse index format");
}
}
} // namespace internal
Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Buffer& metadata,
io::RandomAccessFile* file) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<std::string> dim_names;
int64_t non_zero_length;
SparseTensorFormat::type sparse_tensor_format_id;
const flatbuf::SparseTensor* sparse_tensor;
const flatbuf::Buffer* buffer;
RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
&non_zero_length, &sparse_tensor_format_id,
&sparse_tensor, &buffer));
ARROW_ASSIGN_OR_RAISE(auto data, file->ReadAt(buffer->offset(), buffer->length()));
std::shared_ptr<SparseIndex> sparse_index;
switch (sparse_tensor_format_id) {
case SparseTensorFormat::COO: {
ARROW_ASSIGN_OR_RAISE(
sparse_index, ReadSparseCOOIndex(sparse_tensor, shape, non_zero_length, file));
return MakeSparseTensorWithSparseCOOIndex(
type, shape, dim_names, checked_pointer_cast<SparseCOOIndex>(sparse_index),
non_zero_length, data);
}
case SparseTensorFormat::CSR: {
ARROW_ASSIGN_OR_RAISE(
sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
return MakeSparseTensorWithSparseCSRIndex(
type, shape, dim_names, checked_pointer_cast<SparseCSRIndex>(sparse_index),
non_zero_length, data);
}
case SparseTensorFormat::CSC: {
ARROW_ASSIGN_OR_RAISE(
sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
return MakeSparseTensorWithSparseCSCIndex(
type, shape, dim_names, checked_pointer_cast<SparseCSCIndex>(sparse_index),
non_zero_length, data);
}
case SparseTensorFormat::CSF: {
ARROW_ASSIGN_OR_RAISE(sparse_index, ReadSparseCSFIndex(sparse_tensor, shape, file));
return MakeSparseTensorWithSparseCSFIndex(
type, shape, dim_names, checked_pointer_cast<SparseCSFIndex>(sparse_index),
data);
}
default:
return Status::Invalid("Unsupported sparse index format");
}
}
Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Message& message) {
CHECK_HAS_BODY(message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
return ReadSparseTensor(*message.metadata(), reader.get());
}
Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(io::InputStream* file) {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadContiguousPayload(file, &message));
CHECK_MESSAGE_TYPE(MessageType::SPARSE_TENSOR, message->type());
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
return ReadSparseTensor(*message->metadata(), reader.get());
}
///////////////////////////////////////////////////////////////////////////
// Helpers for fuzzing
namespace internal {
Status FuzzIpcStream(const uint8_t* data, int64_t size) {
auto buffer = std::make_shared<Buffer>(data, size);
io::BufferReader buffer_reader(buffer);
std::shared_ptr<RecordBatchReader> batch_reader;
ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchStreamReader::Open(&buffer_reader));
while (true) {
std::shared_ptr<arrow::RecordBatch> batch;
RETURN_NOT_OK(batch_reader->ReadNext(&batch));
if (batch == nullptr) {
break;
}
RETURN_NOT_OK(batch->ValidateFull());
}
return Status::OK();
}
Status FuzzIpcFile(const uint8_t* data, int64_t size) {
auto buffer = std::make_shared<Buffer>(data, size);
io::BufferReader buffer_reader(buffer);
std::shared_ptr<RecordBatchFileReader> batch_reader;
ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchFileReader::Open(&buffer_reader));
const int n_batches = batch_reader->num_record_batches();
for (int i = 0; i < n_batches; ++i) {
ARROW_ASSIGN_OR_RAISE(auto batch, batch_reader->ReadRecordBatch(i));
RETURN_NOT_OK(batch->ValidateFull());
}
return Status::OK();
}
Status FuzzIpcTensorStream(const uint8_t* data, int64_t size) {
auto buffer = std::make_shared<Buffer>(data, size);
io::BufferReader buffer_reader(buffer);
std::shared_ptr<Tensor> tensor;
while (true) {
ARROW_ASSIGN_OR_RAISE(tensor, ReadTensor(&buffer_reader));
if (tensor == nullptr) {
break;
}
RETURN_NOT_OK(tensor->Validate());
}
return Status::OK();
}
} // namespace internal
} // namespace ipc
} // namespace arrow