blob: 2189771b4707ffa2fa576ea3242b56dd56de6283 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.datasketches.sampling;
import static org.apache.datasketches.sampling.PreambleUtil.PREAMBLE_LONGS_BYTE;
import static org.apache.datasketches.sampling.PreambleUtil.SER_VER_BYTE;
import static org.apache.datasketches.sampling.VarOptItemsSketchTest.EPS;
import static org.apache.datasketches.sampling.VarOptItemsSketchTest.getUnweightedLongsVIS;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import org.testng.annotations.Test;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.ArrayOfLongsSerDe;
import org.apache.datasketches.ArrayOfStringsSerDe;
import org.apache.datasketches.Family;
import org.apache.datasketches.SketchesArgumentException;
/**
* @author Jon Malkin
*/
@SuppressWarnings("javadoc")
public class VarOptItemsUnionTest {
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadSerVer() {
final int k = 25;
final int n = 30;
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(k);
union.update(getUnweightedLongsVIS(k, n));
final byte[] bytes = union.toByteArray(new ArrayOfLongsSerDe());
final WritableMemory mem = WritableMemory.writableWrap(bytes);
mem.putByte(SER_VER_BYTE, (byte) 0); // corrupt the serialization version
VarOptItemsUnion.heapify(mem, new ArrayOfLongsSerDe());
fail();
}
@Test(expectedExceptions = SketchesArgumentException.class)
public void checkBadPreLongs() {
final int k = 25;
final int n = 30;
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(k);
union.update(getUnweightedLongsVIS(k, n));
final byte[] bytes = union.toByteArray(new ArrayOfLongsSerDe());
final WritableMemory mem = WritableMemory.writableWrap(bytes);
// corrupt the preLongs count to 0
mem.putByte(PREAMBLE_LONGS_BYTE, (byte) (Family.VAROPT.getMinPreLongs() - 1));
VarOptItemsUnion.heapify(mem, new ArrayOfLongsSerDe());
fail();
}
@Test
public void unionEmptySketch() {
final int k = 2048;
final ArrayOfStringsSerDe serDe = new ArrayOfStringsSerDe();
// we'll union from Memory for good measure
final byte[] sketchBytes = VarOptItemsSketch.<String>newInstance(k).toByteArray(serDe);
final Memory mem = Memory.wrap(sketchBytes);
final VarOptItemsUnion<String> union = VarOptItemsUnion.newInstance(k);
union.update(mem, serDe);
final VarOptItemsSketch<String> result = union.getResult();
assertEquals(result.getN(), 0);
assertEquals(result.getHRegionCount(), 0);
assertEquals(result.getRRegionCount(), 0);
assertTrue(Double.isNaN(result.getTau()));
}
@Test
public void unionTwoExactSketches() {
final int n = 4; // 2n < k
final int k = 10;
final VarOptItemsSketch<Integer> sk1 = VarOptItemsSketch.newInstance(k);
final VarOptItemsSketch<Integer> sk2 = VarOptItemsSketch.newInstance(k);
for (int i = 1; i <= n; ++i) {
sk1.update(i, i);
sk2.update(-i, i);
}
final VarOptItemsUnion<Integer> union = VarOptItemsUnion.newInstance(k);
union.update(sk1);
union.update(sk2);
final VarOptItemsSketch<Integer> result = union.getResult();
assertEquals(result.getN(), 2 * n);
assertEquals(result.getHRegionCount(), 2 * n);
assertEquals(result.getRRegionCount(), 0);
}
@Test
public void unionHeavySamplingSketch() {
final int n1 = 20;
final int k1 = 10;
final int n2 = 6;
final int k2 = 5;
final VarOptItemsSketch<Integer> sk1 = VarOptItemsSketch.newInstance(k1);
final VarOptItemsSketch<Integer> sk2 = VarOptItemsSketch.newInstance(k2);
for (int i = 1; i <= n1; ++i) {
sk1.update(i, i);
}
for (int i = 1; i < n2; ++i) { // we'll add a very heavy one later
sk2.update(-i, i + 1000.0);
}
sk2.update(-n2, 1000000.0);
final VarOptItemsUnion<Integer> union = VarOptItemsUnion.newInstance(k1);
union.update(sk1);
union.update(sk2);
VarOptItemsSketch<Integer> result = union.getResult();
assertEquals(result.getN(), n1 + n2);
assertEquals(result.getK(), k2); // heavy enough it'll pull back to k2
assertEquals(result.getHRegionCount(), 1);
assertEquals(result.getRRegionCount(), k2 - 1);
union.reset();
assertEquals(union.getOuterTau(), 0.0);
result = union.getResult();
assertEquals(result.getK(), k1);
assertEquals(result.getN(), 0);
}
@Test
public void unionIdenticalSamplingSketches() {
final int k = 20;
final int n = 50;
VarOptItemsSketch<Long> sketch = getUnweightedLongsVIS(k, n);
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(k);
union.update(sketch);
union.update(sketch);
VarOptItemsSketch<Long> result = union.getResult();
double expectedWeight = 2.0 * n; // unweighted, aka uniform weight of 1.0
assertEquals(result.getN(), 2 * n);
assertEquals(result.getTotalWtR(), expectedWeight);
// add another sketch, such that sketchTau < outerTau
sketch = getUnweightedLongsVIS(k, k + 1); // tau = (k + 1) / k
union.update(sketch);
result = union.getResult();
expectedWeight = (2.0 * n) + k + 1;
assertEquals(result.getN(), (2 * n) + k + 1);
assertEquals(result.getTotalWtR(), expectedWeight, EPS);
union.reset();
assertEquals(union.getOuterTau(), 0.0);
result = union.getResult();
assertEquals(result.getK(), k);
assertEquals(result.getN(), 0);
}
@Test
public void unionSmallSamplingSketch() {
final int kSmall = 16;
final int n1 = 32;
final int n2 = 64;
final int kMax = 128;
// small k sketch, but sampling
VarOptItemsSketch<Long> sketch = getUnweightedLongsVIS(kSmall, n1);
sketch.update(-1L, n1 ^ 2); // add a heavy item
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(kMax);
union.update(sketch);
// another one, but different n to get a different per-item weight
sketch = getUnweightedLongsVIS(kSmall, n2);
union.update(sketch);
// should trigger migrateMarkedItemsByDecreasingK()
final VarOptItemsSketch<Long> result = union.getResult();
assertEquals(result.getN(), n1 + n2 + 1);
assertEquals(result.getTotalWtR(), 96.0, EPS); // n1+n2 light items, ignore the heavy one
}
@Test
public void unionExactReservoirSketch() {
// build a varopt union which contains both heavy and light items, then copy it and
// compare unioning:
// 1. A varopt sketch of items with weight 1.0
// 2. A reservoir sample made of the same input items as above
// and we should find that the resulting unions are equivalent.
final int k = 20;
final long n = 2 * k;
final VarOptItemsSketch<Long> baseVis = VarOptItemsSketch.newInstance(k);
for (long i = 1; i <= n; ++i) {
baseVis.update(-i, i);
}
baseVis.update(-n - 1L, n * n);
baseVis.update(-n - 2L, n * n);
baseVis.update(-n - 3L, n * n);
final VarOptItemsUnion<Long> union1 = VarOptItemsUnion.newInstance(k);
union1.update(baseVis);
final ArrayOfLongsSerDe serDe = new ArrayOfLongsSerDe();
final Memory unionImg = Memory.wrap(union1.toByteArray(serDe));
final VarOptItemsUnion<Long> union2 = VarOptItemsUnion.heapify(unionImg, serDe);
compareUnionsExact(union1, union2); // sanity check
final VarOptItemsSketch<Long> vis = VarOptItemsSketch.newInstance(k);
final ReservoirItemsSketch<Long> ris = ReservoirItemsSketch.newInstance(k);
union2.update((ReservoirItemsSketch<Long>) null);
union2.update(ris); // empty
compareUnionsExact(union1, union2); // union2 should be unchanged
for (long i = 1; i < (k - 1); ++i) {
ris.update(i);
vis.update(i, 1.0);
}
union1.update(vis);
union2.update(ris);
compareUnionsEquivalent(union1, union2);
}
@Test
public void unionSamplingReservoirSketch() {
// Like unionExactReservoirSketch, but merge in reservoir first, with reservoir in sampling mode
final int k = 20;
final long n = k * k;
final VarOptItemsUnion<Long> union1 = VarOptItemsUnion.newInstance(k);
final VarOptItemsUnion<Long> union2 = VarOptItemsUnion.newInstance(k);
compareUnionsExact(union1, union2); // sanity check
final VarOptItemsSketch<Long> vis = VarOptItemsSketch.newInstance(k);
final ReservoirItemsSketch<Long> ris = ReservoirItemsSketch.newInstance(k);
for (long i = 1; i < n; ++i) {
ris.update(i);
vis.update(i, 1.0);
}
union1.update(vis);
union2.update(ris);
compareUnionsEquivalent(union1, union2);
// repeat to trigger equal tau scenario
union1.update(vis);
union2.update(ris);
compareUnionsEquivalent(union1, union2);
// create and add a sketch with some heavy items
final VarOptItemsSketch<Long> newVis = VarOptItemsSketch.newInstance(k);
for (long i = 1; i <= n; ++i) {
newVis.update(-i, i);
}
newVis.update(-n - 1L, n * n);
newVis.update(-n - 2L, n * n);
newVis.update(-n - 3L, n * n);
union1.update(newVis);
union2.update(newVis);
compareUnionsEquivalent(union1, union2);
}
@Test
public void unionReservoirVariousTauValues() {
final int k = 20;
final long n = 2 * k;
final VarOptItemsSketch<Long> baseVis = VarOptItemsSketch.newInstance(k);
for (long i = 1; i <= n; ++i) {
baseVis.update(-i, 1.0);
}
final VarOptItemsUnion<Long> union1 = VarOptItemsUnion.newInstance(k);
union1.update(baseVis);
final ArrayOfLongsSerDe serDe = new ArrayOfLongsSerDe();
final Memory unionImg = Memory.wrap(union1.toByteArray(serDe));
final VarOptItemsUnion<Long> union2 = VarOptItemsUnion.heapify(unionImg, serDe);
compareUnionsExact(union1, union2); // sanity check
// reservoir tau will be greater than gadget's tau
VarOptItemsSketch<Long> vis = VarOptItemsSketch.newInstance(k);
ReservoirItemsSketch<Long> ris = ReservoirItemsSketch.newInstance(k);
for (long i = 1; i < (2 * n); ++i) {
ris.update(i);
vis.update(i, 1.0);
}
union1.update(vis);
union2.update(ris);
compareUnionsEquivalent(union1, union2);
// reservoir tau will be smaller than gadget's tau
vis = VarOptItemsSketch.newInstance(k);
ris = ReservoirItemsSketch.newInstance(k);
for (long i = 1; i <= (k + 1); ++i) {
ris.update(i);
vis.update(i, 1.0);
}
union1.update(vis);
union2.update(ris);
compareUnionsEquivalent(union1, union2);
}
@Test
public void serializeEmptyUnion() {
final int k = 100;
final VarOptItemsUnion<String> union = VarOptItemsUnion.newInstance(k);
// null inputs to update() should leave the union empty
union.update((VarOptItemsSketch<String>) null);
union.update(null, new ArrayOfStringsSerDe());
final ArrayOfStringsSerDe serDe = new ArrayOfStringsSerDe();
final byte[] bytes = union.toByteArray(serDe);
assertEquals(bytes.length, 8);
final Memory mem = Memory.wrap(bytes);
final VarOptItemsUnion<String> rebuilt = VarOptItemsUnion.heapify(mem, serDe);
final VarOptItemsSketch<String> sketch = rebuilt.getResult();
assertEquals(sketch.getN(), 0);
assertEquals(rebuilt.toString(), union.toString());
}
@Test
public void serializeExactUnion() {
final int n1 = 32;
final int n2 = 64;
final int k = 128;
final VarOptItemsSketch<Long> sketch1 = getUnweightedLongsVIS(k, n1);
final VarOptItemsSketch<Long> sketch2 = getUnweightedLongsVIS(k, n2);
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(k);
union.update(sketch1);
union.update(sketch2);
final ArrayOfLongsSerDe serDe = new ArrayOfLongsSerDe();
final byte[] unionBytes = union.toByteArray(serDe);
final Memory mem = Memory.wrap(unionBytes);
final VarOptItemsUnion<Long> rebuilt = VarOptItemsUnion.heapify(mem, serDe);
compareUnionsExact(rebuilt, union);
assertEquals(rebuilt.toString(), union.toString());
}
@Test
public void serializeSamplingUnion() {
final int n = 256;
final int k = 128;
final VarOptItemsSketch<Long> sketch = getUnweightedLongsVIS(k, n);
sketch.update(n + 1L, 1000.0);
sketch.update(n + 2L, 1001.0);
sketch.update(n + 3L, 1002.0);
sketch.update(n + 4L, 1003.0);
sketch.update(n + 5L, 1004.0);
sketch.update(n + 6L, 1005.0);
sketch.update(n + 7L, 1006.0);
sketch.update(n + 8L, 1007.0);
final VarOptItemsUnion<Long> union = VarOptItemsUnion.newInstance(k);
union.update(sketch);
final ArrayOfLongsSerDe serDe = new ArrayOfLongsSerDe();
final byte[] unionBytes = union.toByteArray(serDe);
final Memory mem = Memory.wrap(unionBytes);
final VarOptItemsUnion<Long> rebuilt = VarOptItemsUnion.heapify(mem, serDe);
compareUnionsExact(rebuilt, union);
assertEquals(rebuilt.toString(), union.toString());
}
private static <T> void compareUnionsExact(final VarOptItemsUnion<T> u1,
final VarOptItemsUnion<T> u2) {
assertEquals(u1.getOuterTau(), u2.getOuterTau());
final VarOptItemsSketch<T> sketch1 = u1.getResult();
final VarOptItemsSketch<T> sketch2 = u2.getResult();
assertEquals(sketch1.getN(), sketch2.getN());
assertEquals(sketch1.getHRegionCount(), sketch2.getHRegionCount());
assertEquals(sketch1.getRRegionCount(), sketch2.getRRegionCount());
final VarOptItemsSamples<T> s1 = sketch1.getSketchSamples();
final VarOptItemsSamples<T> s2 = sketch2.getSketchSamples();
assertEquals(s1.getNumSamples(), s2.getNumSamples());
assertEquals(s1.weights(), s2.weights());
assertEquals(s1.items(), s2.items());
}
private static <T> void compareUnionsEquivalent(final VarOptItemsUnion<T> u1,
final VarOptItemsUnion<T> u2) {
assertEquals(u1.getOuterTau(), u2.getOuterTau());
final VarOptItemsSketch<T> sketch1 = u1.getResult();
final VarOptItemsSketch<T> sketch2 = u2.getResult();
assertEquals(sketch1.getN(), sketch2.getN());
assertEquals(sketch1.getHRegionCount(), sketch2.getHRegionCount());
assertEquals(sketch1.getRRegionCount(), sketch2.getRRegionCount());
final VarOptItemsSamples<T> s1 = sketch1.getSketchSamples();
final VarOptItemsSamples<T> s2 = sketch2.getSketchSamples();
assertEquals(s1.getNumSamples(), s2.getNumSamples());
assertEquals(s1.weights(), s2.weights());
// only compare exact items; others can differ as long as weights match
for (int i = 0; i < sketch1.getHRegionCount(); ++i) {
assertEquals(s1.items(i), s2.items(i));
}
}
/**
* Wrapper around System.out.println() allowing a simple way to disable logging in tests
* @param msg The message to print
*/
@SuppressWarnings("unused")
private static void println(final String msg) {
//System.out.println(msg);
}
}