blob: b7dd24597525ac4dd9d1d16911611987b3cdd723 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Sql.Client;
using Arrow.Flight.Protocol.Sql;
using Google.Protobuf;
using Grpc.Core;
namespace Apache.Arrow.Flight.Sql;
public class PreparedStatement : IDisposable, IAsyncDisposable
{
private readonly FlightSqlClient _client;
private readonly string _handle;
private RecordBatch? _recordsBatch;
private bool _isClosed;
public Schema DatasetSchema { get; }
public Schema ParameterSchema { get; }
public bool IsClosed => _isClosed;
public string Handle => _handle;
public RecordBatch? ParametersBatch => _recordsBatch;
/// <summary>
/// Initializes a new instance of the <see cref="PreparedStatement"/> class.
/// </summary>
/// <param name="client">The Flight SQL client used for executing SQL operations.</param>
/// <param name="handle">The handle representing the prepared statement.</param>
/// <param name="datasetSchema">The schema of the result dataset.</param>
/// <param name="parameterSchema">The schema of the parameters for this prepared statement.</param>
public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSchema, Schema parameterSchema)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
_handle = handle ?? throw new ArgumentNullException(nameof(handle));
DatasetSchema = datasetSchema ?? throw new ArgumentNullException(nameof(datasetSchema));
ParameterSchema = parameterSchema ?? throw new ArgumentNullException(nameof(parameterSchema));
_isClosed = false;
}
/// <summary>
/// Retrieves the schema associated with the prepared statement asynchronously.
/// </summary>
/// <param name="options">The options used to configure the Flight call.</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>A task representing the asynchronous operation, which returns the schema of the result set.</returns>
/// <exception cref="InvalidOperationException">Thrown when the schema is empty or invalid.</exception>
public async Task<Schema> GetSchemaAsync(FlightCallOptions? options = default, CancellationToken cancellationToken = default)
{
EnsureStatementIsNotClosed();
try
{
var command = new CommandPreparedStatementQuery
{
PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8)
};
var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize());
var schema = await _client.GetSchemaAsync(descriptor, options, cancellationToken).ConfigureAwait(false);
if (schema == null || !schema.FieldsList.Any())
{
throw new InvalidOperationException("Schema is empty or invalid.");
}
return schema;
}
catch (RpcException ex)
{
throw new InvalidOperationException("Failed to retrieve the schema for the prepared statement", ex);
}
}
/// <summary>
/// Closes the prepared statement asynchronously.
/// </summary>
/// <param name="options">The options used to configure the Flight call.</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>A task representing the asynchronous operation.</returns>
/// <exception cref="InvalidOperationException">Thrown if closing the prepared statement fails.</exception>
public async Task CloseAsync(FlightCallOptions? options = default, CancellationToken cancellationToken = default)
{
EnsureStatementIsNotClosed();
try
{
var closeRequest = new ActionClosePreparedStatementRequest
{
PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8)
};
var action = new FlightAction(SqlAction.CloseRequest, closeRequest.PackAndSerialize());
await foreach (var result in _client.DoActionAsync(action, options, cancellationToken).ConfigureAwait(false))
{
// Just drain the results to complete the operation
}
_isClosed = true;
}
catch (RpcException ex)
{
throw new InvalidOperationException("Failed to close the prepared statement", ex);
}
}
/// <summary>
/// Reads the result from an asynchronous stream of FlightData and populates the provided Protobuf message.
/// </summary>
/// <param name="results">The asynchronous stream of <see cref="FlightData"/> objects.</param>
/// <param name="message">The Protobuf message to populate with the data from the stream.</param>
/// <returns>A task that represents the asynchronous read operation.</returns>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="results"/> or <paramref name="message"/> is null.</exception>
/// <exception cref="InvalidOperationException">Thrown if parsing the data fails.</exception>
public async Task ReadResultAsync(IAsyncEnumerable<FlightData> results, IMessage message)
{
if (results == null) throw new ArgumentNullException(nameof(results));
if (message == null) throw new ArgumentNullException(nameof(message));
await foreach (var flightData in results.ConfigureAwait(false))
{
if (flightData.DataBody == null || flightData.DataBody.Length == 0)
continue;
try
{
message.MergeFrom(message.PackAndSerialize());
}
catch (InvalidProtocolBufferException ex)
{
throw new InvalidOperationException("Failed to parse the received FlightData into the specified message.", ex);
}
}
}
/// <summary>
/// Parses the response of a prepared statement execution from the FlightData stream.
/// </summary>
/// <param name="client">The Flight SQL client.</param>
/// <param name="results">The asynchronous stream of <see cref="FlightData"/> objects.</param>
/// <returns>A task representing the asynchronous operation, which returns the populated <see cref="PreparedStatement"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="client"/> or <paramref name="results"/> is null.</exception>
/// <exception cref="InvalidOperationException">Thrown if the prepared statement handle or data is invalid.</exception>
public async Task<PreparedStatement> ParseResponseAsync(FlightSqlClient client, IAsyncEnumerable<FlightData> results)
{
if (client == null)
throw new ArgumentNullException(nameof(client));
if (results == null)
throw new ArgumentNullException(nameof(results));
var preparedStatementResult = new ActionCreatePreparedStatementResult();
await foreach (var flightData in results.ConfigureAwait(false))
{
if (flightData.DataBody == null || flightData.DataBody.Length == 0)
{
continue;
}
try
{
preparedStatementResult.MergeFrom(flightData.DataBody.ToByteArray());
}
catch (InvalidProtocolBufferException ex)
{
throw new InvalidOperationException("Failed to parse FlightData into ActionCreatePreparedStatementResult.", ex);
}
}
if (preparedStatementResult.PreparedStatementHandle.Length == 0)
{
throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle.");
}
Schema datasetSchema = null!;
Schema parameterSchema = null!;
if (preparedStatementResult.DatasetSchema.Length > 0)
{
datasetSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.DatasetSchema.ToByteArray());
}
if (preparedStatementResult.ParameterSchema.Length > 0)
{
parameterSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.ParameterSchema.ToByteArray());
}
// Create and return the PreparedStatement object
return new PreparedStatement(client, preparedStatementResult.PreparedStatementHandle.ToStringUtf8(),
datasetSchema, parameterSchema);
}
/// <summary>
/// Binds the specified parameter batch to the prepared statement and returns the status.
/// </summary>
/// <param name="parameterBatch">The <see cref="RecordBatch"/> containing parameters to bind to the statement.</param>
/// <param name="cancellationToken">A cancellation token for the binding operation.</param>
/// <returns>A <see cref="Status"/> indicating success or failure.</returns>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="parameterBatch"/> is null.</exception>
public void SetParameters(RecordBatch parameterBatch)
{
_recordsBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch));
}
/// <summary>
/// Executes the prepared statement asynchronously and retrieves the query results as <see cref="FlightInfo"/>.
/// </summary>
/// <param name="options">Optional <see cref="FlightCallOptions"/>The <see cref="FlightCallOptions"/> for the operation, which may include timeouts, headers, and other options for the call.</param>
/// <param name="cancellationToken">Optional <see cref="CancellationToken"/> to observe while waiting for the task to complete. The task will be canceled if the token is canceled.</param>
/// <returns>A <see cref="Task{FlightInfo}"/> representing the asynchronous operation. The task result contains the <see cref="FlightInfo"/> describing the executed query results.</returns>
/// <exception cref="InvalidOperationException">Thrown if the prepared statement is closed or if there is an error during execution.</exception>
/// <exception cref="OperationCanceledException">Thrown if the operation is canceled by the <paramref name="cancellationToken"/>.</exception>
public async Task<FlightInfo> ExecuteAsync(FlightCallOptions? options = default, CancellationToken cancellationToken = default)
{
EnsureStatementIsNotClosed();
var command = new CommandPreparedStatementQuery
{
PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8),
};
var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize());
cancellationToken.ThrowIfCancellationRequested();
if (_recordsBatch != null)
{
await BindParametersAsync(descriptor, _recordsBatch, options, cancellationToken).ConfigureAwait(false);
}
cancellationToken.ThrowIfCancellationRequested();
return await _client.GetFlightInfoAsync(descriptor, options, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Executes a prepared update statement asynchronously with the provided parameter batch.
/// </summary>
/// <remarks>
/// This method executes an update operation using a prepared statement. The provided <paramref name="parameterBatch"/>
/// is bound to the statement, and the operation is sent to the server. The server processes the update and returns
/// metadata indicating the number of affected rows.
///
/// This operation is asynchronous and can be canceled via the provided <paramref name="cancellationToken"/>.
/// </remarks>
/// <param name="parameterBatch">
/// A <see cref="RecordBatch"/> containing the parameters to be bound to the update statement.
/// This batch should match the schema expected by the prepared statement.
/// </param>
/// <param name="options">The <see cref="FlightCallOptions"/> for this execution, containing headers and other options.</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>
/// A <see cref="Task{TResult}"/> representing the asynchronous operation.
/// The task result contains the number of rows affected by the update.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Thrown if <paramref name="parameterBatch"/> is null, as a valid parameter batch is required for execution.
/// </exception>
/// <exception cref="InvalidOperationException">
/// Thrown if the update operation fails for any reason, including when the server returns invalid or empty metadata,
/// or if the operation is canceled via the <paramref name="cancellationToken"/>.
/// </exception>
/// <example>
/// The following example demonstrates how to use the <see cref="ExecuteUpdateAsync"/> method to execute an update operation:
/// <code>
/// var parameterBatch = CreateParameterBatch();
/// var affectedRows = await preparedStatement.ExecuteUpdateAsync(new FlightCallOptions(), parameterBatch);
/// Console.WriteLine($"Rows affected: {affectedRows}");
/// </code>
/// </example>
public async Task<long> ExecuteUpdateAsync(
RecordBatch parameterBatch,
FlightCallOptions? options = default,
CancellationToken cancellationToken = default)
{
if (parameterBatch == null)
{
throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null.");
}
var command = new CommandPreparedStatementQuery
{
PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8),
};
var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize());
var metadata = await BindParametersAsync(descriptor, parameterBatch, options, cancellationToken).ConfigureAwait(false);
try
{
return ParseAffectedRows(metadata);
}
catch (OperationCanceledException)
{
throw new InvalidOperationException("Update operation was canceled.");
}
catch (Exception ex)
{
throw new InvalidOperationException("Failed to execute the prepared update statement.", ex);
}
}
private long ParseAffectedRows(ByteString metadata)
{
if (metadata == null || metadata.Length == 0)
{
throw new InvalidOperationException("Server returned empty metadata, unable to determine affected row count.");
}
var updateResult = new DoPutUpdateResult();
updateResult.MergeFrom(metadata);
return updateResult.RecordCount;
}
/// <summary>
/// Binds parameters to the prepared statement by streaming the given RecordBatch to the server asynchronously.
/// </summary>
/// <param name="descriptor">The <see cref="FlightDescriptor"/> that identifies the statement or command being executed.</param>
/// <param name="parameterBatch">The <see cref="RecordBatch"/> containing the parameters to bind to the prepared statement.</param>
/// <param name="options">The <see cref="FlightCallOptions"/> for the operation, which may include timeouts, headers, and other options for the call.</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>A <see cref="Task{ByteString}"/> that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters.</returns>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="parameterBatch"/> is null.</exception>
/// <exception cref="InvalidOperationException">Thrown if the operation is canceled or if there is an error during the DoPut operation.</exception>
public async Task<ByteString> BindParametersAsync(
FlightDescriptor descriptor,
RecordBatch parameterBatch,
FlightCallOptions? options = default,
CancellationToken cancellationToken = default)
{
if (parameterBatch == null)
{
throw new ArgumentNullException(nameof(parameterBatch), @"Parameter batch cannot be null.");
}
var putResult = await _client.DoPutAsync(descriptor, parameterBatch, options, cancellationToken).ConfigureAwait(false);
try
{
var metadata = putResult.ApplicationMetadata;
return metadata;
}
catch (OperationCanceledException)
{
throw new InvalidOperationException("Parameter binding was canceled.");
}
catch (Exception ex)
{
throw new InvalidOperationException("Failed to bind parameters to the prepared statement.", ex);
}
}
/// <summary>
/// Ensures that the statement is not already closed.
/// </summary>
private void EnsureStatementIsNotClosed()
{
if (_isClosed)
throw new InvalidOperationException("Cannot execute a closed statement.");
}
/// <summary>
/// Disposes of the resources used by the prepared statement.
/// </summary>
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Disposes of the resources used by the prepared statement.
/// </summary>
/// <param name="disposing">Whether the method is called from <see cref="Dispose()"/>.</param>
protected virtual void Dispose(bool disposing)
{
if (_isClosed) return;
if (disposing)
{
DisposeAsync().GetAwaiter().GetResult();
}
_isClosed = true;
}
public async ValueTask DisposeAsync()
{
if (!_isClosed)
{
await CloseAsync(new FlightCallOptions());
_isClosed = true;
}
}
}