blob: a7aaad4fb2d84366e22957b8caa8ef9a09d997e1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Server;
using Apache.Arrow.Flight.Sql;
using Apache.Arrow.Types;
using Arrow.Flight.Protocol.Sql;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
namespace Apache.Arrow.Flight.TestWeb;
public class TestFlightSqlServer : FlightServer
{
private readonly FlightStore _flightStore;
public TestFlightSqlServer(FlightStore flightStore)
{
_flightStore = flightStore;
}
public override async Task DoAction(FlightAction request, IAsyncStreamWriter<FlightResult> responseStream,
ServerCallContext context)
{
switch (request.Type)
{
case "test":
await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false);
break;
case SqlAction.GetPrimaryKeysRequest:
await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false);
break;
case SqlAction.CancelFlightInfoRequest:
var cancelRequest = new FlightInfoCancelResult();
cancelRequest.SetStatus(1);
await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray()))
.ConfigureAwait(false);
break;
case SqlAction.BeginTransactionRequest:
case SqlAction.CommitRequest:
case SqlAction.RollbackRequest:
await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id")))
.ConfigureAwait(false);
break;
case SqlAction.CreateRequest:
case SqlAction.CloseRequest:
var schema = new Schema.Builder()
.Field(f => f.Name("id").DataType(Int32Type.Default))
.Field(f => f.Name("name").DataType(StringType.Default))
.Build();
var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema);
var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema);
var preparedStatementResponse = new ActionCreatePreparedStatementResult
{
PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement"),
DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes),
ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes)
};
byte[] packedResult = Any.Pack(preparedStatementResponse).Serialize().ToByteArray();
var flightResult = new FlightResult(packedResult);
await responseStream.WriteAsync(flightResult).ConfigureAwait(false);
break;
default:
throw new NotImplementedException();
}
}
public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream,
ServerCallContext context)
{
FlightDescriptor flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8());
if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
var batches = flightHolder.GetRecordBatches();
foreach (var batch in batches)
{
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);
}
}
}
public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context)
{
var flightDescriptor = await requestStream.FlightDescriptor;
if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}");
_flightStore.Flights.Add(flightDescriptor, flightHolder);
}
int affectedRows = 0;
while (await requestStream.MoveNext())
{
// Process the record batch (if needed) here
// Increment the affected row count for demonstration purposes
affectedRows += requestStream.Current.Column(0).Length; // Example of counting rows in the first column
}
// Create a DoPutUpdateResult with the affected row count
var updateResult = new DoPutUpdateResult
{
RecordCount = affectedRows // Set the actual affected row count
};
// Serialize the DoPutUpdateResult into a ByteString
var metadata = updateResult.ToByteString();
// Send the metadata back as part of the FlightPutResult
var flightPutResult = new FlightPutResult(metadata);
await responseStream.WriteAsync(flightPutResult);
}
public override Task<FlightInfo> GetFlightInfo(FlightDescriptor request, ServerCallContext context)
{
if (_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo());
}
if (_flightStore.Flights.Count > 0)
{
return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo());
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
public override Task<Schema> GetSchema(FlightDescriptor request, ServerCallContext context)
{
if (_flightStore.Flights.TryGetValue(request, out var flightHolder))
{
return Task.FromResult(flightHolder.GetFlightInfo().Schema);
}
if (_flightStore.Flights.Count > 0)
{
return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema);
}
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
}