blob: 096550dfb40be50f454046df1eeec0a16364e174 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use futures_util::{future, stream, StreamExt, TryStreamExt};
use http::HeaderValue;
use http_body::Body;
use prost::Message;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::{
invocation::Request,
status::Status,
triple::{
client::triple::get_codec,
codec::{Decoder, Encoder},
compression::{CompressionEncoding, COMPRESSIONS},
decode::Decoding,
encode::encode_server,
server::service::{ClientStreamingSvc, ServerStreamingSvc, StreamingSvc, UnarySvc},
},
BoxBody,
};
pub const GRPC_ACCEPT_ENCODING: &str = "grpc-accept-encoding";
pub const GRPC_ENCODING: &str = "grpc-encoding";
pub struct TripleServer<M1, M2> {
_pd: PhantomData<(M1, M2)>,
compression: Option<CompressionEncoding>,
}
impl<M1, M2> TripleServer<M1, M2> {
pub fn new() -> Self {
Self {
_pd: PhantomData,
compression: Some(CompressionEncoding::Gzip),
}
}
}
impl<M1, M2> TripleServer<M1, M2>
where
M1: Message + for<'a> Deserialize<'a> + Default + 'static,
M2: Message + Serialize + Default + 'static,
{
pub async fn client_streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: ClientStreamingSvc<M1, Response = M2>,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
let content_type = req
.headers()
.get("content-type")
.cloned()
.unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
let content_type_str = content_type.to_str().unwrap();
let (decoder, encoder): (
Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
) = get_codec(content_type_str);
let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
if self.compression.is_none() || accept_encoding.is_none() {
accept_encoding = None;
}
// Get grpc_encoding from http_header, decompress message.
let compression = match self.get_encoding_from_req(req.headers()) {
Ok(val) => val,
Err(status) => return status.to_http(),
};
let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
let resp = service.call(Request::from_http(req_stream)).await;
let (mut parts, resp_body) = match resp {
Ok(v) => v.into_http().into_parts(),
Err(err) => return err.to_http(),
};
let resp_body = encode_server(
encoder,
stream::once(future::ready(resp_body)).map(Ok).into_stream(),
accept_encoding,
true,
);
parts
.headers
.insert(http::header::CONTENT_TYPE, content_type);
if let Some(encoding) = accept_encoding {
parts
.headers
.insert(GRPC_ENCODING, encoding.into_header_value());
}
parts.status = http::StatusCode::OK;
http::Response::from_parts(parts, BoxBody::new(resp_body))
}
pub async fn bidi_streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: StreamingSvc<M1, Response = M2>,
S::ResponseStream: Send + 'static,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
let content_type = req
.headers()
.get("content-type")
.cloned()
.unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
let content_type_str = content_type.to_str().unwrap();
let (decoder, encoder): (
Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
) = get_codec(content_type_str);
// Firstly, get grpc_accept_encoding from http_header, get compression
// Secondly, if server enable compression and compression is valid, this method should compress response
let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
if self.compression.is_none() || accept_encoding.is_none() {
accept_encoding = None;
}
// Get grpc_encoding from http_header, decompress message.
let compression = match self.get_encoding_from_req(req.headers()) {
Ok(val) => val,
Err(status) => return status.to_http(),
};
let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
let resp = service.call(Request::from_http(req_stream)).await;
let (mut parts, resp_body) = match resp {
Ok(v) => v.into_http().into_parts(),
Err(err) => return err.to_http(),
};
let resp_body = encode_server(encoder, resp_body, compression, true);
parts
.headers
.insert(http::header::CONTENT_TYPE, content_type);
if let Some(encoding) = accept_encoding {
parts
.headers
.insert(GRPC_ENCODING, encoding.into_header_value());
}
parts.status = http::StatusCode::OK;
http::Response::from_parts(parts, BoxBody::new(resp_body))
}
pub async fn server_streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: ServerStreamingSvc<M1, Response = M2>,
S::ResponseStream: Send + 'static,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
let content_type = req
.headers()
.get("content-type")
.cloned()
.unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
let content_type_str = content_type.to_str().unwrap();
let (decoder, encoder): (
Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
) = get_codec(content_type_str);
// Firstly, get grpc_accept_encoding from http_header, get compression
// Secondly, if server enable compression and compression is valid, this method should compress response
let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
if self.compression.is_none() || accept_encoding.is_none() {
accept_encoding = None;
}
// Get grpc_encoding from http_header, decompress message.
let compression = match self.get_encoding_from_req(req.headers()) {
Ok(val) => val,
Err(status) => return status.to_http(),
};
let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
let (parts, mut body) = Request::from_http(req_stream).into_parts();
let msg = body.try_next().await.unwrap().ok_or_else(|| {
crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
});
let msg = match msg {
Ok(v) => v,
Err(err) => return err.to_http(),
};
let resp = service.call(Request::from_parts(parts, msg)).await;
let (mut parts, resp_body) = match resp {
Ok(v) => v.into_http().into_parts(),
Err(err) => return err.to_http(),
};
let resp_body = encode_server(encoder, resp_body, compression, true);
parts
.headers
.insert(http::header::CONTENT_TYPE, content_type);
if let Some(encoding) = accept_encoding {
parts
.headers
.insert(GRPC_ENCODING, encoding.into_header_value());
}
parts.status = http::StatusCode::OK;
http::Response::from_parts(parts, BoxBody::new(resp_body))
}
pub async fn unary<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: UnarySvc<M1, Response = M2>,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
if self.compression.is_none() || accept_encoding.is_none() {
accept_encoding = None;
}
let compression = match self.get_encoding_from_req(req.headers()) {
Ok(val) => val,
Err(status) => return status.to_http(),
};
let content_type = req
.headers()
.get("content-type")
.cloned()
.unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
let content_type_str = content_type.to_str().unwrap();
//Determine whether to use the gRPC mode to handle request data
let handle_request_as_grpc = content_type_str.contains("grpc");
let (decoder, encoder): (
Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
) = get_codec(content_type_str);
let req_stream =
req.map(|body| Decoding::new(body, decoder, compression, handle_request_as_grpc));
let (parts, mut body) = Request::from_http(req_stream).into_parts();
let msg = body.try_next().await.unwrap().ok_or_else(|| {
crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
});
let msg = match msg {
Ok(v) => v,
Err(err) => return err.to_http(),
};
let resp = service.call(Request::from_parts(parts, msg)).await;
let (mut parts, resp_body) = match resp {
Ok(v) => v.into_http().into_parts(),
Err(err) => return err.to_http(),
};
let resp_body = encode_server(
encoder,
stream::once(future::ready(resp_body)).map(Ok).into_stream(),
accept_encoding,
handle_request_as_grpc,
);
parts
.headers
.insert(http::header::CONTENT_TYPE, content_type);
if let Some(encoding) = accept_encoding {
parts
.headers
.insert(GRPC_ENCODING, encoding.into_header_value());
}
parts.status = http::StatusCode::OK;
http::Response::from_parts(parts, BoxBody::new(resp_body))
}
fn get_encoding_from_req(
&self,
header: &http::HeaderMap,
) -> Result<Option<CompressionEncoding>, crate::status::Status> {
let encoding = match header.get(GRPC_ENCODING) {
Some(val) => val.to_str().unwrap(),
None => return Ok(None),
};
let compression = match COMPRESSIONS.get(encoding) {
Some(val) => val.to_owned(),
None => {
let status = crate::status::Status::new(
crate::status::Code::Unimplemented,
format!("grpc-accept-encoding: {} not support!", encoding),
);
return Err(status);
}
};
Ok(compression)
}
}