add WeightedSample udf and tests; add CountEach udf and tests
diff --git a/src/java/datafu/pig/bags/CountEach.java b/src/java/datafu/pig/bags/CountEach.java
new file mode 100644
index 0000000..d44386f
--- /dev/null
+++ b/src/java/datafu/pig/bags/CountEach.java
@@ -0,0 +1,151 @@
+/*
+ * Copyright 2013 LinkedIn, Inc
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package datafu.pig.bags;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * This UDF takes an input bag and generates a count 
+ * of the number of times each distinct tuple appears.
+ * 
+ * Example:
+ * <pre>
+ * {@code
+ * DEFINE CountEach datafu.pig.bags.CountEach();
+ * DEFINE CountEachFlatten datafu.pig.bags.CountEach('flatten');
+ * 
+ * -- input: 
+ * -- ({(A),(A),(C),(B)})
+ * input = LOAD 'input' AS (B: bag {T: tuple(alpha:CHARARRAY, numeric:INT)});
+ * 
+ * -- output: 
+ * -- {((A),2),((C),1),((B),1)}
+ * output = FOREACH input GENERATE CountEach(B); 
+ * 
+ * -- output_flatten: 
+ * -- ({(A,2),(C,1),(B,1)})
+ * output_flatten = FOREACH input GENERATE CountEachFlatten(B);
+ * } 
+ * </pre>
+ */
+public class CountEach extends EvalFunc<DataBag>
+{
+  private boolean flatten = false;
+  
+  public CountEach() {
+    
+  }
+  
+  public CountEach(String arg) {
+    if (arg != null && arg.toLowerCase().equals("flatten")) {
+      flatten = true;
+    }
+  }
+
+  @Override
+  public DataBag exec(Tuple input) throws IOException {
+    DataBag inputBag = (DataBag)input.get(0);
+    if (inputBag == null) throw new IllegalArgumentException("Expected a bag, got null");
+    
+    Map<Tuple, Integer> counts = new HashMap<Tuple, Integer>();
+    for (Tuple tuple : inputBag) {
+      if (!counts.containsKey(tuple)) {
+        counts.put(tuple, 0);
+      }
+      counts.put(tuple, counts.get(tuple)+1);
+    }
+    
+    DataBag output = BagFactory.getInstance().newDefaultBag();
+    for (Tuple tuple : counts.keySet()) {
+      Tuple outputTuple = null;
+      Tuple innerTuple = TupleFactory.getInstance().newTuple(tuple.getAll());
+      if (flatten) {        
+        innerTuple.append(counts.get(tuple));
+        outputTuple = innerTuple;
+      } else {
+        outputTuple = TupleFactory.getInstance().newTuple();
+        outputTuple.append(innerTuple);
+        outputTuple.append(counts.get(tuple));
+      }
+      output.add(outputTuple);
+    }
+
+    return output;
+  }
+  
+  @Override
+  public Schema outputSchema(Schema input)
+  {
+    try {
+      if (input.size() != 1)
+      {
+        throw new RuntimeException("Expected input to have one field");
+      }
+      
+      Schema.FieldSchema bagFieldSchema = input.getField(0);
+
+      if (bagFieldSchema.type != DataType.BAG)
+      {
+        throw new RuntimeException("Expected a BAG as input");
+      }
+      
+      Schema inputBagSchema = bagFieldSchema.schema;
+
+      if (inputBagSchema.getField(0).type != DataType.TUPLE)
+      {
+        throw new RuntimeException(String.format("Expected input bag to contain a TUPLE, but instead found %s",
+                                                 DataType.findTypeName(inputBagSchema.getField(0).type)));
+      }      
+      
+      Schema inputTupleSchema = inputBagSchema.getField(0).schema;
+      if (inputTupleSchema == null) inputTupleSchema = new Schema();
+      
+      Schema outputTupleSchema = null;
+      
+      if (this.flatten) {
+        outputTupleSchema = inputTupleSchema.clone();
+        outputTupleSchema.add(new Schema.FieldSchema("count", DataType.INTEGER));
+      } else {        
+        outputTupleSchema = new Schema();
+        outputTupleSchema.add(new Schema.FieldSchema("tuple_schema", inputTupleSchema.clone(), DataType.TUPLE));
+        outputTupleSchema.add(new Schema.FieldSchema("count", DataType.INTEGER));
+      }
+      
+      return new Schema(new Schema.FieldSchema(
+            getSchemaName(this.getClass().getName().toLowerCase(), input),
+            outputTupleSchema, 
+            DataType.BAG));
+    }
+    catch (CloneNotSupportedException e) {
+      throw new RuntimeException(e);
+    }
+    catch (FrontendException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+}
diff --git a/src/java/datafu/pig/bags/WeightedSample.java b/src/java/datafu/pig/bags/WeightedSample.java
new file mode 100644
index 0000000..29b2099
--- /dev/null
+++ b/src/java/datafu/pig/bags/WeightedSample.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2013 LinkedIn, Inc
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package datafu.pig.bags;
+
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * Create a new bag by performing a weighted sampling without replacement
+ * from the input bag. Optionally takes two additional arguments to specify
+ * the index of the column to use as a scoring column (uses enumerated bag 
+ * order if not specified) and a limit on the number of items to return.
+ * n.b.
+ * <ul>
+ * <li>When no scoring column is specified, items from the top of the bag are
+ * more likely to be chosen than items from the bottom.
+ * <li>High scores are more likely to be chosen when using a scoring column.
+ * </ul>
+ * 
+ * <p>
+ * Example:
+ * <pre>
+ * {@code
+ * define WeightedSample com.linkedin.endorsements.pig.WeightedSample()
+ * 
+ * -- input:
+ * -- ({(a,1),(b,2),(c,3),(d,4),(e,5)})
+ * input = LOAD 'input' AS (A: bag{T: tuple(name:chararray,score:int)});
+ * 
+ * output1 = FOREACH input GENERATE WeightedSample(A);
+ * -- output1:
+ * -- no scoring column specified, so uses bag order
+ * -- ({(a,1),(b,2),(e,5),(d,4),(c,3)}) -- example of random
+ * 
+ * -- sample using the second column (index 1) and keep only the top 3
+ * -- scoring column is specified, so bias towards higher scores
+ * -- only keep the first 3
+ * output2 = FOREACH input GENERATE WeightedSample(A,1,3);
+ * -- output2:
+ * -- ({(e,5),(d,4),(b,2)})
+ * }
+ * </pre>
+ */
+public class WeightedSample extends EvalFunc<DataBag>
+{
+  BagFactory bagFactory = BagFactory.getInstance();
+
+  public WeightedSample() {
+  }
+
+  @Override
+  public DataBag exec(Tuple input) throws IOException {   
+    DataBag output = bagFactory.newDefaultBag();
+
+    DataBag samples = (DataBag) input.get(0);
+    if (samples == null || samples.size() == 0) {
+      return output; // if we are given null we will return an empty bag
+    }
+    int numSamples = (int) samples.size();
+    if (numSamples == 1) return samples;
+       
+    Tuple[] tuples = new Tuple[numSamples];
+    int tupleIndex = 0;
+    for (Tuple tuple : samples) {
+      tuples[tupleIndex] = tuple;
+      tupleIndex++;
+    }
+
+    double[] scores = new double[numSamples];
+    if (input.size() <= 1) {      
+      // no scoring column specified, so use rank order
+      for (int i = 0; i < scores.length; i++) {
+        scores[i] = scores.length + 1 - i;
+      }
+    } else {
+      tupleIndex = 0;
+      int scoreIndex = ((Number)input.get(1)).intValue();
+      for (Tuple tuple : samples) {
+        double score = ((Number)tuple.get(scoreIndex)).doubleValue();
+        score = Math.max(score, Double.MIN_NORMAL); // negative scores cause problems
+        scores[tupleIndex] = score;
+        tupleIndex++;
+      }
+    }
+    
+    int limitSamples = numSamples;
+    if (input.size() == 3) {
+      // sample limit included
+      limitSamples = Math.min(((Number)input.get(2)).intValue(), numSamples);      
+    }
+
+    /*
+     * Here's how the algorithm works:
+     * 
+     * 1. Create a cumulative distribution of the scores 2. Draw a random number 3. Find
+     * the interval in which the drawn number falls into 4. Select the element
+     * encompassing that interval 5. Remove the selected element from consideration 6.
+     * Repeat 1-5 k times
+     * 
+     * However, rather than removing the element (#5), which is expensive for an array,
+     * this function performs some extra bookkeeping by replacing the selected element
+     * with an element from the front of the array and truncating the front. This
+     * complicates matters as the element positions have changed, so another mapping for
+     * positions is needed.
+     * 
+     * This is an O(k*n) algorithm, where k is the number of elements to sample and n is
+     * the number of scores.
+     */    
+    Random rng = new Random();
+    // the system property random seed is used to enable repeatable tests
+    if (System.getProperties().containsKey("pigunit.randseed")) {
+      long randSeed = Long.parseLong(System.getProperties().getProperty("pigunit.randseed"));
+      rng = new Random(randSeed);
+    }
+    for (int k = 0; k < limitSamples; k++) {
+      double val = rng.nextDouble();
+      int idx = find_cumsum_interval(scores, val, k, numSamples);
+      if (idx == numSamples)
+        idx = rng.nextInt(numSamples - k) + k;
+
+      output.add(tuples[idx]);
+
+      scores[idx] = scores[k];
+      tuples[idx] = tuples[k];
+    }
+
+    return output;
+  }
+
+  public int find_cumsum_interval(double[] scores, double val, int begin, int end) {
+    double sum = 0.0;
+    double cumsum = 0.0;
+    for (int i = begin; i < end; i++) {
+      sum += scores[i];
+    }
+
+    for (int i = begin; i < end; i++) {
+      cumsum += scores[i];
+      if ((cumsum / sum) > val)
+        return i;
+    }
+    return end;
+  }
+
+  @Override
+  public Schema outputSchema(Schema input) {
+    try {     
+      Schema.FieldSchema inputFieldSchema = input.getField(0);
+
+      if (inputFieldSchema.type != DataType.BAG) {
+        throw new RuntimeException("Expected a BAG as input");
+      }
+      
+      return new Schema(new Schema.FieldSchema(getSchemaName(this.getClass().getName().toLowerCase(), input),
+                                               inputFieldSchema.schema, DataType.BAG));    
+    } catch (FrontendException e) {
+      e.printStackTrace();
+      throw new RuntimeException(e);
+    }
+  }
+}
diff --git a/test/pig/datafu/test/pig/bags/BagTests.java b/test/pig/datafu/test/pig/bags/BagTests.java
index 8a0a223..d6c1226 100644
--- a/test/pig/datafu/test/pig/bags/BagTests.java
+++ b/test/pig/datafu/test/pig/bags/BagTests.java
@@ -1,15 +1,8 @@
 package datafu.test.pig.bags;
 
-import static org.testng.Assert.assertEquals;
-
-import org.apache.pig.data.BagFactory;
-import org.apache.pig.data.DataBag;
-import org.apache.pig.data.Tuple;
-import org.apache.pig.data.TupleFactory;
 import org.apache.pig.pigunit.PigTest;
 import org.testng.annotations.Test;
 
-import datafu.pig.bags.Enumerate;
 import datafu.test.pig.PigTests;
 
 
@@ -243,73 +236,6 @@
   }
   
   @Test
-  public void enumerateTest2() throws Exception
-  {
-    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/enumerateTest.pig");
-      
-    writeLinesToFile("input",
-                     "({(10,{(1),(2),(3)}),(20,{(4),(5),(6)}),(30,{(7),(8)}),(40,{(9),(10),(11)}),(50,{(12),(13),(14),(15)})})",
-                     "({(11,{(11),(12),(13),(14)}),(21,{(15),(16),(17),(18)}),(31,{(19),(20)}),(41,{(21),(22),(23),(24)}),(51,{(25),(26),(27)})})");
-   
-    test.runScript();
-   
-    assertOutput(test, "data4",
-                 "(10,{(1),(2),(3)},0)",
-                 "(20,{(4),(5),(6)},1)",
-                 "(30,{(7),(8)},2)",
-                 "(40,{(9),(10),(11)},3)",
-                 "(50,{(12),(13),(14),(15)},4)",
-                 "(11,{(11),(12),(13),(14)},0)",
-                 "(21,{(15),(16),(17),(18)},1)",
-                 "(31,{(19),(20)},2)",
-                 "(41,{(21),(22),(23),(24)},3)",
-                 "(51,{(25),(26),(27)},4)");
-  }  
-  
-  /* 
-   * Testing "Accumulator" part of Enumeration by manually invoke accumulate() and getValue() 
-   */
-  @Test
-  public void enumerateAccumulatorTest() throws Exception
-  {
-    Enumerate enumurate = new Enumerate(); 
-    
-    Tuple tuple1 = TupleFactory.getInstance().newTuple(1);
-    tuple1.set(0, 10);
-    
-    Tuple tuple2 = TupleFactory.getInstance().newTuple(1);
-    tuple2.set(0, 20);
-    
-    Tuple tuple3 = TupleFactory.getInstance().newTuple(1);
-    tuple3.set(0, 30);
-    
-    Tuple tuple4 = TupleFactory.getInstance().newTuple(1);
-    tuple4.set(0, 40);
-    
-    Tuple tuple5 = TupleFactory.getInstance().newTuple(1);
-    tuple5.set(0, 50);
-    
-    DataBag bag1 = BagFactory.getInstance().newDefaultBag();
-    bag1.add(tuple1);
-    bag1.add(tuple2);
-    bag1.add(tuple3);
-    
-    DataBag bag2 = BagFactory.getInstance().newDefaultBag();
-    bag2.add(tuple4);
-    bag2.add(tuple5);
-    
-    Tuple inputTuple1 = TupleFactory.getInstance().newTuple(1);
-    inputTuple1.set(0,bag1);
-    
-    Tuple inputTuple2 = TupleFactory.getInstance().newTuple(1);
-    inputTuple2.set(0,bag2);
-    
-    enumurate.accumulate(inputTuple1);
-    enumurate.accumulate(inputTuple2);
-    assertEquals(enumurate.getValue().toString(), "{(10,0),(20,1),(30,2),(40,3),(50,4)}");
-  }
-  
-  @Test
   public void comprehensiveBagSplitAndEnumerate() throws Exception
   {
     PigTest test = createPigTest("test/pig/datafu/test/pig/bags/comprehensiveBagSplitAndEnumerate.pig");
@@ -360,4 +286,81 @@
     assertOutput(test, "data2",
                  "({(Z,1,0),(A,1,0),(B,2,0),(C,3,0),(D,4,0),(E,5,0)})");
   }
+  
+  @Test
+  public void weightedSampleTest() throws Exception
+  {
+    // the system property random seed is used to enable repeatable tests
+    System.getProperties().setProperty("pigunit.randseed", "1");
+    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleTest.pig");
+
+    writeLinesToFile("input", 
+                     "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+                  
+    test.runScript();
+            
+    assertOutput(test, "data2",
+        "({(c,5),(a,100),(b,1),(d,2)})");
+  }
+  
+  @Test
+  public void weightedSampleScoreTest() throws Exception
+  {
+    // the system property random seed is used to enable repeatable tests
+    System.getProperties().setProperty("pigunit.randseed", "1");
+    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig");
+
+    writeLinesToFile("input", 
+                     "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+                  
+    test.runScript();
+            
+    assertOutput(test, "data2",
+        "({(a,100),(c,5),(b,1),(d,2)})");
+  }
+  
+  @Test
+  public void weightedSampleScoreLimitTest() throws Exception
+  {
+    // the system property random seed is used to enable repeatable tests
+    System.getProperties().setProperty("pigunit.randseed", "1");
+    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig");
+
+    writeLinesToFile("input", 
+                     "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+                  
+    test.runScript();
+            
+    assertOutput(test, "data2",
+        "({(a,100),(c,5),(b,1)})");
+  }
+  
+  @Test 
+  public void countEachTest() throws Exception
+  {
+    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/countEachTest.pig");
+
+    writeLinesToFile("input", 
+                     "({(A),(B),(A),(C),(A),(B)})");
+                  
+    test.runScript();
+            
+    assertOutput(test, "data3",
+        "({((A),3),((B),2),((C),1)})");
+  }
+  
+  @Test 
+  public void countEachFlattenTest() throws Exception
+  {
+    PigTest test = createPigTest("test/pig/datafu/test/pig/bags/countEachFlattenTest.pig");
+
+    writeLinesToFile("input", 
+                     "({(A),(B),(A),(C),(A),(B)})");
+                  
+    test.runScript();
+            
+    assertOutput(test, "data3",
+        "({(A,3),(B,2),(C,1)})");
+  }
+
 }
diff --git a/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig b/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig
new file mode 100644
index 0000000..1db7651
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig
@@ -0,0 +1,16 @@
+register $JAR_PATH
+
+define CountEach datafu.pig.bags.CountEach('flatten');
+
+data = LOAD 'input' AS (data: bag {T: tuple(v1:chararray)});
+
+data2 = FOREACH data GENERATE CountEach(data) as counted;
+describe data2;
+
+data3 = FOREACH data2 {
+  ordered = ORDER counted BY count DESC;
+  GENERATE ordered;
+}
+describe data3
+
+STORE data3 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/countEachTest.pig b/test/pig/datafu/test/pig/bags/countEachTest.pig
new file mode 100644
index 0000000..ef83284
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/countEachTest.pig
@@ -0,0 +1,16 @@
+register $JAR_PATH
+
+define CountEach datafu.pig.bags.CountEach();
+
+data = LOAD 'input' AS (data: bag {T: tuple(v1:chararray)});
+
+data2 = FOREACH data GENERATE CountEach(data) as counted;
+describe data2;
+
+data3 = FOREACH data2 {
+  ordered = ORDER counted BY count DESC;
+  GENERATE ordered;
+}
+describe data3
+
+STORE data3 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig
new file mode 100644
index 0000000..3bd8a18
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A,1,3);
+describe data2;
+
+STORE data2 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig
new file mode 100644
index 0000000..e231a8d
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A,1);
+describe data2;
+
+STORE data2 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleTest.pig
new file mode 100644
index 0000000..4890127
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A);
+describe data2;
+
+STORE data2 INTO 'output';