blob: f28629962f8b2c182bdd6cf06505bfd013a44694 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Server;
using Apache.Arrow.Types;
using Arrow.Flight.Protocol.Sql;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
using Microsoft.Extensions.Logging;
namespace Apache.Arrow.Flight.Sql;
public abstract class FlightSqlServer : FlightServer
{
private ILogger? Logger { get; }
public static readonly Schema CatalogSchema = new(new List<Field> { new("catalog_name", StringType.Default, false) }, null);
public static readonly Schema TableTypesSchema = new(new List<Field> { new("table_type", StringType.Default, false) }, null);
public static readonly Schema DbSchemaFlightSchema = new(new List<Field> { new("catalog_name", StringType.Default, true), new("db_schema_name", StringType.Default, false) }, null);
public static readonly Schema PrimaryKeysSchema = new(new List<Field>
{
new("catalog_name", StringType.Default, true),
new("db_schema_name", StringType.Default, true),
new("table_name", StringType.Default, false),
new("column_name", StringType.Default, false),
new("key_sequence", Int32Type.Default, false),
new("key_name", StringType.Default, true)
}, null);
public static readonly Schema KeyImportExportSchema = new(new List<Field>
{
new("pk_catalog_name", StringType.Default, true),
new("pk_db_schema_name", StringType.Default, true),
new("pk_table_name", StringType.Default, false),
new("pk_column_name", StringType.Default, false),
new("fk_catalog_name", StringType.Default, true),
new("fk_db_schema_name", StringType.Default, true),
new("fk_table_name", StringType.Default, false),
new("fk_column_name", StringType.Default, false),
new("key_sequence", Int32Type.Default, false),
new("fk_key_name", StringType.Default, true),
new("pk_key_name", StringType.Default, true),
new("update_rule", UInt8Type.Default, false),
new("delete_rule", UInt8Type.Default, false)
}, null);
public static readonly Schema TypeInfoSchema = new(new List<Field>
{
new("type_name", StringType.Default, false),
new("data_type", Int32Type.Default, false),
new("column_size", Int32Type.Default, true),
new("literal_prefix", StringType.Default, true),
new("literal_suffix", StringType.Default, true),
new("create_params", new ListType(new Field("item", StringType.Default, false)), true),
new("nullable", Int32Type.Default, false),
new("case_sensitive", BooleanType.Default, false),
new("searchable", Int32Type.Default, false),
new("unsigned_attribute", BooleanType.Default, true),
new("fixed_prec_scale", BooleanType.Default, false),
new("auto_increment", BooleanType.Default, true),
new("local_type_name", StringType.Default, true),
new("minimum_scale", Int32Type.Default, true),
new("maximum_scale", Int32Type.Default, true),
new("sql_data_type", Int32Type.Default, false),
new("datetime_subcode", Int32Type.Default, true),
new("num_prec_radix", Int32Type.Default, true),
new("interval_precision", Int32Type.Default, true)
}, null);
public static readonly Schema SqlInfoSchema = new(new List<Field>
{
new("info_name", UInt32Type.Default, false)
//TODO: once we have union serialization in Arrow Flight for .Net we should to add these fields
// fieldList.Add(new Field("value", new UnionType(new List<Field>(), new List<byte>()), false));
// fieldList.Add(new Field("value", new UnionType(new []
// {
// new Field("string_value", StringType.Default, false),
// new Field("bool_value", BooleanType.Default, false),
// new Field("bigint_value", Int64Type.Default, false),
// new Field("bool_value", BooleanType.Default, false),
// new Field("bigint_value", Int64Type.Default, false),
// new Field("int32_bitmask", Int32Type.Default, false),
// new Field("string_list", new ListType(new Field("item", StringType.Default, false)), false),
// new Field("int32_to_int32_list_map", new DictionaryType(Int32Type.Default, new ListType(Int32Type.Default), false), false),
// }, new []{(byte)ArrowTypeId.String, (byte)ArrowTypeId.Boolean, (byte)ArrowTypeId.Int64,/* (byte)3, (byte)4, (byte)5*/}, UnionMode.Dense), false));
}, null);
private static readonly Schema s_tableSchema = new(new List<Field>
{
new("catalog_name", StringType.Default, true),
new("db_schema_name", StringType.Default, true),
new("table_name", StringType.Default, false),
new("table_type", StringType.Default, false)
}, null);
public static Schema GetTableSchema(bool includeTableSchemaField)
{
if (!includeTableSchemaField)
{
return s_tableSchema;
}
var fields = s_tableSchema.FieldsList.ToList();
fields.Add(new Field("table_schema", BinaryType.Default, false));
return new Schema(fields, s_tableSchema.Metadata);
}
public static IMessage? GetCommand(FlightTicket ticket)
{
try
{
return GetCommand(Any.Parser.ParseFrom(ticket.Ticket));
}
catch (InvalidProtocolBufferException) { } //The ticket is not a flight sql command
return null;
}
public static async Task<IMessage?> GetCommand(FlightServerRecordBatchStreamReader requestStream)
{
return GetCommand(await requestStream.FlightDescriptor.ConfigureAwait(false));
}
public static IMessage? GetCommand(FlightDescriptor? request)
{
if (request == null) return null;
if (request.Type == FlightDescriptorType.Command && request.ParsedAndUnpackedMessage() is { } command)
{
return command;
}
return null;
}
private static IMessage? GetCommand(Any command)
{
if (command.Is(CommandPreparedStatementQuery.Descriptor))
{
return command.Unpack<CommandPreparedStatementQuery>();
}
if (command.Is(CommandGetSqlInfo.Descriptor))
{
return command.Unpack<CommandGetSqlInfo>();
}
if (command.Is(CommandGetCatalogs.Descriptor))
{
return command.Unpack<CommandGetCatalogs>();
}
if (command.Is(CommandGetTableTypes.Descriptor))
{
return command.Unpack<CommandGetTableTypes>();
}
if (command.Is(CommandGetTables.Descriptor))
{
return command.Unpack<CommandGetTables>();
}
if (command.Is(CommandGetDbSchemas.Descriptor))
{
return command.Unpack<CommandGetDbSchemas>();
}
if (command.Is(CommandGetPrimaryKeys.Descriptor))
{
return command.Unpack<CommandGetPrimaryKeys>();
}
if (command.Is(CommandGetExportedKeys.Descriptor))
{
return command.Unpack<CommandGetExportedKeys>();
}
if (command.Is(CommandGetImportedKeys.Descriptor))
{
return command.Unpack<CommandGetImportedKeys>();
}
if (command.Is(CommandGetCrossReference.Descriptor))
{
return command.Unpack<CommandGetCrossReference>();
}
if (command.Is(CommandGetXdbcTypeInfo.Descriptor))
{
return command.Unpack<CommandGetXdbcTypeInfo>();
}
return null;
}
protected FlightSqlServer(ILoggerFactory? factory = null)
{
Logger = factory?.CreateLogger(typeof(FlightSqlServer));
}
/// <summary>
/// Lists actions supported by Flight SQL. For Flight RPC actions support
/// implementers should extend this method to return additional supported actions.
/// </summary>
public override async Task ListActions(IAsyncStreamWriter<FlightActionType> responseStream, ServerCallContext context)
{
foreach (var actionType in FlightSqlUtils.FlightSqlActions)
{
await responseStream.WriteAsync(actionType).ConfigureAwait(false);
}
}
/// <summary>
/// Attempts to execute a valid Flight SQL command. For Flight RPC calls
/// implementers should extend this method in order to handle RPC messages.
/// </summary>
public override Task<FlightInfo> GetFlightInfo(FlightDescriptor flightDescriptor, ServerCallContext context)
{
var sqlCommand = GetCommand(flightDescriptor);
Logger?.LogTrace("Executing Flight SQL FlightInfo command: {DescriptorName}", sqlCommand?.Descriptor.Name);
return sqlCommand switch
{
CommandStatementQuery command => GetStatementQueryFlightInfo(command, flightDescriptor, context),
CommandPreparedStatementQuery command => GetPreparedStatementQueryFlightInfo(command, flightDescriptor, context),
CommandGetCatalogs command => GetCatalogFlightInfo(command, flightDescriptor, context),
CommandGetDbSchemas command => GetDbSchemaFlightInfo(command, flightDescriptor, context),
CommandGetTables command => GetTablesFlightInfo(command, flightDescriptor, context),
CommandGetTableTypes command => GetTableTypesFlightInfo(command, flightDescriptor, context),
CommandGetSqlInfo command => GetSqlFlightInfo(command, flightDescriptor, context),
CommandGetPrimaryKeys command => GetPrimaryKeysFlightInfo(command, flightDescriptor, context),
CommandGetExportedKeys command => GetExportedKeysFlightInfo(command, flightDescriptor, context),
CommandGetImportedKeys command => GetImportedKeysFlightInfo(command, flightDescriptor, context),
CommandGetCrossReference command => GetCrossReferenceFlightInfo(command, flightDescriptor, context),
CommandGetXdbcTypeInfo command => GetXdbcTypeFlightInfo(command, flightDescriptor, context),
_ => throw new InvalidOperationException($"command type {sqlCommand?.Descriptor?.Name} not supported")
};
}
/// <summary>
/// Attempts to execute a valid Flight SQL command. For Flight RPC calls
/// implementers should extend this method in order to handle RPC messages.
/// </summary>
public override Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context)
{
var sqlCommand = GetCommand(ticket);
Logger?.LogTrace("Executing Flight SQL DoGet command: {SqlCommandDescriptor}", sqlCommand?.Descriptor);
return sqlCommand switch
{
CommandPreparedStatementQuery command => DoGetPreparedStatementQuery(command, responseStream, context),
CommandGetSqlInfo command => DoGetSqlInfo(command, responseStream, context),
CommandGetCatalogs command => DoGetCatalog(command, responseStream, context),
CommandGetTableTypes command => DoGetTableType(command, responseStream, context),
CommandGetTables command => DoGetTables(command, responseStream, context),
CommandGetDbSchemas command => DoGetDbSchema(command, responseStream, context),
CommandGetPrimaryKeys command => DoGetPrimaryKeys(command, responseStream, context),
CommandGetExportedKeys command => DoGetExportedKeys(command, responseStream, context),
CommandGetImportedKeys command => DoGetImportedKeys(command, responseStream, context),
CommandGetCrossReference command => DoGetCrossReference(command, responseStream, context),
CommandGetXdbcTypeInfo command => DoGetXbdcTypeInfo(command, responseStream, context),
_ => throw new RpcException(new Status(StatusCode.InvalidArgument, $"DoGet command {sqlCommand?.Descriptor} is not supported."))
};
}
/// <summary>
/// Attempts to execute a valid Flight SQL command. For Flight RPC calls
/// implementers should extend this method in order to handle RPC messages.
/// </summary>
public override Task DoAction(FlightAction action, IAsyncStreamWriter<FlightResult> responseStream, ServerCallContext context)
{
Logger?.LogTrace("Executing Flight SQL DoAction: {ActionType}", action.Type);
switch (action.Type)
{
case SqlAction.CreateRequest:
var command = FlightSqlUtils.ParseAndUnpack<ActionCreatePreparedStatementRequest>(action.Body);
return CreatePreparedStatement(command, action, responseStream, context);
case SqlAction.CloseRequest:
var closeCommand = FlightSqlUtils.ParseAndUnpack<ActionClosePreparedStatementRequest>(action.Body);
return ClosePreparedStatement(closeCommand, action, responseStream, context);
default:
throw new NotImplementedException($"Action type {action.Type} not supported");
}
}
/// <summary>
/// Attempts to execute a valid Flight SQL command. For Flight RPC calls
/// implementers should extend this method in order to handle RPC messages.
/// </summary>
public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context)
{
if (await GetCommand(requestStream).ConfigureAwait(false) is { } command)
{
await DoPutInternal(command, requestStream, responseStream, context).ConfigureAwait(false);
}
else
{
throw new NotImplementedException();
}
}
private Task DoPutInternal(IMessage command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context)
{
Logger?.LogTrace("Executing Flight SQL DoAction: {DescriptorName}", command.Descriptor.Name);
return command switch
{
CommandStatementUpdate statementUpdate => PutStatementUpdate(statementUpdate, requestStream, responseStream, context),
CommandPreparedStatementQuery preparedStatementQuery => PutPreparedStatementQuery(preparedStatementQuery, requestStream, responseStream, context),
CommandPreparedStatementUpdate preparedStatementUpdate => PutPreparedStatementUpdate(preparedStatementUpdate, requestStream, responseStream, context),
_ => throw new NotImplementedException($"Command {command.Descriptor.Name} not supported")
};
}
public static bool SupportsAction(FlightAction action)
{
switch (action.Type)
{
case SqlAction.CreateRequest:
case SqlAction.CloseRequest:
return true;
default:
return false;
}
}
#region FlightInfo
protected abstract Task<FlightInfo> GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetCatalogFlightInfo(CommandGetCatalogs command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetDbSchemaFlightInfo(CommandGetDbSchemas command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetTablesFlightInfo(CommandGetTables command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetTableTypesFlightInfo(CommandGetTableTypes command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetSqlFlightInfo(CommandGetSqlInfo commandGetSqlInfo, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetPrimaryKeysFlightInfo(CommandGetPrimaryKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetExportedKeysFlightInfo(CommandGetExportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetImportedKeysFlightInfo(CommandGetImportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetCrossReferenceFlightInfo(CommandGetCrossReference command, FlightDescriptor flightDescriptor, ServerCallContext context);
protected abstract Task<FlightInfo> GetXdbcTypeFlightInfo(CommandGetXdbcTypeInfo command, FlightDescriptor flightDescriptor, ServerCallContext context);
#endregion
#region DoGet
protected abstract Task DoGetPreparedStatementQuery(CommandPreparedStatementQuery preparedStatementQuery, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetSqlInfo(CommandGetSqlInfo getSqlInfo, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetCatalog(CommandGetCatalogs command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetTableType(CommandGetTableTypes command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetTables(CommandGetTables command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetPrimaryKeys(CommandGetPrimaryKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetDbSchema(CommandGetDbSchemas command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetExportedKeys(CommandGetExportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetImportedKeys(CommandGetImportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetCrossReference(CommandGetCrossReference command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
protected abstract Task DoGetXbdcTypeInfo(CommandGetXdbcTypeInfo command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
#endregion
#region DoAction
protected abstract Task CreatePreparedStatement(ActionCreatePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter<FlightResult> streamWriter, ServerCallContext context);
protected abstract Task ClosePreparedStatement(ActionClosePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter<FlightResult> streamWriter, ServerCallContext context);
#endregion
#region DoPut
protected abstract Task PutPreparedStatementUpdate(CommandPreparedStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context);
protected abstract Task PutStatementUpdate(CommandStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context);
protected abstract Task PutPreparedStatementQuery(CommandPreparedStatementQuery command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter<FlightPutResult> responseStream, ServerCallContext context);
#endregion
}