blob: 0b8c810da70b72bc03abd05fd5b5c0d7cb670562 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file parallel_for.cc
* \brief An implementation to run loop in parallel.
*/
#include <dmlc/logging.h>
#include <tvm/support/parallel_for.h>
#include <future>
#include <thread>
#include <utility>
#include <vector>
namespace tvm {
namespace support {
std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
int total_task_count = (end - begin) / step;
CHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin
<< " end: " << end << " step: " << step;
std::vector<std::vector<int>> ret;
ret.reserve(num_threads);
for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
if (thread >= ret.size()) {
ret.push_back(std::vector<int>());
}
ret[thread].push_back(begin);
}
return ret;
}
void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
const PartitionerFuncType partitioner) {
static bool GLOBAL_PARALLEL_FOR_FLAG{false};
static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG;
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're "
<< "currently inside another parallel_for loop.";
GLOBAL_PARALLEL_FOR_FLAG = true;
}
int default_num_threads = std::thread::hardware_concurrency();
const auto& run_partitions = partitioner(begin, end, step, default_num_threads);
std::vector<std::thread> threads;
threads.reserve(run_partitions.size());
std::vector<std::future<void>> res_vec;
res_vec.reserve(run_partitions.size());
for (const auto& run_partition : run_partitions) {
std::packaged_task<void(const std::vector<int>&, const std::function<void(int)>&)> task(
[](const std::vector<int>& run_pattition, const std::function<void(int)>& f) {
for (const auto& i : run_pattition) {
f(i);
}
});
res_vec.emplace_back(task.get_future());
threads.emplace_back(std::move(task), run_partition, f);
}
for (auto&& thread : threads) {
thread.join();
}
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(GLOBAL_PARALLEL_FOR_FLAG);
GLOBAL_PARALLEL_FOR_FLAG = false;
}
try {
for (auto&& i : res_vec) {
i.get();
}
} catch (const std::exception& e) {
LOG(FATAL) << "Parallel_for error with " << e.what();
}
}
} // namespace support
} // namespace tvm