blob: ae6e2e4b03b6db0c0a8bde4e2dbd73d94e132001 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Server;
using Grpc.Core;
using Grpc.Core.Utils;
namespace Apache.Arrow.Flight.TestWeb
{
public class TestFlightServer : FlightServer
{
private readonly FlightStore _flightStore;
public TestFlightServer(FlightStore flightStore)
{
_flightStore = flightStore;
}
public override async Task DoAction(FlightAction request, IAsyncStreamWriter<FlightResult> responseStream, ServerCallContext context)
{
switch (request.Type)
{
case "test":
await responseStream.WriteAsync(new FlightResult("test data"));
break;
default:
throw new NotImplementedException();
}
}
public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context)
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8());
if(_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
var batches = flightHolder.GetRecordBatches();
foreach(var batch in batches)
{
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);
}
}
}
public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context)
{
var flightDescriptor = await requestStream.FlightDescriptor;
if(!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}");
_flightStore.Flights.Add(flightDescriptor, flightHolder);
}
while (await requestStream.MoveNext())
{
flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()));
await responseStream.WriteAsync(FlightPutResult.Empty);
}
}
public override Task<FlightInfo> GetFlightInfo(FlightDescriptor request, ServerCallContext context)
{
if(_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo());
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
public override Task<Schema> GetSchema(FlightDescriptor request, ServerCallContext context)
{
if(_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo().Schema);
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
public override async Task ListActions(IAsyncStreamWriter<FlightActionType> responseStream, ServerCallContext context)
{
await responseStream.WriteAsync(new FlightActionType("get", "get a flight"));
await responseStream.WriteAsync(new FlightActionType("put", "add a flight"));
await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight"));
await responseStream.WriteAsync(new FlightActionType("test", "test action"));
}
public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter<FlightInfo> responseStream, ServerCallContext context)
{
var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList();
foreach(var flightInfo in flightInfos)
{
await responseStream.WriteAsync(flightInfo);
}
}
}
}