add SimpleRandomSample to sampling, which implements a scalable simple random sampling algorithm
diff --git a/src/java/datafu/pig/sampling/SimpleRandomSample.java b/src/java/datafu/pig/sampling/SimpleRandomSample.java
new file mode 100644
index 0000000..169e0ab
--- /dev/null
+++ b/src/java/datafu/pig/sampling/SimpleRandomSample.java
@@ -0,0 +1,322 @@
+package datafu.pig.sampling;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Random;
+
+import org.apache.pig.AlgebraicEvalFunc;
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * Scalable simple random sampling.
+ * <p/>
+ * This UDF implements a scalable simple random sampling algorithm described in
+ * 
+ * <pre>
+ * X. Meng, Scalable Simple Random Sampling and Stratified Sampling, ICML 2013.
+ * </pre>
+ * 
+ * It takes a sampling probability p as input and outputs a simple random sample of size
+ * exactly ceil(p*n) with probability at least 99.99%, where $n$ is the size of the
+ * population. This UDF is very useful for stratified sampling. For example,
+ * 
+ * <pre>
+ * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('0.01');
+ * examples = LOAD ...
+ * grouped = GROUP examples BY label;
+ * sampled = FOREACH grouped GENERATE FLATTEN(SRS(examples));
+ * STORE sampled ...
+ * </pre>
+ * 
+ * We note that, in a Java Hadoop job, we can output pre-selected records directly using
+ * MultipleOutputs. However, this feature is not available in a Pig UDF. So we still let
+ * pre-selected records go through the sort phase. However, as long as the sample size is
+ * not huge, this should not be a big problem.
+ * 
+ * @author ximeng
+ * 
+ */
+public class SimpleRandomSample extends AlgebraicEvalFunc<DataBag>
+{
+  public static final TupleFactory tupleFactory = TupleFactory.getInstance();
+  public static final BagFactory bagFactory = BagFactory.getInstance();
+
+  public SimpleRandomSample()
+  {
+  }
+
+  public SimpleRandomSample(String samplingProbability)
+  {
+    Double p = Double.parseDouble(samplingProbability);
+
+    if (p < 0.0 || p > 1.0)
+    {
+      throw new IllegalArgumentException("Sampling probability must be inside [0, 1].");
+    }
+  }
+
+  @Override
+  public String getInitial()
+  {
+    return Initial.class.getName();
+  }
+
+  @Override
+  public String getIntermed()
+  {
+    return Intermediate.class.getName();
+  }
+
+  @Override
+  public String getFinal()
+  {
+    return Final.class.getName();
+  }
+
+  @Override
+  public Schema outputSchema(Schema input)
+  {
+    try
+    {
+      Schema.FieldSchema inputFieldSchema = input.getField(0);
+
+      if (inputFieldSchema.type != DataType.BAG)
+      {
+        throw new RuntimeException("Expected a BAG as input");
+      }
+
+      return new Schema(new Schema.FieldSchema(getSchemaName(this.getClass()
+                                                                 .getName()
+                                                                 .toLowerCase(), input),
+                                               inputFieldSchema.schema,
+                                               DataType.BAG));
+    }
+    catch (FrontendException e)
+    {
+      e.printStackTrace();
+      throw new RuntimeException(e);
+    }
+  }
+
+  static public class Initial extends EvalFunc<Tuple>
+  {
+    private double _samplingProbability;
+    private Random _random = new Random(System.nanoTime() % 996209L);
+
+    public Initial()
+    {
+    }
+
+    public Initial(String samplingProbability)
+    {
+      _samplingProbability = Double.parseDouble(samplingProbability);
+    }
+
+    @Override
+    public Tuple exec(Tuple input) throws IOException
+    {
+      Tuple output = tupleFactory.newTuple();
+      DataBag selected = bagFactory.newDefaultBag();
+      DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+
+      DataBag items = (DataBag) input.get(0);
+
+      if (items != null)
+      {
+        long n = items.size();
+
+        double q1 = getQ1(n, _samplingProbability);
+        double q2 = getQ2(n, _samplingProbability);
+
+        for (Tuple item : items)
+        {
+          double key = _random.nextDouble();
+
+          if (key < q1)
+          {
+            selected.add(item);
+          }
+          else if (key < q2)
+          {
+            waiting.add(new ScoredTuple(key, item).getIntermediateTuple(tupleFactory));
+          }
+        }
+
+        output.append(n);
+        output.append(selected);
+        output.append(waiting);
+      }
+
+      return output;
+    }
+  }
+
+  public static class Intermediate extends EvalFunc<Tuple>
+  {
+    public Intermediate()
+    {
+    }
+
+    public Intermediate(String samplingProbability)
+    {
+      _samplingProbability = Double.parseDouble(samplingProbability);
+    }
+
+    private double _samplingProbability;
+
+    @Override
+    public Tuple exec(Tuple input) throws IOException
+    {
+      DataBag bag = (DataBag) input.get(0);
+      DataBag selected = bagFactory.newDefaultBag();
+      DataBag aggWaiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+      DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+      Tuple output = tupleFactory.newTuple();
+
+      long n = 0L;
+
+      for (Tuple innerTuple : bag)
+      {
+        n += (Long) innerTuple.get(0);
+
+        selected.addAll((DataBag) innerTuple.get(1));
+
+        double q1 = getQ1(n, _samplingProbability);
+        double q2 = getQ2(n, _samplingProbability);
+
+        for (Tuple t : (DataBag) innerTuple.get(2))
+        {
+          ScoredTuple scored = ScoredTuple.fromIntermediateTuple(t);
+
+          if (scored.getScore() < q1)
+          {
+            selected.add(scored.getTuple());
+          }
+          else if (scored.getScore() < q2)
+          {
+            aggWaiting.add(t);
+          }
+          else
+          {
+            break;
+          }
+        }
+      }
+
+      double q1 = getQ1(n, _samplingProbability);
+      double q2 = getQ2(n, _samplingProbability);
+
+      for (Tuple t : aggWaiting)
+      {
+        ScoredTuple scored = ScoredTuple.fromIntermediateTuple(t);
+
+        if (scored.getScore() < q1)
+        {
+          selected.add(scored.getTuple());
+        }
+        else if (scored.getScore() < q2)
+        {
+          waiting.add(t);
+        }
+        else
+        {
+          break;
+        }
+      }
+
+      output.append(n);
+      output.append(selected);
+      output.append(waiting);
+
+      System.err.println("Read " + n + " items, selected " + selected.size()
+          + ", and wait-listed " + aggWaiting.size() + ".");
+
+      return output;
+    }
+  }
+
+  static public class Final extends EvalFunc<DataBag>
+  {
+    private double _samplingProbability;
+
+    public Final()
+    {
+    }
+
+    public Final(String samplingProbability)
+    {
+      _samplingProbability = Double.parseDouble(samplingProbability);
+    }
+
+    @Override
+    public DataBag exec(Tuple input) throws IOException
+    {
+      DataBag bag = (DataBag) input.get(0);
+      long n = 0L;
+      DataBag selected = bagFactory.newDefaultBag();
+      DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+
+      for (Tuple innerTuple : bag)
+      {
+        n += (Long) innerTuple.get(0);
+        selected.addAll((DataBag) innerTuple.get(1));
+        waiting.addAll((DataBag) innerTuple.get(2));
+      }
+
+      long sampleSize = (long) Math.ceil(_samplingProbability * n);
+      long nNeeded = sampleSize - selected.size();
+
+      for (Tuple scored : waiting)
+      {
+        if (nNeeded <= 0)
+        {
+          break;
+        }
+        selected.add(ScoredTuple.fromIntermediateTuple(scored).getTuple());
+        nNeeded--;
+      }
+
+      return selected;
+    }
+  }
+
+  private static class ScoredTupleComparator implements Comparator<Tuple>
+  {
+
+    @Override
+    public int compare(Tuple o1, Tuple o2)
+    {
+      try
+      {
+        ScoredTuple t1 = ScoredTuple.fromIntermediateTuple(o1);
+        ScoredTuple t2 = ScoredTuple.fromIntermediateTuple(o2);
+        return t1.getScore().compareTo(t2.getScore());
+      }
+      catch (Throwable e)
+      {
+        throw new RuntimeException("Cannot compare " + o1 + " and " + o2 + ".", e);
+      }
+    }
+  }
+
+  private static double getQ1(long n, double p)
+  {
+    double t1 = 20.0 / (3.0 * n);
+    double q1 = p + t1 - Math.sqrt(t1 * t1 + 3.0 * t1 * p);
+    return q1;
+  }
+
+  private static double getQ2(long n, double p)
+  {
+    double t2 = 10.0 / n;
+    double q2 = p + t2 + Math.sqrt(t2 * t2 + 2.0 * t2 * p);
+    return q2;
+  }
+}
diff --git a/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java b/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java
new file mode 100644
index 0000000..b20d89d
--- /dev/null
+++ b/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java
@@ -0,0 +1,155 @@
+package datafu.test.pig.sampling;
+
+import org.adrianwalker.multilinestring.Multiline;
+import org.apache.pig.pigunit.PigTest;
+import org.testng.annotations.Test;
+
+import datafu.pig.sampling.SimpleRandomSample;
+import datafu.test.pig.PigTests;
+
+/**
+ * Tests for {@link SimpleRandomSample}.
+ * 
+ * @author ximeng
+ *
+ */
+public class SimpleRandomSampleTest extends PigTests
+{
+  /**
+   * register $JAR_PATH
+   * 
+   * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('$SAMPLING_PROBABILITY');
+   * 
+   * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+   * 
+   * sampled = FOREACH (GROUP data ALL) GENERATE SRS(data) as sample_data;
+   * 
+   * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count;
+   * 
+   * STORE sampled INTO 'output';
+   */
+  @Multiline
+  private String simpleRandomSampleTest;
+
+  @Test
+  public void simpleRandomSampleTest() throws Exception
+  {
+    writeLinesToFile("input",
+                     "A1\tB1\t1",
+                     "A1\tB1\t4",
+                     "A1\tB3\t4",
+                     "A1\tB4\t4",
+                     "A2\tB1\t4",
+                     "A2\tB2\t4",
+                     "A3\tB1\t3",
+                     "A3\tB1\t1",
+                     "A3\tB3\t77",
+                     "A4\tB1\t3",
+                     "A4\tB2\t3",
+                     "A4\tB3\t59",
+                     "A4\tB4\t29",
+                     "A5\tB1\t4",
+                     "A6\tB2\t3",
+                     "A6\tB2\t55",
+                     "A6\tB3\t1",
+                     "A7\tB1\t39",
+                     "A7\tB2\t27",
+                     "A7\tB3\t85",
+                     "A8\tB1\t4",
+                     "A8\tB2\t45",
+                     "A9\tB3\t92",
+                     "A9\tB3\t0",
+                     "A9\tB6\t42",
+                     "A9\tB5\t1",
+                     "A10\tB1\t7",
+                     "A10\tB2\t23",
+                     "A10\tB2\t1",
+                     "A10\tB2\t31",
+                     "A10\tB6\t41",
+                     "A10\tB7\t52");
+
+    int n = 32;
+    double p = 0.3;
+    int s = (int) Math.ceil(p * n);
+    PigTest test =
+        createPigTestFromString(simpleRandomSampleTest, "SAMPLING_PROBABILITY=" + p);
+
+    test.runScript();
+
+    assertOutput(test, "sampled", "(" + s + ")");
+  }
+
+  /**
+   * register $JAR_PATH
+   * 
+   * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('$SAMPLING_PROBABILITY');
+   * 
+   * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+   * 
+   * sampled = FOREACH (GROUP data BY A_id) GENERATE group, SRS(data) as sample_data;
+   * 
+   * sampled = FOREACH sampled GENERATE group, COUNT(sample_data) AS sample_count;
+   * 
+   * sampled = ORDER sampled BY group;
+   * 
+   * STORE sampled INTO 'output';
+   */
+  @Multiline
+  private String stratifiedSampleTest;
+
+  @Test
+  public void stratifiedSampleTest() throws Exception
+  {
+    writeLinesToFile("input",
+                     "A1\tB1\t1",
+                     "A1\tB1\t4",
+                     "A1\tB3\t4",
+                     "A1\tB4\t4",
+                     "A2\tB1\t4",
+                     "A2\tB2\t4",
+                     "A3\tB1\t3",
+                     "A3\tB1\t1",
+                     "A3\tB3\t77",
+                     "A4\tB1\t3",
+                     "A4\tB2\t3",
+                     "A4\tB3\t59",
+                     "A4\tB4\t29",
+                     "A5\tB1\t4",
+                     "A6\tB2\t3",
+                     "A6\tB2\t55",
+                     "A6\tB3\t1",
+                     "A7\tB1\t39",
+                     "A7\tB2\t27",
+                     "A7\tB3\t85",
+                     "A8\tB1\t4",
+                     "A8\tB2\t45",
+                     "A9\tB3\t92",
+                     "A9\tB3\t0",
+                     "A9\tB6\t42",
+                     "A9\tB5\t1",
+                     "A10\tB1\t7",
+                     "A10\tB2\t23",
+                     "A10\tB2\t1",
+                     "A10\tB2\t31",
+                     "A10\tB6\t41",
+                     "A10\tB7\t52");
+
+    double p = 0.5;
+
+    PigTest test =
+        createPigTestFromString(stratifiedSampleTest, "SAMPLING_PROBABILITY=" + p);
+    test.runScript();
+    assertOutput(test,
+                 "sampled",
+                 "(A1,2)",
+                 "(A10,3)",
+                 "(A2,1)",
+                 "(A3,2)",
+                 "(A4,2)",
+                 "(A5,1)",
+                 "(A6,2)",
+                 "(A7,2)",
+                 "(A8,1)",
+                 "(A9,2)");
+  }
+}