blob: a34e516f730274eb750107a881b303ab4cef3ddd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tuple
import (
"errors"
"slices"
"github.com/apache/datasketches-go/internal"
"github.com/apache/datasketches-go/theta"
)
type intersectionOptions struct {
seed uint64
}
type IntersectionOptionFunc func(*intersectionOptions)
// WithIntersectionSeed sets the seed for the hash function.
func WithIntersectionSeed(seed uint64) IntersectionOptionFunc {
return func(i *intersectionOptions) {
i.seed = seed
}
}
// Intersection computes the intersection of sketches.
type Intersection[S Summary] struct {
hashtable *hashtable[S]
policy Policy[S]
entryLessFunc func(a, b entry[S]) int
isValid bool
}
// NewIntersection creates a new intersection.
func NewIntersection[S Summary](policy Policy[S], opts ...IntersectionOptionFunc) *Intersection[S] {
options := &intersectionOptions{
seed: theta.DefaultSeed,
}
for _, opt := range opts {
opt(options)
}
return &Intersection[S]{
hashtable: newHashtable[S](
0, 0, theta.ResizeX1, 1.0, theta.MaxTheta, options.seed, false,
),
entryLessFunc: func(a, b entry[S]) int {
if a.Hash < b.Hash {
return -1
} else if a.Hash > b.Hash {
return 1
}
return 0
},
policy: policy,
isValid: false,
}
}
// Update updates the intersection with a given sketch.
func (i *Intersection[S]) Update(sketch Sketch[S]) error {
if i.hashtable.isEmpty {
return nil
}
seedHash, err := internal.ComputeSeedHash(int64(i.hashtable.seed))
if err != nil {
return err
}
sketchSeedHash, err := sketch.SeedHash()
if err != nil {
return err
}
if !sketch.IsEmpty() && sketchSeedHash != uint16(seedHash) {
return errors.New("seed hash mismatch")
}
i.hashtable.isEmpty = i.hashtable.isEmpty || sketch.IsEmpty()
if i.hashtable.isEmpty {
i.hashtable.theta = theta.MaxTheta
} else {
i.hashtable.theta = min(i.hashtable.theta, sketch.Theta64())
}
if i.isValid && i.hashtable.numEntries == 0 {
return nil
}
if sketch.NumRetained() == 0 {
i.isValid = true
i.hashtable = newHashtable[S](
0, 0, theta.ResizeX1, 1.0, i.hashtable.theta, i.hashtable.seed, i.hashtable.isEmpty,
)
return nil
}
if !i.isValid { // first update, copy or move incoming sketch
i.isValid = true
lgSize := internal.LgSizeFromCount(sketch.NumRetained(), rebuildThreshold)
i.hashtable = newHashtable[S](lgSize, lgSize-1, theta.ResizeX1, 1.0, i.hashtable.theta, i.hashtable.seed, i.hashtable.isEmpty)
for hash, summary := range sketch.All() {
idx, err := i.hashtable.Find(hash)
if err == nil {
return errors.New("duplicate key, possibly corrupted input sketch")
}
i.hashtable.Insert(idx, entry[S]{
Hash: hash,
Summary: summary,
})
}
if i.hashtable.numEntries != sketch.NumRetained() {
return errors.New("num entries mismatch, possibly corrupted input sketch")
}
return nil
}
// intersection
var (
maxMatches = min(i.hashtable.numEntries, sketch.NumRetained())
matchesEntries = make([]entry[S], 0, maxMatches)
matchCount = 0
count = 0
)
for hash, summary := range sketch.All() {
if hash < i.hashtable.theta {
key, err := i.hashtable.Find(hash)
if err == nil {
if uint32(matchCount) == maxMatches {
return errors.New("max matches exceeded, possibly corrupted input sketch")
}
i.policy.Apply(i.hashtable.entries[key].Summary, summary)
matchesEntries = append(matchesEntries, i.hashtable.entries[key])
matchCount++
}
} else if sketch.IsOrdered() {
// early stop
break
}
count++
}
if count > int(sketch.NumRetained()) {
return errors.New("more keys than expected, possibly corrupted input sketch")
}
if !sketch.IsOrdered() && count < int(sketch.NumRetained()) {
return errors.New("fewer keys than expected, possibly corrupted input sketch")
}
if matchCount == 0 {
i.hashtable = newHashtable[S](
0, 0, theta.ResizeX1, 1.0, i.hashtable.theta, i.hashtable.seed, i.hashtable.isEmpty,
)
if i.hashtable.theta == theta.MaxTheta {
i.hashtable.isEmpty = true
}
} else {
lgSize := internal.LgSizeFromCount(uint32(matchCount), rebuildThreshold)
i.hashtable = newHashtable[S](
lgSize, lgSize-1, theta.ResizeX1, 1.0, i.hashtable.theta, i.hashtable.seed, i.hashtable.isEmpty,
)
for j := 0; j < matchCount; j++ {
key, err := i.hashtable.Find(matchesEntries[j].Hash)
if err != nil && err == ErrKeyNotFoundAndNoEmptySlots {
return err
}
i.hashtable.Insert(key, matchesEntries[j])
}
}
return nil
}
// Result produces a copy of the current state of the intersection.
func (i *Intersection[S]) Result(ordered bool) (*CompactSketch[S], error) {
if !i.isValid {
return nil, errors.New("calling Result() before calling Update() is undefined")
}
entries := make([]entry[S], 0, i.hashtable.numEntries)
if i.hashtable.numEntries > 0 {
for _, e := range i.hashtable.entries {
if e.Hash != 0 {
entries = append(entries, entry[S]{
Hash: e.Hash,
Summary: e.Summary.Clone().(S),
})
}
}
if ordered {
slices.SortFunc(entries, i.entryLessFunc)
}
}
seedHash, err := internal.ComputeSeedHash(int64(i.hashtable.seed))
if err != nil {
return nil, err
}
return newCompactSketch[S](
i.hashtable.isEmpty,
ordered,
uint16(seedHash),
i.hashtable.theta,
entries,
), nil
}
// OrderedResult produces a copy of the current state of the intersection.
func (i *Intersection[S]) OrderedResult() (*CompactSketch[S], error) {
return i.Result(true)
}
// HasResult returns true if the state of the intersection is defined.
func (i *Intersection[S]) HasResult() bool {
return i.isValid
}
// Policy returns the policy for processing matched summary during intersection.
func (i *Intersection[S]) Policy() Policy[S] {
return i.policy
}