Merge pull request #52 from matthayes/master

Add more sampling tests
diff --git a/src/java/datafu/pig/bags/CountEach.java b/src/java/datafu/pig/bags/CountEach.java
index c346f10..2bf42e1 100644
--- a/src/java/datafu/pig/bags/CountEach.java
+++ b/src/java/datafu/pig/bags/CountEach.java
@@ -20,6 +20,7 @@
 import java.util.Map;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.EvalFunc;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
@@ -53,7 +54,7 @@
  * } 
  * </pre>
  */
-public class CountEach extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class CountEach extends AccumulatorEvalFunc<DataBag>
 {
   private boolean flatten = false;
   private Map<Tuple, Integer> counts = new HashMap<Tuple, Integer>();
@@ -110,21 +111,6 @@
   }
   
   @Override
-  public DataBag exec(Tuple input) throws IOException
-  {
-    try
-    {
-      accumulate(input);
-      
-      return getValue();
-    }
-    finally
-    {
-      cleanup();
-    }
-  }
-  
-  @Override
   public Schema outputSchema(Schema input)
   {
     try {
diff --git a/src/java/datafu/pig/bags/Enumerate.java b/src/java/datafu/pig/bags/Enumerate.java
index 59a9246..9bbca9c 100644
--- a/src/java/datafu/pig/bags/Enumerate.java
+++ b/src/java/datafu/pig/bags/Enumerate.java
@@ -19,6 +19,7 @@
 import java.io.IOException;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
 import org.apache.pig.data.DataType;
@@ -53,7 +54,7 @@
  * }
  * </pre>
  */
-public class Enumerate extends SimpleEvalFunc<DataBag> implements Accumulator<DataBag>
+public class Enumerate extends AccumulatorEvalFunc<DataBag>
 {
   private final int start;
   
@@ -72,15 +73,10 @@
     cleanup();
   }
   
-  public DataBag call(DataBag inputBag) throws IOException
+  @Override
+  public void accumulate(Tuple arg0) throws IOException
   {
-    cleanup();
-    outputBag = BagFactory.getInstance().newDefaultBag();
-    enumerateBag(inputBag);
-    return getValue();
-  }
-  
-  public void enumerateBag(DataBag inputBag){
+    DataBag inputBag = (DataBag)arg0.get(0);
     for (Tuple t : inputBag) {
       Tuple t1 = TupleFactory.getInstance().newTuple(t.getAll());
       t1.append(i);
@@ -94,13 +90,6 @@
       count++;
     }
   }
-  
-  @Override
-  public void accumulate(Tuple arg0) throws IOException
-  {
-    DataBag inputBag = (DataBag)arg0.get(0);
-    enumerateBag(inputBag);
-  }
 
   @Override
   public void cleanup()
diff --git a/src/java/datafu/pig/hash/MD5.java b/src/java/datafu/pig/hash/MD5.java
index 0f104c2..2bf6293 100644
--- a/src/java/datafu/pig/hash/MD5.java
+++ b/src/java/datafu/pig/hash/MD5.java
@@ -20,27 +20,56 @@
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 
+import org.apache.commons.codec.binary.Base64;
+
 import datafu.pig.util.SimpleEvalFunc;
 
 /**
- * Computes the MD5 value of a string and outputs it in hex. 
+ * Computes the MD5 value of a string and outputs it in hex (by default).
+ * A method can be provided to the constructor, which may be either 'hex' or 'base64'.
  */
 public class MD5 extends SimpleEvalFunc<String>
 {
-    private final MessageDigest md5er;
-
-    public MD5()
+  private final MessageDigest md5er;
+  private final boolean isBase64;
+  
+  public MD5()
+  {
+    this("hex");
+  }
+  
+  public MD5(String method)
+  {
+    if ("hex".equals(method))
     {
-        try {
-            md5er = MessageDigest.getInstance("md5");
-        }
-        catch (NoSuchAlgorithmException e) {
-            throw new RuntimeException(e);
-        }
+      isBase64 = false;
     }
-
-    public String call(String val)
+    else if ("base64".equals(method))
     {
-        return new BigInteger(1, md5er.digest(val.getBytes())).toString(16);
+      isBase64 = true;
     }
+    else
+    {
+      throw new IllegalArgumentException("Expected either hex or base64");
+    }
+    
+    try {
+      md5er = MessageDigest.getInstance("md5");
+    }
+    catch (NoSuchAlgorithmException e) {
+      throw new RuntimeException(e);
+    }
+  }
+  
+  public String call(String val)
+  {
+    if (isBase64)
+    {
+      return new String(Base64.encodeBase64(md5er.digest(val.getBytes())));
+    }
+    else
+    {
+      return new BigInteger(1, md5er.digest(val.getBytes())).toString(16);
+    }
+  }
 }
diff --git a/src/java/datafu/pig/hash/MD5Base64.java b/src/java/datafu/pig/hash/MD5Base64.java
deleted file mode 100644
index 0ef4b44..0000000
--- a/src/java/datafu/pig/hash/MD5Base64.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright 2010 LinkedIn, Inc
- * 
- * Licensed under the Apache License, Version 2.0 (the "License"); you may not
- * use this file except in compliance with the License. You may obtain a copy of
- * the License at
- * 
- * http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations under
- * the License.
- */
- 
-package datafu.pig.hash;
-
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
-
-import org.apache.commons.codec.binary.Base64;
-
-import datafu.pig.util.SimpleEvalFunc;
-
-/**
- * Computes the MD5 value of a string and outputs it in Base64 encoding.
- */
-public class MD5Base64 extends SimpleEvalFunc<String>
-{
-    private final MessageDigest md5er;
-
-    public MD5Base64()
-    {
-        try {
-            md5er = MessageDigest.getInstance("md5");
-        }
-        catch (NoSuchAlgorithmException e) {
-            throw new RuntimeException(e);
-        }
-    }
-
-    public String call(String val)
-    {
-        return new String(Base64.encodeBase64(md5er.digest(val.getBytes())));
-    }
-}
diff --git a/src/java/datafu/pig/linkanalysis/PageRank.java b/src/java/datafu/pig/linkanalysis/PageRank.java
index dc2e35f..482d854 100644
--- a/src/java/datafu/pig/linkanalysis/PageRank.java
+++ b/src/java/datafu/pig/linkanalysis/PageRank.java
@@ -25,6 +25,7 @@
 import java.util.Map;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.EvalFunc;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
@@ -135,7 +136,7 @@
  * </pre>
  * </p> 
  */
-public class PageRank extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class PageRank extends AccumulatorEvalFunc<DataBag>
 {
   private final datafu.pig.linkanalysis.PageRankImpl graph = new datafu.pig.linkanalysis.PageRankImpl();
 
@@ -367,21 +368,6 @@
     }
   }
 
