blob: 9bbb7edf1c37be53c69e4c3cca97173d1cff2fbb [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using FlatBuffers;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Apache.Arrow.Ipc
{
internal abstract class ArrowReaderImplementation : IDisposable
{
public Schema Schema { get; protected set; }
protected bool HasReadSchema => Schema != null;
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
}
public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();
protected static T ReadMessage<T>(ByteBuffer bb)
where T : struct, IFlatbufferObject
{
var returnType = typeof(T);
var msg = Flatbuf.Message.GetRootAsMessage(bb);
if (MatchEnum(msg.HeaderType, returnType))
{
return msg.Header<T>().Value;
}
else
{
throw new Exception($"Requested type '{returnType.Name}' " +
$"did not match type found at offset => '{msg.HeaderType}'");
}
}
private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuffType)
{
switch (messageHeader)
{
case Flatbuf.MessageHeader.RecordBatch:
return flatBuffType == typeof(Flatbuf.RecordBatch);
case Flatbuf.MessageHeader.DictionaryBatch:
return flatBuffType == typeof(Flatbuf.DictionaryBatch);
case Flatbuf.MessageHeader.Schema:
return flatBuffType == typeof(Flatbuf.Schema);
case Flatbuf.MessageHeader.Tensor:
return flatBuffType == typeof(Flatbuf.Tensor);
case Flatbuf.MessageHeader.NONE:
throw new ArgumentException("MessageHeader NONE has no matching flatbuf types", nameof(messageHeader));
default:
throw new ArgumentException($"Unexpected MessageHeader value", nameof(messageHeader));
}
}
protected RecordBatch CreateArrowObjectFromMessage(
Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner)
{
switch (message.HeaderType)
{
case Flatbuf.MessageHeader.Schema:
// TODO: Read schema and verify equality?
break;
case Flatbuf.MessageHeader.DictionaryBatch:
// TODO: not supported currently
Debug.WriteLine("Dictionaries are not yet supported.");
break;
case Flatbuf.MessageHeader.RecordBatch:
var rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
default:
// NOTE: Skip unsupported message type
Debug.WriteLine($"Skipping unsupported message type '{message.HeaderType}'");
break;
}
return null;
}
internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer)
{
return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0);
}
private List<IArrowArray> BuildArrays(
Schema schema,
ByteBuffer messageBuffer,
Flatbuf.RecordBatch recordBatchMessage)
{
var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength);
int bufferIndex = 0;
for (var n = 0; n < recordBatchMessage.NodesLength; n++)
{
Field field = schema.GetFieldByIndex(n);
Flatbuf.FieldNode fieldNode = recordBatchMessage.Nodes(n).GetValueOrDefault();
ArrayData arrayData = field.DataType.IsFixedPrimitive() ?
LoadPrimitiveField(field, fieldNode, recordBatchMessage, messageBuffer, ref bufferIndex) :
LoadVariableField(field, fieldNode, recordBatchMessage, messageBuffer, ref bufferIndex);
arrays.Add(ArrowArrayFactory.BuildArray(arrayData));
}
return arrays;
}
private ArrayData LoadPrimitiveField(
Field field,
Flatbuf.FieldNode fieldNode,
Flatbuf.RecordBatch recordBatch,
ByteBuffer bodyData,
ref int bufferIndex)
{
var nullBitmapBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
var valueBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, nullBitmapBuffer);
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, valueBuffer);
var fieldLength = (int)fieldNode.Length;
var fieldNullCount = (int)fieldNode.NullCount;
if (fieldLength < 0)
{
throw new InvalidDataException("Field length must be >= 0"); // TODO:Localize exception message
}
if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message
}
var arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer };
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff);
}
private ArrayData LoadVariableField(
Field field,
Flatbuf.FieldNode fieldNode,
Flatbuf.RecordBatch recordBatch,
ByteBuffer bodyData,
ref int bufferIndex)
{
var nullBitmapBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
var offsetBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
var valueBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, nullBitmapBuffer);
ArrowBuffer offsetArrowBuffer = BuildArrowBuffer(bodyData, offsetBuffer);
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, valueBuffer);
var fieldLength = (int)fieldNode.Length;
var fieldNullCount = (int)fieldNode.NullCount;
if (fieldLength < 0)
{
throw new InvalidDataException("Field length must be >= 0"); // TODO: Localize exception message
}
if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count length must be >= 0"); //TODO: Localize exception message
}
var arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer };
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff);
}
private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer)
{
if (buffer.Length <= 0)
{
return ArrowBuffer.Empty;
}
int offset = (int)buffer.Offset;
int length = (int)buffer.Length;
var data = bodyData.ToReadOnlyMemory(offset, length);
return new ArrowBuffer(data);
}
}
}