blob: b3aac54b4e32b11627bc0c0926cff19a61ec8456 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.datasketches.sampling;
import static org.apache.datasketches.common.Util.LS;
import static org.apache.datasketches.sampling.PreambleUtil.EBPPS_SER_VER;
import static org.apache.datasketches.sampling.PreambleUtil.EMPTY_FLAG_MASK;
import static org.apache.datasketches.sampling.PreambleUtil.HAS_PARTIAL_ITEM_MASK;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.datasketches.common.ArrayOfItemsSerDe;
import org.apache.datasketches.common.Family;
import org.apache.datasketches.common.SketchesArgumentException;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
/**
* An implementation of an Exact and Bounded Sampling Proportional to Size sketch.
*
* <p>From: "Exact PPS Sampling with Bounded Sample Size",
* B. Hentschel, P. J. Haas, Y. Tian. Information Processing Letters, 2023.
*
* <p>This sketch samples data from a stream of items proportional to the weight of each item.
* The sample guarantees the presence of an item in the result is proportional to that item's
* portion of the total weight seen by the sketch, and returns a sample no larger than size k.
*
* <p>The sample may be smaller than k and the resulting size of the sample potentially includes
* a probabilistic component, meaning the resulting sample size is not always constant.
* @param <T> the item class type
* @author Jon Malkin
*/
public final class EbppsItemsSketch<T> {
private static final int MAX_K = Integer.MAX_VALUE - 2;
private static final int EBPPS_C_DOUBLE = 40; // part of sample state, not preamble
private static final int EBPPS_ITEMS_START = 48;
private int k_; // max size of sketch, in items
private long n_; // total number of items processed by the sketch
private double cumulativeWt_; // total weight of items processed by the sketch
private double wtMax_; // maximum weight seen so far
private double rho_; // latest scaling parameter for downsampling
private EbppsItemsSample<T> sample_; // Object holding the current state of the sample
final private EbppsItemsSample<T> tmp_; // temporary storage
/**
* Constructor
* @param k The maximum number of samples to retain
*/
public EbppsItemsSketch(final int k) {
checkK(k);
k_ = k;
rho_ = 1.0;
sample_ = new EbppsItemsSample<>(k);
tmp_ = new EbppsItemsSample<>(1);
}
// private copy constructor
private EbppsItemsSketch(final EbppsItemsSketch<T> other) {
k_ = other.k_;
n_ = other.n_;
rho_ = other.rho_;
cumulativeWt_ = other.cumulativeWt_;
wtMax_ = other.wtMax_;
sample_ = new EbppsItemsSample<>(other.sample_);
tmp_ = new EbppsItemsSample<>(1);
}
// private constructor for heapify
private EbppsItemsSketch(final EbppsItemsSample<T> sample,
final int k,
final long n,
final double cumWt,
final double maxWt,
final double rho) {
k_ = k;
n_ = n;
cumulativeWt_ = cumWt;
wtMax_ = maxWt;
rho_ = rho;
sample_ = sample;
tmp_ = new EbppsItemsSample<>(1);
}
/**
* Returns a sketch instance of this class from the given srcMem,
* which must be a Memory representation of this sketch class.
*
* @param <T> The type of item this sketch contains
* @param srcMem a Memory representation of a sketch of this class.
* <a href="{@docRoot}/resources/dictionary.html#mem">See Memory</a>
* @param serDe An instance of ArrayOfItemsSerDe
* @return a sketch instance of this class
*/
public static <T> EbppsItemsSketch<T> heapify(final Memory srcMem,
final ArrayOfItemsSerDe<T> serDe)
{
final int numPreLongs = PreambleUtil.getAndCheckPreLongs(srcMem);
final int serVer = PreambleUtil.extractSerVer(srcMem);
final int familyId = PreambleUtil.extractFamilyID(srcMem);
final int flags = PreambleUtil.extractFlags(srcMem);
final boolean isEmpty = (flags & EMPTY_FLAG_MASK) != 0;
final boolean hasPartialItem = (flags & HAS_PARTIAL_ITEM_MASK) != 0;
// Check values
if (isEmpty) {
if (numPreLongs != Family.EBPPS.getMinPreLongs()) {
throw new SketchesArgumentException("Possible corruption: Must be " + Family.EBPPS.getMinPreLongs()
+ " for an empty sketch. Found: " + numPreLongs);
}
} else {
if (numPreLongs != Family.EBPPS.getMaxPreLongs()) {
throw new SketchesArgumentException("Possible corruption: Must be "
+ Family.EBPPS.getMaxPreLongs() + " for a non-empty sketch. Found: " + numPreLongs);
}
}
if (serVer != EBPPS_SER_VER) {
throw new SketchesArgumentException(
"Possible Corruption: Ser Ver must be " + EBPPS_SER_VER + ": " + serVer);
}
final int reqFamilyId = Family.EBPPS.getID();
if (familyId != reqFamilyId) {
throw new SketchesArgumentException(
"Possible Corruption: FamilyID must be " + reqFamilyId + ": " + familyId);
}
final int k = PreambleUtil.extractK(srcMem);
if (k < 1 || k > MAX_K) {
throw new SketchesArgumentException("Possible Corruption: k must be at least 1 "
+ "and less than " + MAX_K + ". Found: " + k);
}
if (isEmpty) {
return new EbppsItemsSketch<>(k);
}
final long n = PreambleUtil.extractN(srcMem);
if (n < 0) {
throw new SketchesArgumentException("Possible Corruption: n cannot be negative: " + n);
}
final double cumWt = PreambleUtil.extractEbppsCumulativeWeight(srcMem);
if (cumWt < 0.0 || Double.isNaN(cumWt) || Double.isInfinite(cumWt)) {
throw new SketchesArgumentException("Possible Corruption: cumWt must be nonnegative and finite: " + cumWt);
}
final double maxWt = PreambleUtil.extractEbppsMaxWeight(srcMem);
if (maxWt < 0.0 || Double.isNaN(maxWt) || Double.isInfinite(maxWt)) {
throw new SketchesArgumentException("Possible Corruption: maxWt must be nonnegative and finite: " + maxWt);
}
final double rho = PreambleUtil.extractEbppsRho(srcMem);
if (rho < 0.0 || rho > 1.0 || Double.isNaN(rho) || Double.isInfinite(rho)) {
throw new SketchesArgumentException("Possible Corruption: rho must be in [0.0, 1.0]: " + rho);
}
// extract C (part of sample_, not the preamble)
// due to numeric precision issues, c may occasionally be very slightly larger than k
final double c = srcMem.getDouble(EBPPS_C_DOUBLE);
if (c < 0 || c >= (k + 1) || Double.isNaN(c) || Double.isInfinite(c)) {
throw new SketchesArgumentException("Possible Corruption: c must be between 0 and k: " + c);
}
// extract items
final int numTotalItems = (int) Math.ceil(c);
final int numFullItems = (int) Math.floor(c); // floor() not strictly necessary
final int offsetBytes = EBPPS_ITEMS_START;
final T[] rawItems = serDe.deserializeFromMemory(
srcMem.region(offsetBytes, srcMem.getCapacity() - offsetBytes), 0, numTotalItems);
final List<T> itemsList = Arrays.asList(rawItems);
final ArrayList<T> data;
final T partialItem;
if (hasPartialItem) {
if (numFullItems >= numTotalItems) {
throw new SketchesArgumentException("Possible Corruption: Expected partial item but none found");
}
data = new ArrayList<>(itemsList.subList(0, numFullItems));
partialItem = itemsList.get(numFullItems); // 0-based, so last item
} else {
data = new ArrayList<>(itemsList);
partialItem = null; // just to be explicit
}
final EbppsItemsSample<T> sample = new EbppsItemsSample<>(data, partialItem, c);
return new EbppsItemsSketch<>(sample, k, n, cumWt, maxWt, rho);
}
/**
* Updates this sketch with the given data item with weight 1.0.
* @param item an item from a stream of items
*/
public void update(final T item) {
update(item, 1.0);
}
/**
* Updates this sketch with the given data item with the given weight.
* @param item an item from a stream of items
* @param weight the weight of the item
*/
public void update(final T item, final double weight) {
if (weight < 0.0 || Double.isNaN(weight) || Double.isInfinite(weight)) {
throw new SketchesArgumentException("Item weights must be nonnegative and finite. "
+ "Found: " + weight);
}
if (weight == 0.0) {
return;
}
final double newCumWt = cumulativeWt_ + weight;
final double newWtMax = Math.max(wtMax_, weight);
final double newRho = Math.min(1.0 / newWtMax, k_ / newCumWt);
if (cumulativeWt_ > 0.0) {
sample_.downsample((newRho / rho_));
}
tmp_.replaceContent(item, newRho * weight);
sample_.merge(tmp_);
cumulativeWt_ = newCumWt;
wtMax_ = newWtMax;
rho_ = newRho;
++n_;
}
/* Merging
* There is a trivial merge algorithm that involves downsampling each sketch A and B
* as A.cum_wt / (A.cum_wt + B.cum_wt) and B.cum_wt / (A.cum_wt + B.cum_wt),
* respectively. That merge does preserve first-order probabilities, specifically
* the probability proportional to size property, and like all other known merge
* algorithms distorts second-order probabilities (co-occurrences). There are
* pathological cases, most obvious with k=2 and A.cum_wt == B.cum_wt where that
* approach will always take exactly 1 item from A and 1 from B, meaning the
* co-occurrence rate for two items from either sketch is guaranteed to be 0.0.
*
* With EBPPS, once an item is accepted into the sketch we no longer need to
* track the item's weight: All accepted items are treated equally. As a result, we
* can take inspiration from the reservoir sampling merge in the datasketches-java
* library. We need to merge the smaller sketch into the larger one, swapping as
* needed to ensure that, at which point we simply call update() with the items
* in the smaller sketch as long as we adjust the weight appropriately.
* Merging smaller into larger is essential to ensure that no item has a
* contribution to C > 1.0.
*/
/**
* Merges the provided sketch into the current one.
* @param other the sketch to merge into the current object
*/
public void merge(final EbppsItemsSketch<T> other) {
if (other.getCumulativeWeight() == 0.0) {
return;
} else if (other.getCumulativeWeight() > cumulativeWt_) {
// need to swap this with other
// make a copy of other, merge into it, and take the result
final EbppsItemsSketch<T> copy = new EbppsItemsSketch<>(other);
copy.internalMerge(this);
k_ = copy.k_;
n_ = copy.n_;
cumulativeWt_ = copy.cumulativeWt_;
wtMax_ = copy.wtMax_;
rho_ = copy.rho_;
sample_ = copy.sample_;
} else {
internalMerge(other);
}
}
// merge implementation called exclusively from public merge()
private void internalMerge(final EbppsItemsSketch<T> other) {
// assumes that other.cumulativeWeight_ <= cumulativeWt_m
// which must be checked before calling this
final double finalCumWt = cumulativeWt_ + other.cumulativeWt_;
final double newWtMax = Math.max(wtMax_, other.wtMax_);
k_ = Math.min(k_, other.k_);
final long newN = n_ + other.n_;
// Insert other's items with the cumulative weight
// split between the input items. We repeat the same process
// for full items and the partial item, scaling the input
// weight appropriately.
// We handle all C input items, meaning we always process
// the partial item using a scaled down weight.
// Handling the partial item by probabilistically including
// it as a full item would be correct on average but would
// introduce bias for any specific merge operation.
final double avgWt = other.cumulativeWt_ / other.getC();
final ArrayList<T> items = other.sample_.getFullItems();
if (items != null) {
for (T item : items) {
// newWtMax is pre-computed
final double newCumWt = cumulativeWt_ + avgWt;
final double newRho = Math.min(1.0 / newWtMax, k_ / newCumWt);
if (cumulativeWt_ > 0.0) {
sample_.downsample(newRho / rho_);
}
tmp_.replaceContent(item, newRho * avgWt);
sample_.merge(tmp_);
cumulativeWt_ = newCumWt;
rho_ = newRho;
}
}
// insert partial item with weight scaled by the fractional part of C
if (other.sample_.hasPartialItem()) {
final double otherCFrac = other.getC() % 1;
final double newCumWt = cumulativeWt_ + (otherCFrac * avgWt);
final double newRho = Math.min(1.0 / newWtMax, k_ / newCumWt);
if (cumulativeWt_ > 0.0) {
sample_.downsample(newRho / rho_);
}
tmp_.replaceContent(other.sample_.getPartialItem(), newRho * otherCFrac * avgWt);
sample_.merge(tmp_);
// cumulativeWt_ will be assigned momentarily
rho_ = newRho;
}
// avoid numeric issues by setting cumulative weight to the
// pre-computed value
cumulativeWt_ = finalCumWt;
n_ = newN;
}
/**
* Returns a copy of the current sample. The exact size may be
* probabilistic, differing by at most 1 item.
* @return the current sketch sample
*/
public ArrayList<T> getResult() { return sample_.getSample(); }
/**
* Provides a human-readable summary of the sketch
* @return a summary of information in the sketch
*/
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(LS);
final String thisSimpleName = this.getClass().getSimpleName();
sb.append("### ").append(thisSimpleName).append(" SUMMARY: ").append(LS);
sb.append(" k : ").append(k_).append(LS);
sb.append(" n : ").append(n_).append(LS);
sb.append(" Cum. weight : ").append(cumulativeWt_).append(LS);
sb.append(" wtMax : ").append(wtMax_).append(LS);
sb.append(" rho : ").append(rho_).append(LS);
sb.append(" C : ").append(sample_.getC()).append(LS);
sb.append("### END SKETCH SUMMARY").append(LS);
return sb.toString();
}
/**
* Returns the configured maximum sample size.
* @return configured maximum sample size
*/
public int getK() { return k_; }
/**
* Returns the number of items processed by the sketch, regardless
* of item weight.
* @return count of items processed by the sketch
*/
public long getN() { return n_; }
/**
* Returns the cumulative weight of items processed by the sketch.
* @return cumulative weight of items seen
*/
public double getCumulativeWeight() { return cumulativeWt_; }
/**
* Returns the expected number of samples returned upon a call to
* getResult(). The number is a floating point value, where the
* fractional portion represents the probability of including a
* "partial item" from the sample.
*
* <p>The value C should be no larger than the sketch's configured
* value of k, although numerical precision limitations mean it
* may exceed k by double precision floating point error margins
* in certain cases.
* @return The expected number of samples returned when querying the sketch
*/
public double getC() { return sample_.getC(); }
/**
* Returns true if the sketch is empty.
* @return empty flag
*/
public boolean isEmpty() { return n_ == 0; }
/**
* Resets the sketch to its default, empty state.
*/
public void reset() {
n_ = 0;
cumulativeWt_ = 0.0;
wtMax_ = 0.0;
rho_ = 1.0;
sample_ = new EbppsItemsSample<>(k_);
}
/**
* Returns the size of a byte array representation of this sketch. May fail for polymorphic item types.
*
* @param serDe An instance of ArrayOfItemsSerDe
* @return the length of a byte array representation of this sketch
*/
public int getSerializedSizeBytes(final ArrayOfItemsSerDe<? super T> serDe) {
if (isEmpty()) {
return Family.EBPPS.getMinPreLongs() << 3;
} else if (sample_.getC() < 1.0) {
return getSerializedSizeBytes(serDe, sample_.getPartialItem().getClass());
} else {
return getSerializedSizeBytes(serDe, sample_.getSample().get(0).getClass());
}
}
/**
* Returns the length of a byte array representation of this sketch. Copies contents into an array of the
* specified class for serialization to allow for polymorphic types.
*
* @param serDe An instance of ArrayOfItemsSerDe
* @param clazz The class represented by &lt;T&gt;
* @return the length of a byte array representation of this sketch
*/
public int getSerializedSizeBytes(final ArrayOfItemsSerDe<? super T> serDe, final Class<?> clazz) {
if (n_ == 0) {
return Family.EBPPS.getMinPreLongs() << 3;
}
final int preLongs = Family.EBPPS.getMaxPreLongs();
final byte[] itemBytes = serDe.serializeToByteArray(sample_.getAllSamples(clazz));
// in C++, c_ is serialized as part of the sample_ and not included in the header size
return (preLongs << 3) + Double.BYTES + itemBytes.length;
}
/**
* Returns a byte array representation of this sketch. May fail for polymorphic item types.
*
* @param serDe An instance of ArrayOfItemsSerDe
* @return a byte array representation of this sketch
*/
public byte[] toByteArray(final ArrayOfItemsSerDe<? super T> serDe) {
if (n_ == 0) {
// null class is ok since empty -- no need to call serDe
return toByteArray(serDe, null);
} else if (sample_.getC() < 1.0) {
return toByteArray(serDe, sample_.getPartialItem().getClass());
} else {
return toByteArray(serDe, sample_.getSample().get(0).getClass());
}
}
/**
* Returns a byte array representation of this sketch. Copies contents into an array of the
* specified class for serialization to allow for polymorphic types.
*
* @param serDe An instance of ArrayOfItemsSerDe
* @param clazz The class represented by &lt;T&gt;
* @return a byte array representation of this sketch
*/
public byte[] toByteArray(final ArrayOfItemsSerDe<? super T> serDe, final Class<?> clazz) {
final int preLongs, outBytes;
final boolean empty = n_ == 0;
byte[] itemBytes = null; // for serialized items from sample_
if (empty) {
preLongs = 1;
outBytes = 8;
} else {
preLongs = Family.EBPPS.getMaxPreLongs();
itemBytes = serDe.serializeToByteArray(sample_.getAllSamples(clazz));
// in C++, c_ is serialized as part of the sample_ and not included in the header size
outBytes = (preLongs << 3) + Double.BYTES + itemBytes.length;
}
final byte[] outArr = new byte[outBytes];
final WritableMemory mem = WritableMemory.writableWrap(outArr);
// Common header elements
PreambleUtil.insertPreLongs(mem, preLongs); // Byte 0
PreambleUtil.insertSerVer(mem, EBPPS_SER_VER); // Byte 1
PreambleUtil.insertFamilyID(mem, Family.EBPPS.getID()); // Byte 2
if (empty) {
PreambleUtil.insertFlags(mem, EMPTY_FLAG_MASK); // Byte 3
} else {
PreambleUtil.insertFlags(mem, sample_.hasPartialItem() ? HAS_PARTIAL_ITEM_MASK : 0);
}
PreambleUtil.insertK(mem, k_); // Bytes 4-7
// conditional elements
if (!empty) {
PreambleUtil.insertN(mem, n_);
PreambleUtil.insertEbppsCumulativeWeight(mem, cumulativeWt_);
PreambleUtil.insertEbppsMaxWeight(mem, wtMax_);
PreambleUtil.insertEbppsRho(mem, rho_);
// data from sample_ -- itemBytes includes the partial item
mem.putDouble(EBPPS_C_DOUBLE, sample_.getC());
mem.putByteArray(EBPPS_ITEMS_START, itemBytes, 0, itemBytes.length);
}
return outArr;
}
private static void checkK(final int k) {
if (k <= 0 || k > MAX_K) {
throw new SketchesArgumentException("k must be strictly positive and less than " + MAX_K);
}
}
}