Merge pull request #43 from apache/wrapped_compact_theta_sketch

Wrapped compact theta sketch
diff --git a/Dockerfile b/Dockerfile
index 0e6e1aa..bd58322 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -18,7 +18,7 @@
 ARG BASE_IMAGE_VERSION=latest
 
 ARG DATASKETCHES_CPP_HASH=8135b65408947694e13bd131038889e439847aa2
-ARG DATASKETCHES_CPP_VERSION=3.0.0
+ARG DATASKETCHES_CPP_VERSION=3.1.0
 
 FROM postgres:$BASE_IMAGE_VERSION
 
diff --git a/src/theta_sketch_c_adapter.cpp b/src/theta_sketch_c_adapter.cpp
index 8e8d825..dd635e9 100644
--- a/src/theta_sketch_c_adapter.cpp
+++ b/src/theta_sketch_c_adapter.cpp
@@ -32,6 +32,7 @@
 typedef datasketches::theta_union_alloc<palloc_allocator<uint64_t>> theta_union_pg;
 typedef datasketches::theta_intersection_alloc<palloc_allocator<uint64_t>> theta_intersection_pg;
 typedef datasketches::theta_a_not_b_alloc<palloc_allocator<uint64_t>> theta_a_not_b_pg;
+typedef datasketches::wrapped_compact_theta_sketch_alloc<palloc_allocator<uint64_t>> wrapped_compact_theta_sketch_pg;
 
 void* theta_sketch_new_default() {
   try {
@@ -175,9 +176,9 @@
   }
 }
 
-void theta_union_update(void* unionptr, const void* sketchptr) {
+void theta_union_update(void* unionptr, const void* buffer, unsigned length) {
   try {
-    static_cast<theta_union_pg*>(unionptr)->update(std::move(*static_cast<const theta_sketch_pg*>(sketchptr)));
+    static_cast<theta_union_pg*>(unionptr)->update(wrapped_compact_theta_sketch_pg::wrap(buffer, length));
   } catch (std::exception& e) {
     pg_error(e.what());
   }
@@ -210,9 +211,9 @@
   }
 }
 
