| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use arrow_flight::sql::DoPutPreparedStatementResult; |
| use arrow_flight::sql::server::PeekableFlightDataStream; |
| use base64::Engine; |
| use base64::prelude::BASE64_STANDARD; |
| use core::str; |
| use futures::{Stream, TryStreamExt, stream}; |
| use once_cell::sync::Lazy; |
| use prost::Message; |
| use std::collections::HashSet; |
| use std::pin::Pin; |
| use std::str::FromStr; |
| use std::sync::Arc; |
| use tonic::metadata::MetadataValue; |
| use tonic::transport::Server; |
| use tonic::transport::{Certificate, Identity, ServerTlsConfig}; |
| use tonic::{Request, Response, Status, Streaming}; |
| |
| use arrow_array::builder::StringBuilder; |
| use arrow_array::{ArrayRef, RecordBatch}; |
| use arrow_flight::encode::FlightDataEncoderBuilder; |
| use arrow_flight::sql::metadata::{ |
| SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder, |
| }; |
| use arrow_flight::sql::{ |
| ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, |
| ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, |
| ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, |
| ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, |
| ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, |
| CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, |
| CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, |
| CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, |
| CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, |
| CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, |
| XdbcDataType, server::FlightSqlService, |
| }; |
| use arrow_flight::utils::batches_to_flight_data; |
| use arrow_flight::{ |
| Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, |
| HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, flight_service_server::FlightService, |
| flight_service_server::FlightServiceServer, |
| }; |
| use arrow_ipc::writer::IpcWriteOptions; |
| use arrow_schema::{ArrowError, DataType, Field, Schema}; |
| |
| macro_rules! status { |
| ($desc:expr, $err:expr) => { |
| Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) |
| }; |
| } |
| |
| const FAKE_TOKEN: &str = "uuid_token"; |
| const FAKE_HANDLE: &str = "uuid_handle"; |
| const FAKE_UPDATE_RESULT: i64 = 1; |
| |
| static INSTANCE_SQL_DATA: Lazy<SqlInfoData> = Lazy::new(|| { |
| let mut builder = SqlInfoDataBuilder::new(); |
| // Server information |
| builder.append(SqlInfo::FlightSqlServerName, "Example Flight SQL Server"); |
| builder.append(SqlInfo::FlightSqlServerVersion, "1"); |
| // 1.3 comes from https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/Schema.fbs#L24 |
| builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3"); |
| builder.build().unwrap() |
| }); |
| |
| static INSTANCE_XBDC_DATA: Lazy<XdbcTypeInfoData> = Lazy::new(|| { |
| let mut builder = XdbcTypeInfoDataBuilder::new(); |
| builder.append(XdbcTypeInfo { |
| type_name: "INTEGER".into(), |
| data_type: XdbcDataType::XdbcInteger, |
| column_size: Some(32), |
| literal_prefix: None, |
| literal_suffix: None, |
| create_params: None, |
| nullable: Nullable::NullabilityNullable, |
| case_sensitive: false, |
| searchable: Searchable::Full, |
| unsigned_attribute: Some(false), |
| fixed_prec_scale: false, |
| auto_increment: Some(false), |
| local_type_name: Some("INTEGER".into()), |
| minimum_scale: None, |
| maximum_scale: None, |
| sql_data_type: XdbcDataType::XdbcInteger, |
| datetime_subcode: None, |
| num_prec_radix: Some(2), |
| interval_precision: None, |
| }); |
| builder.build().unwrap() |
| }); |
| |
| static TABLES: Lazy<Vec<&'static str>> = Lazy::new(|| vec!["flight_sql.example.table"]); |
| |
| #[derive(Clone)] |
| pub struct FlightSqlServiceImpl {} |
| |
| impl FlightSqlServiceImpl { |
| #[allow(clippy::result_large_err)] |
| fn check_token<T>(&self, req: &Request<T>) -> Result<(), Status> { |
| let metadata = req.metadata(); |
| let auth = metadata.get("authorization").ok_or_else(|| { |
| Status::internal(format!("No authorization header! metadata = {metadata:?}")) |
| })?; |
| let str = auth |
| .to_str() |
| .map_err(|e| Status::internal(format!("Error parsing header: {e}")))?; |
| let authorization = str.to_string(); |
| let bearer = "Bearer "; |
| if !authorization.starts_with(bearer) { |
| Err(Status::internal("Invalid auth header!"))?; |
| } |
| let token = authorization[bearer.len()..].to_string(); |
| if token == FAKE_TOKEN { |
| Ok(()) |
| } else { |
| Err(Status::unauthenticated("invalid token ")) |
| } |
| } |
| |
| fn fake_result() -> Result<RecordBatch, ArrowError> { |
| let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, false)]); |
| let mut builder = StringBuilder::new(); |
| builder.append_value("Hello, FlightSQL!"); |
| let cols = vec![Arc::new(builder.finish()) as ArrayRef]; |
| RecordBatch::try_new(Arc::new(schema), cols) |
| } |
| } |
| |
| #[tonic::async_trait] |
| impl FlightSqlService for FlightSqlServiceImpl { |
| type FlightService = FlightSqlServiceImpl; |
| |
| async fn do_handshake( |
| &self, |
| request: Request<Streaming<HandshakeRequest>>, |
| ) -> Result< |
| Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>, |
| Status, |
| > { |
| let basic = "Basic "; |
| let authorization = request |
| .metadata() |
| .get("authorization") |
| .ok_or_else(|| Status::invalid_argument("authorization field not present"))? |
| .to_str() |
| .map_err(|e| status!("authorization not parsable", e))?; |
| if !authorization.starts_with(basic) { |
| Err(Status::invalid_argument(format!( |
| "Auth type not implemented: {authorization}" |
| )))?; |
| } |
| let base64 = &authorization[basic.len()..]; |
| let bytes = BASE64_STANDARD |
| .decode(base64) |
| .map_err(|e| status!("authorization not decodable", e))?; |
| let str = str::from_utf8(&bytes).map_err(|e| status!("authorization not parsable", e))?; |
| let parts: Vec<_> = str.split(':').collect(); |
| let (user, pass) = match parts.as_slice() { |
| [user, pass] => (user, pass), |
| _ => Err(Status::invalid_argument( |
| "Invalid authorization header".to_string(), |
| ))?, |
| }; |
| if user != &"admin" || pass != &"password" { |
| Err(Status::unauthenticated("Invalid credentials!"))? |
| } |
| |
| let result = HandshakeResponse { |
| protocol_version: 0, |
| payload: FAKE_TOKEN.into(), |
| }; |
| let result = Ok(result); |
| let output = futures::stream::iter(vec![result]); |
| |
| let token = format!("Bearer {FAKE_TOKEN}"); |
| let mut response: Response<Pin<Box<dyn Stream<Item = _> + Send>>> = |
| Response::new(Box::pin(output)); |
| response.metadata_mut().append( |
| "authorization", |
| MetadataValue::from_str(token.as_str()).unwrap(), |
| ); |
| return Ok(response); |
| } |
| |
| async fn do_get_fallback( |
| &self, |
| request: Request<Ticket>, |
| _message: Any, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| self.check_token(&request)?; |
| let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; |
| let schema = batch.schema_ref(); |
| let batches = vec![batch.clone()]; |
| let flight_data = batches_to_flight_data(schema, batches) |
| .map_err(|e| status!("Could not convert batches", e))? |
| .into_iter() |
| .map(Ok); |
| |
| let stream: Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>> = |
| Box::pin(stream::iter(flight_data)); |
| let resp = Response::new(stream); |
| Ok(resp) |
| } |
| |
| async fn get_flight_info_statement( |
| &self, |
| _query: CommandStatementQuery, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_statement not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_substrait_plan( |
| &self, |
| _query: CommandStatementSubstraitPlan, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_substrait_plan not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_prepared_statement( |
| &self, |
| cmd: CommandPreparedStatementQuery, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| self.check_token(&request)?; |
| let handle = std::str::from_utf8(&cmd.prepared_statement_handle) |
| .map_err(|e| status!("Unable to parse handle", e))?; |
| |
| let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; |
| let schema = (*batch.schema()).clone(); |
| let num_rows = batch.num_rows(); |
| let num_bytes = batch.get_array_memory_size(); |
| |
| let fetch = FetchResults { |
| handle: handle.to_string(), |
| }; |
| let buf = fetch.as_any().encode_to_vec().into(); |
| let ticket = Ticket { ticket: buf }; |
| let endpoint = FlightEndpoint { |
| ticket: Some(ticket), |
| location: vec![], |
| expiration_time: None, |
| app_metadata: vec![].into(), |
| }; |
| let info = FlightInfo::new() |
| .try_with_schema(&schema) |
| .map_err(|e| status!("Unable to serialize schema", e))? |
| .with_descriptor(FlightDescriptor::new_cmd(vec![])) |
| .with_endpoint(endpoint) |
| .with_total_records(num_rows as i64) |
| .with_total_bytes(num_bytes as i64) |
| .with_ordered(false); |
| |
| let resp = Response::new(info); |
| Ok(resp) |
| } |
| |
| async fn get_flight_info_catalogs( |
| &self, |
| query: CommandGetCatalogs, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| let flight_descriptor = request.into_inner(); |
| let ticket = Ticket { |
| ticket: query.as_any().encode_to_vec().into(), |
| }; |
| let endpoint = FlightEndpoint::new().with_ticket(ticket); |
| |
| let flight_info = FlightInfo::new() |
| .try_with_schema(&query.into_builder().schema()) |
| .map_err(|e| status!("Unable to encode schema", e))? |
| .with_endpoint(endpoint) |
| .with_descriptor(flight_descriptor); |
| |
| Ok(tonic::Response::new(flight_info)) |
| } |
| |
| async fn get_flight_info_schemas( |
| &self, |
| query: CommandGetDbSchemas, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| let flight_descriptor = request.into_inner(); |
| let ticket = Ticket { |
| ticket: query.as_any().encode_to_vec().into(), |
| }; |
| let endpoint = FlightEndpoint::new().with_ticket(ticket); |
| |
| let flight_info = FlightInfo::new() |
| .try_with_schema(&query.into_builder().schema()) |
| .map_err(|e| status!("Unable to encode schema", e))? |
| .with_endpoint(endpoint) |
| .with_descriptor(flight_descriptor); |
| |
| Ok(tonic::Response::new(flight_info)) |
| } |
| |
| async fn get_flight_info_tables( |
| &self, |
| query: CommandGetTables, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| let flight_descriptor = request.into_inner(); |
| let ticket = Ticket { |
| ticket: query.as_any().encode_to_vec().into(), |
| }; |
| let endpoint = FlightEndpoint::new().with_ticket(ticket); |
| |
| let flight_info = FlightInfo::new() |
| .try_with_schema(&query.into_builder().schema()) |
| .map_err(|e| status!("Unable to encode schema", e))? |
| .with_endpoint(endpoint) |
| .with_descriptor(flight_descriptor); |
| |
| Ok(tonic::Response::new(flight_info)) |
| } |
| |
| async fn get_flight_info_table_types( |
| &self, |
| _query: CommandGetTableTypes, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_table_types not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_sql_info( |
| &self, |
| query: CommandGetSqlInfo, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| let flight_descriptor = request.into_inner(); |
| let ticket = Ticket::new(query.as_any().encode_to_vec()); |
| let endpoint = FlightEndpoint::new().with_ticket(ticket); |
| |
| let flight_info = FlightInfo::new() |
| .try_with_schema(query.into_builder(&INSTANCE_SQL_DATA).schema().as_ref()) |
| .map_err(|e| status!("Unable to encode schema", e))? |
| .with_endpoint(endpoint) |
| .with_descriptor(flight_descriptor); |
| |
| Ok(tonic::Response::new(flight_info)) |
| } |
| |
| async fn get_flight_info_primary_keys( |
| &self, |
| _query: CommandGetPrimaryKeys, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_primary_keys not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_exported_keys( |
| &self, |
| _query: CommandGetExportedKeys, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_exported_keys not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_imported_keys( |
| &self, |
| _query: CommandGetImportedKeys, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_imported_keys not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_cross_reference( |
| &self, |
| _query: CommandGetCrossReference, |
| _request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| Err(Status::unimplemented( |
| "get_flight_info_imported_keys not implemented", |
| )) |
| } |
| |
| async fn get_flight_info_xdbc_type_info( |
| &self, |
| query: CommandGetXdbcTypeInfo, |
| request: Request<FlightDescriptor>, |
| ) -> Result<Response<FlightInfo>, Status> { |
| let flight_descriptor = request.into_inner(); |
| let ticket = Ticket::new(query.as_any().encode_to_vec()); |
| let endpoint = FlightEndpoint::new().with_ticket(ticket); |
| |
| let flight_info = FlightInfo::new() |
| .try_with_schema(query.into_builder(&INSTANCE_XBDC_DATA).schema().as_ref()) |
| .map_err(|e| status!("Unable to encode schema", e))? |
| .with_endpoint(endpoint) |
| .with_descriptor(flight_descriptor); |
| |
| Ok(tonic::Response::new(flight_info)) |
| } |
| |
| // do_get |
| async fn do_get_statement( |
| &self, |
| _ticket: TicketStatementQuery, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented("do_get_statement not implemented")) |
| } |
| |
| async fn do_get_prepared_statement( |
| &self, |
| _query: CommandPreparedStatementQuery, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented( |
| "do_get_prepared_statement not implemented", |
| )) |
| } |
| |
| async fn do_get_catalogs( |
| &self, |
| query: CommandGetCatalogs, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| let catalog_names = TABLES |
| .iter() |
| .map(|full_name| full_name.split('.').collect::<Vec<_>>()[0].to_string()) |
| .collect::<HashSet<_>>(); |
| let mut builder = query.into_builder(); |
| for catalog_name in catalog_names { |
| builder.append(catalog_name); |
| } |
| let schema = builder.schema(); |
| let batch = builder.build(); |
| let stream = FlightDataEncoderBuilder::new() |
| .with_schema(schema) |
| .build(futures::stream::once(async { batch })) |
| .map_err(Status::from); |
| Ok(Response::new(Box::pin(stream))) |
| } |
| |
| async fn do_get_schemas( |
| &self, |
| query: CommandGetDbSchemas, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| let schemas = TABLES |
| .iter() |
| .map(|full_name| { |
| let parts = full_name.split('.').collect::<Vec<_>>(); |
| (parts[0].to_string(), parts[1].to_string()) |
| }) |
| .collect::<HashSet<_>>(); |
| |
| let mut builder = query.into_builder(); |
| for (catalog_name, schema_name) in schemas { |
| builder.append(catalog_name, schema_name); |
| } |
| |
| let schema = builder.schema(); |
| let batch = builder.build(); |
| let stream = FlightDataEncoderBuilder::new() |
| .with_schema(schema) |
| .build(futures::stream::once(async { batch })) |
| .map_err(Status::from); |
| Ok(Response::new(Box::pin(stream))) |
| } |
| |
| async fn do_get_tables( |
| &self, |
| query: CommandGetTables, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| let tables = TABLES |
| .iter() |
| .map(|full_name| { |
| let parts = full_name.split('.').collect::<Vec<_>>(); |
| ( |
| parts[0].to_string(), |
| parts[1].to_string(), |
| parts[2].to_string(), |
| ) |
| }) |
| .collect::<HashSet<_>>(); |
| |
| let dummy_schema = Schema::empty(); |
| let mut builder = query.into_builder(); |
| for (catalog_name, schema_name, table_name) in tables { |
| builder |
| .append( |
| catalog_name, |
| schema_name, |
| table_name, |
| "TABLE", |
| &dummy_schema, |
| ) |
| .map_err(Status::from)?; |
| } |
| |
| let schema = builder.schema(); |
| let batch = builder.build(); |
| let stream = FlightDataEncoderBuilder::new() |
| .with_schema(schema) |
| .build(futures::stream::once(async { batch })) |
| .map_err(Status::from); |
| Ok(Response::new(Box::pin(stream))) |
| } |
| |
| async fn do_get_table_types( |
| &self, |
| _query: CommandGetTableTypes, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented("do_get_table_types not implemented")) |
| } |
| |
| async fn do_get_sql_info( |
| &self, |
| query: CommandGetSqlInfo, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| let builder = query.into_builder(&INSTANCE_SQL_DATA); |
| let schema = builder.schema(); |
| let batch = builder.build(); |
| let stream = FlightDataEncoderBuilder::new() |
| .with_schema(schema) |
| .build(futures::stream::once(async { batch })) |
| .map_err(Status::from); |
| Ok(Response::new(Box::pin(stream))) |
| } |
| |
| async fn do_get_primary_keys( |
| &self, |
| _query: CommandGetPrimaryKeys, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented("do_get_primary_keys not implemented")) |
| } |
| |
| async fn do_get_exported_keys( |
| &self, |
| _query: CommandGetExportedKeys, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented( |
| "do_get_exported_keys not implemented", |
| )) |
| } |
| |
| async fn do_get_imported_keys( |
| &self, |
| _query: CommandGetImportedKeys, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented( |
| "do_get_imported_keys not implemented", |
| )) |
| } |
| |
| async fn do_get_cross_reference( |
| &self, |
| _query: CommandGetCrossReference, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| Err(Status::unimplemented( |
| "do_get_cross_reference not implemented", |
| )) |
| } |
| |
| async fn do_get_xdbc_type_info( |
| &self, |
| query: CommandGetXdbcTypeInfo, |
| _request: Request<Ticket>, |
| ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { |
| // create a builder with pre-defined Xdbc data: |
| let builder = query.into_builder(&INSTANCE_XBDC_DATA); |
| let schema = builder.schema(); |
| let batch = builder.build(); |
| let stream = FlightDataEncoderBuilder::new() |
| .with_schema(schema) |
| .build(futures::stream::once(async { batch })) |
| .map_err(Status::from); |
| Ok(Response::new(Box::pin(stream))) |
| } |
| |
| // do_put |
| async fn do_put_statement_update( |
| &self, |
| _ticket: CommandStatementUpdate, |
| _request: Request<PeekableFlightDataStream>, |
| ) -> Result<i64, Status> { |
| Ok(FAKE_UPDATE_RESULT) |
| } |
| |
| async fn do_put_statement_ingest( |
| &self, |
| _ticket: CommandStatementIngest, |
| _request: Request<PeekableFlightDataStream>, |
| ) -> Result<i64, Status> { |
| Ok(FAKE_UPDATE_RESULT) |
| } |
| |
| async fn do_put_substrait_plan( |
| &self, |
| _ticket: CommandStatementSubstraitPlan, |
| _request: Request<PeekableFlightDataStream>, |
| ) -> Result<i64, Status> { |
| Err(Status::unimplemented( |
| "do_put_substrait_plan not implemented", |
| )) |
| } |
| |
| async fn do_put_prepared_statement_query( |
| &self, |
| _query: CommandPreparedStatementQuery, |
| _request: Request<PeekableFlightDataStream>, |
| ) -> Result<DoPutPreparedStatementResult, Status> { |
| Err(Status::unimplemented( |
| "do_put_prepared_statement_query not implemented", |
| )) |
| } |
| |
| async fn do_put_prepared_statement_update( |
| &self, |
| _query: CommandPreparedStatementUpdate, |
| _request: Request<PeekableFlightDataStream>, |
| ) -> Result<i64, Status> { |
| Err(Status::unimplemented( |
| "do_put_prepared_statement_update not implemented", |
| )) |
| } |
| |
| async fn do_action_create_prepared_statement( |
| &self, |
| _query: ActionCreatePreparedStatementRequest, |
| request: Request<Action>, |
| ) -> Result<ActionCreatePreparedStatementResult, Status> { |
| self.check_token(&request)?; |
| let record_batch = |
| Self::fake_result().map_err(|e| status!("Error getting result schema", e))?; |
| let schema = record_batch.schema_ref(); |
| let message = SchemaAsIpc::new(schema, &IpcWriteOptions::default()) |
| .try_into() |
| .map_err(|e| status!("Unable to serialize schema", e))?; |
| let IpcMessage(schema_bytes) = message; |
| let res = ActionCreatePreparedStatementResult { |
| prepared_statement_handle: FAKE_HANDLE.into(), |
| dataset_schema: schema_bytes, |
| parameter_schema: Default::default(), // TODO: parameters |
| }; |
| Ok(res) |
| } |
| |
| async fn do_action_close_prepared_statement( |
| &self, |
| _query: ActionClosePreparedStatementRequest, |
| _request: Request<Action>, |
| ) -> Result<(), Status> { |
| Ok(()) |
| } |
| |
| async fn do_action_create_prepared_substrait_plan( |
| &self, |
| _query: ActionCreatePreparedSubstraitPlanRequest, |
| _request: Request<Action>, |
| ) -> Result<ActionCreatePreparedStatementResult, Status> { |
| Err(Status::unimplemented( |
| "Implement do_action_create_prepared_substrait_plan", |
| )) |
| } |
| |
| async fn do_action_begin_transaction( |
| &self, |
| _query: ActionBeginTransactionRequest, |
| _request: Request<Action>, |
| ) -> Result<ActionBeginTransactionResult, Status> { |
| Err(Status::unimplemented( |
| "Implement do_action_begin_transaction", |
| )) |
| } |
| |
| async fn do_action_end_transaction( |
| &self, |
| _query: ActionEndTransactionRequest, |
| _request: Request<Action>, |
| ) -> Result<(), Status> { |
| Err(Status::unimplemented("Implement do_action_end_transaction")) |
| } |
| |
| async fn do_action_begin_savepoint( |
| &self, |
| _query: ActionBeginSavepointRequest, |
| _request: Request<Action>, |
| ) -> Result<ActionBeginSavepointResult, Status> { |
| Err(Status::unimplemented("Implement do_action_begin_savepoint")) |
| } |
| |
| async fn do_action_end_savepoint( |
| &self, |
| _query: ActionEndSavepointRequest, |
| _request: Request<Action>, |
| ) -> Result<(), Status> { |
| Err(Status::unimplemented("Implement do_action_end_savepoint")) |
| } |
| |
| async fn do_action_cancel_query( |
| &self, |
| _query: ActionCancelQueryRequest, |
| _request: Request<Action>, |
| ) -> Result<ActionCancelQueryResult, Status> { |
| Err(Status::unimplemented("Implement do_action_cancel_query")) |
| } |
| |
| async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} |
| } |
| |
| /// This example shows how to run a FlightSql server |
| #[tokio::main] |
| async fn main() -> Result<(), Box<dyn std::error::Error>> { |
| let addr_str = "0.0.0.0:50051"; |
| let addr = addr_str.parse()?; |
| |
| println!("Listening on {addr:?}"); |
| |
| if std::env::var("USE_TLS").ok().is_some() { |
| let cert = std::fs::read_to_string("arrow-flight/examples/data/server.pem")?; |
| let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?; |
| let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; |
| |
| let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); |
| let tls_config = ServerTlsConfig::new() |
| .identity(Identity::from_pem(&cert, &key)) |
| .client_ca_root(Certificate::from_pem(&client_ca)); |
| |
| Server::builder() |
| .tls_config(tls_config)? |
| .add_service(svc) |
| .serve(addr) |
| .await?; |
| } else { |
| let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); |
| |
| Server::builder().add_service(svc).serve(addr).await?; |
| } |
| |
| Ok(()) |
| } |
| |
| #[derive(Clone, PartialEq, ::prost::Message)] |
| pub struct FetchResults { |
| #[prost(string, tag = "1")] |
| pub handle: ::prost::alloc::string::String, |
| } |
| |
| impl ProstMessageExt for FetchResults { |
| fn type_url() -> &'static str { |
| "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" |
| } |
| |
| fn as_any(&self) -> Any { |
| Any { |
| type_url: FetchResults::type_url().to_string(), |
| value: ::prost::Message::encode_to_vec(self).into(), |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use futures::{TryFutureExt, TryStreamExt}; |
| use hyper_util::rt::TokioIo; |
| use std::fs; |
| use std::future::Future; |
| use std::net::SocketAddr; |
| use std::path::PathBuf; |
| use std::time::Duration; |
| use tempfile::NamedTempFile; |
| use tokio::net::{TcpListener, UnixListener, UnixStream}; |
| use tokio_stream::wrappers::UnixListenerStream; |
| use tonic::transport::{Channel, ClientTlsConfig}; |
| |
| use arrow_cast::pretty::pretty_format_batches; |
| use arrow_flight::sql::client::FlightSqlServiceClient; |
| use tonic::transport::server::TcpIncoming; |
| use tonic::transport::{Certificate, Endpoint}; |
| use tower::service_fn; |
| |
| async fn bind_tcp() -> (TcpIncoming, SocketAddr) { |
| let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); |
| let addr = listener.local_addr().unwrap(); |
| let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); |
| (incoming, addr) |
| } |
| |
| fn endpoint(uri: String) -> Result<Endpoint, ArrowError> { |
| let endpoint = Endpoint::new(uri) |
| .map_err(|_| ArrowError::IpcError("Cannot create endpoint".to_string()))? |
| .connect_timeout(Duration::from_secs(20)) |
| .timeout(Duration::from_secs(20)) |
| .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait |
| .tcp_keepalive(Option::Some(Duration::from_secs(3600))) |
| .http2_keep_alive_interval(Duration::from_secs(300)) |
| .keep_alive_timeout(Duration::from_secs(20)) |
| .keep_alive_while_idle(true); |
| |
| Ok(endpoint) |
| } |
| |
| async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) { |
| let token = client.handshake("admin", "password").await.unwrap(); |
| client.set_token(String::from_utf8(token.to_vec()).unwrap()); |
| } |
| |
| async fn test_uds_client<F, C>(f: F) |
| where |
| F: FnOnce(FlightSqlServiceClient<Channel>) -> C, |
| C: Future<Output = ()>, |
| { |
| let file = NamedTempFile::new().unwrap(); |
| let path = file.into_temp_path().to_str().unwrap().to_string(); |
| let _ = fs::remove_file(path.clone()); |
| |
| let uds = UnixListener::bind(path.clone()).unwrap(); |
| let stream = UnixListenerStream::new(uds); |
| |
| let service = FlightSqlServiceImpl {}; |
| let serve_future = Server::builder() |
| .add_service(FlightServiceServer::new(service)) |
| .serve_with_incoming(stream); |
| |
| let request_future = async { |
| let connector = |
| service_fn(move |_| UnixStream::connect(path.clone()).map_ok(TokioIo::new)); |
| let channel = Endpoint::try_from("http://example.com") |
| .unwrap() |
| .connect_with_connector(connector) |
| .await |
| .unwrap(); |
| let client = FlightSqlServiceClient::new(channel); |
| f(client).await |
| }; |
| |
| tokio::select! { |
| _ = serve_future => panic!("server returned first"), |
| _ = request_future => println!("Client finished!"), |
| } |
| } |
| |
| async fn test_http_client<F, C>(f: F) |
| where |
| F: FnOnce(FlightSqlServiceClient<Channel>) -> C, |
| C: Future<Output = ()>, |
| { |
| let (incoming, addr) = bind_tcp().await; |
| let uri = format!("http://{}:{}", addr.ip(), addr.port()); |
| |
| let service = FlightSqlServiceImpl {}; |
| let serve_future = Server::builder() |
| .add_service(FlightServiceServer::new(service)) |
| .serve_with_incoming(incoming); |
| |
| let request_future = async { |
| let endpoint = endpoint(uri).unwrap(); |
| let channel = endpoint.connect().await.unwrap(); |
| let client = FlightSqlServiceClient::new(channel); |
| f(client).await |
| }; |
| |
| tokio::select! { |
| _ = serve_future => panic!("server returned first"), |
| _ = request_future => println!("Client finished!"), |
| } |
| } |
| |
| async fn test_https_client<F, C>(f: F) |
| where |
| F: FnOnce(FlightSqlServiceClient<Channel>) -> C, |
| C: Future<Output = ()>, |
| { |
| let cert_dir = PathBuf::from("examples/data"); |
| |
| let cert = std::fs::read_to_string(cert_dir.join("server.pem")).unwrap(); |
| let key = std::fs::read_to_string(cert_dir.join("server.key")).unwrap(); |
| let ca_root = std::fs::read_to_string(cert_dir.join("ca_root.pem")).unwrap(); |
| |
| let tls_config = ServerTlsConfig::new() |
| .identity(Identity::from_pem(&cert, &key)) |
| .client_ca_root(Certificate::from_pem(&ca_root)); |
| |
| let (incoming, addr) = bind_tcp().await; |
| let uri = format!("https://{}:{}", addr.ip(), addr.port()); |
| |
| let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); |
| |
| let serve_future = Server::builder() |
| .tls_config(tls_config) |
| .unwrap() |
| .add_service(svc) |
| .serve_with_incoming(incoming); |
| |
| let request_future = async move { |
| let cert = std::fs::read_to_string(cert_dir.join("client.pem")).unwrap(); |
| let key = std::fs::read_to_string(cert_dir.join("client.key")).unwrap(); |
| |
| let tls_config = ClientTlsConfig::new() |
| .domain_name("localhost") |
| .ca_certificate(Certificate::from_pem(&ca_root)) |
| .identity(Identity::from_pem(cert, key)); |
| |
| let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap(); |
| let channel = endpoint.connect().await.unwrap(); |
| let client = FlightSqlServiceClient::new(channel); |
| f(client).await |
| }; |
| |
| tokio::select! { |
| _ = serve_future => panic!("server returned first"), |
| _ = request_future => println!("Client finished!"), |
| } |
| } |
| |
| async fn test_all_clients<F, C>(task: F) |
| where |
| F: FnOnce(FlightSqlServiceClient<Channel>) -> C + Copy, |
| C: Future<Output = ()>, |
| { |
| println!("testing uds client"); |
| test_uds_client(task).await; |
| println!("======="); |
| |
| println!("testing http client"); |
| test_http_client(task).await; |
| println!("======="); |
| |
| println!("testing https client"); |
| test_https_client(task).await; |
| println!("======="); |
| } |
| |
| #[tokio::test] |
| async fn test_select() { |
| test_all_clients(|mut client| async move { |
| auth_client(&mut client).await; |
| |
| let mut stmt = client.prepare("select 1;".to_string(), None).await.unwrap(); |
| |
| let flight_info = stmt.execute().await.unwrap(); |
| |
| let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); |
| let flight_data = client.do_get(ticket).await.unwrap(); |
| let batches: Vec<_> = flight_data.try_collect().await.unwrap(); |
| |
| let res = pretty_format_batches(batches.as_slice()).unwrap(); |
| let expected = r#" |
| +-------------------+ |
| | salutation | |
| +-------------------+ |
| | Hello, FlightSQL! | |
| +-------------------+"# |
| .trim() |
| .to_string(); |
| assert_eq!(res.to_string(), expected); |
| }) |
| .await |
| } |
| |
| #[tokio::test] |
| async fn test_execute_update() { |
| test_all_clients(|mut client| async move { |
| auth_client(&mut client).await; |
| let res = client |
| .execute_update("creat table test(a int);".to_string(), None) |
| .await |
| .unwrap(); |
| assert_eq!(res, FAKE_UPDATE_RESULT); |
| }) |
| .await |
| } |
| |
| #[tokio::test] |
| async fn test_auth() { |
| test_all_clients(|mut client| async move { |
| // no handshake |
| assert_contains( |
| client |
| .prepare("select 1;".to_string(), None) |
| .await |
| .unwrap_err() |
| .to_string(), |
| "No authorization header", |
| ); |
| |
| // Invalid credentials |
| assert_contains( |
| client |
| .handshake("admin", "password2") |
| .await |
| .unwrap_err() |
| .to_string(), |
| "Invalid credentials", |
| ); |
| |
| // Invalid Tokens |
| client.handshake("admin", "password").await.unwrap(); |
| client.set_token("wrong token".to_string()); |
| assert_contains( |
| client |
| .prepare("select 1;".to_string(), None) |
| .await |
| .unwrap_err() |
| .to_string(), |
| "invalid token", |
| ); |
| |
| client.clear_token(); |
| |
| // Successful call (token is automatically set by handshake) |
| client.handshake("admin", "password").await.unwrap(); |
| client.prepare("select 1;".to_string(), None).await.unwrap(); |
| }) |
| .await |
| } |
| |
| fn assert_contains(actual: impl AsRef<str>, searched_for: impl AsRef<str>) { |
| let actual = actual.as_ref(); |
| let searched_for = searched_for.as_ref(); |
| assert!( |
| actual.contains(searched_for), |
| "Expected '{}' to contain '{}'", |
| actual, |
| searched_for |
| ); |
| } |
| } |