blob: b84469eabecdd8a03e975698226b583c80bd5966 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::collections::HashMap;
use async_trait::async_trait;
use mockall::automock;
use slog::{info, o, Logger};
use tokio::sync::Mutex;
use tonic::metadata::AsciiMetadataValue;
use tonic::transport::{Channel, Endpoint};
use crate::conf::ClientOption;
use crate::error::ErrorKind;
use crate::model::common::Endpoints;
use crate::pb::{QueryRouteRequest, QueryRouteResponse, SendMessageRequest, SendMessageResponse};
use crate::{error::ClientError, pb::messaging_service_client::MessagingServiceClient};
#[async_trait]
#[automock]
pub(crate) trait RPCClient {
const OPERATION_QUERY_ROUTE: &'static str = "rpc.query_route";
const OPERATION_SEND_MESSAGE: &'static str = "rpc.send_message";
async fn query_route(
&mut self,
request: QueryRouteRequest,
) -> Result<QueryRouteResponse, ClientError>;
async fn send_message(
&mut self,
request: SendMessageRequest,
) -> Result<SendMessageResponse, ClientError>;
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub(crate) struct Session {
logger: Logger,
client_id: String,
stub: MessagingServiceClient<Channel>,
}
impl Session {
const OPERATION_CREATE: &'static str = "session.create_session";
const HTTP_SCHEMA: &'static str = "http";
const HTTPS_SCHEMA: &'static str = "https";
async fn new(
logger: &Logger,
endpoints: &Endpoints,
client_id: String,
option: &ClientOption,
) -> Result<Self, ClientError> {
let peer = endpoints.endpoint_url().to_owned();
let mut channel_endpoints = Vec::new();
for endpoint in endpoints.inner().addresses.clone() {
channel_endpoints.push(Self::build_endpoint(endpoint.host, endpoint.port, option)?);
}
if channel_endpoints.is_empty() {
return Err(ClientError::new(
ErrorKind::Connect,
"No endpoint available.",
Self::OPERATION_CREATE,
)
.with_context("peer", peer.clone()));
}
let channel = if channel_endpoints.len() == 1 {
channel_endpoints[0].connect().await.map_err(|e| {
ClientError::new(
ErrorKind::Connect,
"Failed to connect to peer.",
Self::OPERATION_CREATE,
)
.set_source(e)
.with_context("peer", peer.clone())
})?
} else {
Channel::balance_list(channel_endpoints.into_iter())
};
let stub = MessagingServiceClient::new(channel);
info!(
logger,
"create session success, peer={}",
endpoints.endpoint_url()
);
Ok(Session {
logger: logger.new(o!("component" => "session", "peer" => peer.clone())),
client_id,
stub,
})
}
fn build_endpoint(
host: String,
port: i32,
option: &ClientOption,
) -> Result<Endpoint, ClientError> {
let url = if option.enable_tls() {
format!("{}://{}:{}", Self::HTTPS_SCHEMA, host, port)
} else {
format!("{}://{}:{}", Self::HTTP_SCHEMA, host, port)
};
let endpoint = Endpoint::from_shared(url.clone())
.map_err(|e| {
ClientError::new(
ErrorKind::Connect,
"Failed to create channel endpoint.",
Self::OPERATION_CREATE,
)
.set_source(e)
.with_context("peer", url)
})?
// TODO tls config
// .tls_config(tls)
// .map_err(|e| {
// ClientError::new(
// ErrorKind::Connect,
// "Failed to configure TLS.".to_string(),
// OPERATION,
// )
// .set_source(e)
// .with_context("peer", &peer_addr)
// })?
.connect_timeout(std::time::Duration::from_secs(3))
.tcp_nodelay(true);
Ok(endpoint)
}
fn sign(&self, metadata: &mut tonic::metadata::MetadataMap) {
let _ = AsciiMetadataValue::try_from(&self.client_id)
.map(|v| metadata.insert("x-mq-client-id", v));
metadata.insert("x-mq-language", AsciiMetadataValue::from_static("RUST"));
metadata.insert(
"x-mq-client-version",
AsciiMetadataValue::from_static("5.0.0"),
);
metadata.insert(
"x-mq-protocol-version",
AsciiMetadataValue::from_static("2.0.0"),
);
}
}
#[async_trait]
impl RPCClient for Session {
async fn query_route(
&mut self,
request: QueryRouteRequest,
) -> Result<QueryRouteResponse, ClientError> {
let mut request = tonic::Request::new(request);
self.sign(request.metadata_mut());
let response = self.stub.query_route(request).await.map_err(|e| {
ClientError::new(
ErrorKind::ClientInternal,
"Query topic route rpc failed.",
Self::OPERATION_QUERY_ROUTE,
)
.set_source(e)
})?;
Ok(response.into_inner())
}
async fn send_message(
&mut self,
request: SendMessageRequest,
) -> Result<SendMessageResponse, ClientError> {
let mut request = tonic::Request::new(request);
self.sign(request.metadata_mut());
let response = self.stub.send_message(request).await.map_err(|e| {
ClientError::new(
ErrorKind::ClientInternal,
"Send message rpc failed.",
Self::OPERATION_SEND_MESSAGE,
)
.set_source(e)
})?;
Ok(response.into_inner())
}
}
#[derive(Debug)]
pub(crate) struct SessionManager {
logger: Logger,
client_id: String,
option: ClientOption,
session_map: Mutex<HashMap<String, Session>>,
}
impl SessionManager {
pub(crate) fn new(logger: &Logger, client_id: String, option: &ClientOption) -> Self {
let logger = logger.new(o!("component" => "session_manager"));
let session_map = Mutex::new(HashMap::new());
SessionManager {
logger,
client_id,
option: option.clone(),
session_map,
}
}
pub(crate) async fn get_session(&self, endpoints: &Endpoints) -> Result<Session, ClientError> {
let mut session_map = self.session_map.lock().await;
let endpoint_url = endpoints.endpoint_url().to_string();
return if session_map.contains_key(&endpoint_url) {
Ok(session_map.get(&endpoint_url).unwrap().clone())
} else {
let session = Session::new(
&self.logger,
endpoints,
self.client_id.clone(),
&self.option,
)
.await?;
session_map.insert(endpoint_url.clone(), session.clone());
Ok(session)
};
}
}
#[cfg(test)]
mod tests {
use crate::log::terminal_logger;
use slog::debug;
use wiremock_grpc::generate;
use super::*;
generate!("apache.rocketmq.v2", RocketMQMockServer);
#[tokio::test]
async fn session_new() {
let server = RocketMQMockServer::start_default().await;
let logger = terminal_logger();
let session = Session::new(
&logger,
&Endpoints::from_url(&format!("localhost:{}", server.address().port())).unwrap(),
"test_client".to_string(),
&ClientOption::default(),
)
.await;
debug!(logger, "session: {:?}", session);
}
#[tokio::test]
async fn session_new_multi_addr() {
let logger = terminal_logger();
let session = Session::new(
&logger,
&Endpoints::from_url("127.0.0.1:8080,127.0.0.1:8081").unwrap(),
"test_client".to_string(),
&ClientOption::default(),
)
.await;
debug!(logger, "session: {:?}", session);
}
#[tokio::test]
async fn session_manager_new() {
let server = RocketMQMockServer::start_default().await;
let logger = terminal_logger();
let session_manager =
SessionManager::new(&logger, "test_client".to_string(), &ClientOption::default());
let session = session_manager
.get_session(
&Endpoints::from_url(&format!("localhost:{}", server.address().port())).unwrap(),
)
.await
.unwrap();
debug!(logger, "session: {:?}", session);
}
}