| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| mod common; |
| |
| use crate::common::fixture::TestFixture; |
| use crate::common::utils::make_primitive_batch; |
| |
| use arrow_array::RecordBatch; |
| use arrow_flight::decode::FlightRecordBatchStream; |
| use arrow_flight::encode::FlightDataEncoderBuilder; |
| use arrow_flight::error::FlightError; |
| use arrow_flight::flight_service_server::FlightServiceServer; |
| use arrow_flight::sql::client::FlightSqlServiceClient; |
| use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; |
| use arrow_flight::sql::{ |
| ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest, |
| CommandStatementIngest, EndTransaction, FallibleRequestStream, ProstMessageExt, SqlInfo, |
| TableDefinitionOptions, TableExistsOption, TableNotExistOption, |
| }; |
| use arrow_flight::{Action, FlightData, FlightDescriptor}; |
| use futures::{StreamExt, TryStreamExt}; |
| use prost::Message; |
| use std::collections::HashMap; |
| use std::sync::Arc; |
| use tokio::sync::Mutex; |
| use tonic::{IntoStreamingRequest, Request, Status}; |
| use uuid::Uuid; |
| |
| #[tokio::test] |
| pub async fn test_begin_end_transaction() { |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| |
| // begin commit |
| let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); |
| flight_sql_client |
| .end_transaction(transaction_id, EndTransaction::Commit) |
| .await |
| .unwrap(); |
| |
| // begin rollback |
| let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); |
| flight_sql_client |
| .end_transaction(transaction_id, EndTransaction::Rollback) |
| .await |
| .unwrap(); |
| |
| // unknown transaction id |
| let transaction_id = "UnknownTransactionId".to_string().into(); |
| assert!( |
| flight_sql_client |
| .end_transaction(transaction_id, EndTransaction::Commit) |
| .await |
| .is_err() |
| ); |
| } |
| |
| #[tokio::test] |
| pub async fn test_execute_ingest() { |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| let cmd = make_ingest_command(); |
| let expected_rows = 10; |
| let batches = vec![ |
| make_primitive_batch(5), |
| make_primitive_batch(3), |
| make_primitive_batch(2), |
| ]; |
| let actual_rows = flight_sql_client |
| .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok)) |
| .await |
| .expect("ingest should succeed"); |
| assert_eq!(actual_rows, expected_rows); |
| // make sure the batches made it through to the server |
| let ingested_batches = test_server.ingested_batches.lock().await.clone(); |
| assert_eq!(ingested_batches, batches); |
| } |
| |
| #[tokio::test] |
| pub async fn test_execute_ingest_error() { |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| let cmd = make_ingest_command(); |
| // send an error from the client |
| let batches = vec![ |
| Ok(make_primitive_batch(5)), |
| Err(FlightError::NotYetImplemented( |
| "Client error message".to_string(), |
| )), |
| ]; |
| // make sure the client returns the error from the client |
| let err = flight_sql_client |
| .execute_ingest(cmd, futures::stream::iter(batches)) |
| .await |
| .unwrap_err(); |
| assert_eq!( |
| err.to_string(), |
| "External error: Not yet implemented: Client error message" |
| ); |
| } |
| |
| #[tokio::test] |
| pub async fn test_do_put_empty_stream() { |
| // Test for https://github.com/apache/arrow-rs/issues/7329 |
| |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| let cmd = make_ingest_command(); |
| |
| // Create an empty request stream |
| let input_data = futures::stream::iter(vec![]); |
| let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); |
| let flight_data_encoder = FlightDataEncoderBuilder::default() |
| .with_flight_descriptor(Some(flight_descriptor)) |
| .build(input_data); |
| let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap(); |
| let request_stream = futures::stream::iter(flight_data); |
| |
| // Execute a `do_put` and verify that the server error contains the expected message |
| let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); |
| assert!( |
| err.to_string() |
| .contains("Unhandled Error: Command is missing."), |
| ); |
| } |
| |
| #[tokio::test] |
| pub async fn test_do_put_first_element_err() { |
| // Test for https://github.com/apache/arrow-rs/issues/7329 |
| |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| let cmd = make_ingest_command(); |
| |
| let (sender, _receiver) = futures::channel::oneshot::channel(); |
| |
| // Create a fallible request stream such that the 1st element is a FlightError |
| let input_data = futures::stream::iter(vec![ |
| Err(FlightError::NotYetImplemented("random error".to_string())), |
| Ok(make_primitive_batch(5)), |
| ]); |
| let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); |
| let flight_data_encoder = FlightDataEncoderBuilder::default() |
| .with_flight_descriptor(Some(flight_descriptor)) |
| .build(input_data); |
| let flight_data: FallibleRequestStream<FlightData, FlightError> = |
| FallibleRequestStream::new(sender, Box::pin(flight_data_encoder)); |
| let request_stream = flight_data.into_streaming_request(); |
| |
| // Execute a `do_put` and verify that the server error contains the expected message |
| let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); |
| |
| assert!( |
| err.to_string() |
| .contains("Unhandled Error: Command is missing."), |
| ); |
| } |
| |
| #[tokio::test] |
| pub async fn test_do_put_missing_flight_descriptor() { |
| // Test for https://github.com/apache/arrow-rs/issues/7329 |
| |
| let test_server = FlightSqlServiceImpl::new(); |
| let fixture = TestFixture::new(test_server.service()).await; |
| let channel = fixture.channel().await; |
| let mut flight_sql_client = FlightSqlServiceClient::new(channel); |
| |
| // Create a request stream such that the flight descriptor is missing |
| let stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]); |
| let flight_data_encoder = FlightDataEncoderBuilder::default() |
| .with_flight_descriptor(None) |
| .build(stream); |
| let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap(); |
| let request_stream = futures::stream::iter(flight_data); |
| |
| // Execute a `do_put` and verify that the server error contains the expected message |
| let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); |
| assert!( |
| err.to_string() |
| .contains("Unhandled Error: Flight descriptor is missing."), |
| ); |
| } |
| |
| fn make_ingest_command() -> CommandStatementIngest { |
| CommandStatementIngest { |
| table_definition_options: Some(TableDefinitionOptions { |
| if_not_exist: TableNotExistOption::Create.into(), |
| if_exists: TableExistsOption::Fail.into(), |
| }), |
| table: String::from("test"), |
| schema: None, |
| catalog: None, |
| temporary: true, |
| transaction_id: None, |
| options: HashMap::default(), |
| } |
| } |
| |
| #[derive(Clone)] |
| pub struct FlightSqlServiceImpl { |
| transactions: Arc<Mutex<HashMap<String, ()>>>, |
| ingested_batches: Arc<Mutex<Vec<RecordBatch>>>, |
| } |
| |
| impl FlightSqlServiceImpl { |
| pub fn new() -> Self { |
| Self { |
| transactions: Arc::new(Mutex::new(HashMap::new())), |
| ingested_batches: Arc::new(Mutex::new(Vec::new())), |
| } |
| } |
| |
| /// Return an [`FlightServiceServer`] that can be used with a |
| /// [`Server`](tonic::transport::Server) |
| pub fn service(&self) -> FlightServiceServer<Self> { |
| // wrap up tonic goop |
| FlightServiceServer::new(self.clone()) |
| } |
| } |
| |
| impl Default for FlightSqlServiceImpl { |
| fn default() -> Self { |
| Self::new() |
| } |
| } |
| |
| #[tonic::async_trait] |
| impl FlightSqlService for FlightSqlServiceImpl { |
| type FlightService = FlightSqlServiceImpl; |
| |
| async fn do_action_begin_transaction( |
| &self, |
| _query: ActionBeginTransactionRequest, |
| _request: Request<Action>, |
| ) -> Result<ActionBeginTransactionResult, Status> { |
| let transaction_id = Uuid::new_v4().to_string(); |
| self.transactions |
| .lock() |
| .await |
| .insert(transaction_id.clone(), ()); |
| Ok(ActionBeginTransactionResult { |
| transaction_id: transaction_id.as_bytes().to_vec().into(), |
| }) |
| } |
| |
| async fn do_action_end_transaction( |
| &self, |
| query: ActionEndTransactionRequest, |
| _request: Request<Action>, |
| ) -> Result<(), Status> { |
| let transaction_id = String::from_utf8(query.transaction_id.to_vec()) |
| .map_err(|_| Status::invalid_argument("Invalid transaction id"))?; |
| if self |
| .transactions |
| .lock() |
| .await |
| .remove(&transaction_id) |
| .is_none() |
| { |
| return Err(Status::invalid_argument("Transaction id not found")); |
| } |
| Ok(()) |
| } |
| |
| async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} |
| |
| async fn do_put_statement_ingest( |
| &self, |
| _ticket: CommandStatementIngest, |
| request: Request<PeekableFlightDataStream>, |
| ) -> Result<i64, Status> { |
| let batches: Vec<RecordBatch> = FlightRecordBatchStream::new_from_flight_data( |
| request.into_inner().map_err(|e| e.into()), |
| ) |
| .try_collect() |
| .await?; |
| let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); |
| *self.ingested_batches.lock().await.as_mut() = batches; |
| Ok(affected_rows) |
| } |
| } |