blob: b0d07ba54757e36eedbb30b8fc054b74a8c387e0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Server;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Utils;
namespace Apache.Arrow.Flight.TestWeb
{
public class TestFlightServer : FlightServer
{
private readonly FlightStore _flightStore;
public TestFlightServer(FlightStore flightStore)
{
_flightStore = flightStore;
}
public override async Task DoAction(FlightAction request, IAsyncStreamWriter<FlightResult> responseStream, ServerCallContext context)
{
switch (request.Type)
{
case "test":
await responseStream.WriteAsync(new FlightResult("test data"));
break;
default:
throw new NotImplementedException();
}
}
public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context)
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8());
if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
await responseStream.SetupStream(flightHolder.GetFlightInfo().Schema);
var batches = flightHolder.GetRecordBatches();
foreach (var batch in batches)
{
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);
}
}
}
public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context)
{
var flightDescriptor = await requestStream.FlightDescriptor;
if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"grpc+tcp://{context.Host}");
_flightStore.Flights.Add(flightDescriptor, flightHolder);
}
while (await requestStream.MoveNext())
{
var applicationMetadata = requestStream.ApplicationMetadata.FirstOrDefault();
flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, applicationMetadata));
await responseStream.WriteAsync(
applicationMetadata == null ? FlightPutResult.Empty : new FlightPutResult(applicationMetadata));
}
}
public override Task<FlightInfo> GetFlightInfo(FlightDescriptor request, ServerCallContext context)
{
if (_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo());
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
public override async Task Handshake(IAsyncStreamReader<FlightHandshakeRequest> requestStream, IAsyncStreamWriter<FlightHandshakeResponse> responseStream, ServerCallContext context)
{
while (await requestStream.MoveNext().ConfigureAwait(false))
{
if (requestStream.Current.Payload.ToStringUtf8() == "Hello")
{
await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))).ConfigureAwait(false);
}
else
{
await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false);
}
}
}
public override Task<Schema> GetSchema(FlightDescriptor request, ServerCallContext context)
{
if (_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo().Schema);
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
public override async Task ListActions(IAsyncStreamWriter<FlightActionType> responseStream, ServerCallContext context)
{
await responseStream.WriteAsync(new FlightActionType("get", "get a flight"));
await responseStream.WriteAsync(new FlightActionType("put", "add a flight"));
await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight"));
await responseStream.WriteAsync(new FlightActionType("test", "test action"));
}
public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter<FlightInfo> responseStream, ServerCallContext context)
{
var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList();
foreach (var flightInfo in flightInfos)
{
await responseStream.WriteAsync(flightInfo);
}
}
public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context)
{
while (await requestStream.MoveNext().ConfigureAwait(false))
{
await responseStream.WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()).ConfigureAwait(false);
}
}
}
}