blob: 286f0ca00392459afcc1e96c913dfd4013040617 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <iostream>
#include <sstream>
#include <algorithm>
#include "conditional_forward.hpp"
namespace datasketches {
template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
theta_intersection_base<EN, EK, P, S, CS, A>::theta_intersection_base(uint64_t seed, const P& policy, const A& allocator):
policy_(policy),
is_valid_(false),
table_(0, 0, resize_factor::X1, theta_constants::MAX_THETA, seed, allocator, false)
{}
template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
template<typename SS>
void theta_intersection_base<EN, EK, P, S, CS, A>::update(SS&& sketch) {
if (table_.is_empty_) return;
if (!sketch.is_empty() && sketch.get_seed_hash() != compute_seed_hash(table_.seed_)) throw std::invalid_argument("seed hash mismatch");
table_.is_empty_ |= sketch.is_empty();
table_.theta_ = std::min(table_.theta_, sketch.get_theta64());
if (is_valid_ && table_.num_entries_ == 0) return;
if (sketch.get_num_retained() == 0) {
is_valid_ = true;
table_ = hash_table(0, 0, resize_factor::X1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
return;
}
if (!is_valid_) { // first update, copy or move incoming sketch
is_valid_ = true;
const uint8_t lg_size = lg_size_from_count(sketch.get_num_retained(), theta_update_sketch_base<EN, EK, A>::REBUILD_THRESHOLD);
table_ = hash_table(lg_size, lg_size, resize_factor::X1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
for (auto& entry: sketch) {
auto result = table_.find(EK()(entry));
if (result.second) {
throw std::invalid_argument("duplicate key, possibly corrupted input sketch");
}
table_.insert(result.first, conditional_forward<SS>(entry));
}
if (table_.num_entries_ != sketch.get_num_retained()) throw std::invalid_argument("num entries mismatch, possibly corrupted input sketch");
} else { // intersection
const uint32_t max_matches = std::min(table_.num_entries_, sketch.get_num_retained());
std::vector<EN, A> matched_entries(table_.allocator_);
matched_entries.reserve(max_matches);
uint32_t match_count = 0;
uint32_t count = 0;
for (auto& entry: sketch) {
if (EK()(entry) < table_.theta_) {
auto result = table_.find(EK()(entry));
if (result.second) {
if (match_count == max_matches) throw std::invalid_argument("max matches exceeded, possibly corrupted input sketch");
policy_(*result.first, conditional_forward<SS>(entry));
matched_entries.push_back(std::move(*result.first));
++match_count;
}
} else if (sketch.is_ordered()) {
break; // early stop
}
++count;
}
if (count > sketch.get_num_retained()) {
throw std::invalid_argument(" more keys than expected, possibly corrupted input sketch");
} else if (!sketch.is_ordered() && count < sketch.get_num_retained()) {
throw std::invalid_argument(" fewer keys than expected, possibly corrupted input sketch");
}
if (match_count == 0) {
table_ = hash_table(0, 0, resize_factor::X1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
if (table_.theta_ == theta_constants::MAX_THETA) table_.is_empty_ = true;
} else {
const uint8_t lg_size = lg_size_from_count(match_count, theta_update_sketch_base<EN, EK, A>::REBUILD_THRESHOLD);
table_ = hash_table(lg_size, lg_size, resize_factor::X1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
for (uint32_t i = 0; i < match_count; i++) {
auto result = table_.find(EK()(matched_entries[i]));
table_.insert(result.first, std::move(matched_entries[i]));
}
}
}
}
template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
CS theta_intersection_base<EN, EK, P, S, CS, A>::get_result(bool ordered) const {
if (!is_valid_) throw std::invalid_argument("calling get_result() before calling update() is undefined");
std::vector<EN, A> entries(table_.allocator_);
if (table_.num_entries_ > 0) {
entries.reserve(table_.num_entries_);
std::copy_if(table_.begin(), table_.end(), std::back_inserter(entries), key_not_zero<EN, EK>());
if (ordered) std::sort(entries.begin(), entries.end(), comparator());
}
return CS(table_.is_empty_, ordered, compute_seed_hash(table_.seed_), table_.theta_, std::move(entries));
}
template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
bool theta_intersection_base<EN, EK, P, S, CS, A>::has_result() const {
return is_valid_;
}
template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
const P& theta_intersection_base<EN, EK, P, S, CS, A>::get_policy() const {
return policy_;
}
} /* namespace datasketches */