aod union, allocator fixes
diff --git a/tuple/include/array_of_doubles_sketch.hpp b/tuple/include/array_of_doubles_sketch.hpp
index f1ada69..a51b46b 100644
--- a/tuple/include/array_of_doubles_sketch.hpp
+++ b/tuple/include/array_of_doubles_sketch.hpp
@@ -30,34 +30,36 @@
// equivalent of ArrayOfDoublesSketch in Java
-template<typename A = std::allocator<std::vector<double>>>
+template<typename A = std::allocator<double>>
class array_of_doubles_update_policy {
public:
array_of_doubles_update_policy(uint8_t num_values = 1, const A& allocator = A()):
- allocator(allocator), num_values(num_values) {}
- std::vector<double> create() const {
- return std::vector<double>(num_values, 0, allocator);
+ allocator_(allocator), num_values_(num_values) {}
+ std::vector<double, A> create() const {
+ return std::vector<double, A>(num_values_, 0, allocator_);
}
- void update(std::vector<double>& summary, const std::vector<double>& update) const {
- for (uint8_t i = 0; i < num_values; ++i) summary[i] += update[i];
+ void update(std::vector<double, A>& summary, const std::vector<double, A>& update) const {
+ for (uint8_t i = 0; i < num_values_; ++i) summary[i] += update[i];
}
uint8_t get_num_values() const {
- return num_values;
+ return num_values_;
}
private:
- A allocator;
- uint8_t num_values;
+ A allocator_;
+ uint8_t num_values_;
};
// forward declaration
template<typename A> class compact_array_of_doubles_sketch_alloc;
-template<typename A = std::allocator<std::vector<double>>>
-class update_array_of_doubles_sketch_alloc: public update_tuple_sketch<std::vector<double>, std::vector<double>,
-array_of_doubles_update_policy<A>, A> {
+template<typename A> using AllocVectorDouble = typename std::allocator_traits<A>::template rebind_alloc<std::vector<double, A>>;
+
+template<typename A = std::allocator<double>>
+class update_array_of_doubles_sketch_alloc: public update_tuple_sketch<std::vector<double, A>, std::vector<double, A>,
+array_of_doubles_update_policy<A>, AllocVectorDouble<A>> {
public:
- using Base = update_tuple_sketch<std::vector<double>, std::vector<double>, array_of_doubles_update_policy<A>, A>;
+ using Base = update_tuple_sketch<std::vector<double, A>, std::vector<double, A>, array_of_doubles_update_policy<A>, AllocVectorDouble<A>>;
using resize_factor = typename Base::resize_factor;
class builder;
@@ -81,10 +83,10 @@
update_array_of_doubles_sketch_alloc<A> build() const;
};
-template<typename A = std::allocator<std::vector<double>>>
-class compact_array_of_doubles_sketch_alloc: public compact_tuple_sketch<std::vector<double>, A> {
+template<typename A = std::allocator<double>>
+class compact_array_of_doubles_sketch_alloc: public compact_tuple_sketch<std::vector<double, A>, AllocVectorDouble<A>> {
public:
- using Base = compact_tuple_sketch<std::vector<double>, A>;
+ using Base = compact_tuple_sketch<std::vector<double, A>, AllocVectorDouble<A>>;
using Entry = typename Base::Entry;
using AllocEntry = typename Base::AllocEntry;
using AllocU64 = typename Base::AllocU64;
@@ -109,6 +111,7 @@
// for internal use
compact_array_of_doubles_sketch_alloc(bool is_empty, bool is_ordered, uint16_t seed_hash, uint64_t theta, std::vector<Entry, AllocEntry>&& entries, uint8_t num_values);
+ compact_array_of_doubles_sketch_alloc(uint8_t num_values, Base&& base);
private:
uint8_t num_values_;
};
diff --git a/tuple/include/array_of_doubles_sketch_impl.hpp b/tuple/include/array_of_doubles_sketch_impl.hpp
index 455a52c..16fa925 100644
--- a/tuple/include/array_of_doubles_sketch_impl.hpp
+++ b/tuple/include/array_of_doubles_sketch_impl.hpp
@@ -59,6 +59,10 @@
Base(is_empty, is_ordered, seed_hash, theta, std::move(entries)), num_values_(num_values) {}
template<typename A>
+compact_array_of_doubles_sketch_alloc<A>::compact_array_of_doubles_sketch_alloc(uint8_t num_values, Base&& base):
+Base(std::move(base)), num_values_(num_values) {}
+
+template<typename A>
uint8_t compact_array_of_doubles_sketch_alloc<A>::get_num_values() const {
return num_values_;
}
@@ -172,7 +176,7 @@
std::vector<uint64_t, AllocU64> keys(num_entries, 0, allocator);
is.read(reinterpret_cast<char*>(keys.data()), num_entries * sizeof(uint64_t));
for (size_t i = 0; i < num_entries; ++i) {
- std::vector<double> summary(num_values, 0, allocator);
+ std::vector<double, A> summary(num_values, 0, allocator);
is.read(reinterpret_cast<char*>(summary.data()), num_values * sizeof(double));
entries.push_back(Entry(keys[i], std::move(summary)));
}
@@ -221,7 +225,7 @@
std::vector<uint64_t, AllocU64> keys(num_entries, 0, allocator);
ptr += copy_from_mem(ptr, keys.data(), sizeof(uint64_t) * num_entries);
for (size_t i = 0; i < num_entries; ++i) {
- std::vector<double> summary(num_values, 0, allocator);
+ std::vector<double, A> summary(num_values, 0, allocator);
ptr += copy_from_mem(ptr, summary.data(), num_values * sizeof(double));
entries.push_back(Entry(keys[i], std::move(summary)));
}
diff --git a/tuple/include/array_of_doubles_union.hpp b/tuple/include/array_of_doubles_union.hpp
index eafcab1..2b4c77c 100644
--- a/tuple/include/array_of_doubles_union.hpp
+++ b/tuple/include/array_of_doubles_union.hpp
@@ -23,24 +23,57 @@
#include <vector>
#include <memory>
+#include "array_of_doubles_sketch.hpp"
#include "tuple_union.hpp"
namespace datasketches {
+template<typename A>
struct array_of_doubles_union_policy {
- void operator()(std::vector<double>& summary, const std::vector<double>& other) const {
+ array_of_doubles_union_policy(uint8_t num_values = 1): num_values_(num_values) {}
+
+ void operator()(std::vector<double, A>& summary, const std::vector<double, A>& other) const {
for (size_t i = 0; i < summary.size(); ++i) {
summary[i] += other[i];
}
}
+
+ uint8_t get_num_values() const {
+ return num_values_;
+ }
+private:
+ uint8_t num_values_;
};
-template<typename Allocator = std::allocator<std::vector<double>>>
-using array_of_doubles_union_alloc = tuple_union<std::vector<double>, array_of_doubles_union_policy, Allocator>;
+template<typename Allocator = std::allocator<double>>
+class array_of_doubles_union_alloc: public tuple_union<std::vector<double, Allocator>, array_of_doubles_union_policy<Allocator>, AllocVectorDouble<Allocator>> {
+public:
+ using Policy = array_of_doubles_union_policy<Allocator>;
+ using Base = tuple_union<std::vector<double, Allocator>, Policy, AllocVectorDouble<Allocator>>;
+ using CompactSketch = compact_array_of_doubles_sketch_alloc<Allocator>;
+ using resize_factor = theta_constants::resize_factor;
+
+ class builder;
+
+ CompactSketch get_result(bool ordered = true) const;
+
+private:
+ // for builder
+ array_of_doubles_union_alloc(uint8_t lg_cur_size, uint8_t lg_nom_size, resize_factor rf, uint64_t theta, uint64_t seed, const Policy& policy, const Allocator& allocator);
+};
+
+template<typename Allocator>
+class array_of_doubles_union_alloc<Allocator>::builder: public tuple_base_builder<builder, array_of_doubles_union_policy<Allocator>, Allocator> {
+public:
+ builder(const array_of_doubles_union_policy<Allocator>& policy = array_of_doubles_union_policy<Allocator>(), const Allocator& allocator = Allocator());
+ array_of_doubles_union_alloc<Allocator> build() const;
+};
// alias with default allocator
using array_of_doubles_union = array_of_doubles_union_alloc<>;
} /* namespace datasketches */
+#include "array_of_doubles_union_impl.hpp"
+
#endif
diff --git a/tuple/include/array_of_doubles_union_impl.hpp b/tuple/include/array_of_doubles_union_impl.hpp
new file mode 100644
index 0000000..80a385a
--- /dev/null
+++ b/tuple/include/array_of_doubles_union_impl.hpp
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+namespace datasketches {
+
+template<typename A>
+array_of_doubles_union_alloc<A>::array_of_doubles_union_alloc(uint8_t lg_cur_size, uint8_t lg_nom_size, resize_factor rf, uint64_t theta, uint64_t seed, const Policy& policy, const A& allocator):
+Base(lg_cur_size, lg_nom_size, rf, theta, seed, policy, allocator)
+{}
+
+template<typename A>
+auto array_of_doubles_union_alloc<A>::get_result(bool ordered) const -> CompactSketch {
+ return compact_array_of_doubles_sketch_alloc<A>(this->state_.get_policy().get_policy().get_num_values(), Base::get_result(ordered));
+}
+
+// builder
+
+template<typename A>
+array_of_doubles_union_alloc<A>::builder::builder(const array_of_doubles_union_policy<A>& policy, const A& allocator):
+tuple_base_builder<builder, array_of_doubles_union_policy<A>, A>(policy, allocator) {}
+
+template<typename A>
+array_of_doubles_union_alloc<A> array_of_doubles_union_alloc<A>::builder::build() const {
+ return array_of_doubles_union_alloc<A>(this->starting_lg_size(), this->lg_k_, this->rf_, this->starting_theta(), this->seed_, this->policy_, this->allocator_);
+}
+
+} /* namespace datasketches */
diff --git a/tuple/include/theta_union_base.hpp b/tuple/include/theta_union_base.hpp
index 72d6eca..3072630 100644
--- a/tuple/include/theta_union_base.hpp
+++ b/tuple/include/theta_union_base.hpp
@@ -45,6 +45,8 @@
CompactSketch get_result(bool ordered = true) const;
+ const Policy& get_policy() const;
+
private:
Policy policy_;
hash_table table_;
diff --git a/tuple/include/theta_union_base_impl.hpp b/tuple/include/theta_union_base_impl.hpp
index 2bc458d..a86ba3e 100644
--- a/tuple/include/theta_union_base_impl.hpp
+++ b/tuple/include/theta_union_base_impl.hpp
@@ -76,4 +76,9 @@
return CS(table_.is_empty_, ordered, compute_seed_hash(table_.seed_), theta, std::move(entries));
}
+template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
+const P& theta_union_base<EN, EK, P, S, CS, A>::get_policy() const {
+ return policy_;
+}
+
} /* namespace datasketches */
diff --git a/tuple/include/tuple_union.hpp b/tuple/include/tuple_union.hpp
index 983c724..d9eff26 100644
--- a/tuple/include/tuple_union.hpp
+++ b/tuple/include/tuple_union.hpp
@@ -57,6 +57,7 @@
void operator()(Entry& internal_entry, Entry&& incoming_entry) const {
policy_(internal_entry.second, std::move(incoming_entry.second));
}
+ const Policy& get_policy() const { return policy_; }
Policy policy_;
};
@@ -79,7 +80,7 @@
*/
CompactSketch get_result(bool ordered = true) const;
-private:
+protected:
State state_;
// for builder
diff --git a/tuple/test/array_of_doubles_sketch_test.cpp b/tuple/test/array_of_doubles_sketch_test.cpp
index 63dcea3..7c2857f 100644
--- a/tuple/test/array_of_doubles_sketch_test.cpp
+++ b/tuple/test/array_of_doubles_sketch_test.cpp
@@ -256,7 +256,7 @@
auto update_sketch2 = update_array_of_doubles_sketch::builder().build();
for (int i = 500; i < 1500; ++i) update_sketch2.update(i, a);
- auto u = tuple_union<std::vector<double>, array_of_doubles_union_policy>::builder().build();
+ auto u = array_of_doubles_union::builder().build();
u.update(update_sketch1);
u.update(update_sketch2);
auto result = u.get_result();