Split tracer inner segment sender and receiver into traits. (#37)
diff --git a/src/context/tracer.rs b/src/context/tracer.rs
index 9e14052..c3f7abb 100644
--- a/src/context/tracer.rs
+++ b/src/context/tracer.rs
@@ -19,13 +19,14 @@
context::trace_context::TracingContext, reporter::DynReporter, reporter::Reporter,
skywalking_proto::v3::SegmentObject,
};
+use std::error::Error;
use std::future::Future;
use std::pin::Pin;
+use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Weak;
use std::task::{Context, Poll};
use std::{collections::LinkedList, sync::Arc};
use tokio::sync::OnceCell;
-use tokio::task::JoinError;
use tokio::{
sync::{
mpsc::{self},
@@ -33,6 +34,7 @@
},
task::JoinHandle,
};
+use tonic::async_trait;
static GLOBAL_TRACER: OnceCell<Tracer> = OnceCell::const_new();
@@ -65,12 +67,67 @@
global_tracer().reporting(shutdown_signal)
}
+pub trait SegmentSender: Send + Sync + 'static {
+ fn send(&self, segment: SegmentObject) -> Result<(), Box<dyn Error>>;
+}
+
+impl SegmentSender for () {
+ fn send(&self, _segment: SegmentObject) -> Result<(), Box<dyn Error>> {
+ Ok(())
+ }
+}
+
+impl SegmentSender for mpsc::UnboundedSender<SegmentObject> {
+ fn send(&self, segment: SegmentObject) -> Result<(), Box<dyn Error>> {
+ Ok(self.send(segment)?)
+ }
+}
+
+#[async_trait]
+pub trait SegmentReceiver: Send + Sync + 'static {
+ async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>>;
+
+ async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>>;
+}
+
+#[async_trait]
+impl SegmentReceiver for () {
+ async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+ Ok(None)
+ }
+
+ async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+ Ok(None)
+ }
+}
+
+#[async_trait]
+impl SegmentReceiver for Mutex<mpsc::UnboundedReceiver<SegmentObject>> {
+ async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+ Ok(self.lock().await.recv().await)
+ }
+
+ async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+ use mpsc::error::TryRecvError;
+
+ match self.lock().await.try_recv() {
+ Ok(segment) => Ok(Some(segment)),
+ Err(e) => match e {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(Box::new(e)),
+ },
+ }
+ }
+}
+
struct Inner {
service_name: String,
instance_name: String,
- segment_sender: mpsc::UnboundedSender<SegmentObject>,
- segment_receiver: Mutex<mpsc::UnboundedReceiver<SegmentObject>>,
+ segment_sender: Box<dyn SegmentSender>,
+ segment_receiver: Box<dyn SegmentReceiver>,
reporter: Box<Mutex<DynReporter>>,
+ is_reporting: AtomicBool,
+ is_closed: AtomicBool,
}
/// Skywalking tracer.
@@ -87,14 +144,30 @@
reporter: impl Reporter + Send + Sync + 'static,
) -> Self {
let (segment_sender, segment_receiver) = mpsc::unbounded_channel();
+ Self::new_with_channel(
+ service_name,
+ instance_name,
+ reporter,
+ (segment_sender, Mutex::new(segment_receiver)),
+ )
+ }
+ /// New with service info, reporter, and custom channel.
+ pub fn new_with_channel(
+ service_name: impl ToString,
+ instance_name: impl ToString,
+ reporter: impl Reporter + Send + Sync + 'static,
+ channel: (impl SegmentSender, impl SegmentReceiver),
+ ) -> Self {
Self {
inner: Arc::new(Inner {
service_name: service_name.to_string(),
instance_name: instance_name.to_string(),
- segment_sender,
- segment_receiver: Mutex::new(segment_receiver),
+ segment_sender: Box::new(channel.0),
+ segment_receiver: Box::new(channel.1),
reporter: Box::new(Mutex::new(reporter)),
+ is_reporting: Default::default(),
+ is_closed: Default::default(),
}),
}
}
@@ -131,73 +204,97 @@
/// Finalize the trace context.
pub(crate) fn finalize_context(&self, context: &mut TracingContext) {
+ if self.inner.is_closed.load(Ordering::Relaxed) {
+ tracing::warn!("tracer closed");
+ return;
+ }
+
let segment_object = context.convert_segment_object();
- if self.inner.segment_sender.send(segment_object).is_err() {
- tracing::error!("segment object channel has closed");
+ if let Err(err) = self.inner.segment_sender.send(segment_object) {
+ tracing::error!(?err, "send segment object failed");
}
}
/// Start to reporting, quit when shutdown_signal received.
///
/// Accept a `shutdown_signal` argument as a graceful shutdown signal.
+ ///
+ /// # Panics
+ ///
+ /// Panic if call more than once.
pub fn reporting(
&self,
shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static,
) -> Reporting {
+ if self.inner.is_reporting.swap(true, Ordering::Relaxed) {
+ panic!("reporting already called");
+ }
+
Reporting {
handle: tokio::spawn(self.clone().do_reporting(shutdown_signal)),
}
}
- async fn do_reporting(self, shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static) {
+ async fn do_reporting(
+ self,
+ shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static,
+ ) -> crate::Result<()> {
let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel();
let handle = tokio::spawn(async move {
loop {
- let mut segment_receiver = self.inner.segment_receiver.lock().await;
- let mut segments = LinkedList::new();
-
tokio::select! {
- segment = segment_receiver.recv() => {
- drop(segment_receiver);
-
- if let Some(segment) = segment {
- // TODO Implement batch collect in future.
- segments.push_back(segment);
- Self::report_segment_object(&self.inner.reporter, segments).await;
- } else {
- break;
+ segment = self.inner.segment_receiver.recv() => {
+ match segment {
+ Ok(Some(segment)) => {
+ // TODO Implement batch collect in future.
+ let mut segments = LinkedList::new();
+ segments.push_back(segment);
+ Self::report_segment_object(&self.inner.reporter, segments).await;
+ }
+ Ok(None) => break,
+ Err(err) => return Err(err.into()),
}
}
_ = shutdown_rx.recv() => break,
}
}
+ self.inner.is_closed.store(true, Ordering::Relaxed);
+
// Flush.
- let mut segment_receiver = self.inner.segment_receiver.lock().await;
let mut segments = LinkedList::new();
- while let Ok(segment) = segment_receiver.try_recv() {
- segments.push_back(segment);
+ loop {
+ match self.inner.segment_receiver.try_recv().await {
+ Ok(Some(segment)) => {
+ segments.push_back(segment);
+ }
+ Ok(None) => break,
+ Err(err) => return Err(err.into()),
+ }
}
Self::report_segment_object(&self.inner.reporter, segments).await;
+
+ Ok::<_, crate::Error>(())
});
shutdown_signal.await;
if shutdown_tx.send(()).is_err() {
- tracing::error!("Shutdown signal send failed");
+ tracing::error!("shutdown signal send failed");
}
- if let Err(e) = handle.await {
- tracing::error!("Tokio handle join failed: {:?}", e);
- }
+
+ handle.await??;
+
+ Ok(())
}
async fn report_segment_object(
reporter: &Mutex<DynReporter>,
segments: LinkedList<SegmentObject>,
) {
- if let Err(e) = reporter.lock().await.collect(segments).await {
- tracing::error!("Collect failed: {:?}", e);
+ if let Err(err) = reporter.lock().await.collect(segments).await {
+ tracing::error!(?err, "collect failed");
}
}
@@ -221,22 +318,29 @@
/// Created by [Tracer::reporting].
pub struct Reporting {
- handle: JoinHandle<()>,
+ handle: JoinHandle<crate::Result<()>>,
}
impl Future for Reporting {
- type Output = Result<(), JoinError>;
+ type Output = crate::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- Pin::new(&mut self.handle).poll(cx)
+ Pin::new(&mut self.handle).poll(cx).map(|r| r?)
}
}
#[cfg(test)]
mod tests {
use super::*;
+ use std::future;
trait AssertSend: Send {}
impl AssertSend for Tracer {}
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+ async fn custom_channel() {
+ let tracer = Tracer::new_with_channel("service_name", "instance_name", (), ((), ()));
+ tracer.reporting(future::ready(())).await.unwrap();
+ }
}
diff --git a/src/error/mod.rs b/src/error/mod.rs
index 5246571..406a240 100644
--- a/src/error/mod.rs
+++ b/src/error/mod.rs
@@ -33,4 +33,10 @@
#[error("tonic status: {0}")]
TonicStatus(#[from] tonic::Status),
+
+ #[error("tokio join failed: {0}")]
+ TokioJoin(#[from] tokio::task::JoinError),
+
+ #[error(transparent)]
+ Other(#[from] Box<dyn std::error::Error + Send + 'static>),
}
diff --git a/src/reporter/log.rs b/src/reporter/log.rs
index 0d002a2..2e0a98a 100644
--- a/src/reporter/log.rs
+++ b/src/reporter/log.rs
@@ -54,7 +54,7 @@
impl Default for LogReporter {
fn default() -> Self {
Self {
- tip: "Collect".to_string(),
+ tip: "collect".to_string(),
used: Used::Println,
}
}
diff --git a/src/reporter/mod.rs b/src/reporter/mod.rs
index 29ea40c..950b146 100644
--- a/src/reporter/mod.rs
+++ b/src/reporter/mod.rs
@@ -28,3 +28,13 @@
pub trait Reporter {
async fn collect(&mut self, segments: LinkedList<SegmentObject>) -> Result<(), Box<dyn Error>>;
}
+
+#[async_trait]
+impl Reporter for () {
+ async fn collect(
+ &mut self,
+ _segments: LinkedList<SegmentObject>,
+ ) -> Result<(), Box<dyn Error>> {
+ Ok(())
+ }
+}