blob: 4913fb9294c2f49d25fd27df725bb04b6d5f113f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <memory>
#include <random>
#include <thread>
#include <utility>
#include <vector>
#include <gtest/gtest.h>
#include "arrow/status.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/task_group.h"
#include "arrow/util/thread_pool.h"
namespace arrow {
namespace internal {
// Generate random sleep durations
static std::vector<double> RandomSleepDurations(int nsleeps, double min_seconds,
double max_seconds) {
std::vector<double> sleeps;
std::default_random_engine engine;
std::uniform_real_distribution<> sleep_dist(min_seconds, max_seconds);
for (int i = 0; i < nsleeps; ++i) {
sleeps.push_back(sleep_dist(engine));
}
return sleeps;
}
// Check TaskGroup behaviour with a bunch of all-successful tasks
void TestTaskGroupSuccess(std::shared_ptr<TaskGroup> task_group) {
const int NTASKS = 10;
auto sleeps = RandomSleepDurations(NTASKS, 1e-3, 4e-3);
// Add NTASKS sleeps
std::atomic<int> count(0);
for (int i = 0; i < NTASKS; ++i) {
task_group->Append([&, i]() {
SleepFor(sleeps[i]);
count += i;
return Status::OK();
});
}
ASSERT_TRUE(task_group->ok());
ASSERT_OK(task_group->Finish());
ASSERT_TRUE(task_group->ok());
ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2);
// Finish() is idempotent
ASSERT_OK(task_group->Finish());
}
// Check TaskGroup behaviour with some successful and some failing tasks
void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) {
const int NSUCCESSES = 2;
const int NERRORS = 20;
std::atomic<int> count(0);
auto task_group_was_ok = false;
task_group->Append([&]() -> Status {
for (int i = 0; i < NSUCCESSES; ++i) {
task_group->Append([&]() {
count++;
return Status::OK();
});
}
task_group_was_ok = task_group->ok();
for (int i = 0; i < NERRORS; ++i) {
task_group->Append([&]() {
SleepFor(1e-2);
count++;
return Status::Invalid("some message");
});
}
return Status::OK();
});
// Task error is propagated
ASSERT_RAISES(Invalid, task_group->Finish());
ASSERT_TRUE(task_group_was_ok);
ASSERT_FALSE(task_group->ok());
if (task_group->parallelism() == 1) {
// Serial: exactly two successes and an error
ASSERT_EQ(count.load(), 3);
} else {
// Parallel: at least two successes and an error
ASSERT_GE(count.load(), 3);
ASSERT_LE(count.load(), 2 * task_group->parallelism());
}
// Finish() is idempotent
ASSERT_RAISES(Invalid, task_group->Finish());
}
void TestTaskGroupCancel(std::shared_ptr<TaskGroup> task_group, StopSource* stop_source) {
const int NSUCCESSES = 2;
const int NCANCELS = 20;
std::atomic<int> count(0);
auto task_group_was_ok = false;
task_group->Append([&]() -> Status {
for (int i = 0; i < NSUCCESSES; ++i) {
task_group->Append([&]() {
count++;
return Status::OK();
});
}
task_group_was_ok = task_group->ok();
for (int i = 0; i < NCANCELS; ++i) {
task_group->Append([&]() {
SleepFor(1e-2);
stop_source->RequestStop();
count++;
return Status::OK();
});
}
return Status::OK();
});
// Cancellation is propagated
ASSERT_RAISES(Cancelled, task_group->Finish());
ASSERT_TRUE(task_group_was_ok);
ASSERT_FALSE(task_group->ok());
if (task_group->parallelism() == 1) {
// Serial: exactly three successes
ASSERT_EQ(count.load(), NSUCCESSES + 1);
} else {
// Parallel: at least three successes
ASSERT_GE(count.load(), NSUCCESSES + 1);
ASSERT_LE(count.load(), NSUCCESSES * task_group->parallelism());
}
// Finish() is idempotent
ASSERT_RAISES(Cancelled, task_group->Finish());
}
class CopyCountingTask {
public:
explicit CopyCountingTask(std::shared_ptr<uint8_t> target)
: counter(0), target(std::move(target)) {}
CopyCountingTask(const CopyCountingTask& other)
: counter(other.counter + 1), target(other.target) {}
CopyCountingTask& operator=(const CopyCountingTask& other) {
counter = other.counter + 1;
target = other.target;
return *this;
}
CopyCountingTask(CopyCountingTask&& other) = default;
CopyCountingTask& operator=(CopyCountingTask&& other) = default;
Status operator()() {
*target = counter;
return Status::OK();
}
private:
uint8_t counter;
std::shared_ptr<uint8_t> target;
};
// Check TaskGroup behaviour with tasks spawning other tasks
void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
const int N = 6;
std::atomic<int> count(0);
// Make a task that recursively spawns itself
std::function<std::function<Status()>(int)> make_task = [&](int i) {
return [&, i]() {
count++;
if (i > 0) {
// Exercise parallelism by spawning two tasks at once and then sleeping
task_group->Append(make_task(i - 1));
task_group->Append(make_task(i - 1));
SleepFor(1e-3);
}
return Status::OK();
};
};
task_group->Append(make_task(N));
ASSERT_OK(task_group->Finish());
ASSERT_TRUE(task_group->ok());
ASSERT_EQ(count.load(), (1 << (N + 1)) - 1);
}
// A task that keeps recursing until a barrier is set.
// Using a lambda for this doesn't play well with Thread Sanitizer.
struct BarrierTask {
std::atomic<bool>* barrier_;
std::weak_ptr<TaskGroup> weak_group_ptr_;
Status final_status_;
Status operator()() {
if (!barrier_->load()) {
SleepFor(1e-5);
// Note the TaskGroup should be kept alive by the fact this task
// is still running...
weak_group_ptr_.lock()->Append(*this);
}
return final_status_;
}
};
// Try to replicate subtle lifetime issues when destroying a TaskGroup
// where all tasks may not have finished running.
void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
const int NTASKS = 100;
auto task_group = factory();
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
std::atomic<bool> barrier(false);
BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
for (int i = 0; i < NTASKS; ++i) {
task_group->Append(task);
}
// Lose strong reference
barrier.store(true);
task_group.reset();
// Wait for finish
while (!weak_group_ptr.expired()) {
SleepFor(1e-5);
}
}
// Same, but with also a failing task
void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
const int NTASKS = 100;
auto task_group = factory();
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
std::atomic<bool> barrier(false);
BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")};
for (int i = 0; i < NTASKS; ++i) {
task_group->Append(task);
}
task_group->Append(failing_task);
// Lose strong reference
barrier.store(true);
task_group.reset();
// Wait for finish
while (!weak_group_ptr.expired()) {
SleepFor(1e-5);
}
}
void TestNoCopyTask(std::shared_ptr<TaskGroup> task_group) {
auto counter = std::make_shared<uint8_t>(0);
CopyCountingTask task(counter);
task_group->Append(std::move(task));
ASSERT_OK(task_group->Finish());
ASSERT_EQ(0, *counter);
}
void TestFinishNotSticky(std::function<std::shared_ptr<TaskGroup>()> factory) {
// If a task is added that runs very quickly it might decrement the task counter back
// down to 0 and mark the completion future as complete before all tasks are added.
// The "finished future" of the task group could get stuck to complete.
//
// Instead the task group should not allow the finished future to be marked complete
// until after FinishAsync has been called.
const int NTASKS = 100;
for (int i = 0; i < NTASKS; ++i) {
auto task_group = factory();
// Add a task and let it complete
task_group->Append([] { return Status::OK(); });
// Wait a little bit, if the task group was going to lock the finish hopefully it
// would do so here while we wait
SleepFor(1e-2);
// Add a new task that will still be running
std::atomic<bool> ready(false);
std::mutex m;
std::condition_variable cv;
task_group->Append([&m, &cv, &ready] {
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&ready] { return ready.load(); });
return Status::OK();
});
// Ensure task group not finished already
auto finished = task_group->FinishAsync();
ASSERT_FALSE(finished.is_finished());
std::unique_lock<std::mutex> lk(m);
ready = true;
lk.unlock();
cv.notify_one();
ASSERT_FINISHES_OK(finished);
}
}
void TestFinishNeverStarted(std::shared_ptr<TaskGroup> task_group) {
// If we call FinishAsync we are done adding tasks so if we never added any it should be
// completed
auto finished = task_group->FinishAsync();
ASSERT_TRUE(finished.Wait(1));
}
void TestFinishAlreadyCompleted(std::function<std::shared_ptr<TaskGroup>()> factory) {
// If we call FinishAsync we are done adding tasks so even if no tasks are running we
// should still be completed
const int NTASKS = 100;
for (int i = 0; i < NTASKS; ++i) {
auto task_group = factory();
// Add a task and let it complete
task_group->Append([] { return Status::OK(); });
// Wait a little bit, hopefully enough time for the task to finish on one of these
// iterations
SleepFor(1e-2);
auto finished = task_group->FinishAsync();
ASSERT_FINISHES_OK(finished);
}
}
TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
TEST(SerialTaskGroup, Cancel) {
StopSource stop_source;
TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source.token()), &stop_source);
}
TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); }
TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); }
TEST(SerialTaskGroup, FinishNeverStarted) {
TestFinishNeverStarted(TaskGroup::MakeSerial());
}
TEST(SerialTaskGroup, FinishAlreadyCompleted) {
TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); });
}
TEST(ThreadedTaskGroup, Success) {
auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
TestTaskGroupSuccess(task_group);
}
TEST(ThreadedTaskGroup, Errors) {
// Limit parallelism to ensure some tasks don't get started
// after the first failing ones
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get()));
}
TEST(ThreadedTaskGroup, Cancel) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
StopSource stop_source;
TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool.get(), stop_source.token()),
&stop_source);
}
TEST(ThreadedTaskGroup, TasksSpawnTasks) {
auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
TestTasksSpawnTasks(task_group);
}
TEST(ThreadedTaskGroup, NoCopyTask) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
TestNoCopyTask(TaskGroup::MakeThreaded(thread_pool.get()));
}
TEST(ThreadedTaskGroup, StressTaskGroupLifetime) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}
TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
StressFailingTaskGroupLifetime(
[&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}
TEST(ThreadedTaskGroup, FinishNotSticky) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}
TEST(ThreadedTaskGroup, FinishNeverStarted) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get()));
}
TEST(ThreadedTaskGroup, FinishAlreadyCompleted) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}
} // namespace internal
} // namespace arrow