blob: 375ce15686eed95ccb1562dbf8142f376fba768d [file] [log] [blame]
use crate::pb::{self, *};
use async_stream::try_stream;
use futures_util::{stream, StreamExt, TryStreamExt};
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http_body::Body;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status};
use tower::Service;
pub use pb::test_service_server::TestServiceServer;
pub use pb::unimplemented_service_server::UnimplementedServiceServer;
#[derive(Default, Clone)]
pub struct TestService;
type Result<T> = std::result::Result<Response<T>, Status>;
type Streaming<T> = Request<tonic::Streaming<T>>;
type Stream<T> =
Pin<Box<dyn futures_core::Stream<Item = std::result::Result<T, Status>> + Send + 'static>>;
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'static>>;
#[tonic::async_trait]
impl pb::test_service_server::TestService for TestService {
async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
Ok(Response::new(Empty {}))
}
async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
let req = request.into_inner();
if let Some(echo_status) = req.response_status {
let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
return Err(status);
}
let res_size = if req.response_size >= 0 {
req.response_size as usize
} else {
let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
return Err(status);
};
let res = SimpleResponse {
payload: Some(Payload {
body: vec![0; res_size],
..Default::default()
}),
..Default::default()
};
Ok(Response::new(res))
}
async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
unimplemented!()
}
type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>;
async fn streaming_output_call(
&self,
req: Request<StreamingOutputCallRequest>,
) -> Result<Self::StreamingOutputCallStream> {
let StreamingOutputCallRequest {
response_parameters,
..
} = req.into_inner();
let stream = try_stream! {
for param in response_parameters {
tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
let payload = crate::server_payload(param.size as usize);
yield StreamingOutputCallResponse { payload: Some(payload) };
}
};
Ok(Response::new(
Box::pin(stream) as Self::StreamingOutputCallStream
))
}
async fn streaming_input_call(
&self,
req: Streaming<StreamingInputCallRequest>,
) -> Result<StreamingInputCallResponse> {
let mut stream = req.into_inner();
let mut aggregated_payload_size = 0;
while let Some(msg) = stream.try_next().await? {
aggregated_payload_size += msg.payload.unwrap().body.len() as i32;
}
let res = StreamingInputCallResponse {
aggregated_payload_size,
};
Ok(Response::new(res))
}
type FullDuplexCallStream = Stream<StreamingOutputCallResponse>;
async fn full_duplex_call(
&self,
req: Streaming<StreamingOutputCallRequest>,
) -> Result<Self::FullDuplexCallStream> {
let mut stream = req.into_inner();
if let Some(first_msg) = stream.message().await? {
if let Some(echo_status) = first_msg.response_status {
let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
return Err(status);
}
let single_message = stream::iter(vec![Ok(first_msg)]);
let mut stream = single_message.chain(stream);
let stream = try_stream! {
while let Some(msg) = stream.try_next().await? {
if let Some(echo_status) = msg.response_status {
let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
Err(status)?;
}
for param in msg.response_parameters {
tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
let payload = crate::server_payload(param.size as usize);
yield StreamingOutputCallResponse { payload: Some(payload) };
}
}
};
Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
} else {
let stream = stream::empty();
Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
}
}
type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>;
async fn half_duplex_call(
&self,
_: Streaming<StreamingOutputCallRequest>,
) -> Result<Self::HalfDuplexCallStream> {
Err(Status::unimplemented("TODO"))
}
async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
Err(Status::unimplemented(""))
}
}
#[derive(Default)]
pub struct UnimplementedService;
#[tonic::async_trait]
impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
Err(Status::unimplemented(""))
}
}
#[derive(Clone, Default)]
pub struct EchoHeadersSvc<S> {
inner: S,
}
impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
const NAME: &'static str = S::NAME;
}
impl<S> EchoHeadersSvc<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S>
where
S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
let echo_header = req
.headers()
.get("x-grpc-test-echo-initial")
.map(Clone::clone);
let echo_trailer = req
.headers()
.get("x-grpc-test-echo-trailing-bin")
.map(Clone::clone)
.map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));
let call = self.inner.call(req);
Box::pin(async move {
let mut res = call.await?;
if let Some(echo_header) = echo_header {
res.headers_mut()
.insert("x-grpc-test-echo-initial", echo_header);
Ok(res
.map(|b| MergeTrailers::new(b, echo_trailer))
.map(BoxBody::new))
} else {
Ok(res)
}
})
}
}
pub struct MergeTrailers<B> {
inner: B,
trailer: Option<(HeaderName, HeaderValue)>,
}
impl<B> MergeTrailers<B> {
pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self {
Self { inner, trailer }
}
}
impl<B: Body + Unpin> Body for MergeTrailers<B> {
type Data = B::Data;
type Error = B::Error;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> {
Pin::new(&mut self.inner).poll_data(cx)
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> {
Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| {
h.map(|mut headers| {
if let Some((key, value)) = &self.trailer {
headers.insert(key.clone(), value.clone());
}
headers
})
})
}
}