blob: 95650e5316a4d5b89c313d8fcfe3781c94d0e4ce [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <bthread/countdown_event.h>
#include <cpp/sync_point.h>
#include <fmt/core.h>
#include <gen_cpp/cloud.pb.h>
#include <glog/logging.h>
#include <chrono>
#include <future>
#include <string>
#include "common/defer.h"
#include "common/simple_thread_pool.h"
namespace doris::cloud {
template <typename T>
class SyncExecutor {
public:
SyncExecutor(
std::shared_ptr<SimpleThreadPool> pool, std::string name_tag,
std::function<bool(const T&)> cancel = [](const T& /**/) { return false; })
: _pool(std::move(pool)), _cancel(std::move(cancel)), _name_tag(std::move(name_tag)) {}
auto add(std::function<T()> callback) -> SyncExecutor<T>& {
auto task = std::make_unique<Task>(std::move(callback), _cancel, _count);
_count.add_count();
// The actual task logic would be wrapped by one promise and passed to the threadpool.
// The result would be returned by the future once the task is finished.
// Or the task would be invalid if the whole task is cancelled.
int r = _pool->submit([this, t = task.get()]() { (*t)(_stop_token); });
CHECK(r == 0);
_res.emplace_back(std::move(task));
return *this;
}
std::vector<T> when_all(bool* finished) {
DORIS_CLOUD_DEFER {
_reset();
};
timespec current_time;
auto current_time_second = time(nullptr);
current_time.tv_sec = current_time_second + 300;
current_time.tv_nsec = 0;
// Wait for all tasks to complete
while (0 != _count.timed_wait(current_time)) {
current_time.tv_sec += 300;
LOG(WARNING) << _name_tag << " has already taken 5 min, cost: "
<< time(nullptr) - current_time_second << " seconds";
}
*finished = !_stop_token;
std::vector<T> res;
res.reserve(_res.size());
for (auto& task : _res) {
if (!task->valid()) {
*finished = false;
return res;
}
size_t max_wait_ms = 10000;
TEST_SYNC_POINT_CALLBACK("SyncExecutor::when_all.set_wait_time", &max_wait_ms);
// _count.timed_wait has already ensured that all tasks are completed.
// The 10 seconds here is just waiting for the task results to be returned,
// so 10 seconds is more than enough.
auto status = task->wait_for(max_wait_ms);
if (status == std::future_status::ready) {
res.emplace_back(task->get());
} else {
*finished = false;
LOG(WARNING) << _name_tag << " task timed out after 10 seconds";
return res;
}
}
return res;
}
private:
void _reset() {
_count.reset(0);
_res.clear();
_stop_token = false;
}
private:
class Task {
public:
Task(std::function<T()> callback, std::function<bool(const T&)> cancel,
bthread::CountdownEvent& count)
: _callback(std::move(callback)),
_cancel(std::move(cancel)),
_count(count),
_fut(_pro.get_future()) {}
void operator()(std::atomic_bool& stop_token) {
DORIS_CLOUD_DEFER {
_count.signal();
};
if (stop_token) {
_valid = false;
return;
}
T t = _callback();
// We'll return this task result to user even if this task return error
// So we don't set _valid to false here
if (_cancel(t)) {
stop_token = true;
}
_pro.set_value(std::move(t));
}
std::future_status wait_for(size_t milliseconds) {
return _fut.wait_for(std::chrono::milliseconds(milliseconds));
}
bool valid() { return _valid; }
T get() { return _fut.get(); }
private:
// It's guarantted that the valid function can only be called inside SyncExecutor's `when_all()` function
// and only be called when the _count.timed_wait function returned. So there would be no data race for
// _valid then it doesn't need to be one atomic bool.
bool _valid = true;
std::function<T()> _callback;
std::function<bool(const T&)> _cancel;
std::promise<T> _pro;
bthread::CountdownEvent& _count;
std::future<T> _fut;
};
std::vector<std::unique_ptr<Task>> _res;
// use CountdownEvent to do periodically log using CountdownEvent::time_wait()
bthread::CountdownEvent _count {0};
std::atomic_bool _stop_token {false};
std::shared_ptr<SimpleThreadPool> _pool;
std::function<bool(const T&)> _cancel;
std::string _name_tag;
};
} // namespace doris::cloud