blob: ab223b04661c8ddaf9bc90949d2dc6306a33cef0 [file] [log] [blame]
//! A future wrapper to ensure the wrapped future must be polled.
//!
//! This implementation is forked from: https://github.com/influxdata/influxdb_iox/blob/885767aa0a6010de592bde9992945b01389eb994/cache_system/src/cancellation_safe_future.rs
//! Here is the copyright and license disclaimer:
//! Copyright (c) 2020 InfluxData. Licensed under Apache-2.0.
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures::future::BoxFuture;
use lazy_static::lazy_static;
use prometheus::{register_int_counter_vec, IntCounterVec};
use runtime::RuntimeRef;
lazy_static! {
static ref FUTURE_CANCEL_COUNTER: IntCounterVec = register_int_counter_vec!(
"future_cancel_counter",
"Counter of future cancel",
&["token"]
)
.unwrap();
}
/// Wrapper around a future that cannot be cancelled.
///
/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_
/// it.
pub struct CancellationSafeFuture<F, T>
where
F: Future + Send + 'static,
F::Output: Send,
T: AsRef<str> + 'static + Send + Unpin,
{
/// Token for metrics
token: T,
/// Mark if the inner future finished. If not, we must spawn a helper task
/// on drop.
done: bool,
/// Inner future.
///
/// Wrapped in an `Option` so we can extract it during drop. Inside that
/// option however we also need a pinned box because once this wrapper
/// is polled, it will be pinned in memory -- even during drop. Now the
/// inner future does not necessarily implement `Unpin`, so we need a
/// heap allocation to pin it in memory even when we move it out of this
/// option.
inner: Option<BoxFuture<'static, F::Output>>,
/// The runtime to execute the dropped future.
runtime: RuntimeRef,
}
impl<F, T> Drop for CancellationSafeFuture<F, T>
where
F: Future + Send + 'static,
F::Output: Send,
T: AsRef<str> + 'static + Send + Unpin,
{
fn drop(&mut self) {
if !self.done {
FUTURE_CANCEL_COUNTER
.with_label_values(&[self.token.as_ref()])
.inc();
let inner = self.inner.take().unwrap();
let handle = self.runtime.spawn(inner);
drop(handle);
}
}
}
impl<F, T> CancellationSafeFuture<F, T>
where
F: Future + Send,
F::Output: Send,
T: AsRef<str> + 'static + Send + Unpin,
{
/// Create new future that is protected from cancellation.
///
/// If [`CancellationSafeFuture`] is cancelled (i.e. dropped) and there is
/// still some external receiver of the state left, than we will drive
/// the payload (`f`) to completion. Otherwise `f` will be cancelled.
pub fn new(fut: F, token: T, runtime: RuntimeRef) -> Self {
Self {
token,
done: false,
inner: Some(Box::pin(fut)),
runtime,
}
}
}
impl<F, T> Future for CancellationSafeFuture<F, T>
where
F: Future + Send,
F::Output: Send,
T: AsRef<str> + 'static + Send + Unpin,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.done, "Polling future that already returned");
match self.inner.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use runtime::Builder;
use tokio::sync::Barrier;
use super::*;
fn rt() -> RuntimeRef {
let rt = Builder::default()
.worker_threads(2)
.thread_name("test_spawn_join")
.enable_all()
.build();
assert!(rt.is_ok());
Arc::new(rt.unwrap())
}
#[test]
fn test_happy_path() {
let runtime = rt();
let runtime_clone = runtime.clone();
runtime.block_on(async move {
let done = Arc::new(AtomicBool::new(false));
let done_captured = Arc::clone(&done);
let fut = CancellationSafeFuture::new(
async move {
done_captured.store(true, Ordering::SeqCst);
},
"test",
runtime_clone,
);
fut.await;
assert!(done.load(Ordering::SeqCst));
})
}
#[test]
fn test_cancel_future() {
let runtime = rt();
let runtime_clone = runtime.clone();
runtime.block_on(async move {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
"test",
runtime_clone,
);
drop(fut);
tokio::time::timeout(Duration::from_secs(5), done.wait())
.await
.unwrap();
});
}
}