blob: 4d8a276325c6bef1857704e5f28bd03d28d4e11a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from datasketches import density_sketch, KernelFunction
import numpy as np
class UnitSphereKernel(KernelFunction):
def __call__(self, a: np.array, b: np.array) -> float:
if np.linalg.norm(a - b) < 1.0:
return 1.0
else:
return 0.0
class densityTest(unittest.TestCase):
def test_density_sketch(self):
k = 10
dim = 3
n = 1000
sketch = density_sketch(k, dim)
self.assertEqual(sketch.get_k(), k)
self.assertEqual(sketch.get_dim(), dim)
self.assertTrue(sketch.is_empty())
self.assertFalse(sketch.is_estimation_mode())
self.assertEqual(sketch.get_n(), 0)
self.assertEqual(sketch.get_num_retained(), 0)
for i in range(n):
sketch.update([i, i, i])
self.assertFalse(sketch.is_empty())
self.assertTrue(sketch.is_estimation_mode())
self.assertEqual(sketch.get_n(), n)
self.assertGreater(sketch.get_num_retained(), k)
self.assertLess(sketch.get_num_retained(), n)
self.assertGreater(sketch.get_estimate([n - 1, n - 1, n - 1]), 0)
for tuple in sketch:
vector = tuple[0]
weight = tuple[1]
self.assertEqual(len(vector), dim)
self.assertGreaterEqual(weight, 1)
sk_bytes = sketch.serialize()
sketch2 = density_sketch.deserialize(sk_bytes)
self.assertEqual(sketch.get_estimate([1.5, 2.5, 3.5]), sketch2.get_estimate([1.5, 2.5, 3.5]))
def test_density_merge(self):
sketch1 = density_sketch(10, 2)
sketch1.update([0, 0])
sketch2 = density_sketch(10, 2)
sketch2.update([0, 1])
sketch1.merge(sketch2)
self.assertEqual(sketch1.get_n(), 2)
self.assertEqual(sketch1.get_num_retained(), 2)
def test_custom_kernel(self):
gaussianSketch = density_sketch(10, 2) # default kernel
sphericalSketch = density_sketch(10, 2, UnitSphereKernel())
p = [1, 1]
gaussianSketch.update(p)
sphericalSketch.update(p)
# Spherical kernel should return 1.0 for a nearby point, 0 farther
# Gaussian kernel should return something nonzero when farther away
self.assertEqual(sphericalSketch.get_estimate([1.001, 1]), 1.0)
self.assertEqual(sphericalSketch.get_estimate([2, 2]), 0.0)
self.assertGreater(gaussianSketch.get_estimate([2, 2]), 0.0)
# We can also use a custom kernel when deserializing
sk_bytes = sphericalSketch.serialize()
sphericalRebuilt = density_sketch.deserialize(sk_bytes, UnitSphereKernel())
self.assertEqual(sphericalSketch.get_estimate([1.001, 1]), sphericalRebuilt.get_estimate([1.001, 1]))
if __name__ == '__main__':
unittest.main()