// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// All rights reserved.

#include <algorithm>
#include <cstdlib>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <tuple>  // IWYU pragma: keep
#include <utility>
#include <vector>

#include <boost/optional/optional.hpp>
#include <glog/logging.h>
#include <gtest/gtest.h>

#include "kudu/gutil/stringprintf.h"
#include "kudu/gutil/strings/substitute.h"
#include "kudu/util/interval_tree.h"
#include "kudu/util/interval_tree-inl.h"
#include "kudu/util/test_util.h"

using std::pair;
using std::string;
using std::vector;
using strings::Substitute;

namespace kudu {

// Test harness.
class TestIntervalTree : public KuduTest {
};

// Simple interval class for integer intervals.
struct IntInterval {
  IntInterval(int left, int right, int id = -1)
      : left(left),
        right(right),
        id(id) {
  }

  bool Intersects(const IntInterval &other) const {
    if (other.left > right) return false;
    if (left > other.right) return false;
    return true;
  }

  string ToString() const {
    return strings::Substitute("[$0, $1]($2) ", left, right, id);
  }

  int left, right, id;
};

// A wrapper around an int which can be compared with IntTraits::compare()
// but also can keep a counter of how many times it has been compared. Used
// for TestBigO below.
struct CountingQueryPoint {
  explicit CountingQueryPoint(int v)
      : val(v),
        count(new int(0)) {
  }

  int val;
  std::shared_ptr<int> count;
};

// Traits definition for intervals made up of ints on either end.
struct IntTraits {
  typedef int point_type;
  typedef IntInterval interval_type;
  static point_type get_left(const IntInterval &x) {
    return x.left;
  }
  static point_type get_right(const IntInterval &x) {
    return x.right;
  }
  static int compare(int a, int b) {
    if (a < b) return -1;
    if (a > b) return 1;
    return 0;
  }

  static int compare(const CountingQueryPoint& q, int b) {
    (*q.count)++;
    return compare(q.val, b);
  }
  static int compare(int a, const CountingQueryPoint& b) {
    return -compare(b, a);
  }

};

// Compare intervals in an arbitrary but consistent way - this is only
// used for verifying that the two algorithms come up with the same results.
// It's not necessary to define this to use an interval tree.
static bool CompareIntervals(const IntInterval &a, const IntInterval &b) {
  return std::make_tuple(a.left, a.right, a.id) <
    std::make_tuple(b.left, b.right, b.id);
}

// Stringify a list of int intervals, for easy test error reporting.
static string Stringify(const vector<IntInterval> &intervals) {
  string ret;
  bool first = true;
  for (const IntInterval &interval : intervals) {
    if (!first) {
      ret.append(",");
    }
    ret.append(interval.ToString());
  }
  return ret;
}

// Find any intervals in 'intervals' which contain 'query_point' by brute force.
static void FindContainingBruteForce(const vector<IntInterval> &intervals,
                                     int query_point,
                                     vector<IntInterval> *results) {
  for (const IntInterval &i : intervals) {
    if (query_point >= i.left && query_point <= i.right) {
      results->push_back(i);
    }
  }
}


// Find any intervals in 'intervals' which intersect 'query_interval' by brute force.
static void FindIntersectingBruteForce(const vector<IntInterval> &intervals,
                                       IntInterval query_interval,
                                       vector<IntInterval> *results) {
  for (const IntInterval &i : intervals) {
    if (query_interval.Intersects(i)) {
      results->push_back(i);
    }
  }
}


// Verify that IntervalTree::FindContainingPoint yields the same results as the naive
// brute-force O(n) algorithm.
static void VerifyFindContainingPoint(const vector<IntInterval> all_intervals,
                                      const IntervalTree<IntTraits> &tree,
                                      int query_point) {
  vector<IntInterval> results;
  tree.FindContainingPoint(query_point, &results);
  std::sort(results.begin(), results.end(), CompareIntervals);

  vector<IntInterval> brute_force;
  FindContainingBruteForce(all_intervals, query_point, &brute_force);
  std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);

  SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" (q=%d)", query_point));
  EXPECT_EQ(Stringify(brute_force), Stringify(results));
}

// Verify that IntervalTree::FindIntersectingInterval yields the same results as the naive
// brute-force O(n) algorithm.
static void VerifyFindIntersectingInterval(const vector<IntInterval> all_intervals,
                                           const IntervalTree<IntTraits> &tree,
                                           const IntInterval &query_interval) {
  vector<IntInterval> results;
  tree.FindIntersectingInterval(query_interval, &results);
  std::sort(results.begin(), results.end(), CompareIntervals);

  vector<IntInterval> brute_force;
  FindIntersectingBruteForce(all_intervals, query_interval, &brute_force);
  std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);

  SCOPED_TRACE(Stringify(all_intervals) +
               StringPrintf(" (q=[%d,%d])", query_interval.left, query_interval.right));
  EXPECT_EQ(Stringify(brute_force), Stringify(results));
}

static vector<IntInterval> CreateRandomIntervals(int n = 100) {
  vector<IntInterval> intervals;
  for (int i = 0; i < n; i++) {
    int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
    int r = l + rand() % 20; // NOLINT(runtime/threadsafe_fn)
    intervals.emplace_back(l, r, i);
  }
  return intervals;
}

