blob: 6f2282374c805fef333a2bf8c94ca14f024f6001 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.pig.piggybank.evaluation;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.pig.EvalFunc;
import org.apache.pig.PigException;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.builtin.AVG;
import org.apache.pig.builtin.BigDecimalAvg;
import org.apache.pig.builtin.BigDecimalMax;
import org.apache.pig.builtin.BigDecimalMin;
import org.apache.pig.builtin.BigDecimalSum;
import org.apache.pig.builtin.COUNT;
import org.apache.pig.builtin.DoubleAvg;
import org.apache.pig.builtin.DoubleMax;
import org.apache.pig.builtin.DoubleMin;
import org.apache.pig.builtin.DoubleSum;
import org.apache.pig.builtin.FloatAvg;
import org.apache.pig.builtin.FloatMax;
import org.apache.pig.builtin.FloatMin;
import org.apache.pig.builtin.IntAvg;
import org.apache.pig.builtin.IntMax;
import org.apache.pig.builtin.IntMin;
import org.apache.pig.builtin.LongAvg;
import org.apache.pig.builtin.LongMax;
import org.apache.pig.builtin.LongMin;
import org.apache.pig.builtin.LongSum;
import org.apache.pig.builtin.MAX;
import org.apache.pig.builtin.MIN;
import org.apache.pig.builtin.StringMax;
import org.apache.pig.builtin.StringMin;
import org.apache.pig.builtin.SUM;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import org.apache.pig.impl.logicalLayer.schema.Schema.FieldSchema;
/**
* Given an aggregate function, a bag, and possibly a window definition,
* produce output that matches SQL OVER. It is the reponsibility of the caller
* to have already ordered the bag as required by their operation.
* The aggregate, and window definition are passed in the constructor. The bag
* is passed to exec each time.
*
* <p>Usage: Over(bag, function_to_call[, window_start, window_end[, function specific args]])
*
* <p>bag - The bag to be called. Most functions assume this is a bag with tuples
* of a single field.
* <p>function_to_call - Can be one of the following: <ul>
* <li>count</li>
* <li>sum(double)
* <li>sum(float)</li>
* <li>sum(int)</li>
* <li>sum(long)</li>
* <li>sum(bytearray)</li>
* <li>sum(bigdecimal)</li>
* <li>avg(double)</li>
* <li>avg(float)</li>
* <li>avg(long)</li>
* <li>avg(int)</li>
* <li>avg(bytearray)</li>
* <li>avg(bigdecimal)</li>
* <li>min(double)</li>
* <li>min(float)</li>
* <li>min(long)</li>
* <li>min(int)</li>
* <li>min(chararray)</li>
* <li>min(bytearray)</li>
* <li>min(bigdecimal)</li>
* <li>max(double)</li>
* <li>max(float)</li>
* <li>max(long)</li>
* <li>max(int)</li>
* <li>max(chararray)</li>
* <li>max(bytearray)</li>
* <li>max(bigdecimal)</li>
* <li>row_number</li>
* <li>first_value</li>
* <li>last_value</li>
* <li>lead</li>
* <li>lag</li>
* <li>rank</li>
* <li>dense_rank</li>
* <li>ntile</li>
* <li>percent_rank</li>
* <li>cume_dist</li>
* </ul>
* <p>window_start - optional - Record to start window on for the function. -1
* indicates 'unbounded preceding', i.e. the beginning of the bag. A positive
* integer indicates that number of records before the current record. 0
* indicates the current record. If not specified -1 is the default.
* <p>window_end - optional - Record to end window on for the function. -1
* indicates 'unbounded following', i.e. the end of the bag. A positive
* integer indicates that number of records after the current record. 0
* indicates teh current record. If not specified 0 is the default.
* <p>function_specific_args - maybe optional - The following functions accept
* require additional arguments: <ul>
* <li>lead - two optional arguments, first number of records ahead
* of current to lead, second default value when lead extends beyond
* the end of the window frame.</li>
* <li>lag - two optional arguments, first number of records behind
* of current to lag, second default value when lag extends beyond
* the beginning of the window frame.</li>
* <li>rank - one required, the number of the field the bag is
* ordered by</li>
* <li>dense_rank - one required, the number of the field the bag is
* ordered by</li>
* <li>ntile - one required, the number of buckets to split the data
* into</li>
* <li>percent_rank - one required, the number of the field the bag
* is ordered by</li>
* <li>cume_dist - one required, the number of the field the bag is
* ordered by</li>
* </ul>
*
* <p>Example Usage:
* <p>To do a cumulative sum:
* <p><pre> A = load 'T' AS (si:chararray, i:int, d:long, f:float, s:chararray);
* C = foreach (group A by si) {
* Aord = order A by d;
* generate flatten(Stitch(Aord, Over(Aord.f, 'sum(float)')));
* }
* D = foreach C generate s, $5;</pre>
* <p> This is equivalent to the SQL statement
* <p><tt>select s, sum(f) over (partition by si order by d) from T;</tt>
*
* <p>To find the record 3 ahead of the current record, using a window between
* the current row and 3 records ahead and a default value of 0.
* <p><pre> A = load 'T' AS (si:chararray, i:int, d:long, f:float, s:chararray);
* C = foreach (group A by si) {
* Aord = order A by i;
* generate flatten(Stitch(Aord, Over(Aord.i, 'lead', 0, 3, 3, 0)));
* }
* D = foreach C generate s, $9;</pre>
* <p> This is equivalent to the SQL statement
* <p><tt>select s, lead(i, 3, 0) over (partition by si order by i rows between
* current row and 3 following) over T;</tt>
*
* <p>Over accepts a constructor argument specifying the name and type,
* colon-separated, of its return schema. If the argument option is 'true' use the inner-search,
* take the name and type of bag and return a schema with alias+'_over' and the same type</p>
*
* <p><pre>
* DEFINE IOver org.apache.pig.piggybank.evaluation.Over('state_rk:int');
* cities = LOAD 'cities' AS (city:chararray, state:chararray, pop:int);
* -- Decorate each city with its population rank within the state it belongs to:
* ranked = FOREACH(GROUP cities BY state) {
* c_ord = ORDER cities BY pop DESC;
* GENERATE FLATTEN(Stitch(c_ord,
* IOver(c_ord, 'rank', -1, -1, 2))); -- beginning (-1) to end (-1) on third field (2)
* };
* DESCRIBE ranked;
* -- ranked: {stitched::city: chararray,stitched::state: chararray,stitched::pop: int,stitched::state_rk: int}
* DUMP ranked;
* -- ...
* -- (Nashville,Tennessee,609644,2)
* -- (Houston,Texas,2145146,1)
* -- (San Antonio,Texas,1359758,2)
* -- (Dallas,Texas,1223229,3)
* -- (Austin,Texas,820611,4)
* -- ...
* </pre></p>
*/
public class Over extends EvalFunc<DataBag> {
private static TupleFactory mTupleFactory = TupleFactory.getInstance();
private int rowsBefore;
private int rowsAfter;
private String agg;
private boolean initialized;
private EvalFunc<? extends Object> func;
private Object[] udfArgs;
private byte returnType;
private String returnName;
private boolean searchInnerType;
public Over() {
initialized = false;
udfArgs = null;
func = null;
returnType = DataType.UNKNOWN;
searchInnerType = false;
}
public Over(String typespec) {
this();
if (typespec.contains(":")) {
String[] fn_tn = typespec.split(":", 2);
this.returnName = fn_tn[0];
this.returnType = DataType.findTypeByName(fn_tn[1]);
} else if(Boolean.parseBoolean(typespec)) {
searchInnerType = Boolean.parseBoolean(typespec);
}else{
this.returnName = "result";
this.returnType = DataType.findTypeByName(typespec);
}
}
@Override
public DataBag exec(Tuple input) throws IOException {
if (input == null || input.size() < 2) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Over expected 2 or more inputs but received "
+ input.size();
throw new ExecException(msg, errCode, PigException.INPUT);
}
DataBag inbag = null;
try {
inbag = (DataBag)input.get(0);
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Over expected a bag for arg 1 but received " +
DataType.findTypeName(input.get(0));
throw new ExecException(msg, errCode, PigException.INPUT);
}
if (!initialized) {
init(input);
} else {
if (func instanceof ResetableEvalFunc) {
((ResetableEvalFunc)func).reset();
}
}
// Copy the bag into a special bag that we can offset into so we don't
// have to copy once for each row
OverBag tmpbag = new OverBag(inbag, rowsBefore, rowsAfter);
Tuple tmptuple = mTupleFactory.newTuple(1);
tmptuple.set(0, tmpbag);
DataBag outbag = BagFactory.getInstance().newDefaultBag();
for (int i = 0; i < inbag.size(); i++) {
tmpbag.setCurrentRow(i);
Tuple t = mTupleFactory.newTuple(1);
t.set(0, func.exec(tmptuple));
outbag.add(t);
}
return outbag;
}
@Override
public Schema outputSchema(Schema inputSch) {
try {
FieldSchema field;
if (searchInnerType) {
field = new FieldSchema(inputSch.getField(0));
while (searchInnerType) {
if (field.schema != null
&& field.schema.getFields().size() > 1) {
searchInnerType = false;
} else {
if (field.type == DataType.TUPLE
|| field.type == DataType.BAG) {
field = new FieldSchema(field.schema.getField(0));
} else {
field.alias = field.alias + "_over";
searchInnerType = false;
}
}
}
searchInnerType = true;
} else if (returnType == DataType.UNKNOWN) {
return Schema.generateNestedSchema(DataType.BAG, DataType.NULL);
} else {
field = new Schema.FieldSchema(returnName, returnType);
}
Schema outputTupleSchema = new Schema(field);
return new Schema(new Schema.FieldSchema(getSchemaName(this
.getClass().getName().toLowerCase(), inputSch),
outputTupleSchema, DataType.BAG));
} catch (FrontendException fe) {
throw new RuntimeException("Unable to create nested schema", fe);
}
}
private void init(Tuple input) throws IOException {
initialized = true;
// Look for the aggregate in arg 2
try {
agg = (String)input.get(1);
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Over expected a string for arg 2 but received " +
DataType.findTypeName(input.get(1));
throw new ExecException(msg, errCode, PigException.INPUT);
}
// See if there is a preceding value specified, if not, unbounded
// preceding is the default
rowsBefore = -1;
if (input.size() > 2) {
try {
rowsBefore = (Integer)input.get(2);
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Over expected an integer for arg 3 but " +
"received " + DataType.findTypeName(input.get(2));
throw new ExecException(msg, errCode, PigException.INPUT);
}
}
// See if there is a preceding value specified, if not, current row
// is the default
rowsAfter = 0;
if (input.size() > 3) {
try {
rowsAfter = (Integer)input.get(3);
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Over expected an integer for arg 4 but " +
"received " + DataType.findTypeName(input.get(3));
throw new ExecException(msg, errCode, PigException.INPUT);
}
}
// Place any additional arguments in the udfArgs array to be passed
// to the UDF each time
if (input.size() > 4) {
udfArgs = new Object[input.size() - 4];
for (int i = 0; i < input.size() - 4; i++) {
udfArgs[i] = input.get(i + 4);
}
}
if ("count".equalsIgnoreCase(agg)) {
func = new COUNT();
} else if ("sum(double)".equalsIgnoreCase(agg) ||
"sum(float)".equalsIgnoreCase(agg)) {
func = new DoubleSum();
} else if ("sum(int)".equalsIgnoreCase(agg) ||
"sum(long)".equalsIgnoreCase(agg)) {
func = new LongSum();
} else if ("sum(bytearray)".equalsIgnoreCase(agg)) {
func = new SUM();
} else if ("sum(bigdecimal)".equalsIgnoreCase(agg)) {
func = new BigDecimalSum();
} else if ("avg(double)".equalsIgnoreCase(agg)) {
func = new DoubleAvg();
} else if ("avg(float)".equalsIgnoreCase(agg)) {
func = new FloatAvg();
} else if ("avg(long)".equalsIgnoreCase(agg)) {
func = new LongAvg();
} else if ("avg(int)".equalsIgnoreCase(agg)) {
func = new IntAvg();
} else if ("avg(bytearray)".equalsIgnoreCase(agg)) {
func = new AVG();
} else if ("avg(bigdecimal)".equalsIgnoreCase(agg)) {
func = new BigDecimalAvg();
} else if ("min(double)".equalsIgnoreCase(agg)) {
func = new DoubleMin();
} else if ("min(float)".equalsIgnoreCase(agg)) {
func = new FloatMin();
} else if ("min(long)".equalsIgnoreCase(agg)) {
func = new LongMin();
} else if ("min(int)".equalsIgnoreCase(agg)) {
func = new IntMin();
} else if ("min(chararray)".equalsIgnoreCase(agg)) {
func = new StringMin();
} else if ("min(bytearray)".equalsIgnoreCase(agg)) {
func = new MIN();
} else if ("min(bigdecimal)".equalsIgnoreCase(agg)) {
func = new BigDecimalMin();
} else if ("max(double)".equalsIgnoreCase(agg)) {
func = new DoubleMax();
} else if ("max(float)".equalsIgnoreCase(agg)) {
func = new FloatMax();
} else if ("max(long)".equalsIgnoreCase(agg)) {
func = new LongMax();
} else if ("max(int)".equalsIgnoreCase(agg)) {
func = new IntMax();
} else if ("max(chararray)".equalsIgnoreCase(agg)) {
func = new StringMax();
} else if ("max(bytearray)".equalsIgnoreCase(agg)) {
func = new MAX();
} else if ("max(bigdecimal)".equalsIgnoreCase(agg)) {
func = new BigDecimalMax();
} else if ("row_number".equalsIgnoreCase(agg)) {
func = new RowNumber();
} else if ("first_value".equalsIgnoreCase(agg)) {
func = new FirstValue();
} else if ("last_value".equalsIgnoreCase(agg)) {
func = new LastValue();
} else if ("lead".equalsIgnoreCase(agg)) {
func = new Lead(udfArgs);
} else if ("lag".equalsIgnoreCase(agg)) {
func = new Lag(udfArgs);
} else if ("rank".equalsIgnoreCase(agg)) {
func = new Rank(udfArgs);
} else if ("dense_rank".equalsIgnoreCase(agg)) {
func = new DenseRank(udfArgs);
} else if ("ntile".equalsIgnoreCase(agg)) {
func = new Ntile(udfArgs);
} else if ("percent_rank".equalsIgnoreCase(agg)) {
func = new PercentRank(udfArgs);
} else if ("cume_dist".equalsIgnoreCase(agg)) {
//func = new CumeDist(udfArgs);
func = new CumeDist();
} else if ("debug".equalsIgnoreCase(agg)) {
func = new Debug();
} else {
throw new ExecException("Unknown aggregate " + agg);
}
}
static class OverBag implements DataBag {
private List<Tuple> tuples;
private int before;
private int after;
private int currentRow;
OverBag(DataBag bag, int before, int after) {
addAll(bag);
this.before = before;
this.after = after;
currentRow = 0;
}
public long size() {
return endPosition() - startPosition();
}
public boolean isSorted() {
return false; // don't actually know
}
public boolean isDistinct() {
return false;
}
public Iterator<Tuple> iterator() {
return new OverBagIterator(tuples, currentRow, startPosition(),
endPosition());
}
public void add(Tuple t) {
throw new RuntimeException("OverBag.add not implemented");
}
public void addAll(DataBag b) {
tuples = new ArrayList<Tuple>((int)b.size());
for (Tuple t : b) {
tuples.add(t);
}
}
public void clear() {
throw new RuntimeException("OverBag.clear not implemented");
}
public void markStale(boolean stale) {
throw new RuntimeException("OverBag.markStale not implemented");
}
public void readFields(java.io.DataInput in) {
throw new RuntimeException("OverBag.readFields not implemented");
}
public void write(java.io.DataOutput out) {
throw new RuntimeException("OverBag.write not implemented");
}
public int compareTo(Object o) {
throw new RuntimeException("OverBag.compareTo not implemented");
}
void setCurrentRow(int cr) {
currentRow = cr;
}
private int startPosition() {
return (before == -1 ? 0 : currentRow - before);
}
private int endPosition() {
return (after == -1 ? tuples.size() : currentRow + after + 1);
}
static class OverBagIterator implements Iterator<Tuple> {
List<Tuple> tuples;
int currentRow; // from the UDFs perspective
int begin;
int end;
int nextRow; // next row this iterator will return
OverBagIterator(List<Tuple> tuples,
int currentRow,
int begin,
int end) {
this.tuples = tuples;
this.currentRow = currentRow;
this.begin = begin;
this.end = end;
nextRow = begin;
}
public boolean hasNext() {
return nextRow < end;
}
public Tuple next() {
try {
// Check if the beginning of frame is positioned before the
// beginning of the bag.
if (nextRow < 0) return mTupleFactory.newTuple(1);
// Check if the pointer has moved past the end of the window
if (nextRow >= tuples.size()) {
return mTupleFactory.newTuple(1);
}
return tuples.get(nextRow);
} finally {
// Placed here so we increment it no matter what.
nextRow++;
}
}
public void remove() {
throw new RuntimeException(
"OverBagIterator.remove not implemented");
}
}
public long spill() {
return 0;
}
public long getMemorySize() {
return 0;
}
}
private static abstract class ResetableEvalFunc<K> extends EvalFunc<K> {
protected int currentRow;
ResetableEvalFunc() {
reset();
}
void reset() {
currentRow = 0;
}
}
// Makes some serious assumptions about how many times it's called, don't call
// it any extra times.
private static class RowNumber extends ResetableEvalFunc<Integer> {
@Override
public Integer exec(Tuple input) throws IOException {
return ++currentRow;
}
}
private static class FirstValue extends EvalFunc<Object> {
@Override
public Object exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
if (inbag.size() == 0) return null;
return inbag.iterator().next().get(0);
}
}
private static class LastValue extends EvalFunc<Object> {
@Override
public Object exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
return iter.tuples.get(iter.end - 1).get(0);
}
}
// Makes some serious assumptions about how many times it's called, don't call
// it any extra times.
private static class Lead extends ResetableEvalFunc<Object> {
int rowsAhead;
Object deflt;
Lead(Object[] args) throws IOException {
rowsAhead = 1;
deflt = null;
if (args != null) {
if (args.length >= 1) {
try {
rowsAhead = (Integer)args[0];
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Lead expected an integer for arg 2 " +
" but received " + DataType.findTypeName(args[0]);
throw new ExecException(msg, errCode, PigException.INPUT);
}
}
if (args.length >= 2) {
deflt = args[1];
}
}
reset();
}
@Override
public Object exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
if (currentRow < iter.tuples.size()) {
return iter.tuples.get(currentRow++).get(0);
} else if (deflt != null) {
return deflt;
} else {
return null;
}
}
@Override
void reset() {
currentRow = rowsAhead;
}
}
// Makes some serious assumptions about how many times it's called, don't call
// it any extra times.
private static class Lag extends ResetableEvalFunc<Object> {
int rowsBehind;
Object deflt;
Lag(Object[] args) throws IOException {
rowsBehind = 1;
deflt = null;
if (args != null) {
if (args.length >= 1) {
try {
rowsBehind = (Integer)args[0];
} catch (ClassCastException cce) {
int errCode = 2107; // TODO not sure this is the right one
String msg = "Lag expected an integer for arg 2 " +
" but received " + DataType.findTypeName(args[0]);
throw new ExecException(msg, errCode, PigException.INPUT);
}
}
if (args.length >= 2) {
deflt = args[1];
}
}
reset();
}
@Override
public Object exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
try {
if (currentRow >= 0) {
return iter.tuples.get(currentRow).get(0);
} else if (deflt != null) {
return deflt;
} else {
return null;
}
} finally {
currentRow++;
}
}
@Override
void reset() {
currentRow = -1 * rowsBehind;
}
}
// Makes some serious assumptions about how many times it's called, don't
// call it any extra times.
private static abstract class BaseRank<T> extends ResetableEvalFunc<T> {
Object[] lastKey;
int[] orderFields;
int lastRankUsed;
int timesThisRankUsed;
protected BaseRank(Object[] args) throws IOException {
if (args == null || args.length < 1) {
throw new ExecException(
"Rank args must contain ordering column numbers, "
+ "e.g. rank(1, 2)", 2107, PigException.INPUT);
}
lastKey = new Object[args.length];
orderFields = new int[args.length];
for (int i = 0; i < args.length; i++) {
try {
orderFields[i] = (Integer)args[i];
} catch (ClassCastException cce) {
throw new ExecException(
"Rank expected column number in arg " + i +
" but received " + DataType.findTypeName(args[i]),
2107, PigException.INPUT);
}
}
reset();
}
@Override
public T exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
if (lastRankUsed == 0) {
// First record
for (int i = 0; i < lastKey.length; i++) {
lastKey[i] = iter.tuples.get(0).get(orderFields[i]);
}
lastRankUsed = 1;
} else {
// Check to see if the keys have changed
boolean keyChange = false;
for (int i = 0; i < lastKey.length && !keyChange; i++) {
Object currentKey =
iter.tuples.get(currentRow).get(orderFields[i]);
if (lastKey[i] == null) {
keyChange |= currentKey != null;
} else {
keyChange |= !lastKey[i].equals(currentKey);
}
}
if (keyChange) {
incrementRank();
timesThisRankUsed = 1;
for (int i = 0; i < lastKey.length; i++) {
lastKey[i] =
iter.tuples.get(currentRow).get(orderFields[i]);
}
} else {
timesThisRankUsed++;
}
}
currentRow++;
return calculateRank(iter);
}
@Override
void reset() {
super.reset();
lastRankUsed = 0;
timesThisRankUsed = 1;
}
abstract protected void incrementRank();
abstract protected T calculateRank(OverBag.OverBagIterator iter);
}
private static class Rank extends BaseRank<Integer> {
Rank(Object[] args) throws IOException {
super(args);
}
protected void incrementRank() {
lastRankUsed += timesThisRankUsed;
}
protected Integer calculateRank(OverBag.OverBagIterator iter) {
return lastRankUsed;
}
}
private static class DenseRank extends BaseRank<Integer> {
DenseRank(Object[] args) throws IOException {
super(args);
}
protected void incrementRank() {
lastRankUsed++;
}
protected Integer calculateRank(OverBag.OverBagIterator iter) {
return lastRankUsed;
}
}
private static class PercentRank extends BaseRank<Double> {
PercentRank(Object[] args) throws IOException {
super(args);
}
protected void incrementRank() {
lastRankUsed += timesThisRankUsed;
}
protected Double calculateRank(OverBag.OverBagIterator iter) {
return ((double)lastRankUsed - 1.0 ) /
((double)iter.tuples.size() - 1.0);
}
}
/*
private static class CumeDist extends BaseRank<Double> {
CumeDist(Object[] args) throws IOException {
super(args);
}
protected void incrementRank() {
lastRankUsed += timesThisRankUsed;
}
protected Double calculateRank(OverBag.OverBagIterator iter) {
return ((double)lastRankUsed) / (double)iter.tuples.size();
}
}
*/
private static class CumeDist extends ResetableEvalFunc<Double> {
@Override
public Double exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
return ((double)++currentRow)/(double)iter.tuples.size();
}
}
// Makes some serious assumptions about how many times it's called, don't
// call it any extra times.
private static class Ntile extends ResetableEvalFunc<Integer> {
int numBuckets;
protected Ntile(Object[] args) throws IOException {
if (args == null || args.length != 1) {
throw new ExecException(
"Ntile args must contain arg describing how to split data, "
+ "e.g. ntile(4)", 2107, PigException.INPUT);
}
try {
numBuckets = (Integer)args[0];
} catch (ClassCastException cce) {
throw new ExecException(
"Ntile expected integer argument but received " +
DataType.findTypeName(args[0]), 2107, PigException.INPUT);
}
reset();
}
@Override
public Integer exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
int val = 0;
if (numBuckets >= iter.tuples.size()) val = currentRow + 1;
else val = currentRow * numBuckets / iter.tuples.size() + 1;
currentRow++;
return val;
}
}
private static class Debug extends EvalFunc<String> {
@Override
public String exec(Tuple input) throws IOException {
DataBag inbag = (DataBag)input.get(0);
OverBag.OverBagIterator iter =
(OverBag.OverBagIterator)inbag.iterator();
System.out.println("Current row " + iter.currentRow + " begin "
+ iter.begin + " end " + iter.end + " nextRow " +
iter.nextRow + " size " + iter.tuples.size());
System.out.print("{");
while (iter.hasNext()) {
Tuple t = iter.next();
if (t == null) System.out.print("null,");
else System.out.print(t.toString() + ",");
}
System.out.println("}");
return "bla";
}
}
}