/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.cql3.functions;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.db.marshal.*;
import org.apache.cassandra.exceptions.InvalidRequestException;

/**
 * Factory methods for aggregate functions.
 */
public abstract class AggregateFcts
{
    public static Collection<AggregateFunction> all()
    {
        Collection<AggregateFunction> functions = new ArrayList<>();

        functions.add(countRowsFunction);

        // sum for primitives
        functions.add(sumFunctionForByte);
        functions.add(sumFunctionForShort);
        functions.add(sumFunctionForInt32);
        functions.add(sumFunctionForLong);
        functions.add(sumFunctionForFloat);
        functions.add(sumFunctionForDouble);
        functions.add(sumFunctionForDecimal);
        functions.add(sumFunctionForVarint);
        functions.add(sumFunctionForCounter);

        // avg for primitives
        functions.add(avgFunctionForByte);
        functions.add(avgFunctionForShort);
        functions.add(avgFunctionForInt32);
        functions.add(avgFunctionForLong);
        functions.add(avgFunctionForFloat);
        functions.add(avgFunctionForDouble);
        functions.add(avgFunctionForDecimal);
        functions.add(avgFunctionForVarint);
        functions.add(avgFunctionForCounter);

        // count, max, and min for all standard types
        for (CQL3Type type : CQL3Type.Native.values())
        {
            if (type != CQL3Type.Native.VARCHAR) // varchar and text both mapping to UTF8Type
            {
                functions.add(AggregateFcts.makeCountFunction(type.getType()));
                if (type != CQL3Type.Native.COUNTER)
                {
                    functions.add(AggregateFcts.makeMaxFunction(type.getType()));
                    functions.add(AggregateFcts.makeMinFunction(type.getType()));
                }
                else
                {
                    functions.add(AggregateFcts.maxFunctionForCounter);
                    functions.add(AggregateFcts.minFunctionForCounter);
                }
            }
        }

        return functions;
    }

    /**
     * Checks if the specified function is the count rows (e.g. COUNT(*) or COUNT(1)) function.
     *
     * @param function the function to check
     * @return <code>true</code> if the specified function is the count rows one, <code>false</code> otherwise.
     */
    public static boolean isCountRows(Function function)
    {
        return function == countRowsFunction;
    }

