| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.cassandra.cql3.functions; |
| |
| import java.math.BigDecimal; |
| import java.math.BigInteger; |
| import java.nio.ByteBuffer; |
| import java.util.List; |
| |
| import org.apache.cassandra.db.marshal.AbstractType; |
| import org.apache.cassandra.db.marshal.ByteType; |
| import org.apache.cassandra.db.marshal.CounterColumnType; |
| import org.apache.cassandra.db.marshal.DecimalType; |
| import org.apache.cassandra.db.marshal.DoubleType; |
| import org.apache.cassandra.db.marshal.FloatType; |
| import org.apache.cassandra.db.marshal.Int32Type; |
| import org.apache.cassandra.db.marshal.IntegerType; |
| import org.apache.cassandra.db.marshal.LongType; |
| import org.apache.cassandra.db.marshal.ShortType; |
| |
| /** |
| * Factory methods for aggregate functions. |
| */ |
| public abstract class AggregateFcts |
| { |
| /** |
| * Checks if the specified function is the count rows (e.g. COUNT(*) or COUNT(1)) function. |
| * |
| * @param function the function to check |
| * @return <code>true</code> if the specified function is the count rows one, <code>false</code> otherwise. |
| */ |
| public static boolean isCountRows(Function function) |
| { |
| return function == countRowsFunction; |
| } |
| |
| /** |
| * The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1) |
| * is specified. |
| */ |
| public static final AggregateFunction countRowsFunction = |
| new NativeAggregateFunction("countRows", LongType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private long count; |
| |
| public void reset() |
| { |
| count = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((LongType) returnType()).decompose(Long.valueOf(count)); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| count++; |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for decimal values. |
| */ |
| public static final AggregateFunction sumFunctionForDecimal = |
| new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance) |
| { |
| @Override |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private BigDecimal sum = BigDecimal.ZERO; |
| |
| public void reset() |
| { |
| sum = BigDecimal.ZERO; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((DecimalType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); |
| sum = sum.add(number); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The AVG function for decimal values. |
| */ |
| public static final AggregateFunction avgFunctionForDecimal = |
| new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private BigDecimal sum = BigDecimal.ZERO; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = BigDecimal.ZERO; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| if (count == 0) |
| return ((DecimalType) returnType()).decompose(BigDecimal.ZERO); |
| |
| return ((DecimalType) returnType()).decompose(sum.divide(BigDecimal.valueOf(count))); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); |
| sum = sum.add(number); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for varint values. |
| */ |
| public static final AggregateFunction sumFunctionForVarint = |
| new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private BigInteger sum = BigInteger.ZERO; |
| |
| public void reset() |
| { |
| sum = BigInteger.ZERO; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((IntegerType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); |
| sum = sum.add(number); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The AVG function for varint values. |
| */ |
| public static final AggregateFunction avgFunctionForVarint = |
| new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private BigInteger sum = BigInteger.ZERO; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = BigInteger.ZERO; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| if (count == 0) |
| return ((IntegerType) returnType()).decompose(BigInteger.ZERO); |
| |
| return ((IntegerType) returnType()).decompose(sum.divide(BigInteger.valueOf(count))); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); |
| sum = sum.add(number); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for int32 values. |
| */ |
| public static final AggregateFunction sumFunctionForByte = |
| new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private byte sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((ByteType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.byteValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * AVG function for int32 values. |
| */ |
| public static final AggregateFunction avgFunctionForByte = |
| new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private byte sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| int avg = count == 0 ? 0 : sum / count; |
| |
| return ((ByteType) returnType()).decompose((byte) avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.byteValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for int32 values. |
| */ |
| public static final AggregateFunction sumFunctionForShort = |
| new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private short sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((ShortType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.shortValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * AVG function for int32 values. |
| */ |
| public static final AggregateFunction avgFunctionForShort = |
| new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private short sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| int avg = count == 0 ? 0 : sum / count; |
| |
| return ((ShortType) returnType()).decompose((short) avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.shortValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for int32 values. |
| */ |
| public static final AggregateFunction sumFunctionForInt32 = |
| new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private int sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((Int32Type) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.intValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * AVG function for int32 values. |
| */ |
| public static final AggregateFunction avgFunctionForInt32 = |
| new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private int sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| int avg = count == 0 ? 0 : sum / count; |
| |
| return ((Int32Type) returnType()).decompose(avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.intValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for long values. |
| */ |
| public static final AggregateFunction sumFunctionForLong = |
| new NativeAggregateFunction("sum", LongType.instance, LongType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new LongSumAggregate(); |
| } |
| }; |
| |
| /** |
| * AVG function for long values. |
| */ |
| public static final AggregateFunction avgFunctionForLong = |
| new NativeAggregateFunction("avg", LongType.instance, LongType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new LongAvgAggregate(); |
| } |
| }; |
| |
| /** |
| * The SUM function for float values. |
| */ |
| public static final AggregateFunction sumFunctionForFloat = |
| new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private float sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((FloatType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.floatValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * AVG function for float values. |
| */ |
| public static final AggregateFunction avgFunctionForFloat = |
| new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private float sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| float avg = count == 0 ? 0 : sum / count; |
| |
| return ((FloatType) returnType()).decompose(avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.floatValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for double values. |
| */ |
| public static final AggregateFunction sumFunctionForDouble = |
| new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private double sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((DoubleType) returnType()).decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.doubleValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * AVG function for double values. |
| */ |
| public static final AggregateFunction avgFunctionForDouble = |
| new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private double sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| double avg = count == 0 ? 0 : sum / count; |
| |
| return ((DoubleType) returnType()).decompose(avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = ((Number) argTypes().get(0).compose(value)); |
| sum += number.doubleValue(); |
| } |
| }; |
| } |
| }; |
| |
| /** |
| * The SUM function for counter column values. |
| */ |
| public static final AggregateFunction sumFunctionForCounter = |
| new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new LongSumAggregate(); |
| } |
| }; |
| |
| /** |
| * AVG function for counter column values. |
| */ |
| public static final AggregateFunction avgFunctionForCounter = |
| new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance) |
| { |
| public Aggregate newAggregate() |
| { |
| return new LongAvgAggregate(); |
| } |
| }; |
| |
| /** |
| * Creates a MAX function for the specified type. |
| * |
| * @param inputType the function input and output type |
| * @return a MAX function for the specified type. |
| */ |
| public static AggregateFunction makeMaxFunction(final AbstractType<?> inputType) |
| { |
| return new NativeAggregateFunction("max", inputType, inputType) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private ByteBuffer max; |
| |
| public void reset() |
| { |
| max = null; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return max; |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| if (max == null || returnType().compare(max, value) < 0) |
| max = value; |
| } |
| }; |
| } |
| }; |
| } |
| |
| /** |
| * Creates a MIN function for the specified type. |
| * |
| * @param inputType the function input and output type |
| * @return a MIN function for the specified type. |
| */ |
| public static AggregateFunction makeMinFunction(final AbstractType<?> inputType) |
| { |
| return new NativeAggregateFunction("min", inputType, inputType) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private ByteBuffer min; |
| |
| public void reset() |
| { |
| min = null; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return min; |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| if (min == null || returnType().compare(min, value) > 0) |
| min = value; |
| } |
| }; |
| } |
| }; |
| } |
| |
| /** |
| * Creates a COUNT function for the specified type. |
| * |
| * @param inputType the function input type |
| * @return a COUNT function for the specified type. |
| */ |
| public static AggregateFunction makeCountFunction(AbstractType<?> inputType) |
| { |
| return new NativeAggregateFunction("count", LongType.instance, inputType) |
| { |
| public Aggregate newAggregate() |
| { |
| return new Aggregate() |
| { |
| private long count; |
| |
| public void reset() |
| { |
| count = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return ((LongType) returnType()).decompose(count); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| } |
| }; |
| } |
| }; |
| } |
| |
| private static class LongSumAggregate implements AggregateFunction.Aggregate |
| { |
| private long sum; |
| |
| public void reset() |
| { |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| return LongType.instance.decompose(sum); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| Number number = LongType.instance.compose(value); |
| sum += number.longValue(); |
| } |
| } |
| |
| private static class LongAvgAggregate implements AggregateFunction.Aggregate |
| { |
| private long sum; |
| |
| private int count; |
| |
| public void reset() |
| { |
| count = 0; |
| sum = 0; |
| } |
| |
| public ByteBuffer compute(int protocolVersion) |
| { |
| long avg = count == 0 ? 0 : sum / count; |
| |
| return LongType.instance.decompose(avg); |
| } |
| |
| public void addInput(int protocolVersion, List<ByteBuffer> values) |
| { |
| ByteBuffer value = values.get(0); |
| |
| if (value == null) |
| return; |
| |
| count++; |
| Number number = LongType.instance.compose(value); |
| sum += number.longValue(); |
| } |
| } |
| } |