blob: c161caae8ca45df4793e8d45376b22e229e38a88 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
mod common;
use std::{pin::Pin, sync::Arc};
use crate::common::fixture::TestFixture;
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray};
use arrow_flight::{
Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket,
decode::FlightRecordBatchStream,
encode::FlightDataEncoderBuilder,
flight_service_server::{FlightService, FlightServiceServer},
sql::{
ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any,
CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables,
CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult,
ProstMessageExt, SqlInfo,
server::{FlightSqlService, PeekableFlightDataStream},
},
utils::batches_to_flight_data,
};
use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};
use assert_cmd::Command;
use bytes::Bytes;
use futures::{Stream, TryStreamExt};
use prost::Message;
use tonic::{Request, Response, Status, Streaming};
const QUERY: &str = "SELECT * FROM table;";
/// Return a Command instance for running the `flight_sql_client` CLI
fn flight_sql_client_cmd() -> Command {
Command::new(assert_cmd::cargo::cargo_bin!("flight_sql_client"))
}
#[tokio::test]
async fn test_simple() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("statement-query")
.arg(QUERY)
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+-----------+---------------------------+-----------------------------+\
\n| field_string | field_int | field_timestamp_nano_notz | field_timestamp_nano_berlin |\
\n+--------------+-----------+---------------------------+-----------------------------+\
\n| Hello | 42 | | |\
\n| lovely | | 1970-01-01T00:00:00 | 1970-01-01T01:00:00+01:00 |\
\n| FlightSQL! | 1337 | 2024-10-30T11:36:57 | 2024-10-30T12:36:57+01:00 |\
\n+--------------+-----------+---------------------------+-----------------------------+",
);
}
#[tokio::test]
async fn test_get_catalogs() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("catalogs")
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+\
\n| catalog_name |\
\n+--------------+\
\n| catalog_a |\
\n| catalog_b |\
\n+--------------+",
);
}
#[tokio::test]
async fn test_get_db_schemas() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("db-schemas")
.arg("catalog_a")
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+----------------+\
\n| catalog_name | db_schema_name |\
\n+--------------+----------------+\
\n| catalog_a | schema_1 |\
\n| catalog_a | schema_2 |\
\n+--------------+----------------+",
);
}
#[tokio::test]
async fn test_get_tables() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("tables")
.arg("catalog_a")
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+----------------+------------+------------+\
\n| catalog_name | db_schema_name | table_name | table_type |\
\n+--------------+----------------+------------+------------+\
\n| catalog_a | schema_1 | table_1 | TABLE |\
\n| catalog_a | schema_2 | table_2 | VIEW |\
\n+--------------+----------------+------------+------------+",
);
}
#[tokio::test]
async fn test_get_tables_db_filter() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("tables")
.arg("catalog_a")
.arg("--db-schema-filter")
.arg("schema_2")
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+----------------+------------+------------+\
\n| catalog_name | db_schema_name | table_name | table_type |\
\n+--------------+----------------+------------+------------+\
\n| catalog_a | schema_2 | table_2 | VIEW |\
\n+--------------+----------------+------------+------------+",
);
}
#[tokio::test]
async fn test_get_tables_types() {
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("table-types")
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+\
\n| table_type |\
\n+--------------+\
\n| SYSTEM_TABLE |\
\n| TABLE |\
\n| VIEW |\
\n+--------------+",
);
}
const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1";
const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle";
const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle";
async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) {
let fixture = TestFixture::new(test_server.service()).await;
let addr = fixture.addr;
let stdout = tokio::task::spawn_blocking(move || {
flight_sql_client_cmd()
.env_clear()
.env("RUST_BACKTRACE", "1")
.env("RUST_LOG", "warn")
.arg("--host")
.arg(addr.ip().to_string())
.arg("--port")
.arg(addr.port().to_string())
.arg("prepared-statement-query")
.arg(PREPARED_QUERY)
.args(["-p", "$1=string"])
.args(["-p", "$2=64"])
.assert()
.success()
.get_output()
.stdout
.clone()
})
.await
.unwrap();
fixture.shutdown_and_wait().await;
assert_eq!(
std::str::from_utf8(&stdout).unwrap().trim(),
"+--------------+-----------+---------------------------+-----------------------------+\
\n| field_string | field_int | field_timestamp_nano_notz | field_timestamp_nano_berlin |\
\n+--------------+-----------+---------------------------+-----------------------------+\
\n| Hello | 42 | | |\
\n| lovely | | 1970-01-01T00:00:00 | 1970-01-01T01:00:00+01:00 |\
\n| FlightSQL! | 1337 | 2024-10-30T11:36:57 | 2024-10-30T12:36:57+01:00 |\
\n+--------------+-----------+---------------------------+-----------------------------+",
);
}
#[tokio::test]
pub async fn test_do_put_prepared_statement_stateless() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: true,
})
.await
}
#[tokio::test]
pub async fn test_do_put_prepared_statement_stateful() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: false,
})
.await
}
#[derive(Clone)]
pub struct FlightSqlServiceImpl {
/// Whether to emulate stateless (true) or stateful (false) behavior for
/// prepared statements. stateful servers will not return an updated
/// handle after executing `DoPut(CommandPreparedStatementQuery)`
stateless_prepared_statements: bool,
}
impl Default for FlightSqlServiceImpl {
fn default() -> Self {
Self {
stateless_prepared_statements: true,
}
}
}
impl FlightSqlServiceImpl {
/// Return an [`FlightServiceServer`] that can be used with a
/// [`Server`](tonic::transport::Server)
pub fn service(&self) -> FlightServiceServer<Self> {
// wrap up tonic goop
FlightServiceServer::new(self.clone())
}
fn schema() -> Schema {
Schema::new(vec![
Field::new("field_string", DataType::Utf8, false),
Field::new("field_int", DataType::Int64, true),
Field::new(
"field_timestamp_nano_notz",
DataType::Timestamp(TimeUnit::Nanosecond, None),
true,
),
Field::new(
"field_timestamp_nano_berlin",
DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("Europe/Berlin"))),
true,
),
])
}
fn fake_result() -> Result<RecordBatch, ArrowError> {
let schema = Self::schema();
let string_array = StringArray::from(vec!["Hello", "lovely", "FlightSQL!"]);
let int_array = Int64Array::from(vec![Some(42), None, Some(1337)]);
let timestamp_array =
TimestampNanosecondArray::from(vec![None, Some(0), Some(1730288217000000000)]);
let timestamp_ts_array = timestamp_array
.clone()
.with_timezone(Arc::from("Europe/Berlin"));
let cols = vec![
Arc::new(string_array) as ArrayRef,
Arc::new(int_array) as ArrayRef,
Arc::new(timestamp_array) as ArrayRef,
Arc::new(timestamp_ts_array) as ArrayRef,
];
RecordBatch::try_new(Arc::new(schema), cols)
}
fn create_fake_prepared_stmt() -> Result<ActionCreatePreparedStatementResult, ArrowError> {
let handle = PREPARED_STATEMENT_HANDLE.to_string();
let schema = Self::schema();
let parameter_schema = Schema::new(vec![
Field::new("$1", DataType::Utf8, false),
Field::new("$2", DataType::Int64, true),
]);
Ok(ActionCreatePreparedStatementResult {
prepared_statement_handle: handle.into(),
dataset_schema: serialize_schema(&schema)?,
parameter_schema: serialize_schema(&parameter_schema)?,
})
}
fn fake_flight_info(&self) -> Result<FlightInfo, ArrowError> {
let batch = Self::fake_result()?;
Ok(FlightInfo::new()
.try_with_schema(batch.schema_ref())
.expect("encoding schema")
.with_endpoint(
FlightEndpoint::new().with_ticket(Ticket::new(
FetchResults {
handle: String::from("part_1"),
}
.as_any()
.encode_to_vec(),
)),
)
.with_endpoint(
FlightEndpoint::new().with_ticket(Ticket::new(
FetchResults {
handle: String::from("part_2"),
}
.as_any()
.encode_to_vec(),
)),
)
.with_total_records(batch.num_rows() as i64)
.with_total_bytes(batch.get_array_memory_size() as i64)
.with_ordered(false))
}
}
fn serialize_schema(schema: &Schema) -> Result<Bytes, ArrowError> {
Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0)
}
#[tonic::async_trait]
impl FlightSqlService for FlightSqlServiceImpl {
type FlightService = FlightSqlServiceImpl;
async fn do_handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
Err(Status::unimplemented("do_handshake not implemented"))
}
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let part = message.unpack::<FetchResults>().unwrap().unwrap().handle;
let batch = Self::fake_result().unwrap();
let batch = match part.as_str() {
"part_1" => batch.slice(0, 2),
"part_2" => batch.slice(2, 1),
ticket => panic!("Invalid ticket: {ticket:?}"),
};
let schema = batch.schema_ref();
let batches = vec![batch.clone()];
let flight_data = batches_to_flight_data(schema, batches)
.unwrap()
.into_iter()
.map(Ok);
let stream: Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>> =
Box::pin(futures::stream::iter(flight_data));
let resp = Response::new(stream);
Ok(resp)
}
async fn get_flight_info_catalogs(
&self,
query: CommandGetCatalogs,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.unwrap()
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(Response::new(flight_info))
}
async fn get_flight_info_schemas(
&self,
query: CommandGetDbSchemas,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.unwrap()
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(Response::new(flight_info))
}
async fn get_flight_info_tables(
&self,
query: CommandGetTables,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.unwrap()
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(Response::new(flight_info))
}
async fn get_flight_info_table_types(
&self,
query: CommandGetTableTypes,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.unwrap()
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(Response::new(flight_info))
}
async fn get_flight_info_statement(
&self,
query: CommandStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
assert_eq!(query.query, QUERY);
let resp = Response::new(self.fake_flight_info().unwrap());
Ok(resp)
}
async fn get_flight_info_prepared_statement(
&self,
cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
if self.stateless_prepared_statements {
assert_eq!(
cmd.prepared_statement_handle,
UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes()
);
} else {
assert_eq!(
cmd.prepared_statement_handle,
PREPARED_STATEMENT_HANDLE.as_bytes()
);
}
let resp = Response::new(self.fake_flight_info().unwrap());
Ok(resp)
}
async fn do_get_catalogs(
&self,
query: CommandGetCatalogs,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let mut builder = query.into_builder();
for catalog_name in ["catalog_a", "catalog_b"] {
builder.append(catalog_name);
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
async fn do_get_schemas(
&self,
query: CommandGetDbSchemas,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let mut builder = query.into_builder();
for (catalog_name, schema_name) in [
("catalog_a", "schema_1"),
("catalog_a", "schema_2"),
("catalog_b", "schema_3"),
] {
builder.append(catalog_name, schema_name);
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
async fn do_get_tables(
&self,
query: CommandGetTables,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let mut builder = query.into_builder();
for (catalog_name, schema_name, table_name, table_type, schema) in [
(
"catalog_a",
"schema_1",
"table_1",
"TABLE",
Arc::new(Schema::empty()),
),
(
"catalog_a",
"schema_2",
"table_2",
"VIEW",
Arc::new(Schema::empty()),
),
(
"catalog_b",
"schema_3",
"table_3",
"TABLE",
Arc::new(Schema::empty()),
),
] {
builder
.append(catalog_name, schema_name, table_name, table_type, &schema)
.unwrap();
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
async fn do_get_table_types(
&self,
query: CommandGetTableTypes,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let mut builder = query.into_builder();
for table_type in ["TABLE", "VIEW", "SYSTEM_TABLE"] {
builder.append(table_type);
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
request: Request<PeekableFlightDataStream>,
) -> Result<DoPutPreparedStatementResult, Status> {
// just make sure decoding the parameters works
let parameters = FlightRecordBatchStream::new_from_flight_data(
request.into_inner().map_err(|e| e.into()),
)
.try_collect::<Vec<_>>()
.await?;
for (left, right) in parameters[0].schema().flattened_fields().iter().zip(vec![
Field::new("$1", DataType::Utf8, false),
Field::new("$2", DataType::Int64, true),
]) {
if left.name() != right.name() || left.data_type() != right.data_type() {
return Err(Status::invalid_argument(format!(
"Parameters did not match parameter schema\ngot {}",
parameters[0].schema(),
)));
}
}
let handle = if self.stateless_prepared_statements {
UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into()
} else {
PREPARED_STATEMENT_HANDLE.to_string().into()
};
let result = DoPutPreparedStatementResult {
prepared_statement_handle: Some(handle),
};
Ok(result)
}
async fn do_action_create_prepared_statement(
&self,
_query: ActionCreatePreparedStatementRequest,
_request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
Self::create_fake_prepared_stmt()
.map_err(|e| Status::internal(format!("Unable to serialize schema: {e}")))
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FetchResults {
#[prost(string, tag = "1")]
pub handle: ::prost::alloc::string::String,
}
impl ProstMessageExt for FetchResults {
fn type_url() -> &'static str {
"type.googleapis.com/arrow.flight.protocol.sql.FetchResults"
}
fn as_any(&self) -> Any {
Any {
type_url: FetchResults::type_url().to_string(),
value: ::prost::Message::encode_to_vec(self).into(),
}
}
}