TEST_F(TestIntervalTree, TestBasic) {
  vector<IntInterval> intervals;
  intervals.emplace_back(1, 2, 1);
  intervals.emplace_back(3, 4, 2);
  intervals.emplace_back(1, 4, 3);
  IntervalTree<IntTraits> t(intervals);

  for (int i = 0; i <= 5; i++) {
    VerifyFindContainingPoint(intervals, t, i);

    for (int j = i; j <= 5; j++) {
      VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0));
    }
  }
}

TEST_F(TestIntervalTree, TestRandomized) {
  SeedRandom();

  // Generate 100 random intervals spanning 0-200 and build an interval tree from them.
  vector<IntInterval> intervals = CreateRandomIntervals();
  IntervalTree<IntTraits> t(intervals);

  // Test that we get the correct result on every possible query.
  for (int i = -1; i < 201; i++) {
    VerifyFindContainingPoint(intervals, t, i);
  }

  // Test that we get the correct result for random intervals
  for (int i = 0; i < 100; i++) {
    int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
    int r = l + rand() % 100; // NOLINT(runtime/threadsafe_fn)
    VerifyFindIntersectingInterval(intervals, t, IntInterval(l, r));
  }
}

TEST_F(TestIntervalTree, TestEmpty) {
  vector<IntInterval> empty;
  IntervalTree<IntTraits> t(empty);

  VerifyFindContainingPoint(empty, t, 1);
  VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0));
}

TEST_F(TestIntervalTree, TestBigO) {
#ifndef NDEBUG
  LOG(WARNING) << "big-O results are not valid if DCHECK is enabled";
  return;
#endif
  SeedRandom();

  LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch";
  for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) {
    vector<IntInterval> intervals = CreateRandomIntervals(num_intervals);
    IntervalTree<IntTraits> t(intervals);
    for (int num_queries = 1; num_queries < 2000; num_queries *= 2) {
      vector<CountingQueryPoint> queries;
      for (int i = 0; i < num_queries; i++) {
        queries.emplace_back(rand() % 100);
      }
      std::sort(queries.begin(), queries.end(),
                [](const CountingQueryPoint& a,
                   const CountingQueryPoint& b) {
                  return a.val < b.val;
                });

      // Test using batch algorithm.
      int num_results_batch = 0;
      t.ForEachIntervalContainingPoints(
          queries,
          [&](CountingQueryPoint query_point, const IntInterval& interval) {
            num_results_batch++;
          });
      int num_comparisons_batch = 0;
      for (const auto& q : queries) {
        num_comparisons_batch += *q.count;
        *q.count = 0;
      }

      // Test using one-by-one queries.
      int num_results_simple = 0;
      for (auto& q : queries) {
        vector<IntInterval> intervals;
        t.FindContainingPoint(q, &intervals);
        num_results_simple += intervals.size();
      }
      int num_comparisons_simple = 0;
      for (const auto& q : queries) {
        num_comparisons_simple += *q.count;
      }
      ASSERT_EQ(num_results_simple, num_results_batch);

      LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t"
                << num_comparisons_simple << "\t" << num_comparisons_batch;
    }
  }
}

TEST_F(TestIntervalTree, TestMultiQuery) {
  SeedRandom();
  const int kNumQueries = 1;
  vector<IntInterval> intervals = CreateRandomIntervals(10);
  IntervalTree<IntTraits> t(intervals);

  // Generate random queries.
  vector<int> queries;
  for (int i = 0; i < kNumQueries; i++) {
    queries.push_back(rand() % 100);
  }
  std::sort(queries.begin(), queries.end());

  vector<pair<string, int>> results_simple;
  for (int q : queries) {
    vector<IntInterval> intervals;
    t.FindContainingPoint(q, &intervals);
    for (const auto& interval : intervals) {
      results_simple.emplace_back(interval.ToString(), q);
    }
  }

  vector<pair<string, int>> results_batch;
  t.ForEachIntervalContainingPoints(
      queries,
      [&](int query_point, const IntInterval& interval) {
        results_batch.emplace_back(interval.ToString(), query_point);
      });

  // Check the property that, when the batch query points are in sorted order,
  // the results are grouped by interval, and within each interval, sorted by
  // query point. Each interval may have at most two groups.
  boost::optional<pair<string, int>> prev = boost::none;
  std::map<string, int> intervals_seen;
  for (int i = 0; i < results_batch.size(); i++) {
    const auto& cur = results_batch[i];
    // If it's another query point hitting the same interval,
    // make sure the query points are returned in order.
    if (prev && prev->first == cur.first) {
      EXPECT_GE(cur.second, prev->second) << prev->first;
    } else {
      // It's the start of a new interval's data. Make sure that we don't
      // see the same interval twice.
      EXPECT_LE(++intervals_seen[cur.first], 2)
          << "Saw more than two groups for interval " << cur.first;
    }
    prev = cur;
  }

  std::sort(results_simple.begin(), results_simple.end());
  std::sort(results_batch.begin(), results_batch.end());
  ASSERT_EQ(results_simple, results_batch);
}

} // namespace kudu
