blob: 3d3623e6f6861b6562e1296d1725ccf6ad649fec [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
from random import randint, seed
from .streaming_heap import StreamingHeap
class CountSketch:
def __init__(self, num_buckets: int, num_levels: int, phi, rng_seed=100):
# TODO: add string update functionality.
:param num_buckets:int -- number of columns in the sketch table
:param num_levels:int -- number of rows in the sketch table
:param phi:float -- The threshold for heavy hitters
:param rng_seed:int -- seed for the randomisation
- self.p :: int - a large prime used to generate the random hashes.
- self.w :: int - the number of hash buckets for the table
- self.d :: int - the number of levels that are repeated
- self.table :: np.ndarray of type float that is the table that is updated on viewing data.
Denote F2 = ||f||_2 where f is the frequency vector underlying the data stream.
eg. if stream = (1, 1, 2, 1, 5) then f = (0, 3, 1, 0, 0, 1).
If the stream is weighted i.e. we receive (item, weight) pairs, then the same idea applies:
stream = [(1, 4), (2, 1) (1, -1), (5, 1)] also has f = (0, 3, 1, 0, 0, 1) -- nb using 0-based indexing for
consistency with python.
The a-heavy hitters problem is to identify all items i for which f[i] > a*F2
Solving this problem requires linear space in the worst-case so a relaxed version is solved: to return
b-heavy hitters for b = a - t and t > 0.
Hence, we permit some _false positives_ : items that are flagged as being b-heavy, but may not be a-heavy.
A CountSketch can achieve this aim in small space by maintaining a sketch that is used to estimate f[i].
Specifically, the CountSketch returns an estimate g[i] for which
f[i] - epsilon*F2 <= g[i] <= f[i] + epsilon*F2.
The value epsilon is between 0 and 1 and is the worst-case error for estimating the frequency f[i] using the
value g[i].
epsilon can be explicitly calculated using parameters of the sketch self.w and self.d as seen in the function
def get_epsilon(self).
We return a set i of heavy indices by using the sketch and finding the (relaxed) b-heavy hitters.
The CountSketch guarantee ensures f[i] - epsilon*F2 <= g[i] so we will opt to find all i such that:
g[i] >= (b - epsilon)*F2 or equivalently g[i] >= ( (a - t) - epsilon)*F2.
We *set* phi ( = b - t ) and do not use the (a - t) setup and epsilon is explicitly known.
Note that there are two sources of approximation:
- the ``t'' is the heavy hitter approximate relaxation parameter
- the ``epsilon'' is the frequency estimation parameter.
self.p = 2 ** 31 - 1
self.w = num_buckets
self.d = num_levels
self.table = np.zeros((self.d, self.w), dtype=float)
self.phi = phi
assert(self.phi <= 1.0), f"Phi={self.phi:.5f} but cannot be larger than 1."
self.total_weight = 0.0 # Sum of all weights seen and is aka the L1 norm of the underlying frequency vector.
self.max_num_items = np.ceil(1./self.phi).astype(np.int64)
self.heavy_hitter_detector = StreamingHeap(self.max_num_items)
self.merged_sketches = [] # A list of all sketches that have been merged with self.
def _init_hashes(self):
Initialises the hash functions for bucket and sign selection by generating (a,b) pairs for the hash family.
A new (a,b) pair is necessary for each of the hash functions
self.bucket_hash_params = {i: self._get_ab_hash() for i in range(self.d)}
self.sign_hash_params = {i: self._get_ab_hash() for i in range(self.d)}
def _get_ab_hash(self):
# TODO - check the nonzero property on a is correct
We need 0 <= a <= self.p - 1 and 0 <= b <= self.p - 1 as discussed on page 18
:return: the (a,b) pair used to define the hash family
return randint(0, self.p - 1), randint(0, self.p - 1) # random (a,b) pairs from ([p], [p])
def _bucket_hash(self, x: int, a: int, b: int, buckets: int):
Generic function for generating 2-wise independent hash functions.
We use this function for the bucket locations with buckets <- self.num_buckets.
It is also used to generate the random signs with buckets <- 2 -- see `` def _sign_hash ''
:param x: stream item
:param a:
:param b:
:param buckets:
:return: h:int the hash value (aka bucket index) for item x observed in the stream.
h = (a * x + b) % self.p
h = h % buckets
return h
def _sign_hash(self, x: int, a: int, b: int):
Returns the 2-wise independent sign hash.
This generates the signs when items are put into buckets.
:param x:stream item
:param a:
:param b:
:return:s -- int the sign +1 or -1 used for the hashing.
s = 2.*self._bucket_hash(x, a, b, 2) - 1.
return s
def _insert(self, item: int, weight=1.0):
Inserts the item into the sketch table
:param item:
:param weight:
if not (isinstance(item, np.integer) or isinstance(item, int)):
# this checks for ``int`` and ``*` for any input item
raise TypeError("Input item must be an int.")
for ii in range(self.d):
a_bucket, b_bucket = self.bucket_hash_params[ii] # Gets the (a,b) pair used for *bucket* hashes at level ii
a_sign, b_sign = self.sign_hash_params[ii] # Gets the (a,b) pair used for *sign* hashes at level ii
bucket = self._bucket_hash(item, a_bucket, b_bucket, self.w)
sign = self._sign_hash(item, a_sign, b_sign)
self.table[ii, bucket] += sign*weight
def update(self, item: int, weight=1.0):
Updates the sketch by:
1. Inserting (item, weight) into self.table
2. Estimating the current frequency and pushing it into the heavy_hitter_detector
(nb. item only pushed into the detector if it has large enough frequency.
This logic is dealt with in the detector class itself.)
:param item:int
:param weight:float
self.total_weight += weight
self._insert(item, weight)
self.heavy_hitter_detector.push(item, self.get_estimate(item))
def get_epsilon(self):
:return: an estimate of the worst-case epsilon error achieved by the sketch for frequency estimation.
return np.sqrt(np.e / self.w)
def get_estimate(self, item: int):
:param item:int -- the identifier of the item whose frequency is being queried.
:return:np.median(_estimates) -- The count sketch frequency estimate.
_estimates = self._get_frequency_estimate(item)
for sk in self.merged_sketches:
_estimates += sk._get_frequency_estimate(item)
return np.median(_estimates) # TODO - replace this with np.max(0.0, np.median(_estimate))?
def _get_frequency_estimate(self, item: int):
:param item:
:return: _estimates -- an array of the frequency estimate where each entry corresponds to a level of the sketch.
_buckets = {_: self._bucket_hash(item, self.bucket_hash_params[_][0], self.bucket_hash_params[_][1], self.w)
for _ in range(self.d)}
_signs = {_: self._sign_hash(item, self.sign_hash_params[_][0], self.sign_hash_params[_][1])
for _ in range(self.d)}
_estimates = np.array([self.table[_, _buckets[_]] * _signs[_] for _ in range(self.d)])
return _estimates
def get_frequent_items(self):
:return: a list of the heavy hitters and their frequencies.
for k in self.heavy_hitter_detector.heap.keys():
self.heavy_hitter_detector.heap[k] = self.get_estimate(k)
return list(self.heavy_hitter_detector.heap.items())
def get_total_weight(self):
Returns the total weight inserted on the stream.
Is equivalent to the mass inserted over the stream, or the ell_1 norm of the
underlying frequency vector.
return self.total_weight
def is_empty(self):
Returns True if the sketch self.table is empty and is False otherwise.
if not np.any(self.table):
# All zeros array
return True
# There is a nonzero in self.table
return False
def merge(self, other_count_sketch):
:param other_count_sketch: another count_sketch object to be merged into self.
The count sketch is a linear sketch so we can simple add the data structures self.table
to other_count_sketch.table.
Then we need to adjust the dictionaries of heavy hitters.
self.table += other_count_sketch.table
self.total_weight += other_count_sketch.total_weight
# Call a naive merge on StreamingHeap (could be improved by using min_heaps rather than just dicts.)
except AttributeError:
AttributeError("Argument must be a count sketch object")
def __str__(self):
return (
'### Count sketch summary:\n'
f' num. buckets : {self.w}\n'
f' num. levels : {self.d}\n'
f' worst-case error : {(self.get_epsilon()):.5f}\n'
f' Heavy hitter threshold : {self.phi:.5f}\n'
'### End sketch summary')