Split tracer inner segment sender and receiver into traits. (#37)

diff --git a/src/context/tracer.rs b/src/context/tracer.rs
index 9e14052..c3f7abb 100644
--- a/src/context/tracer.rs
+++ b/src/context/tracer.rs
@@ -19,13 +19,14 @@
     context::trace_context::TracingContext, reporter::DynReporter, reporter::Reporter,
     skywalking_proto::v3::SegmentObject,
 };
+use std::error::Error;
 use std::future::Future;
 use std::pin::Pin;
+use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Weak;
 use std::task::{Context, Poll};
 use std::{collections::LinkedList, sync::Arc};
 use tokio::sync::OnceCell;
-use tokio::task::JoinError;
 use tokio::{
     sync::{
         mpsc::{self},
@@ -33,6 +34,7 @@
     },
     task::JoinHandle,
 };
+use tonic::async_trait;
 
 static GLOBAL_TRACER: OnceCell<Tracer> = OnceCell::const_new();
 
@@ -65,12 +67,67 @@
     global_tracer().reporting(shutdown_signal)
 }
 
+pub trait SegmentSender: Send + Sync + 'static {
+    fn send(&self, segment: SegmentObject) -> Result<(), Box<dyn Error>>;
+}
+
+impl SegmentSender for () {
+    fn send(&self, _segment: SegmentObject) -> Result<(), Box<dyn Error>> {
+        Ok(())
+    }
+}
+
+impl SegmentSender for mpsc::UnboundedSender<SegmentObject> {
+    fn send(&self, segment: SegmentObject) -> Result<(), Box<dyn Error>> {
+        Ok(self.send(segment)?)
+    }
+}
+
+#[async_trait]
+pub trait SegmentReceiver: Send + Sync + 'static {
+    async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>>;
+
+    async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>>;
+}
+
+#[async_trait]
+impl SegmentReceiver for () {
+    async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+        Ok(None)
+    }
+
+    async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+        Ok(None)
+    }
+}
+
+#[async_trait]
+impl SegmentReceiver for Mutex<mpsc::UnboundedReceiver<SegmentObject>> {
+    async fn recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+        Ok(self.lock().await.recv().await)
+    }
+
+    async fn try_recv(&self) -> Result<Option<SegmentObject>, Box<dyn Error + Send>> {
+        use mpsc::error::TryRecvError;
+
+        match self.lock().await.try_recv() {
+            Ok(segment) => Ok(Some(segment)),
+            Err(e) => match e {
+                TryRecvError::Empty => Ok(None),
+                TryRecvError::Disconnected => Err(Box::new(e)),
+            },
+        }
+    }
+}
+
 struct Inner {
     service_name: String,
     instance_name: String,
-    segment_sender: mpsc::UnboundedSender<SegmentObject>,
-    segment_receiver: Mutex<mpsc::UnboundedReceiver<SegmentObject>>,
+    segment_sender: Box<dyn SegmentSender>,
+    segment_receiver: Box<dyn SegmentReceiver>,
     reporter: Box<Mutex<DynReporter>>,
+    is_reporting: AtomicBool,
+    is_closed: AtomicBool,
 }
 
 /// Skywalking tracer.
@@ -87,14 +144,30 @@
         reporter: impl Reporter + Send + Sync + 'static,
     ) -> Self {
         let (segment_sender, segment_receiver) = mpsc::unbounded_channel();
+        Self::new_with_channel(
+            service_name,
+            instance_name,
+            reporter,
+            (segment_sender, Mutex::new(segment_receiver)),
+        )
+    }
 
+    /// New with service info, reporter, and custom channel.
+    pub fn new_with_channel(
+        service_name: impl ToString,
+        instance_name: impl ToString,
+        reporter: impl Reporter + Send + Sync + 'static,
+        channel: (impl SegmentSender, impl SegmentReceiver),
+    ) -> Self {
         Self {
             inner: Arc::new(Inner {
                 service_name: service_name.to_string(),
                 instance_name: instance_name.to_string(),
-                segment_sender,
-                segment_receiver: Mutex::new(segment_receiver),
+                segment_sender: Box::new(channel.0),
+                segment_receiver: Box::new(channel.1),
                 reporter: Box::new(Mutex::new(reporter)),
+                is_reporting: Default::default(),
+                is_closed: Default::default(),
             }),
         }
     }
