| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.druid.query.aggregation; |
| |
| import com.fasterxml.jackson.annotation.JacksonInject; |
| import com.fasterxml.jackson.annotation.JsonCreator; |
| import com.fasterxml.jackson.annotation.JsonProperty; |
| import com.google.common.base.Preconditions; |
| import com.google.common.base.Supplier; |
| import com.google.common.base.Suppliers; |
| import com.google.common.collect.ImmutableMap; |
| import com.google.common.collect.Iterables; |
| import org.apache.druid.common.config.NullHandling; |
| import org.apache.druid.java.util.common.HumanReadableBytes; |
| import org.apache.druid.java.util.common.StringUtils; |
| import org.apache.druid.java.util.common.guava.Comparators; |
| import org.apache.druid.math.expr.Expr; |
| import org.apache.druid.math.expr.ExprEval; |
| import org.apache.druid.math.expr.ExprMacroTable; |
| import org.apache.druid.math.expr.ExprType; |
| import org.apache.druid.math.expr.InputBindings; |
| import org.apache.druid.math.expr.Parser; |
| import org.apache.druid.math.expr.SettableObjectBinding; |
| import org.apache.druid.query.cache.CacheKeyBuilder; |
| import org.apache.druid.query.expression.ExprUtils; |
| import org.apache.druid.segment.ColumnInspector; |
| import org.apache.druid.segment.ColumnSelectorFactory; |
| import org.apache.druid.segment.column.ColumnCapabilities; |
| import org.apache.druid.segment.column.ColumnCapabilitiesImpl; |
| import org.apache.druid.segment.column.ValueType; |
| import org.apache.druid.segment.virtual.ExpressionPlan; |
| import org.apache.druid.segment.virtual.ExpressionPlanner; |
| import org.apache.druid.segment.virtual.ExpressionSelectors; |
| |
| import javax.annotation.Nullable; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.List; |
| import java.util.Objects; |
| import java.util.Set; |
| |
| public class ExpressionLambdaAggregatorFactory extends AggregatorFactory |
| { |
| private static final String FINALIZE_IDENTIFIER = "o"; |
| private static final String COMPARE_O1 = "o1"; |
| private static final String COMPARE_O2 = "o2"; |
| private static final String DEFAULT_ACCUMULATOR_ID = "__acc"; |
| |
| // minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double) |
| // | expression type byte | is_null byte | primitive value (8 bytes) | |
| private static final int MIN_SIZE_BYTES = 10; |
| private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10); |
| |
| private final String name; |
| @Nullable |
| private final Set<String> fields; |
| private final String accumulatorId; |
| private final String foldExpressionString; |
| private final String initialValueExpressionString; |
| private final String initialCombineValueExpressionString; |
| private final boolean isNullUnlessAggregated; |
| |
| private final String combineExpressionString; |
| @Nullable |
| private final String compareExpressionString; |
| @Nullable |
| private final String finalizeExpressionString; |
| |
| private final ExprMacroTable macroTable; |
| private final Supplier<ExprEval<?>> initialValue; |
| private final Supplier<ExprEval<?>> initialCombineValue; |
| private final Supplier<Expr> foldExpression; |
| private final Supplier<Expr> combineExpression; |
| private final Supplier<Expr> compareExpression; |
| private final Supplier<Expr> finalizeExpression; |
| private final HumanReadableBytes maxSizeBytes; |
| |
| private final Supplier<SettableObjectBinding> compareBindings = |
| Suppliers.memoize(() -> new SettableObjectBinding(2)); |
| private final Supplier<SettableObjectBinding> combineBindings = |
| Suppliers.memoize(() -> new SettableObjectBinding(2)); |
| private final Supplier<SettableObjectBinding> finalizeBindings = |
| Suppliers.memoize(() -> new SettableObjectBinding(1)); |
| private final Supplier<Expr.InputBindingInspector> finalizeInspector; |
| |
| @JsonCreator |
| public ExpressionLambdaAggregatorFactory( |
| @JsonProperty("name") String name, |
| @JsonProperty("fields") @Nullable final Set<String> fields, |
| @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier, |
| @JsonProperty("initialValue") final String initialValue, |
| @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue, |
| @JsonProperty("isNullUnlessAggregated") @Nullable final Boolean isNullUnlessAggregated, |
| @JsonProperty("fold") final String foldExpression, |
| @JsonProperty("combine") @Nullable final String combineExpression, |
| @JsonProperty("compare") @Nullable final String compareExpression, |
| @JsonProperty("finalize") @Nullable final String finalizeExpression, |
| @JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes, |
| @JacksonInject ExprMacroTable macroTable |
| ) |
| { |
| Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); |
| |
| this.name = name; |
| this.fields = fields; |
| this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID; |
| |
| this.initialValueExpressionString = initialValue; |
| this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue; |
| this.isNullUnlessAggregated = isNullUnlessAggregated == null ? NullHandling.sqlCompatible() : isNullUnlessAggregated; |
| this.foldExpressionString = foldExpression; |
| if (combineExpression != null) { |
| this.combineExpressionString = combineExpression; |
| } else { |
| // if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the |
| // field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should |
| // already be set with the rewritten value at that point |
| Preconditions.checkArgument( |
| fields != null && fields.size() == 1, |
| "Must have a single input field if no combine expression is supplied" |
| ); |
| this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name); |
| } |
| this.compareExpressionString = compareExpression; |
| this.finalizeExpressionString = finalizeExpression; |
| this.macroTable = macroTable; |
| |
| this.initialValue = Suppliers.memoize(() -> { |
| Expr parsed = Parser.parse(initialValue, macroTable); |
| Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant"); |
| return parsed.eval(ExprUtils.nilBindings()); |
| }); |
| this.initialCombineValue = Suppliers.memoize(() -> { |
| Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable); |
| Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant"); |
| return parsed.eval(ExprUtils.nilBindings()); |
| }); |
| this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable); |
| this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable); |
| this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable); |
| this.finalizeInspector = Suppliers.memoize( |
| () -> InputBindings.inspectorFromTypeMap( |
| ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type()) |
| ) |
| ); |
| this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable); |
| this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES; |
| Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES); |
| |
| } |
| |
| @JsonProperty |
| @Override |
| public String getName() |
| { |
| return name; |
| } |
| |
| @JsonProperty |
| @Nullable |
| public Set<String> getFields() |
| { |
| return fields; |
| } |
| |
| @JsonProperty |
| @Nullable |
| public String getAccumulatorIdentifier() |
| { |
| return accumulatorId; |
| } |
| |
| @JsonProperty("initialValue") |
| public String getInitialValueExpressionString() |
| { |
| return initialValueExpressionString; |
| } |
| |
| @JsonProperty("initialCombineValue") |
| public String getInitialCombineValueExpressionString() |
| { |
| return initialCombineValueExpressionString; |
| } |
| |
| @JsonProperty("isNullUnlessAggregated") |
| public boolean getIsNullUnlessAggregated() |
| { |
| return isNullUnlessAggregated; |
| } |
| |
| @JsonProperty("fold") |
| public String getFoldExpressionString() |
| { |
| return foldExpressionString; |
| } |
| |
| @JsonProperty("combine") |
| public String getCombineExpressionString() |
| { |
| return combineExpressionString; |
| } |
| |
| @JsonProperty("compare") |
| @Nullable |
| public String getCompareExpressionString() |
| { |
| return compareExpressionString; |
| } |
| |
| @JsonProperty("finalize") |
| @Nullable |
| public String getFinalizeExpressionString() |
| { |
| return finalizeExpressionString; |
| } |
| |
| @JsonProperty("maxSizeBytes") |
| public HumanReadableBytes getMaxSizeBytes() |
| { |
| return maxSizeBytes; |
| } |
| |
| @Override |
| public byte[] getCacheKey() |
| { |
| return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID) |
| .appendStrings(fields) |
| .appendString(initialValueExpressionString) |
| .appendString(initialCombineValueExpressionString) |
| .appendCacheable(foldExpression.get()) |
| .appendCacheable(combineExpression.get()) |
| .appendCacheable(combineExpression.get()) |
| .appendCacheable(finalizeExpression.get()) |
| .appendInt(maxSizeBytes.getBytesInInt()) |
| .build(); |
| } |
| |
| @Override |
| public Aggregator factorize(ColumnSelectorFactory metricFactory) |
| { |
| FactorizePlan thePlan = new FactorizePlan(metricFactory); |
| return new ExpressionLambdaAggregator( |
| thePlan.getExpression(), |
| thePlan.getBindings(), |
| isNullUnlessAggregated, |
| maxSizeBytes.getBytesInInt() |
| ); |
| } |
| |
| @Override |
| public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) |
| { |
| FactorizePlan thePlan = new FactorizePlan(metricFactory); |
| return new ExpressionLambdaBufferAggregator( |
| thePlan.getExpression(), |
| thePlan.getInitialValue(), |
| thePlan.getBindings(), |
| isNullUnlessAggregated, |
| maxSizeBytes.getBytesInInt() |
| ); |
| } |
| |
| @Override |
| public Comparator getComparator() |
| { |
| Expr compareExpr = compareExpression.get(); |
| if (compareExpr != null) { |
| return (o1, o2) -> |
| compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt(); |
| } |
| switch (initialValue.get().type()) { |
| case LONG: |
| return LongSumAggregator.COMPARATOR; |
| case DOUBLE: |
| return DoubleSumAggregator.COMPARATOR; |
| default: |
| return Comparators.naturalNullsFirst(); |
| } |
| } |
| |
| @Nullable |
| @Override |
| public Object combine(@Nullable Object lhs, @Nullable Object rhs) |
| { |
| // arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function |
| return combineExpression.get().eval( |
| combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs) |
| ).value(); |
| } |
| |
| @Override |
| public Object deserialize(Object object) |
| { |
| return object; |
| } |
| |
| @Nullable |
| @Override |
| public Object finalizeComputation(@Nullable Object object) |
| { |
| Expr finalizeExpr; |
| finalizeExpr = finalizeExpression.get(); |
| if (finalizeExpr != null) { |
| return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value(); |
| } |
| return object; |
| } |
| |
| @Override |
| public List<String> requiredFields() |
| { |
| if (fields == null) { |
| return combineExpression.get().analyzeInputs().getRequiredBindingsList(); |
| } |
| return foldExpression.get().analyzeInputs().getRequiredBindingsList(); |
| } |
| |
| @Override |
| public AggregatorFactory getCombiningFactory() |
| { |
| return new ExpressionLambdaAggregatorFactory( |
| name, |
| null, |
| accumulatorId, |
| initialValueExpressionString, |
| initialCombineValueExpressionString, |
| isNullUnlessAggregated, |
| foldExpressionString, |
| combineExpressionString, |
| compareExpressionString, |
| finalizeExpressionString, |
| maxSizeBytes, |
| macroTable |
| ); |
| } |
| |
| @Override |
| public List<AggregatorFactory> getRequiredColumns() |
| { |
| return Collections.singletonList( |
| new ExpressionLambdaAggregatorFactory( |
| name, |
| fields, |
| accumulatorId, |
| initialValueExpressionString, |
| initialCombineValueExpressionString, |
| isNullUnlessAggregated, |
| foldExpressionString, |
| combineExpressionString, |
| compareExpressionString, |
| finalizeExpressionString, |
| maxSizeBytes, |
| macroTable |
| ) |
| ); |
| } |
| |
| @Override |
| public ValueType getType() |
| { |
| if (fields == null) { |
| return ExprType.toValueType(initialCombineValue.get().type()); |
| } |
| return ExprType.toValueType(initialValue.get().type()); |
| } |
| |
| @Override |
| public ValueType getFinalizedType() |
| { |
| Expr finalizeExpr = finalizeExpression.get(); |
| ExprEval<?> initialVal = initialCombineValue.get(); |
| if (finalizeExpr != null) { |
| ExprType type = finalizeExpr.getOutputType(finalizeInspector.get()); |
| if (type == null) { |
| type = initialVal.type(); |
| } |
| return ExprType.toValueType(type); |
| } |
| return ExprType.toValueType(initialVal.type()); |
| } |
| |
| @Override |
| public int getMaxIntermediateSize() |
| { |
| // numeric expressions are either longs or doubles, with strings or arrays max size is unknown |
| // for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte |
| return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt(); |
| } |
| |
| @Override |
| public boolean equals(Object o) |
| { |
| if (this == o) { |
| return true; |
| } |
| if (o == null || getClass() != o.getClass()) { |
| return false; |
| } |
| ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o; |
| return maxSizeBytes.equals(that.maxSizeBytes) |
| && name.equals(that.name) |
| && Objects.equals(fields, that.fields) |
| && accumulatorId.equals(that.accumulatorId) |
| && foldExpressionString.equals(that.foldExpressionString) |
| && initialValueExpressionString.equals(that.initialValueExpressionString) |
| && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString) |
| && isNullUnlessAggregated == that.isNullUnlessAggregated |
| && combineExpressionString.equals(that.combineExpressionString) |
| && Objects.equals(compareExpressionString, that.compareExpressionString) |
| && Objects.equals(finalizeExpressionString, that.finalizeExpressionString); |
| } |
| |
| @Override |
| public int hashCode() |
| { |
| return Objects.hash( |
| name, |
| fields, |
| accumulatorId, |
| foldExpressionString, |
| initialValueExpressionString, |
| initialCombineValueExpressionString, |
| isNullUnlessAggregated, |
| combineExpressionString, |
| compareExpressionString, |
| finalizeExpressionString, |
| maxSizeBytes |
| ); |
| } |
| |
| @Override |
| public String toString() |
| { |
| return "ExpressionLambdaAggregatorFactory{" + |
| "name='" + name + '\'' + |
| ", fields=" + fields + |
| ", accumulatorId='" + accumulatorId + '\'' + |
| ", foldExpressionString='" + foldExpressionString + '\'' + |
| ", initialValueExpressionString='" + initialValueExpressionString + '\'' + |
| ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' + |
| ", nullUnlessAggregated='" + isNullUnlessAggregated + '\'' + |
| ", combineExpressionString='" + combineExpressionString + '\'' + |
| ", compareExpressionString='" + compareExpressionString + '\'' + |
| ", finalizeExpressionString='" + finalizeExpressionString + '\'' + |
| ", maxSizeBytes=" + maxSizeBytes + |
| '}'; |
| } |
| |
| /** |
| * Determine how to factorize the aggregator |
| */ |
| private class FactorizePlan |
| { |
| private final ExpressionPlan plan; |
| |
| private final ExprEval<?> seed; |
| private final ExpressionLambdaAggregatorInputBindings bindings; |
| |
| FactorizePlan(ColumnSelectorFactory metricFactory) |
| { |
| final List<String> columns; |
| |
| if (fields != null) { |
| // if fields are set, we are accumulating from raw inputs, use fold expression |
| plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get()); |
| seed = initialValue.get(); |
| columns = plan.getAnalysis().getRequiredBindingsList(); |
| } else { |
| // else we are merging intermediary results, use combine expression |
| plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get()); |
| seed = initialCombineValue.get(); |
| columns = plan.getAnalysis().getRequiredBindingsList(); |
| } |
| |
| bindings = new ExpressionLambdaAggregatorInputBindings( |
| ExpressionSelectors.createBindings(metricFactory, columns), |
| accumulatorId, |
| seed |
| ); |
| } |
| |
| public Expr getExpression() |
| { |
| if (fields == null) { |
| return plan.getExpression(); |
| } |
| // for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown |
| // inputs |
| return plan.getAppliedFoldExpression(accumulatorId); |
| } |
| |
| public ExprEval<?> getInitialValue() |
| { |
| return seed; |
| } |
| |
| public ExpressionLambdaAggregatorInputBindings getBindings() |
| { |
| return bindings; |
| } |
| |
| private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector) |
| { |
| return new ColumnInspector() |
| { |
| @Nullable |
| @Override |
| public ColumnCapabilities getColumnCapabilities(String column) |
| { |
| if (accumulatorId.equals(column)) { |
| return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type())); |
| } |
| return inspector.getColumnCapabilities(column); |
| } |
| |
| @Nullable |
| @Override |
| public ExprType getType(String name) |
| { |
| if (accumulatorId.equals(name)) { |
| return initialValue.get().type(); |
| } |
| return inspector.getType(name); |
| } |
| }; |
| } |
| } |
| } |