-void theta_intersection_update(void* interptr, const void* sketchptr) {
+void theta_intersection_update(void* interptr, const void* buffer, unsigned length) {
   try {
-    static_cast<theta_intersection_pg*>(interptr)->update(*static_cast<const theta_sketch_pg*>(sketchptr));
+    static_cast<theta_intersection_pg*>(interptr)->update(wrapped_compact_theta_sketch_pg::wrap(buffer, length));
   } catch (std::exception& e) {
     pg_error(e.what());
   }
@@ -227,12 +228,12 @@
   pg_unreachable();
 }
 
-void* theta_a_not_b(const void* sketchptr1, const void* sketchptr2) {
+void* theta_a_not_b(const void* buffer1, unsigned length1, const void* buffer2, unsigned length2) {
   try {
     theta_a_not_b_pg a_not_b;
     return new (palloc(sizeof(compact_theta_sketch_pg))) compact_theta_sketch_pg(a_not_b.compute(
-      *static_cast<const theta_sketch_pg*>(sketchptr1),
-      *static_cast<const theta_sketch_pg*>(sketchptr2)
+      wrapped_compact_theta_sketch_pg::wrap(buffer1, length1),
+      wrapped_compact_theta_sketch_pg::wrap(buffer2, length2)
     ));
   } catch (std::exception& e) {
     pg_error(e.what());
diff --git a/src/theta_sketch_c_adapter.h b/src/theta_sketch_c_adapter.h
index 2a14f98..5d6eef1 100644
--- a/src/theta_sketch_c_adapter.h
+++ b/src/theta_sketch_c_adapter.h
@@ -44,15 +44,15 @@
 void* theta_union_new_default();
 void* theta_union_new(unsigned lg_k);
 void theta_union_delete(void* unionptr);
-void theta_union_update(void* unionptr, const void* sketchptr);
+void theta_union_update(void* unionptr, const void* buffer, unsigned length);
 void* theta_union_get_result(const void* unionptr);
 
 void* theta_intersection_new_default();
 void theta_intersection_delete(void* interptr);
-void theta_intersection_update(void* interptr, const void* sketchptr);
+void theta_intersection_update(void* interptr, const void* buffer, unsigned length);
 void* theta_intersection_get_result(const void* interptr);
 
-void* theta_a_not_b(const void* sketchptr1, const void* sketchptr2);
+void* theta_a_not_b(const void* buffer1, unsigned length1, const void* buffer2, unsigned length2);
 
 #ifdef __cplusplus
 }
diff --git a/src/theta_sketch_pg_functions.c b/src/theta_sketch_pg_functions.c
index 00b8195..27ca259 100644
--- a/src/theta_sketch_pg_functions.c
+++ b/src/theta_sketch_pg_functions.c
@@ -165,7 +165,6 @@
 Datum pg_theta_sketch_intersection_agg(PG_FUNCTION_ARGS) {
   void* interptr;
   bytea* sketch_bytes;
-  void* sketchptr;
 
   MemoryContext oldcontext;
   MemoryContext aggcontext;
@@ -188,9 +187,7 @@
   }
 
   sketch_bytes = PG_GETARG_BYTEA_P(1);
-  sketchptr = theta_sketch_deserialize(VARDATA(sketch_bytes), VARSIZE(sketch_bytes) - VARHDRSZ);
-  theta_intersection_update(interptr, sketchptr);
-  theta_sketch_delete(sketchptr);
+  theta_intersection_update(interptr, VARDATA(sketch_bytes), VARSIZE(sketch_bytes) - VARHDRSZ);
 
   MemoryContextSwitchTo(oldcontext);
 
@@ -200,7 +197,6 @@
 Datum pg_theta_sketch_union_agg(PG_FUNCTION_ARGS) {
   void* unionptr;
   bytea* sketch_bytes;
-  void* sketchptr;
   int lg_k;
 
   MemoryContext oldcontext;
@@ -225,9 +221,7 @@
   }
 
   sketch_bytes = PG_GETARG_BYTEA_P(1);
-  sketchptr = theta_sketch_deserialize(VARDATA(sketch_bytes), VARSIZE(sketch_bytes) - VARHDRSZ);
-  theta_union_update(unionptr, sketchptr);
-  theta_sketch_delete(sketchptr);
+  theta_union_update(unionptr, VARDATA(sketch_bytes), VARSIZE(sketch_bytes) - VARHDRSZ);
 
   MemoryContextSwitchTo(oldcontext);
 
@@ -339,8 +333,6 @@
 Datum pg_theta_sketch_union(PG_FUNCTION_ARGS) {
   const bytea* bytes_in1;
   const bytea* bytes_in2;
-  void* sketchptr1;
-  void* sketchptr2;
   void* unionptr;
   void* sketchptr;
   struct ptr_with_size bytes_out;
@@ -350,15 +342,11 @@
   unionptr = lg_k ? theta_union_new(lg_k) : theta_union_new_default();
   if (!PG_ARGISNULL(0)) {
     bytes_in1 = PG_GETARG_BYTEA_P(0);
-    sketchptr1 = theta_sketch_deserialize(VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ);
-    theta_union_update(unionptr, sketchptr1);
-    theta_sketch_delete(sketchptr1);
+    theta_union_update(unionptr, VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ);
   }
   if (!PG_ARGISNULL(1)) {
     bytes_in2 = PG_GETARG_BYTEA_P(1);
-    sketchptr2 = theta_sketch_deserialize(VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
-    theta_union_update(unionptr, sketchptr2);
-    theta_sketch_delete(sketchptr2);
+    theta_union_update(unionptr, VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
   }
   sketchptr = theta_union_get_result(unionptr);
   theta_union_delete(unionptr);
@@ -371,8 +359,6 @@
 Datum pg_theta_sketch_intersection(PG_FUNCTION_ARGS) {
   const bytea* bytes_in1;
   const bytea* bytes_in2;
-  void* sketchptr1;
-  void* sketchptr2;
   void* interptr;
   void* sketchptr;
   struct ptr_with_size bytes_out;
@@ -380,15 +366,11 @@
   interptr = theta_intersection_new_default();
   if (!PG_ARGISNULL(0)) {
     bytes_in1 = PG_GETARG_BYTEA_P(0);
-    sketchptr1 = theta_sketch_deserialize(VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ);
-    theta_intersection_update(interptr, sketchptr1);
-    theta_sketch_delete(sketchptr1);
+    theta_intersection_update(interptr, VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ);
   }
   if (!PG_ARGISNULL(1)) {
     bytes_in2 = PG_GETARG_BYTEA_P(1);
-    sketchptr2 = theta_sketch_deserialize(VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
-    theta_intersection_update(interptr, sketchptr2);
-    theta_sketch_delete(sketchptr2);
+    theta_intersection_update(interptr, VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
   }
   sketchptr = theta_intersection_get_result(interptr);
   theta_intersection_delete(interptr);
@@ -401,8 +383,6 @@
 Datum pg_theta_sketch_a_not_b(PG_FUNCTION_ARGS) {
   const bytea* bytes_in1;
   const bytea* bytes_in2;
-  void* sketchptr1;
-  void* sketchptr2;
   void* sketchptr;
   struct ptr_with_size bytes_out;
 
@@ -411,12 +391,8 @@
   }
 
   bytes_in1 = PG_GETARG_BYTEA_P(0);
-  sketchptr1 = theta_sketch_deserialize(VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ);
   bytes_in2 = PG_GETARG_BYTEA_P(1);
-  sketchptr2 = theta_sketch_deserialize(VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
-  sketchptr = theta_a_not_b(sketchptr1, sketchptr2);
-  theta_sketch_delete(sketchptr1);
-  theta_sketch_delete(sketchptr2);
+  sketchptr = theta_a_not_b(VARDATA(bytes_in1), VARSIZE(bytes_in1) - VARHDRSZ, VARDATA(bytes_in2), VARSIZE(bytes_in2) - VARHDRSZ);
   bytes_out = theta_sketch_serialize(sketchptr, VARHDRSZ);
   theta_sketch_delete(sketchptr);
   SET_VARSIZE(bytes_out.ptr, bytes_out.size);