aod intersection
diff --git a/tuple/CMakeLists.txt b/tuple/CMakeLists.txt
index 580ffa9..48dedae 100644
--- a/tuple/CMakeLists.txt
+++ b/tuple/CMakeLists.txt
@@ -38,7 +38,8 @@
list(APPEND tuple_HEADERS "include/tuple_intersection.hpp;include/tuple_intersection_impl.hpp")
list(APPEND tuple_HEADERS "include/tuple_a_not_b.hpp;include/tuple_a_not_b_impl.hpp")
list(APPEND tuple_HEADERS "include/array_of_doubles_sketch.hpp;include/array_of_doubles_sketch_impl.hpp")
-list(APPEND tuple_HEADERS "include/array_of_doubles_union.hpp")
+list(APPEND tuple_HEADERS "include/array_of_doubles_union.hpp;include/array_of_doulbes_union_impl.hpp")
+list(APPEND tuple_HEADERS "include/array_of_doubles_intersection.hpp;include/array_of_doulbes_intersection_impl.hpp")
list(APPEND tuple_HEADERS "include/theta_update_sketch_base.hpp;include/theta_update_sketch_base_impl.hpp")
list(APPEND tuple_HEADERS "include/theta_union_base.hpp;include/theta_union_base_impl.hpp")
list(APPEND tuple_HEADERS "include/theta_intersection_base.hpp;include/theta_intersection_base_impl.hpp")
@@ -68,6 +69,9 @@
${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_sketch.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_sketch_impl.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_union.hpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_union_impl.hpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_intersection.hpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/array_of_doubles_intersection_impl.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/theta_update_sketch_base.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/theta_update_sketch_base_impl.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/theta_union_base.hpp
diff --git a/tuple/include/array_of_doubles_intersection.hpp b/tuple/include/array_of_doubles_intersection.hpp
new file mode 100644
index 0000000..008d9d6
--- /dev/null
+++ b/tuple/include/array_of_doubles_intersection.hpp
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef ARRAY_OF_DOUBLES_INTERSECTION_HPP_
+#define ARRAY_OF_DOUBLES_INTERSECTION_HPP_
+
+#include <vector>
+#include <memory>
+
+#include "array_of_doubles_sketch.hpp"
+#include "tuple_intersection.hpp"
+
+namespace datasketches {
+
+template<
+ typename Policy,
+ typename Allocator = std::allocator<double>
+>
+class array_of_doubles_intersection: public tuple_intersection<std::vector<double, Allocator>, Policy, AllocVectorDouble<Allocator>> {
+public:
+ using Summary = std::vector<double, Allocator>;
+ using AllocSummary = AllocVectorDouble<Allocator>;
+ using Base = tuple_intersection<Summary, Policy, AllocSummary>;
+ using CompactSketch = compact_array_of_doubles_sketch_alloc<Allocator>;
+ using resize_factor = theta_constants::resize_factor;
+
+ explicit array_of_doubles_intersection(uint64_t seed = DEFAULT_SEED, const Policy& policy = Policy(), const Allocator& allocator = Allocator());
+
+ CompactSketch get_result(bool ordered = true) const;
+};
+
+} /* namespace datasketches */
+
+#include "array_of_doubles_intersection_impl.hpp"
+
+#endif
diff --git a/tuple/include/array_of_doubles_intersection_impl.hpp b/tuple/include/array_of_doubles_intersection_impl.hpp
new file mode 100644
index 0000000..7cd2472
--- /dev/null
+++ b/tuple/include/array_of_doubles_intersection_impl.hpp
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+namespace datasketches {
+
+template<typename P, typename A>
+array_of_doubles_intersection<P, A>::array_of_doubles_intersection(uint64_t seed, const P& policy, const A& allocator):
+Base(seed, policy, allocator) {}
+
+template<typename P, typename A>
+auto array_of_doubles_intersection<P, A>::get_result(bool ordered) const -> CompactSketch {
+ return compact_array_of_doubles_sketch_alloc<A>(this->state_.get_policy().get_policy().get_num_values(), Base::get_result(ordered));
+}
+
+} /* namespace datasketches */
diff --git a/tuple/include/array_of_doubles_union.hpp b/tuple/include/array_of_doubles_union.hpp
index 2b4c77c..a70a015 100644
--- a/tuple/include/array_of_doubles_union.hpp
+++ b/tuple/include/array_of_doubles_union.hpp
@@ -28,9 +28,9 @@
namespace datasketches {
-template<typename A>
-struct array_of_doubles_union_policy {
- array_of_doubles_union_policy(uint8_t num_values = 1): num_values_(num_values) {}
+template<typename A = std::allocator<double>>
+struct array_of_doubles_union_policy_alloc {
+ array_of_doubles_union_policy_alloc(uint8_t num_values = 1): num_values_(num_values) {}
void operator()(std::vector<double, A>& summary, const std::vector<double, A>& other) const {
for (size_t i = 0; i < summary.size(); ++i) {
@@ -45,10 +45,12 @@
uint8_t num_values_;
};
+using array_of_doubles_union_policy = array_of_doubles_union_policy_alloc<>;
+
template<typename Allocator = std::allocator<double>>
-class array_of_doubles_union_alloc: public tuple_union<std::vector<double, Allocator>, array_of_doubles_union_policy<Allocator>, AllocVectorDouble<Allocator>> {
+class array_of_doubles_union_alloc: public tuple_union<std::vector<double, Allocator>, array_of_doubles_union_policy_alloc<Allocator>, AllocVectorDouble<Allocator>> {
public:
- using Policy = array_of_doubles_union_policy<Allocator>;
+ using Policy = array_of_doubles_union_policy_alloc<Allocator>;
using Base = tuple_union<std::vector<double, Allocator>, Policy, AllocVectorDouble<Allocator>>;
using CompactSketch = compact_array_of_doubles_sketch_alloc<Allocator>;
using resize_factor = theta_constants::resize_factor;
@@ -63,9 +65,9 @@
};
template<typename Allocator>
-class array_of_doubles_union_alloc<Allocator>::builder: public tuple_base_builder<builder, array_of_doubles_union_policy<Allocator>, Allocator> {
+class array_of_doubles_union_alloc<Allocator>::builder: public tuple_base_builder<builder, array_of_doubles_union_policy_alloc<Allocator>, Allocator> {
public:
- builder(const array_of_doubles_union_policy<Allocator>& policy = array_of_doubles_union_policy<Allocator>(), const Allocator& allocator = Allocator());
+ builder(const array_of_doubles_union_policy_alloc<Allocator>& policy = array_of_doubles_union_policy_alloc<Allocator>(), const Allocator& allocator = Allocator());
array_of_doubles_union_alloc<Allocator> build() const;
};
diff --git a/tuple/include/array_of_doubles_union_impl.hpp b/tuple/include/array_of_doubles_union_impl.hpp
index 80a385a..57899d9 100644
--- a/tuple/include/array_of_doubles_union_impl.hpp
+++ b/tuple/include/array_of_doubles_union_impl.hpp
@@ -32,8 +32,8 @@
// builder
template<typename A>
-array_of_doubles_union_alloc<A>::builder::builder(const array_of_doubles_union_policy<A>& policy, const A& allocator):
-tuple_base_builder<builder, array_of_doubles_union_policy<A>, A>(policy, allocator) {}
+array_of_doubles_union_alloc<A>::builder::builder(const Policy& policy, const A& allocator):
+tuple_base_builder<builder, Policy, A>(policy, allocator) {}
template<typename A>
array_of_doubles_union_alloc<A> array_of_doubles_union_alloc<A>::builder::build() const {
diff --git a/tuple/include/theta_intersection_base.hpp b/tuple/include/theta_intersection_base.hpp
index 1313e11..c034590 100644
--- a/tuple/include/theta_intersection_base.hpp
+++ b/tuple/include/theta_intersection_base.hpp
@@ -44,6 +44,8 @@
bool has_result() const;
+ const Policy& get_policy() const;
+
private:
Policy policy_;
bool is_valid_;
diff --git a/tuple/include/theta_intersection_base_impl.hpp b/tuple/include/theta_intersection_base_impl.hpp
index 970ec50..0b5817c 100644
--- a/tuple/include/theta_intersection_base_impl.hpp
+++ b/tuple/include/theta_intersection_base_impl.hpp
@@ -113,4 +113,9 @@
return is_valid_;
}
+template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
+const P& theta_intersection_base<EN, EK, P, S, CS, A>::get_policy() const {
+ return policy_;
+}
+
} /* namespace datasketches */
diff --git a/tuple/include/tuple_intersection.hpp b/tuple/include/tuple_intersection.hpp
index e03e288..966ea9f 100644
--- a/tuple/include/tuple_intersection.hpp
+++ b/tuple/include/tuple_intersection.hpp
@@ -61,6 +61,7 @@
void operator()(Entry& internal_entry, Entry&& incoming_entry) const {
policy_(internal_entry.second, std::move(incoming_entry.second));
}
+ const Policy& get_policy() const { return policy_; }
Policy policy_;
};
@@ -92,7 +93,7 @@
*/
bool has_result() const;
-private:
+protected:
State state_;
};
diff --git a/tuple/test/array_of_doubles_sketch_test.cpp b/tuple/test/array_of_doubles_sketch_test.cpp
index 7c2857f..98ecf6a 100644
--- a/tuple/test/array_of_doubles_sketch_test.cpp
+++ b/tuple/test/array_of_doubles_sketch_test.cpp
@@ -25,6 +25,7 @@
#include <catch.hpp>
#include <array_of_doubles_sketch.hpp>
#include <array_of_doubles_union.hpp>
+#include <array_of_doubles_intersection.hpp>
namespace datasketches {
@@ -263,4 +264,20 @@
REQUIRE(result.get_estimate() == Approx(1500).margin(0.01));
}
+TEST_CASE("aod intersection: half overlap", "[tuple_sketch]") {
+ std::vector<double> a = {1};
+
+ auto update_sketch1 = update_array_of_doubles_sketch::builder().build();
+ for (int i = 0; i < 1000; ++i) update_sketch1.update(i, a);
+
+ auto update_sketch2 = update_array_of_doubles_sketch::builder().build();
+ for (int i = 500; i < 1500; ++i) update_sketch2.update(i, a);
+
+ array_of_doubles_intersection<array_of_doubles_union_policy> intersection;
+ intersection.update(update_sketch1);
+ intersection.update(update_sketch2);
+ auto result = intersection.get_result();
+ REQUIRE(result.get_estimate() == Approx(500).margin(0.01));
+}
+
} /* namespace datasketches */