@@ -131,73 +204,97 @@
 
     /// Finalize the trace context.
     pub(crate) fn finalize_context(&self, context: &mut TracingContext) {
+        if self.inner.is_closed.load(Ordering::Relaxed) {
+            tracing::warn!("tracer closed");
+            return;
+        }
+
         let segment_object = context.convert_segment_object();
-        if self.inner.segment_sender.send(segment_object).is_err() {
-            tracing::error!("segment object channel has closed");
+        if let Err(err) = self.inner.segment_sender.send(segment_object) {
+            tracing::error!(?err, "send segment object failed");
         }
     }
 
     /// Start to reporting, quit when shutdown_signal received.
     ///
     /// Accept a `shutdown_signal` argument as a graceful shutdown signal.
+    ///
+    /// # Panics
+    ///
+    /// Panic if call more than once.
     pub fn reporting(
         &self,
         shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static,
     ) -> Reporting {
+        if self.inner.is_reporting.swap(true, Ordering::Relaxed) {
+            panic!("reporting already called");
+        }
+
         Reporting {
             handle: tokio::spawn(self.clone().do_reporting(shutdown_signal)),
         }
     }
 
-    async fn do_reporting(self, shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static) {
+    async fn do_reporting(
+        self,
+        shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static,
+    ) -> crate::Result<()> {
         let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel();
 
         let handle = tokio::spawn(async move {
             loop {
-                let mut segment_receiver = self.inner.segment_receiver.lock().await;
-                let mut segments = LinkedList::new();
-
                 tokio::select! {
-                    segment = segment_receiver.recv() => {
-                        drop(segment_receiver);
-
-                        if let Some(segment) = segment {
-                            // TODO Implement batch collect in future.
-                            segments.push_back(segment);
-                            Self::report_segment_object(&self.inner.reporter, segments).await;
-                        } else {
-                            break;
+                    segment = self.inner.segment_receiver.recv() => {
+                        match segment {
+                            Ok(Some(segment)) => {
+                                // TODO Implement batch collect in future.
+                                let mut segments = LinkedList::new();
+                                segments.push_back(segment);
+                                Self::report_segment_object(&self.inner.reporter, segments).await;
+                            }
+                            Ok(None) => break,
+                            Err(err) => return Err(err.into()),
                         }
                     }
                     _ =  shutdown_rx.recv() => break,
                 }
             }
 
+            self.inner.is_closed.store(true, Ordering::Relaxed);
+
             // Flush.
-            let mut segment_receiver = self.inner.segment_receiver.lock().await;
             let mut segments = LinkedList::new();
-            while let Ok(segment) = segment_receiver.try_recv() {
-                segments.push_back(segment);
+            loop {
+                match self.inner.segment_receiver.try_recv().await {
+                    Ok(Some(segment)) => {
+                        segments.push_back(segment);
+                    }
+                    Ok(None) => break,
+                    Err(err) => return Err(err.into()),
+                }
             }
             Self::report_segment_object(&self.inner.reporter, segments).await;
+
+            Ok::<_, crate::Error>(())
         });
 
         shutdown_signal.await;
 
         if shutdown_tx.send(()).is_err() {
-            tracing::error!("Shutdown signal send failed");
+            tracing::error!("shutdown signal send failed");
         }
-        if let Err(e) = handle.await {
-            tracing::error!("Tokio handle join failed: {:?}", e);
-        }
+
+        handle.await??;
+
+        Ok(())
     }
 
     async fn report_segment_object(
         reporter: &Mutex<DynReporter>,
         segments: LinkedList<SegmentObject>,
     ) {
-        if let Err(e) = reporter.lock().await.collect(segments).await {
-            tracing::error!("Collect failed: {:?}", e);
+        if let Err(err) = reporter.lock().await.collect(segments).await {
+            tracing::error!(?err, "collect failed");
         }
     }
 
@@ -221,22 +318,29 @@
 
 /// Created by [Tracer::reporting].
 pub struct Reporting {
-    handle: JoinHandle<()>,
+    handle: JoinHandle<crate::Result<()>>,
 }
 
 impl Future for Reporting {
-    type Output = Result<(), JoinError>;
+    type Output = crate::Result<()>;
 
     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        Pin::new(&mut self.handle).poll(cx)
+        Pin::new(&mut self.handle).poll(cx).map(|r| r?)
     }
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
+    use std::future;
 
     trait AssertSend: Send {}
 
     impl AssertSend for Tracer {}
+
+    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+    async fn custom_channel() {
+        let tracer = Tracer::new_with_channel("service_name", "instance_name", (), ((), ()));
+        tracer.reporting(future::ready(())).await.unwrap();
+    }
 }
diff --git a/src/error/mod.rs b/src/error/mod.rs
index 5246571..406a240 100644
--- a/src/error/mod.rs
+++ b/src/error/mod.rs
@@ -33,4 +33,10 @@
 
     #[error("tonic status: {0}")]
     TonicStatus(#[from] tonic::Status),
+
+    #[error("tokio join failed: {0}")]
+    TokioJoin(#[from] tokio::task::JoinError),
+
+    #[error(transparent)]
+    Other(#[from] Box<dyn std::error::Error + Send + 'static>),
 }
diff --git a/src/reporter/log.rs b/src/reporter/log.rs
index 0d002a2..2e0a98a 100644
--- a/src/reporter/log.rs
+++ b/src/reporter/log.rs
@@ -54,7 +54,7 @@
 impl Default for LogReporter {
     fn default() -> Self {
         Self {
-            tip: "Collect".to_string(),
+            tip: "collect".to_string(),
             used: Used::Println,
         }
     }
diff --git a/src/reporter/mod.rs b/src/reporter/mod.rs
index 29ea40c..950b146 100644
--- a/src/reporter/mod.rs
+++ b/src/reporter/mod.rs
@@ -28,3 +28,13 @@
 pub trait Reporter {
     async fn collect(&mut self, segments: LinkedList<SegmentObject>) -> Result<(), Box<dyn Error>>;
 }
+
+#[async_trait]
+impl Reporter for () {
+    async fn collect(
+        &mut self,
+        _segments: LinkedList<SegmentObject>,
+    ) -> Result<(), Box<dyn Error>> {
+        Ok(())
+    }
+}