blob: df143d3b294391c4e75112a1c01e4ad25c02aac5 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <tuple> // IWYU pragma: keep
#include <utility>
#include <vector>
#include <boost/optional/optional.hpp>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "kudu/gutil/stringprintf.h"
#include "kudu/gutil/strings/substitute.h"
#include "kudu/util/interval_tree.h"
#include "kudu/util/interval_tree-inl.h"
#include "kudu/util/test_util.h"
using std::pair;
using std::string;
using std::vector;
using strings::Substitute;
namespace kudu {
// Test harness.
class TestIntervalTree : public KuduTest {
};
// Simple interval class for integer intervals.
struct IntInterval {
IntInterval(int left, int right, int id = -1)
: left(left),
right(right),
id(id) {
}
bool Intersects(const IntInterval &other) const {
if (other.left > right) return false;
if (left > other.right) return false;
return true;
}
string ToString() const {
return strings::Substitute("[$0, $1]($2) ", left, right, id);
}
int left, right, id;
};
// A wrapper around an int which can be compared with IntTraits::compare()
// but also can keep a counter of how many times it has been compared. Used
// for TestBigO below.
struct CountingQueryPoint {
explicit CountingQueryPoint(int v)
: val(v),
count(new int(0)) {
}
int val;
std::shared_ptr<int> count;
};
// Traits definition for intervals made up of ints on either end.
struct IntTraits {
typedef int point_type;
typedef IntInterval interval_type;
static point_type get_left(const IntInterval &x) {
return x.left;
}
static point_type get_right(const IntInterval &x) {
return x.right;
}
static int compare(int a, int b) {
if (a < b) return -1;
if (a > b) return 1;
return 0;
}
static int compare(const CountingQueryPoint& q, int b) {
(*q.count)++;
return compare(q.val, b);
}
static int compare(int a, const CountingQueryPoint& b) {
return -compare(b, a);
}
};
// Compare intervals in an arbitrary but consistent way - this is only
// used for verifying that the two algorithms come up with the same results.
// It's not necessary to define this to use an interval tree.
static bool CompareIntervals(const IntInterval &a, const IntInterval &b) {
return std::make_tuple(a.left, a.right, a.id) <
std::make_tuple(b.left, b.right, b.id);
}
// Stringify a list of int intervals, for easy test error reporting.
static string Stringify(const vector<IntInterval> &intervals) {
string ret;
bool first = true;
for (const IntInterval &interval : intervals) {
if (!first) {
ret.append(",");
}
ret.append(interval.ToString());
}
return ret;
}
// Find any intervals in 'intervals' which contain 'query_point' by brute force.
static void FindContainingBruteForce(const vector<IntInterval> &intervals,
int query_point,
vector<IntInterval> *results) {
for (const IntInterval &i : intervals) {
if (query_point >= i.left && query_point <= i.right) {
results->push_back(i);
}
}
}
// Find any intervals in 'intervals' which intersect 'query_interval' by brute force.
static void FindIntersectingBruteForce(const vector<IntInterval> &intervals,
IntInterval query_interval,
vector<IntInterval> *results) {
for (const IntInterval &i : intervals) {
if (query_interval.Intersects(i)) {
results->push_back(i);
}
}
}
// Verify that IntervalTree::FindContainingPoint yields the same results as the naive
// brute-force O(n) algorithm.
static void VerifyFindContainingPoint(const vector<IntInterval> all_intervals,
const IntervalTree<IntTraits> &tree,
int query_point) {
vector<IntInterval> results;
tree.FindContainingPoint(query_point, &results);
std::sort(results.begin(), results.end(), CompareIntervals);
vector<IntInterval> brute_force;
FindContainingBruteForce(all_intervals, query_point, &brute_force);
std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);
SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" (q=%d)", query_point));
EXPECT_EQ(Stringify(brute_force), Stringify(results));
}
// Verify that IntervalTree::FindIntersectingInterval yields the same results as the naive
// brute-force O(n) algorithm.
static void VerifyFindIntersectingInterval(const vector<IntInterval> all_intervals,
const IntervalTree<IntTraits> &tree,
const IntInterval &query_interval) {
vector<IntInterval> results;
tree.FindIntersectingInterval(query_interval, &results);
std::sort(results.begin(), results.end(), CompareIntervals);
vector<IntInterval> brute_force;
FindIntersectingBruteForce(all_intervals, query_interval, &brute_force);
std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);
SCOPED_TRACE(Stringify(all_intervals) +
StringPrintf(" (q=[%d,%d])", query_interval.left, query_interval.right));
EXPECT_EQ(Stringify(brute_force), Stringify(results));
}
static vector<IntInterval> CreateRandomIntervals(int n = 100) {
vector<IntInterval> intervals;
for (int i = 0; i < n; i++) {
int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
int r = l + rand() % 20; // NOLINT(runtime/threadsafe_fn)
intervals.emplace_back(l, r, i);
}
return intervals;
}
TEST_F(TestIntervalTree, TestBasic) {
vector<IntInterval> intervals;
intervals.emplace_back(1, 2, 1);
intervals.emplace_back(3, 4, 2);
intervals.emplace_back(1, 4, 3);
IntervalTree<IntTraits> t(intervals);
for (int i = 0; i <= 5; i++) {
VerifyFindContainingPoint(intervals, t, i);
for (int j = i; j <= 5; j++) {
VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0));
}
}
}
TEST_F(TestIntervalTree, TestRandomized) {
SeedRandom();
// Generate 100 random intervals spanning 0-200 and build an interval tree from them.
vector<IntInterval> intervals = CreateRandomIntervals();
IntervalTree<IntTraits> t(intervals);
// Test that we get the correct result on every possible query.
for (int i = -1; i < 201; i++) {
VerifyFindContainingPoint(intervals, t, i);
}
// Test that we get the correct result for random intervals
for (int i = 0; i < 100; i++) {
int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
int r = l + rand() % 100; // NOLINT(runtime/threadsafe_fn)
VerifyFindIntersectingInterval(intervals, t, IntInterval(l, r));
}
}
TEST_F(TestIntervalTree, TestEmpty) {
vector<IntInterval> empty;
IntervalTree<IntTraits> t(empty);
VerifyFindContainingPoint(empty, t, 1);
VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0));
}
TEST_F(TestIntervalTree, TestBigO) {
#ifndef NDEBUG
LOG(WARNING) << "big-O results are not valid if DCHECK is enabled";
return;
#endif
SeedRandom();
LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch";
for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) {
vector<IntInterval> intervals = CreateRandomIntervals(num_intervals);
IntervalTree<IntTraits> t(intervals);
for (int num_queries = 1; num_queries < 2000; num_queries *= 2) {
vector<CountingQueryPoint> queries;
for (int i = 0; i < num_queries; i++) {
queries.emplace_back(rand() % 100);
}
std::sort(queries.begin(), queries.end(),
[](const CountingQueryPoint& a,
const CountingQueryPoint& b) {
return a.val < b.val;
});
// Test using batch algorithm.
int num_results_batch = 0;
t.ForEachIntervalContainingPoints(
queries,
[&](CountingQueryPoint query_point, const IntInterval& interval) {
num_results_batch++;
});
int num_comparisons_batch = 0;
for (const auto& q : queries) {
num_comparisons_batch += *q.count;
*q.count = 0;
}
// Test using one-by-one queries.
int num_results_simple = 0;
for (auto& q : queries) {
vector<IntInterval> intervals;
t.FindContainingPoint(q, &intervals);
num_results_simple += intervals.size();
}
int num_comparisons_simple = 0;
for (const auto& q : queries) {
num_comparisons_simple += *q.count;
}
ASSERT_EQ(num_results_simple, num_results_batch);
LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t"
<< num_comparisons_simple << "\t" << num_comparisons_batch;
}
}
}
TEST_F(TestIntervalTree, TestMultiQuery) {
SeedRandom();
const int kNumQueries = 1;
vector<IntInterval> intervals = CreateRandomIntervals(10);
IntervalTree<IntTraits> t(intervals);
// Generate random queries.
vector<int> queries;
for (int i = 0; i < kNumQueries; i++) {
queries.push_back(rand() % 100);
}
std::sort(queries.begin(), queries.end());
vector<pair<string, int>> results_simple;
for (int q : queries) {
vector<IntInterval> intervals;
t.FindContainingPoint(q, &intervals);
for (const auto& interval : intervals) {
results_simple.emplace_back(interval.ToString(), q);
}
}
vector<pair<string, int>> results_batch;
t.ForEachIntervalContainingPoints(
queries,
[&](int query_point, const IntInterval& interval) {
results_batch.emplace_back(interval.ToString(), query_point);
});
// Check the property that, when the batch query points are in sorted order,
// the results are grouped by interval, and within each interval, sorted by
// query point. Each interval may have at most two groups.
boost::optional<pair<string, int>> prev = boost::none;
std::map<string, int> intervals_seen;
for (int i = 0; i < results_batch.size(); i++) {
const auto& cur = results_batch[i];
// If it's another query point hitting the same interval,
// make sure the query points are returned in order.
if (prev && prev->first == cur.first) {
EXPECT_GE(cur.second, prev->second) << prev->first;
} else {
// It's the start of a new interval's data. Make sure that we don't
// see the same interval twice.
EXPECT_LE(++intervals_seen[cur.first], 2)
<< "Saw more than two groups for interval " << cur.first;
}
prev = cur;
}
std::sort(results_simple.begin(), results_simple.end());
std::sort(results_batch.begin(), results_batch.end());
ASSERT_EQ(results_simple, results_batch);
}
} // namespace kudu