#include <var_opt_union.hpp>

#include <catch.hpp>

#include <vector>
#include <string>
#include <sstream>
#include <fstream>
#include <cmath>
#include <random>

#ifdef TEST_BINARY_INPUT_PATH
static std::string testBinaryInputPath = TEST_BINARY_INPUT_PATH;
#else
static std::string testBinaryInputPath = "test/";
#endif

namespace datasketches {
  
static constexpr double EPS = 1e-13;

static var_opt_sketch<int> create_unweighted_sketch(uint32_t k, uint64_t n) {
  var_opt_sketch<int> sk(k);
  for (uint64_t i = 0; i < n; ++i) {
    sk.update(i, 1.0);
  }
  return sk;
}

// if exact_compare = false, checks for equivalence -- specific R region values may differ but
// R region weights must match
template<typename T, typename S, typename A>
static void check_if_equal(var_opt_sketch<T,S,A>& sk1, var_opt_sketch<T,S,A>& sk2, bool exact_compare = true) {
  REQUIRE(sk1.get_k() == sk2.get_k());
  REQUIRE(sk1.get_n() == sk2.get_n());
  REQUIRE(sk1.get_num_samples() == sk2.get_num_samples());
      
  auto it1 = sk1.begin();
  auto it2 = sk2.begin();
  size_t i = 0;

  while ((it1 != sk1.end()) && (it2 != sk2.end())) {
    const std::pair<const T&, const double> p1 = *it1;
    const std::pair<const T&, const double> p2 = *it2;
    if (exact_compare) {
      REQUIRE(p1.first == p2.first); // data values
    }
    REQUIRE(p1.second == p2.second); // weight values
    ++i;
    ++it1;
    ++it2;
  }

  REQUIRE((it1 == sk1.end() && it2 == sk2.end())); // iterators must end at the same time
}

// compare serialization and deserialization results, checking string and stream methods to
// ensure that the resulting binary images are compatible.
// if exact_compare = false, checks for equivalence -- specific R region values may differ but
// R region weights must match
template<typename T, typename S, typename A>
static void compare_serialization_deserialization(var_opt_union<T,S,A>& vo_union, bool exact_compare = true) {
  std::vector<uint8_t> bytes = vo_union.serialize();

  var_opt_union<T> u_from_bytes = var_opt_union<T>::deserialize(bytes.data(), bytes.size());
  var_opt_sketch<T> sk1 = vo_union.get_result();
  var_opt_sketch<T> sk2 = u_from_bytes.get_result();
  check_if_equal(sk1, sk2, exact_compare);

  std::string str(bytes.begin(), bytes.end());
  std::stringstream ss;
  ss.str(str);

  var_opt_union<T> u_from_stream = var_opt_union<T>::deserialize(ss);
  sk2 = u_from_stream.get_result();
  check_if_equal(sk1, sk2, exact_compare);

  ss.seekg(0); // didn't put anything so only reset read position
  vo_union.serialize(ss);
  u_from_stream = var_opt_union<T>::deserialize(ss);
  sk2 = u_from_stream.get_result();
  check_if_equal(sk1, sk2, exact_compare);

  std::string str_from_stream = ss.str();
  var_opt_union<T> u_from_str = var_opt_union<T>::deserialize(str_from_stream.c_str(), str_from_stream.size());
  sk2 = u_from_str.get_result();
  check_if_equal(sk1, sk2, exact_compare);

  // check truncated input, too
  REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(bytes.data(), bytes.size() - 5), std::out_of_range);
  std::string str_trunc((char*)&bytes[0], bytes.size() - 5);
  ss.str(str_trunc);
  // next line may throw either std::illegal_argument or std::runtime_exception
  REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(ss), std::exception);
}

TEST_CASE("varopt union: bad prelongs", "[var_opt_union]") {
  var_opt_sketch<int> sk = create_unweighted_sketch(32, 33);
  var_opt_union<int> u(32);
  u.update(sk);
  std::vector<uint8_t> bytes = u.serialize();

  bytes[0] = 0; // corrupt the preamble longs byte to be too small
  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(bytes.data(), bytes.size()), std::invalid_argument);

  // create a stringstream to check the same
  std::stringstream ss;
  std::string str(bytes.begin(), bytes.end());
  ss.str(str);
  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(ss), std::invalid_argument);
}

