blob: 7e766677f8b282592ac4ee7879cfe3defd9a026b [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flatbuf;
using Apache.Arrow.Types;
using Apache.Arrow.Memory;
using Google.FlatBuffers;
using Type = System.Type;
namespace Apache.Arrow.Ipc
{
internal abstract class ArrowReaderImplementation : IDisposable
{
public Schema Schema
{
get
{
if (!HasReadSchema)
{
ReadSchema();
}
return _schema;
}
}
protected internal bool HasReadSchema => _schema != null;
private protected DictionaryMemo _dictionaryMemo;
private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo();
private protected readonly MemoryAllocator _allocator;
private readonly ICompressionCodecFactory _compressionCodecFactory;
private protected Schema _schema;
private protected ArrowReaderImplementation() : this(null, null)
{ }
private protected ArrowReaderImplementation(MemoryAllocator allocator, ICompressionCodecFactory compressionCodecFactory)
{
_allocator = allocator ?? MemoryAllocator.Default.Value;
_compressionCodecFactory = compressionCodecFactory;
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
}
public abstract ValueTask ReadSchemaAsync(CancellationToken cancellationToken);
public abstract void ReadSchema();
public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();
internal static T ReadMessage<T>(ByteBuffer bb)
where T : struct, IFlatbufferObject
{
Type returnType = typeof(T);
Flatbuf.Message msg = Flatbuf.Message.GetRootAsMessage(bb);
if (MatchEnum(msg.HeaderType, returnType))
{
return msg.Header<T>().Value;
}
else
{
throw new Exception($"Requested type '{returnType.Name}' " +
$"did not match type found at offset => '{msg.HeaderType}'");
}
}
private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuffType)
{
switch (messageHeader)
{
case Flatbuf.MessageHeader.RecordBatch:
return flatBuffType == typeof(Flatbuf.RecordBatch);
case Flatbuf.MessageHeader.DictionaryBatch:
return flatBuffType == typeof(Flatbuf.DictionaryBatch);
case Flatbuf.MessageHeader.Schema:
return flatBuffType == typeof(Flatbuf.Schema);
case Flatbuf.MessageHeader.Tensor:
return flatBuffType == typeof(Flatbuf.Tensor);
case Flatbuf.MessageHeader.NONE:
throw new ArgumentException("MessageHeader NONE has no matching flatbuf types", nameof(messageHeader));
default:
throw new ArgumentException($"Unexpected MessageHeader value", nameof(messageHeader));
}
}
/// <summary>
/// Create a record batch or dictionary batch from Flatbuf.Message.
/// </summary>
/// <remarks>
/// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch.
/// </remarks>>
/// <returns>
/// The record batch when the message type is RecordBatch.
/// Null when the message type is not RecordBatch.
/// </returns>
protected RecordBatch CreateArrowObjectFromMessage(
Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner)
{
switch (message.HeaderType)
{
case Flatbuf.MessageHeader.Schema:
// TODO: Read schema and verify equality?
break;
case Flatbuf.MessageHeader.DictionaryBatch:
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;
ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner);
break;
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
default:
// NOTE: Skip unsupported message type
Debug.WriteLine($"Skipping unsupported message type '{message.HeaderType}'");
break;
}
return null;
}
internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer)
{
return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0);
}
private void ReadDictionaryBatch(
MetadataVersion version,
Flatbuf.DictionaryBatch dictionaryBatch,
ByteBuffer bodyByteBuffer,
IMemoryOwner<byte> memoryOwner)
{
long id = dictionaryBatch.Id;
IArrowType valueType = DictionaryMemo.GetDictionaryType(id);
Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data;
if (!recordBatch.HasValue)
{
throw new InvalidDataException("Dictionary must contain RecordBatch");
}
Field valueField = new Field("dummy", valueType, true);
var schema = new Schema(new[] { valueField }, default);
IList<IArrowArray> arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value);
if (arrays.Count != 1)
{
throw new InvalidDataException("Dictionary record batch must contain only one field");
}
if (dictionaryBatch.IsDelta)
{
DictionaryMemo.AddDeltaDictionary(id, arrays[0], _allocator);
}
else
{
DictionaryMemo.AddOrReplaceDictionary(id, arrays[0]);
}
}
private List<IArrowArray> BuildArrays(
MetadataVersion version,
Schema schema,
ByteBuffer messageBuffer,
Flatbuf.RecordBatch recordBatchMessage)
{
var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength);
if (recordBatchMessage.NodesLength == 0)
{
return arrays;
}
using var bufferCreator = GetBufferCreator(recordBatchMessage.Compression);
var recordBatchEnumerator = new RecordBatchEnumerator(in recordBatchMessage);
int schemaFieldIndex = 0;
do
{
Field field = schema.GetFieldByIndex(schemaFieldIndex++);
Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode;
ArrayData arrayData = LoadField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator);
arrays.Add(ArrowArrayFactory.BuildArray(arrayData));
} while (recordBatchEnumerator.MoveNextNode());
return arrays;
}
private IBufferCreator GetBufferCreator(BodyCompression? compression)
{
if (!compression.HasValue)
{
return NoOpBufferCreator.Instance;
}
var method = compression.Value.Method;
if (method != BodyCompressionMethod.BUFFER)
{
throw new NotImplementedException($"Compression method {method} is not supported");
}
var codec = compression.Value.Codec;
if (_compressionCodecFactory == null)
{
throw new Exception(
$"Body is compressed with codec {codec} but no {nameof(ICompressionCodecFactory)} has been configured to decompress buffers");
}
var decompressor = codec switch
{
Apache.Arrow.Flatbuf.CompressionType.LZ4_FRAME => _compressionCodecFactory.CreateCodec(CompressionCodecType.Lz4Frame),
Apache.Arrow.Flatbuf.CompressionType.ZSTD => _compressionCodecFactory.CreateCodec(CompressionCodecType.Zstd),
_ => throw new NotImplementedException($"Compression codec {codec} is not supported")
};
return new DecompressingBufferCreator(decompressor, _allocator);
}
private ArrayData LoadField(
MetadataVersion version,
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
in Flatbuf.FieldNode fieldNode,
ByteBuffer bodyData,
IBufferCreator bufferCreator)
{
int fieldLength = (int)fieldNode.Length;
int fieldNullCount = (int)fieldNode.NullCount;
if (fieldLength < 0)
{
throw new InvalidDataException("Field length must be >= 0"); // TODO:Localize exception message
}
if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count must be >= 0"); // TODO:Localize exception message
}
int buffers;
switch (field.DataType.TypeId)
{
case ArrowTypeId.Null:
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty<ArrowBuffer>());
case ArrowTypeId.Union:
if (version < MetadataVersion.V5)
{
if (fieldNullCount > 0)
{
if (recordBatchEnumerator.CurrentBuffer.Length > 0)
{
// With older metadata we can get a validity bitmap. Fixing up union data is hard,
// so we will just quit.
throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap");
}
}
recordBatchEnumerator.MoveNextBuffer();
}
buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1;
break;
case ArrowTypeId.Struct:
case ArrowTypeId.FixedSizeList:
buffers = 1;
break;
case ArrowTypeId.String:
case ArrowTypeId.Binary:
case ArrowTypeId.LargeString:
case ArrowTypeId.LargeBinary:
case ArrowTypeId.ListView:
buffers = 3;
break;
case ArrowTypeId.StringView:
case ArrowTypeId.BinaryView:
buffers = checked((int)(2 + recordBatchEnumerator.CurrentVariadicCount));
recordBatchEnumerator.MoveNextVariadicCount();
break;
default:
buffers = 2;
break;
}
ArrowBuffer[] arrowBuff = new ArrowBuffer[buffers];
for (int i = 0; i < buffers; i++)
{
arrowBuff[i] = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator);
recordBatchEnumerator.MoveNextBuffer();
}
ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator);
IArrowArray dictionary = null;
if (field.DataType.TypeId == ArrowTypeId.Dictionary)
{
long id = DictionaryMemo.GetId(field);
dictionary = DictionaryMemo.GetDictionary(id);
}
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data);
}
private ArrayData[] GetChildren(
MetadataVersion version,
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
ByteBuffer bodyData,
IBufferCreator bufferCreator)
{
if (!(field.DataType is NestedType type)) return null;
int childrenCount = type.Fields.Count;
var children = new ArrayData[childrenCount];
for (int index = 0; index < childrenCount; index++)
{
recordBatchEnumerator.MoveNextNode();
Flatbuf.FieldNode childFieldNode = recordBatchEnumerator.CurrentNode;
Field childField = type.Fields[index];
children[index] = LoadField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator);
}
return children;
}
private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer, IBufferCreator bufferCreator)
{
if (buffer.Length <= 0)
{
return ArrowBuffer.Empty;
}
int offset = (int)buffer.Offset;
int length = (int)buffer.Length;
var data = bodyData.ToReadOnlyMemory(offset, length);
return bufferCreator.CreateBuffer(data);
}
}
internal struct RecordBatchEnumerator
{
private Flatbuf.RecordBatch RecordBatch { get; }
internal int CurrentBufferIndex { get; private set; }
internal int CurrentNodeIndex { get; private set; }
internal int CurrentVariadicCountIndex { get; private set; }
internal Flatbuf.Buffer CurrentBuffer => RecordBatch.Buffers(CurrentBufferIndex).GetValueOrDefault();
internal Flatbuf.FieldNode CurrentNode => RecordBatch.Nodes(CurrentNodeIndex).GetValueOrDefault();
internal long CurrentVariadicCount => RecordBatch.VariadicBufferCounts(CurrentVariadicCountIndex);
internal bool MoveNextBuffer()
{
return ++CurrentBufferIndex < RecordBatch.BuffersLength;
}
internal bool MoveNextNode()
{
return ++CurrentNodeIndex < RecordBatch.NodesLength;
}
internal bool MoveNextVariadicCount()
{
return ++CurrentVariadicCountIndex < RecordBatch.VariadicBufferCountsLength;
}
internal RecordBatchEnumerator(in Flatbuf.RecordBatch recordBatch)
{
RecordBatch = recordBatch;
CurrentBufferIndex = 0;
CurrentNodeIndex = 0;
CurrentVariadicCountIndex = 0;
}
}
}