blob: 29b209912aed12cc34c9366ef5f4f420e4eed39d [file] [log] [blame]
/*
* Copyright 2013 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package datafu.pig.bags;
import java.io.IOException;
import java.util.Random;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
/**
* Create a new bag by performing a weighted sampling without replacement
* from the input bag. Optionally takes two additional arguments to specify
* the index of the column to use as a scoring column (uses enumerated bag
* order if not specified) and a limit on the number of items to return.
* n.b.
* <ul>
* <li>When no scoring column is specified, items from the top of the bag are
* more likely to be chosen than items from the bottom.
* <li>High scores are more likely to be chosen when using a scoring column.
* </ul>
*
* <p>
* Example:
* <pre>
* {@code
* define WeightedSample com.linkedin.endorsements.pig.WeightedSample()
*
* -- input:
* -- ({(a,1),(b,2),(c,3),(d,4),(e,5)})
* input = LOAD 'input' AS (A: bag{T: tuple(name:chararray,score:int)});
*
* output1 = FOREACH input GENERATE WeightedSample(A);
* -- output1:
* -- no scoring column specified, so uses bag order
* -- ({(a,1),(b,2),(e,5),(d,4),(c,3)}) -- example of random
*
* -- sample using the second column (index 1) and keep only the top 3
* -- scoring column is specified, so bias towards higher scores
* -- only keep the first 3
* output2 = FOREACH input GENERATE WeightedSample(A,1,3);
* -- output2:
* -- ({(e,5),(d,4),(b,2)})
* }
* </pre>
*/
public class WeightedSample extends EvalFunc<DataBag>
{
BagFactory bagFactory = BagFactory.getInstance();
public WeightedSample() {
}
@Override
public DataBag exec(Tuple input) throws IOException {
DataBag output = bagFactory.newDefaultBag();
DataBag samples = (DataBag) input.get(0);
if (samples == null || samples.size() == 0) {
return output; // if we are given null we will return an empty bag
}
int numSamples = (int) samples.size();
if (numSamples == 1) return samples;
Tuple[] tuples = new Tuple[numSamples];
int tupleIndex = 0;
for (Tuple tuple : samples) {
tuples[tupleIndex] = tuple;
tupleIndex++;
}
double[] scores = new double[numSamples];
if (input.size() <= 1) {
// no scoring column specified, so use rank order
for (int i = 0; i < scores.length; i++) {
scores[i] = scores.length + 1 - i;
}
} else {
tupleIndex = 0;
int scoreIndex = ((Number)input.get(1)).intValue();
for (Tuple tuple : samples) {
double score = ((Number)tuple.get(scoreIndex)).doubleValue();
score = Math.max(score, Double.MIN_NORMAL); // negative scores cause problems
scores[tupleIndex] = score;
tupleIndex++;
}
}
int limitSamples = numSamples;
if (input.size() == 3) {
// sample limit included
limitSamples = Math.min(((Number)input.get(2)).intValue(), numSamples);
}
/*
* Here's how the algorithm works:
*
* 1. Create a cumulative distribution of the scores 2. Draw a random number 3. Find
* the interval in which the drawn number falls into 4. Select the element
* encompassing that interval 5. Remove the selected element from consideration 6.
* Repeat 1-5 k times
*
* However, rather than removing the element (#5), which is expensive for an array,
* this function performs some extra bookkeeping by replacing the selected element
* with an element from the front of the array and truncating the front. This
* complicates matters as the element positions have changed, so another mapping for
* positions is needed.
*
* This is an O(k*n) algorithm, where k is the number of elements to sample and n is
* the number of scores.
*/
Random rng = new Random();
// the system property random seed is used to enable repeatable tests
if (System.getProperties().containsKey("pigunit.randseed")) {
long randSeed = Long.parseLong(System.getProperties().getProperty("pigunit.randseed"));
rng = new Random(randSeed);
}
for (int k = 0; k < limitSamples; k++) {
double val = rng.nextDouble();
int idx = find_cumsum_interval(scores, val, k, numSamples);
if (idx == numSamples)
idx = rng.nextInt(numSamples - k) + k;
output.add(tuples[idx]);
scores[idx] = scores[k];
tuples[idx] = tuples[k];
}
return output;
}
public int find_cumsum_interval(double[] scores, double val, int begin, int end) {
double sum = 0.0;
double cumsum = 0.0;
for (int i = begin; i < end; i++) {
sum += scores[i];
}
for (int i = begin; i < end; i++) {
cumsum += scores[i];
if ((cumsum / sum) > val)
return i;
}
return end;
}
@Override
public Schema outputSchema(Schema input) {
try {
Schema.FieldSchema inputFieldSchema = input.getField(0);
if (inputFieldSchema.type != DataType.BAG) {
throw new RuntimeException("Expected a BAG as input");
}
return new Schema(new Schema.FieldSchema(getSchemaName(this.getClass().getName().toLowerCase(), input),
inputFieldSchema.schema, DataType.BAG));
} catch (FrontendException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}