blob: c448d2c57cdcf59c98efde9f98200d090d050126 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.datasketches.sampling;
import static org.apache.datasketches.sampling.PreambleUtil.FAMILY_BYTE;
import static org.apache.datasketches.sampling.PreambleUtil.PREAMBLE_LONGS_BYTE;
import static org.apache.datasketches.sampling.PreambleUtil.RESERVOIR_SIZE_INT;
import static org.apache.datasketches.sampling.PreambleUtil.RESERVOIR_SIZE_SHORT;
import static org.apache.datasketches.sampling.PreambleUtil.SER_VER_BYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import org.testng.annotations.Test;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.Family;
import org.apache.datasketches.ResizeFactor;
import org.apache.datasketches.SketchesArgumentException;
import org.apache.datasketches.SketchesException;
import org.apache.datasketches.SketchesStateException;
@SuppressWarnings("javadoc")
public class ReservoirLongsSketchTest {
private static final double EPS = 1e-8;
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkInvalidK() {
ReservoirLongsSketch.newInstance(0);
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadPreLongs() {
final WritableMemory mem = getBasicSerializedRLS();
mem.putByte(PREAMBLE_LONGS_BYTE, (byte) 0); // corrupt the preLongs count
ReservoirLongsSketch.heapify(mem);
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadSerVer() {
final WritableMemory mem = getBasicSerializedRLS();
mem.putByte(SER_VER_BYTE, (byte) 0); // corrupt the serialization version
ReservoirLongsSketch.heapify(mem);
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadFamily() {
final WritableMemory mem = getBasicSerializedRLS();
mem.putByte(FAMILY_BYTE, (byte) 0); // corrupt the family ID
ReservoirLongsSketch.heapify(mem);
fail();
}
@Test
public void checkEmptySketch() {
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(5);
assertTrue(rls.getSamples() == null);
final byte[] sketchBytes = rls.toByteArray();
final Memory mem = Memory.wrap(sketchBytes);
// only minPreLongs bytes and should deserialize to empty
assertEquals(sketchBytes.length, Family.RESERVOIR.getMinPreLongs() << 3);
final ReservoirLongsSketch loadedRls = ReservoirLongsSketch.heapify(mem);
assertEquals(loadedRls.getNumSamples(), 0);
println("Empty sketch:");
println(rls.toString());
ReservoirLongsSketch.toString(sketchBytes);
ReservoirLongsSketch.toString(mem);
}
@Test
public void checkUnderFullReservoir() {
final int k = 128;
final int n = 64;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < n; ++i) {
rls.update(i);
}
assertEquals(rls.getNumSamples(), n);
final long[] data = rls.getSamples();
assertEquals(rls.getNumSamples(), rls.getN());
assertNotNull(data);
assertEquals(data.length, n);
// items in submit order until reservoir at capacity so check
for (int i = 0; i < n; ++i) {
assertEquals(data[i], i);
}
validateSerializeAndDeserialize(rls);
}
@Test
public void checkFullReservoir() {
final int k = 1000;
final int n = 2000;
// specify smaller ResizeFactor to ensure multiple resizes
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k, ResizeFactor.X2);
for (int i = 0; i < n; ++i) {
rls.update(i);
}
assertEquals(rls.getNumSamples(), rls.getK());
validateSerializeAndDeserialize(rls);
println("Full reservoir:");
println(rls.toString());
}
@Test
public void checkDownsampledCopy() {
final int k = 256;
final int tgtK = 64;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
// check status at 3 points:
// 1. n < encTgtK
// 2. encTgtK < n < k
// 3. n > k
int i;
for (i = 0; i < (tgtK - 1); ++i) {
rls.update(i);
}
ReservoirLongsSketch dsCopy = rls.downsampledCopy(tgtK);
assertEquals(dsCopy.getK(), tgtK);
// should be identical other than value of k, which isn't checked here
validateReservoirEquality(rls, dsCopy);
// check condition 2 next
for (; i < (k - 1); ++i) {
rls.update(i);
}
assertEquals(rls.getN(), k - 1);
dsCopy = rls.downsampledCopy(tgtK);
assertEquals(dsCopy.getN(), rls.getN());
assertEquals(dsCopy.getNumSamples(), tgtK);
// and now condition 3
for (; i < (2 * k); ++i) {
rls.update(i);
}
assertEquals(rls.getN(), 2 * k);
dsCopy = rls.downsampledCopy(tgtK);
assertEquals(dsCopy.getN(), rls.getN());
assertEquals(dsCopy.getNumSamples(), tgtK);
}
@Test
public void checkBadConstructorArgs() {
final long[] data = new long[128];
for (int i = 0; i < 128; ++i) {
data[i] = i;
}
final ResizeFactor rf = ResizeFactor.X8;
// no items
try {
ReservoirLongsSketch.getInstance(null, 128, rf, 128);
fail();
} catch (final SketchesException e) {
assertTrue(e.getMessage().contains("null reservoir"));
}
// size too small
try {
ReservoirLongsSketch.getInstance(data, 128, rf, 1);
fail();
} catch (final SketchesException e) {
assertTrue(e.getMessage().contains("size less than 2"));
}
// configured reservoir size smaller than items length
try {
ReservoirLongsSketch.getInstance(data, 128, rf, 64);
fail();
} catch (final SketchesException e) {
assertTrue(e.getMessage().contains("max size less than array length"));
}
// too many items seen vs items length, full sketch
try {
ReservoirLongsSketch.getInstance(data, 512, rf, 256);
fail();
} catch (final SketchesException e) {
assertTrue(e.getMessage().contains("too few samples"));
}
// too many items seen vs items length, under-full sketch
try {
ReservoirLongsSketch.getInstance(data, 256, rf, 256);
fail();
} catch (final SketchesException e) {
assertTrue(e.getMessage().contains("too few samples"));
}
}
@Test
public void checkSketchCapacity() {
final long[] data = new long[64];
final long itemsSeen = (1L << 48) - 2;
final ReservoirLongsSketch rls = ReservoirLongsSketch.getInstance(data, itemsSeen,
ResizeFactor.X8, data.length);
// this should work, the next should fail
rls.update(0);
try {
rls.update(0);
fail();
} catch (final SketchesStateException e) {
assertTrue(e.getMessage().contains("Sketch has exceeded capacity for total items seen"));
}
rls.reset();
assertEquals(rls.getN(), 0);
rls.update(1L);
assertEquals(rls.getN(), 1L);
}
@Test
public void checkSampleWeight() {
final int k = 32;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < (k / 2); ++i) {
rls.update(i);
}
assertEquals(rls.getImplicitSampleWeight(), 1.0); // should be exact value here
// will have 3k/2 total samples when done
for (int i = 0; i < k; ++i) {
rls.update(i);
}
assertTrue(Math.abs(rls.getImplicitSampleWeight() - 1.5) < EPS);
}
/*
@Test
public void checkReadOnlyHeapify() {
Memory sketchMem = getBasicSerializedRLS();
// Load from read-only and writable memory to ensure they deserialize identically
ReservoirLongsSketch rls = ReservoirLongsSketch.heapify(sketchMem.asReadOnlyMemory());
ReservoirLongsSketch fromWritable = ReservoirLongsSketch.heapify(sketchMem);
validateReservoirEquality(rls, fromWritable);
// Same with an empty sketch
final byte[] sketchBytes = ReservoirLongsSketch.newInstance(32).toByteArray();
sketchMem = new NativeMemory(sketchBytes);
rls = ReservoirLongsSketch.heapify(sketchMem.asReadOnlyMemory());
fromWritable = ReservoirLongsSketch.heapify(sketchMem);
validateReservoirEquality(rls, fromWritable);
}
*/
@Test
public void checkVersionConversion() {
// version change from 1 to 2 only impact first preamble long, so empty sketch is sufficient
final int k = 32768;
final short encK = ReservoirSize.computeSize(k);
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
final byte[] sketchBytesOrig = rls.toByteArray();
// get a new byte[], manually revert to v1, then reconstruct
final byte[] sketchBytes = rls.toByteArray();
final WritableMemory sketchMem = WritableMemory.writableWrap(sketchBytes);
sketchMem.putByte(SER_VER_BYTE, (byte) 1);
sketchMem.putInt(RESERVOIR_SIZE_INT, 0); // zero out all 4 bytes
sketchMem.putShort(RESERVOIR_SIZE_SHORT, encK);
final ReservoirLongsSketch rebuilt = ReservoirLongsSketch.heapify(sketchMem);
final byte[] rebuiltBytes = rebuilt.toByteArray();
assertEquals(sketchBytesOrig.length, rebuiltBytes.length);
for (int i = 0; i < sketchBytesOrig.length; ++i) {
assertEquals(sketchBytesOrig[i], rebuiltBytes[i]);
}
}
@Test
public void checkSetAndGetValue() {
final int k = 20;
final int tgtIdx = 5;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < k; ++i) {
rls.update(i);
}
assertEquals(rls.getValueAtPosition(tgtIdx), tgtIdx);
rls.insertValueAtPosition(-1, tgtIdx);
assertEquals(rls.getValueAtPosition(tgtIdx), -1);
}
@Test
public void checkBadSetAndGetValue() {
final int k = 20;
final int tgtIdx = 5;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
try {
rls.getValueAtPosition(0);
fail();
} catch (final SketchesArgumentException e) {
// expected
}
for (int i = 0; i < k; ++i) {
rls.update(i);
}
assertEquals(rls.getValueAtPosition(tgtIdx), tgtIdx);
try {
rls.insertValueAtPosition(-1, -1);
fail();
} catch (final SketchesArgumentException e) {
// expected
}
try {
rls.insertValueAtPosition(-1, k + 1);
fail();
} catch (final SketchesArgumentException e) {
// expected
}
try {
rls.getValueAtPosition(-1);
fail();
} catch (final SketchesArgumentException e) {
// expected
}
try {
rls.getValueAtPosition(k + 1);
fail();
} catch (final SketchesArgumentException e) {
// expected
}
}
@Test
public void checkForceIncrement() {
final int k = 100;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < (2 * k); ++i) {
rls.update(i);
}
assertEquals(rls.getN(), 2 * k);
rls.forceIncrementItemsSeen(k);
assertEquals(rls.getN(), 3 * k);
try {
rls.forceIncrementItemsSeen((1L << 48) - 1);
fail();
} catch (final SketchesStateException e) {
// expected
}
}
@Test
public void checkEstimateSubsetSum() {
final int k = 10;
final ReservoirLongsSketch sketch = ReservoirLongsSketch.newInstance(k);
// empty sketch -- all zeros
SampleSubsetSummary ss = sketch.estimateSubsetSum(item -> true);
assertEquals(ss.getEstimate(), 0.0);
assertEquals(ss.getTotalSketchWeight(), 0.0);
// add items, keeping in exact mode
double itemCount = 0.0;
for (long i = 1; i <= (k - 1); ++i) {
sketch.update(i);
itemCount += 1.0;
}
ss = sketch.estimateSubsetSum(item -> true);
assertEquals(ss.getEstimate(), itemCount);
assertEquals(ss.getLowerBound(), itemCount);
assertEquals(ss.getUpperBound(), itemCount);
assertEquals(ss.getTotalSketchWeight(), itemCount);
// add a few more items, pushing to sampling mode
for (long i = k; i <= (k + 1); ++i) {
sketch.update(i);
itemCount += 1.0;
}
// predicate always true so estimate == upper bound
ss = sketch.estimateSubsetSum(item -> true);
assertEquals(ss.getEstimate(), itemCount);
assertEquals(ss.getUpperBound(), itemCount);
assertTrue(ss.getLowerBound() < itemCount);
assertEquals(ss.getTotalSketchWeight(), itemCount);
// predicate always false so estimate == lower bound == 0.0
ss = sketch.estimateSubsetSum(item -> false);
assertEquals(ss.getEstimate(), 0.0);
assertEquals(ss.getLowerBound(), 0.0);
assertTrue(ss.getUpperBound() > 0.0);
assertEquals(ss.getTotalSketchWeight(), itemCount);
// finally, a non-degenerate predicate
// insert negative items with identical weights, filter for negative weights only
for (long i = 1; i <= (k + 1); ++i) {
sketch.update(-i);
itemCount += 1.0;
}
ss = sketch.estimateSubsetSum(item -> item < 0);
assertTrue(ss.getEstimate() >= ss.getLowerBound());
assertTrue(ss.getEstimate() <= ss.getUpperBound());
// allow pretty generous bounds when testing
assertTrue(ss.getLowerBound() < (itemCount / 1.4));
assertTrue(ss.getUpperBound() > (itemCount / 2.6));
assertEquals(ss.getTotalSketchWeight(), itemCount);
}
private static WritableMemory getBasicSerializedRLS() {
final int k = 10;
final int n = 20;
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
assertEquals(rls.getNumSamples(), 0);
for (int i = 0; i < n; ++i) {
rls.update(i);
}
assertEquals(rls.getNumSamples(), Math.min(n, k));
assertEquals(rls.getN(), n);
assertEquals(rls.getK(), k);
final byte[] sketchBytes = rls.toByteArray();
return WritableMemory.writableWrap(sketchBytes);
}
private static void validateSerializeAndDeserialize(final ReservoirLongsSketch rls) {
final byte[] sketchBytes = rls.toByteArray();
assertEquals(sketchBytes.length,
(Family.RESERVOIR.getMaxPreLongs() + rls.getNumSamples()) << 3);
// ensure full reservoir rebuilds correctly
final Memory mem = Memory.wrap(sketchBytes);
final ReservoirLongsSketch loadedRls = ReservoirLongsSketch.heapify(mem);
validateReservoirEquality(rls, loadedRls);
}
static void validateReservoirEquality(final ReservoirLongsSketch rls1,
final ReservoirLongsSketch rls2) {
assertEquals(rls1.getNumSamples(), rls2.getNumSamples());
if (rls1.getNumSamples() == 0) {
return;
}
final long[] samples1 = rls1.getSamples();
final long[] samples2 = rls2.getSamples();
assertNotNull(samples1);
assertNotNull(samples2);
assertEquals(samples1.length, samples2.length);
for (int i = 0; i < samples1.length; ++i) {
assertEquals(samples1[i], samples2[i]);
}
}
static String printBytesAsLongs(final byte[] byteArr) {
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < byteArr.length; i += 8) {
for (int j = i + 7; j >= i; --j) {
final String str = Integer.toHexString(byteArr[j] & 0XFF);
sb.append(org.apache.datasketches.Util.zeroPad(str, 2));
}
sb.append(org.apache.datasketches.Util.LS);
}
return sb.toString();
}
/**
* Wrapper around System.out.println() allowing a simple way to disable logging in tests
*
* @param msg The message to print
*/
private static void println(final String msg) {
//System.out.println(msg);
}
}