blob: b983be0fd0da38a2201dd1eb570204ba70514acb [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Buffers;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace Apache.Arrow
{
// Helpers to write Memory<byte> to Stream on netstandard
internal static partial class StreamExtensions
{
public static int Read(this Stream stream, Memory<byte> buffer)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return stream.Read(array.Array, array.Offset, array.Count);
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
int result = stream.Read(sharedBuffer, 0, buffer.Length);
new Span<byte>(sharedBuffer, 0, result).CopyTo(buffer.Span);
return result;
}
finally
{
ArrayPool<byte>.Shared.Return(sharedBuffer);
}
}
}
public static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
return FinishReadAsync(stream.ReadAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer, buffer);
async ValueTask<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
{
try
{
int result = await readTask.ConfigureAwait(false);
new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
return result;
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}
public static void Write(this Stream stream, ReadOnlyMemory<byte> buffer)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
stream.Write(array.Array, array.Offset, array.Count);
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
buffer.Span.CopyTo(sharedBuffer);
stream.Write(sharedBuffer, 0, buffer.Length);
}
finally
{
ArrayPool<byte>.Shared.Return(sharedBuffer);
}
}
}
public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
buffer.Span.CopyTo(sharedBuffer);
return FinishWriteAsync(stream.WriteAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer);
}
}
private static async ValueTask FinishWriteAsync(Task writeTask, byte[] localBuffer)
{
try
{
await writeTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}