blob: e32fd32012a683bb75a3a7daca0dbfa1c428cfd7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/logging.h>
#include <tvm/support/parallel_for.h>
#include <thread>
#include <vector>
TEST(ParallelFor, Basic) {
using tvm::support::parallel_for;
int a[1000], b[1000];
// Check for a small size of parallel
for (int i = 0; i < 10; i++) {
a[i] = i;
}
parallel_for(0, 10, [&b](int i) { b[i] = i; });
for (int i = 0; i < 10; i++) {
ICHECK_EQ(a[i], b[i]);
}
// Check for a large size of parallel
for (int i = 0; i < 1000; i++) {
a[i] = i;
}
parallel_for(0, 1000, [&b](int i) { b[i] = i; });
for (int i = 0; i < 1000; i++) {
ICHECK_EQ(a[i], b[i]);
}
// Check for step != 1
for (int i = 0; i < 1000; i += 2) {
a[i] *= 2;
}
parallel_for(
0, 1000, [&b](int i) { b[i] *= 2; }, 2);
for (int i = 0; i < 1000; i++) {
ICHECK_EQ(a[i], b[i]);
}
}
TEST(ParallelFor, NestedWithNormalForLoop) {
using tvm::support::parallel_for;
int a[500][500], b[500][500], c[500][500];
for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
a[i][j] = i * j;
}
}
parallel_for(0, 500, [&b](int i) {
for (int j = 0; j < 500; j++) {
b[i][j] = i * j;
}
});
for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
ICHECK_EQ(a[i][j], b[i][j]);
}
}
for (int i = 0; i < 500; i++) {
parallel_for(0, 500, [&c, &i](int j) { c[i][j] = i * j; });
}
for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
ICHECK_EQ(a[i][j], c[i][j]);
}
}
}
TEST(ParallelFor, NestedWithParallelFor) {
// Currently do not support using nested parallel_for
using tvm::support::parallel_for;
bool exception = false;
try {
parallel_for(0, 100, [](int i) {
parallel_for(0, 100, [](int j) {
// Blank loop
});
});
} catch (const std::exception& e) {
exception = true;
}
ICHECK(exception);
}
TEST(ParallelFor, Exception) {
using tvm::support::parallel_for;
bool exception = false;
try {
parallel_for(0, 100, [](int i) { LOG(FATAL) << "error"; });
} catch (const std::exception& e) {
exception = true;
}
ICHECK(exception);
}
TEST(ParallelForDynamic, Basic) {
using tvm::support::parallel_for_dynamic;
int a[1000];
int num_threads = std::thread::hardware_concurrency();
parallel_for_dynamic(0, 1000, num_threads, [&a](int thread_id, int i) { a[i] = i; });
for (int i = 0; i < 1000; i++) {
ICHECK_EQ(a[i], i);
}
}
TEST(ParallelForDynamic, ExceptionOnMain) {
using tvm::support::parallel_for_dynamic;
int num_threads = 1;
bool exception = false;
try {
parallel_for_dynamic(0, 10, num_threads, [](int thread_id, int task_id) {
if (thread_id == 0) {
LOG(FATAL) << "Error";
}
});
} catch (const std::exception& e) {
exception = true;
}
ICHECK(exception);
}
TEST(ParallelForDynamic, ExceptionOnArbitrary) {
using tvm::support::parallel_for_dynamic;
int num_threads = 3;
bool exception = false;
try {
parallel_for_dynamic(0, 100, num_threads,
[](int thread_id, int task_id) { LOG(FATAL) << "Error"; });
} catch (const std::exception& e) {
exception = true;
}
ICHECK(exception);
}