add WeightedSample udf and tests; add CountEach udf and tests
diff --git a/src/java/datafu/pig/bags/CountEach.java b/src/java/datafu/pig/bags/CountEach.java
new file mode 100644
index 0000000..d44386f
--- /dev/null
+++ b/src/java/datafu/pig/bags/CountEach.java
@@ -0,0 +1,151 @@
+/*
+ * Copyright 2013 LinkedIn, Inc
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package datafu.pig.bags;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.data.TupleFactory;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * This UDF takes an input bag and generates a count
+ * of the number of times each distinct tuple appears.
+ *
+ * Example:
+ * <pre>
+ * {@code
+ * DEFINE CountEach datafu.pig.bags.CountEach();
+ * DEFINE CountEachFlatten datafu.pig.bags.CountEach('flatten');
+ *
+ * -- input:
+ * -- ({(A),(A),(C),(B)})
+ * input = LOAD 'input' AS (B: bag {T: tuple(alpha:CHARARRAY, numeric:INT)});
+ *
+ * -- output:
+ * -- {((A),2),((C),1),((B),1)}
+ * output = FOREACH input GENERATE CountEach(B);
+ *
+ * -- output_flatten:
+ * -- ({(A,2),(C,1),(B,1)})
+ * output_flatten = FOREACH input GENERATE CountEachFlatten(B);
+ * }
+ * </pre>
+ */
+public class CountEach extends EvalFunc<DataBag>
+{
+ private boolean flatten = false;
+
+ public CountEach() {
+
+ }
+
+ public CountEach(String arg) {
+ if (arg != null && arg.toLowerCase().equals("flatten")) {
+ flatten = true;
+ }
+ }
+
+ @Override
+ public DataBag exec(Tuple input) throws IOException {
+ DataBag inputBag = (DataBag)input.get(0);
+ if (inputBag == null) throw new IllegalArgumentException("Expected a bag, got null");
+
+ Map<Tuple, Integer> counts = new HashMap<Tuple, Integer>();
+ for (Tuple tuple : inputBag) {
+ if (!counts.containsKey(tuple)) {
+ counts.put(tuple, 0);
+ }
+ counts.put(tuple, counts.get(tuple)+1);
+ }
+
+ DataBag output = BagFactory.getInstance().newDefaultBag();
+ for (Tuple tuple : counts.keySet()) {
+ Tuple outputTuple = null;
+ Tuple innerTuple = TupleFactory.getInstance().newTuple(tuple.getAll());
+ if (flatten) {
+ innerTuple.append(counts.get(tuple));
+ outputTuple = innerTuple;
+ } else {
+ outputTuple = TupleFactory.getInstance().newTuple();
+ outputTuple.append(innerTuple);
+ outputTuple.append(counts.get(tuple));
+ }
+ output.add(outputTuple);
+ }
+
+ return output;
+ }
+
+ @Override
+ public Schema outputSchema(Schema input)
+ {
+ try {
+ if (input.size() != 1)
+ {
+ throw new RuntimeException("Expected input to have one field");
+ }
+
+ Schema.FieldSchema bagFieldSchema = input.getField(0);
+
+ if (bagFieldSchema.type != DataType.BAG)
+ {
+ throw new RuntimeException("Expected a BAG as input");
+ }
+
+ Schema inputBagSchema = bagFieldSchema.schema;
+
+ if (inputBagSchema.getField(0).type != DataType.TUPLE)
+ {
+ throw new RuntimeException(String.format("Expected input bag to contain a TUPLE, but instead found %s",
+ DataType.findTypeName(inputBagSchema.getField(0).type)));
+ }
+
+ Schema inputTupleSchema = inputBagSchema.getField(0).schema;
+ if (inputTupleSchema == null) inputTupleSchema = new Schema();
+
+ Schema outputTupleSchema = null;
+
+ if (this.flatten) {
+ outputTupleSchema = inputTupleSchema.clone();
+ outputTupleSchema.add(new Schema.FieldSchema("count", DataType.INTEGER));
+ } else {
+ outputTupleSchema = new Schema();
+ outputTupleSchema.add(new Schema.FieldSchema("tuple_schema", inputTupleSchema.clone(), DataType.TUPLE));
+ outputTupleSchema.add(new Schema.FieldSchema("count", DataType.INTEGER));
+ }
+
+ return new Schema(new Schema.FieldSchema(
+ getSchemaName(this.getClass().getName().toLowerCase(), input),
+ outputTupleSchema,
+ DataType.BAG));
+ }
+ catch (CloneNotSupportedException e) {
+ throw new RuntimeException(e);
+ }
+ catch (FrontendException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+}
diff --git a/src/java/datafu/pig/bags/WeightedSample.java b/src/java/datafu/pig/bags/WeightedSample.java
new file mode 100644
index 0000000..29b2099
--- /dev/null
+++ b/src/java/datafu/pig/bags/WeightedSample.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2013 LinkedIn, Inc
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package datafu.pig.bags;
+
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.pig.EvalFunc;
+import org.apache.pig.data.BagFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.data.DataType;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.impl.logicalLayer.schema.Schema;
+
+/**
+ * Create a new bag by performing a weighted sampling without replacement
+ * from the input bag. Optionally takes two additional arguments to specify
+ * the index of the column to use as a scoring column (uses enumerated bag
+ * order if not specified) and a limit on the number of items to return.
+ * n.b.
+ * <ul>
+ * <li>When no scoring column is specified, items from the top of the bag are
+ * more likely to be chosen than items from the bottom.
+ * <li>High scores are more likely to be chosen when using a scoring column.
+ * </ul>
+ *
+ * <p>
+ * Example:
+ * <pre>
+ * {@code
+ * define WeightedSample com.linkedin.endorsements.pig.WeightedSample()
+ *
+ * -- input:
+ * -- ({(a,1),(b,2),(c,3),(d,4),(e,5)})
+ * input = LOAD 'input' AS (A: bag{T: tuple(name:chararray,score:int)});
+ *
+ * output1 = FOREACH input GENERATE WeightedSample(A);
+ * -- output1:
+ * -- no scoring column specified, so uses bag order
+ * -- ({(a,1),(b,2),(e,5),(d,4),(c,3)}) -- example of random
+ *
+ * -- sample using the second column (index 1) and keep only the top 3
+ * -- scoring column is specified, so bias towards higher scores
+ * -- only keep the first 3
+ * output2 = FOREACH input GENERATE WeightedSample(A,1,3);
+ * -- output2:
+ * -- ({(e,5),(d,4),(b,2)})
+ * }
+ * </pre>
+ */
+public class WeightedSample extends EvalFunc<DataBag>
+{
+ BagFactory bagFactory = BagFactory.getInstance();
+
+ public WeightedSample() {
+ }
+
+ @Override
+ public DataBag exec(Tuple input) throws IOException {
+ DataBag output = bagFactory.newDefaultBag();
+
+ DataBag samples = (DataBag) input.get(0);
+ if (samples == null || samples.size() == 0) {
+ return output; // if we are given null we will return an empty bag
+ }
+ int numSamples = (int) samples.size();
+ if (numSamples == 1) return samples;
+
+ Tuple[] tuples = new Tuple[numSamples];
+ int tupleIndex = 0;
+ for (Tuple tuple : samples) {
+ tuples[tupleIndex] = tuple;
+ tupleIndex++;
+ }
+
+ double[] scores = new double[numSamples];
+ if (input.size() <= 1) {
+ // no scoring column specified, so use rank order
+ for (int i = 0; i < scores.length; i++) {
+ scores[i] = scores.length + 1 - i;
+ }
+ } else {
+ tupleIndex = 0;
+ int scoreIndex = ((Number)input.get(1)).intValue();
+ for (Tuple tuple : samples) {
+ double score = ((Number)tuple.get(scoreIndex)).doubleValue();
+ score = Math.max(score, Double.MIN_NORMAL); // negative scores cause problems
+ scores[tupleIndex] = score;
+ tupleIndex++;
+ }
+ }
+
+ int limitSamples = numSamples;
+ if (input.size() == 3) {
+ // sample limit included
+ limitSamples = Math.min(((Number)input.get(2)).intValue(), numSamples);
+ }
+
+ /*
+ * Here's how the algorithm works:
+ *
+ * 1. Create a cumulative distribution of the scores 2. Draw a random number 3. Find
+ * the interval in which the drawn number falls into 4. Select the element
+ * encompassing that interval 5. Remove the selected element from consideration 6.
+ * Repeat 1-5 k times
+ *
+ * However, rather than removing the element (#5), which is expensive for an array,
+ * this function performs some extra bookkeeping by replacing the selected element
+ * with an element from the front of the array and truncating the front. This
+ * complicates matters as the element positions have changed, so another mapping for
+ * positions is needed.
+ *
+ * This is an O(k*n) algorithm, where k is the number of elements to sample and n is
+ * the number of scores.
+ */
+ Random rng = new Random();
+ // the system property random seed is used to enable repeatable tests
+ if (System.getProperties().containsKey("pigunit.randseed")) {
+ long randSeed = Long.parseLong(System.getProperties().getProperty("pigunit.randseed"));
+ rng = new Random(randSeed);
+ }
+ for (int k = 0; k < limitSamples; k++) {
+ double val = rng.nextDouble();
+ int idx = find_cumsum_interval(scores, val, k, numSamples);
+ if (idx == numSamples)
+ idx = rng.nextInt(numSamples - k) + k;
+
+ output.add(tuples[idx]);
+
+ scores[idx] = scores[k];
+ tuples[idx] = tuples[k];
+ }
+
+ return output;
+ }
+
+ public int find_cumsum_interval(double[] scores, double val, int begin, int end) {
+ double sum = 0.0;
+ double cumsum = 0.0;
+ for (int i = begin; i < end; i++) {
+ sum += scores[i];
+ }
+
+ for (int i = begin; i < end; i++) {
+ cumsum += scores[i];
+ if ((cumsum / sum) > val)
+ return i;
+ }
+ return end;
+ }
+
+ @Override
+ public Schema outputSchema(Schema input) {
+ try {
+ Schema.FieldSchema inputFieldSchema = input.getField(0);
+
+ if (inputFieldSchema.type != DataType.BAG) {
+ throw new RuntimeException("Expected a BAG as input");
+ }
+
+ return new Schema(new Schema.FieldSchema(getSchemaName(this.getClass().getName().toLowerCase(), input),
+ inputFieldSchema.schema, DataType.BAG));
+ } catch (FrontendException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/test/pig/datafu/test/pig/bags/BagTests.java b/test/pig/datafu/test/pig/bags/BagTests.java
index 8a0a223..d6c1226 100644
--- a/test/pig/datafu/test/pig/bags/BagTests.java
+++ b/test/pig/datafu/test/pig/bags/BagTests.java
@@ -1,15 +1,8 @@
package datafu.test.pig.bags;
-import static org.testng.Assert.assertEquals;
-
-import org.apache.pig.data.BagFactory;
-import org.apache.pig.data.DataBag;
-import org.apache.pig.data.Tuple;
-import org.apache.pig.data.TupleFactory;
import org.apache.pig.pigunit.PigTest;
import org.testng.annotations.Test;
-import datafu.pig.bags.Enumerate;
import datafu.test.pig.PigTests;
@@ -243,73 +236,6 @@
}
@Test
- public void enumerateTest2() throws Exception
- {
- PigTest test = createPigTest("test/pig/datafu/test/pig/bags/enumerateTest.pig");
-
- writeLinesToFile("input",
- "({(10,{(1),(2),(3)}),(20,{(4),(5),(6)}),(30,{(7),(8)}),(40,{(9),(10),(11)}),(50,{(12),(13),(14),(15)})})",
- "({(11,{(11),(12),(13),(14)}),(21,{(15),(16),(17),(18)}),(31,{(19),(20)}),(41,{(21),(22),(23),(24)}),(51,{(25),(26),(27)})})");
-
- test.runScript();
-
- assertOutput(test, "data4",
- "(10,{(1),(2),(3)},0)",
- "(20,{(4),(5),(6)},1)",
- "(30,{(7),(8)},2)",
- "(40,{(9),(10),(11)},3)",
- "(50,{(12),(13),(14),(15)},4)",
- "(11,{(11),(12),(13),(14)},0)",
- "(21,{(15),(16),(17),(18)},1)",
- "(31,{(19),(20)},2)",
- "(41,{(21),(22),(23),(24)},3)",
- "(51,{(25),(26),(27)},4)");
- }
-
- /*
- * Testing "Accumulator" part of Enumeration by manually invoke accumulate() and getValue()
- */
- @Test
- public void enumerateAccumulatorTest() throws Exception
- {
- Enumerate enumurate = new Enumerate();
-
- Tuple tuple1 = TupleFactory.getInstance().newTuple(1);
- tuple1.set(0, 10);
-
- Tuple tuple2 = TupleFactory.getInstance().newTuple(1);
- tuple2.set(0, 20);
-
- Tuple tuple3 = TupleFactory.getInstance().newTuple(1);
- tuple3.set(0, 30);
-
- Tuple tuple4 = TupleFactory.getInstance().newTuple(1);
- tuple4.set(0, 40);
-
- Tuple tuple5 = TupleFactory.getInstance().newTuple(1);
- tuple5.set(0, 50);
-
- DataBag bag1 = BagFactory.getInstance().newDefaultBag();
- bag1.add(tuple1);
- bag1.add(tuple2);
- bag1.add(tuple3);
-
- DataBag bag2 = BagFactory.getInstance().newDefaultBag();
- bag2.add(tuple4);
- bag2.add(tuple5);
-
- Tuple inputTuple1 = TupleFactory.getInstance().newTuple(1);
- inputTuple1.set(0,bag1);
-
- Tuple inputTuple2 = TupleFactory.getInstance().newTuple(1);
- inputTuple2.set(0,bag2);
-
- enumurate.accumulate(inputTuple1);
- enumurate.accumulate(inputTuple2);
- assertEquals(enumurate.getValue().toString(), "{(10,0),(20,1),(30,2),(40,3),(50,4)}");
- }
-
- @Test
public void comprehensiveBagSplitAndEnumerate() throws Exception
{
PigTest test = createPigTest("test/pig/datafu/test/pig/bags/comprehensiveBagSplitAndEnumerate.pig");
@@ -360,4 +286,81 @@
assertOutput(test, "data2",
"({(Z,1,0),(A,1,0),(B,2,0),(C,3,0),(D,4,0),(E,5,0)})");
}
+
+ @Test
+ public void weightedSampleTest() throws Exception
+ {
+ // the system property random seed is used to enable repeatable tests
+ System.getProperties().setProperty("pigunit.randseed", "1");
+ PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleTest.pig");
+
+ writeLinesToFile("input",
+ "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+
+ test.runScript();
+
+ assertOutput(test, "data2",
+ "({(c,5),(a,100),(b,1),(d,2)})");
+ }
+
+ @Test
+ public void weightedSampleScoreTest() throws Exception
+ {
+ // the system property random seed is used to enable repeatable tests
+ System.getProperties().setProperty("pigunit.randseed", "1");
+ PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig");
+
+ writeLinesToFile("input",
+ "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+
+ test.runScript();
+
+ assertOutput(test, "data2",
+ "({(a,100),(c,5),(b,1),(d,2)})");
+ }
+
+ @Test
+ public void weightedSampleScoreLimitTest() throws Exception
+ {
+ // the system property random seed is used to enable repeatable tests
+ System.getProperties().setProperty("pigunit.randseed", "1");
+ PigTest test = createPigTest("test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig");
+
+ writeLinesToFile("input",
+ "({(a, 100),(b, 1),(c, 5),(d, 2)})");
+
+ test.runScript();
+
+ assertOutput(test, "data2",
+ "({(a,100),(c,5),(b,1)})");
+ }
+
+ @Test
+ public void countEachTest() throws Exception
+ {
+ PigTest test = createPigTest("test/pig/datafu/test/pig/bags/countEachTest.pig");
+
+ writeLinesToFile("input",
+ "({(A),(B),(A),(C),(A),(B)})");
+
+ test.runScript();
+
+ assertOutput(test, "data3",
+ "({((A),3),((B),2),((C),1)})");
+ }
+
+ @Test
+ public void countEachFlattenTest() throws Exception
+ {
+ PigTest test = createPigTest("test/pig/datafu/test/pig/bags/countEachFlattenTest.pig");
+
+ writeLinesToFile("input",
+ "({(A),(B),(A),(C),(A),(B)})");
+
+ test.runScript();
+
+ assertOutput(test, "data3",
+ "({(A,3),(B,2),(C,1)})");
+ }
+
}
diff --git a/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig b/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig
new file mode 100644
index 0000000..1db7651
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/countEachFlattenTest.pig
@@ -0,0 +1,16 @@
+register $JAR_PATH
+
+define CountEach datafu.pig.bags.CountEach('flatten');
+
+data = LOAD 'input' AS (data: bag {T: tuple(v1:chararray)});
+
+data2 = FOREACH data GENERATE CountEach(data) as counted;
+describe data2;
+
+data3 = FOREACH data2 {
+ ordered = ORDER counted BY count DESC;
+ GENERATE ordered;
+}
+describe data3
+
+STORE data3 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/countEachTest.pig b/test/pig/datafu/test/pig/bags/countEachTest.pig
new file mode 100644
index 0000000..ef83284
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/countEachTest.pig
@@ -0,0 +1,16 @@
+register $JAR_PATH
+
+define CountEach datafu.pig.bags.CountEach();
+
+data = LOAD 'input' AS (data: bag {T: tuple(v1:chararray)});
+
+data2 = FOREACH data GENERATE CountEach(data) as counted;
+describe data2;
+
+data3 = FOREACH data2 {
+ ordered = ORDER counted BY count DESC;
+ GENERATE ordered;
+}
+describe data3
+
+STORE data3 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig
new file mode 100644
index 0000000..3bd8a18
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleScoreLimitTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A,1,3);
+describe data2;
+
+STORE data2 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig
new file mode 100644
index 0000000..e231a8d
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleScoreTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A,1);
+describe data2;
+
+STORE data2 INTO 'output';
diff --git a/test/pig/datafu/test/pig/bags/weightedSampleTest.pig b/test/pig/datafu/test/pig/bags/weightedSampleTest.pig
new file mode 100644
index 0000000..4890127
--- /dev/null
+++ b/test/pig/datafu/test/pig/bags/weightedSampleTest.pig
@@ -0,0 +1,10 @@
+register $JAR_PATH
+
+define WeightedSample datafu.pig.bags.WeightedSample();
+
+data = LOAD 'input' AS (A: bag {T: tuple(v1:chararray,v2:INT)});
+
+data2 = FOREACH data GENERATE WeightedSample(A);
+describe data2;
+
+STORE data2 INTO 'output';