blob: 184db270b809b869f42de7dab21644f5023c1d6b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.datasketches.sampling;
import static org.apache.datasketches.sampling.PreambleUtil.FAMILY_BYTE;
import static org.apache.datasketches.sampling.PreambleUtil.PREAMBLE_LONGS_BYTE;
import static org.apache.datasketches.sampling.PreambleUtil.RESERVOIR_SIZE_INT;
import static org.apache.datasketches.sampling.PreambleUtil.RESERVOIR_SIZE_SHORT;
import static org.apache.datasketches.sampling.PreambleUtil.SER_VER_BYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import org.testng.annotations.Test;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.Family;
import org.apache.datasketches.SketchesArgumentException;
@SuppressWarnings("javadoc")
public class ReservoirLongsUnionTest {
@Test
public void checkEmptyUnion() {
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(1024);
final byte[] unionBytes = rlu.toByteArray();
// will intentionally break if changing empty union serialization
assertEquals(unionBytes.length, 8);
println(rlu.toString());
}
@Test
public void checkInstantiation() {
final int n = 100;
final int k = 25;
// create empty unions
ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
assertNull(rlu.getResult());
rlu.update(5);
assertNotNull(rlu.getResult());
// pass in a sketch, as both an object and memory
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < n; ++i) {
rls.update(i);
}
rlu.reset();
assertEquals(rlu.getResult().getN(), 0);
rlu.update(rls);
assertEquals(rlu.getResult().getN(), rls.getN());
final byte[] sketchBytes = rls.toByteArray();
final Memory mem = Memory.wrap(sketchBytes);
rlu = ReservoirLongsUnion.newInstance(rls.getK());
rlu.update(mem);
assertNotNull(rlu.getResult());
println(rlu.toString());
}
/*
@Test
public void checkReadOnlyInstantiation() {
final int k = 100;
final ReservoirLongsUnion union = ReservoirLongsUnion.newInstance(k);
for (long i = 0; i < 2 * k; ++i) {
union.update(i);
}
final byte[] unionBytes = union.toByteArray();
final Memory mem = Memory.wrap(unionBytes);
final ReservoirLongsUnion rlu;
rlu = ReservoirLongsUnion.heapify(mem);
assertNotNull(rlu);
assertEquals(rlu.getMaxK(), k);
ReservoirLongsSketchTest.validateReservoirEquality(rlu.getResult(), union.getResult());
}
*/
@Test
public void checkNullUpdate() {
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(1024);
assertNull(rlu.getResult());
// null sketch
rlu.update((ReservoirLongsSketch) null);
assertNull(rlu.getResult());
// null memory
rlu.update((Memory) null);
assertNull(rlu.getResult());
// valid input
rlu.update(5L);
assertNotNull(rlu.getResult());
}
@Test
public void checkSerialization() {
final int n = 100;
final int k = 25;
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
for (int i = 0; i < n; ++i) {
rlu.update(i);
}
final byte[] unionBytes = rlu.toByteArray();
final Memory mem = Memory.wrap(unionBytes);
final ReservoirLongsUnion rebuiltUnion = ReservoirLongsUnion.heapify(mem);
validateUnionEquality(rlu, rebuiltUnion);
}
@Test
public void checkVersionConversionWithEmptyGadget() {
final int k = 32768;
final short encK = ReservoirSize.computeSize(k);
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
final byte[] unionBytesOrig = rlu.toByteArray();
// get a new byte[], manually revert to v1, then reconstruct
final byte[] unionBytes = rlu.toByteArray();
final WritableMemory unionMem = WritableMemory.writableWrap(unionBytes);
unionMem.putByte(SER_VER_BYTE, (byte) 1);
unionMem.putInt(RESERVOIR_SIZE_INT, 0); // zero out all 4 bytes
unionMem.putShort(RESERVOIR_SIZE_SHORT, encK);
println(PreambleUtil.preambleToString(unionMem));
final ReservoirLongsUnion rebuilt = ReservoirLongsUnion.heapify(unionMem);
final byte[] rebuiltBytes = rebuilt.toByteArray();
assertEquals(unionBytesOrig.length, rebuiltBytes.length);
for (int i = 0; i < unionBytesOrig.length; ++i) {
assertEquals(unionBytesOrig[i], rebuiltBytes[i]);
}
}
@Test
public void checkVersionConversionWithGadget() {
final long n = 32;
final int k = 256;
final short encK = ReservoirSize.computeSize(k);
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
for (long i = 0; i < n; ++i) {
rlu.update(i);
}
final byte[] unionBytesOrig = rlu.toByteArray();
// get a new byte[], manually revert to v1, then reconstruct
final byte[] unionBytes = rlu.toByteArray();
final WritableMemory unionMem = WritableMemory.writableWrap(unionBytes);
unionMem.putByte(SER_VER_BYTE, (byte) 1);
unionMem.putInt(RESERVOIR_SIZE_INT, 0); // zero out all 4 bytes
unionMem.putShort(RESERVOIR_SIZE_SHORT, encK);
// force gadget header to v1, too
final int offset = Family.RESERVOIR_UNION.getMaxPreLongs() << 3;
unionMem.putByte(offset + SER_VER_BYTE, (byte) 1);
unionMem.putInt(offset + RESERVOIR_SIZE_INT, 0); // zero out all 4 bytes
unionMem.putShort(offset + RESERVOIR_SIZE_SHORT, encK);
final ReservoirLongsUnion rebuilt = ReservoirLongsUnion.heapify(unionMem);
final byte[] rebuiltBytes = rebuilt.toByteArray();
assertEquals(unionBytesOrig.length, rebuiltBytes.length);
for (int i = 0; i < unionBytesOrig.length; ++i) {
assertEquals(unionBytesOrig[i], rebuiltBytes[i]);
}
}
//@SuppressWarnings("null") // this is the point of the test
@Test(expectedExceptions = java.lang.NullPointerException.class)
public void checkNullMemoryInstantiation() {
ReservoirLongsUnion.heapify(null);
}
@Test
public void checkDownsampledUpdate() {
final int bigK = 1024;
final int bigN = 131072;
final int smallK = 256;
final int smallN = 2048;
final ReservoirLongsSketch sketch1 = getBasicSketch(smallN, smallK);
final ReservoirLongsSketch sketch2 = getBasicSketch(bigN, bigK);
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(smallK);
assertEquals(rlu.getMaxK(), smallK);
rlu.update(sketch1);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), smallK);
rlu.update(sketch2);
assertEquals(rlu.getResult().getK(), smallK);
assertEquals(rlu.getResult().getNumSamples(), smallK);
}
@Test
public void checkUnionResetWithInitialSmallK() {
final int maxK = 25;
final int sketchK = 10;
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(maxK);
ReservoirLongsSketch rls = getBasicSketch(2 * sketchK, sketchK); // in sampling mode
rlu.update(rls);
assertEquals(rlu.getMaxK(), maxK);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), sketchK);
rlu.reset();
assertNotNull(rlu.getResult());
// feed in sketch in sampling mode, with larger k than old gadget
rls = getBasicSketch(2 * maxK, maxK + 1);
rlu.update(rls);
assertEquals(rlu.getMaxK(), maxK);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), maxK);
}
@Test
public void checkNewGadget() {
final int maxK = 1024;
final int bigK = 1536;
final int smallK = 128;
// downsample input sketch, use as gadget (exact mode, but irrelevant here)
final ReservoirLongsSketch bigKSketch = getBasicSketch(maxK / 2, bigK);
final byte[] bigKBytes = bigKSketch.toByteArray();
final Memory bigKMem = Memory.wrap(bigKBytes);
ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(maxK);
rlu.update(bigKMem);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), maxK);
assertEquals(rlu.getResult().getN(), maxK / 2);
// sketch k < maxK but in sampling mode
final ReservoirLongsSketch smallKSketch = getBasicSketch(maxK, smallK);
final byte[] smallKBytes = smallKSketch.toByteArray();
final Memory smallKMem = Memory.wrap(smallKBytes);
rlu = ReservoirLongsUnion.newInstance(maxK);
rlu.update(smallKMem);
assertNotNull(rlu.getResult());
assertTrue(rlu.getResult().getK() < maxK);
assertEquals(rlu.getResult().getK(), smallK);
assertEquals(rlu.getResult().getN(), maxK);
// sketch k < maxK and in exact mode
final ReservoirLongsSketch smallKExactSketch = getBasicSketch(smallK, smallK);
final byte[] smallKExactBytes = smallKExactSketch.toByteArray();
final Memory smallKExactMem = Memory.wrap(smallKExactBytes);
rlu = ReservoirLongsUnion.newInstance(maxK);
rlu.update(smallKExactMem);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), maxK);
assertEquals(rlu.getResult().getN(), smallK);
}
@Test
public void checkStandardMergeNoCopy() {
final int k = 1024;
final int n1 = 256;
final int n2 = 256;
final ReservoirLongsSketch sketch1 = getBasicSketch(n1, k);
final ReservoirLongsSketch sketch2 = getBasicSketch(n2, k);
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
rlu.update(sketch1);
rlu.update(sketch2);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), k);
assertEquals(rlu.getResult().getN(), n1 + n2);
assertEquals(rlu.getResult().getNumSamples(), n1 + n2);
// creating from Memory should avoid a copy
final int n3 = 2048;
final ReservoirLongsSketch sketch3 = getBasicSketch(n3, k);
final byte[] sketch3Bytes = sketch3.toByteArray();
final Memory mem = Memory.wrap(sketch3Bytes);
rlu.update(mem);
assertEquals(rlu.getResult().getK(), k);
assertEquals(rlu.getResult().getN(), n1 + n2 + n3);
assertEquals(rlu.getResult().getNumSamples(), k);
}
@Test
public void checkStandardMergeWithCopy() {
// this will check the other code route to a standard merge,
// but will copy sketch2 to be non-destructive.
final int k = 1024;
final int n1 = 768;
final int n2 = 2048;
final ReservoirLongsSketch sketch1 = getBasicSketch(n1, k);
final ReservoirLongsSketch sketch2 = getBasicSketch(n2, k);
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
rlu.update(sketch1);
rlu.update(sketch2);
rlu.update(10);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), k);
assertEquals(rlu.getResult().getN(), n1 + n2 + 1);
assertEquals(rlu.getResult().getNumSamples(), k);
}
@Test
public void checkWeightedMerge() {
final int k = 1024;
final int n1 = 16384;
final int n2 = 2048;
final ReservoirLongsSketch sketch1 = getBasicSketch(n1, k);
final ReservoirLongsSketch sketch2 = getBasicSketch(n2, k);
ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(k);
rlu.update(sketch1);
rlu.update(sketch2);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), k);
assertEquals(rlu.getResult().getN(), n1 + n2);
assertEquals(rlu.getResult().getNumSamples(), k);
// now merge into the sketch for updating -- results should match
rlu = ReservoirLongsUnion.newInstance(k);
rlu.update(sketch2);
rlu.update(sketch1);
assertNotNull(rlu.getResult());
assertEquals(rlu.getResult().getK(), k);
assertEquals(rlu.getResult().getN(), n1 + n2);
assertEquals(rlu.getResult().getNumSamples(), k);
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadPreLongs() {
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(1024);
final WritableMemory mem = WritableMemory.writableWrap(rlu.toByteArray());
mem.putByte(PREAMBLE_LONGS_BYTE, (byte) 0); // corrupt the preLongs count
ReservoirLongsUnion.heapify(mem);
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadSerVer() {
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(1024);
final WritableMemory mem = WritableMemory.writableWrap(rlu.toByteArray());
mem.putByte(SER_VER_BYTE, (byte) 0); // corrupt the serialization version
ReservoirLongsUnion.heapify(mem);
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadFamily() {
final ReservoirLongsUnion rlu = ReservoirLongsUnion.newInstance(1024);
final WritableMemory mem = WritableMemory.writableWrap(rlu.toByteArray());
mem.putByte(FAMILY_BYTE, (byte) 0); // corrupt the family ID
ReservoirLongsUnion.heapify(mem);
fail();
}
private static void validateUnionEquality(final ReservoirLongsUnion rlu1,
final ReservoirLongsUnion rlu2) {
assertEquals(rlu1.getMaxK(), rlu2.getMaxK());
ReservoirLongsSketchTest.validateReservoirEquality(rlu1.getResult(), rlu2.getResult());
}
private static ReservoirLongsSketch getBasicSketch(final int n, final int k) {
final ReservoirLongsSketch rls = ReservoirLongsSketch.newInstance(k);
for (int i = 0; i < n; ++i) {
rls.update(i);
}
return rls;
}
/**
* Wrapper around System.out.println() allowing a simple way to disable logging in tests
*
* @param msg The message to print
*/
private static void println(final String msg) {
//System.out.println(msg);
}
}