blob: 76a3d36787ec308da2d572abd9d2154c74eb715b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from math import floor, ceil
from datasketches import ebpps_sketch, PyIntsSerDe, PyStringsSerDe
class EbppsTest(unittest.TestCase):
def test_ebpps_example(self):
k = 50 # a small value so we can easily fill the sketch
sk = ebpps_sketch(k)
# ebpps sampling reduces to standard reservoir sampling
# if the items are all equally weighted, although the
# algorithm will be significantly slower than an optimized
# reservoir sampler
n = 5 * k
for i in range(0, n):
sk.update(i)
# we should have k samples since they're equally weighted
self.assertAlmostEqual(k, sk.c, places=12)
# we can also add a heavy item, using a negative value to
# be able to identify the item later. Keep in mind that
# "heavy" is a relative concept, so using a fixed
# multiple of n may not be considered a heavy item for
# larger values of n
sk.update(-1, 1000 * n)
self.assertEqual(k, sk.k)
self.assertLess(sk.c, k)
self.assertEqual(n + 1, sk.n)
self.assertFalse(sk.is_empty())
# we can easily get the list of items in the sample
items = list(sk)
self.assertLessEqual(len(items), k)
count = 0
for items in sk:
count = count + 1
self.assertTrue(count == floor(sk.c) or count == ceil(sk.c))
# next we'll create a second, smaller sketch with
# only heavier items relative to the previous sketch
k2 = 5
sk2 = ebpps_sketch(k2)
# for weight, use the sum of all items >=0 from before (n)
wt = n
for i in range(0, k2 + 1):
sk2.update((2 * n) + i, wt)
# now merge the sketches, noting how the
# resulting k will be the smaller of the two.
input_sk_c = sk.c
sk.merge(sk2)
self.assertEqual(n + k2 + 2, sk.n)
self.assertFalse(sk.is_empty())
self.assertEqual(sk.k, k2)
self.assertLess(sk.k, k)
# the expected number of samples post-merge may be larger than
# with the input sketch
self.assertGreater(sk.c, input_sk_c)
# we can dump a sumamry of information in the sketch
print(sk)
# if we want to print the list of items, there must be a
# __str__() method for each item (which need not be the same
# type; they're all generic python objects when used from
# python), otherwise you may trigger an exception.
# to_string() is provided as a convenience to avoid direct
# calls to __str__() with parameters.
print(sk.to_string(True))
# finally, we can serialize the sketch by providing an
# appropriate serde class.
expected_size = sk.get_serialized_size_bytes(PyIntsSerDe())
b = sk.serialize(PyIntsSerDe())
self.assertEqual(expected_size, len(b))
# if we try to deserialize with the wrong serde, things break
try:
ebpps_sketch.deserialize(b, PyStringsSerDe())
self.fail()
except:
# expected; do nothing
self.assertTrue(True)
# using the correct serde gives us back a copy of the original
rebuilt = ebpps_sketch.deserialize(b, PyIntsSerDe())
self.assertEqual(sk.k, rebuilt.k)
self.assertEqual(sk.c, rebuilt.c)
self.assertEqual(sk.n, rebuilt.n)
# check that printing works as expected
self.assertGreater(len(sk.to_string(True)), 0)
self.assertEqual(len(sk.__str__()), len(sk.to_string()))
if __name__ == '__main__':
unittest.main()