expression aggregator (#11104)
* add experimental expression aggregator
* add test
* fix lgtm
* fix test
* adjust test
* use not null constant
* array_set_concat docs
* add equals and hashcode and tostring
* fix it
* spelling
* do multi-value magic for expression agg, more javadocs, tests
* formatting
* fix inspection
* more better
* nullable
diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
index dd64629..5edb2fe 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
@@ -183,9 +183,9 @@
class LongArrayExpr extends ConstantExpr<Long[]>
{
- LongArrayExpr(Long[] value)
+ LongArrayExpr(@Nullable Long[] value)
{
- super(ExprType.LONG_ARRAY, Preconditions.checkNotNull(value, "value"));
+ super(ExprType.LONG_ARRAY, value);
}
@Override
@@ -320,9 +320,9 @@
class DoubleArrayExpr extends ConstantExpr<Double[]>
{
- DoubleArrayExpr(Double[] value)
+ DoubleArrayExpr(@Nullable Double[] value)
{
- super(ExprType.DOUBLE_ARRAY, Preconditions.checkNotNull(value, "value"));
+ super(ExprType.DOUBLE_ARRAY, value);
}
@Override
@@ -426,9 +426,9 @@
class StringArrayExpr extends ConstantExpr<String[]>
{
- StringArrayExpr(String[] value)
+ StringArrayExpr(@Nullable String[] value)
{
- super(ExprType.STRING_ARRAY, Preconditions.checkNotNull(value, "value"));
+ super(ExprType.STRING_ARRAY, value);
}
@Override
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
index 52e5730..1a9a962 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
@@ -23,15 +23,357 @@
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.UOE;
import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
import java.util.Arrays;
+import java.util.List;
/**
* Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow
*/
public abstract class ExprEval<T>
{
+ private static final int NULL_LENGTH = -1;
+
+ /**
+ * Deserialize an expression stored in a bytebuffer, e.g. for an agg.
+ *
+ * This should be refactored to be consolidated with some of the standard type handling of aggregators probably
+ */
+ public static ExprEval deserialize(ByteBuffer buffer, int position)
+ {
+ // | expression type (byte) | expression bytes |
+ ExprType type = ExprType.fromByte(buffer.get(position));
+ int offset = position + 1;
+ switch (type) {
+ case LONG:
+ // | expression type (byte) | is null (byte) | long bytes |
+ if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
+ return of(buffer.getLong(offset));
+ }
+ return ofLong(null);
+ case DOUBLE:
+ // | expression type (byte) | is null (byte) | double bytes |
+ if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
+ return of(buffer.getDouble(offset));
+ }
+ return ofDouble(null);
+ case STRING:
+ // | expression type (byte) | string length (int) | string bytes |
+ final int length = buffer.getInt(offset);
+ if (length < 0) {
+ return of(null);
+ }
+ final byte[] stringBytes = new byte[length];
+ final int oldPosition = buffer.position();
+ buffer.position(offset + Integer.BYTES);
+ buffer.get(stringBytes, 0, length);
+ buffer.position(oldPosition);
+ return of(StringUtils.fromUtf8(stringBytes));
+ case LONG_ARRAY:
+ // | expression type (byte) | array length (int) | array bytes |
+ final int longArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ if (longArrayLength < 0) {
+ return ofLongArray(null);
+ }
+ final Long[] longs = new Long[longArrayLength];
+ for (int i = 0; i < longArrayLength; i++) {
+ final byte isNull = buffer.get(offset);
+ offset += Byte.BYTES;
+ if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
+ // | is null (byte) | long bytes |
+ longs[i] = buffer.getLong(offset);
+ offset += Long.BYTES;
+ } else {
+ // | is null (byte) |
+ longs[i] = null;
+ }
+ }
+ return ofLongArray(longs);
+ case DOUBLE_ARRAY:
+ // | expression type (byte) | array length (int) | array bytes |
+ final int doubleArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ if (doubleArrayLength < 0) {
+ return ofDoubleArray(null);
+ }
+ final Double[] doubles = new Double[doubleArrayLength];
+ for (int i = 0; i < doubleArrayLength; i++) {
+ final byte isNull = buffer.get(offset);
+ offset += Byte.BYTES;
+ if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
+ // | is null (byte) | double bytes |
+ doubles[i] = buffer.getDouble(offset);
+ offset += Double.BYTES;
+ } else {
+ // | is null (byte) |
+ doubles[i] = null;
+ }
+ }
+ return ofDoubleArray(doubles);
+ case STRING_ARRAY:
+ // | expression type (byte) | array length (int) | array bytes |
+ final int stringArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ if (stringArrayLength < 0) {
+ return ofStringArray(null);
+ }
+ final String[] stringArray = new String[stringArrayLength];
+ for (int i = 0; i < stringArrayLength; i++) {
+ final int stringElementLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ if (stringElementLength < 0) {
+ // | string length (int) |
+ stringArray[i] = null;
+ } else {
+ // | string length (int) | string bytes |
+ final byte[] stringElementBytes = new byte[stringElementLength];
+ final int oldPosition2 = buffer.position();
+ buffer.position(offset);
+ buffer.get(stringElementBytes, 0, stringElementLength);
+ buffer.position(oldPosition2);
+ stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
+ offset += stringElementLength;
+ }
+ }
+ return ofStringArray(stringArray);
+ default:
+ throw new UOE("how can this be?");
+ }
+ }
+
+ /**
+ * Write an expression result to a bytebuffer, throwing an {@link ISE} if the data exceeds a maximum size. Primitive
+ * numeric types are not validated to be lower than max size, so it is expected to be at least 10 bytes. Callers
+ * of this method should enforce this themselves (instead of doing it here, which might be done every row)
+ *
+ * This should be refactored to be consolidated with some of the standard type handling of aggregators probably
+ */
+ public static void serialize(ByteBuffer buffer, int position, ExprEval<?> eval, int maxSizeBytes)
+ {
+ int offset = position;
+ buffer.put(offset++, eval.type().getId());
+ switch (eval.type()) {
+ case LONG:
+ if (eval.isNumericNull()) {
+ buffer.put(offset, NullHandling.IS_NULL_BYTE);
+ } else {
+ buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putLong(offset, eval.asLong());
+ }
+ break;
+ case DOUBLE:
+ if (eval.isNumericNull()) {
+ buffer.put(offset, NullHandling.IS_NULL_BYTE);
+ } else {
+ buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putDouble(offset, eval.asDouble());
+ }
+ break;
+ case STRING:
+ final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString());
+ if (stringBytes != null) {
+ // | expression type (byte) | string length (int) | string bytes |
+ checkMaxBytes(eval.type(), 1 + Integer.BYTES + stringBytes.length, maxSizeBytes);
+ buffer.putInt(offset, stringBytes.length);
+ offset += Integer.BYTES;
+ final int oldPosition = buffer.position();
+ buffer.position(offset);
+ buffer.put(stringBytes, 0, stringBytes.length);
+ buffer.position(oldPosition);
+ } else {
+ checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
+ buffer.putInt(offset, NULL_LENGTH);
+ }
+ break;
+ case LONG_ARRAY:
+ Long[] longs = eval.asLongArray();
+ if (longs == null) {
+ // | expression type (byte) | array length (int) |
+ checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
+ buffer.putInt(offset, NULL_LENGTH);
+ } else {
+ // | expression type (byte) | array length (int) | array bytes |
+ final int sizeBytes = 1 + Integer.BYTES + (Long.BYTES * longs.length);
+ checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
+ buffer.putInt(offset, longs.length);
+ offset += Integer.BYTES;
+ for (Long aLong : longs) {
+ if (aLong != null) {
+ buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
+ offset++;
+ buffer.putLong(offset, aLong);
+ offset += Long.BYTES;
+ } else {
+ buffer.put(offset++, NullHandling.IS_NULL_BYTE);
+ }
+ }
+ }
+ break;
+ case DOUBLE_ARRAY:
+ Double[] doubles = eval.asDoubleArray();
+ if (doubles == null) {
+ // | expression type (byte) | array length (int) |
+ checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
+ buffer.putInt(offset, NULL_LENGTH);
+ } else {
+ // | expression type (byte) | array length (int) | array bytes |
+ final int sizeBytes = 1 + Integer.BYTES + (Double.BYTES * doubles.length);
+ checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
+ buffer.putInt(offset, doubles.length);
+ offset += Integer.BYTES;
+
+ for (Double aDouble : doubles) {
+ if (aDouble != null) {
+ buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
+ offset++;
+ buffer.putDouble(offset, aDouble);
+ offset += Long.BYTES;
+ } else {
+ buffer.put(offset++, NullHandling.IS_NULL_BYTE);
+ }
+ }
+ }
+ break;
+ case STRING_ARRAY:
+ String[] strings = eval.asStringArray();
+ if (strings == null) {
+ // | expression type (byte) | array length (int) |
+ checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
+ buffer.putInt(offset, NULL_LENGTH);
+ } else {
+ // | expression type (byte) | array length (int) | array bytes |
+ buffer.putInt(offset, strings.length);
+ offset += Integer.BYTES;
+ int sizeBytes = 1 + Integer.BYTES;
+ for (String string : strings) {
+ if (string == null) {
+ // | string length (int) |
+ sizeBytes += Integer.BYTES;
+ checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
+ buffer.putInt(offset, NULL_LENGTH);
+ offset += Integer.BYTES;
+ } else {
+ // | string length (int) | string bytes |
+ final byte[] stringElementBytes = StringUtils.toUtf8(string);
+ sizeBytes += Integer.BYTES + stringElementBytes.length;
+ checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
+ buffer.putInt(offset, stringElementBytes.length);
+ offset += Integer.BYTES;
+ final int oldPosition = buffer.position();
+ buffer.position(offset);
+ buffer.put(stringElementBytes, 0, stringElementBytes.length);
+ buffer.position(oldPosition);
+ offset += stringElementBytes.length;
+ }
+ }
+ }
+ break;
+ default:
+ throw new UOE("how can this be?");
+ }
+ }
+
+ private static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes)
+ {
+ if (sizeBytes > maxSizeBytes) {
+ throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes);
+ }
+ }
+
+ /**
+ * Converts a List to an appropriate array type, optionally doing some conversion to make multi-valued strings
+ * consistent across selector types, which are not consistent in treatment of null, [], and [null].
+ *
+ * If homogenizeMultiValueStrings is true, null and [] will be converted to [null], otherwise they will retain
+ */
+ @Nullable
+ public static Object coerceListToArray(@Nullable List<?> val, boolean homogenizeMultiValueStrings)
+ {
+ // if value is not null and has at least 1 element, conversion is unambigous regardless of the selector
+ if (val != null && val.size() > 0) {
+ Class<?> coercedType = null;
+
+ for (Object elem : val) {
+ if (elem != null) {
+ coercedType = convertType(coercedType, elem.getClass());
+ }
+ }
+
+ if (coercedType == Long.class || coercedType == Integer.class) {
+ return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
+ }
+ if (coercedType == Float.class || coercedType == Double.class) {
+ return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
+ }
+ // default to string
+ return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+ }
+ if (homogenizeMultiValueStrings) {
+ return new String[]{null};
+ } else {
+ if (val != null) {
+ return val.toArray();
+ }
+ return null;
+ }
+ }
+
+ /**
+ * Find the common type to use between 2 types, useful for choosing the appropriate type for an array given a set
+ * of objects with unknown type, following rules similar to Java, our own native Expr, and SQL implicit type
+ * conversions. This is used to assist in preparing native java objects for {@link Expr.ObjectBinding} which will
+ * later be wrapped in {@link ExprEval} when evaluating {@link IdentifierExpr}.
+ *
+ * If any type is string, then the result will be string because everything can be converted to a string, but a string
+ * cannot be converted to everything.
+ *
+ * For numbers, integer is the most restrictive type, only chosen if both types are integers. Longs win over integers,
+ * floats over longs and integers, and doubles win over everything.
+ */
+ private static Class convertType(@Nullable Class existing, Class next)
+ {
+ if (Number.class.isAssignableFrom(next) || next == String.class) {
+ if (existing == null) {
+ return next;
+ }
+ // string wins everything
+ if (existing == String.class) {
+ return existing;
+ }
+ if (next == String.class) {
+ return next;
+ }
+ // all numbers win over Integer
+ if (existing == Integer.class) {
+ return next;
+ }
+ if (existing == Float.class) {
+ // doubles win over floats
+ if (next == Double.class) {
+ return next;
+ }
+ return existing;
+ }
+ if (existing == Long.class) {
+ if (next == Integer.class) {
+ // long beats int
+ return existing;
+ }
+ // double and float win over longs
+ return next;
+ }
+ // otherwise double
+ return Double.class;
+ }
+ throw new UOE("Invalid array expression type: %s", next);
+ }
+
public static ExprEval ofLong(@Nullable Number longValue)
{
return new LongExprEval(longValue);
@@ -118,6 +460,13 @@
return new StringArrayExprEval((String[]) val);
}
+ if (val instanceof List) {
+ // do not convert empty lists to arrays with a single null element here, because that should have been done
+ // by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately
+ // empty
+ return bestEffortOf(coerceListToArray((List<?>) val, false));
+ }
+
return new StringExprEval(val == null ? null : String.valueOf(val));
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java
index eaacf56..80b5e03 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java
@@ -19,6 +19,8 @@
package org.apache.druid.math.expr;
+import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap;
+import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.ValueType;
@@ -29,13 +31,32 @@
*/
public enum ExprType
{
- DOUBLE,
- LONG,
- STRING,
- DOUBLE_ARRAY,
- LONG_ARRAY,
- STRING_ARRAY;
+ DOUBLE((byte) 0x01),
+ LONG((byte) 0x02),
+ STRING((byte) 0x03),
+ DOUBLE_ARRAY((byte) 0x04),
+ LONG_ARRAY((byte) 0x05),
+ STRING_ARRAY((byte) 0x06);
+ private static final Byte2ObjectMap<ExprType> TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length);
+
+ static {
+ for (ExprType type : ExprType.values()) {
+ TYPE_BYTES.put(type.getId(), type);
+ }
+ }
+
+ final byte id;
+
+ ExprType(byte id)
+ {
+ this.id = id;
+ }
+
+ public byte getId()
+ {
+ return id;
+ }
public boolean isNumeric()
{
@@ -47,6 +68,11 @@
return isScalar(this);
}
+ public static ExprType fromByte(byte id)
+ {
+ return TYPE_BYTES.get(id);
+ }
+
/**
* The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and
* cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to
@@ -177,5 +203,4 @@
}
return elementType;
}
-
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java
index baa5768..7b9fe527 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Function.java
@@ -42,6 +42,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
+import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@@ -384,13 +385,13 @@
@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
- return ImmutableSet.of(args.get(1));
+ return ImmutableSet.of(getScalarArgument(args));
}
@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
- return ImmutableSet.of(args.get(0));
+ return ImmutableSet.of(getArrayArgument(args));
}
@Override
@@ -402,14 +403,24 @@
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
- final ExprEval arrayExpr = args.get(0).eval(bindings);
- final ExprEval scalarExpr = args.get(1).eval(bindings);
+ final ExprEval arrayExpr = getArrayArgument(args).eval(bindings);
+ final ExprEval scalarExpr = getScalarArgument(args).eval(bindings);
if (arrayExpr.asArray() == null) {
return ExprEval.of(null);
}
return doApply(arrayExpr, scalarExpr);
}
+ Expr getScalarArgument(List<Expr> args)
+ {
+ return args.get(1);
+ }
+
+ Expr getArrayArgument(List<Expr> args)
+ {
+ return args.get(0);
+ }
+
abstract ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr);
}
@@ -450,8 +461,11 @@
final ExprEval arrayExpr1 = args.get(0).eval(bindings);
final ExprEval arrayExpr2 = args.get(1).eval(bindings);
- if (arrayExpr1.asArray() == null || arrayExpr2.asArray() == null) {
- return ExprEval.of(null);
+ if (arrayExpr1.asArray() == null) {
+ return arrayExpr1;
+ }
+ if (arrayExpr2.asArray() == null) {
+ return arrayExpr2;
}
return doApply(arrayExpr1, arrayExpr2);
@@ -460,6 +474,118 @@
abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr);
}
+ /**
+ * Scaffolding for a 2 argument {@link Function} which accepts one array and one scalar input and adds the scalar
+ * input to the array in some way.
+ */
+ abstract class ArrayAddElementFunction extends ArrayScalarFunction
+ {
+ @Override
+ public boolean hasArrayOutput()
+ {
+ return true;
+ }
+
+ @Nullable
+ @Override
+ public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
+ {
+ ExprType arrayType = getArrayArgument(args).getOutputType(inspector);
+ return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
+ }
+
+ @Override
+ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
+ {
+ switch (arrayExpr.type()) {
+ case STRING:
+ case STRING_ARRAY:
+ return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
+ case LONG:
+ case LONG_ARRAY:
+ return ExprEval.ofLongArray(
+ add(
+ arrayExpr.asLongArray(),
+ scalarExpr.isNumericNull() ? null : scalarExpr.asLong()
+ ).toArray(Long[]::new)
+ );
+ case DOUBLE:
+ case DOUBLE_ARRAY:
+ return ExprEval.ofDoubleArray(
+ add(
+ arrayExpr.asDoubleArray(),
+ scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()
+ ).toArray(Double[]::new)
+ );
+ }
+
+ throw new RE("Unable to add to unknown array type %s", arrayExpr.type());
+ }
+
+ abstract <T> Stream<T> add(T[] array, @Nullable T val);
+ }
+
+ /**
+ * Base scaffolding for functions which accept 2 array arguments and combine them in some way
+ */
+ abstract class ArraysMergeFunction extends ArraysFunction
+ {
+ @Override
+ public Set<Expr> getArrayInputs(List<Expr> args)
+ {
+ return ImmutableSet.copyOf(args);
+ }
+
+ @Override
+ public boolean hasArrayOutput()
+ {
+ return true;
+ }
+
+ @Nullable
+ @Override
+ public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
+ {
+ ExprType arrayType = args.get(0).getOutputType(inspector);
+ return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
+ }
+
+ @Override
+ ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
+ {
+ final Object[] array1 = lhsExpr.asArray();
+ final Object[] array2 = rhsExpr.asArray();
+
+ if (array1 == null) {
+ return ExprEval.of(null);
+ }
+ if (array2 == null) {
+ return lhsExpr;
+ }
+
+ switch (lhsExpr.type()) {
+ case STRING:
+ case STRING_ARRAY:
+ return ExprEval.ofStringArray(
+ merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
+ );
+ case LONG:
+ case LONG_ARRAY:
+ return ExprEval.ofLongArray(
+ merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
+ );
+ case DOUBLE:
+ case DOUBLE_ARRAY:
+ return ExprEval.ofDoubleArray(
+ merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
+ );
+ }
+ throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
+ }
+
+ abstract <T> Stream<T> merge(T[] array1, T[] array2);
+ }
+
abstract class ReduceFunction implements Function
{
private final DoubleBinaryOperator doubleReducer;
@@ -3168,7 +3294,7 @@
}
}
- class ArrayAppendFunction extends ArrayScalarFunction
+ class ArrayAppendFunction extends ArrayAddElementFunction
{
@Override
public String name()
@@ -3177,48 +3303,7 @@
}
@Override
- public boolean hasArrayOutput()
- {
- return true;
- }
-
- @Nullable
- @Override
- public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
- {
- ExprType arrayType = args.get(0).getOutputType(inspector);
- return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
- }
-
- @Override
- ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
- {
- switch (arrayExpr.type()) {
- case STRING:
- case STRING_ARRAY:
- return ExprEval.ofStringArray(this.append(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
- case LONG:
- case LONG_ARRAY:
- return ExprEval.ofLongArray(
- this.append(
- arrayExpr.asLongArray(),
- scalarExpr.isNumericNull() ? null : scalarExpr.asLong()).toArray(Long[]::new
- )
- );
- case DOUBLE:
- case DOUBLE_ARRAY:
- return ExprEval.ofDoubleArray(
- this.append(
- arrayExpr.asDoubleArray(),
- scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()).toArray(Double[]::new
- )
- );
- }
-
- throw new RE("Unable to append to unknown type %s", arrayExpr.type());
- }
-
- private <T> Stream<T> append(T[] array, T val)
+ <T> Stream<T> add(T[] array, @Nullable T val)
{
List<T> l = new ArrayList<>(Arrays.asList(array));
l.add(val);
@@ -3226,7 +3311,36 @@
}
}
- class ArrayConcatFunction extends ArraysFunction
+ class ArrayPrependFunction extends ArrayAddElementFunction
+ {
+ @Override
+ public String name()
+ {
+ return "array_prepend";
+ }
+
+ @Override
+ Expr getScalarArgument(List<Expr> args)
+ {
+ return args.get(0);
+ }
+
+ @Override
+ Expr getArrayArgument(List<Expr> args)
+ {
+ return args.get(1);
+ }
+
+ @Override
+ <T> Stream<T> add(T[] array, @Nullable T val)
+ {
+ List<T> l = new ArrayList<>(Arrays.asList(array));
+ l.add(0, val);
+ return l.stream();
+ }
+ }
+
+ class ArrayConcatFunction extends ArraysMergeFunction
{
@Override
public String name()
@@ -3235,59 +3349,7 @@
}
@Override
- public Set<Expr> getArrayInputs(List<Expr> args)
- {
- return ImmutableSet.copyOf(args);
- }
-
- @Override
- public boolean hasArrayOutput()
- {
- return true;
- }
-
- @Nullable
- @Override
- public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
- {
- ExprType arrayType = args.get(0).getOutputType(inspector);
- return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
- }
-
- @Override
- ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
- {
- final Object[] array1 = lhsExpr.asArray();
- final Object[] array2 = rhsExpr.asArray();
-
- if (array1 == null) {
- return ExprEval.of(null);
- }
- if (array2 == null) {
- return lhsExpr;
- }
-
- switch (lhsExpr.type()) {
- case STRING:
- case STRING_ARRAY:
- return ExprEval.ofStringArray(
- cat(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
- );
- case LONG:
- case LONG_ARRAY:
- return ExprEval.ofLongArray(
- cat(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
- );
- case DOUBLE:
- case DOUBLE_ARRAY:
- return ExprEval.ofDoubleArray(
- cat(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
- );
- }
- throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
- }
-
- private <T> Stream<T> cat(T[] array1, T[] array2)
+ <T> Stream<T> merge(T[] array1, T[] array2)
{
List<T> l = new ArrayList<>(Arrays.asList(array1));
l.addAll(Arrays.asList(array2));
@@ -3295,6 +3357,40 @@
}
}
+ class ArraySetAddFunction extends ArrayAddElementFunction
+ {
+ @Override
+ public String name()
+ {
+ return "array_set_add";
+ }
+
+ @Override
+ <T> Stream<T> add(T[] array, @Nullable T val)
+ {
+ Set<T> l = new HashSet<>(Arrays.asList(array));
+ l.add(val);
+ return l.stream();
+ }
+ }
+
+ class ArraySetAddAllFunction extends ArraysMergeFunction
+ {
+ @Override
+ public String name()
+ {
+ return "array_set_add_all";
+ }
+
+ @Override
+ <T> Stream<T> merge(T[] array1, T[] array2)
+ {
+ Set<T> l = new HashSet<>(Arrays.asList(array1));
+ l.addAll(Arrays.asList(array2));
+ return l.stream();
+ }
+ }
+
class ArrayContainsFunction extends ArraysFunction
{
@Override
@@ -3438,93 +3534,4 @@
throw new RE("Unable to slice to unknown type %s", expr.type());
}
}
-
- class ArrayPrependFunction implements Function
- {
- @Override
- public String name()
- {
- return "array_prepend";
- }
-
- @Override
- public void validateArguments(List<Expr> args)
- {
- if (args.size() != 2) {
- throw new IAE("Function[%s] needs 2 arguments", name());
- }
- }
-
- @Nullable
- @Override
- public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
- {
- ExprType arrayType = args.get(1).getOutputType(inspector);
- return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
- }
-
- @Override
- public Set<Expr> getScalarInputs(List<Expr> args)
- {
- return ImmutableSet.of(args.get(0));
- }
-
- @Override
- public Set<Expr> getArrayInputs(List<Expr> args)
- {
- return ImmutableSet.of(args.get(1));
- }
-
- @Override
- public boolean hasArrayInputs()
- {
- return true;
- }
-
- @Override
- public boolean hasArrayOutput()
- {
- return true;
- }
-
- @Override
- public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
- {
- final ExprEval scalarExpr = args.get(0).eval(bindings);
- final ExprEval arrayExpr = args.get(1).eval(bindings);
- if (arrayExpr.asArray() == null) {
- return ExprEval.of(null);
- }
- switch (arrayExpr.type()) {
- case STRING:
- case STRING_ARRAY:
- return ExprEval.ofStringArray(this.prepend(scalarExpr.asString(), arrayExpr.asStringArray()).toArray(String[]::new));
- case LONG:
- case LONG_ARRAY:
- return ExprEval.ofLongArray(
- this.prepend(
- scalarExpr.isNumericNull() ? null : scalarExpr.asLong(),
- arrayExpr.asLongArray()).toArray(Long[]::new
- )
- );
- case DOUBLE:
- case DOUBLE_ARRAY:
- return ExprEval.ofDoubleArray(
- this.prepend(
- scalarExpr.isNumericNull() ? null : scalarExpr.asDouble(),
- arrayExpr.asDoubleArray()).toArray(Double[]::new
- )
- );
- }
-
- throw new RE("Unable to prepend to unknown type %s", arrayExpr.type());
- }
-
- private <T> Stream<T> prepend(T val, T[] array)
- {
- List<T> l = new ArrayList<>(Arrays.asList(array));
- l.add(0, val);
- return l.stream();
- }
- }
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java
index c9388bf..b0c923c 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Parser.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java
@@ -22,6 +22,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
@@ -35,12 +36,14 @@
import org.apache.druid.math.expr.antlr.ExprLexer;
import org.apache.druid.math.expr.antlr.ExprParser;
+import javax.annotation.Nullable;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
public class Parser
@@ -96,17 +99,33 @@
}
/**
- * Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
- * so re-use instead of re-creating whenever possible.
+ * Create a memoized lazy supplier to parse a string into a flattened {@link Expr}. There is some overhead to this,
+ * and these objects are all immutable, so this assists in the goal of re-using instead of re-creating whenever
+ * possible.
+ *
+ * Lazy form of {@link #parse(String, ExprMacroTable)}
+ *
* @param in expression to parse
* @param macroTable additional extensions to expression language
- * @return
+ */
+ public static Supplier<Expr> lazyParse(@Nullable String in, ExprMacroTable macroTable)
+ {
+ return Suppliers.memoize(() -> in == null ? null : Parser.parse(in, macroTable));
+ }
+
+ /**
+ * Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
+ * so re-use instead of re-creating whenever possible.
+ *
+ * @param in expression to parse
+ * @param macroTable additional extensions to expression language
*/
public static Expr parse(String in, ExprMacroTable macroTable)
{
return parse(in, macroTable, true);
}
+
@VisibleForTesting
public static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten)
{
@@ -164,47 +183,43 @@
/**
* Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
* used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
- * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr}
- * @param expr expression to visit and rewrite
- * @param bindingsToApply
- * @return
+ * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
+ *
+ * This function applies a transformation for "map" style uses, such as column selectors, where the supplied
+ * expression will be transformed to return an array of results instead of the scalar result (or appropriately
+ * rewritten into existing apply expressions to produce correct results when referenced from a scalar context).
+ *
+ * This function and {@link #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)} exist to handle
+ * "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
+ * ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
+ * means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
+ * functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
+ * important because the writer of the query might not actually know if the column is definitively always single or
+ * multi-valued (and it might in fact not be).
+ *
+ * @see #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)
*/
- public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply)
+ public static Expr applyUnappliedBindings(
+ Expr expr,
+ Expr.BindingAnalysis bindingAnalysis,
+ List<String> bindingsToApply
+ )
{
if (bindingsToApply.isEmpty()) {
// nothing to do, expression is fine as is
return expr;
}
// filter the list of bindings to those which are used in this expression
- List<String> unappliedBindingsInExpression = bindingsToApply.stream()
- .filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
- .collect(Collectors.toList());
+ List<String> unappliedBindingsInExpression =
+ bindingsToApply.stream()
+ .filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
+ .collect(Collectors.toList());
// any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
- Expr newExpr = expr.visit(
- childExpr -> {
- if (childExpr instanceof ApplyFunctionExpr) {
- // try to lift unapplied arguments into the apply function lambda
- return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
- } else if (childExpr instanceof FunctionExpr) {
- // check array function arguments for unapplied identifiers to transform if necessary
- FunctionExpr fnExpr = (FunctionExpr) childExpr;
- Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
- List<Expr> newArgs = new ArrayList<>();
- for (Expr arg : fnExpr.args) {
- if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
- Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression);
- newArgs.add(newArg);
- } else {
- newArgs.add(arg);
- }
- }
-
- FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
- return newFnExpr;
- }
- return childExpr;
- }
+ Expr newExpr = rewriteUnappliedSubExpressions(
+ expr,
+ unappliedBindingsInExpression,
+ (arg) -> applyUnappliedBindings(arg, bindingAnalysis, bindingsToApply)
);
Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
@@ -221,9 +236,123 @@
return applyUnapplied(newExpr, remainingUnappliedBindings);
}
+
+ /**
+ * Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
+ * used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
+ * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
+ *
+ * This function applies a transformation for "fold" style uses, such as aggregators, where the supplied
+ * expression will be transformed to accumulate the result of applying the expression to each value of the unapplied
+ * input (or appropriately rewritten into existing apply expressions to produce correct results when referenced from
+ * a scalar context). This rewriting assumes that there exists some accumulator variable, which is re-used as the
+ * accumulator for this fold rewrite, so that evaluating each expression can be accumulated into the larger external
+ * fold operation that an aggregator might be performing.
+ *
+ * This function and {@link #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)} exist to handle
+ * "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
+ * ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
+ * means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
+ * functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
+ * important because the writer of the query might not actually know if the column is definitively always single or
+ * multi-valued (and it might in fact not be).
+ *
+ * @see #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)
+ */
+ public static Expr foldUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply, String accumulatorId)
+ {
+ if (bindingsToApply.isEmpty()) {
+ // nothing to do, expression is fine as is
+ return expr;
+ }
+
+ // filter the list of bindings to those which are used in this expression
+ List<String> unappliedBindingsInExpression =
+ bindingsToApply.stream()
+ .filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
+ .collect(Collectors.toList());
+
+ Expr newExpr = rewriteUnappliedSubExpressions(
+ expr,
+ unappliedBindingsInExpression,
+ (arg) -> foldUnappliedBindings(arg, bindingAnalysis, bindingsToApply, accumulatorId)
+ );
+
+ Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
+ final Set<String> expectedArrays = newExprBindings.getArrayVariables();
+
+ List<String> remainingUnappliedBindings =
+ unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList());
+
+ // if lifting the lambdas got rid of all missing bindings, return the transformed expression
+ if (remainingUnappliedBindings.isEmpty()) {
+ return newExpr;
+ }
+
+ return foldUnapplied(newExpr, remainingUnappliedBindings, accumulatorId);
+ }
+
+ /**
+ * Any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten to "lift"
+ * the identifier variables and transform the function.
+ *
+ * For example:
+ * if "y" is unapplied:
+ * map((x) -> x + y, x) => cartesian_map((x,y) -> x + y, x, y)
+ *
+ * @see #liftApplyLambda(ApplyFunctionExpr, List)
+ *
+ * Array functions on expressions using unapplied identifiers might also need transformed, so we recursively call the
+ * unapplied binding transformation function (supplied to this method) on that expression to ensure proper
+ * transformation and rewrite of these array expressions.
+ *
+ * For example:
+ * if "y" is unapplied:
+ * array_length(filter((x) -> x > y, x))
+ */
+ private static Expr rewriteUnappliedSubExpressions(
+ Expr expr,
+ List<String> unappliedBindingsInExpression,
+ UnaryOperator<Expr> applyUnappliedFn
+ )
+ {
+ // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
+ return expr.visit(
+ childExpr -> {
+ if (childExpr instanceof ApplyFunctionExpr) {
+ // try to lift unapplied arguments into the apply function lambda
+ return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
+ } else if (childExpr instanceof FunctionExpr) {
+ // check array function arguments for unapplied identifiers to transform if necessary
+ FunctionExpr fnExpr = (FunctionExpr) childExpr;
+ Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
+ List<Expr> newArgs = new ArrayList<>();
+ for (Expr arg : fnExpr.args) {
+ if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
+ Expr newArg = applyUnappliedFn.apply(arg);
+ newArgs.add(newArg);
+ } else {
+ newArgs.add(arg);
+ }
+ }
+
+ FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
+ return newFnExpr;
+ }
+ return childExpr;
+ }
+ );
+ }
+
/**
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.MapFunction} or
- * {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied
+ * {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied.
+ *
+ * For example:
+ * if "x" is unapplied:
+ * x + y => map((x) -> x + y, x)
+ * if "x" and "y" are unapplied:
+ * x + y => cartesian_map((x, y) -> x + y, x, y)
*/
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
{
@@ -275,6 +404,72 @@
}
/**
+ * translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.FoldFunction} or
+ * {@link ApplyFunction.CartesianFoldFunction} if there are multiple unbound arguments to be applied.
+ *
+ * This assumes a known {@link IdentifierExpr} is an "accumulator", which is re-used as the accumulator variable and
+ * input for the translated fold.
+ *
+ * For example given an accumulator "__acc":
+ * if "x" is unapplied:
+ * __acc + x => fold((x, __acc) -> x + __acc, x, __acc)
+ * if "x" and "y" are unapplied:
+ * __acc + x + y => cartesian_fold((x, y, __acc) -> __acc + x + y, x, y, __acc)
+ *
+ */
+ private static Expr foldUnapplied(Expr expr, List<String> unappliedBindings, String accumulatorId)
+ {
+
+ // filter to get list of IdentifierExpr that are backed by the unapplied bindings
+ final List<IdentifierExpr> args = expr.analyzeInputs()
+ .getFreeVariables()
+ .stream()
+ .filter(x -> unappliedBindings.contains(x.getBinding()))
+ .collect(Collectors.toList());
+
+ final List<IdentifierExpr> lambdaArgs = new ArrayList<>();
+
+ // construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
+ // that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
+ // replacements are done by binding rather than identifier because repeats of the same input should not result
+ // in a cartesian product
+ final Map<String, IdentifierExpr> toReplace = new HashMap<>();
+ for (IdentifierExpr applyFnArg : args) {
+ if (!toReplace.containsKey(applyFnArg.getBinding())) {
+ IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
+ lambdaArgs.add(lambdaRewrite);
+ toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
+ }
+ }
+
+ lambdaArgs.add(new IdentifierExpr(accumulatorId));
+
+ // rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
+ // are constructing
+ Expr newExpr = expr.visit(childExpr -> {
+ if (childExpr instanceof IdentifierExpr) {
+ if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
+ return toReplace.get(((IdentifierExpr) childExpr).getBinding());
+ }
+ }
+ return childExpr;
+ });
+
+
+ // wrap an expression in either fold or cartesian_fold to apply any unapplied identifiers
+ final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
+ final ApplyFunction fn;
+ if (lambdaArgs.size() == 2) {
+ fn = new ApplyFunction.FoldFunction();
+ } else {
+ fn = new ApplyFunction.CartesianFoldFunction();
+ }
+
+ final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
+ return magic;
+ }
+
+ /**
* Performs partial lifting of free identifiers of the lambda expression of an {@link ApplyFunctionExpr}, constrained
* by a list of "unapplied" identifiers, and translating them into arguments of a new {@link LambdaExpr} and
* {@link ApplyFunctionExpr} as appropriate.
diff --git a/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java
new file mode 100644
index 0000000..8b414e5
--- /dev/null
+++ b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.math.expr;
+
+import com.google.common.collect.Maps;
+
+import javax.annotation.Nullable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Simple map backed object binding
+ */
+public class SettableObjectBinding implements Expr.ObjectBinding
+{
+ private final Map<String, Object> bindings;
+
+ public SettableObjectBinding()
+ {
+ this.bindings = new HashMap<>();
+ }
+
+ public SettableObjectBinding(int expectedSize)
+ {
+ this.bindings = Maps.newHashMapWithExpectedSize(expectedSize);
+ }
+
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ return bindings.get(name);
+ }
+
+ public SettableObjectBinding withBinding(String name, @Nullable Object value)
+ {
+ bindings.put(name, value);
+ return this;
+ }
+}
diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
new file mode 100644
index 0000000..b15f321
--- /dev/null
+++ b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.math.expr;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ExprEvalTest extends InitializedNullHandlingTest
+{
+ private static int MAX_SIZE_BYTES = 1 << 13;
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
+
+ @Test
+ public void testStringSerde()
+ {
+ assertExpr(0, "hello");
+ assertExpr(1234, "hello");
+ assertExpr(0, ExprEval.bestEffortOf(null));
+ }
+
+ @Test
+ public void testStringSerdeTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING, 16, 10));
+ assertExpr(0, ExprEval.of("hello world"), 10);
+ }
+
+
+ @Test
+ public void testLongSerde()
+ {
+ assertExpr(0, 1L);
+ assertExpr(1234, 1L);
+ assertExpr(1234, ExprEval.ofLong(null));
+ }
+
+ @Test
+ public void testDoubleSerde()
+ {
+ assertExpr(0, 1.123);
+ assertExpr(1234, 1.123);
+ assertExpr(1234, ExprEval.ofDouble(null));
+ }
+
+ @Test
+ public void testStringArraySerde()
+ {
+ assertExpr(0, new String[] {"hello", "hi", "hey"});
+ assertExpr(1024, new String[] {"hello", null, "hi", "hey"});
+ assertExpr(2048, new String[] {});
+ }
+
+ @Test
+ public void testStringArraySerdeToBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING_ARRAY, 14, 10));
+ assertExpr(0, ExprEval.ofStringArray(new String[] {"hello", "hi", "hey"}), 10);
+ }
+
+ @Test
+ public void testLongArraySerde()
+ {
+ assertExpr(0, new Long[] {1L, 2L, 3L});
+ assertExpr(1234, new Long[] {1L, 2L, null, 3L});
+ assertExpr(1234, new Long[] {});
+ }
+
+ @Test
+ public void testLongArraySerdeTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.LONG_ARRAY, 29, 10));
+ assertExpr(0, ExprEval.ofLongArray(new Long[] {1L, 2L, 3L}), 10);
+ }
+
+ @Test
+ public void testDoubleArraySerde()
+ {
+ assertExpr(0, new Double[] {1.1, 2.2, 3.3});
+ assertExpr(1234, new Double[] {1.1, 2.2, null, 3.3});
+ assertExpr(1234, new Double[] {});
+ }
+
+ @Test
+ public void testDoubleArraySerdeTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.DOUBLE_ARRAY, 29, 10));
+ assertExpr(0, ExprEval.ofDoubleArray(new Double[] {1.1, 2.2, 3.3}), 10);
+ }
+
+ @Test
+ public void test_coerceListToArray()
+ {
+ Assert.assertNull(ExprEval.coerceListToArray(null, false));
+ Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false));
+ Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true));
+ Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true));
+
+ List<Long> longList = ImmutableList.of(1L, 2L, 3L);
+ Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false));
+
+ List<Integer> intList = ImmutableList.of(1, 2, 3);
+ Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false));
+
+ List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
+ Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false));
+
+ List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
+ Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false));
+
+ List<String> stringList = ImmutableList.of("a", "b", "c");
+ Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false));
+
+ List<String> withNulls = new ArrayList<>();
+ withNulls.add("a");
+ withNulls.add(null);
+ withNulls.add("c");
+ Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false));
+
+ List<Long> withNumberNulls = new ArrayList<>();
+ withNumberNulls.add(1L);
+ withNumberNulls.add(null);
+ withNumberNulls.add(3L);
+
+ Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false));
+
+ List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
+ Assert.assertArrayEquals(
+ new String[]{"1", "b", "3"},
+ (String[]) ExprEval.coerceListToArray(withStringMix, false)
+ );
+
+ List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
+ Assert.assertArrayEquals(
+ new Long[]{1L, 2L, 3L},
+ (Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false)
+ );
+
+ List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
+ Assert.assertArrayEquals(
+ new Double[]{1.0, 2.0, 3.0},
+ (Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false)
+ );
+
+ List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
+ Assert.assertArrayEquals(
+ new Double[]{1.0, 2.0, 3.0},
+ (Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false)
+ );
+
+ List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
+ Assert.assertArrayEquals(
+ new Double[]{1.0, 2.0, 3.0},
+ (Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false)
+ );
+
+ List<String> withAllNulls = new ArrayList<>();
+ withAllNulls.add(null);
+ withAllNulls.add(null);
+ withAllNulls.add(null);
+ Assert.assertArrayEquals(
+ new String[]{null, null, null},
+ (String[]) ExprEval.coerceListToArray(withAllNulls, false)
+ );
+ }
+
+ private void assertExpr(int position, Object expected)
+ {
+ assertExpr(position, ExprEval.bestEffortOf(expected));
+ }
+
+ private void assertExpr(int position, ExprEval expected)
+ {
+ assertExpr(position, expected, MAX_SIZE_BYTES);
+ }
+
+ private void assertExpr(int position, ExprEval expected, int maxSizeBytes)
+ {
+ ExprEval.serialize(buffer, position, expected, maxSizeBytes);
+ if (ExprType.isArray(expected.type())) {
+ Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray());
+ } else {
+ Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value());
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index bd729fe..1bd423f 100644
--- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -291,6 +291,27 @@
}
@Test
+ public void testArraySetAdd()
+ {
+ assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
+ assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L});
+ assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L});
+ assertArrayExpr("array_set_add([], 1)", new String[]{"1"});
+ assertArrayExpr("array_set_add(<LONG>[], 1)", new Long[]{1L});
+ assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{null});
+ }
+
+ @Test
+ public void testArraySetAddAll()
+ {
+ assertArrayExpr("array_set_add_all([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 4L, 6L});
+ assertArrayExpr("array_set_add_all([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
+ assertArrayExpr("array_set_add_all(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L});
+ assertArrayExpr("array_set_add_all(map(y -> y * 3, b), [1, 2, 3])", new Long[]{1L, 2L, 3L, 6L, 9L, 12L, 15L});
+ assertArrayExpr("array_set_add_all(0, 1)", new Long[]{0L, 1L});
+ }
+
+ @Test
public void testArrayToString()
{
assertExpr("array_to_string([1, 2, 3], ',')", "1,2,3");
diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
index 6998a0e..51f991f 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
@@ -529,6 +529,48 @@
}
@Test
+ public void testFoldUnapplied()
+ {
+ validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of(), "__acc");
+ validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of("z"), "__acc");
+ validateFoldUnapplied(
+ "x + __acc",
+ "(+ x __acc)",
+ "(fold ([x, __acc] -> (+ x __acc)), [x, __acc])",
+ ImmutableList.of("x"),
+ "__acc"
+ );
+ validateFoldUnapplied(
+ "x + y + __acc",
+ "(+ (+ x y) __acc)",
+ "(cartesian_fold ([x, y, __acc] -> (+ (+ x y) __acc)), [x, y, __acc])",
+ ImmutableList.of("x", "y"),
+ "__acc"
+ );
+ validateFoldUnapplied(
+ "__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
+ "(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
+ "(fold ([z, __acc] -> (+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))), [z, __acc])",
+ ImmutableList.of("z"),
+ "__acc"
+ );
+ validateFoldUnapplied(
+ "__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
+ "(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
+ "(fold ([z, __acc] -> (+ (+ __acc z) (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))), [z, __acc])",
+ ImmutableList.of("y", "z"),
+ "__acc"
+ );
+ validateFoldUnapplied(
+ "__acc + fold((x, acc) -> x + y + acc, x, __acc)",
+ "(+ __acc (fold ([x, acc] -> (+ (+ x y) acc)), [x, __acc]))",
+ "(+ __acc (cartesian_fold ([x, y, acc] -> (+ (+ x y) acc)), [x, y, __acc]))",
+ ImmutableList.of("y"),
+ "__acc"
+ );
+ }
+
+ @Test
public void testUniquify()
{
validateParser("x-x", "(- x x)", ImmutableList.of("x"), ImmutableSet.of("x", "x_0"));
@@ -666,6 +708,33 @@
Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
}
+ private void validateFoldUnapplied(
+ String expression,
+ String unapplied,
+ String applied,
+ List<String> identifiers,
+ String accumulator
+ )
+ {
+ final Expr parsed = Parser.parse(expression, ExprMacroTable.nil());
+ Expr.BindingAnalysis deets = parsed.analyzeInputs();
+ Parser.validateExpr(parsed, deets);
+ final Expr transformed = Parser.foldUnappliedBindings(parsed, deets, identifiers, accumulator);
+ Assert.assertEquals(expression, unapplied, parsed.toString());
+ Assert.assertEquals(applied, applied, transformed.toString());
+
+ final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
+ final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil());
+ Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs();
+ Parser.validateExpr(parsedRoundTrip, roundTripDeets);
+ final Expr transformedRoundTrip = Parser.foldUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers, accumulator);
+ Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString());
+ Assert.assertEquals(applied, applied, transformedRoundTrip.toString());
+
+ Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify());
+ Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
+ }
+
private void validateConstantExpression(String expression, Object expected)
{
Expr parsed = Parser.parse(expression, ExprMacroTable.nil());
diff --git a/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java b/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java
index 6d769e0..7b005ee 100644
--- a/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java
@@ -415,29 +415,6 @@
.toArray(String[][]::new);
}
- static class SettableObjectBinding implements Expr.ObjectBinding
- {
- private final Map<String, Object> bindings;
-
- SettableObjectBinding()
- {
- this.bindings = new HashMap<>();
- }
-
- @Nullable
- @Override
- public Object get(String name)
- {
- return bindings.get(name);
- }
-
- public SettableObjectBinding withBinding(String name, @Nullable Object value)
- {
- bindings.put(name, value);
- return this;
- }
- }
-
static class SettableVectorInputBinding implements Expr.VectorInputBinding
{
private final Map<String, boolean[]> nulls;
diff --git a/docs/misc/math-expr.md b/docs/misc/math-expr.md
index 0fbe40d..174e711 100644
--- a/docs/misc/math-expr.md
+++ b/docs/misc/math-expr.md
@@ -177,14 +177,17 @@
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false`if no matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false` if no matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |
-| array_append(arr1,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
+| array_append(arr,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
| array_concat(arr1,arr2) | concatenates 2 arrays, the resulting array type determined by the type of the first array |
+| array_set_add(arr,expr) | adds expr to arr and converts the array to a new array composed of the unique set of elements. The resulting array type determined by the type of the array |
+| array_set_add_all(arr1,arr2) | combines the unique set of elements of 2 arrays, the resulting array type determined by the type of the first array |
| array_slice(arr,start,end) | return the subarray of arr from the 0 based index start(inclusive) to end(exclusive), or `null`, if start is less than 0, greater than length of arr or less than end|
| array_to_string(arr,str) | joins all elements of arr by the delimiter specified by str |
| string_to_array(str1,str2) | splits str1 into an array on the delimiter specified by str2 |
## Apply functions
+Apply functions allow for special 'lambda' expressions to be defined and applied to array inputs to enable free-form transformations.
| function | description |
| --- | --- |
@@ -197,6 +200,26 @@
| all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 |
+### Lambda expressions syntax
+Lambda expressions are a sort of function definition, where new identifiers can be defined and passed as input to the expression body
+```
+(identifier1 ...) -> expr
+```
+e.g.
+```
+(x, y) -> x + y
+```
+The identifier arguments of a lambda expression correspond to the elements of the array it is being applied to. For example:
+```
+map((x) -> x + 1, some_multi_value_column)
+```
+will map each element of `some_multi_value_column` to the identifier `x` so that the lambda expression body can be evaluated for each `x`. The scoping rules are that lambda arguments will override identifiers which are defined externally from the lambda expression body. Using the same example:
+
+```
+map((x) -> x + 1, x)
+```
+in this case, the `x` when evaluating `x + 1` is the lambda argument, thus an element of the multi-valued column `x`, rather than the column `x` itself.
+
## Reduction functions
Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as
diff --git a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java
index 795ea5b..155b8e7 100644
--- a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java
+++ b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java
@@ -27,6 +27,7 @@
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
@@ -120,7 +121,8 @@
@JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class),
- @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class)
+ @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class),
+ @JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class)
})
public interface AggregatorFactoryMixin
{
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java
index 3c7b8d4..34a6ead 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java
@@ -137,6 +137,9 @@
// GROUPING aggregator
public static final byte GROUPING_CACHE_TYPE_ID = 0x46;
+ // expression lambda aggregator
+ public static final byte EXPRESSION_LAMBDA_CACHE_TYPE_ID = 0x47;
+
/**
* returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
new file mode 100644
index 0000000..0305c8a
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import org.apache.druid.math.expr.Expr;
+
+import javax.annotation.Nullable;
+
+public class ExpressionLambdaAggregator implements Aggregator
+{
+ private final Expr lambda;
+ private final ExpressionLambdaAggregatorInputBindings bindings;
+
+ public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings)
+ {
+ this.lambda = lambda;
+ this.bindings = bindings;
+ }
+
+ @Override
+ public void aggregate()
+ {
+ bindings.accumulate(lambda.eval(bindings));
+ }
+
+ @Nullable
+ @Override
+ public Object get()
+ {
+ return bindings.getAccumulator().value();
+ }
+
+ @Override
+ public float getFloat()
+ {
+ return (float) bindings.getAccumulator().asDouble();
+ }
+
+ @Override
+ public long getLong()
+ {
+ return bindings.getAccumulator().asLong();
+ }
+
+ @Override
+ public double getDouble()
+ {
+ return bindings.getAccumulator().asDouble();
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return bindings.getAccumulator().isNumericNull();
+ }
+
+ @Override
+ public void close()
+ {
+ // nothing to close
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
new file mode 100644
index 0000000..2da1abd
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
@@ -0,0 +1,516 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import com.fasterxml.jackson.annotation.JacksonInject;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.Iterables;
+import org.apache.druid.java.util.common.HumanReadableBytes;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Comparators;
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExprType;
+import org.apache.druid.math.expr.Parser;
+import org.apache.druid.math.expr.SettableObjectBinding;
+import org.apache.druid.query.cache.CacheKeyBuilder;
+import org.apache.druid.query.expression.ExprUtils;
+import org.apache.druid.segment.ColumnInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.virtual.ExpressionPlan;
+import org.apache.druid.segment.virtual.ExpressionPlanner;
+import org.apache.druid.segment.virtual.ExpressionSelectors;
+
+import javax.annotation.Nullable;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
+{
+ private static final String FINALIZE_IDENTIFIER = "o";
+ private static final String COMPARE_O1 = "o1";
+ private static final String COMPARE_O2 = "o2";
+ private static final String DEFAULT_ACCUMULATOR_ID = "__acc";
+
+ // minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double)
+ // | expression type byte | is_null byte | primitive value (8 bytes) |
+ private static final int MIN_SIZE_BYTES = 10;
+ private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10);
+
+ private final String name;
+ @Nullable
+ private final Set<String> fields;
+ private final String accumulatorId;
+ private final String foldExpressionString;
+ private final String initialValueExpressionString;
+ private final String initialCombineValueExpressionString;
+
+ private final String combineExpressionString;
+ @Nullable
+ private final String compareExpressionString;
+ @Nullable
+ private final String finalizeExpressionString;
+
+ private final ExprMacroTable macroTable;
+ private final Supplier<ExprEval<?>> initialValue;
+ private final Supplier<ExprEval<?>> initialCombineValue;
+ private final Supplier<Expr> foldExpression;
+ private final Supplier<Expr> combineExpression;
+ private final Supplier<Expr> compareExpression;
+ private final Supplier<Expr> finalizeExpression;
+ private final HumanReadableBytes maxSizeBytes;
+
+ private final Supplier<SettableObjectBinding> compareBindings =
+ Suppliers.memoize(() -> new SettableObjectBinding(2));
+ private final Supplier<SettableObjectBinding> combineBindings =
+ Suppliers.memoize(() -> new SettableObjectBinding(2));
+ private final Supplier<SettableObjectBinding> finalizeBindings =
+ Suppliers.memoize(() -> new SettableObjectBinding(1));
+
+ @JsonCreator
+ public ExpressionLambdaAggregatorFactory(
+ @JsonProperty("name") String name,
+ @JsonProperty("fields") @Nullable final Set<String> fields,
+ @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier,
+ @JsonProperty("initialValue") final String initialValue,
+ @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue,
+ @JsonProperty("fold") final String foldExpression,
+ @JsonProperty("combine") @Nullable final String combineExpression,
+ @JsonProperty("compare") @Nullable final String compareExpression,
+ @JsonProperty("finalize") @Nullable final String finalizeExpression,
+ @JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes,
+ @JacksonInject ExprMacroTable macroTable
+ )
+ {
+ Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
+
+ this.name = name;
+ this.fields = fields;
+ this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID;
+
+ this.initialValueExpressionString = initialValue;
+ this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue;
+ this.foldExpressionString = foldExpression;
+ if (combineExpression != null) {
+ this.combineExpressionString = combineExpression;
+ } else {
+ // if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the
+ // field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should
+ // already be set with the rewritten value at that point
+ Preconditions.checkArgument(
+ fields != null && fields.size() == 1,
+ "Must have a single input field if no combine expression is supplied"
+ );
+ this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name);
+ }
+ this.compareExpressionString = compareExpression;
+ this.finalizeExpressionString = finalizeExpression;
+ this.macroTable = macroTable;
+
+ this.initialValue = Suppliers.memoize(() -> {
+ Expr parsed = Parser.parse(initialValue, macroTable);
+ Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
+ return parsed.eval(ExprUtils.nilBindings());
+ });
+ this.initialCombineValue = Suppliers.memoize(() -> {
+ Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
+ Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
+ return parsed.eval(ExprUtils.nilBindings());
+ });
+ this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
+ this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
+ this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable);
+ this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
+ this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
+ Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);
+ }
+
+ @JsonProperty
+ @Override
+ public String getName()
+ {
+ return name;
+ }
+
+ @JsonProperty
+ @Nullable
+ public Set<String> getFields()
+ {
+ return fields;
+ }
+
+ @JsonProperty
+ @Nullable
+ public String getAccumulatorIdentifier()
+ {
+ return accumulatorId;
+ }
+
+ @JsonProperty("initialValue")
+ public String getInitialValueExpressionString()
+ {
+ return initialValueExpressionString;
+ }
+
+ @JsonProperty("initialCombineValue")
+ public String getInitialCombineValueExpressionString()
+ {
+ return initialCombineValueExpressionString;
+ }
+
+ @JsonProperty("fold")
+ public String getFoldExpressionString()
+ {
+ return foldExpressionString;
+ }
+
+ @JsonProperty("combine")
+ public String getCombineExpressionString()
+ {
+ return combineExpressionString;
+ }
+
+ @JsonProperty("compare")
+ @Nullable
+ public String getCompareExpressionString()
+ {
+ return compareExpressionString;
+ }
+
+ @JsonProperty("finalize")
+ @Nullable
+ public String getFinalizeExpressionString()
+ {
+ return finalizeExpressionString;
+ }
+
+ @JsonProperty("maxSizeBytes")
+ public HumanReadableBytes getMaxSizeBytes()
+ {
+ return maxSizeBytes;
+ }
+
+ @Override
+ public byte[] getCacheKey()
+ {
+ return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID)
+ .appendStrings(fields)
+ .appendString(initialValueExpressionString)
+ .appendString(initialCombineValueExpressionString)
+ .appendString(foldExpressionString)
+ .appendString(combineExpressionString)
+ .appendString(compareExpressionString)
+ .appendString(finalizeExpressionString)
+ .appendInt(maxSizeBytes.getBytesInInt())
+ .build();
+ }
+
+ @Override
+ public Aggregator factorize(ColumnSelectorFactory metricFactory)
+ {
+ FactorizePlan thePlan = new FactorizePlan(metricFactory);
+ return new ExpressionLambdaAggregator(
+ thePlan.getExpression(),
+ thePlan.getBindings()
+ );
+ }
+
+ @Override
+ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
+ {
+ FactorizePlan thePlan = new FactorizePlan(metricFactory);
+ return new ExpressionLambdaBufferAggregator(
+ thePlan.getExpression(),
+ thePlan.getInitialValue(),
+ thePlan.getBindings(),
+ maxSizeBytes.getBytesInInt()
+ );
+ }
+
+ @Override
+ public Comparator getComparator()
+ {
+ Expr compareExpr = compareExpression.get();
+ if (compareExpr != null) {
+ return (o1, o2) ->
+ compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
+ }
+ switch (initialValue.get().type()) {
+ case LONG:
+ return LongSumAggregator.COMPARATOR;
+ case DOUBLE:
+ return DoubleSumAggregator.COMPARATOR;
+ default:
+ return Comparators.naturalNullsFirst();
+ }
+ }
+
+ @Nullable
+ @Override
+ public Object combine(@Nullable Object lhs, @Nullable Object rhs)
+ {
+ // arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function
+ return combineExpression.get().eval(
+ combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs)
+ ).value();
+ }
+
+ @Override
+ public Object deserialize(Object object)
+ {
+ return object;
+ }
+
+ @Nullable
+ @Override
+ public Object finalizeComputation(@Nullable Object object)
+ {
+ Expr finalizeExpr;
+ finalizeExpr = finalizeExpression.get();
+ if (finalizeExpr != null) {
+ return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value();
+ }
+ return object;
+ }
+
+ @Override
+ public List<String> requiredFields()
+ {
+ if (fields == null) {
+ return combineExpression.get().analyzeInputs().getRequiredBindingsList();
+ }
+ return foldExpression.get().analyzeInputs().getRequiredBindingsList();
+ }
+
+ @Override
+ public AggregatorFactory getCombiningFactory()
+ {
+ return new ExpressionLambdaAggregatorFactory(
+ name,
+ null,
+ accumulatorId,
+ initialValueExpressionString,
+ initialCombineValueExpressionString,
+ foldExpressionString,
+ combineExpressionString,
+ compareExpressionString,
+ finalizeExpressionString,
+ maxSizeBytes,
+ macroTable
+ );
+ }
+
+ @Override
+ public List<AggregatorFactory> getRequiredColumns()
+ {
+ return Collections.singletonList(
+ new ExpressionLambdaAggregatorFactory(
+ name,
+ fields,
+ accumulatorId,
+ initialValueExpressionString,
+ initialCombineValueExpressionString,
+ foldExpressionString,
+ combineExpressionString,
+ compareExpressionString,
+ finalizeExpressionString,
+ maxSizeBytes,
+ macroTable
+ )
+ );
+ }
+
+ @Override
+ public ValueType getType()
+ {
+ if (fields == null) {
+ return ExprType.toValueType(initialCombineValue.get().type());
+ }
+ return ExprType.toValueType(initialValue.get().type());
+ }
+
+ @Override
+ public ValueType getFinalizedType()
+ {
+ Expr finalizeExpr = finalizeExpression.get();
+ ExprEval<?> initialVal = initialCombineValue.get();
+ if (finalizeExpr != null) {
+ return ExprType.toValueType(
+ finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type()
+ );
+ }
+ return ExprType.toValueType(initialVal.type());
+ }
+
+ @Override
+ public int getMaxIntermediateSize()
+ {
+ // numeric expressions are either longs or doubles, with strings or arrays max size is unknown
+ // for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte
+ return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt();
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o;
+ return maxSizeBytes.equals(that.maxSizeBytes)
+ && name.equals(that.name)
+ && Objects.equals(fields, that.fields)
+ && accumulatorId.equals(that.accumulatorId)
+ && foldExpressionString.equals(that.foldExpressionString)
+ && initialValueExpressionString.equals(that.initialValueExpressionString)
+ && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString)
+ && combineExpressionString.equals(that.combineExpressionString)
+ && Objects.equals(compareExpressionString, that.compareExpressionString)
+ && Objects.equals(finalizeExpressionString, that.finalizeExpressionString);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(
+ name,
+ fields,
+ accumulatorId,
+ foldExpressionString,
+ initialValueExpressionString,
+ initialCombineValueExpressionString,
+ combineExpressionString,
+ compareExpressionString,
+ finalizeExpressionString,
+ maxSizeBytes
+ );
+ }
+
+ @Override
+ public String toString()
+ {
+ return "ExpressionLambdaAggregatorFactory{" +
+ "name='" + name + '\'' +
+ ", fields=" + fields +
+ ", accumulatorId='" + accumulatorId + '\'' +
+ ", foldExpressionString='" + foldExpressionString + '\'' +
+ ", initialValueExpressionString='" + initialValueExpressionString + '\'' +
+ ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' +
+ ", combineExpressionString='" + combineExpressionString + '\'' +
+ ", compareExpressionString='" + compareExpressionString + '\'' +
+ ", finalizeExpressionString='" + finalizeExpressionString + '\'' +
+ ", maxSizeBytes=" + maxSizeBytes +
+ '}';
+ }
+
+ /**
+ * Determine how to factorize the aggregator
+ */
+ private class FactorizePlan
+ {
+ private final ExpressionPlan plan;
+
+ private final ExprEval<?> seed;
+ private final ExpressionLambdaAggregatorInputBindings bindings;
+
+ FactorizePlan(ColumnSelectorFactory metricFactory)
+ {
+ final List<String> columns;
+
+ if (fields != null) {
+ // if fields are set, we are accumulating from raw inputs, use fold expression
+ plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get());
+ seed = initialValue.get();
+ columns = plan.getAnalysis().getRequiredBindingsList();
+ } else {
+ // else we are merging intermediary results, use combine expression
+ plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get());
+ seed = initialCombineValue.get();
+ columns = plan.getAnalysis().getRequiredBindingsList();
+ }
+
+ bindings = new ExpressionLambdaAggregatorInputBindings(
+ ExpressionSelectors.createBindings(metricFactory, columns),
+ accumulatorId,
+ seed
+ );
+ }
+
+ public Expr getExpression()
+ {
+ if (fields == null) {
+ return plan.getExpression();
+ }
+ // for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown
+ // inputs
+ return plan.getAppliedFoldExpression(accumulatorId);
+ }
+
+ public ExprEval<?> getInitialValue()
+ {
+ return seed;
+ }
+
+ public ExpressionLambdaAggregatorInputBindings getBindings()
+ {
+ return bindings;
+ }
+
+ private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector)
+ {
+ return new ColumnInspector()
+ {
+ @Nullable
+ @Override
+ public ColumnCapabilities getColumnCapabilities(String column)
+ {
+ if (accumulatorId.equals(column)) {
+ return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type()));
+ }
+ return inspector.getColumnCapabilities(column);
+ }
+
+ @Nullable
+ @Override
+ public ExprType getType(String name)
+ {
+ if (accumulatorId.equals(name)) {
+ return initialValue.get().type();
+ }
+ return inspector.getType(name);
+ }
+ };
+ }
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java
new file mode 100644
index 0000000..5e4864e
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+
+import javax.annotation.Nullable;
+
+/**
+ * Special {@link Expr.ObjectBinding} for use with {@link ExpressionLambdaAggregatorFactory}.
+ * This value binding holds a value for a special 'accumulator' variable, in addition to the 'normal' bindings to the
+ * underlying selector inputs for other identifiers, which allows for easy forward feeding of the results of an
+ * expression evaluation to use in the bindings of the next evaluation.
+ */
+public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBinding
+{
+ private final Expr.ObjectBinding inputBindings;
+ private final String accumlatorIdentifier;
+ private ExprEval<?> accumulator;
+
+ public ExpressionLambdaAggregatorInputBindings(
+ Expr.ObjectBinding inputBindings,
+ String accumulatorIdentifier,
+ ExprEval<?> initialValue
+ )
+ {
+ this.accumlatorIdentifier = accumulatorIdentifier;
+ this.inputBindings = inputBindings;
+ this.accumulator = initialValue;
+ }
+
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ if (accumlatorIdentifier.equals(name)) {
+ return accumulator.value();
+ }
+ return inputBindings.get(name);
+ }
+
+ public void accumulate(ExprEval<?> eval)
+ {
+ accumulator = eval;
+ }
+
+ public ExprEval<?> getAccumulator()
+ {
+ return accumulator;
+ }
+
+ public void setAccumulator(ExprEval<?> acc)
+ {
+ this.accumulator = acc;
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java
new file mode 100644
index 0000000..357dd4b
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+public class ExpressionLambdaBufferAggregator implements BufferAggregator
+{
+ private final Expr lambda;
+ private final ExprEval<?> initialValue;
+ private final ExpressionLambdaAggregatorInputBindings bindings;
+ private final int maxSizeBytes;
+
+ public ExpressionLambdaBufferAggregator(
+ Expr lambda,
+ ExprEval<?> initialValue,
+ ExpressionLambdaAggregatorInputBindings bindings,
+ int maxSizeBytes
+ )
+ {
+ this.lambda = lambda;
+ this.initialValue = initialValue;
+ this.bindings = bindings;
+ this.maxSizeBytes = maxSizeBytes;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ ExprEval.serialize(buf, position, initialValue, maxSizeBytes);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position)
+ {
+ ExprEval<?> acc = ExprEval.deserialize(buf, position);
+ bindings.setAccumulator(acc);
+ ExprEval<?> newAcc = lambda.eval(bindings);
+ ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
+ }
+
+ @Nullable
+ @Override
+ public Object get(ByteBuffer buf, int position)
+ {
+ return ExprEval.deserialize(buf, position).value();
+ }
+
+ @Override
+ public float getFloat(ByteBuffer buf, int position)
+ {
+ return (float) ExprEval.deserialize(buf, position).asDouble();
+ }
+
+ @Override
+ public double getDouble(ByteBuffer buf, int position)
+ {
+ return ExprEval.deserialize(buf, position).asDouble();
+ }
+
+ @Override
+ public long getLong(ByteBuffer buf, int position)
+ {
+ return ExprEval.deserialize(buf, position).asLong();
+ }
+
+ @Override
+ public void close()
+ {
+ // nothing to close
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
index b864432..c540018 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java
@@ -23,7 +23,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@@ -72,7 +71,7 @@
this.fieldName = fieldName;
this.expression = expression;
this.storeDoubleAsFloat = ColumnHolder.storeDoubleAsFloat();
- this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
+ this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
index 380ceb1..03b9f92 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java
@@ -23,7 +23,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@@ -63,7 +62,7 @@
this.name = name;
this.fieldName = fieldName;
this.expression = expression;
- this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
+ this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
index 7d148d5..bf297e1 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java
@@ -23,7 +23,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@@ -69,7 +68,7 @@
this.name = name;
this.fieldName = fieldName;
this.expression = expression;
- this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
+ this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java
index 978bf3e..34cbb16 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java
@@ -92,7 +92,7 @@
ordering,
macroTable,
ImmutableMap.of(),
- Suppliers.memoize(() -> Parser.parse(expression, macroTable))
+ Parser.lazyParse(expression, macroTable)
);
}
diff --git a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java
index 2f65cc6..6692733 100644
--- a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java
+++ b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java
@@ -25,7 +25,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
import com.google.common.collect.RangeSet;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
@@ -53,7 +52,7 @@
{
this.expression = expression;
this.filterTuning = filterTuning;
- this.parsed = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
+ this.parsed = Parser.lazyParse(expression, macroTable);
}
@VisibleForTesting
diff --git a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
index 2ace9b0..caa8daf 100644
--- a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
+++ b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
@@ -26,6 +26,7 @@
import com.google.common.base.Suppliers;
import org.apache.druid.data.input.Row;
import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.ColumnHolder;
@@ -106,7 +107,7 @@
} else {
Object raw = row.getRaw(column);
if (raw instanceof List) {
- return ExpressionSelectors.coerceListToArray((List) raw);
+ return ExprEval.coerceListToArray((List) raw, true);
}
return raw;
}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java
index 38a3fc3..b1ab3a9 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java
@@ -122,6 +122,14 @@
return expression;
}
+ public Expr getAppliedFoldExpression(String accumulatorId)
+ {
+ if (is(Trait.NEEDS_APPLIED)) {
+ return Parser.foldUnappliedBindings(expression, analysis, unappliedInputs, accumulatorId);
+ }
+ return expression;
+ }
+
public Expr.BindingAnalysis getAnalysis()
{
return analysis;
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
index 0ff00b6..d910d7b 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
@@ -24,7 +24,6 @@
import com.google.common.base.Supplier;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
-import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.Parser;
@@ -242,13 +241,26 @@
* provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
* are used as array or scalar inputs
*/
- private static Expr.ObjectBinding createBindings(
+ public static Expr.ObjectBinding createBindings(
Expr.BindingAnalysis bindingAnalysis,
ColumnSelectorFactory columnSelectorFactory
)
{
- final Map<String, Supplier<Object>> suppliers = new HashMap<>();
final List<String> columns = bindingAnalysis.getRequiredBindingsList();
+ return createBindings(columnSelectorFactory, columns);
+ }
+
+ /**
+ * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which
+ * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
+ * are used as array or scalar inputs
+ */
+ public static Expr.ObjectBinding createBindings(
+ ColumnSelectorFactory columnSelectorFactory,
+ List<String> columns
+ )
+ {
+ final Map<String, Supplier<Object>> suppliers = new HashMap<>();
for (String columnName : columns) {
final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName);
final ValueType nativeType = columnCapabilities != null ? columnCapabilities.getType() : null;
@@ -269,8 +281,8 @@
columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)),
multiVal
);
- } else if (nativeType == null) {
- // Unknown ValueType. Try making an Object selector and see if that gives us anything useful.
+ } else if (nativeType == null || ValueType.isArray(nativeType)) {
+ // Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful.
supplier = supplierFromObjectSelector(columnSelectorFactory.makeColumnValueSelector(columnName));
} else {
// Unhandleable ValueType (COMPLEX).
@@ -370,10 +382,10 @@
// Might be Numbers and Strings. Use a selector that double-checks.
return () -> {
final Object val = selector.getObject();
- if (val instanceof Number || val instanceof String) {
+ if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) {
return val;
} else if (val instanceof List) {
- return coerceListToArray((List) val);
+ return ExprEval.coerceListToArray((List) val, true);
} else {
return null;
}
@@ -382,7 +394,7 @@
return () -> {
final Object val = selector.getObject();
if (val != null) {
- return coerceListToArray((List) val);
+ return ExprEval.coerceListToArray((List) val, true);
}
return null;
};
@@ -393,70 +405,6 @@
}
/**
- * Selectors are not consistent in treatment of null, [], and [null], so coerce [] to [null]
- */
- public static Object coerceListToArray(@Nullable List<?> val)
- {
- if (val != null && val.size() > 0) {
- Class coercedType = null;
-
- for (Object elem : val) {
- if (elem != null) {
- coercedType = convertType(coercedType, elem.getClass());
- }
- }
-
- if (coercedType == Long.class || coercedType == Integer.class) {
- return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
- }
- if (coercedType == Float.class || coercedType == Double.class) {
- return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
- }
- // default to string
- return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
- }
- return new String[]{null};
- }
-
- private static Class convertType(@Nullable Class existing, Class next)
- {
- if (Number.class.isAssignableFrom(next) || next == String.class) {
- if (existing == null) {
- return next;
- }
- // string wins everything
- if (existing == String.class) {
- return existing;
- }
- if (next == String.class) {
- return next;
- }
- // all numbers win over Integer
- if (existing == Integer.class) {
- return next;
- }
- if (existing == Float.class) {
- // doubles win over floats
- if (next == Double.class) {
- return next;
- }
- return existing;
- }
- if (existing == Long.class) {
- if (next == Integer.class) {
- // long beats int
- return existing;
- }
- // double and float win over longs
- return next;
- }
- // otherwise double
- return Double.class;
- }
- throw new UOE("Invalid array expression type: %s", next);
- }
-
- /**
* Coerces {@link ExprEval} value back to selector friendly {@link List} if the evaluated expression result is an
* array type
*/
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java
index 4a7635c..343cd8a 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java
@@ -72,7 +72,7 @@
this.name = Preconditions.checkNotNull(name, "name");
this.expression = Preconditions.checkNotNull(expression, "expression");
this.outputType = outputType;
- this.parsedExpression = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
+ this.parsedExpression = Parser.lazyParse(expression, macroTable);
}
/**
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
new file mode 100644
index 0000000..5b5c296
--- /dev/null
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
@@ -0,0 +1,570 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.druid.java.util.common.HumanReadableBytes;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.query.Druids;
+import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
+import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
+import org.apache.druid.query.expression.TestExprMacroTable;
+import org.apache.druid.query.timeseries.TimeseriesQuery;
+import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
+import org.apache.druid.segment.TestHelper;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.io.IOException;
+
+public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandlingTest
+{
+ private static ObjectMapper MAPPER = TestHelper.makeJsonMapper();
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testSerde() throws IOException
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ "customAccumulator",
+ "0.0",
+ "10.0",
+ "customAccumulator + some_column + some_other_column",
+ "customAccumulator + expr_agg_name",
+ "if (o1 > o2, if (o1 == o2, 0, 1), -1)",
+ "o + 100",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(agg, MAPPER.readValue(MAPPER.writeValueAsBytes(agg), ExpressionLambdaAggregatorFactory.class));
+ }
+
+ @Test
+ public void testEqualsAndHashCode()
+ {
+ EqualsVerifier.forClass(ExpressionLambdaAggregatorFactory.class)
+ .usingGetClass()
+ .withIgnoredFields(
+ "macroTable",
+ "initialValue",
+ "initialCombineValue",
+ "foldExpression",
+ "combineExpression",
+ "compareExpression",
+ "finalizeExpression",
+ "compareBindings",
+ "combineBindings",
+ "finalizeBindings"
+ )
+ .verify();
+ }
+
+ @Test
+ public void testInitialValueMustBeConstant()
+ {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("initial value must be constant");
+
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "x + y",
+ null,
+ "__acc + some_column + some_other_column",
+ "__acc + expr_agg_name",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ agg.getType();
+ }
+
+ @Test
+ public void testInitialCombineValueMustBeConstant()
+ {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("initial combining value must be constant");
+
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0.0",
+ "x + y",
+ "__acc + some_column + some_other_column",
+ "__acc + expr_agg_name",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ agg.getFinalizedType();
+ }
+
+ @Test
+ public void testSingleInputCombineExpressionIsOptional()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ "__acc + x",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(1L, agg.combine(0L, 1L));
+ }
+
+ @Test
+ public void testFinalizeCanDo()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ "__acc + x",
+ null,
+ null,
+ "o + 100",
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(100L, agg.finalizeComputation(0L));
+ }
+
+ @Test
+ public void testFinalizeCanDoArrays()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("x"),
+ null,
+ "0",
+ null,
+ "array_set_add(__acc, x)",
+ "array_set_add_all(__acc, expr_agg_name)",
+ null,
+ "array_to_string(o, ',')",
+ null,
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals("a,b,c", agg.finalizeComputation(new String[]{"a", "b", "c"}));
+ Assert.assertEquals("a,b,c", agg.finalizeComputation(ImmutableList.of("a", "b", "c")));
+ }
+
+ @Test
+ public void testStringType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "''",
+ "''",
+ "concat(__acc, some_column, some_other_column)",
+ "concat(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.STRING, agg.getType());
+ Assert.assertEquals(ValueType.STRING, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testLongType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ null,
+ "__acc + some_column + some_other_column",
+ "__acc + expr_agg_name",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.LONG, agg.getType());
+ Assert.assertEquals(ValueType.LONG, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.LONG, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testDoubleType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0.0",
+ null,
+ "__acc + some_column + some_other_column",
+ "__acc + expr_agg_name",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.DOUBLE, agg.getType());
+ Assert.assertEquals(ValueType.DOUBLE, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.DOUBLE, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testStringArrayType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "''",
+ "<STRING>[]",
+ "concat(__acc, some_column, some_other_column)",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.STRING, agg.getType());
+ Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.STRING_ARRAY, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testStringArrayTypeFinalized()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "''",
+ "<STRING>[]",
+ "concat(__acc, some_column, some_other_column)",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ "array_to_string(o, ';')",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.STRING, agg.getType());
+ Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testLongArrayType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ "<LONG>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.LONG, agg.getType());
+ Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.LONG_ARRAY, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testLongArrayTypeFinalized()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ "<LONG>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ "array_to_string(o, ';')",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.LONG, agg.getType());
+ Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testDoubleArrayType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.DOUBLE, agg.getType());
+ Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testDoubleArrayTypeFinalized()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, expr_agg_name)",
+ null,
+ "array_to_string(o, ';')",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(ValueType.DOUBLE, agg.getType());
+ Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testResultArraySignature()
+ {
+ final TimeseriesQuery query =
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource("dummy")
+ .intervals("2000/3000")
+ .granularity(Granularities.HOUR)
+ .aggregators(
+ new ExpressionLambdaAggregatorFactory(
+ "string_expr",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "''",
+ "''",
+ "concat(__acc, some_column, some_other_column)",
+ "concat(__acc, string_expr)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "double_expr",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0.0",
+ null,
+ "__acc + some_column + some_other_column",
+ "__acc + double_expr",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "long_expr",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ null,
+ "__acc + some_column + some_other_column",
+ "__acc + long_expr",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "string_array_expr",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "<STRING>[]",
+ "<STRING>[]",
+ "array_set_add(__acc, concat(some_column, some_other_column))",
+ "array_set_add_all(__acc, string_array_expr)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "double_array_expr",
+ ImmutableSet.of("some_column", "some_other_column_expr"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, double_array)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "long_array_expr",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ "<LONG>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, long_array_expr)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "string_array_expr_finalized",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "''",
+ "<STRING>[]",
+ "concat(__acc, some_column, some_other_column)",
+ "array_set_add(__acc, string_array_expr)",
+ null,
+ "array_to_string(o, ';')",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "double_array_expr_finalized",
+ ImmutableSet.of("some_column", "some_other_column_expr"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, double_array)",
+ null,
+ "array_to_string(o, ';')",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "long_array_expr_finalized",
+ ImmutableSet.of("some_column", "some_other_column"),
+ null,
+ "0",
+ "<LONG>[]",
+ "__acc + some_column + some_other_column",
+ "array_set_add(__acc, long_array_expr)",
+ null,
+ "fold((x, acc) -> x + acc, o, 0)",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .postAggregators(
+ new FieldAccessPostAggregator("string-array-expr-access", "string_array_expr_finalized"),
+ new FinalizingFieldAccessPostAggregator("string-array-expr-finalize", "string_array_expr_finalized"),
+ new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"),
+ new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"),
+ new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"),
+ new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized")
+ )
+ .build();
+
+ Assert.assertEquals(
+ RowSignature.builder()
+ .addTimeColumn()
+ .add("string_expr", ValueType.STRING)
+ .add("double_expr", ValueType.DOUBLE)
+ .add("long_expr", ValueType.LONG)
+ .add("string_array_expr", ValueType.STRING_ARRAY)
+ // type does not equal finalized type. (combining factory type does equal finalized type,
+ // but this signature doesn't use combining factory)
+ .add("double_array_expr", null)
+ // type does not equal finalized type. (combining factory type does equal finalized type,
+ // but this signature doesn't use combining factory)
+ .add("long_array_expr", null)
+ // string because fold type equals finalized type, even though merge type is array
+ .add("string_array_expr_finalized", ValueType.STRING)
+ // type does not equal finalized type. (combining factory type does equal finalized type,
+ // but this signature doesn't use combining factory)
+ .add("double_array_expr_finalized", null)
+ // long because fold type equals finalized type, even though merge type is array
+ .add("long_array_expr_finalized", ValueType.LONG)
+ // fold type is string
+ .add("string-array-expr-access", ValueType.STRING)
+ // finalized type is string
+ .add("string-array-expr-finalize", ValueType.STRING)
+ // double because fold type is double
+ .add("double-array-expr-access", ValueType.DOUBLE)
+ // string because finalize type is string
+ .add("double-array-expr-finalize", ValueType.STRING)
+ // long because fold type is long
+ .add("long-array-expr-access", ValueType.LONG)
+ // finalized type is long
+ .add("long-array-expr-finalize", ValueType.LONG)
+ .build(),
+ new TimeseriesQueryQueryToolChest().resultArraySignature(query)
+ );
+ }
+}
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
index bc0f3b2..8d0c864 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
@@ -25,6 +25,7 @@
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
@@ -67,6 +68,7 @@
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory;
@@ -11138,6 +11140,704 @@
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
+ @Test
+ public void testGroupByWithExpressionAggregator()
+ {
+ // expression agg not yet vectorized
+ cannotVectorize();
+ GroupByQuery query = makeQueryBuilder()
+ .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
+ .setDimensions(new DefaultDimensionSpec("quality", "alias"))
+ .setAggregatorSpecs(
+ new ExpressionLambdaAggregatorFactory(
+ "rows",
+ Collections.emptySet(),
+ null,
+ "0",
+ null,
+ "__acc + 1",
+ "__acc + rows",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "idx",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ null,
+ "__acc + index",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .setGranularity(QueryRunnerTestHelper.DAY_GRAN)
+ .build();
+
+ List<ResultRow> expectedResults = Arrays.asList(
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "automotive",
+ "rows",
+ 1L,
+ "idx",
+ 135.88510131835938d
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "business",
+ "rows",
+ 1L,
+ "idx",
+ 118.57034
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "entertainment",
+ "rows",
+ 1L,
+ "idx",
+ 158.747224
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "health",
+ "rows",
+ 1L,
+ "idx",
+ 120.134704
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "mezzanine",
+ "rows",
+ 3L,
+ "idx",
+ 2871.8866900000003d
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "news",
+ "rows",
+ 1L,
+ "idx",
+ 121.58358d
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "premium",
+ "rows",
+ 3L,
+ "idx",
+ 2900.798647d
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "technology",
+ "rows",
+ 1L,
+ "idx",
+ 78.622547d
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "travel",
+ "rows",
+ 1L,
+ "idx",
+ 119.922742d
+ ),
+
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "automotive",
+ "rows",
+ 1L,
+ "idx",
+ 147.42593d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "business",
+ "rows",
+ 1L,
+ "idx",
+ 112.987027d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "entertainment",
+ "rows",
+ 1L,
+ "idx",
+ 166.016049d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "health",
+ "rows",
+ 1L,
+ "idx",
+ 113.446008d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "mezzanine",
+ "rows",
+ 3L,
+ "idx",
+ 2448.830613d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "news",
+ "rows",
+ 1L,
+ "idx",
+ 114.290141d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "premium",
+ "rows",
+ 3L,
+ "idx",
+ 2506.415148d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "technology",
+ "rows",
+ 1L,
+ "idx",
+ 97.387433d
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "travel",
+ "rows",
+ 1L,
+ "idx",
+ 126.411364d
+ )
+ );
+
+ Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
+ TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
+ }
+
+ @Test
+ public void testGroupByWithExpressionAggregatorWithArrays()
+ {
+ // expression agg not yet vectorized
+ cannotVectorize();
+
+ // array types don't work with group by v1
+ if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
+ }
+
+ GroupByQuery query = makeQueryBuilder()
+ .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
+ .setDimensions(new DefaultDimensionSpec("quality", "alias"))
+ .setAggregatorSpecs(
+ new ExpressionLambdaAggregatorFactory(
+ "rows",
+ Collections.emptySet(),
+ null,
+ "0",
+ null,
+ "__acc + 1",
+ "__acc + rows",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "idx",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ null,
+ "__acc + index",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "array_agg_distinct",
+ ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
+ "acc",
+ "[]",
+ null,
+ "array_set_add(acc, market)",
+ "array_set_add_all(acc, array_agg_distinct)",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .setGranularity(QueryRunnerTestHelper.DAY_GRAN)
+ .build();
+
+ List<ResultRow> expectedResults = Arrays.asList(
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "automotive",
+ "rows",
+ 1L,
+ "idx",
+ 135.88510131835938d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "business",
+ "rows",
+ 1L,
+ "idx",
+ 118.57034,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "entertainment",
+ "rows",
+ 1L,
+ "idx",
+ 158.747224,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "health",
+ "rows",
+ 1L,
+ "idx",
+ 120.134704,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "mezzanine",
+ "rows",
+ 3L,
+ "idx",
+ 2871.8866900000003d,
+ "array_agg_distinct",
+ new String[] {"upfront", "spot", "total_market"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "news",
+ "rows",
+ 1L,
+ "idx",
+ 121.58358d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "premium",
+ "rows",
+ 3L,
+ "idx",
+ 2900.798647d,
+ "array_agg_distinct",
+ new String[] {"upfront", "spot", "total_market"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "technology",
+ "rows",
+ 1L,
+ "idx",
+ 78.622547d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "travel",
+ "rows",
+ 1L,
+ "idx",
+ 119.922742d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "automotive",
+ "rows",
+ 1L,
+ "idx",
+ 147.42593d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "business",
+ "rows",
+ 1L,
+ "idx",
+ 112.987027d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "entertainment",
+ "rows",
+ 1L,
+ "idx",
+ 166.016049d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "health",
+ "rows",
+ 1L,
+ "idx",
+ 113.446008d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "mezzanine",
+ "rows",
+ 3L,
+ "idx",
+ 2448.830613d,
+ "array_agg_distinct",
+ new String[] {"upfront", "spot", "total_market"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "news",
+ "rows",
+ 1L,
+ "idx",
+ 114.290141d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "premium",
+ "rows",
+ 3L,
+ "idx",
+ 2506.415148d,
+ "array_agg_distinct",
+ new String[] {"upfront", "spot", "total_market"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "technology",
+ "rows",
+ 1L,
+ "idx",
+ 97.387433d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "travel",
+ "rows",
+ 1L,
+ "idx",
+ 126.411364d,
+ "array_agg_distinct",
+ new String[] {"spot"}
+ )
+ );
+
+ Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
+ TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
+ }
+
+ @Test
+ public void testGroupByExpressionAggregatorArrayMultiValue()
+ {
+ // expression agg not yet vectorized
+ cannotVectorize();
+
+ // array types don't work with group by v1
+ if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
+ }
+
+ GroupByQuery query = makeQueryBuilder()
+ .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
+ .setDimensions(new DefaultDimensionSpec("quality", "alias"))
+ .setAggregatorSpecs(
+ new ExpressionLambdaAggregatorFactory(
+ "array_agg_distinct",
+ ImmutableSet.of(QueryRunnerTestHelper.PLACEMENTISH_DIMENSION),
+ "acc",
+ "[]",
+ null,
+ "array_set_add(acc, placementish)",
+ "array_set_add_all(acc, array_agg_distinct)",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .setGranularity(QueryRunnerTestHelper.DAY_GRAN)
+ .build();
+
+ List<ResultRow> expectedResults = Arrays.asList(
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "automotive",
+ "array_agg_distinct",
+ new String[] {"a", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "business",
+ "array_agg_distinct",
+ new String[] {"b", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "entertainment",
+ "array_agg_distinct",
+ new String[] {"e", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "health",
+ "array_agg_distinct",
+ new String[] {"h", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "mezzanine",
+ "array_agg_distinct",
+ new String[] {"m", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "news",
+ "array_agg_distinct",
+ new String[] {"n", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "premium",
+ "array_agg_distinct",
+ new String[] {"p", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "technology",
+ "array_agg_distinct",
+ new String[] {"t", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-01",
+ "alias",
+ "travel",
+ "array_agg_distinct",
+ new String[] {"t", "preferred"}
+ ),
+
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "automotive",
+ "array_agg_distinct",
+ new String[] {"a", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "business",
+ "array_agg_distinct",
+ new String[] {"b", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "entertainment",
+ "array_agg_distinct",
+ new String[] {"e", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "health",
+ "array_agg_distinct",
+ new String[] {"h", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "mezzanine",
+ "array_agg_distinct",
+ new String[] {"m", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "news",
+ "array_agg_distinct",
+ new String[] {"n", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "premium",
+ "array_agg_distinct",
+ new String[] {"p", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "technology",
+ "array_agg_distinct",
+ new String[] {"t", "preferred"}
+ ),
+ makeRow(
+ query,
+ "2011-04-02",
+ "alias",
+ "travel",
+ "array_agg_distinct",
+ new String[] {"t", "preferred"}
+ )
+ );
+
+ Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
+ TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
+ }
+
private static ResultRow makeRow(final GroupByQuery query, final String timestamp, final Object... vals)
{
return GroupByQueryRunnerTestHelper.createExpectedRow(query, timestamp, vals);
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java
index 8b80fc4..50810c0 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java
@@ -303,32 +303,4 @@
// Skip this test because the timeseries test expects a day that doesn't have a filter match to be filled in,
// but group by just doesn't return a value if the filter doesn't match.
}
-
- @Override
- public void testTimeseriesWithTimestampResultFieldContextForArrayResponse()
- {
- // Cannot vectorize with an expression virtual column
- if (!vectorize) {
- super.testTimeseriesWithTimestampResultFieldContextForArrayResponse();
- }
- }
-
- @Override
- public void testTimeseriesWithTimestampResultFieldContextForMapResponse()
- {
- // Cannot vectorize with an expression virtual column
- if (!vectorize) {
- super.testTimeseriesWithTimestampResultFieldContextForMapResponse();
- }
- }
-
- @Override
- @Test
- public void testTimeseriesWithPostAggregatorReferencingTimestampResultField()
- {
- // Cannot vectorize with an expression virtual column
- if (!vectorize) {
- super.testTimeseriesWithPostAggregatorReferencingTimestampResultField();
- }
- }
}
diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
index c1ae6d4..58b2a90 100644
--- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
@@ -21,6 +21,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.primitives.Doubles;
@@ -44,6 +45,7 @@
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
@@ -2935,6 +2937,104 @@
assertExpectedResults(expectedResults, results);
}
+ @Test
+ public void testTimeseriesWithExpressionAggregator()
+ {
+ // expression agg cannot vectorize
+ cannotVectorize();
+ TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
+ .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .granularity(QueryRunnerTestHelper.DAY_GRAN)
+ .intervals(QueryRunnerTestHelper.FIRST_TO_THIRD)
+ .aggregators(
+ Arrays.asList(
+ new ExpressionLambdaAggregatorFactory(
+ "diy_count",
+ ImmutableSet.of(),
+ null,
+ "0",
+ null,
+ "__acc + 1",
+ "__acc + diy_count",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "diy_sum",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ null,
+ "__acc + index",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "diy_decomposed_sum",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + index",
+ "array_concat(__acc, diy_decomposed_sum)",
+ null,
+ "fold((x, acc) -> x + acc, o, 0.0)",
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "array_agg_distinct",
+ ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
+ "acc",
+ "[]",
+ null,
+ "array_set_add(acc, market)",
+ "array_set_add_all(acc, array_agg_distinct)",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ )
+ .descending(descending)
+ .context(makeContext())
+ .build();
+
+ List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
+ new Result<>(
+ DateTimes.of("2011-04-01"),
+ new TimeseriesResultValue(
+ ImmutableMap.of(
+ "diy_count", 13L,
+ "diy_sum", 6626.151569,
+ "diy_decomposed_sum", 6626.151569,
+ "array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
+ )
+ )
+ ),
+ new Result<>(
+ DateTimes.of("2011-04-02"),
+ new TimeseriesResultValue(
+ ImmutableMap.of(
+ "diy_count", 13L,
+ "diy_sum", 5833.209718,
+ "diy_decomposed_sum", 5833.209718,
+ "array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
+ )
+ )
+ )
+ );
+
+ Iterable<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
+ assertExpectedResults(expectedResults, results);
+ }
+
private Map<String, Object> makeContext()
{
return makeContext(ImmutableMap.of());
diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
index f98160a..eb8709d 100644
--- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
@@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@@ -52,6 +53,7 @@
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
@@ -5964,6 +5966,110 @@
assertExpectedResults(expectedResults, query);
}
+ @Test
+ public void testExpressionAggregator()
+ {
+ // sorted by array length of array_agg_distinct
+ TopNQuery query = new TopNQueryBuilder()
+ .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .granularity(QueryRunnerTestHelper.ALL_GRAN)
+ .dimension(QueryRunnerTestHelper.MARKET_DIMENSION)
+ .metric("array_agg_distinct")
+ .threshold(4)
+ .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+ .aggregators(
+ Lists.newArrayList(
+ Arrays.asList(
+ new ExpressionLambdaAggregatorFactory(
+ "diy_count",
+ Collections.emptySet(),
+ null,
+ "0",
+ null,
+ "__acc + 1",
+ "__acc + diy_count",
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "diy_sum",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ null,
+ "__acc + index",
+ null,
+ null,
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "diy_decomposed_sum",
+ ImmutableSet.of("index"),
+ null,
+ "0.0",
+ "<DOUBLE>[]",
+ "__acc + index",
+ "array_concat(__acc, diy_decomposed_sum)",
+ null,
+ "fold((x, acc) -> x + acc, o, 0.0)",
+ null,
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "array_agg_distinct",
+ ImmutableSet.of(QueryRunnerTestHelper.QUALITY_DIMENSION),
+ "acc",
+ "[]",
+ null,
+ "array_set_add(acc, quality)",
+ "array_set_add_all(acc, array_agg_distinct)",
+ "if(array_length(o1) > array_length(o2), 1, if (array_length(o1) == array_length(o2), 0, -1))",
+ null,
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ )
+ )
+ .build();
+
+ List<Result<TopNResultValue>> expectedResults = Collections.singletonList(
+ new Result<>(
+ DateTimes.of("2011-01-12T00:00:00.000Z"),
+ new TopNResultValue(
+ Arrays.<Map<String, Object>>asList(
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot")
+ .put("diy_count", 837L)
+ .put("diy_sum", 95606.57232284546D)
+ .put("diy_decomposed_sum", 95606.57232284546D)
+ .put("array_agg_distinct", new String[]{"mezzanine", "news", "premium", "business", "entertainment", "health", "technology", "automotive", "travel"})
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market")
+ .put("diy_count", 186L)
+ .put("diy_sum", 215679.82879638672D)
+ .put("diy_decomposed_sum", 215679.82879638672D)
+ .put("array_agg_distinct", new String[]{"mezzanine", "premium"})
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront")
+ .put("diy_count", 186L)
+ .put("diy_sum", 192046.1060180664D)
+ .put("diy_decomposed_sum", 192046.1060180664D)
+ .put("array_agg_distinct", new String[]{"mezzanine", "premium"})
+ .build()
+ )
+ )
+ )
+ );
+ assertExpectedResults(expectedResults, query);
+ }
+
private static Map<String, Object> makeRowWithNulls(
String dimName,
@Nullable Object dimValue,
diff --git a/processing/src/test/java/org/apache/druid/segment/TestHelper.java b/processing/src/test/java/org/apache/druid/segment/TestHelper.java
index a2e409d..59f739a 100644
--- a/processing/src/test/java/org/apache/druid/segment/TestHelper.java
+++ b/processing/src/test/java/org/apache/druid/segment/TestHelper.java
@@ -32,6 +32,7 @@
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Result;
import org.apache.druid.query.expression.TestExprMacroTable;
@@ -352,7 +353,9 @@
final Object expectedValue = expectedMap.get(key);
final Object actualValue = actualMap.get(key);
- if (expectedValue instanceof Float || expectedValue instanceof Double) {
+ if (expectedValue != null && expectedValue.getClass().isArray()) {
+ Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue);
+ } else if (expectedValue instanceof Float || expectedValue instanceof Double) {
Assert.assertEquals(
StringUtils.format("%s: key[%s]", msg, key),
((Number) expectedValue).doubleValue(),
@@ -382,7 +385,23 @@
final Object expectedValue = expected.get(i);
final Object actualValue = actual.get(i);
- if (expectedValue instanceof Float || expectedValue instanceof Double) {
+
+ if (expectedValue != null && expectedValue.getClass().isArray()) {
+ // spilled results will materialize into lists, coerce them back to arrays if we expected arrays
+ if (actualValue instanceof List) {
+ Assert.assertEquals(
+ message,
+ (Object[]) expectedValue,
+ (Object[]) ExprEval.coerceListToArray((List) actualValue, true)
+ );
+ } else {
+ Assert.assertArrayEquals(
+ message,
+ (Object[]) expectedValue,
+ (Object[]) actualValue
+ );
+ }
+ } else if (expectedValue instanceof Float || expectedValue instanceof Double) {
Assert.assertEquals(
message,
((Number) expectedValue).doubleValue(),
diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java
index 64da13d..e1a9f27 100644
--- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java
@@ -244,77 +244,6 @@
}
@Test
- public void test_coerceListToArray()
- {
- List<Long> longList = ImmutableList.of(1L, 2L, 3L);
- Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(longList));
-
- List<Integer> intList = ImmutableList.of(1, 2, 3);
- Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(intList));
-
- List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
- Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(floatList));
-
- List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
- Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(doubleList));
-
- List<String> stringList = ImmutableList.of("a", "b", "c");
- Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExpressionSelectors.coerceListToArray(stringList));
-
- List<String> withNulls = new ArrayList<>();
- withNulls.add("a");
- withNulls.add(null);
- withNulls.add("c");
- Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExpressionSelectors.coerceListToArray(withNulls));
-
- List<Long> withNumberNulls = new ArrayList<>();
- withNumberNulls.add(1L);
- withNumberNulls.add(null);
- withNumberNulls.add(3L);
-
- Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(withNumberNulls));
-
- List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
- Assert.assertArrayEquals(
- new String[]{"1", "b", "3"},
- (String[]) ExpressionSelectors.coerceListToArray(withStringMix)
- );
-
- List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
- Assert.assertArrayEquals(
- new Long[]{1L, 2L, 3L},
- (Long[]) ExpressionSelectors.coerceListToArray(withIntsAndLongs)
- );
-
- List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
- Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndLongs)
- );
-
- List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
- Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExpressionSelectors.coerceListToArray(withDoublesAndLongs)
- );
-
- List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
- Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndDoubles)
- );
-
- List<String> withAllNulls = new ArrayList<>();
- withAllNulls.add(null);
- withAllNulls.add(null);
- withAllNulls.add(null);
- Assert.assertArrayEquals(
- new String[]{null, null, null},
- (String[]) ExpressionSelectors.coerceListToArray(withAllNulls)
- );
- }
-
- @Test
public void test_coerceEvalToSelectorObject()
{
Assert.assertEquals(
diff --git a/website/.spelling b/website/.spelling
index 315e5ff..d717fbf 100644
--- a/website/.spelling
+++ b/website/.spelling
@@ -1100,6 +1100,8 @@
arr2
array_append
array_concat
+array_set_add
+array_set_add_all
array_contains
array_length
array_offset