    /**
     * The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1)
     * is specified.
     */
    public static final AggregateFunction countRowsFunction =
            new NativeAggregateFunction("countRows", LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private long count;

                        public void reset()
                        {
                            count = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return LongType.instance.decompose(count);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            count++;
                        }
                    };
                }
            };

    /**
     * The SUM function for decimal values.
     */
    public static final AggregateFunction sumFunctionForDecimal =
            new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance)
            {
                @Override
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigDecimal sum = BigDecimal.ZERO;

                        public void reset()
                        {
                            sum = BigDecimal.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((DecimalType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            BigDecimal number = DecimalType.instance.compose(value);
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The AVG function for decimal values.
     */
    public static final AggregateFunction avgFunctionForDecimal =
            new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigDecimal avg = BigDecimal.ZERO;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            avg = BigDecimal.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return DecimalType.instance.decompose(avg);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            BigDecimal number = DecimalType.instance.compose(value);

                            // avg = avg + (value - sum) / count.
                            avg = avg.add(number.subtract(avg).divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN));
                        }
                    };
                }
            };


    /**
     * The SUM function for varint values.
     */
    public static final AggregateFunction sumFunctionForVarint =
            new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigInteger sum = BigInteger.ZERO;

                        public void reset()
                        {
                            sum = BigInteger.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((IntegerType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            BigInteger number = IntegerType.instance.compose(value);
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The AVG function for varint values.
     */
    public static final AggregateFunction avgFunctionForVarint =
            new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigInteger sum = BigInteger.ZERO;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = BigInteger.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            if (count == 0)
                                return IntegerType.instance.decompose(BigInteger.ZERO);

                            return IntegerType.instance.decompose(sum.divide(BigInteger.valueOf(count)));
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            BigInteger number = ((BigInteger) argTypes().get(0).compose(value));
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The SUM function for byte values (tinyint).
     */
    public static final AggregateFunction sumFunctionForByte =
            new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private byte sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((ByteType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.byteValue();
                        }
                    };
                }
            };

    /**
     * AVG function for byte values (tinyint).
     */
    public static final AggregateFunction avgFunctionForByte =
            new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new AvgAggregate(ByteType.instance)
                    {
                        public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
                        {
                            return ByteType.instance.decompose((byte) computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for short values (smallint).
     */
    public static final AggregateFunction sumFunctionForShort =
            new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private short sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((ShortType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.shortValue();
                        }
                    };
                }
            };

    /**
     * AVG function for for short values (smallint).
     */
    public static final AggregateFunction avgFunctionForShort =
            new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new AvgAggregate(ShortType.instance)
                    {
                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ShortType.instance.decompose((short) computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for int32 values.
     */
    public static final AggregateFunction sumFunctionForInt32 =
            new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private int sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((Int32Type) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.intValue();
                        }
                    };
                }
            };

    /**
     * AVG function for int32 values.
     */
    public static final AggregateFunction avgFunctionForInt32 =
            new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance)
            {
                public Aggregate newAggregate()
                {
                    return new AvgAggregate(Int32Type.instance)
                    {
                        public ByteBuffer compute(int protocolVersion)
                        {
                            return Int32Type.instance.decompose((int) computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for long values.
     */
    public static final AggregateFunction sumFunctionForLong =
            new NativeAggregateFunction("sum", LongType.instance, LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new LongSumAggregate();
                }
            };

    /**
     * AVG function for long values.
     */
    public static final AggregateFunction avgFunctionForLong =
            new NativeAggregateFunction("avg", LongType.instance, LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new AvgAggregate(LongType.instance)
                    {
                        public ByteBuffer compute(int protocolVersion)
                        {
                            return LongType.instance.decompose(computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for float values.
     */
    public static final AggregateFunction sumFunctionForFloat =
            new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private float sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((FloatType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.floatValue();
                        }
                    };
                }
            };

    /**
     * AVG function for float values.
     */
    public static final AggregateFunction avgFunctionForFloat =
            new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new FloatAvgAggregate(FloatType.instance)
                    {
                        public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
                        {
                            return FloatType.instance.decompose((float) computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for double values.
     */
    public static final AggregateFunction sumFunctionForDouble =
            new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private double sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((DoubleType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List<ByteBuffer> values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.doubleValue();
                        }
                    };
                }
            };
    /**
     * Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm
     * to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is
     * converted to corresponding representation by concrete implementations.
     */
    private static abstract class FloatAvgAggregate implements AggregateFunction.Aggregate
    {
        private double sum;
        private double compensation;
        private double simpleSum;

        private int count;

        private BigDecimal bigSum = null;
        private boolean overflow = false;

        private final AbstractType numberType;

        public FloatAvgAggregate(AbstractType numberType)
        {
            this.numberType = numberType;
        }

        public void reset()
        {
            sum = 0;
            compensation = 0;
            simpleSum = 0;

            count = 0;
            bigSum = null;
            overflow = false;
        }

        public double computeInternal()
        {
            if (count == 0)
                return 0d;

            if (overflow)
            {
                return bigSum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN).doubleValue();
            }
            else
            {
                // correctly compute final sum if it's NaN from consequently
                // adding same-signed infinite values.
                double tmp = sum + compensation;
                if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
                    sum = simpleSum;
                else
                    sum = tmp;

                return sum / count;
            }
        }

        public void addInput(int protocolVersion, List<ByteBuffer> values)
        {
            ByteBuffer value = values.get(0);

            if (value == null)
                return;

            count++;

            double number = ((Number) numberType.compose(value)).doubleValue();

            if (overflow)
            {
                bigSum = bigSum.add(BigDecimal.valueOf(number));
            }
            else
            {
                simpleSum += number;
                double prev = sum;
                double tmp = number - compensation;
                double rounded = sum + tmp;
                compensation = (rounded - sum) - tmp;
                sum = rounded;

                if (Double.isInfinite(sum) && !Double.isInfinite(number))
                {
                    overflow = true;
                    bigSum = BigDecimal.valueOf(prev).add(BigDecimal.valueOf(number));
                }
            }
        }
    }

    /**
     * AVG function for double values.
     */
    public static final AggregateFunction avgFunctionForDouble =
            new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new FloatAvgAggregate(DoubleType.instance)
                    {
                        public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
                        {
                            return DoubleType.instance.decompose(computeInternal());
                        }
                    };
                }
            };

    /**
     * The SUM function for counter column values.
     */
    public static final AggregateFunction sumFunctionForCounter =
    new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance)
    {
        public Aggregate newAggregate()
        {
            return new LongSumAggregate();
        }
    };

    /**
     * AVG function for counter column values.
     */
    public static final AggregateFunction avgFunctionForCounter =
    new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance)
    {
        public Aggregate newAggregate()
        {
            return new AvgAggregate(LongType.instance)
            {
                public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
                {
                    return CounterColumnType.instance.decompose(computeInternal());
                }
            };
        }
    };

    /**
     * The MIN function for counter column values.
     */
    public static final AggregateFunction minFunctionForCounter =
    new NativeAggregateFunction("min", CounterColumnType.instance, CounterColumnType.instance)
    {
        public Aggregate newAggregate()
        {
            return new Aggregate()
            {
                private Long min;

                public void reset()
                {
                    min = null;
                }

                public ByteBuffer compute(int protocolVersion)
                {
                    return min != null ? LongType.instance.decompose(min) : null;
                }

                public void addInput(int protocolVersion, List<ByteBuffer> values)
                {
                    ByteBuffer value = values.get(0);

                    if (value == null)
                        return;

                    long lval = LongType.instance.compose(value);

                    if (min == null || lval < min)
                        min = lval;
                }
            };
        }
    };

    /**
     * MAX function for counter column values.
     */
    public static final AggregateFunction maxFunctionForCounter =
    new NativeAggregateFunction("max", CounterColumnType.instance, CounterColumnType.instance)
    {
        public Aggregate newAggregate()
        {
            return new Aggregate()
            {
                private Long max;

                public void reset()
                {
                    max = null;
                }

                public ByteBuffer compute(int protocolVersion)
                {
                    return max != null ? LongType.instance.decompose(max) : null;
                }

                public void addInput(int protocolVersion, List<ByteBuffer> values)
                {
                    ByteBuffer value = values.get(0);

                    if (value == null)
                        return;

                    long lval = LongType.instance.compose(value);

                    if (max == null || lval > max)
                        max = lval;
                }
            };
        }
    };

    /**
     * Creates a MAX function for the specified type.
     *
     * @param inputType the function input and output type
     * @return a MAX function for the specified type.
     */
    public static AggregateFunction makeMaxFunction(final AbstractType<?> inputType)
    {
        return new NativeAggregateFunction("max", inputType, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private ByteBuffer max;

                    public void reset()
                    {
                        max = null;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return max;
                    }

                    public void addInput(int protocolVersion, List<ByteBuffer> values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        if (max == null || returnType().compare(max, value) < 0)
                            max = value;
                    }
                };
            }
        };
    }

    /**
     * Creates a MIN function for the specified type.
     *
     * @param inputType the function input and output type
     * @return a MIN function for the specified type.
     */
    public static AggregateFunction makeMinFunction(final AbstractType<?> inputType)
    {
        return new NativeAggregateFunction("min", inputType, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private ByteBuffer min;

                    public void reset()
                    {
                        min = null;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return min;
                    }

                    public void addInput(int protocolVersion, List<ByteBuffer> values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        if (min == null || returnType().compare(min, value) > 0)
                            min = value;
                    }
                };
            }
        };
    }

    /**
     * Creates a COUNT function for the specified type.
     *
     * @param inputType the function input type
     * @return a COUNT function for the specified type.
     */
    public static AggregateFunction makeCountFunction(AbstractType<?> inputType)
    {
        return new NativeAggregateFunction("count", LongType.instance, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private long count;

                    public void reset()
                    {
                        count = 0;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return ((LongType) returnType()).decompose(count);
                    }

                    public void addInput(int protocolVersion, List<ByteBuffer> values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        count++;
                    }
                };
            }
        };
    }

    private static class LongSumAggregate implements AggregateFunction.Aggregate
    {
        private long sum;

        public void reset()
        {
            sum = 0;
        }

        public ByteBuffer compute(int protocolVersion)
        {
            return LongType.instance.decompose(sum);
        }

        public void addInput(int protocolVersion, List<ByteBuffer> values)
        {
            ByteBuffer value = values.get(0);

            if (value == null)
                return;

            Number number = LongType.instance.compose(value);
            sum += number.longValue();
        }
    }

    /**
     * Average aggregate class, collecting the sum using long arithmetics, falling back
     * to BigInteger on long overflow. Resulting number is converted to corresponding
     * representation by concrete implementations.
     */
    private static abstract class AvgAggregate implements AggregateFunction.Aggregate
    {
        private long sum;
        private int count;
        private BigInteger bigSum = null;
        private boolean overflow = false;

        private final AbstractType numberType;

        public AvgAggregate(AbstractType type)
        {
            this.numberType = type;
        }

        public void reset()
        {
            count = 0;
            sum = 0L;
            overflow = false;
            bigSum = null;
        }

        long computeInternal()
        {
            if (overflow)
            {
                return bigSum.divide(BigInteger.valueOf(count)).longValue();
            }
            else
            {
                return count == 0 ? 0 : (sum / count);
            }
        }

        public void addInput(int protocolVersion, List<ByteBuffer> values)
        {
            ByteBuffer value = values.get(0);

            if (value == null)
                return;

            count++;
            long number = ((Number) numberType.compose(value)).longValue();
            if (overflow)
            {
                bigSum = bigSum.add(BigInteger.valueOf(number));
            }
            else
            {
                long prev = sum;
                sum += number;

                if (((prev ^ sum) & (number ^ sum)) < 0)
                {
                    overflow = true;
                    bigSum = BigInteger.valueOf(prev).add(BigInteger.valueOf(number));
                }
            }
        }
    }
}
