blob: 441fa58d6eac989e7b0d4c135e302a497b39b8a2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.cql3.functions;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.db.marshal.*;
import org.apache.cassandra.exceptions.InvalidRequestException;
/**
* Factory methods for aggregate functions.
*/
public abstract class AggregateFcts
{
public static Collection<AggregateFunction> all()
{
Collection<AggregateFunction> functions = new ArrayList<>();
functions.add(countRowsFunction);
// sum for primitives
functions.add(sumFunctionForByte);
functions.add(sumFunctionForShort);
functions.add(sumFunctionForInt32);
functions.add(sumFunctionForLong);
functions.add(sumFunctionForFloat);
functions.add(sumFunctionForDouble);
functions.add(sumFunctionForDecimal);
functions.add(sumFunctionForVarint);
functions.add(sumFunctionForCounter);
// avg for primitives
functions.add(avgFunctionForByte);
functions.add(avgFunctionForShort);
functions.add(avgFunctionForInt32);
functions.add(avgFunctionForLong);
functions.add(avgFunctionForFloat);
functions.add(avgFunctionForDouble);
functions.add(avgFunctionForDecimal);
functions.add(avgFunctionForVarint);
functions.add(avgFunctionForCounter);
// count, max, and min for all standard types
for (CQL3Type type : CQL3Type.Native.values())
{
if (type != CQL3Type.Native.VARCHAR) // varchar and text both mapping to UTF8Type
{
functions.add(AggregateFcts.makeCountFunction(type.getType()));
if (type != CQL3Type.Native.COUNTER)
{
functions.add(AggregateFcts.makeMaxFunction(type.getType()));
functions.add(AggregateFcts.makeMinFunction(type.getType()));
}
else
{
functions.add(AggregateFcts.maxFunctionForCounter);
functions.add(AggregateFcts.minFunctionForCounter);
}
}
}
return functions;
}
/**
* Checks if the specified function is the count rows (e.g. COUNT(*) or COUNT(1)) function.
*
* @param function the function to check
* @return <code>true</code> if the specified function is the count rows one, <code>false</code> otherwise.
*/
public static boolean isCountRows(Function function)
{
return function == countRowsFunction;
}
/**
* The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1)
* is specified.
*/
public static final AggregateFunction countRowsFunction =
new NativeAggregateFunction("countRows", LongType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private long count;
public void reset()
{
count = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return LongType.instance.decompose(count);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
count++;
}
};
}
};
/**
* The SUM function for decimal values.
*/
public static final AggregateFunction sumFunctionForDecimal =
new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance)
{
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigDecimal sum = BigDecimal.ZERO;
public void reset()
{
sum = BigDecimal.ZERO;
}
public ByteBuffer compute(int protocolVersion)
{
return ((DecimalType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
BigDecimal number = DecimalType.instance.compose(value);
sum = sum.add(number);
}
};
}
};
/**
* The AVG function for decimal values.
*/
public static final AggregateFunction avgFunctionForDecimal =
new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigDecimal avg = BigDecimal.ZERO;
private int count;
public void reset()
{
count = 0;
avg = BigDecimal.ZERO;
}
public ByteBuffer compute(int protocolVersion)
{
return DecimalType.instance.decompose(avg);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
count++;
BigDecimal number = DecimalType.instance.compose(value);
// avg = avg + (value - sum) / count.
avg = avg.add(number.subtract(avg).divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN));
}
};
}
};
/**
* The SUM function for varint values.
*/
public static final AggregateFunction sumFunctionForVarint =
new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigInteger sum = BigInteger.ZERO;
public void reset()
{
sum = BigInteger.ZERO;
}
public ByteBuffer compute(int protocolVersion)
{
return ((IntegerType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
BigInteger number = IntegerType.instance.compose(value);
sum = sum.add(number);
}
};
}
};
/**
* The AVG function for varint values.
*/
public static final AggregateFunction avgFunctionForVarint =
new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigInteger sum = BigInteger.ZERO;
private int count;
public void reset()
{
count = 0;
sum = BigInteger.ZERO;
}
public ByteBuffer compute(int protocolVersion)
{
if (count == 0)
return IntegerType.instance.decompose(BigInteger.ZERO);
return IntegerType.instance.decompose(sum.divide(BigInteger.valueOf(count)));
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
count++;
BigInteger number = ((BigInteger) argTypes().get(0).compose(value));
sum = sum.add(number);
}
};
}
};
/**
* The SUM function for byte values (tinyint).
*/
public static final AggregateFunction sumFunctionForByte =
new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private byte sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((ByteType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = ((Number) argTypes().get(0).compose(value));
sum += number.byteValue();
}
};
}
};
/**
* AVG function for byte values (tinyint).
*/
public static final AggregateFunction avgFunctionForByte =
new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate(ByteType.instance)
{
public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
{
return ByteType.instance.decompose((byte) computeInternal());
}
};
}
};
/**
* The SUM function for short values (smallint).
*/
public static final AggregateFunction sumFunctionForShort =
new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private short sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((ShortType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = ((Number) argTypes().get(0).compose(value));
sum += number.shortValue();
}
};
}
};
/**
* AVG function for for short values (smallint).
*/
public static final AggregateFunction avgFunctionForShort =
new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate(ShortType.instance)
{
public ByteBuffer compute(int protocolVersion)
{
return ShortType.instance.decompose((short) computeInternal());
}
};
}
};
/**
* The SUM function for int32 values.
*/
public static final AggregateFunction sumFunctionForInt32 =
new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private int sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((Int32Type) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = ((Number) argTypes().get(0).compose(value));
sum += number.intValue();
}
};
}
};
/**
* AVG function for int32 values.
*/
public static final AggregateFunction avgFunctionForInt32 =
new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate(Int32Type.instance)
{
public ByteBuffer compute(int protocolVersion)
{
return Int32Type.instance.decompose((int) computeInternal());
}
};
}
};
/**
* The SUM function for long values.
*/
public static final AggregateFunction sumFunctionForLong =
new NativeAggregateFunction("sum", LongType.instance, LongType.instance)
{
public Aggregate newAggregate()
{
return new LongSumAggregate();
}
};
/**
* AVG function for long values.
*/
public static final AggregateFunction avgFunctionForLong =
new NativeAggregateFunction("avg", LongType.instance, LongType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate(LongType.instance)
{
public ByteBuffer compute(int protocolVersion)
{
return LongType.instance.decompose(computeInternal());
}
};
}
};
/**
* The SUM function for float values.
*/
public static final AggregateFunction sumFunctionForFloat =
new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private float sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((FloatType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = ((Number) argTypes().get(0).compose(value));
sum += number.floatValue();
}
};
}
};
/**
* AVG function for float values.
*/
public static final AggregateFunction avgFunctionForFloat =
new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance)
{
public Aggregate newAggregate()
{
return new FloatAvgAggregate(FloatType.instance)
{
public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
{
return FloatType.instance.decompose((float) computeInternal());
}
};
}
};
/**
* The SUM function for double values.
*/
public static final AggregateFunction sumFunctionForDouble =
new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private double sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((DoubleType) returnType()).decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = ((Number) argTypes().get(0).compose(value));
sum += number.doubleValue();
}
};
}
};
/**
* Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm
* to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is
* converted to corresponding representation by concrete implementations.
*/
private static abstract class FloatAvgAggregate implements AggregateFunction.Aggregate
{
private double sum;
private double compensation;
private double simpleSum;
private int count;
private BigDecimal bigSum = null;
private boolean overflow = false;
private final AbstractType numberType;
public FloatAvgAggregate(AbstractType numberType)
{
this.numberType = numberType;
}
public void reset()
{
sum = 0;
compensation = 0;
simpleSum = 0;
count = 0;
bigSum = null;
overflow = false;
}
public double computeInternal()
{
if (count == 0)
return 0d;
if (overflow)
{
return bigSum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN).doubleValue();
}
else
{
// correctly compute final sum if it's NaN from consequently
// adding same-signed infinite values.
double tmp = sum + compensation;
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
sum = simpleSum;
else
sum = tmp;
return sum / count;
}
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
count++;
double number = ((Number) numberType.compose(value)).doubleValue();
if (overflow)
{
bigSum = bigSum.add(BigDecimal.valueOf(number));
}
else
{
simpleSum += number;
double prev = sum;
double tmp = number - compensation;
double rounded = sum + tmp;
compensation = (rounded - sum) - tmp;
sum = rounded;
if (Double.isInfinite(sum) && !Double.isInfinite(number))
{
overflow = true;
bigSum = BigDecimal.valueOf(prev).add(BigDecimal.valueOf(number));
}
}
}
}
/**
* AVG function for double values.
*/
public static final AggregateFunction avgFunctionForDouble =
new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance)
{
public Aggregate newAggregate()
{
return new FloatAvgAggregate(DoubleType.instance)
{
public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
{
return DoubleType.instance.decompose(computeInternal());
}
};
}
};
/**
* The SUM function for counter column values.
*/
public static final AggregateFunction sumFunctionForCounter =
new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new LongSumAggregate();
}
};
/**
* AVG function for counter column values.
*/
public static final AggregateFunction avgFunctionForCounter =
new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate(LongType.instance)
{
public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
{
return CounterColumnType.instance.decompose(computeInternal());
}
};
}
};
/**
* The MIN function for counter column values.
*/
public static final AggregateFunction minFunctionForCounter =
new NativeAggregateFunction("min", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private Long min;
public void reset()
{
min = null;
}
public ByteBuffer compute(int protocolVersion)
{
return min != null ? LongType.instance.decompose(min) : null;
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
long lval = LongType.instance.compose(value);
if (min == null || lval < min)
min = lval;
}
};
}
};
/**
* MAX function for counter column values.
*/
public static final AggregateFunction maxFunctionForCounter =
new NativeAggregateFunction("max", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private Long max;
public void reset()
{
max = null;
}
public ByteBuffer compute(int protocolVersion)
{
return max != null ? LongType.instance.decompose(max) : null;
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
long lval = LongType.instance.compose(value);
if (max == null || lval > max)
max = lval;
}
};
}
};
/**
* Creates a MAX function for the specified type.
*
* @param inputType the function input and output type
* @return a MAX function for the specified type.
*/
public static AggregateFunction makeMaxFunction(final AbstractType<?> inputType)
{
return new NativeAggregateFunction("max", inputType, inputType)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private ByteBuffer max;
public void reset()
{
max = null;
}
public ByteBuffer compute(int protocolVersion)
{
return max;
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
if (max == null || returnType().compare(max, value) < 0)
max = value;
}
};
}
};
}
/**
* Creates a MIN function for the specified type.
*
* @param inputType the function input and output type
* @return a MIN function for the specified type.
*/
public static AggregateFunction makeMinFunction(final AbstractType<?> inputType)
{
return new NativeAggregateFunction("min", inputType, inputType)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private ByteBuffer min;
public void reset()
{
min = null;
}
public ByteBuffer compute(int protocolVersion)
{
return min;
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
if (min == null || returnType().compare(min, value) > 0)
min = value;
}
};
}
};
}
/**
* Creates a COUNT function for the specified type.
*
* @param inputType the function input type
* @return a COUNT function for the specified type.
*/
public static AggregateFunction makeCountFunction(AbstractType<?> inputType)
{
return new NativeAggregateFunction("count", LongType.instance, inputType)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private long count;
public void reset()
{
count = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return ((LongType) returnType()).decompose(count);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
count++;
}
};
}
};
}
private static class LongSumAggregate implements AggregateFunction.Aggregate
{
private long sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(int protocolVersion)
{
return LongType.instance.decompose(sum);
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
Number number = LongType.instance.compose(value);
sum += number.longValue();
}
}
/**
* Average aggregate class, collecting the sum using long arithmetics, falling back
* to BigInteger on long overflow. Resulting number is converted to corresponding
* representation by concrete implementations.
*/
private static abstract class AvgAggregate implements AggregateFunction.Aggregate
{
private long sum;
private int count;
private BigInteger bigSum = null;
private boolean overflow = false;
private final AbstractType numberType;
public AvgAggregate(AbstractType type)
{
this.numberType = type;
}
public void reset()
{
count = 0;
sum = 0L;
overflow = false;
bigSum = null;
}
long computeInternal()
{
if (overflow)
{
return bigSum.divide(BigInteger.valueOf(count)).longValue();
}
else
{
return count == 0 ? 0 : (sum / count);
}
}
public void addInput(int protocolVersion, List<ByteBuffer> values)
{
ByteBuffer value = values.get(0);
if (value == null)
return;
count++;
long number = ((Number) numberType.compose(value)).longValue();
if (overflow)
{
bigSum = bigSum.add(BigInteger.valueOf(number));
}
else
{
long prev = sum;
sum += number;
if (((prev ^ sum) & (number ^ sum)) < 0)
{
overflow = true;
bigSum = BigInteger.valueOf(prev).add(BigInteger.valueOf(number));
}
}
}
}
}