-  @Override
-  public DataBag exec(Tuple input) throws IOException
-  {
-    try
-    {
-      accumulate(input);
-      
-      return getValue();
-    }
-    finally
-    {
-      cleanup();
-    }
-  }
-
   private ProgressIndicator getProgressIndicator()
   {
     return new ProgressIndicator()
diff --git a/src/java/datafu/pig/sampling/ReservoirSample.java b/src/java/datafu/pig/sampling/ReservoirSample.java
index b1dce9d..23e70bd 100644
--- a/src/java/datafu/pig/sampling/ReservoirSample.java
+++ b/src/java/datafu/pig/sampling/ReservoirSample.java
@@ -18,8 +18,8 @@
 
 import java.io.IOException;
 
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.Algebraic;
-import org.apache.pig.AlgebraicEvalFunc;
 import org.apache.pig.EvalFunc;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
@@ -37,7 +37,7 @@
  * @author wvaughan
  *
  */
-public class ReservoirSample extends AlgebraicEvalFunc<DataBag> implements Algebraic
+public class ReservoirSample extends AccumulatorEvalFunc<DataBag> implements Algebraic
 {
   Integer numSamples;
   private Reservoir reservoir;
@@ -56,24 +56,42 @@
   }
 
   @Override
-  public DataBag exec(Tuple input) throws IOException 
-  {    
+  public void accumulate(Tuple input) throws IOException
+  {
     DataBag samples = (DataBag) input.get(0);
-    if (samples == null || samples.size() <= numSamples) {
-      return samples;
-    }
-    
     for (Tuple sample : samples) {
       getReservoir().consider(new ScoredTuple(Math.random(), sample));
-    }    
-    
+    }  
+  }
+
+  @Override
+  public void cleanup()
+  {
+    this.reservoir = null;
+  }
+
+  @Override
+  public DataBag getValue()
+  {
     DataBag output = BagFactory.getInstance().newDefaultBag();  
     for (ScoredTuple sample : getReservoir()) {
       output.add(sample.getTuple());
     }
-
     return output;
   }
+
+  @Override
+  public DataBag exec(Tuple input) throws IOException 
+  {    
+    DataBag samples = (DataBag)input.get(0);
+    if (samples.size() <= numSamples) {
+      return samples;
+    }
+    else
+    {
+      return super.exec(input);
+    }
+  }
   
   @Override
   public Schema outputSchema(Schema input) {
@@ -146,7 +164,11 @@
       DataBag output = BagFactory.getInstance().newDefaultBag();
       
       DataBag samples = (DataBag) input.get(0);
-      if (samples == null || samples.size() <= numSamples) {
+      if (samples == null)
+      {
+        // do nothing
+      }
+      else if (samples.size() <= numSamples) {
         // no need to construct a reservoir, so just emit intermediate tuples
         for (Tuple sample : samples) {
           // add the score on to the intermediate tuple
@@ -201,7 +223,7 @@
         }
       }
       
-      DataBag output = BagFactory.getInstance().newDefaultBag();  
+      DataBag output = BagFactory.getInstance().newDefaultBag();
       for (ScoredTuple scoredTuple : getReservoir()) {
         // add the score on to the intermediate tuple
         output.add(scoredTuple.getIntermediateTuple(tupleFactory));
diff --git a/src/java/datafu/pig/sessions/SessionCount.java b/src/java/datafu/pig/sessions/SessionCount.java
index 671f3e7..80c4cf8 100644
--- a/src/java/datafu/pig/sessions/SessionCount.java
+++ b/src/java/datafu/pig/sessions/SessionCount.java
@@ -19,6 +19,7 @@
 import java.io.IOException;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.EvalFunc;
 import org.apache.pig.data.DataBag;
 import org.apache.pig.data.Tuple;
@@ -56,7 +57,7 @@
  * </pre>
  * 
  */
-public class SessionCount extends EvalFunc<Long> implements Accumulator<Long>
+public class SessionCount extends AccumulatorEvalFunc<Long>
 {
   private final long millis;
   private DateTime last_date;
@@ -99,14 +100,4 @@
     this.last_date = null;
     this.sum = 0;
   }
-
-  @Override
-  public Long exec(Tuple input) throws IOException
-  {
-    accumulate(input);
-    Long result = getValue();
-    cleanup();
-
-    return result;
-  }
 }
diff --git a/src/java/datafu/pig/sessions/Sessionize.java b/src/java/datafu/pig/sessions/Sessionize.java
index a9aeaae..d83c2f4 100644
--- a/src/java/datafu/pig/sessions/Sessionize.java
+++ b/src/java/datafu/pig/sessions/Sessionize.java
@@ -20,6 +20,7 @@
 import java.util.UUID;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.EvalFunc;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
@@ -67,7 +68,7 @@
  * </pre>
  * </p>
  */
-public class Sessionize extends EvalFunc<DataBag> implements Accumulator<DataBag>
+public class Sessionize extends AccumulatorEvalFunc<DataBag>
 {
   private final long millis;
 
@@ -84,16 +85,6 @@
   }
 
   @Override
-  public DataBag exec(Tuple input) throws IOException
-  {
-    accumulate(input);
-    DataBag outputBag = getValue();
-    cleanup();
-
-    return outputBag;
-  }
-
-  @Override
   public void accumulate(Tuple input) throws IOException
   {
     for (Tuple t : (DataBag) input.get(0)) {
diff --git a/src/java/datafu/pig/stats/StreamingQuantile.java b/src/java/datafu/pig/stats/StreamingQuantile.java
index 9f6a233..02eae00 100644
--- a/src/java/datafu/pig/stats/StreamingQuantile.java
+++ b/src/java/datafu/pig/stats/StreamingQuantile.java
@@ -21,6 +21,7 @@
 import java.util.List;
 
 import org.apache.pig.Accumulator;
+import org.apache.pig.AccumulatorEvalFunc;
 import org.apache.pig.data.DataBag;
 import org.apache.pig.data.DataType;
 import org.apache.pig.data.Tuple;
@@ -117,8 +118,8 @@
  * @see StreamingMedian
  * @see Quantile
  */
-public class StreamingQuantile extends SimpleEvalFunc<Tuple> implements Accumulator<Tuple> {
-
+public class StreamingQuantile extends AccumulatorEvalFunc<Tuple>
+{
   private final int numQuantiles;
   private final QuantileEstimator estimator;
   private List<Double> quantiles;
@@ -242,14 +243,6 @@
     return t;
   }
 
-  public Tuple call(DataBag b) throws IOException
-  {
-    accumulate(TupleFactory.getInstance().newTuple(b));
-    Tuple ret = getValue();
-    cleanup();
-    return ret;
-  }
-
   @Override
   public Schema outputSchema(Schema input)
   {
diff --git a/test/pig/datafu/test/pig/hash/HashTests.java b/test/pig/datafu/test/pig/hash/HashTests.java
index bd4b4cf..2dc9de8 100644
--- a/test/pig/datafu/test/pig/hash/HashTests.java
+++ b/test/pig/datafu/test/pig/hash/HashTests.java
@@ -42,7 +42,7 @@
   /**
   register $JAR_PATH
 
-  define MD5 datafu.pig.hash.MD5Base64();
+  define MD5 datafu.pig.hash.MD5('base64');
   
   data_in = LOAD 'input' as (val:chararray);
   
diff --git a/test/pig/datafu/test/pig/sampling/SamplingTests.java b/test/pig/datafu/test/pig/sampling/SamplingTests.java
index 8231053..ab84832 100644
--- a/test/pig/datafu/test/pig/sampling/SamplingTests.java
+++ b/test/pig/datafu/test/pig/sampling/SamplingTests.java
@@ -1,9 +1,26 @@
 package datafu.test.pig.sampling;
 
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import junit.framework.Assert;
+
 import org.adrianwalker.multilinestring.Multiline;
+import org.apache.pig.backend.executionengine.ExecException;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
 import org.apache.pig.pigunit.PigTest;
 import org.testng.annotations.Test;
 
+import datafu.pig.sampling.ReservoirSample;
+import datafu.pig.sampling.SampleByKey;
+import datafu.pig.sampling.WeightedSample;
 import datafu.test.pig.PigTests;
 
 
@@ -68,6 +85,41 @@
         "({(a,100),(c,5),(b,1)})");
   }
   
+  @Test
+  public void weightedSampleLimitExecTest() throws IOException
+  {
+    WeightedSample sampler = new WeightedSample();
+    
+    DataBag bag = BagFactory.getInstance().newDefaultBag();
+    for (int i=0; i<100; i++)
+    {
+      Tuple t = TupleFactory.getInstance().newTuple(2);
+      t.set(0, i);
+      t.set(1, 1); // score is equal for all
+      bag.add(t);
+    }
+    
+    Tuple input = TupleFactory.getInstance().newTuple(3);
+    input.set(0, bag);
+    input.set(1, 1); // use index 1 for score
+    input.set(2, 10); // get 10 items
+    
+    DataBag result = sampler.exec(input);
+    
+    Assert.assertEquals(10, result.size());
+    
+    // all must be found, no repeats
+    Set<Integer> found = new HashSet<Integer>();
+    for (Tuple t : result)
+    {
+      Integer i = (Integer)t.get(0);
+      System.out.println(i);
+      Assert.assertTrue(i>=0 && i<100);
+      Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+      found.add(i);
+    }
+  }
+  
   /**
   register $JAR_PATH
   
@@ -171,6 +223,45 @@
                    
   }
   
+  @Test
+  public void sampleByKeyExecTest() throws Exception
+  {
+    SampleByKey sampler = new SampleByKey("thesalt","0.10");
+    
+    Map<Integer,Integer> valuesPerKey = new HashMap<Integer,Integer>();
+    
+    // 10,000 keys total
+    for (int i=0; i<10000; i++)
+    {
+      // 5 values per key
+      for (int j=0; j<5; j++)
+      {
+        Tuple t = TupleFactory.getInstance().newTuple(1);
+        t.set(0, i);
+        if (sampler.exec(t))
+        {          
+          if (valuesPerKey.containsKey(i))
+          {
+            valuesPerKey.put(i, valuesPerKey.get(i)+1);
+          }
+          else
+          {
+            valuesPerKey.put(i, 1);
+          }
+        }
+      }
+    }
+    
+    // 10% sample, so should have roughly 1000 keys
+    Assert.assertTrue(Math.abs(1000-valuesPerKey.size()) < 50);
+    
+    // every value should be present for the same key
+    for (Map.Entry<Integer, Integer> pair : valuesPerKey.entrySet())
+    {
+      Assert.assertEquals(5, pair.getValue().intValue());
+    }
+  }
+  
   /**
   register $JAR_PATH
 
@@ -230,4 +321,103 @@
       assertOutput(test, "sampled", "("+reservoirSize+")");
     }
   }
+  
+  @Test
+  public void reservoirSampleExecTest() throws IOException
+  {
+    ReservoirSample sampler = new ReservoirSample("10");
+    
+    DataBag bag = BagFactory.getInstance().newDefaultBag();
+    for (int i=0; i<100; i++)
+    {
+      Tuple t = TupleFactory.getInstance().newTuple(1);
+      t.set(0, i);
+      bag.add(t);
+    }
+    
+    Tuple input = TupleFactory.getInstance().newTuple(bag);
+    
+    DataBag result = sampler.exec(input);
+    
+    Assert.assertEquals(10, result.size());
+    
+    // all must be found, no repeats
+    Set<Integer> found = new HashSet<Integer>();
+    for (Tuple t : result)
+    {
+      Integer i = (Integer)t.get(0);
+      System.out.println(i);
+      Assert.assertTrue(i>=0 && i<100);
+      Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+      found.add(i);
+    }
+  }
+  
+  @Test
+  public void reservoirSampleAccumulateTest() throws IOException
+  {
+    ReservoirSample sampler = new ReservoirSample("10");
+    
+    for (int i=0; i<100; i++)
+    {
+      Tuple t = TupleFactory.getInstance().newTuple(1);
+      t.set(0, i);
+      DataBag bag = BagFactory.getInstance().newDefaultBag();
+      bag.add(t);
+      Tuple input = TupleFactory.getInstance().newTuple(bag);
+      sampler.accumulate(input);
+    }
+        
+    DataBag result = sampler.getValue();
+    
+    Assert.assertEquals(10, result.size());
+    
+    // all must be found, no repeats
+    Set<Integer> found = new HashSet<Integer>();
+    for (Tuple t : result)
+    {
+      Integer i = (Integer)t.get(0);
+      System.out.println(i);
+      Assert.assertTrue(i>=0 && i<100);
+      Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+      found.add(i);
+    }
+  }
+  
+  @Test
+  public void reservoirSampleAlgebraicTest() throws IOException
+  {
+    ReservoirSample.Initial initialSampler = new ReservoirSample.Initial("10");
+    ReservoirSample.Intermediate intermediateSampler = new ReservoirSample.Intermediate("10");
+    ReservoirSample.Final finalSampler = new ReservoirSample.Final("10");
+    
+    DataBag bag = BagFactory.getInstance().newDefaultBag();
+    for (int i=0; i<100; i++)
+    {
+      Tuple t = TupleFactory.getInstance().newTuple(1);
+      t.set(0, i);
+      bag.add(t);
+    }
+    
+    Tuple input = TupleFactory.getInstance().newTuple(bag);
+    
+    Tuple intermediateTuple = initialSampler.exec(input);  
+    DataBag intermediateBag = BagFactory.getInstance().newDefaultBag(Arrays.asList(intermediateTuple));
+    intermediateTuple = intermediateSampler.exec(TupleFactory.getInstance().newTuple(intermediateBag));  
+    intermediateBag = BagFactory.getInstance().newDefaultBag(Arrays.asList(intermediateTuple));
+    DataBag result = finalSampler.exec(TupleFactory.getInstance().newTuple(intermediateBag));
+    
+    Assert.assertEquals(10, result.size());
+    
+    // all must be found, no repeats
+    Set<Integer> found = new HashSet<Integer>();
+    for (Tuple t : result)
+    {
+      Integer i = (Integer)t.get(0);
+      System.out.println(i);
+      Assert.assertTrue(i>=0 && i<100);
+      Assert.assertFalse(String.format("Found duplicate of %d",i), found.contains(i));
+      found.add(i);
+    }
+  }
 }