| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <atomic> |
| #include <chrono> |
| #include <condition_variable> |
| #include <cstdint> |
| #include <functional> |
| #include <memory> |
| #include <random> |
| #include <thread> |
| #include <utility> |
| #include <vector> |
| |
| #include <gtest/gtest.h> |
| |
| #include "arrow/status.h" |
| #include "arrow/testing/future_util.h" |
| #include "arrow/testing/gtest_util.h" |
| #include "arrow/util/task_group.h" |
| #include "arrow/util/thread_pool.h" |
| |
| namespace arrow { |
| namespace internal { |
| |
| // Generate random sleep durations |
| static std::vector<double> RandomSleepDurations(int nsleeps, double min_seconds, |
| double max_seconds) { |
| std::vector<double> sleeps; |
| std::default_random_engine engine; |
| std::uniform_real_distribution<> sleep_dist(min_seconds, max_seconds); |
| for (int i = 0; i < nsleeps; ++i) { |
| sleeps.push_back(sleep_dist(engine)); |
| } |
| return sleeps; |
| } |
| |
| // Check TaskGroup behaviour with a bunch of all-successful tasks |
| void TestTaskGroupSuccess(std::shared_ptr<TaskGroup> task_group) { |
| const int NTASKS = 10; |
| auto sleeps = RandomSleepDurations(NTASKS, 1e-3, 4e-3); |
| |
| // Add NTASKS sleeps |
| std::atomic<int> count(0); |
| for (int i = 0; i < NTASKS; ++i) { |
| task_group->Append([&, i]() { |
| SleepFor(sleeps[i]); |
| count += i; |
| return Status::OK(); |
| }); |
| } |
| ASSERT_TRUE(task_group->ok()); |
| |
| ASSERT_OK(task_group->Finish()); |
| ASSERT_TRUE(task_group->ok()); |
| ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2); |
| // Finish() is idempotent |
| ASSERT_OK(task_group->Finish()); |
| } |
| |
| // Check TaskGroup behaviour with some successful and some failing tasks |
| void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) { |
| const int NSUCCESSES = 2; |
| const int NERRORS = 20; |
| |
| std::atomic<int> count(0); |
| |
| auto task_group_was_ok = false; |
| task_group->Append([&]() -> Status { |
| for (int i = 0; i < NSUCCESSES; ++i) { |
| task_group->Append([&]() { |
| count++; |
| return Status::OK(); |
| }); |
| } |
| task_group_was_ok = task_group->ok(); |
| for (int i = 0; i < NERRORS; ++i) { |
| task_group->Append([&]() { |
| SleepFor(1e-2); |
| count++; |
| return Status::Invalid("some message"); |
| }); |
| } |
| |
| return Status::OK(); |
| }); |
| |
| // Task error is propagated |
| ASSERT_RAISES(Invalid, task_group->Finish()); |
| ASSERT_TRUE(task_group_was_ok); |
| ASSERT_FALSE(task_group->ok()); |
| if (task_group->parallelism() == 1) { |
| // Serial: exactly two successes and an error |
| ASSERT_EQ(count.load(), 3); |
| } else { |
| // Parallel: at least two successes and an error |
| ASSERT_GE(count.load(), 3); |
| ASSERT_LE(count.load(), 2 * task_group->parallelism()); |
| } |
| // Finish() is idempotent |
| ASSERT_RAISES(Invalid, task_group->Finish()); |
| } |
| |
| void TestTaskGroupCancel(std::shared_ptr<TaskGroup> task_group, StopSource* stop_source) { |
| const int NSUCCESSES = 2; |
| const int NCANCELS = 20; |
| |
| std::atomic<int> count(0); |
| |
| auto task_group_was_ok = false; |
| task_group->Append([&]() -> Status { |
| for (int i = 0; i < NSUCCESSES; ++i) { |
| task_group->Append([&]() { |
| count++; |
| return Status::OK(); |
| }); |
| } |
| task_group_was_ok = task_group->ok(); |
| for (int i = 0; i < NCANCELS; ++i) { |
| task_group->Append([&]() { |
| SleepFor(1e-2); |
| stop_source->RequestStop(); |
| count++; |
| return Status::OK(); |
| }); |
| } |
| |
| return Status::OK(); |
| }); |
| |
| // Cancellation is propagated |
| ASSERT_RAISES(Cancelled, task_group->Finish()); |
| ASSERT_TRUE(task_group_was_ok); |
| ASSERT_FALSE(task_group->ok()); |
| if (task_group->parallelism() == 1) { |
| // Serial: exactly three successes |
| ASSERT_EQ(count.load(), NSUCCESSES + 1); |
| } else { |
| // Parallel: at least three successes |
| ASSERT_GE(count.load(), NSUCCESSES + 1); |
| ASSERT_LE(count.load(), NSUCCESSES * task_group->parallelism()); |
| } |
| // Finish() is idempotent |
| ASSERT_RAISES(Cancelled, task_group->Finish()); |
| } |
| |
| class CopyCountingTask { |
| public: |
| explicit CopyCountingTask(std::shared_ptr<uint8_t> target) |
| : counter(0), target(std::move(target)) {} |
| |
| CopyCountingTask(const CopyCountingTask& other) |
| : counter(other.counter + 1), target(other.target) {} |
| |
| CopyCountingTask& operator=(const CopyCountingTask& other) { |
| counter = other.counter + 1; |
| target = other.target; |
| return *this; |
| } |
| |
| CopyCountingTask(CopyCountingTask&& other) = default; |
| CopyCountingTask& operator=(CopyCountingTask&& other) = default; |
| |
| Status operator()() { |
| *target = counter; |
| return Status::OK(); |
| } |
| |
| private: |
| uint8_t counter; |
| std::shared_ptr<uint8_t> target; |
| }; |
| |
| // Check TaskGroup behaviour with tasks spawning other tasks |
| void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) { |
| const int N = 6; |
| |
| std::atomic<int> count(0); |
| // Make a task that recursively spawns itself |
| std::function<std::function<Status()>(int)> make_task = [&](int i) { |
| return [&, i]() { |
| count++; |
| if (i > 0) { |
| // Exercise parallelism by spawning two tasks at once and then sleeping |
| task_group->Append(make_task(i - 1)); |
| task_group->Append(make_task(i - 1)); |
| SleepFor(1e-3); |
| } |
| return Status::OK(); |
| }; |
| }; |
| |
| task_group->Append(make_task(N)); |
| |
| ASSERT_OK(task_group->Finish()); |
| ASSERT_TRUE(task_group->ok()); |
| ASSERT_EQ(count.load(), (1 << (N + 1)) - 1); |
| } |
| |
| // A task that keeps recursing until a barrier is set. |
| // Using a lambda for this doesn't play well with Thread Sanitizer. |
| struct BarrierTask { |
| std::atomic<bool>* barrier_; |
| std::weak_ptr<TaskGroup> weak_group_ptr_; |
| Status final_status_; |
| |
| Status operator()() { |
| if (!barrier_->load()) { |
| SleepFor(1e-5); |
| // Note the TaskGroup should be kept alive by the fact this task |
| // is still running... |
| weak_group_ptr_.lock()->Append(*this); |
| } |
| return final_status_; |
| } |
| }; |
| |
| // Try to replicate subtle lifetime issues when destroying a TaskGroup |
| // where all tasks may not have finished running. |
| void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) { |
| const int NTASKS = 100; |
| auto task_group = factory(); |
| auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group); |
| |
| std::atomic<bool> barrier(false); |
| |
| BarrierTask task{&barrier, weak_group_ptr, Status::OK()}; |
| |
| for (int i = 0; i < NTASKS; ++i) { |
| task_group->Append(task); |
| } |
| |
| // Lose strong reference |
| barrier.store(true); |
| task_group.reset(); |
| |
| // Wait for finish |
| while (!weak_group_ptr.expired()) { |
| SleepFor(1e-5); |
| } |
| } |
| |
| // Same, but with also a failing task |
| void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) { |
| const int NTASKS = 100; |
| auto task_group = factory(); |
| auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group); |
| |
| std::atomic<bool> barrier(false); |
| |
| BarrierTask task{&barrier, weak_group_ptr, Status::OK()}; |
| BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")}; |
| |
| for (int i = 0; i < NTASKS; ++i) { |
| task_group->Append(task); |
| } |
| task_group->Append(failing_task); |
| |
| // Lose strong reference |
| barrier.store(true); |
| task_group.reset(); |
| |
| // Wait for finish |
| while (!weak_group_ptr.expired()) { |
| SleepFor(1e-5); |
| } |
| } |
| |
| void TestNoCopyTask(std::shared_ptr<TaskGroup> task_group) { |
| auto counter = std::make_shared<uint8_t>(0); |
| CopyCountingTask task(counter); |
| task_group->Append(std::move(task)); |
| ASSERT_OK(task_group->Finish()); |
| ASSERT_EQ(0, *counter); |
| } |
| |
| void TestFinishNotSticky(std::function<std::shared_ptr<TaskGroup>()> factory) { |
| // If a task is added that runs very quickly it might decrement the task counter back |
| // down to 0 and mark the completion future as complete before all tasks are added. |
| // The "finished future" of the task group could get stuck to complete. |
| // |
| // Instead the task group should not allow the finished future to be marked complete |
| // until after FinishAsync has been called. |
| const int NTASKS = 100; |
| for (int i = 0; i < NTASKS; ++i) { |
| auto task_group = factory(); |
| // Add a task and let it complete |
| task_group->Append([] { return Status::OK(); }); |
| // Wait a little bit, if the task group was going to lock the finish hopefully it |
| // would do so here while we wait |
| SleepFor(1e-2); |
| |
| // Add a new task that will still be running |
| std::atomic<bool> ready(false); |
| std::mutex m; |
| std::condition_variable cv; |
| task_group->Append([&m, &cv, &ready] { |
| std::unique_lock<std::mutex> lk(m); |
| cv.wait(lk, [&ready] { return ready.load(); }); |
| return Status::OK(); |
| }); |
| |
| // Ensure task group not finished already |
| auto finished = task_group->FinishAsync(); |
| ASSERT_FALSE(finished.is_finished()); |
| |
| std::unique_lock<std::mutex> lk(m); |
| ready = true; |
| lk.unlock(); |
| cv.notify_one(); |
| |
| ASSERT_FINISHES_OK(finished); |
| } |
| } |
| |
| void TestFinishNeverStarted(std::shared_ptr<TaskGroup> task_group) { |
| // If we call FinishAsync we are done adding tasks so if we never added any it should be |
| // completed |
| auto finished = task_group->FinishAsync(); |
| ASSERT_TRUE(finished.Wait(1)); |
| } |
| |
| void TestFinishAlreadyCompleted(std::function<std::shared_ptr<TaskGroup>()> factory) { |
| // If we call FinishAsync we are done adding tasks so even if no tasks are running we |
| // should still be completed |
| const int NTASKS = 100; |
| for (int i = 0; i < NTASKS; ++i) { |
| auto task_group = factory(); |
| // Add a task and let it complete |
| task_group->Append([] { return Status::OK(); }); |
| // Wait a little bit, hopefully enough time for the task to finish on one of these |
| // iterations |
| SleepFor(1e-2); |
| auto finished = task_group->FinishAsync(); |
| ASSERT_FINISHES_OK(finished); |
| } |
| } |
| |
| TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); } |
| |
| TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); } |
| |
| TEST(SerialTaskGroup, Cancel) { |
| StopSource stop_source; |
| TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source.token()), &stop_source); |
| } |
| |
| TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); } |
| |
| TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); } |
| |
| TEST(SerialTaskGroup, FinishNeverStarted) { |
| TestFinishNeverStarted(TaskGroup::MakeSerial()); |
| } |
| |
| TEST(SerialTaskGroup, FinishAlreadyCompleted) { |
| TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); }); |
| } |
| |
| TEST(ThreadedTaskGroup, Success) { |
| auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); |
| TestTaskGroupSuccess(task_group); |
| } |
| |
| TEST(ThreadedTaskGroup, Errors) { |
| // Limit parallelism to ensure some tasks don't get started |
| // after the first failing ones |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); |
| |
| TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get())); |
| } |
| |
| TEST(ThreadedTaskGroup, Cancel) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); |
| |
| StopSource stop_source; |
| TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool.get(), stop_source.token()), |
| &stop_source); |
| } |
| |
| TEST(ThreadedTaskGroup, TasksSpawnTasks) { |
| auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); |
| TestTasksSpawnTasks(task_group); |
| } |
| |
| TEST(ThreadedTaskGroup, NoCopyTask) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); |
| TestNoCopyTask(TaskGroup::MakeThreaded(thread_pool.get())); |
| } |
| |
| TEST(ThreadedTaskGroup, StressTaskGroupLifetime) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); |
| |
| StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); |
| } |
| |
| TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); |
| |
| StressFailingTaskGroupLifetime( |
| [&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); |
| } |
| |
| TEST(ThreadedTaskGroup, FinishNotSticky) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); |
| |
| TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); |
| } |
| |
| TEST(ThreadedTaskGroup, FinishNeverStarted) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); |
| TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get())); |
| } |
| |
| TEST(ThreadedTaskGroup, FinishAlreadyCompleted) { |
| std::shared_ptr<ThreadPool> thread_pool; |
| ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); |
| |
| TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); |
| } |
| |
| } // namespace internal |
| } // namespace arrow |