blob: ab566f578cbbb56d78d0f7232097b308c7e1559b [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Integration test for "mid level" Client
mod common;
use crate::common::fixture::TestFixture;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_flight::{
Action, ActionType, CancelFlightInfoRequest, CancelFlightInfoResult, CancelStatus, Criteria,
Empty, FlightClient, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo,
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, RenewFlightEndpointRequest, Ticket,
decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError,
};
use arrow_schema::{DataType, Field, Schema};
use bytes::Bytes;
use common::server::TestFlightServer;
use futures::{Future, StreamExt, TryStreamExt};
use prost::Message;
use tonic::Status;
use std::sync::Arc;
#[tokio::test]
async fn test_handshake() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request_payload = Bytes::from("foo-request-payload");
let response_payload = Bytes::from("bar-response-payload");
let request = HandshakeRequest {
payload: request_payload.clone(),
protocol_version: 0,
};
let response = HandshakeResponse {
payload: response_payload.clone(),
protocol_version: 0,
};
test_server.set_handshake_response(Ok(response));
let response = client.handshake(request_payload).await.unwrap();
assert_eq!(response, response_payload);
assert_eq!(test_server.take_handshake_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_handshake_error() {
do_test(|test_server, mut client| async move {
let request_payload = "foo-request-payload".to_string().into_bytes();
let e = Status::unauthenticated("DENIED");
test_server.set_handshake_response(Err(e.clone()));
let response = client.handshake(request_payload).await.unwrap_err();
expect_status(response, e);
})
.await;
}
/// Verifies that all headers sent from the the client are in the request_metadata
fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
let client_metadata = client.metadata().clone().into_headers();
assert!(!client_metadata.is_empty());
let metadata = test_server
.take_last_request_metadata()
.expect("No headers in server")
.into_headers();
for (k, v) in &client_metadata {
assert_eq!(
metadata.get(k).as_ref(),
Some(&v),
"Missing / Mismatched metadata {k:?} sent {client_metadata:?} got {metadata:?}"
);
}
}
fn test_flight_info(request: &FlightDescriptor) -> FlightInfo {
FlightInfo {
schema: Bytes::new(),
endpoint: vec![],
flight_descriptor: Some(request.clone()),
total_bytes: 123,
total_records: 456,
ordered: false,
app_metadata: Bytes::new(),
}
}
#[tokio::test]
async fn test_get_flight_info() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let expected_response = test_flight_info(&request);
test_server.set_get_flight_info_response(Ok(expected_response.clone()));
let response = client.get_flight_info(request.clone()).await.unwrap();
assert_eq!(response, expected_response);
assert_eq!(test_server.take_get_flight_info_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_flight_info_error() {
do_test(|test_server, mut client| async move {
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let e = Status::unauthenticated("DENIED");
test_server.set_get_flight_info_response(Err(e.clone()));
let response = client.get_flight_info(request.clone()).await.unwrap_err();
expect_status(response, e);
})
.await;
}
fn test_poll_info(request: &FlightDescriptor) -> PollInfo {
PollInfo {
info: Some(test_flight_info(request)),
flight_descriptor: None,
progress: Some(1.0),
expiration_time: None,
}
}
#[tokio::test]
async fn test_poll_flight_info() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let expected_response = test_poll_info(&request);
test_server.set_poll_flight_info_response(Ok(expected_response.clone()));
let response = client.poll_flight_info(request.clone()).await.unwrap();
assert_eq!(response, expected_response);
assert_eq!(test_server.take_poll_flight_info_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_poll_flight_info_error() {
do_test(|test_server, mut client| async move {
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let e = Status::unauthenticated("DENIED");
test_server.set_poll_flight_info_response(Err(e.clone()));
let response = client.poll_flight_info(request.clone()).await.unwrap_err();
expect_status(response, e);
})
.await;
}
// TODO more negative tests (like if there are endpoints defined, etc)
#[tokio::test]
async fn test_do_get() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
let response = vec![Ok(batch.clone())];
test_server.set_do_get_response(response);
let mut response_stream = client
.do_get(ticket.clone())
.await
.expect("error making request");
assert_eq!(
response_stream
.headers()
.get("test-resp-header")
.expect("header exists")
.to_str()
.unwrap(),
"some_val",
);
// trailers are not available before stream exhaustion
assert!(response_stream.trailers().is_none());
let expected_response = vec![batch];
let response: Vec<_> = (&mut response_stream)
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(
response_stream
.trailers()
.expect("stream exhausted")
.get("test-trailer")
.expect("trailer exists")
.to_str()
.unwrap(),
"trailer_val",
);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_get_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let response = client.do_get(ticket.clone()).await.unwrap_err();
let e = Status::internal("No do_get response configured");
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_get_request(), Some(ticket));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_get_error_in_record_batch_stream() {
do_test(|test_server, mut client| async move {
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
let e = Status::data_loss("she's dead jim");
let expected_response = vec![Ok(batch), Err(e.clone())];
test_server.set_do_get_response(expected_response);
let response_stream = client
.do_get(ticket.clone())
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_get_request(), Some(ticket));
})
.await;
}
#[tokio::test]
async fn test_do_put() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
// encode the batch as a stream of FlightData
let input_flight_data = test_flight_data().await;
let expected_response = vec![
PutResult {
app_metadata: Bytes::from("foo-metadata1"),
},
PutResult {
app_metadata: Bytes::from("bar-metadata2"),
},
];
test_server.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_put response configured");
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let e = Status::invalid_argument("bad arg");
let response = vec![
Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
}),
Err(e.clone()),
];
test_server.set_do_put_response(response);
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::invalid_argument("bad arg: client");
// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));
// server responds with one good message
let response = vec![Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
})];
test_server.set_do_put_response(response);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");
// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));
// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_put_response(response);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
// encode the batch as a stream of FlightData
let input_flight_data = test_flight_data().await;
let output_flight_data = test_flight_data2().await;
test_server
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());
let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
let expected_stream = futures::stream::iter(output_flight_data).map(Ok);
let expected_batches: Vec<_> =
FlightRecordBatchStream::new_from_flight_data(expected_stream)
.try_collect()
.await
.unwrap();
assert_eq!(response, expected_batches);
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let response = client
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_exchange response configured");
expect_status(response, e);
// server still got the request
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let e = Status::invalid_argument("the error");
let response = test_flight_data2()
.await
.into_iter()
.enumerate()
.map(|(i, m)| {
if i == 0 {
Ok(m)
} else {
// make all messages after the first an error
Err(e.clone())
}
})
.collect();
test_server.set_do_exchange_response(response);
let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e);
// server still got the request
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error_stream_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::invalid_argument("bad arg: client");
// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));
let output_flight_data = FlightData::new()
.with_descriptor(FlightDescriptor::new_cmd("Sample command"))
.with_data_body("body".as_bytes())
.with_data_header("header".as_bytes())
.with_app_metadata("metadata".as_bytes());
// server responds with one good message
let response = vec![Ok(output_flight_data)];
test_server.set_do_exchange_response(response);
let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");
// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));
// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_exchange_response(response);
let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_schema() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]);
let request = FlightDescriptor::new_cmd("my command");
test_server.set_get_schema_response(Ok(schema.clone()));
let response = client
.get_schema(request.clone())
.await
.expect("error making request");
let expected_schema = schema;
let expected_request = request;
assert_eq!(response, expected_schema);
assert_eq!(
test_server.take_get_schema_request(),
Some(expected_request)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_schema_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd("my command");
let e = Status::unauthenticated("DENIED");
test_server.set_get_schema_response(Err(e.clone()));
let response = client.get_schema(request).await.unwrap_err();
expect_status(response, e);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let infos = vec![
test_flight_info(&FlightDescriptor::new_cmd("foo")),
test_flight_info(&FlightDescriptor::new_cmd("bar")),
];
let response = infos.iter().map(|i| Ok(i.clone())).collect();
test_server.set_list_flights_response(response);
let response_stream = client
.list_flights("query")
.await
.expect("error making request");
let expected_response = infos;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
let expected_request = Some(Criteria {
expression: "query".into(),
});
assert_eq!(response, expected_response);
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let response = client.list_flights("query").await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No list_flights response configured");
expect_status(response, e);
// server still got the request
let expected_request = Some(Criteria {
expression: "query".into(),
});
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let response = vec![
Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))),
Err(e.clone()),
];
test_server.set_list_flights_response(response);
let response_stream = client
.list_flights("other query")
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
// server still got the request
let expected_request = Some(Criteria {
expression: "other query".into(),
});
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let actions = vec![
ActionType {
r#type: "type 1".into(),
description: "awesomeness".into(),
},
ActionType {
r#type: "type 2".into(),
description: "more awesomeness".into(),
},
];
let response = actions.iter().map(|i| Ok(i.clone())).collect();
test_server.set_list_actions_response(response);
let response_stream = client.list_actions().await.expect("error making request");
let expected_response = actions;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let response = client.list_actions().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No list_actions response configured");
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let response = vec![
Ok(ActionType {
r#type: "type 1".into(),
description: "awesomeness".into(),
}),
Err(e.clone()),
];
test_server.set_list_actions_response(response);
let response_stream = client.list_actions().await.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")];
let response = bytes
.iter()
.cloned()
.map(arrow_flight::Result::new)
.map(Ok)
.collect();
test_server.set_do_action_response(response);
let request = Action::new("action type", "action body");
let response_stream = client
.do_action(request.clone())
.await
.expect("error making request");
let expected_response = bytes;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = Action::new("action type", "action body");
let response = client.do_action(request.clone()).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_action response configured");
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let request = Action::new("action type", "action body");
let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())];
test_server.set_do_action_response(response);
let response_stream = client
.do_action(request.clone())
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
// server still got the request
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_cancel_flight_info() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let expected_response = CancelFlightInfoResult::new(CancelStatus::Cancelled);
let response = expected_response.encode_to_vec();
let response = Ok(arrow_flight::Result::new(response));
test_server.set_do_action_response(vec![response]);
let request = CancelFlightInfoRequest::new(FlightInfo::new());
let actual_response = client
.cancel_flight_info(request.clone())
.await
.expect("error making request");
let expected_request = Action::new("CancelFlightInfo", request.encode_to_vec());
assert_eq!(actual_response, expected_response);
assert_eq!(test_server.take_do_action_request(), Some(expected_request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_cancel_flight_info_error_no_response() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
test_server.set_do_action_response(vec![]);
let request = CancelFlightInfoRequest::new(FlightInfo::new());
let err = client
.cancel_flight_info(request.clone())
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Protocol error: Received no response for cancel_flight_info call"
);
// server still got the request
let expected_request = Action::new("CancelFlightInfo", request.encode_to_vec());
assert_eq!(test_server.take_do_action_request(), Some(expected_request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_renew_flight_endpoint() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let expected_response = FlightEndpoint::new().with_app_metadata(vec![1]);
let response = expected_response.encode_to_vec();
let response = Ok(arrow_flight::Result::new(response));
test_server.set_do_action_response(vec![response]);
let request =
RenewFlightEndpointRequest::new(FlightEndpoint::new().with_app_metadata(vec![0]));
let actual_response = client
.renew_flight_endpoint(request.clone())
.await
.expect("error making request");
let expected_request = Action::new("RenewFlightEndpoint", request.encode_to_vec());
assert_eq!(actual_response, expected_response);
assert_eq!(test_server.take_do_action_request(), Some(expected_request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_renew_flight_endpoint_error_no_response() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
test_server.set_do_action_response(vec![]);
let request = RenewFlightEndpointRequest::new(FlightEndpoint::new());
let err = client
.renew_flight_endpoint(request.clone())
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Protocol error: Received no response for renew_flight_endpoint call"
);
// server still got the request
let expected_request = Action::new("RenewFlightEndpoint", request.encode_to_vec());
assert_eq!(test_server.take_do_action_request(), Some(expected_request));
ensure_metadata(&client, &test_server);
})
.await;
}
async fn test_flight_data() -> Vec<FlightData> {
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
// encode the batch as a stream of FlightData
FlightDataEncoderBuilder::new()
.build(futures::stream::iter(vec![Ok(batch)]))
.try_collect()
.await
.unwrap()
}
async fn test_flight_data2() -> Vec<FlightData> {
let batch = RecordBatch::try_from_iter(vec![(
"col2",
Arc::new(UInt64Array::from_iter([10, 23, 33])) as _,
)])
.unwrap();
// encode the batch as a stream of FlightData
FlightDataEncoderBuilder::new()
.build(futures::stream::iter(vec![Ok(batch)]))
.try_collect()
.await
.unwrap()
}
/// Runs the future returned by the function, passing it a test server and client
async fn do_test<F, Fut>(f: F)
where
F: Fn(TestFlightServer, FlightClient) -> Fut,
Fut: Future<Output = ()>,
{
let test_server = TestFlightServer::new();
let fixture = TestFixture::new(test_server.service()).await;
let client = FlightClient::new(fixture.channel().await);
// run the test function
f(test_server, client).await;
// cleanly shutdown the test fixture
fixture.shutdown_and_wait().await
}
fn expect_status(error: FlightError, expected: Status) {
let status = if let FlightError::Tonic(status) = error {
status
} else {
panic!("Expected FlightError::Tonic, got: {error:?}");
};
assert_eq!(
status.code(),
expected.code(),
"Got {status:?} want {expected:?}"
);
assert_eq!(
status.message(),
expected.message(),
"Got {status:?} want {expected:?}"
);
assert_eq!(
status.details(),
expected.details(),
"Got {status:?} want {expected:?}"
);
}