TEST_CASE("varopt union: bad serialization version", "[var_opt_union]") {
  var_opt_sketch<int> sk = create_unweighted_sketch(16, 16);
  var_opt_union<int> u(32);
  u.update(sk);
  std::vector<uint8_t> bytes = u.serialize();
  bytes[1] = 0; // corrupt the serialization version byte

  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(bytes.data(), bytes.size()), std::invalid_argument);

  // create a stringstream to check the same
  std::stringstream ss;
  std::string str(bytes.begin(), bytes.end());
  ss.str(str);
  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(ss), std::invalid_argument);
}

TEST_CASE("varopt union: invalid k", "[var_opt_union]") {
  REQUIRE_THROWS_AS(var_opt_union<int>(0), std::invalid_argument);
  REQUIRE_THROWS_AS(var_opt_union<int>(1<<31), std::invalid_argument);
}

TEST_CASE("varopt union: bad family", "[var_opt_union]") {
  var_opt_sketch<int> sk = create_unweighted_sketch(16, 16);
  var_opt_union<int> u(15);
  u.update(sk);
  std::vector<uint8_t> bytes = u.serialize();
  bytes[2] = 0; // corrupt the family byte

  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(bytes.data(), bytes.size()), std::invalid_argument);

  std::stringstream ss;
  std::string str(bytes.begin(), bytes.end());
  ss.str(str);
  REQUIRE_THROWS_AS(var_opt_union<int>::deserialize(ss), std::invalid_argument);
}

TEST_CASE("varopt union: empty union", "[var_opt_union]") {
  uint32_t k = 2048;
  var_opt_sketch<std::string> sk(k);
  var_opt_union<std::string> u(k);
  u.update(sk);

  var_opt_sketch<std::string> result = u.get_result();
  REQUIRE(result.is_empty());
  REQUIRE(result.get_n() == 0);
  REQUIRE(result.get_num_samples() == 0);
  REQUIRE(result.get_k() == k);
}

TEST_CASE("varopt union: two exact sketches", "[var_opt_union]") {
  uint64_t n = 4; // 2n < k
  uint32_t k = 10;
  var_opt_sketch<int> sk1(k), sk2(k);

  for (uint64_t i = 1; i <= n; ++i) {
    sk1.update(i, i);
    sk2.update(static_cast<int64_t>(-i), i);
  }

  var_opt_union<int> u(k);
  u.update(sk1);
  u.update(sk2);

  var_opt_sketch<int> result = u.get_result();
  REQUIRE(result.get_n() == 2 * n);
  REQUIRE(result.get_k() == k);
}

TEST_CASE("varopt union: heavy sampling sketch", "[var_opt_union]") {
  uint64_t n1 = 20;
  uint32_t k1 = 10;
  uint64_t n2 = 6;
  uint32_t k2 = 5;
  var_opt_sketch<int64_t> sk1(k1), sk2(k2);
  for (uint64_t i = 1; i <= n1; ++i) {
    sk1.update(i, i);
  }

  for (uint64_t i = 1; i < n2; ++i) { // we'll add a very heavy one later
    sk2.update(static_cast<int64_t>(-i), i + 1000.0);
  }
  sk2.update(-n2, 1000000.0);

  var_opt_union<int64_t> u(k1);
  u.update(sk1);
  u.update(sk2);

  var_opt_sketch<int64_t> result = u.get_result();
  REQUIRE(result.get_n() == n1 + n2);
  REQUIRE(result.get_k() == k2); // heavy enough the result pulls back to k2

  u.reset();
  result = u.get_result();
  REQUIRE(result.get_n() == 0);
  REQUIRE(result.get_k() == k1); // union reset so empty result reflects max_k
}

