Merge pull request #52 from matthayes/master
Add more sampling tests
diff --git a/src/java/datafu/pig/bags/CountEach.java b/src/java/datafu/pig/bags/CountEach.java
index c346f10..2bf42e1 100644
--- a/src/java/datafu/pig/bags/CountEach.java
+++ b/src/java/datafu/pig/bags/CountEach.java
@@ -20,6 +20,7 @@
import java.util.Map;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
@@ -53,7 +54,7 @@
* }
* </pre>
*/
-public class CountEach extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class CountEach extends AccumulatorEvalFunc<DataBag>
{
private boolean flatten = false;
private Map<Tuple, Integer> counts = new HashMap<Tuple, Integer>();
@@ -110,21 +111,6 @@
}
@Override
- public DataBag exec(Tuple input) throws IOException
- {
- try
- {
- accumulate(input);
-
- return getValue();
- }
- finally
- {
- cleanup();
- }
- }
-
- @Override
public Schema outputSchema(Schema input)
{
try {
diff --git a/src/java/datafu/pig/bags/Enumerate.java b/src/java/datafu/pig/bags/Enumerate.java
index 59a9246..9bbca9c 100644
--- a/src/java/datafu/pig/bags/Enumerate.java
+++ b/src/java/datafu/pig/bags/Enumerate.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
@@ -53,7 +54,7 @@
* }
* </pre>
*/
-public class Enumerate extends SimpleEvalFunc<DataBag> implements Accumulator<DataBag>
+public class Enumerate extends AccumulatorEvalFunc<DataBag>
{
private final int start;
@@ -72,15 +73,10 @@
cleanup();
}
- public DataBag call(DataBag inputBag) throws IOException
+ @Override
+ public void accumulate(Tuple arg0) throws IOException
{
- cleanup();
- outputBag = BagFactory.getInstance().newDefaultBag();
- enumerateBag(inputBag);
- return getValue();
- }
-
- public void enumerateBag(DataBag inputBag){
+ DataBag inputBag = (DataBag)arg0.get(0);
for (Tuple t : inputBag) {
Tuple t1 = TupleFactory.getInstance().newTuple(t.getAll());
t1.append(i);
@@ -94,13 +90,6 @@
count++;
}
}
-
- @Override
- public void accumulate(Tuple arg0) throws IOException
- {
- DataBag inputBag = (DataBag)arg0.get(0);
- enumerateBag(inputBag);
- }
@Override
public void cleanup()
diff --git a/src/java/datafu/pig/hash/MD5.java b/src/java/datafu/pig/hash/MD5.java
index 0f104c2..2bf6293 100644
--- a/src/java/datafu/pig/hash/MD5.java
+++ b/src/java/datafu/pig/hash/MD5.java
@@ -20,27 +20,56 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
+import org.apache.commons.codec.binary.Base64;
+
import datafu.pig.util.SimpleEvalFunc;
/**
- * Computes the MD5 value of a string and outputs it in hex.
+ * Computes the MD5 value of a string and outputs it in hex (by default).
+ * A method can be provided to the constructor, which may be either 'hex' or 'base64'.
*/
public class MD5 extends SimpleEvalFunc<String>
{
- private final MessageDigest md5er;
-
- public MD5()
+ private final MessageDigest md5er;
+ private final boolean isBase64;
+
+ public MD5()
+ {
+ this("hex");
+ }
+
+ public MD5(String method)
+ {
+ if ("hex".equals(method))
{
- try {
- md5er = MessageDigest.getInstance("md5");
- }
- catch (NoSuchAlgorithmException e) {
- throw new RuntimeException(e);
- }
+ isBase64 = false;
}
-
- public String call(String val)
+ else if ("base64".equals(method))
{
- return new BigInteger(1, md5er.digest(val.getBytes())).toString(16);
+ isBase64 = true;
}
+ else
+ {
+ throw new IllegalArgumentException("Expected either hex or base64");
+ }
+
+ try {
+ md5er = MessageDigest.getInstance("md5");
+ }
+ catch (NoSuchAlgorithmException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public String call(String val)
+ {
+ if (isBase64)
+ {
+ return new String(Base64.encodeBase64(md5er.digest(val.getBytes())));
+ }
+ else
+ {
+ return new BigInteger(1, md5er.digest(val.getBytes())).toString(16);
+ }
+ }
}
diff --git a/src/java/datafu/pig/hash/MD5Base64.java b/src/java/datafu/pig/hash/MD5Base64.java
deleted file mode 100644
index 0ef4b44..0000000
--- a/src/java/datafu/pig/hash/MD5Base64.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright 2010 LinkedIn, Inc
- *
- * Licensed under the Apache License, Version 2.0 (the "License"); you may not
- * use this file except in compliance with the License. You may obtain a copy of
- * the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations under
- * the License.
- */
-
-package datafu.pig.hash;
-
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
-
-import org.apache.commons.codec.binary.Base64;
-
-import datafu.pig.util.SimpleEvalFunc;
-
-/**
- * Computes the MD5 value of a string and outputs it in Base64 encoding.
- */
-public class MD5Base64 extends SimpleEvalFunc<String>
-{
- private final MessageDigest md5er;
-
- public MD5Base64()
- {
- try {
- md5er = MessageDigest.getInstance("md5");
- }
- catch (NoSuchAlgorithmException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String call(String val)
- {
- return new String(Base64.encodeBase64(md5er.digest(val.getBytes())));
- }
-}
diff --git a/src/java/datafu/pig/linkanalysis/PageRank.java b/src/java/datafu/pig/linkanalysis/PageRank.java
index dc2e35f..482d854 100644
--- a/src/java/datafu/pig/linkanalysis/PageRank.java
+++ b/src/java/datafu/pig/linkanalysis/PageRank.java
@@ -25,6 +25,7 @@
import java.util.Map;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
@@ -135,7 +136,7 @@
* </pre>
* </p>
*/
-public class PageRank extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class PageRank extends AccumulatorEvalFunc<DataBag>
{
private final datafu.pig.linkanalysis.PageRankImpl graph = new datafu.pig.linkanalysis.PageRankImpl();
@@ -367,21 +368,6 @@
}
}
- @Override
- public DataBag exec(Tuple input) throws IOException
- {
- try
- {
- accumulate(input);
-
- return getValue();
- }
- finally
- {
- cleanup();
- }
- }
-
private ProgressIndicator getProgressIndicator()
{
return new ProgressIndicator()
diff --git a/src/java/datafu/pig/sampling/ReservoirSample.java b/src/java/datafu/pig/sampling/ReservoirSample.java
index b1dce9d..23e70bd 100644
--- a/src/java/datafu/pig/sampling/ReservoirSample.java
+++ b/src/java/datafu/pig/sampling/ReservoirSample.java
@@ -18,8 +18,8 @@
import java.io.IOException;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.Algebraic;
-import org.apache.pig.AlgebraicEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
@@ -37,7 +37,7 @@
* @author wvaughan
*
*/
-public class ReservoirSample extends AlgebraicEvalFunc<DataBag> implements Algebraic
+public class ReservoirSample extends AccumulatorEvalFunc<DataBag> implements Algebraic
{
Integer numSamples;
private Reservoir reservoir;
@@ -56,24 +56,42 @@
}
@Override
- public DataBag exec(Tuple input) throws IOException
- {
+ public void accumulate(Tuple input) throws IOException
+ {
DataBag samples = (DataBag) input.get(0);
- if (samples == null || samples.size() <= numSamples) {
- return samples;
- }
-
for (Tuple sample : samples) {
getReservoir().consider(new ScoredTuple(Math.random(), sample));
- }
-
+ }
+ }
+
+ @Override
+ public void cleanup()
+ {
+ this.reservoir = null;
+ }
+
+ @Override
+ public DataBag getValue()
+ {
DataBag output = BagFactory.getInstance().newDefaultBag();
for (ScoredTuple sample : getReservoir()) {
output.add(sample.getTuple());
}
-
return output;
}
+
+ @Override
+ public DataBag exec(Tuple input) throws IOException
+ {
+ DataBag samples = (DataBag)input.get(0);
+ if (samples.size() <= numSamples) {
+ return samples;
+ }
+ else
+ {
+ return super.exec(input);
+ }
+ }
@Override
public Schema outputSchema(Schema input) {
@@ -146,7 +164,11 @@
DataBag output = BagFactory.getInstance().newDefaultBag();
DataBag samples = (DataBag) input.get(0);
- if (samples == null || samples.size() <= numSamples) {
+ if (samples == null)
+ {
+ // do nothing
+ }
+ else if (samples.size() <= numSamples) {
// no need to construct a reservoir, so just emit intermediate tuples
for (Tuple sample : samples) {
// add the score on to the intermediate tuple
@@ -201,7 +223,7 @@
}
}
- DataBag output = BagFactory.getInstance().newDefaultBag();
+ DataBag output = BagFactory.getInstance().newDefaultBag();
for (ScoredTuple scoredTuple : getReservoir()) {
// add the score on to the intermediate tuple
output.add(scoredTuple.getIntermediateTuple(tupleFactory));
diff --git a/src/java/datafu/pig/sessions/SessionCount.java b/src/java/datafu/pig/sessions/SessionCount.java
index 671f3e7..80c4cf8 100644
--- a/src/java/datafu/pig/sessions/SessionCount.java
+++ b/src/java/datafu/pig/sessions/SessionCount.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
@@ -56,7 +57,7 @@
* </pre>
*
*/
-public class SessionCount extends EvalFunc<Long> implements Accumulator<Long>
+public class SessionCount extends AccumulatorEvalFunc<Long>
{
private final long millis;
private DateTime last_date;
@@ -99,14 +100,4 @@
this.last_date = null;
this.sum = 0;
}
-
- @Override
- public Long exec(Tuple input) throws IOException
- {
- accumulate(input);
- Long result = getValue();
- cleanup();
-
- return result;
- }
}
diff --git a/src/java/datafu/pig/sessions/Sessionize.java b/src/java/datafu/pig/sessions/Sessionize.java
index a9aeaae..d83c2f4 100644
--- a/src/java/datafu/pig/sessions/Sessionize.java
+++ b/src/java/datafu/pig/sessions/Sessionize.java
@@ -20,6 +20,7 @@
import java.util.UUID;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
@@ -67,7 +68,7 @@
* </pre>
* </p>
*/
-public class Sessionize extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class Sessionize extends AccumulatorEvalFunc<DataBag>
{
private final long millis;
@@ -84,16 +85,6 @@
}
@Override
- public DataBag exec(Tuple input) throws IOException
- {
- accumulate(input);
- DataBag outputBag = getValue();
- cleanup();
-
- return outputBag;
- }
-
- @Override
public void accumulate(Tuple input) throws IOException
{
for (Tuple t : (DataBag) input.get(0)) {
diff --git a/src/java/datafu/pig/stats/StreamingQuantile.java b/src/java/datafu/pig/stats/StreamingQuantile.java
index 9f6a233..02eae00 100644
--- a/src/java/datafu/pig/stats/StreamingQuantile.java
+++ b/src/java/datafu/pig/stats/StreamingQuantile.java
@@ -21,6 +21,7 @@
import java.util.List;
import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
@@ -117,8 +118,8 @@
* @see StreamingMedian
* @see Quantile
*/
-public class StreamingQuantile extends SimpleEvalFunc<Tuple> implements Accumulator<Tuple> {
-
+public class StreamingQuantile extends AccumulatorEvalFunc<Tuple>
+{
private final int numQuantiles;
private final QuantileEstimator estimator;
private List<Double> quantiles;
@@ -242,14 +243,6 @@
return t;
}
- public Tuple call(DataBag b) throws IOException
- {
- accumulate(TupleFactory.getInstance().newTuple(b));
- Tuple ret = getValue();
- cleanup();
- return ret;
- }
-
@Override
public Schema outputSchema(Schema input)
{
diff --git a/test/pig/datafu/test/pig/hash/HashTests.java b/test/pig/datafu/test/pig/hash/HashTests.java
index bd4b4cf..2dc9de8 100644
--- a/test/pig/datafu/test/pig/hash/HashTests.java
+++ b/test/pig/datafu/test/pig/hash/HashTests.java
@@ -42,7 +42,7 @@
/**
register $JAR_PATH
- define MD5 datafu.pig.hash.MD5Base64();
+ define MD5 datafu.pig.hash.MD5('base64');
data_in = LOAD 'input' as (val:chararray);
diff --git a/test/pig/datafu/test/pig/sampling/SamplingTests.java b/test/pig/datafu/test/pig/sampling/SamplingTests.java
index 8231053..ab84832 100644
--- a/test/pig/datafu/test/pig/sampling/SamplingTests.java
+++ b/test/pig/datafu/test/pig/sampling/SamplingTests.java
@@ -1,9 +1,26 @@
package datafu.test.pig.sampling;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import junit.framework.Assert;
+
import org.adrianwalker.multilinestring.Multiline;
+import org.apache.pig.backend.executionengine.ExecException;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
import org.apache.pig.pigunit.PigTest;
import org.testng.annotations.Test;
+import datafu.pig.sampling.ReservoirSample;
+import datafu.pig.sampling.SampleByKey;
+import datafu.pig.sampling.WeightedSample;
import datafu.test.pig.PigTests;
@@ -68,6 +85,41 @@
"({(a,100),(c,5),(b,1)})");
}
+ @Test
+ public void weightedSampleLimitExecTest() throws IOException
+ {
+ WeightedSample sampler = new WeightedSample();
+
+ DataBag bag = BagFactory.getInstance().newDefaultBag();
+ for (int i=0; i<100; i++)
+ {
+ Tuple t = TupleFactory.getInstance().newTuple(2);
+ t.set(0, i);
+ t.set(1, 1); // score is equal for all
+ bag.add(t);
+ }
+
+ Tuple input = TupleFactory.getInstance().newTuple(3);
+ input.set(0, bag);
+ input.set(1, 1); // use index 1 for score
+ input.set(2, 10); // get 10 items
+
+ DataBag result = sampler.exec(input);
+
+ Assert.assertEquals(10, result.size());
+
+ // all must be found, no repeats
+ Set<Integer> found = new HashSet<Integer>();
+ for (Tuple t : result)
+ {
+ Integer i = (Integer)t.get(0);
+ System.out.println(i);
+ Assert.assertTrue(i>=0 && i<100);
+ Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+ found.add(i);
+ }
+ }
+
/**
register $JAR_PATH
@@ -171,6 +223,45 @@
}
+ @Test
+ public void sampleByKeyExecTest() throws Exception
+ {
+ SampleByKey sampler = new SampleByKey("thesalt","0.10");
+
+ Map<Integer,Integer> valuesPerKey = new HashMap<Integer,Integer>();
+
+ // 10,000 keys total
+ for (int i=0; i<10000; i++)
+ {
+ // 5 values per key
+ for (int j=0; j<5; j++)
+ {
+ Tuple t = TupleFactory.getInstance().newTuple(1);
+ t.set(0, i);
+ if (sampler.exec(t))
+ {
+ if (valuesPerKey.containsKey(i))
+ {
+ valuesPerKey.put(i, valuesPerKey.get(i)+1);
+ }
+ else
+ {
+ valuesPerKey.put(i, 1);
+ }
+ }
+ }
+ }
+
+ // 10% sample, so should have roughly 1000 keys
+ Assert.assertTrue(Math.abs(1000-valuesPerKey.size()) < 50);
+
+ // every value should be present for the same key
+ for (Map.Entry<Integer, Integer> pair : valuesPerKey.entrySet())
+ {
+ Assert.assertEquals(5, pair.getValue().intValue());
+ }
+ }
+
/**
register $JAR_PATH
@@ -230,4 +321,103 @@
assertOutput(test, "sampled", "("+reservoirSize+")");
}
}
+
+ @Test
+ public void reservoirSampleExecTest() throws IOException
+ {
+ ReservoirSample sampler = new ReservoirSample("10");
+
+ DataBag bag = BagFactory.getInstance().newDefaultBag();
+ for (int i=0; i<100; i++)
+ {
+ Tuple t = TupleFactory.getInstance().newTuple(1);
+ t.set(0, i);
+ bag.add(t);
+ }
+
+ Tuple input = TupleFactory.getInstance().newTuple(bag);
+
+ DataBag result = sampler.exec(input);
+
+ Assert.assertEquals(10, result.size());
+
+ // all must be found, no repeats
+ Set<Integer> found = new HashSet<Integer>();
+ for (Tuple t : result)
+ {
+ Integer i = (Integer)t.get(0);
+ System.out.println(i);
+ Assert.assertTrue(i>=0 && i<100);
+ Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+ found.add(i);
+ }
+ }
+
+ @Test
+ public void reservoirSampleAccumulateTest() throws IOException
+ {
+ ReservoirSample sampler = new ReservoirSample("10");
+
+ for (int i=0; i<100; i++)
+ {
+ Tuple t = TupleFactory.getInstance().newTuple(1);
+ t.set(0, i);
+ DataBag bag = BagFactory.getInstance().newDefaultBag();
+ bag.add(t);
+ Tuple input = TupleFactory.getInstance().newTuple(bag);
+ sampler.accumulate(input);
+ }
+
+ DataBag result = sampler.getValue();
+
+ Assert.assertEquals(10, result.size());
+
+ // all must be found, no repeats
+ Set<Integer> found = new HashSet<Integer>();
+ for (Tuple t : result)
+ {
+ Integer i = (Integer)t.get(0);
+ System.out.println(i);
+ Assert.assertTrue(i>=0 && i<100);
+ Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+ found.add(i);
+ }
+ }
+
+ @Test
+ public void reservoirSampleAlgebraicTest() throws IOException
+ {
+ ReservoirSample.Initial initialSampler = new ReservoirSample.Initial("10");
+ ReservoirSample.Intermediate intermediateSampler = new ReservoirSample.Intermediate("10");
+ ReservoirSample.Final finalSampler = new ReservoirSample.Final("10");
+
+ DataBag bag = BagFactory.getInstance().newDefaultBag();
+ for (int i=0; i<100; i++)
+ {
+ Tuple t = TupleFactory.getInstance().newTuple(1);
+ t.set(0, i);
+ bag.add(t);
+ }
+
+ Tuple input = TupleFactory.getInstance().newTuple(bag);
+
+ Tuple intermediateTuple = initialSampler.exec(input);
+ DataBag intermediateBag = BagFactory.getInstance().newDefaultBag(Arrays.asList(intermediateTuple));
+ intermediateTuple = intermediateSampler.exec(TupleFactory.getInstance().newTuple(intermediateBag));
+ intermediateBag = BagFactory.getInstance().newDefaultBag(Arrays.asList(intermediateTuple));
+ DataBag result = finalSampler.exec(TupleFactory.getInstance().newTuple(intermediateBag));
+
+ Assert.assertEquals(10, result.size());
+
+ // all must be found, no repeats
+ Set<Integer> found = new HashSet<Integer>();
+ for (Tuple t : result)
+ {
+ Integer i = (Integer)t.get(0);
+ System.out.println(i);
+ Assert.assertTrue(i>=0 && i<100);
+ Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+ found.add(i);
+ }
+ }
}