blob: e4839cb71e9d1f63aa140dde8a6ba29a4a828867 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::collections::VecDeque;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use futures::FutureExt;
use crate::*;
/// BoxedFuture is the type alias of [`futures::future::BoxFuture`].
#[cfg(not(target_arch = "wasm32"))]
pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
#[cfg(target_arch = "wasm32")]
/// BoxedFuture is the type alias of [`futures::future::LocalBoxFuture`].
pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
/// BoxedStaticFuture is the type alias of [`futures::future::BoxFuture`].
#[cfg(not(target_arch = "wasm32"))]
pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
#[cfg(target_arch = "wasm32")]
/// BoxedStaticFuture is the type alias of [`futures::future::LocalBoxFuture`].
pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
/// MaybeSend is a marker to determine whether a type is `Send` or not.
/// We use this trait to wrap the `Send` requirement for wasm32 target.
///
/// # Safety
///
/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
#[cfg(not(target_arch = "wasm32"))]
pub trait MaybeSend: Send {}
/// MaybeSend is a marker to determine whether a type is `Send` or not.
/// We use this trait to wrap the `Send` requirement for wasm32 target.
///
/// # Safety
///
/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
#[cfg(target_arch = "wasm32")]
pub trait MaybeSend {}
#[cfg(not(target_arch = "wasm32"))]
impl<T: Send> MaybeSend for T {}
#[cfg(target_arch = "wasm32")]
impl<T> MaybeSend for T {}
/// ConcurrentTasks is used to execute tasks concurrently.
///
/// ConcurrentTasks has two generic types:
///
/// - `I` represents the input type of the task.
/// - `O` represents the output type of the task.
///
/// # Implementation Notes
///
/// The code patterns below are intentional; please do not modify them unless you fully understand these notes.
///
/// ```skip
/// let (i, o) = self
/// .tasks
/// .front_mut() // Use `front_mut` instead of `pop_front`
/// .expect("tasks must be available")
/// .await;
/// ...
/// match o {
/// Ok(o) => {
/// let _ = self.tasks.pop_front(); // `pop_front` after got `Ok(o)`
/// self.results.push_back(o)
/// }
/// Err(err) => {
/// if err.is_temporary() {
/// let task = self.create_task(i);
/// self.tasks
/// .front_mut()
/// .expect("tasks must be available")
/// .replace(task) // Use replace here to instead of `push_front`
/// } else {
/// self.clear();
/// self.errored = true;
/// }
/// return Err(err);
/// }
/// }
/// ```
///
/// Please keep in mind that there is no guarantee the task will be `await`ed until completion. It's possible
/// the task may be dropped before it resolves. Therefore, we should keep the `Task` in the `tasks` queue until
/// it is resolved.
///
/// For example, users may have a timeout for the task, and the task will be dropped if it exceeds the timeout.
/// If we `pop_front` the task before it resolves, the task will be canceled and the result will be lost.
pub struct ConcurrentTasks<I, O> {
/// The executor to execute the tasks.
///
/// If user doesn't provide an executor, the tasks will be executed with the default executor.
executor: Executor,
/// The factory to create the task.
///
/// Caller of ConcurrentTasks must provides a factory to create the task for executing.
///
/// The factory must accept an input and return a future that resolves to a tuple of input and
/// output result. If the given result is error, the error will be returned to users and the
/// task will be retried.
factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
/// `tasks` holds the ongoing tasks.
///
/// Please keep in mind that all tasks are running in the background by `Executor`. We only need
/// to poll the tasks to see if they are ready.
///
/// Dropping task without `await` it will cancel the task.
tasks: VecDeque<Task<(I, Result<O>)>>,
/// `results` stores the successful results.
results: VecDeque<O>,
/// The maximum number of concurrent tasks.
concurrent: usize,
/// The maximum number of completed tasks that can be buffered.
prefetch: usize,
/// Tracks the number of tasks that have finished execution but have not yet been collected.
/// This count is subtracted from the total concurrency capacity, ensuring that the system
/// always schedules new tasks to maintain the user's desired concurrency level.
///
/// Example: If `concurrency = 10` and `completed_but_unretrieved = 3`,
/// the system can still spawn 7 new tasks (since 3 slots are "logically occupied"
/// by uncollected results).
completed_but_unretrieved: Arc<AtomicUsize>,
/// hitting the last unrecoverable error.
///
/// If concurrent tasks hit an unrecoverable error, it will stop executing new tasks and return
/// an unrecoverable error to users.
errored: bool,
}
impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
/// Create a new concurrent tasks with given executor, concurrent, prefetch and factory.
///
/// The factory is a function pointer that shouldn't capture any context.
pub fn new(
executor: Executor,
concurrent: usize,
prefetch: usize,
factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
) -> Self {
Self {
executor,
factory,
tasks: VecDeque::with_capacity(concurrent),
results: VecDeque::with_capacity(concurrent),
concurrent,
prefetch,
completed_but_unretrieved: Arc::default(),
errored: false,
}
}
/// Return true if the tasks are running concurrently.
#[inline]
fn is_concurrent(&self) -> bool {
self.concurrent > 1
}
/// Clear all tasks and results.
///
/// All ongoing tasks will be canceled.
pub fn clear(&mut self) {
self.tasks.clear();
self.results.clear();
}
/// Check if there are remaining space to push new tasks.
#[inline]
pub fn has_remaining(&self) -> bool {
let completed = self.completed_but_unretrieved.load(Ordering::Relaxed);
// Allow up to `prefetch` completed tasks to be buffered
self.tasks.len() < self.concurrent + completed.min(self.prefetch)
}
/// Chunk if there are remaining results to fetch.
#[inline]
pub fn has_result(&self) -> bool {
!self.results.is_empty()
}
/// Create a task with given input.
pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
let completed = self.completed_but_unretrieved.clone();
let fut = (self.factory)(input).inspect(move |_| {
completed.fetch_add(1, Ordering::Relaxed);
});
self.executor.execute(fut)
}
/// Execute the task with given input.
///
/// - Execute the task in the current thread if is not concurrent.
/// - Execute the task in the background if there are available slots.
/// - Await the first task in the queue if there is no available slots.
pub async fn execute(&mut self, input: I) -> Result<()> {
if self.errored {
return Err(Error::new(
ErrorKind::Unexpected,
"concurrent tasks met an unrecoverable error",
));
}
// Short path for non-concurrent case.
if !self.is_concurrent() {
let (_, o) = (self.factory)(input).await;
return match o {
Ok(o) => {
self.results.push_back(o);
Ok(())
}
// We don't need to rebuild the future if it's not concurrent.
Err(err) => Err(err),
};
}
if !self.has_remaining() {
let (i, o) = self
.tasks
.front_mut()
.expect("tasks must be available")
.await;
self.completed_but_unretrieved
.fetch_sub(1, Ordering::Relaxed);
match o {
Ok(o) => {
let _ = self.tasks.pop_front();
self.results.push_back(o)
}
Err(err) => {
// Retry this task if the error is temporary
if err.is_temporary() {
let task = self.create_task(i);
self.tasks
.front_mut()
.expect("tasks must be available")
.replace(task)
} else {
self.clear();
self.errored = true;
}
return Err(err);
}
}
}
self.tasks.push_back(self.create_task(input));
Ok(())
}
/// Fetch the successful result from the result queue.
pub async fn next(&mut self) -> Option<Result<O>> {
if self.errored {
return Some(Err(Error::new(
ErrorKind::Unexpected,
"concurrent tasks met an unrecoverable error",
)));
}
if let Some(result) = self.results.pop_front() {
return Some(Ok(result));
}
if let Some(task) = self.tasks.front_mut() {
let (i, o) = task.await;
self.completed_but_unretrieved
.fetch_sub(1, Ordering::Relaxed);
return match o {
Ok(o) => {
let _ = self.tasks.pop_front();
Some(Ok(o))
}
Err(err) => {
// Retry this task if the error is temporary
if err.is_temporary() {
let task = self.create_task(i);
self.tasks
.front_mut()
.expect("tasks must be available")
.replace(task)
} else {
self.clear();
self.errored = true;
}
Some(Err(err))
}
};
}
None
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use pretty_assertions::assert_eq;
use rand::Rng;
use tokio::time::sleep;
use super::*;
#[tokio::test]
async fn test_concurrent_tasks() {
let executor = Executor::new();
let mut tasks = ConcurrentTasks::new(executor, 16, 8, |(i, dur)| {
Box::pin(async move {
sleep(dur).await;
// 5% rate to fail.
if rand::thread_rng().gen_range(0..100) > 90 {
return (
(i, dur),
Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
);
}
((i, dur), Ok(i))
})
});
let mut ans = vec![];
for i in 0..10240 {
// Sleep up to 10ms
let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
loop {
let res = tasks.execute((i, dur)).await;
if res.is_ok() {
break;
}
}
}
loop {
match tasks.next().await.transpose() {
Ok(Some(i)) => ans.push(i),
Ok(None) => break,
Err(_) => continue,
}
}
assert_eq!(ans, (0..10240).collect::<Vec<_>>())
}
#[tokio::test]
async fn test_prefetch_backpressure() {
let executor = Executor::new();
let concurrent = 4;
let prefetch = 2;
// Create a slower task to ensure they don't complete immediately
let mut tasks = ConcurrentTasks::new(executor, concurrent, prefetch, |i: usize| {
Box::pin(async move {
sleep(Duration::from_millis(100)).await;
(i, Ok(i))
})
});
// Initially, we should have space for concurrent tasks
assert!(tasks.has_remaining(), "Should have space initially");
// Submit concurrent tasks
for i in 0..concurrent {
assert!(tasks.has_remaining(), "Should have space for task {i}");
tasks.execute(i).await.unwrap();
}
// Now we shouldn't have any more space (since no tasks have completed yet)
assert!(
!tasks.has_remaining(),
"Should not have space after submitting concurrent tasks"
);
// Wait for some tasks to complete
sleep(Duration::from_millis(150)).await;
// Now we should have space up to prefetch limit
for i in concurrent..concurrent + prefetch {
assert!(
tasks.has_remaining(),
"Should have space for prefetch task {i}"
);
tasks.execute(i).await.unwrap();
}
// Now has_remaining should return false
assert!(
!tasks.has_remaining(),
"Should not have remaining space after filling up prefetch buffer"
);
// Retrieve one result
let result = tasks.next().await;
assert!(result.is_some());
// Now there should be space for one more task
assert!(
tasks.has_remaining(),
"Should have remaining space after retrieving one result"
);
}
#[tokio::test]
async fn test_prefetch_zero() {
let executor = Executor::new();
let concurrent = 4;
let prefetch = 0; // No prefetching allowed
let mut tasks = ConcurrentTasks::new(executor, concurrent, prefetch, |i: usize| {
Box::pin(async move {
sleep(Duration::from_millis(10)).await;
(i, Ok(i))
})
});
// With prefetch=0, we can only submit up to concurrent tasks
for i in 0..concurrent {
tasks.execute(i).await.unwrap();
}
// Should not have space for more
assert!(
!tasks.has_remaining(),
"Should not have remaining space with prefetch=0"
);
// Retrieve one result
let result = tasks.next().await;
assert!(result.is_some());
// Now there should be space for exactly one more task
assert!(
tasks.has_remaining(),
"Should have remaining space after retrieving one result"
);
// Execute one more
tasks.execute(concurrent).await.unwrap();
// Should be full again
assert!(!tasks.has_remaining(), "Should be full again");
}
}