TEST_CASE("varopt union: identical sampling sketches", "[var_opt_union]") {
  uint32_t k = 20;
  uint64_t n = 50;
  var_opt_sketch<int> sk = create_unweighted_sketch(k, n);

  var_opt_union<int> u(k);
  u.update(sk);
  u.update(sk);

  var_opt_sketch<int> result = u.get_result();
  double expected_wt = 2.0 * n;
  subset_summary ss = result.estimate_subset_sum([](int){return true;});
  REQUIRE(result.get_n() == 2 * n);
  REQUIRE(ss.total_sketch_weight == Approx(expected_wt).margin(EPS));

  // add another sketch, such that sketch_tau < outer_tau
  sk = create_unweighted_sketch(k, k + 1); // tau = (k + 1) / k
  u.update(sk);
  result = u.get_result();
  expected_wt = (2.0 * n) + k + 1;
  ss = result.estimate_subset_sum([](int){return true;});
  REQUIRE(result.get_n() == (2 * n) + k + 1);
  REQUIRE(ss.total_sketch_weight == Approx(expected_wt).margin(EPS));
}

TEST_CASE("varopt union: small sampling sketch", "[var_opt_union]") {
  uint32_t k_small = 16;
  uint32_t k_max = 128;
  uint64_t n1 = 32;
  uint64_t n2 = 64;

  var_opt_sketch<float> sk(k_small);
  for (uint64_t i = 0; i < n1; ++i) { sk.update(i); }
  sk.update(-1, n1 * n1); // add a heavy item

  var_opt_union<float> u(k_max);
  u.update(sk);

  // another one, but different n to get a different per-item weight
  var_opt_sketch<float> sk2(k_small);
  for (uint64_t i = 0; i < n2; ++i) { sk2.update(i); }
  u.update(sk2);

  // should trigger migrate_marked_items_by_decreasing_k()
  var_opt_sketch<float> result = u.get_result();
  REQUIRE(result.get_n() == n1 + n2 + 1);
  
  double expected_wt = 1.0 * (n1 + n2); // n1 + n2 light items, ignore the heavy one
  subset_summary ss = result.estimate_subset_sum([](float x){return x >= 0;});
  REQUIRE(ss.estimate == Approx(expected_wt).margin(EPS));
  REQUIRE(ss.total_sketch_weight == Approx(expected_wt + (n1 * n1)).margin(EPS));
  REQUIRE(result.get_k() < k_max);

  // check that mark information is preserved as expected
  compare_serialization_deserialization(u, false);
}

TEST_CASE("varopt union: serialize empty", "[var_opt_union]") {
  var_opt_union<std::string> u(100);
  compare_serialization_deserialization(u);
}

TEST_CASE("varopt union: serialize exact", "[var_opt_union]") {
  uint32_t k = 100;
  var_opt_union<int> u(k);
  var_opt_sketch<int> sk = create_unweighted_sketch(k, k / 2);
  u.update(sk);

  compare_serialization_deserialization(u);
}

TEST_CASE("varopt union: serialize sampling", "[var_opt_union]") {
  uint32_t k = 100;
  var_opt_union<int> u(k);
  var_opt_sketch<int> sk = create_unweighted_sketch(k, 2 * k);
  u.update(sk);

  compare_serialization_deserialization(u);
}

TEST_CASE("varopt union: deserialize from java", "[var_opt_union]") {
  std::ifstream is;
  is.exceptions(std::ios::failbit | std::ios::badbit);
  is.open(testBinaryInputPath + "varopt_union_double_sampling.sk", std::ios::binary);
  var_opt_union<double> u = var_opt_union<double>::deserialize(is);
    
  // must reduce k in the process, like in small_sampling_sketch()
  var_opt_sketch<double> result = u.get_result();
  REQUIRE_FALSE(result.is_empty());
  REQUIRE(result.get_n() == 97);
  
  double expected_wt = 96.0;// light items -- ignoring the heavy one
  subset_summary ss = result.estimate_subset_sum([](double x){return x >= 0;});
  REQUIRE(ss.estimate == Approx(expected_wt).margin(EPS));
  REQUIRE(ss.total_sketch_weight == Approx(expected_wt + 1024.0).margin(EPS));
  REQUIRE(result.get_k() < 128);
}

}
