blob: 73dc53060b63576c887316db27aaed0caa7e2d02 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file threaded_engine_test.cc
* \brief threaded engine tests
*/
#include <time.h>
#include <unistd.h>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <mxnet/engine.h>
#include <dmlc/timer.h>
#include <cstdio>
#include <thread>
#include <chrono>
#include <vector>
#include "../src/engine/engine_impl.h"
/**
* present the following workload
* n = reads.size()
* data[write] = (data[reads[0]] + ... data[reads[n]]) / n
* std::this_thread::sleep_for(std::chrono::microsecons(time));
*/
struct Workload {
std::vector<int> reads;
int write;
int time;
};
static u_int32_t seed_ = 0xdeadbeef;
/**
* generate a list of workloads
*/
void GenerateWorkload(int num_workloads, int num_var,
int min_read, int max_read,
int min_time, int max_time,
std::vector<Workload>* workloads) {
workloads->clear();
workloads->resize(num_workloads);
for (int i = 0; i < num_workloads; ++i) {
auto& wl = workloads->at(i);
wl.write = rand_r(&seed_) % num_var;
int r = rand_r(&seed_);
int num_read = min_read + (r % (max_read - min_read));
for (int j = 0; j < num_read; ++j) {
wl.reads.push_back(rand_r(&seed_) % num_var);
}
wl.time = min_time + rand_r(&seed_) % (max_time - min_time);
}
}
/**
* evaluate a single workload
*/
void EvaluateWorload(const Workload& wl, std::vector<double>* data) {
double tmp = 0;
for (int i : wl.reads) tmp += data->at(i);
data->at(wl.write) = tmp / (wl.reads.size() + 1);
if (wl.time > 0) {
std::this_thread::sleep_for(std::chrono::microseconds(wl.time));
}
}
/**
* evaluate a list of workload, return the time used
*/
double EvaluateWorloads(const std::vector<Workload>& workloads,
mxnet::Engine* engine,
std::vector<double>* data) {
using namespace mxnet;
double t = dmlc::GetTime();
std::vector<Engine::VarHandle> vars;
if (engine) {
for (size_t i = 0; i < data->size(); ++i) {
vars.push_back(engine->NewVariable());
}
}
for (const auto& wl : workloads) {
if (wl.reads.size() == 0) continue;
if (engine == NULL) {
EvaluateWorload(wl, data);
} else {
auto func = [wl, data](RunContext ctx, Engine::CallbackOnComplete cb) {
EvaluateWorload(wl, data); cb();
};
std::vector<Engine::VarHandle> reads;
for (auto i : wl.reads) {
if (i != wl.write) reads.push_back(vars[i]);
}
engine->PushAsync(func, Context::CPU(), reads, {vars[wl.write]});
}
}
if (engine) {
engine->WaitForAll();
}
return dmlc::GetTime() - t;
}
TEST(Engine, RandSumExpr) {
std::vector<Workload> workloads;
int num_repeat = 5;
const int num_engine = 4;
std::vector<double> t(num_engine, 0.0);
std::vector<mxnet::Engine*> engine(num_engine);
engine[0] = NULL;
engine[1] = mxnet::engine::CreateNaiveEngine();
engine[2] = mxnet::engine::CreateThreadedEnginePooled();
engine[3] = mxnet::engine::CreateThreadedEnginePerDevice();
for (int repeat = 0; repeat < num_repeat; ++repeat) {
srand(time(NULL) + repeat);
int num_var = 100;
GenerateWorkload(10000, num_var, 2, 20, 1, 10, &workloads);
std::vector<std::vector<double>> data(num_engine);
for (int i = 0; i < num_engine; ++i) {
data[i].resize(num_var, 1.0);
t[i] += EvaluateWorloads(workloads, engine[i], &data[i]);
}
for (int i = 1; i < num_engine; ++i) {
for (int j = 0; j < num_var; ++j) EXPECT_EQ(data[0][j], data[i][j]);
}
LOG(INFO) << "data: " << data[0][1] << " " << data[0][2] << "...";
}
LOG(INFO) << "baseline\t\t" << t[0] << " sec";
LOG(INFO) << "NaiveEngine\t\t" << t[1] << " sec";
LOG(INFO) << "ThreadedEnginePooled\t" << t[2] << " sec";
LOG(INFO) << "ThreadedEnginePerDevice\t" << t[3] << " sec";
}
void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }
TEST(Engine, basics) {
auto&& engine = mxnet::Engine::Get();
auto&& var = engine->NewVariable();
std::vector<mxnet::Engine::OprHandle> oprs;
// Test #1
printf("============= Test #1 ==============\n");
for (int i = 0; i < 10; ++i) {
oprs.push_back(engine->NewOperator(
[i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
Foo(ctx, i);
std::this_thread::sleep_for(std::chrono::seconds{1});
cb();
},
{var}, {}));
engine->Push(oprs.at(i), mxnet::Context{});
}
engine->WaitForAll();
printf("Going to push delete\n");
// std::this_thread::sleep_for(std::chrono::seconds{1});
for (auto&& i : oprs) {
engine->DeleteOperator(i);
}
engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
engine->WaitForAll();
printf("============= Test #2 ==============\n");
var = engine->NewVariable();
oprs.clear();
for (int i = 0; i < 10; ++i) {
oprs.push_back(engine->NewOperator(
[i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
Foo(ctx, i);
std::this_thread::sleep_for(std::chrono::milliseconds{500});
cb();
},
{}, {var}));
engine->Push(oprs.at(i), mxnet::Context{});
}
// std::this_thread::sleep_for(std::chrono::seconds{1});
engine->WaitForAll();
for (auto&& i : oprs) {
engine->DeleteOperator(i);
}
engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
printf("============= Test #3 ==============\n");
var = engine->NewVariable();
oprs.clear();
engine->WaitForVar(var);
engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
engine->WaitForAll();
printf("============= Test #4 ==============\n");
var = engine->NewVariable();
oprs.clear();
oprs.push_back(engine->NewOperator(
[](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
std::this_thread::sleep_for(std::chrono::seconds{2});
Foo(ctx, 42);
cb();
},
{}, {var}, mxnet::FnProperty::kCopyFromGPU));
engine->Push(oprs.at(0), mxnet::Context{});
LOG(INFO) << "IO operator pushed, should wait for 2 seconds.";
engine->WaitForVar(var);
LOG(INFO) << "OK, here I am.";
for (auto&& i : oprs) {
engine->DeleteOperator(i);
}
engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
engine->WaitForAll();
printf("============= Test #5 ==============\n");
var = engine->NewVariable();
oprs.clear();
oprs.push_back(engine->NewOperator(
[](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
Foo(ctx, 42);
std::this_thread::sleep_for(std::chrono::seconds{2});
cb();
},
{var}, {}));
engine->Push(oprs.at(0), mxnet::Context{});
LOG(INFO) << "Operator pushed, should not wait.";
engine->WaitForVar(var);
LOG(INFO) << "OK, here I am.";
engine->WaitForAll();
LOG(INFO) << "That was 2 seconds.";
for (auto&& i : oprs) {
engine->DeleteOperator(i);
}
engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
engine->WaitForAll();
var = nullptr;
oprs.clear();
LOG(INFO) << "All pass";
}