blob: 0b0bd851ac8ef0ad7773ed7596081364ed46764b [file] [log] [blame]
use crate::pb::{
messaging_service_client::MessagingServiceClient, QueryRouteRequest, QueryRouteResponse,
SendMessageRequest, SendMessageResponse,
};
use rustls::client;
use tonic::{
metadata::MetadataMap,
transport::{Channel, ClientTlsConfig},
Request, Response,
};
use crate::credentials::CredentialProvider;
use std::collections::HashMap;
use std::rc::Rc;
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex, RwLock,
},
thread,
};
static CLIENT_SEQUENCE: AtomicUsize = AtomicUsize::new(0);
pub struct ClientConfig {
region: String,
service_name: String,
resource_namespace: Option<String>,
credential_provider: Option<Box<dyn CredentialProvider>>,
tenant_id: Option<String>,
connect_timeout: std::time::Duration,
io_timeout: std::time::Duration,
long_polling_timeout: std::time::Duration,
group: Option<String>,
client_id: String,
tracing: bool,
}
fn build_client_id() -> String {
let mut client_id = String::new();
match gethostname::gethostname().into_string() {
Ok(hostname) => {
client_id.push_str(&hostname);
}
Err(_) => {
client_id.push_str("localhost");
}
};
client_id.push('@');
let pid = std::process::id();
client_id.push_str(&pid.to_string());
client_id.push('#');
let sequence = CLIENT_SEQUENCE.fetch_add(1usize, Ordering::Relaxed);
client_id.push_str(&sequence.to_string());
client_id
}
impl Default for ClientConfig {
fn default() -> Self {
let client_id = build_client_id();
Self {
region: String::from("cn-hangzhou"),
service_name: String::from("RocketMQ"),
resource_namespace: None,
credential_provider: None,
tenant_id: None,
connect_timeout: std::time::Duration::from_secs(3),
io_timeout: std::time::Duration::from_secs(3),
long_polling_timeout: std::time::Duration::from_secs(3),
group: None,
client_id,
tracing: false,
}
}
}
pub struct RpcClient {
client_config: Arc<RwLock<ClientConfig>>,
stub: MessagingServiceClient<Channel>,
peer_address: String,
// client_config: std::rc::Rc<ClientConfig>,
}
impl RpcClient {
pub async fn new(
target: String,
client_config: Arc<RwLock<ClientConfig>>,
) -> Result<RpcClient, Box<dyn std::error::Error>> {
let config = Arc::clone(&client_config);
let mut channel = Channel::from_shared(target.clone())?
.tcp_nodelay(true)
.connect_timeout(config.read().unwrap().connect_timeout);
if target.starts_with("https://") {
channel = channel.tls_config(ClientTlsConfig::new())?;
}
let channel = channel.connect().await?;
let stub = MessagingServiceClient::new(channel);
Ok(RpcClient {
client_config,
stub,
peer_address: target,
})
}
fn add_metadata(meta: &mut MetadataMap) {}
pub async fn query_route(
&mut self,
request: QueryRouteRequest,
) -> Result<Response<QueryRouteResponse>, Box<dyn std::error::Error>> {
let mut req = Request::new(request);
RpcClient::add_metadata(req.metadata_mut());
Ok(self.stub.query_route(req).await?)
}
pub async fn send_message(
&mut self,
request: SendMessageRequest,
) -> Result<Response<SendMessageResponse>, Box<dyn std::error::Error>> {
let req = Request::new(request);
Ok(self.stub.send_message(req).await?)
}
}
#[derive(Default)]
pub struct ClientManager {
client_config: Arc<RwLock<ClientConfig>>,
clients: Mutex<HashMap<String, Rc<RpcClient>>>,
}
impl ClientManager {
pub async fn start(&self) {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
let handle = tokio::spawn(async move {
for i in 0..3 {
interval.tick().await;
println!("Tick");
}
});
let _result = handle.await;
}
pub async fn get_rpc_client(
&'static mut self,
endpoint: &str,
) -> Result<Rc<RpcClient>, Box<dyn std::error::Error>> {
let mut rpc_clients = self.clients.lock()?;
let key = endpoint.to_owned();
match rpc_clients.get(&key) {
Some(value) => {
return Ok(Rc::clone(value));
}
None => {
let rpc_client =
RpcClient::new(key.clone(), Arc::clone(&self.client_config)).await?;
let client = Rc::new(rpc_client);
rpc_clients.insert(key, Rc::clone(&client));
Ok(client)
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use std::collections::HashSet;
use crate::pb::{Code, Resource};
#[tokio::test]
async fn test_connect() {
let target = "http://127.0.0.1:5001";
let client_config = Arc::new(RwLock::new(ClientConfig::default()));
let _rpc_client = RpcClient::new(target.to_owned(), client_config)
.await
.expect("Should be able to connect");
}
#[tokio::test]
async fn test_connect_staging() {
let client_config = Arc::new(RwLock::new(ClientConfig::default()));
let target = "https://mq-inst-1080056302921134-bxuibml7.mq.cn-hangzhou.aliyuncs.com:80";
let _rpc_client = RpcClient::new(target.to_owned(), client_config)
.await
.expect("Failed to connect to staging proxy server");
}
#[tokio::test]
async fn test_query_route() {
let target = "http://127.0.0.1:5001";
let client_config = Arc::new(RwLock::new(ClientConfig::default()));
let mut rpc_client = RpcClient::new(target.to_owned(), client_config)
.await
.expect("Should be able to connect");
let topic = Resource {
resource_namespace: String::from("arn"),
name: String::from("TestTopic"),
};
let request = QueryRouteRequest {
topic: Some(topic),
endpoints: None,
};
let reply = rpc_client
.query_route(request)
.await
.expect("Failed to query route");
let route_response = reply.into_inner();
assert_eq!(route_response.status.unwrap().code, Code::Ok as i32);
}
#[test]
fn test_build_client_id() {
let mut set = HashSet::new();
let cnt = 1000;
for _ in 0..cnt {
let client_id = build_client_id();
set.insert(client_id);
}
assert_eq!(cnt, set.len());
}
#[tokio::test]
async fn test_periodic_task() {
let client_manager = ClientManager::default();
client_manager.start().await;
}
}