blob: 02afa63595648a51e68f7550ea0810e079639cec [file] [log] [blame]
/*
* Copyright 2017, Yahoo! Inc.
* Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root for terms.
*/
package com.yahoo.sketches.pig.sampling;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import com.yahoo.sketches.sampling.ReservoirItemsSketch;
import com.yahoo.sketches.sampling.ReservoirItemsUnion;
import com.yahoo.sketches.sampling.SamplingPigUtil;
/**
* This is a Pig UDF that unions reservoir samples. It implements
* the <tt>Accumulator</tt> interface for more efficient performance. Input is
* assumed to come from the reservoir sampling UDF or to be in a compatible format:
* <tt>(n, k, {(samples)}</tt>
* where <tt>n</tt> is the total number of items presented to the sketch and <tt>k</tt> is the
* maximum size of the sketch.
*
* @author Jon Malkin
*/
public class ReservoirUnion extends AccumulatorEvalFunc<Tuple> {
private static final int DEFAULT_TARGET_K = 1024;
private final int maxK_;
private ReservoirItemsUnion<Tuple> union_;
/**
* Reservoir sampling constructor.
* @param kStr String indicating the maximum number of desired entries in the reservoir.
*/
public ReservoirUnion(final String kStr) {
maxK_ = Integer.parseInt(kStr);
if (maxK_ < 2) {
throw new IllegalArgumentException("ReservoirUnion requires max reservoir size >= 2: "
+ maxK_);
}
}
ReservoirUnion() { maxK_ = DEFAULT_TARGET_K; }
// We could overload exec() for easy cases, but we still need to compare the incoming
// reservoir's k vs max k and possibly downsample.
@Override
public void accumulate(final Tuple inputTuple) throws IOException {
if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
return;
}
final DataBag reservoirs = (DataBag) inputTuple.get(0);
if (union_ == null) {
union_ = ReservoirItemsUnion.newInstance(maxK_);
}
try {
for (Tuple t : reservoirs) {
// if t == null or t.size() < 3, we'll throw an exception
final long n = (long) t.get(0);
final int k = (int) t.get(1);
final DataBag sampleBag = (DataBag) t.get(2);
final ArrayList<Tuple> samples = ReservoirSampling.dataBagToArrayList(sampleBag);
union_.update(n, k, samples);
}
} catch (final IndexOutOfBoundsException e) {
throw new ExecException("Cannot update union with given reservoir", e);
}
}
@Override
public Tuple getValue() {
if (union_ == null) {
return null;
}
// newDefaultBag(List<Tuple>) does *not* copy values
final ReservoirItemsSketch<Tuple> resultSketch = union_.getResult();
final List<Tuple> data = SamplingPigUtil.getRawSamplesAsList(resultSketch);
final DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);
return ReservoirSampling.createResultTuple(resultSketch.getN(), resultSketch.getK(), sampleBag);
}
@Override
public void cleanup() {
union_ = null;
}
/**
* Validates format of input schema and returns a matching schema
* @param input Expects input to be a bag of sketches: <tt>(n, k, {(samples...)})</tt>
* @return Schema based on the
*/
@Override
public Schema outputSchema(final Schema input) {
if (input != null && input.size() > 0) {
try {
Schema source = input;
// if we have a bag, grab one level down to get a tuple
if (source.size() == 1 && source.getField(0).type == DataType.BAG) {
source = source.getField(0).schema;
}
if (source.size() == 1 && source.getField(0).type == DataType.TUPLE) {
source = source.getField(0).schema;
}
final List<Schema.FieldSchema> fields = source.getFields();
if (fields.size() == 3
&& fields.get(0).type == DataType.LONG
&& fields.get(1).type == DataType.INTEGER
&& fields.get(2).type == DataType.BAG) {
return new Schema(new Schema.FieldSchema(getSchemaName(this
.getClass().getName().toLowerCase(), source), source, DataType.TUPLE));
}
} catch (final FrontendException e) {
throw new RuntimeException(e);
}
}
return null;
}
}