/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.query.aggregation;

import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Comparators;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionPlan;
import org.apache.druid.segment.virtual.ExpressionPlanner;
import org.apache.druid.segment.virtual.ExpressionSelectors;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Set;

public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
{
  private static final String FINALIZE_IDENTIFIER = "o";
  private static final String COMPARE_O1 = "o1";
  private static final String COMPARE_O2 = "o2";
  private static final String DEFAULT_ACCUMULATOR_ID = "__acc";

  // minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double)
  // | expression type byte | is_null byte | primitive value (8 bytes) |
  private static final int MIN_SIZE_BYTES = 10;
  private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10);

  private final String name;
  @Nullable
  private final Set<String> fields;
  private final String accumulatorId;
  private final String foldExpressionString;
  private final String initialValueExpressionString;
  private final String initialCombineValueExpressionString;

  private final String combineExpressionString;
  @Nullable
  private final String compareExpressionString;
  @Nullable
  private final String finalizeExpressionString;

  private final ExprMacroTable macroTable;
  private final Supplier<ExprEval<?>> initialValue;
  private final Supplier<ExprEval<?>> initialCombineValue;
  private final Supplier<Expr> foldExpression;
  private final Supplier<Expr> combineExpression;
  private final Supplier<Expr> compareExpression;
  private final Supplier<Expr> finalizeExpression;
  private final HumanReadableBytes maxSizeBytes;

  private final Supplier<SettableObjectBinding> compareBindings =
      Suppliers.memoize(() -> new SettableObjectBinding(2));
  private final Supplier<SettableObjectBinding> combineBindings =
      Suppliers.memoize(() -> new SettableObjectBinding(2));
  private final Supplier<SettableObjectBinding> finalizeBindings =
      Suppliers.memoize(() -> new SettableObjectBinding(1));
  private final Supplier<Expr.InputBindingInspector> finalizeInspector;

  @JsonCreator
  public ExpressionLambdaAggregatorFactory(
      @JsonProperty("name") String name,
      @JsonProperty("fields") @Nullable final Set<String> fields,
      @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier,
      @JsonProperty("initialValue") final String initialValue,
      @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue,
      @JsonProperty("fold") final String foldExpression,
      @JsonProperty("combine") @Nullable final String combineExpression,
      @JsonProperty("compare") @Nullable final String compareExpression,
      @JsonProperty("finalize") @Nullable final String finalizeExpression,
      @JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes,
      @JacksonInject ExprMacroTable macroTable
  )
  {
    Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");

    this.name = name;
    this.fields = fields;
    this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID;

    this.initialValueExpressionString = initialValue;
    this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue;
    this.foldExpressionString = foldExpression;
    if (combineExpression != null) {
      this.combineExpressionString = combineExpression;
    } else {
      // if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the
      // field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should
      // already be set with the rewritten value at that point
      Preconditions.checkArgument(
          fields != null && fields.size() == 1,
          "Must have a single input field if no combine expression is supplied"
      );
      this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name);
    }
    this.compareExpressionString = compareExpression;
    this.finalizeExpressionString = finalizeExpression;
    this.macroTable = macroTable;

    this.initialValue = Suppliers.memoize(() -> {
      Expr parsed = Parser.parse(initialValue, macroTable);
      Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
      return parsed.eval(ExprUtils.nilBindings());
    });
    this.initialCombineValue = Suppliers.memoize(() -> {
      Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
      Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
      return parsed.eval(ExprUtils.nilBindings());
    });
    this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
    this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
    this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable);
    this.finalizeInspector = Suppliers.memoize(
        () -> InputBindings.inspectorFromTypeMap(
            ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type())
        )
    );
    this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
    this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
    Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);

  }

  @JsonProperty
  @Override
  public String getName()
  {
    return name;
  }

  @JsonProperty
  @Nullable
  public Set<String> getFields()
  {
    return fields;
  }

  @JsonProperty
  @Nullable
  public String getAccumulatorIdentifier()
  {
    return accumulatorId;
  }

  @JsonProperty("initialValue")
  public String getInitialValueExpressionString()
  {
    return initialValueExpressionString;
  }

  @JsonProperty("initialCombineValue")
  public String getInitialCombineValueExpressionString()
  {
    return initialCombineValueExpressionString;
  }

  @JsonProperty("fold")
  public String getFoldExpressionString()
  {
    return foldExpressionString;
  }

  @JsonProperty("combine")
  public String getCombineExpressionString()
  {
    return combineExpressionString;
  }

  @JsonProperty("compare")
  @Nullable
  public String getCompareExpressionString()
  {
    return compareExpressionString;
  }

  @JsonProperty("finalize")
  @Nullable
  public String getFinalizeExpressionString()
  {
    return finalizeExpressionString;
  }

  @JsonProperty("maxSizeBytes")
  public HumanReadableBytes getMaxSizeBytes()
  {
    return maxSizeBytes;
  }

  @Override
  public byte[] getCacheKey()
  {
    return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID)
        .appendStrings(fields)
        .appendString(initialValueExpressionString)
        .appendString(initialCombineValueExpressionString)
        .appendCacheable(foldExpression.get())
        .appendCacheable(combineExpression.get())
        .appendCacheable(combineExpression.get())
        .appendCacheable(finalizeExpression.get())
        .appendInt(maxSizeBytes.getBytesInInt())
        .build();
  }

  @Override
  public Aggregator factorize(ColumnSelectorFactory metricFactory)
  {
    FactorizePlan thePlan = new FactorizePlan(metricFactory);
    return new ExpressionLambdaAggregator(
        thePlan.getExpression(),
        thePlan.getBindings(),
        maxSizeBytes.getBytesInInt()
    );
  }

  @Override
  public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
  {
    FactorizePlan thePlan = new FactorizePlan(metricFactory);
    return new ExpressionLambdaBufferAggregator(
        thePlan.getExpression(),
        thePlan.getInitialValue(),
        thePlan.getBindings(),
        maxSizeBytes.getBytesInInt()
    );
  }

  @Override
  public Comparator getComparator()
  {
    Expr compareExpr = compareExpression.get();
    if (compareExpr != null) {
      return (o1, o2) ->
          compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
    }
    switch (initialValue.get().type()) {
      case LONG:
        return LongSumAggregator.COMPARATOR;
      case DOUBLE:
        return DoubleSumAggregator.COMPARATOR;
      default:
        return Comparators.naturalNullsFirst();
    }
  }

  @Nullable
  @Override
  public Object combine(@Nullable Object lhs, @Nullable Object rhs)
  {
    // arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function
    return combineExpression.get().eval(
        combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs)
    ).value();
  }

  @Override
  public Object deserialize(Object object)
  {
    return object;
  }

  @Nullable
  @Override
  public Object finalizeComputation(@Nullable Object object)
  {
    Expr finalizeExpr;
    finalizeExpr = finalizeExpression.get();
    if (finalizeExpr != null) {
      return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value();
    }
    return object;
  }

  @Override
  public List<String> requiredFields()
  {
    if (fields == null) {
      return combineExpression.get().analyzeInputs().getRequiredBindingsList();
    }
    return foldExpression.get().analyzeInputs().getRequiredBindingsList();
  }

  @Override
  public AggregatorFactory getCombiningFactory()
  {
    return new ExpressionLambdaAggregatorFactory(
        name,
        null,
        accumulatorId,
        initialValueExpressionString,
        initialCombineValueExpressionString,
        foldExpressionString,
        combineExpressionString,
        compareExpressionString,
        finalizeExpressionString,
        maxSizeBytes,
        macroTable
    );
  }

  @Override
  public List<AggregatorFactory> getRequiredColumns()
  {
    return Collections.singletonList(
        new ExpressionLambdaAggregatorFactory(
            name,
            fields,
            accumulatorId,
            initialValueExpressionString,
            initialCombineValueExpressionString,
            foldExpressionString,
            combineExpressionString,
            compareExpressionString,
            finalizeExpressionString,
            maxSizeBytes,
            macroTable
        )
    );
  }

  @Override
  public ValueType getType()
  {
    if (fields == null) {
      return ExprType.toValueType(initialCombineValue.get().type());
    }
    return ExprType.toValueType(initialValue.get().type());
  }

  @Override
  public ValueType getFinalizedType()
  {
    Expr finalizeExpr = finalizeExpression.get();
    ExprEval<?> initialVal = initialCombineValue.get();
    if (finalizeExpr != null) {
      ExprType type = finalizeExpr.getOutputType(finalizeInspector.get());
      if (type == null) {
        type = initialVal.type();
      }
      return ExprType.toValueType(type);
    }
    return ExprType.toValueType(initialVal.type());
  }

  @Override
  public int getMaxIntermediateSize()
  {
    // numeric expressions are either longs or doubles, with strings or arrays max size is unknown
    // for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte
    return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt();
  }

  @Override
  public boolean equals(Object o)
  {
    if (this == o) {
      return true;
    }
    if (o == null || getClass() != o.getClass()) {
      return false;
    }
    ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o;
    return maxSizeBytes.equals(that.maxSizeBytes)
           && name.equals(that.name)
           && Objects.equals(fields, that.fields)
           && accumulatorId.equals(that.accumulatorId)
           && foldExpressionString.equals(that.foldExpressionString)
           && initialValueExpressionString.equals(that.initialValueExpressionString)
           && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString)
           && combineExpressionString.equals(that.combineExpressionString)
           && Objects.equals(compareExpressionString, that.compareExpressionString)
           && Objects.equals(finalizeExpressionString, that.finalizeExpressionString);
  }

  @Override
  public int hashCode()
  {
    return Objects.hash(
        name,
        fields,
        accumulatorId,
        foldExpressionString,
        initialValueExpressionString,
        initialCombineValueExpressionString,
        combineExpressionString,
        compareExpressionString,
        finalizeExpressionString,
        maxSizeBytes
    );
  }

  @Override
  public String toString()
  {
    return "ExpressionLambdaAggregatorFactory{" +
           "name='" + name + '\'' +
           ", fields=" + fields +
           ", accumulatorId='" + accumulatorId + '\'' +
           ", foldExpressionString='" + foldExpressionString + '\'' +
           ", initialValueExpressionString='" + initialValueExpressionString + '\'' +
           ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' +
           ", combineExpressionString='" + combineExpressionString + '\'' +
           ", compareExpressionString='" + compareExpressionString + '\'' +
           ", finalizeExpressionString='" + finalizeExpressionString + '\'' +
           ", maxSizeBytes=" + maxSizeBytes +
           '}';
  }

  /**
   * Determine how to factorize the aggregator
   */
  private class FactorizePlan
  {
    private final ExpressionPlan plan;

    private final ExprEval<?> seed;
    private final ExpressionLambdaAggregatorInputBindings bindings;

    FactorizePlan(ColumnSelectorFactory metricFactory)
    {
      final List<String> columns;

      if (fields != null) {
        // if fields are set, we are accumulating from raw inputs, use fold expression
        plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get());
        seed = initialValue.get();
        columns = plan.getAnalysis().getRequiredBindingsList();
      } else {
        // else we are merging intermediary results, use combine expression
        plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get());
        seed = initialCombineValue.get();
        columns = plan.getAnalysis().getRequiredBindingsList();
      }

      bindings = new ExpressionLambdaAggregatorInputBindings(
          ExpressionSelectors.createBindings(metricFactory, columns),
          accumulatorId,
          seed
      );
    }

    public Expr getExpression()
    {
      if (fields == null) {
        return plan.getExpression();
      }
      // for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown
      // inputs
      return plan.getAppliedFoldExpression(accumulatorId);
    }

    public ExprEval<?> getInitialValue()
    {
      return seed;
    }

    public ExpressionLambdaAggregatorInputBindings getBindings()
    {
      return bindings;
    }

    private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector)
    {
      return new ColumnInspector()
      {
        @Nullable
        @Override
        public ColumnCapabilities getColumnCapabilities(String column)
        {
          if (accumulatorId.equals(column)) {
            return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type()));
          }
          return inspector.getColumnCapabilities(column);
        }

        @Nullable
        @Override
        public ExprType getType(String name)
        {
          if (accumulatorId.equals(name)) {
            return initialValue.get().type();
          }
          return inspector.getType(name);
        }
      };
    }
  }
}
