blob: 1db8e39e8b2f6a4f216df5db0cd886445a8e41d8 [file] [log] [blame]
/*
* Copyright 2016, Yahoo! Inc.
* Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root for terms.
*/
package com.yahoo.sketches.pig.sampling;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.Algebraic;
import org.apache.pig.EvalFunc;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import com.yahoo.sketches.sampling.ReservoirItemsSketch;
import com.yahoo.sketches.sampling.ReservoirItemsUnion;
import com.yahoo.sketches.sampling.SamplingPigUtil;
/**
* This is a Pig UDF that applies reservoir sampling to input tuples. It implements both
* the <tt>Accumulator</tt> and <tt>Algebraic</tt> interfaces for efficient performance.
*
* @author Jon Malkin
*/
public class ReservoirSampling extends AccumulatorEvalFunc<Tuple> implements Algebraic {
// defined for test consistency
static final String N_ALIAS = "n";
static final String K_ALIAS = "k";
static final String SAMPLES_ALIAS = "samples";
private static final int DEFAULT_TARGET_K = 1024;
private final int targetK_;
private ReservoirItemsSketch<Tuple> reservoir_;
/**
* Reservoir sampling constructor.
* @param kStr String indicating the maximum number of desired entries in the reservoir.
*/
public ReservoirSampling(final String kStr) {
targetK_ = Integer.parseInt(kStr);
if (targetK_ < 2) {
throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
+ targetK_);
}
}
ReservoirSampling() { targetK_ = DEFAULT_TARGET_K; }
@Override
public Tuple exec(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return null;
}
final DataBag samples = (DataBag) inputTuple.get(0);
// if entire input data fits in reservoir, shortcut result
if (samples.size() <= targetK_) {
return createResultTuple(samples.size(), targetK_, samples);
}
return super.exec(inputTuple);
}
@Override
public void accumulate(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return;
}
final DataBag samples = (DataBag) inputTuple.get(0);
if (reservoir_ == null) {
reservoir_ = ReservoirItemsSketch.newInstance(targetK_);
}
for (Tuple t : samples) {
reservoir_.update(t);
}
}
@Override
public Tuple getValue() {
if (reservoir_ == null) {
return null;
}
final List<Tuple> data = SamplingPigUtil.getRawSamplesAsList(reservoir_);
final DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);
return createResultTuple(reservoir_.getN(), reservoir_.getK(), sampleBag);
}
@Override
public void cleanup() {
reservoir_ = null;
}
@Override
public Schema outputSchema(final Schema input) {
if (input != null && input.size() > 0) {
try {
Schema source = input;
// if we have a bag, grab one level down to get a tuple
if (source.size() == 1 && source.getField(0).type == DataType.BAG) {
source = source.getField(0).schema;
}
final Schema recordSchema = new Schema();
recordSchema.add(new Schema.FieldSchema(N_ALIAS, DataType.LONG));
recordSchema.add(new Schema.FieldSchema(K_ALIAS, DataType.INTEGER));
// this should add a bag to the output
recordSchema.add(new Schema.FieldSchema(SAMPLES_ALIAS, source, DataType.BAG));
return new Schema(new Schema.FieldSchema(getSchemaName(this
.getClass().getName().toLowerCase(), source), recordSchema, DataType.TUPLE));
}
catch (final FrontendException e) {
throw new RuntimeException(e);
}
}
return null;
}
static Tuple createResultTuple(final long n, final int k, final DataBag samples) {
final Tuple output = TupleFactory.getInstance().newTuple(3);
try {
output.set(0, n);
output.set(1, k);
output.set(2, samples);
} catch (final ExecException e) {
throw new RuntimeException("Pig error: " + e.getMessage(), e);
}
return output;
}
@Override
public String getInitial() {
return Initial.class.getName();
}
@Override
public String getIntermed() {
return IntermediateFinal.class.getName();
}
@Override
public String getFinal() {
return IntermediateFinal.class.getName();
}
public static class Initial extends EvalFunc<Tuple> {
private final int targetK_;
public Initial() {
targetK_ = DEFAULT_TARGET_K;
}
/**
* Map-side constructor for reservoir sampling UDF
* @param kStr String indicating the maximum number of desired entries in the reservoir.
* */
public Initial(final String kStr) {
targetK_ = Integer.parseInt(kStr);
if (targetK_ < 2) {
throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
+ targetK_);
}
}
@Override
public Tuple exec(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return null;
}
final DataBag records = (DataBag) inputTuple.get(0);
final ReservoirItemsSketch<Tuple> reservoir;
final DataBag outputBag;
int k = targetK_;
if (records.size() <= targetK_) {
outputBag = records;
} else {
reservoir = ReservoirItemsSketch.newInstance(targetK_);
for (Tuple t : records) {
reservoir.update(t);
}
// newDefaultBag(List<Tuple>) does *not* copy values
final List<Tuple> data = SamplingPigUtil.getRawSamplesAsList(reservoir);
outputBag = BagFactory.getInstance().newDefaultBag(data);
k = reservoir.getK();
}
final Tuple output = TupleFactory.getInstance().newTuple(3);
output.set(0, records.size());
output.set(1, k);
output.set(2, outputBag);
return output;
}
}
public static class IntermediateFinal extends EvalFunc<Tuple> {
private final int targetK_;
public IntermediateFinal() {
targetK_ = DEFAULT_TARGET_K;
}
/**
* Combiner and reducer side constructor for reservoir sampling UDF
* @param kStr String indicating the maximum number of desired entries in the reservoir.
* */
public IntermediateFinal(final String kStr) {
targetK_ = Integer.parseInt(kStr);
if (targetK_ < 2) {
throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: "
+ targetK_);
}
}
@Override
public Tuple exec(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return null;
}
final ReservoirItemsUnion<Tuple> union = ReservoirItemsUnion.newInstance(targetK_);
final DataBag outerBag = (DataBag) inputTuple.get(0);
for (Tuple reservoir : outerBag) {
final long n = (long) reservoir.get(0);
final int k = (int) reservoir.get(1);
if (n <= k && k <= targetK_) {
for (Tuple t : (DataBag) reservoir.get(2)) {
union.update(t);
}
} else {
final ArrayList<Tuple> samples = dataBagToArrayList((DataBag) reservoir.get(2));
union.update(n, k, samples);
}
}
final ReservoirItemsSketch<Tuple> result = union.getResult();
final ArrayList<Tuple> data = SamplingPigUtil.getRawSamplesAsList(result);
final DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);
final Tuple output = TupleFactory.getInstance().newTuple(3);
output.set(0, result.getN());
output.set(1, result.getK());
output.set(2, sampleBag);
return output;
}
}
static ArrayList<Tuple> dataBagToArrayList(final DataBag bag) {
final int arrayLength = (int) bag.size();
final ArrayList<Tuple> output = new ArrayList<>(arrayLength);
for (Tuple t : bag) {
output.add(t);
}
return output;
}
}