blob: 3279f7030557a5f170022ce622f2c03aa4d3987a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using FlatBuffers;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Types;
namespace Apache.Arrow.Ipc
{
internal abstract class ArrowReaderImplementation : IDisposable
{
public Schema Schema { get; protected set; }
protected bool HasReadSchema => Schema != null;
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
}
public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();
internal static T ReadMessage<T>(ByteBuffer bb)
where T : struct, IFlatbufferObject
{
Type returnType = typeof(T);
Flatbuf.Message msg = Flatbuf.Message.GetRootAsMessage(bb);
if (MatchEnum(msg.HeaderType, returnType))
{
return msg.Header<T>().Value;
}
else
{
throw new Exception($"Requested type '{returnType.Name}' " +
$"did not match type found at offset => '{msg.HeaderType}'");
}
}
private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuffType)
{
switch (messageHeader)
{
case Flatbuf.MessageHeader.RecordBatch:
return flatBuffType == typeof(Flatbuf.RecordBatch);
case Flatbuf.MessageHeader.DictionaryBatch:
return flatBuffType == typeof(Flatbuf.DictionaryBatch);
case Flatbuf.MessageHeader.Schema:
return flatBuffType == typeof(Flatbuf.Schema);
case Flatbuf.MessageHeader.Tensor:
return flatBuffType == typeof(Flatbuf.Tensor);
case Flatbuf.MessageHeader.NONE:
throw new ArgumentException("MessageHeader NONE has no matching flatbuf types", nameof(messageHeader));
default:
throw new ArgumentException($"Unexpected MessageHeader value", nameof(messageHeader));
}
}
protected RecordBatch CreateArrowObjectFromMessage(
Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner)
{
switch (message.HeaderType)
{
case Flatbuf.MessageHeader.Schema:
// TODO: Read schema and verify equality?
break;
case Flatbuf.MessageHeader.DictionaryBatch:
// TODO: not supported currently
Debug.WriteLine("Dictionaries are not yet supported.");
break;
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
default:
// NOTE: Skip unsupported message type
Debug.WriteLine($"Skipping unsupported message type '{message.HeaderType}'");
break;
}
return null;
}
internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer)
{
return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0);
}
private List<IArrowArray> BuildArrays(
Schema schema,
ByteBuffer messageBuffer,
Flatbuf.RecordBatch recordBatchMessage)
{
var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength);
if (recordBatchMessage.NodesLength == 0)
{
return arrays;
}
var recordBatchEnumerator = new RecordBatchEnumerator(in recordBatchMessage);
int schemaFieldIndex = 0;
do
{
Field field = schema.GetFieldByIndex(schemaFieldIndex++);
Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode;
ArrayData arrayData = field.DataType.IsFixedPrimitive()
? LoadPrimitiveField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer)
: LoadVariableField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer);
arrays.Add(ArrowArrayFactory.BuildArray(arrayData));
} while (recordBatchEnumerator.MoveNextNode());
return arrays;
}
private ArrayData LoadPrimitiveField(
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
in Flatbuf.FieldNode fieldNode,
ByteBuffer bodyData)
{
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer);
if (!recordBatchEnumerator.MoveNextBuffer())
{
throw new Exception("Unable to move to the next buffer.");
}
int fieldLength = (int)fieldNode.Length;
int fieldNullCount = (int)fieldNode.NullCount;
if (fieldLength < 0)
{
throw new InvalidDataException("Field length must be >= 0"); // TODO:Localize exception message
}
if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message
}
ArrowBuffer[] arrowBuff;
if (field.DataType.TypeId == ArrowTypeId.Struct)
{
arrowBuff = new[] { nullArrowBuffer };
}
else
{
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer);
recordBatchEnumerator.MoveNextBuffer();
arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer };
}
ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData);
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children);
}
private ArrayData LoadVariableField(
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
in Flatbuf.FieldNode fieldNode,
ByteBuffer bodyData)
{
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer);
if (!recordBatchEnumerator.MoveNextBuffer())
{
throw new Exception("Unable to move to the next buffer.");
}
ArrowBuffer offsetArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer);
if (!recordBatchEnumerator.MoveNextBuffer())
{
throw new Exception("Unable to move to the next buffer.");
}
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer);
recordBatchEnumerator.MoveNextBuffer();
int fieldLength = (int)fieldNode.Length;
int fieldNullCount = (int)fieldNode.NullCount;
if (fieldLength < 0)
{
throw new InvalidDataException("Field length must be >= 0"); // TODO: Localize exception message
}
if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count length must be >= 0"); //TODO: Localize exception message
}
ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer };
ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData);
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children);
}
private ArrayData[] GetChildren(
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
ByteBuffer bodyData)
{
if (!(field.DataType is NestedType type)) return null;
int childrenCount = type.Fields.Count;
var children = new ArrayData[childrenCount];
for (int index = 0; index < childrenCount; index++)
{
recordBatchEnumerator.MoveNextNode();
Flatbuf.FieldNode childFieldNode = recordBatchEnumerator.CurrentNode;
Field childField = type.Fields[index];
ArrayData child = childField.DataType.IsFixedPrimitive()
? LoadPrimitiveField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData)
: LoadVariableField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData);
children[index] = child;
}
return children;
}
private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer)
{
if (buffer.Length <= 0)
{
return ArrowBuffer.Empty;
}
int offset = (int)buffer.Offset;
int length = (int)buffer.Length;
var data = bodyData.ToReadOnlyMemory(offset, length);
return new ArrowBuffer(data);
}
}
internal struct RecordBatchEnumerator
{
private Flatbuf.RecordBatch RecordBatch { get; }
internal int CurrentBufferIndex { get; private set; }
internal int CurrentNodeIndex { get; private set; }
internal Flatbuf.Buffer CurrentBuffer => RecordBatch.Buffers(CurrentBufferIndex).GetValueOrDefault();
internal Flatbuf.FieldNode CurrentNode => RecordBatch.Nodes(CurrentNodeIndex).GetValueOrDefault();
internal bool MoveNextBuffer()
{
return ++CurrentBufferIndex < RecordBatch.BuffersLength;
}
internal bool MoveNextNode()
{
return ++CurrentNodeIndex < RecordBatch.NodesLength;
}
internal RecordBatchEnumerator(in Flatbuf.RecordBatch recordBatch)
{
RecordBatch = recordBatch;
CurrentBufferIndex = 0;
CurrentNodeIndex = 0;
}
}
}