blob: d608a47537231f5168cab1f26e67bea6057fb2fd [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Integration tests for the Flight server.
use core::str;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use arrow::{
array::ArrayRef,
buffer::Buffer,
datatypes::Schema,
datatypes::SchemaRef,
ipc::{self, reader, writer},
record_batch::RecordBatch,
};
use arrow_flight::{
flight_descriptor::DescriptorType, flight_service_server::FlightService,
flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData,
FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage,
PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket,
};
use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
use tokio::sync::Mutex;
use tonic::{transport::Server, Request, Response, Status, Streaming};
type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
type Result<T = (), E = Error> = std::result::Result<T, E>;
/// Run a scenario that tests integration testing.
pub async fn scenario_setup(port: u16) -> Result {
let addr = super::listen_on(port).await?;
let resolved_port = addr.port();
let service = FlightServiceImpl {
// See https://github.com/apache/arrow-rs/issues/6577
// C# had trouble resolving addressed like 0.0.0.0:port
// server_location: format!("grpc+tcp://{addr}"),
server_location: format!("grpc+tcp://localhost:{resolved_port}"),
..Default::default()
};
let svc = FlightServiceServer::new(service);
let server = Server::builder().add_service(svc).serve(addr);
// NOTE: Log output used in tests to signal server is ready
println!("Server listening on localhost:{}", addr.port());
server.await?;
Ok(())
}
#[derive(Debug, Clone)]
struct IntegrationDataset {
schema: Schema,
chunks: Vec<RecordBatch>,
}
/// Flight service implementation for integration testing
#[derive(Clone, Default)]
pub struct FlightServiceImpl {
server_location: String,
uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
}
impl FlightServiceImpl {
fn endpoint_from_path(&self, path: &str) -> FlightEndpoint {
super::endpoint(path, &self.server_location)
}
}
#[tonic::async_trait]
impl FlightService for FlightServiceImpl {
type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
type DoGetStream = TonicStream<Result<FlightData, Status>>;
type DoPutStream = TonicStream<Result<PutResult, Status>>;
type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
type ListActionsStream = TonicStream<Result<ActionType, Status>>;
type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let ticket = request.into_inner();
let key = str::from_utf8(&ticket.ticket)
.map_err(|e| Status::invalid_argument(format!("Invalid ticket: {e:?}")))?;
let uploaded_chunks = self.uploaded_chunks.lock().await;
let flight = uploaded_chunks
.get(key)
.ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?;
let options = arrow::ipc::writer::IpcWriteOptions::default();
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
let data_gen = writer::IpcDataGenerator::default();
let data = IpcMessage(
data_gen
.schema_to_bytes_with_dictionary_tracker(
&flight.schema,
&mut dictionary_tracker,
&options,
)
.ipc_message
.into(),
);
let schema_flight_data = FlightData {
data_header: data.0,
..Default::default()
};
let schema = std::iter::once(Ok(schema_flight_data));
let batches = flight
.chunks
.iter()
.enumerate()
.flat_map(|(counter, batch)| {
let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, &options)
.expect("DictionaryTracker configured above to not error on replacement");
let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into);
let mut batch_flight_data: FlightData = encoded_batch.into();
// Only the record batch's FlightData gets app_metadata
let metadata = counter.to_string().into();
batch_flight_data.app_metadata = metadata;
dictionary_flight_data
.chain(std::iter::once(batch_flight_data))
.map(Ok)
});
let output = futures::stream::iter(schema.chain(batches).collect::<Vec<_>>());
Ok(Response::new(Box::pin(output) as Self::DoGetStream))
}
async fn handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn list_flights(
&self,
_request: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn get_flight_info(
&self,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let descriptor = request.into_inner();
match descriptor.r#type {
t if t == DescriptorType::Path as i32 => {
let path = &descriptor.path;
if path.is_empty() {
return Err(Status::invalid_argument("Invalid path"));
}
let uploaded_chunks = self.uploaded_chunks.lock().await;
let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| {
Status::not_found(format!("Could not find flight. {}", path[0]))
})?;
let endpoint = self.endpoint_from_path(&path[0]);
let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum();
let options = arrow::ipc::writer::IpcWriteOptions::default();
let message = SchemaAsIpc::new(&flight.schema, &options)
.try_into()
.expect(
"Could not generate schema bytes from schema stored by a DoPut; \
this should be impossible",
);
let IpcMessage(schema) = message;
let info = FlightInfo {
schema,
flight_descriptor: Some(descriptor.clone()),
endpoint: vec![endpoint],
total_records: total_records as i64,
total_bytes: -1,
ordered: false,
app_metadata: vec![].into(),
};
Ok(Response::new(info))
}
other => Err(Status::unimplemented(format!("Request type: {other}"))),
}
}
async fn poll_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<PollInfo>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn do_put(
&self,
request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let mut input_stream = request.into_inner();
let flight_data = input_stream
.message()
.await?
.ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?;
let descriptor = flight_data
.flight_descriptor
.clone()
.ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?;
if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() {
return Err(Status::invalid_argument("Must specify a path"));
}
let key = descriptor.path[0].clone();
let schema = Schema::try_from(&flight_data)
.map_err(|e| Status::invalid_argument(format!("Invalid schema: {e:?}")))?;
let schema_ref = Arc::new(schema.clone());
let (response_tx, response_rx) = mpsc::channel(10);
let uploaded_chunks = self.uploaded_chunks.clone();
tokio::spawn(async {
let mut error_tx = response_tx.clone();
if let Err(e) = save_uploaded_chunks(
uploaded_chunks,
schema_ref,
input_stream,
response_tx,
schema,
key,
)
.await
{
error_tx.send(Err(e)).await.expect("Error sending error")
}
});
Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream))
}
async fn do_action(
&self,
_request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn list_actions(
&self,
_request: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
}
async fn send_app_metadata(
tx: &mut mpsc::Sender<Result<PutResult, Status>>,
app_metadata: &[u8],
) -> Result<(), Status> {
tx.send(Ok(PutResult {
app_metadata: app_metadata.to_vec().into(),
}))
.await
.map_err(|e| Status::internal(format!("Could not send PutResult: {e:?}")))
}
async fn record_batch_from_message(
message: ipc::Message<'_>,
data_body: &Buffer,
schema_ref: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
) -> Result<RecordBatch, Status> {
let ipc_batch = message
.header_as_record_batch()
.ok_or_else(|| Status::internal("Could not parse message header as record batch"))?;
let arrow_batch_result = reader::read_record_batch(
data_body,
ipc_batch,
schema_ref,
dictionaries_by_id,
None,
&message.version(),
);
arrow_batch_result
.map_err(|e| Status::internal(format!("Could not convert to RecordBatch: {e:?}")))
}
async fn dictionary_from_message(
message: ipc::Message<'_>,
data_body: &Buffer,
schema_ref: SchemaRef,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
) -> Result<(), Status> {
let ipc_batch = message
.header_as_dictionary_batch()
.ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?;
let dictionary_batch_result = reader::read_dictionary(
data_body,
ipc_batch,
&schema_ref,
dictionaries_by_id,
&message.version(),
);
dictionary_batch_result
.map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}")))
}
async fn save_uploaded_chunks(
uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
schema_ref: Arc<Schema>,
mut input_stream: Streaming<FlightData>,
mut response_tx: mpsc::Sender<Result<PutResult, Status>>,
schema: Schema,
key: String,
) -> Result<(), Status> {
let mut chunks = vec![];
let mut uploaded_chunks = uploaded_chunks.lock().await;
let mut dictionaries_by_id = HashMap::new();
while let Some(Ok(data)) = input_stream.next().await {
let message = arrow::ipc::root_as_message(&data.data_header[..])
.map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?;
match message.header_type() {
ipc::MessageHeader::Schema => {
return Err(Status::internal(
"Not expecting a schema when messages are read",
))
}
ipc::MessageHeader::RecordBatch => {
send_app_metadata(&mut response_tx, &data.app_metadata).await?;
let batch = record_batch_from_message(
message,
&Buffer::from(data.data_body.as_ref()),
schema_ref.clone(),
&dictionaries_by_id,
)
.await?;
chunks.push(batch);
}
ipc::MessageHeader::DictionaryBatch => {
dictionary_from_message(
message,
&Buffer::from(data.data_body.as_ref()),
schema_ref.clone(),
&mut dictionaries_by_id,
)
.await?;
}
t => {
return Err(Status::internal(format!(
"Reading types other than record batches not yet supported, \
unable to read {t:?}"
)));
}
}
}
let dataset = IntegrationDataset { schema, chunks };
uploaded_chunks.insert(key, dataset);
Ok(())
}