#include <time.h>
#include <unistd.h>
#include <dmlc/logging.h>
#include <cstdio>
#include <gtest/gtest.h>
#include <thread>
#include <chrono>
#include <vector>

#include <mxnet/engine.h>
#include "../src/engine/engine_impl.h"
#include <dmlc/timer.h>

/**
 * present the following workload
 *  n = reads.size()
 *  data[write] = (data[reads[0]] + ... data[reads[n]]) / n
 *  std::this_thread::sleep_for(std::chrono::microsecons(time));
 */
struct Workload {
  std::vector<int> reads;
  int write;
  int time;
};

/**
 * generate a list of workloads
 */
void GenerateWorkload(int num_workloads, int num_var,
                      int min_read, int max_read,
                      int min_time, int max_time,
                      std::vector<Workload>* workloads) {
  workloads->clear();
  workloads->resize(num_workloads);
  for (int i = 0; i < num_workloads; ++i) {
    auto& wl = workloads->at(i);
    wl.write = rand() % num_var;
    int r = rand();
    int num_read = min_read + (r % (max_read - min_read));
    for (int j = 0; j < num_read; ++j) {
      wl.reads.push_back(rand() % num_var);
    }
    wl.time = min_time + rand() % (max_time - min_time);
  }
}

/**
 * evaluate a single workload
 */
void EvaluateWorload(const Workload& wl, std::vector<double>* data) {
  double tmp = 0;
  for (int i : wl.reads) tmp += data->at(i);
  data->at(wl.write) = tmp / (wl.reads.size() + 1);
  if (wl.time > 0) {
    std::this_thread::sleep_for(std::chrono::microseconds(wl.time));
  }
}

/**
 * evaluate a list of workload, return the time used
 */
double EvaluateWorloads(const std::vector<Workload>& workloads,
                      mxnet::Engine* engine,
                      std::vector<double>* data) {
  using namespace mxnet;
  double t = dmlc::GetTime();
  std::vector<Engine::VarHandle> vars;
  if (engine) {
    for (size_t i = 0; i < data->size(); ++i) {
      vars.push_back(engine->NewVariable());
    }
  }

  for (const auto& wl : workloads) {
    if (wl.reads.size() == 0) continue;
    if (engine == NULL) {
      EvaluateWorload(wl, data);
    } else {
      auto func = [wl,data](RunContext ctx, Engine::CallbackOnComplete cb) {
        EvaluateWorload(wl, data); cb();
      };
      std::vector<Engine::VarHandle> reads;
      for (auto i : wl.reads) {
        if (i != wl.write) reads.push_back(vars[i]);
      }
      engine->PushAsync(func, Context::CPU(), reads, {vars[wl.write]});
    }
  }

  if (engine) {
    engine->WaitForAll();
  }
  return dmlc::GetTime() - t;
}

TEST(Engine, RandSumExpr) {
  std::vector<Workload> workloads;
  int num_repeat = 5;
  const int num_engine = 4;

  std::vector<double> t(num_engine, 0.0);
  std::vector<mxnet::Engine*> engine(num_engine);

  engine[0] = NULL;
  engine[1] = mxnet::engine::CreateNaiveEngine();
  engine[2] = mxnet::engine::CreateThreadedEnginePooled();
  engine[3] = mxnet::engine::CreateThreadedEnginePerDevice();

  for (int repeat = 0; repeat < num_repeat; ++repeat) {
    srand(time(NULL) + repeat);
    int num_var = 100;
    GenerateWorkload(10000, num_var, 2, 20, 1, 10, &workloads);
    std::vector<std::vector<double>> data(num_engine);
    for (int i = 0; i < num_engine; ++i) {
      data[i].resize(num_var, 1.0);
      t[i] += EvaluateWorloads(workloads, engine[i], &data[i]);
    }

    for (int i = 1; i < num_engine; ++i) {
      for (int j = 0; j < num_var; ++j) EXPECT_EQ(data[0][j], data[i][j]);
    }
    LOG(INFO) << "data: " << data[0][1] << " " << data[0][2] << "...";
  }


  LOG(INFO) << "baseline\t\t"  << t[0] << " sec";
  LOG(INFO) << "NaiveEngine\t\t"  << t[1] << " sec";
  LOG(INFO) << "ThreadedEnginePooled\t" << t[2] << " sec";
  LOG(INFO) << "ThreadedEnginePerDevice\t" << t[3] << " sec";
}

void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }

TEST(Engine, basics) {
  auto&& engine = mxnet::Engine::Get();
  auto&& var = engine->NewVariable();
  std::vector<mxnet::Engine::OprHandle> oprs;

  // Test #1
  printf("============= Test #1 ==============\n");
  for (int i = 0; i < 10; ++i) {
    oprs.push_back(engine->NewOperator(
        [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
          Foo(ctx, i);
          std::this_thread::sleep_for(std::chrono::seconds{1});
          cb();
        },
        {var}, {}));
    engine->Push(oprs.at(i), mxnet::Context{});
  }
  engine->WaitForAll();
  printf("Going to push delete\n");
  // std::this_thread::sleep_for(std::chrono::seconds{1});
  for (auto&& i : oprs) {
    engine->DeleteOperator(i);
  }
  engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
  engine->WaitForAll();

  printf("============= Test #2 ==============\n");
  var = engine->NewVariable();
  oprs.clear();
  for (int i = 0; i < 10; ++i) {
    oprs.push_back(engine->NewOperator(
        [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
          Foo(ctx, i);
          std::this_thread::sleep_for(std::chrono::milliseconds{500});
          cb();
        },
        {}, {var}));
    engine->Push(oprs.at(i), mxnet::Context{});
  }
  // std::this_thread::sleep_for(std::chrono::seconds{1});
  engine->WaitForAll();
  for (auto&& i : oprs) {
    engine->DeleteOperator(i);
  }
  engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);

  printf("============= Test #3 ==============\n");
  var = engine->NewVariable();
  oprs.clear();
  engine->WaitForVar(var);
  engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
  engine->WaitForAll();

  printf("============= Test #4 ==============\n");
  var = engine->NewVariable();
  oprs.clear();
  oprs.push_back(engine->NewOperator(
      [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
        std::this_thread::sleep_for(std::chrono::seconds{2});
        Foo(ctx, 42);
        cb();
      },
      {}, {var}, mxnet::FnProperty::kCopyFromGPU));
  engine->Push(oprs.at(0), mxnet::Context{});
  LOG(INFO) << "IO operator pushed, should wait for 2 seconds.";
  engine->WaitForVar(var);
  LOG(INFO) << "OK, here I am.";
  for (auto&& i : oprs) {
    engine->DeleteOperator(i);
  }
  engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
  engine->WaitForAll();

  printf("============= Test #5 ==============\n");
  var = engine->NewVariable();
  oprs.clear();
  oprs.push_back(engine->NewOperator(
      [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
        Foo(ctx, 42);
        std::this_thread::sleep_for(std::chrono::seconds{2});
        cb();
      },
      {var}, {}));
  engine->Push(oprs.at(0), mxnet::Context{});
  LOG(INFO) << "Operator pushed, should not wait.";
  engine->WaitForVar(var);
  LOG(INFO) << "OK, here I am.";
  engine->WaitForAll();
  LOG(INFO) << "That was 2 seconds.";
  for (auto&& i : oprs) {
    engine->DeleteOperator(i);
  }
  engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
  engine->WaitForAll();
  var = nullptr;
  oprs.clear();
  LOG(INFO) << "All pass";
}
