Ftr: Add the ClusterFilter (#79)
* Ftr: Add the ClusterFilter
* Ftr: Add the ClusterFilter
* Fix: change from crate::codegen::Request to crate::invocation::Request
diff --git a/dubbo/src/filter/clusterfilter.rs b/dubbo/src/filter/clusterfilter.rs
new file mode 100644
index 0000000..be898a4
--- /dev/null
+++ b/dubbo/src/filter/clusterfilter.rs
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+//! Implement the [`ClusterFilter`].
+
+use tower_service::Service;
+
+use core::fmt;
+use std::task::Poll;
+
+use bytes::Bytes;
+use futures::Future;
+use pin_project::pin_project;
+use tower::Layer;
+
+use crate::invocation::Metadata;
+use crate::invocation::Request;
+use crate::{boxed, status::Status, BoxBody, Error};
+
+use super::ClusterFilter;
+
+impl<F> ClusterFilter for F
+where
+ F: FnMut(Request<()>) -> Result<Request<()>, Status>,
+{
+ fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
+ self(req)
+ }
+}
+
+/// Create a [`ClusterFilterLayer`].
+///
+/// See [`ClusterFilter`] for more details.
+pub fn cluster_filter<F>(filter: F) -> ClusterFilterLayer<F> {
+ ClusterFilterLayer { filter }
+}
+
+/// A [`ClusterFilter`] can be used as a [`Layer`],
+/// is created by calling [`cluster_filter`].
+///
+/// See [`ClusterFilterService`] for more details.
+#[derive(Debug, Clone, Copy)]
+pub struct ClusterFilterLayer<F> {
+ filter: F,
+}
+
+impl<S, F: Clone> Layer<S> for ClusterFilterLayer<F> {
+ type Service = ClusterFilterService<S, F>;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ ClusterFilterService::new(inner, self.filter.clone())
+ }
+}
+
+/// The service will call the filter `CF` to preprocess the HTTP Request, and then pass the request to the service `S`.
+///
+/// See [`ClusterFilter`] for more details.
+#[derive(Clone, Copy)]
+pub struct ClusterFilterService<S, CF> {
+ inner: S,
+ filter: CF,
+}
+
+impl<S, CF> ClusterFilterService<S, CF> {
+ fn new(inner: S, filter: CF) -> Self {
+ Self { inner, filter }
+ }
+}
+
+impl<S, CF> fmt::Debug for ClusterFilterService<S, CF>
+where
+ S: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ClusterFilterService")
+ .field("inner", &self.inner)
+ .field("filter", &format_args!("{}", std::any::type_name::<CF>()))
+ .finish()
+ }
+}
+
+/// A service which call [`ClusterFilter`] and then pass the result to the inner [`Service`].
+impl<S, CF, ReqBody, RespBody> Service<http::Request<ReqBody>> for ClusterFilterService<S, CF>
+where
+ CF: ClusterFilter,
+ S: Service<http::Request<ReqBody>, Response = http::Response<RespBody>>,
+ S::Error: Into<Error>,
+ ReqBody: http_body::Body<Data = Bytes> + Send + 'static,
+ ReqBody::Error: Into<Error>,
+ RespBody: http_body::Body<Data = Bytes> + Send + 'static,
+ RespBody::Error: Into<Error>,
+{
+ type Response = http::Response<BoxBody>;
+ type Error = S::Error;
+ type Future = ResponseFuture<S::Future>;
+
+ fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.inner.poll_ready(cx)
+ }
+
+ fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
+ // It's a bad practice to modify the HTTP Request's body, so we extract it from the original
+ // request and recreate a Dubbo's HTTP Request, which can avoid exposing the message to the
+ // [`ClusterFilter`]. Dubbo's [`invocation::Request`] don't preserve the URI, Method and Version of the
+ // original HTTP Request, so we extract them and set them back later when recreate the HTTP Request.
+ let uri = req.uri().clone();
+ let method = req.method().clone();
+ let version = req.version();
+ let (parts, msg) = req.into_parts();
+
+ let req = Request::from_parts(Metadata::from_headers(parts.headers), ());
+ match self.filter.call(req) {
+ Ok(req) => {
+ let (metadata, _) = req.into_parts();
+ let req = Request::from_parts(metadata, msg);
+ let http_req = req.into_http(uri, method, version);
+ ResponseFuture::furure(self.inner.call(http_req))
+ }
+ Err(status) => ResponseFuture::status(status),
+ }
+ }
+}
+
+/// A general Future for [`ClusterFilterService`].
+#[pin_project]
+#[derive(Debug)]
+pub struct ResponseFuture<F> {
+ #[pin]
+ kind: Kind<F>,
+}
+
+#[pin_project(project = KindProj)]
+#[derive(Debug)]
+enum Kind<F> {
+ Future(#[pin] F),
+ Status(Option<Status>),
+}
+
+impl<F> ResponseFuture<F> {
+ fn furure(future: F) -> Self {
+ Self {
+ kind: Kind::Future(future),
+ }
+ }
+
+ fn status(status: Status) -> Self {
+ Self {
+ kind: Kind::Status(Some(status)),
+ }
+ }
+}
+
+impl<F, RespBody, E> Future for ResponseFuture<F>
+where
+ F: Future<Output = Result<http::Response<RespBody>, E>>,
+ E: Into<Error>,
+ RespBody: http_body::Body<Data = Bytes> + Send + 'static,
+ RespBody::Error: Into<Error>,
+{
+ type Output = Result<http::Response<BoxBody>, E>;
+
+ fn poll(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Self::Output> {
+ match self.project().kind.project() {
+ KindProj::Future(fut) => fut
+ .poll(cx)
+ .map(|result| result.map(|resp| resp.map(boxed))),
+ KindProj::Status(status) => {
+ Poll::Ready(Ok(status.take().unwrap().to_http().map(boxed)))
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use tower::service_fn;
+ use tower::ServiceBuilder;
+ use tower::ServiceExt;
+
+ use crate::empty_body;
+ use crate::status::Code;
+ use crate::status::Status;
+ use crate::BoxBody;
+
+ use super::*;
+
+ const USER_AGENT: &str = "user-agent";
+ const USER_AGENT_VAL: &str = "dubbo-test";
+
+ #[derive(Clone, Copy)]
+ struct MyFilter;
+
+ impl ClusterFilter for MyFilter {
+ fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
+ assert_eq!(
+ req.metadata()
+ .get_ref()
+ .get(USER_AGENT)
+ .expect("missing user-agent."),
+ USER_AGENT_VAL
+ );
+ Ok::<_, Status>(req)
+ }
+ }
+
+ #[tokio::test]
+ async fn doesnt_change_anything() {
+ let svc = service_fn(|req: http::Request<BoxBody>| async move {
+ assert_eq!(
+ req.headers().get(USER_AGENT).expect("missing user-agent."),
+ USER_AGENT_VAL
+ );
+ Ok::<_, Status>(http::Response::new(empty_body()))
+ });
+ let svc = ClusterFilterService::new(svc, MyFilter);
+ let req = http::Request::builder()
+ .header(USER_AGENT, USER_AGENT_VAL)
+ .body(empty_body())
+ .unwrap();
+ svc.oneshot(req).await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn add_cluster_filter_to_service() {
+ let svc = service_fn(|req: http::Request<BoxBody>| async move {
+ assert_eq!(
+ req.headers().get(USER_AGENT).expect("missing user-agent."),
+ USER_AGENT_VAL
+ );
+ Ok::<_, Status>(http::Response::new(empty_body()))
+ });
+ let svc = ServiceBuilder::new()
+ .layer(cluster_filter(MyFilter))
+ .service(svc);
+ let req = http::Request::builder()
+ .header(USER_AGENT, USER_AGENT_VAL)
+ .body(empty_body())
+ .unwrap();
+ svc.oneshot(req).await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn handle_status_as_response() {
+ let msg = "PermissionDenied from ClusterFilter.";
+ let expected = Status::new(Code::PermissionDenied, msg).to_http();
+
+ let svc = service_fn(|_: http::Request<BoxBody>| async move {
+ Ok::<_, Status>(http::Response::new(empty_body()))
+ });
+ let svc = ClusterFilterService::new(svc, |_: Request<()>| -> Result<Request<()>, Status> {
+ Err(Status::new(Code::PermissionDenied, msg))
+ });
+ let resp = svc
+ .oneshot(http::Request::builder().body(empty_body()).unwrap())
+ .await
+ .unwrap();
+
+ assert_eq!(resp.headers(), expected.headers());
+ assert_eq!(resp.status(), expected.status());
+ assert_eq!(resp.version(), expected.version());
+ }
+}
diff --git a/dubbo/src/filter/mod.rs b/dubbo/src/filter/mod.rs
index 075781a..dece123 100644
--- a/dubbo/src/filter/mod.rs
+++ b/dubbo/src/filter/mod.rs
@@ -14,11 +14,85 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-pub mod service;
+//! Filters which can preprocess or postprocess request.
+//!
+//! TODO: Add the `ClusterFilter` according to the name.
+//!
+//! # Example
+//! ## ClusterFilter
+//! ```no_run
+//! const USER_AGENT: &str = "user-agent";
+//! const USER_AGENT_VAL: &str = "dubbo-test";
+//! const USER_AGENT_VAL_2: &str = "dubbo-test-2";
+//!
+//! #[derive(Clone, Copy)]
+//! struct MyFilter;
+//!
+//! impl ClusterFilter for MyFilter1 {
+//! fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
+//! req.metadata_mut().get_mut().insert(
+//! USER_AGENT.to_string(),
+//! USER_AGENT_VAL.to_string(),
+//! );
+//! Ok::<_, Status>(req)
+//! }
+//! }
+//!
+//! impl ClusterFilter for MyFilter2 {
+//! fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status> {
+//! assert_eq!(
+//! req.metadata()
+//! .get_ref()
+//! .get(USER_AGENT)
+//! .expect("missing user-agent."),
+//! USER_AGENT_VAL
+//! );
+//! req.metadata_mut().get_mut().insert(
+//! USER_AGENT.to_string(),
+//! USER_AGENT_VAL_2.to_string(),
+//! );
+//! Ok::<_, Status>(req)
+//! }
+//! }
+//!
+//! #[tokio::main]
+//! async fn main() {
+//! let svc = service_fn(|req: http::Request<BoxBody>| async move {
+//! assert_eq!(
+//! req.headers().get(USER_AGENT).map(|v| v.to_str().unwrap()),
+//! Some(USER_AGENT_VAL_2)
+//! );
+//! Ok::<_, Status>(http::Response::new(empty_body()))
+//! });
+//! let svc = ServiceBuilder::new()
+//! .layer(cluster_filter(MyFilter1))
+//! .layer(cluster_filter(MyFilter2))
+//! .service(svc);
+//! let req = http::Request::builder()
+//! .body(empty_body())
+//! .unwrap();
+//! svc.oneshot(req).await.unwrap();
+//! }
+//! ```
use crate::invocation::Request;
+pub mod clusterfilter;
+pub mod service;
+
+/// TODO: Implement it.
pub trait Filter {
fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status>;
}
+
+/// The `ClusterFilter` can **preprocess** the HTTP Request and then pass the result to the inner
+/// [`Service`].
+/// `ClusterFilter` is implemented as a tower's [`Layer`], which can let us take full advantage of the
+/// other `Layer` provided by [`tower-http`].
+///
+/// [`tower-http`]: https://docs.rs/tower-http/latest/tower_http/
+/// [`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html
+/// [`Layer`]: https://docs.rs/tower/latest/tower/trait.Layer.html
+pub trait ClusterFilter {
+ fn call(&mut self, req: Request<()>) -> Result<Request<()>, crate::status::Status>;
+}
diff --git a/dubbo/src/invocation.rs b/dubbo/src/invocation.rs
index 5a80e9a..87781a4 100644
--- a/dubbo/src/invocation.rs
+++ b/dubbo/src/invocation.rs
@@ -31,6 +31,26 @@
}
}
+ /// Get a reference to the message
+ pub fn get_ref(&self) -> &T {
+ &self.message
+ }
+
+ /// Get a mutable reference to the message
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.message
+ }
+
+ /// Get a reference to the custom request metadata.
+ pub fn metadata(&self) -> &Metadata {
+ &self.metadata
+ }
+
+ /// Get a mutable reference to the request metadata.
+ pub fn metadata_mut(&mut self) -> &mut Metadata {
+ &mut self.metadata
+ }
+
pub fn into_inner(self) -> T {
self.message
}
@@ -166,6 +186,16 @@
}
}
+ /// Get a reference to the inner
+ pub fn get_ref(&self) -> &HashMap<String, String> {
+ &self.inner
+ }
+
+ /// Get a mutable reference to the inner
+ pub fn get_mut(&mut self) -> &mut HashMap<String, String> {
+ &mut self.inner
+ }
+
pub fn from_headers(headers: http::HeaderMap) -> Self {
let mut h: HashMap<String, String> = HashMap::new();
for (k, v) in headers.into_iter() {
diff --git a/dubbo/src/status.rs b/dubbo/src/status.rs
index 926c1d6..68e15a4 100644
--- a/dubbo/src/status.rs
+++ b/dubbo/src/status.rs
@@ -262,12 +262,18 @@
}
impl Status {
- pub fn new(code: Code, message: String) -> Self {
- Status { code, message }
+ pub fn new(code: Code, message: impl Into<String>) -> Self {
+ Status {
+ code,
+ message: message.into(),
+ }
}
- pub fn with_message(self, message: String) -> Self {
- Status { message, ..self }
+ pub fn with_message(self, message: impl Into<String>) -> Self {
+ Status {
+ message: message.into(),
+ ..self
+ }
}
pub fn from_std_erro<T: std::error::Error>(err: T) -> Self {
diff --git a/examples/echo/Cargo.toml b/examples/echo/Cargo.toml
index 9e794b0..b968de8 100644
--- a/examples/echo/Cargo.toml
+++ b/examples/echo/Cargo.toml
@@ -25,6 +25,7 @@
dubbo = {path = "../../dubbo", version = "0.2.0"}
dubbo-config = {path = "../../config", version = "0.2.0"}
+tower = "0.4.13"
[build-dependencies]
dubbo-build = {path = "../../dubbo-build", version = "0.2.0"}
diff --git a/examples/echo/src/echo/server.rs b/examples/echo/src/echo/server.rs
index 2be45b0..6b01225 100644
--- a/examples/echo/src/echo/server.rs
+++ b/examples/echo/src/echo/server.rs
@@ -19,6 +19,8 @@
use std::pin::Pin;
use async_trait::async_trait;
+use dubbo::filter::clusterfilter::cluster_filter;
+use dubbo::filter::ClusterFilter;
use futures_util::Stream;
use futures_util::StreamExt;
use tokio::sync::mpsc;
@@ -31,6 +33,7 @@
echo_server::{register_server, Echo, EchoServer},
EchoRequest, EchoResponse,
};
+use tower::ServiceBuilder;
type ResponseStream =
Pin<Box<dyn Stream<Item = Result<EchoResponse, dubbo::status::Status>> + Send>>;
@@ -50,14 +53,16 @@
register_server(EchoServerImpl {
name: "echo".to_string(),
});
- let server = EchoServerImpl::default();
- let s = EchoServer::<EchoServerImpl>::with_filter(server, FakeFilter {});
+ let server = EchoServer::new(EchoServerImpl::default());
+ let server = ServiceBuilder::new()
+ .layer(cluster_filter(EchoClusterFilter))
+ .service(server);
dubbo::protocol::triple::TRIPLE_SERVICES
.write()
.unwrap()
.insert(
"grpc.examples.echo.Echo".to_string(),
- dubbo::utils::boxed_clone::BoxCloneService::new(s),
+ dubbo::utils::boxed_clone::BoxCloneService::new(server),
);
// Dubbo::new().start().await;
@@ -208,3 +213,17 @@
};
}
}
+
+#[derive(Clone, Copy)]
+struct EchoClusterFilter;
+
+impl ClusterFilter for EchoClusterFilter {
+ fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, dubbo::status::Status> {
+ println!("EchoClusterFilter");
+ req.metadata_mut().get_mut().insert(
+ "EchoClusterFilter".to_string(),
+ "EchoClusterFilter".to_string(),
+ );
+ Ok(req)
+ }
+}