add SimpleRandomSample to sampling, which implements a scalable simple random sampling algorithm
diff --git a/src/java/datafu/pig/sampling/SimpleRandomSample.java b/src/java/datafu/pig/sampling/SimpleRandomSample.java
new file mode 100644
index 0000000..169e0ab
--- /dev/null
+++ b/src/java/datafu/pig/sampling/SimpleRandomSample.java
@@ -0,0 +1,322 @@
+package datafu.pig.sampling;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Random;
+
+import org.apache.pig.AlgebraicEvalFunc;
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * Scalable simple random sampling.
+ * <p/>
+ * This UDF implements a scalable simple random sampling algorithm described in
+ *
+ * <pre>
+ * X. Meng, Scalable Simple Random Sampling and Stratified Sampling, ICML 2013.
+ * </pre>
+ *
+ * It takes a sampling probability p as input and outputs a simple random sample of size
+ * exactly ceil(p*n) with probability at least 99.99%, where $n$ is the size of the
+ * population. This UDF is very useful for stratified sampling. For example,
+ *
+ * <pre>
+ * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('0.01');
+ * examples = LOAD ...
+ * grouped = GROUP examples BY label;
+ * sampled = FOREACH grouped GENERATE FLATTEN(SRS(examples));
+ * STORE sampled ...
+ * </pre>
+ *
+ * We note that, in a Java Hadoop job, we can output pre-selected records directly using
+ * MultipleOutputs. However, this feature is not available in a Pig UDF. So we still let
+ * pre-selected records go through the sort phase. However, as long as the sample size is
+ * not huge, this should not be a big problem.
+ *
+ * @author ximeng
+ *
+ */
+public class SimpleRandomSample extends AlgebraicEvalFunc<DataBag>
+{
+ public static final TupleFactory tupleFactory = TupleFactory.getInstance();
+ public static final BagFactory bagFactory = BagFactory.getInstance();
+
+ public SimpleRandomSample()
+ {
+ }
+
+ public SimpleRandomSample(String samplingProbability)
+ {
+ Double p = Double.parseDouble(samplingProbability);
+
+ if (p < 0.0 || p > 1.0)
+ {
+ throw new IllegalArgumentException("Sampling probability must be inside [0, 1].");
+ }
+ }
+
+ @Override
+ public String getInitial()
+ {
+ return Initial.class.getName();
+ }
+
+ @Override
+ public String getIntermed()
+ {
+ return Intermediate.class.getName();
+ }
+
+ @Override
+ public String getFinal()
+ {
+ return Final.class.getName();
+ }
+
+ @Override
+ public Schema outputSchema(Schema input)
+ {
+ try
+ {
+ Schema.FieldSchema inputFieldSchema = input.getField(0);
+
+ if (inputFieldSchema.type != DataType.BAG)
+ {
+ throw new RuntimeException("Expected a BAG as input");
+ }
+
+ return new Schema(new Schema.FieldSchema(getSchemaName(this.getClass()
+ .getName()
+ .toLowerCase(), input),
+ inputFieldSchema.schema,
+ DataType.BAG));
+ }
+ catch (FrontendException e)
+ {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+
+ static public class Initial extends EvalFunc<Tuple>
+ {
+ private double _samplingProbability;
+ private Random _random = new Random(System.nanoTime() % 996209L);
+
+ public Initial()
+ {
+ }
+
+ public Initial(String samplingProbability)
+ {
+ _samplingProbability = Double.parseDouble(samplingProbability);
+ }
+
+ @Override
+ public Tuple exec(Tuple input) throws IOException
+ {
+ Tuple output = tupleFactory.newTuple();
+ DataBag selected = bagFactory.newDefaultBag();
+ DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+
+ DataBag items = (DataBag) input.get(0);
+
+ if (items != null)
+ {
+ long n = items.size();
+
+ double q1 = getQ1(n, _samplingProbability);
+ double q2 = getQ2(n, _samplingProbability);
+
+ for (Tuple item : items)
+ {
+ double key = _random.nextDouble();
+
+ if (key < q1)
+ {
+ selected.add(item);
+ }
+ else if (key < q2)
+ {
+ waiting.add(new ScoredTuple(key, item).getIntermediateTuple(tupleFactory));
+ }
+ }
+
+ output.append(n);
+ output.append(selected);
+ output.append(waiting);
+ }
+
+ return output;
+ }
+ }
+
+ public static class Intermediate extends EvalFunc<Tuple>
+ {
+ public Intermediate()
+ {
+ }
+
+ public Intermediate(String samplingProbability)
+ {
+ _samplingProbability = Double.parseDouble(samplingProbability);
+ }
+
+ private double _samplingProbability;
+
+ @Override
+ public Tuple exec(Tuple input) throws IOException
+ {
+ DataBag bag = (DataBag) input.get(0);
+ DataBag selected = bagFactory.newDefaultBag();
+ DataBag aggWaiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+ DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+ Tuple output = tupleFactory.newTuple();
+
+ long n = 0L;
+
+ for (Tuple innerTuple : bag)
+ {
+ n += (Long) innerTuple.get(0);
+
+ selected.addAll((DataBag) innerTuple.get(1));
+
+ double q1 = getQ1(n, _samplingProbability);
+ double q2 = getQ2(n, _samplingProbability);
+
+ for (Tuple t : (DataBag) innerTuple.get(2))
+ {
+ ScoredTuple scored = ScoredTuple.fromIntermediateTuple(t);
+
+ if (scored.getScore() < q1)
+ {
+ selected.add(scored.getTuple());
+ }
+ else if (scored.getScore() < q2)
+ {
+ aggWaiting.add(t);
+ }
+ else
+ {
+ break;
+ }
+ }
+ }
+
+ double q1 = getQ1(n, _samplingProbability);
+ double q2 = getQ2(n, _samplingProbability);
+
+ for (Tuple t : aggWaiting)
+ {
+ ScoredTuple scored = ScoredTuple.fromIntermediateTuple(t);
+
+ if (scored.getScore() < q1)
+ {
+ selected.add(scored.getTuple());
+ }
+ else if (scored.getScore() < q2)
+ {
+ waiting.add(t);
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ output.append(n);
+ output.append(selected);
+ output.append(waiting);
+
+ System.err.println("Read " + n + " items, selected " + selected.size()
+ + ", and wait-listed " + aggWaiting.size() + ".");
+
+ return output;
+ }
+ }
+
+ static public class Final extends EvalFunc<DataBag>
+ {
+ private double _samplingProbability;
+
+ public Final()
+ {
+ }
+
+ public Final(String samplingProbability)
+ {
+ _samplingProbability = Double.parseDouble(samplingProbability);
+ }
+
+ @Override
+ public DataBag exec(Tuple input) throws IOException
+ {
+ DataBag bag = (DataBag) input.get(0);
+ long n = 0L;
+ DataBag selected = bagFactory.newDefaultBag();
+ DataBag waiting = bagFactory.newSortedBag(new ScoredTupleComparator());
+
+ for (Tuple innerTuple : bag)
+ {
+ n += (Long) innerTuple.get(0);
+ selected.addAll((DataBag) innerTuple.get(1));
+ waiting.addAll((DataBag) innerTuple.get(2));
+ }
+
+ long sampleSize = (long) Math.ceil(_samplingProbability * n);
+ long nNeeded = sampleSize - selected.size();
+
+ for (Tuple scored : waiting)
+ {
+ if (nNeeded <= 0)
+ {
+ break;
+ }
+ selected.add(ScoredTuple.fromIntermediateTuple(scored).getTuple());
+ nNeeded--;
+ }
+
+ return selected;
+ }
+ }
+
+ private static class ScoredTupleComparator implements Comparator<Tuple>
+ {
+
+ @Override
+ public int compare(Tuple o1, Tuple o2)
+ {
+ try
+ {
+ ScoredTuple t1 = ScoredTuple.fromIntermediateTuple(o1);
+ ScoredTuple t2 = ScoredTuple.fromIntermediateTuple(o2);
+ return t1.getScore().compareTo(t2.getScore());
+ }
+ catch (Throwable e)
+ {
+ throw new RuntimeException("Cannot compare " + o1 + " and " + o2 + ".", e);
+ }
+ }
+ }
+
+ private static double getQ1(long n, double p)
+ {
+ double t1 = 20.0 / (3.0 * n);
+ double q1 = p + t1 - Math.sqrt(t1 * t1 + 3.0 * t1 * p);
+ return q1;
+ }
+
+ private static double getQ2(long n, double p)
+ {
+ double t2 = 10.0 / n;
+ double q2 = p + t2 + Math.sqrt(t2 * t2 + 2.0 * t2 * p);
+ return q2;
+ }
+}
diff --git a/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java b/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java
new file mode 100644
index 0000000..b20d89d
--- /dev/null
+++ b/test/pig/datafu/test/pig/sampling/SimpleRandomSampleTest.java
@@ -0,0 +1,155 @@
+package datafu.test.pig.sampling;
+
+import org.adrianwalker.multilinestring.Multiline;
+import org.apache.pig.pigunit.PigTest;
+import org.testng.annotations.Test;
+
+import datafu.pig.sampling.SimpleRandomSample;
+import datafu.test.pig.PigTests;
+
+/**
+ * Tests for {@link SimpleRandomSample}.
+ *
+ * @author ximeng
+ *
+ */
+public class SimpleRandomSampleTest extends PigTests
+{
+ /**
+ * register $JAR_PATH
+ *
+ * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('$SAMPLING_PROBABILITY');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data ALL) GENERATE SRS(data) as sample_data;
+ *
+ * sampled = FOREACH sampled GENERATE COUNT(sample_data) AS sample_count;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String simpleRandomSampleTest;
+
+ @Test
+ public void simpleRandomSampleTest() throws Exception
+ {
+ writeLinesToFile("input",
+ "A1\tB1\t1",
+ "A1\tB1\t4",
+ "A1\tB3\t4",
+ "A1\tB4\t4",
+ "A2\tB1\t4",
+ "A2\tB2\t4",
+ "A3\tB1\t3",
+ "A3\tB1\t1",
+ "A3\tB3\t77",
+ "A4\tB1\t3",
+ "A4\tB2\t3",
+ "A4\tB3\t59",
+ "A4\tB4\t29",
+ "A5\tB1\t4",
+ "A6\tB2\t3",
+ "A6\tB2\t55",
+ "A6\tB3\t1",
+ "A7\tB1\t39",
+ "A7\tB2\t27",
+ "A7\tB3\t85",
+ "A8\tB1\t4",
+ "A8\tB2\t45",
+ "A9\tB3\t92",
+ "A9\tB3\t0",
+ "A9\tB6\t42",
+ "A9\tB5\t1",
+ "A10\tB1\t7",
+ "A10\tB2\t23",
+ "A10\tB2\t1",
+ "A10\tB2\t31",
+ "A10\tB6\t41",
+ "A10\tB7\t52");
+
+ int n = 32;
+ double p = 0.3;
+ int s = (int) Math.ceil(p * n);
+ PigTest test =
+ createPigTestFromString(simpleRandomSampleTest, "SAMPLING_PROBABILITY=" + p);
+
+ test.runScript();
+
+ assertOutput(test, "sampled", "(" + s + ")");
+ }
+
+ /**
+ * register $JAR_PATH
+ *
+ * DEFINE SRS datafu.pig.sampling.SimpleRandomSample('$SAMPLING_PROBABILITY');
+ *
+ * data = LOAD 'input' AS (A_id:chararray, B_id:chararray, C:int);
+ *
+ * sampled = FOREACH (GROUP data BY A_id) GENERATE group, SRS(data) as sample_data;
+ *
+ * sampled = FOREACH sampled GENERATE group, COUNT(sample_data) AS sample_count;
+ *
+ * sampled = ORDER sampled BY group;
+ *
+ * STORE sampled INTO 'output';
+ */
+ @Multiline
+ private String stratifiedSampleTest;
+
+ @Test
+ public void stratifiedSampleTest() throws Exception
+ {
+ writeLinesToFile("input",
+ "A1\tB1\t1",
+ "A1\tB1\t4",
+ "A1\tB3\t4",
+ "A1\tB4\t4",
+ "A2\tB1\t4",
+ "A2\tB2\t4",
+ "A3\tB1\t3",
+ "A3\tB1\t1",
+ "A3\tB3\t77",
+ "A4\tB1\t3",
+ "A4\tB2\t3",
+ "A4\tB3\t59",
+ "A4\tB4\t29",
+ "A5\tB1\t4",
+ "A6\tB2\t3",
+ "A6\tB2\t55",
+ "A6\tB3\t1",
+ "A7\tB1\t39",
+ "A7\tB2\t27",
+ "A7\tB3\t85",
+ "A8\tB1\t4",
+ "A8\tB2\t45",
+ "A9\tB3\t92",
+ "A9\tB3\t0",
+ "A9\tB6\t42",
+ "A9\tB5\t1",
+ "A10\tB1\t7",
+ "A10\tB2\t23",
+ "A10\tB2\t1",
+ "A10\tB2\t31",
+ "A10\tB6\t41",
+ "A10\tB7\t52");
+
+ double p = 0.5;
+
+ PigTest test =
+ createPigTestFromString(stratifiedSampleTest, "SAMPLING_PROBABILITY=" + p);
+ test.runScript();
+ assertOutput(test,
+ "sampled",
+ "(A1,2)",
+ "(A10,3)",
+ "(A2,1)",
+ "(A3,2)",
+ "(A4,2)",
+ "(A5,1)",
+ "(A6,2)",
+ "(A7,2)",
+ "(A8,1)",
+ "(A9,2)");
+ }
+}