blob: be898a45fdfd7fdb13c4c03ebe7ad3bcc2e8389f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! Implement the [`ClusterFilter`].
use tower_service::Service;
use core::fmt;
use std::task::Poll;
use bytes::Bytes;
use futures::Future;
use pin_project::pin_project;
use tower::Layer;
use crate::invocation::Metadata;
use crate::invocation::Request;
use crate::{boxed, status::Status, BoxBody, Error};
use super::ClusterFilter;
impl<F> ClusterFilter for F
where
F: FnMut(Request<()>) -> Result<Request<()>, Status>,
{
fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
self(req)
}
}
/// Create a [`ClusterFilterLayer`].
///
/// See [`ClusterFilter`] for more details.
pub fn cluster_filter<F>(filter: F) -> ClusterFilterLayer<F> {
ClusterFilterLayer { filter }
}
/// A [`ClusterFilter`] can be used as a [`Layer`],
/// is created by calling [`cluster_filter`].
///
/// See [`ClusterFilterService`] for more details.
#[derive(Debug, Clone, Copy)]
pub struct ClusterFilterLayer<F> {
filter: F,
}
impl<S, F: Clone> Layer<S> for ClusterFilterLayer<F> {
type Service = ClusterFilterService<S, F>;
fn layer(&self, inner: S) -> Self::Service {
ClusterFilterService::new(inner, self.filter.clone())
}
}
/// The service will call the filter `CF` to preprocess the HTTP Request, and then pass the request to the service `S`.
///
/// See [`ClusterFilter`] for more details.
#[derive(Clone, Copy)]
pub struct ClusterFilterService<S, CF> {
inner: S,
filter: CF,
}
impl<S, CF> ClusterFilterService<S, CF> {
fn new(inner: S, filter: CF) -> Self {
Self { inner, filter }
}
}
impl<S, CF> fmt::Debug for ClusterFilterService<S, CF>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClusterFilterService")
.field("inner", &self.inner)
.field("filter", &format_args!("{}", std::any::type_name::<CF>()))
.finish()
}
}
/// A service which call [`ClusterFilter`] and then pass the result to the inner [`Service`].
impl<S, CF, ReqBody, RespBody> Service<http::Request<ReqBody>> for ClusterFilterService<S, CF>
where
CF: ClusterFilter,
S: Service<http::Request<ReqBody>, Response = http::Response<RespBody>>,
S::Error: Into<Error>,
ReqBody: http_body::Body<Data = Bytes> + Send + 'static,
ReqBody::Error: Into<Error>,
RespBody: http_body::Body<Data = Bytes> + Send + 'static,
RespBody::Error: Into<Error>,
{
type Response = http::Response<BoxBody>;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
// It's a bad practice to modify the HTTP Request's body, so we extract it from the original
// request and recreate a Dubbo's HTTP Request, which can avoid exposing the message to the
// [`ClusterFilter`]. Dubbo's [`invocation::Request`] don't preserve the URI, Method and Version of the
// original HTTP Request, so we extract them and set them back later when recreate the HTTP Request.
let uri = req.uri().clone();
let method = req.method().clone();
let version = req.version();
let (parts, msg) = req.into_parts();
let req = Request::from_parts(Metadata::from_headers(parts.headers), ());
match self.filter.call(req) {
Ok(req) => {
let (metadata, _) = req.into_parts();
let req = Request::from_parts(metadata, msg);
let http_req = req.into_http(uri, method, version);
ResponseFuture::furure(self.inner.call(http_req))
}
Err(status) => ResponseFuture::status(status),
}
}
}
/// A general Future for [`ClusterFilterService`].
#[pin_project]
#[derive(Debug)]
pub struct ResponseFuture<F> {
#[pin]
kind: Kind<F>,
}
#[pin_project(project = KindProj)]
#[derive(Debug)]
enum Kind<F> {
Future(#[pin] F),
Status(Option<Status>),
}
impl<F> ResponseFuture<F> {
fn furure(future: F) -> Self {
Self {
kind: Kind::Future(future),
}
}
fn status(status: Status) -> Self {
Self {
kind: Kind::Status(Some(status)),
}
}
}
impl<F, RespBody, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<http::Response<RespBody>, E>>,
E: Into<Error>,
RespBody: http_body::Body<Data = Bytes> + Send + 'static,
RespBody::Error: Into<Error>,
{
type Output = Result<http::Response<BoxBody>, E>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future(fut) => fut
.poll(cx)
.map(|result| result.map(|resp| resp.map(boxed))),
KindProj::Status(status) => {
Poll::Ready(Ok(status.take().unwrap().to_http().map(boxed)))
}
}
}
}
#[cfg(test)]
mod test {
use tower::service_fn;
use tower::ServiceBuilder;
use tower::ServiceExt;
use crate::empty_body;
use crate::status::Code;
use crate::status::Status;
use crate::BoxBody;
use super::*;
const USER_AGENT: &str = "user-agent";
const USER_AGENT_VAL: &str = "dubbo-test";
#[derive(Clone, Copy)]
struct MyFilter;
impl ClusterFilter for MyFilter {
fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
assert_eq!(
req.metadata()
.get_ref()
.get(USER_AGENT)
.expect("missing user-agent."),
USER_AGENT_VAL
);
Ok::<_, Status>(req)
}
}
#[tokio::test]
async fn doesnt_change_anything() {
let svc = service_fn(|req: http::Request<BoxBody>| async move {
assert_eq!(
req.headers().get(USER_AGENT).expect("missing user-agent."),
USER_AGENT_VAL
);
Ok::<_, Status>(http::Response::new(empty_body()))
});
let svc = ClusterFilterService::new(svc, MyFilter);
let req = http::Request::builder()
.header(USER_AGENT, USER_AGENT_VAL)
.body(empty_body())
.unwrap();
svc.oneshot(req).await.unwrap();
}
#[tokio::test]
async fn add_cluster_filter_to_service() {
let svc = service_fn(|req: http::Request<BoxBody>| async move {
assert_eq!(
req.headers().get(USER_AGENT).expect("missing user-agent."),
USER_AGENT_VAL
);
Ok::<_, Status>(http::Response::new(empty_body()))
});
let svc = ServiceBuilder::new()
.layer(cluster_filter(MyFilter))
.service(svc);
let req = http::Request::builder()
.header(USER_AGENT, USER_AGENT_VAL)
.body(empty_body())
.unwrap();
svc.oneshot(req).await.unwrap();
}
#[tokio::test]
async fn handle_status_as_response() {
let msg = "PermissionDenied from ClusterFilter.";
let expected = Status::new(Code::PermissionDenied, msg).to_http();
let svc = service_fn(|_: http::Request<BoxBody>| async move {
Ok::<_, Status>(http::Response::new(empty_body()))
});
let svc = ClusterFilterService::new(svc, |_: Request<()>| -> Result<Request<()>, Status> {
Err(Status::new(Code::PermissionDenied, msg))
});
let resp = svc
.oneshot(http::Request::builder().body(empty_body()).unwrap())
.await
.unwrap();
assert_eq!(resp.headers(), expected.headers());
assert_eq!(resp.status(), expected.status());
assert_eq!(resp.version(), expected.version());
}
}