Vectorized variance aggregators (#10390)
* wip vectorize
* close but not quite
* faster
* unit tests
* fix complex types for variance
diff --git a/benchmarks/pom.xml b/benchmarks/pom.xml
index f00b931..4d21685 100644
--- a/benchmarks/pom.xml
+++ b/benchmarks/pom.xml
@@ -83,6 +83,11 @@
<version>${project.parent.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.druid.extensions</groupId>
+ <artifactId>druid-stats</artifactId>
+ <version>${project.parent.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
<version>${project.parent.version}</version>
@@ -172,7 +177,7 @@
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-protobuf-extensions</artifactId>
- <version>0.20.0-SNAPSHOT</version>
+ <version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java
new file mode 100644
index 0000000..85b7c4d
--- /dev/null
+++ b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.benchmark;
+
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.util.Random;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+@State(Scope.Benchmark)
+@Fork(value = 1)
+@Warmup(iterations = 5)
+@Measurement(iterations = 5)
+public class VarianceBenchmark
+{
+ @Param({"128", "256", "512", "1024"})
+ int vectorSize;
+
+ private float[] randomValues;
+
+ @Setup
+ public void setup()
+ {
+ randomValues = new float[vectorSize];
+ Random r = ThreadLocalRandom.current();
+ for (int i = 0; i < vectorSize; i++) {
+ randomValues[i] = r.nextFloat();
+ }
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.NANOSECONDS)
+ public void collectVarianceOneByOne(Blackhole blackhole)
+ {
+ VarianceAggregatorCollector collector = new VarianceAggregatorCollector();
+ for (float v : randomValues) {
+ collector.add(v);
+ }
+ blackhole.consume(collector);
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.NANOSECONDS)
+ public void collectVarianceInBatch(Blackhole blackhole)
+ {
+ double sum = 0, nvariance = 0;
+ for (float v : randomValues) {
+ sum += v;
+ }
+ double mean = sum / randomValues.length;
+ for (float v : randomValues) {
+ nvariance += (v - mean) * (v - mean);
+ }
+ VarianceAggregatorCollector collector = new VarianceAggregatorCollector(randomValues.length, sum, nvariance);
+ blackhole.consume(collector);
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
index ce0edb0..6526a86 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
@@ -76,6 +76,7 @@
if (other == null || other.count == 0) {
return;
}
+
if (this.count == 0) {
this.nvariance = other.nvariance;
this.count = other.count;
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
index e9b59b4..2894c01 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
@@ -22,6 +22,7 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
@@ -35,12 +36,15 @@
import org.apache.druid.query.aggregation.NoopAggregator;
import org.apache.druid.query.aggregation.NoopBufferAggregator;
import org.apache.druid.query.aggregation.ObjectAggregateCombiner;
+import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
+import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -83,7 +87,8 @@
this.inputType = inputType;
}
- public VarianceAggregatorFactory(String name, String fieldName)
+ @VisibleForTesting
+ VarianceAggregatorFactory(String name, String fieldName)
{
this(name, fieldName, null, null);
}
@@ -131,7 +136,7 @@
return new VarianceAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.LongVarianceAggregator(selector);
- } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
@@ -156,17 +161,43 @@
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.LongVarianceAggregator(selector);
- } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
- inputType
+ type
);
}
@Override
+ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
+ {
+ final String type = getTypeString(selectorFactory);
+ if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
+ return new VarianceFloatVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
+ return new VarianceDoubleVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
+ return new VarianceLongVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
+ return new VarianceObjectVectorAggregator(selectorFactory.makeObjectSelector(fieldName));
+ }
+ throw new IAE(
+ "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
+ fieldName,
+ type
+ );
+ }
+
+ @Override
+ public boolean canVectorize(ColumnInspector columnInspector)
+ {
+ return true;
+ }
+
+ @Override
public Object combine(Object lhs, Object rhs)
{
return VarianceAggregatorCollector.combineValues(lhs, rhs);
@@ -340,11 +371,11 @@
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
}
- private String getTypeString(ColumnSelectorFactory metricFactory)
+ private String getTypeString(ColumnInspector columnInspector)
{
String type = inputType;
if (type == null) {
- ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
+ ColumnCapabilities capabilities = columnInspector.getColumnCapabilities(fieldName);
if (capabilities != null) {
type = StringUtils.toLowerCase(capabilities.getType().name());
} else {
@@ -353,5 +384,4 @@
}
return type;
}
-
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
index 51ec0b1..065ad2a 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
@@ -35,25 +35,19 @@
public abstract class VarianceBufferAggregator implements BufferAggregator
{
private static final int COUNT_OFFSET = 0;
- private static final int SUM_OFFSET = Long.BYTES;
+ private static final int SUM_OFFSET = COUNT_OFFSET + Long.BYTES;
private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES;
@Override
public void init(final ByteBuffer buf, final int position)
{
- buf.putLong(position + COUNT_OFFSET, 0)
- .putDouble(position + SUM_OFFSET, 0)
- .putDouble(position + NVARIANCE_OFFSET, 0);
+ doInit(buf, position);
}
@Override
- public Object get(final ByteBuffer buf, final int position)
+ public VarianceAggregatorCollector get(final ByteBuffer buf, final int position)
{
- VarianceAggregatorCollector holder = new VarianceAggregatorCollector();
- holder.count = buf.getLong(position);
- holder.sum = buf.getDouble(position + SUM_OFFSET);
- holder.nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
- return holder;
+ return getVarianceCollector(buf, position);
}
@Override
@@ -79,6 +73,51 @@
{
}
+ public static void doInit(ByteBuffer buf, int position)
+ {
+ buf.putLong(position + COUNT_OFFSET, 0)
+ .putDouble(position + SUM_OFFSET, 0)
+ .putDouble(position + NVARIANCE_OFFSET, 0);
+ }
+
+ public static long getCount(ByteBuffer buf, int position)
+ {
+ return buf.getLong(position + COUNT_OFFSET);
+ }
+
+ public static double getSum(ByteBuffer buf, int position)
+ {
+ return buf.getDouble(position + SUM_OFFSET);
+ }
+
+ public static double getVariance(ByteBuffer buf, int position)
+ {
+ return buf.getDouble(position + NVARIANCE_OFFSET);
+ }
+ public static VarianceAggregatorCollector getVarianceCollector(ByteBuffer buf, int position)
+ {
+ return new VarianceAggregatorCollector(
+ getCount(buf, position),
+ getSum(buf, position),
+ getVariance(buf, position)
+ );
+ }
+
+ public static void writeNVariance(ByteBuffer buf, int position, long count, double sum, double nvariance)
+ {
+ buf.putLong(position + COUNT_OFFSET, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ if (count > 1) {
+ buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
+ }
+ }
+
+ public static void writeCountAndSum(ByteBuffer buf, int position, long count, double sum)
+ {
+ buf.putLong(position + COUNT_OFFSET, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ }
+
public static final class FloatVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
@@ -94,10 +133,9 @@
{
if (noNulls || !selector.isNull()) {
float v = selector.getFloat();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -128,10 +166,9 @@
{
if (noNulls || !selector.isNull()) {
double v = selector.getDouble();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -162,10 +199,9 @@
{
if (noNulls || !selector.isNull()) {
long v = selector.getLong();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -195,7 +231,7 @@
{
VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject();
Preconditions.checkState(holder2 != null);
- long count = buf.getLong(position + COUNT_OFFSET);
+ long count = getCount(buf, position);
if (count == 0) {
buf.putLong(position, holder2.count);
buf.putDouble(position + SUM_OFFSET, holder2.sum);
@@ -203,7 +239,7 @@
return;
}
- double sum = buf.getDouble(position + SUM_OFFSET);
+ double sum = getSum(buf, position);
double nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
final double ratio = count / (double) holder2.count;
@@ -213,9 +249,7 @@
count += holder2.count;
sum += holder2.sum;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
- buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
+ writeNVariance(buf, position, count, sum, nvariance);
}
@Override
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java
new file mode 100644
index 0000000..37c8739
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for doubles.
+ */
+public class VarianceDoubleVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceDoubleVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ double[] vector = selector.getDoubleVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ double[] vector = selector.getDoubleVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java
new file mode 100644
index 0000000..957926e
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for floats.
+ */
+public class VarianceFloatVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceFloatVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ float[] vector = selector.getFloatVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ float[] vector = selector.getFloatVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java
new file mode 100644
index 0000000..69941b6
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for longs.
+ */
+public class VarianceLongVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceLongVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ long[] vector = selector.getLongVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ long[] vector = selector.getLongVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java
new file mode 100644
index 0000000..1a7dfb0
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for {@link VarianceAggregatorCollector}.
+ */
+public class VarianceObjectVectorAggregator implements VectorAggregator
+{
+ private final VectorObjectSelector selector;
+
+ public VarianceObjectVectorAggregator(VectorObjectSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ for (int i = startRow; i < endRow; i++) {
+ previous.fold(vector[i]);
+ }
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.fold(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java
new file mode 100644
index 0000000..25a5130
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.query.aggregation.Aggregator;
+import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceAggregatorFactoryUnitTest extends InitializedNullHandlingTest
+{
+ private static final String NAME = "NAME";
+ private static final String FIELD_NAME = "FIELD_NAME";
+ private static final String DOUBLE = "double";
+ private static final String LONG = "long";
+ private static final String VARIANCE = "variance";
+ private static final String UNKNOWN = "unknown";
+
+ @Mock
+ private ColumnCapabilities capabilities;
+ @Mock
+ private VectorColumnSelectorFactory selectorFactory;
+ @Mock(answer = Answers.RETURNS_MOCKS)
+ private ColumnSelectorFactory metricFactory;
+
+ private VarianceAggregatorFactory target;
+
+ @Before
+ public void setup()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME);
+ }
+
+ @Test
+ public void factorizeVectorShouldReturnFloatVectorAggregator()
+ {
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceFloatVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForDoubleShouldReturnFloatVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, DOUBLE);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceDoubleVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForLongShouldReturnFloatVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, LONG);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceLongVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForVarianceShouldReturnObjectVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, VARIANCE);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeBufferedForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ BufferAggregator agg = target.factorizeBuffered(metricFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceBufferAggregator.ObjectVarianceAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ Aggregator agg = target.factorize(metricFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceAggregator.ObjectVarianceAggregator.class, agg.getClass());
+ }
+
+ @Test(expected = IAE.class)
+ public void factorizeVectorForUnknownColumnShouldThrowIAE()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
+ target.factorizeVector(selectorFactory);
+ }
+
+ @Test(expected = IAE.class)
+ public void factorizeBufferedForUnknownColumnShouldThrowIAE()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
+ target.factorizeBuffered(metricFactory);
+ }
+
+ @Test
+ public void equalsContract()
+ {
+ EqualsVerifier.forClass(VarianceAggregatorFactory.class)
+ .usingGetClass()
+ .verify();
+ }
+
+ private void mockType(ValueType type)
+ {
+ Mockito.doReturn(capabilities).when(selectorFactory).getColumnCapabilities(FIELD_NAME);
+ Mockito.doReturn(capabilities).when(metricFactory).getColumnCapabilities(FIELD_NAME);
+ Mockito.doReturn(type).when(capabilities).getType();
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java
new file mode 100644
index 0000000..4204c2a
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.segment.vector.VectorValueSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceDoubleVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-10;
+ private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60, 123};
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorValueSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceDoubleVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getDoubleVector();
+ target = new VarianceDoubleVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(VALUES.length - START_ROW, collector.count);
+ Assert.assertEquals(217.67, collector.sum, EPSILON);
+ Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateWithNulls()
+ {
+ mockNullsVector();
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(
+ VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
+ collector.count
+ );
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67 : 134, collector.sum, EPSILON);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.211675 : 6272, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsAndNulls()
+ {
+ mockNullsVector();
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
+ Assert.assertEquals(isNull ? 0 : 1, collector.count);
+ Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+
+ private void mockNullsVector()
+ {
+ if (!NullHandling.replaceWithDefault()) {
+ Mockito.doReturn(NULLS).when(selector).getNullVector();
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java
new file mode 100644
index 0000000..ed2f0a3
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.segment.vector.VectorValueSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceFloatVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-8;
+ private static final float[] VALUES = new float[]{7.8F, 11, 23.67F, 60, 123};
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorValueSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceFloatVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getFloatVector();
+ target = new VarianceFloatVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(VALUES.length - START_ROW, collector.count);
+ Assert.assertEquals(217.67000007, collector.sum, EPSILON);
+ Assert.assertEquals(7565.2116703, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateWithNulls()
+ {
+ mockNullsVector();
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(
+ VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
+ collector.count
+ );
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67000007 : 134, collector.sum, EPSILON);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.2116703 : 6272, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsAndNulls()
+ {
+ mockNullsVector();
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
+ Assert.assertEquals(isNull ? 0 : 1, collector.count);
+ Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+
+ private void mockNullsVector()
+ {
+ if (!NullHandling.replaceWithDefault()) {
+ Mockito.doReturn(NULLS).when(selector).getNullVector();
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
index 7755f32..a91e635 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
@@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@@ -63,14 +64,12 @@
private final QueryRunner<Row> runner;
private final GroupByQueryRunnerFactory factory;
private final String testName;
+ private final GroupByQuery.Builder queryBuilder;
@Parameterized.Parameters(name = "{0}")
public static Collection<Object[]> constructorFeeder()
{
- // Use GroupByQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
- // can't vectorize yet.
return GroupByQueryRunnerTest.constructorFeeder().stream()
- .filter(constructor -> !((boolean) constructor[4]) /* !vectorize */)
.map(
constructor ->
new Object[]{
@@ -94,13 +93,14 @@
this.config = config;
this.factory = factory;
this.runner = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner));
+ this.queryBuilder = GroupByQuery.builder()
+ .setContext(ImmutableMap.of("vectorize", config.isVectorize()));
}
@Test
public void testGroupByVarianceOnly()
{
- GroupByQuery query = GroupByQuery
- .builder()
+ GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@@ -141,8 +141,7 @@
@Test
public void testGroupBy()
{
- GroupByQuery query = GroupByQuery
- .builder()
+ GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@@ -191,8 +190,7 @@
new String[]{"alias", "rows", "index", "index_var", "index_stddev"}
);
- GroupByQuery query = GroupByQuery
- .builder()
+ GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setInterval("2011-04-02/2011-04-04")
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@@ -244,8 +242,7 @@
public void testGroupByZtestPostAgg()
{
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
- GroupByQuery query = GroupByQuery
- .builder()
+ GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@@ -286,8 +283,7 @@
public void testGroupByTestPvalueZscorePostAgg()
{
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
- GroupByQuery query = GroupByQuery
- .builder()
+ GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@@ -308,7 +304,14 @@
.build();
VarianceTestHelper.RowBuilder builder =
- new VarianceTestHelper.RowBuilder(new String[]{"alias", "rows", "idx", "index_stddev", "index_var", "pvalueZscore"});
+ new VarianceTestHelper.RowBuilder(new String[]{
+ "alias",
+ "rows",
+ "idx",
+ "index_stddev",
+ "index_var",
+ "pvalueZscore"
+ });
List<ResultRow> expectedResults = builder
.add("2011-04-01", "automotive", 1L, 135.0, 0.0, 0.0, 1.0)
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java
new file mode 100644
index 0000000..d47bf5c
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.segment.vector.VectorValueSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceLongVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-10;
+ private static final long[] VALUES = new long[]{7, 11, 23, 60, 123};
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorValueSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceLongVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getLongVector();
+ target = new VarianceLongVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(VALUES.length - START_ROW, collector.count);
+ Assert.assertEquals(217, collector.sum, EPSILON);
+ Assert.assertEquals(7606.75, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateWithNulls()
+ {
+ mockNullsVector();
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(
+ VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
+ collector.count
+ );
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 217 : 134, collector.sum, EPSILON);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 7606.75 : 6272, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsAndNulls()
+ {
+ mockNullsVector();
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
+ Assert.assertEquals(isNull ? 0 : 1, collector.count);
+ Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+
+ private void mockNullsVector()
+ {
+ if (!NullHandling.replaceWithDefault()) {
+ Mockito.doReturn(NULLS).when(selector).getNullVector();
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java
new file mode 100644
index 0000000..0e6694a
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.segment.vector.VectorObjectSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceObjectVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-10;
+ private static final VarianceAggregatorCollector[] VALUES = new VarianceAggregatorCollector[]{
+ new VarianceAggregatorCollector(1, 7.8, 0),
+ new VarianceAggregatorCollector(1, 11, 0),
+ new VarianceAggregatorCollector(1, 23.67, 0),
+ null,
+ new VarianceAggregatorCollector(2, 183, 1984.5)
+ };
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorObjectSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceObjectVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getObjectVector();
+ target = new VarianceObjectVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(4, collector.count);
+ Assert.assertEquals(217.67, collector.sum, EPSILON);
+ Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(VALUES[i], collector);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ VarianceAggregatorCollector expectedCollector = VALUES[rows[i]];
+ Assert.assertEquals(expectedCollector == null ? new VarianceAggregatorCollector() : expectedCollector, collector);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java
index 9c52961..fd28b38 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java
@@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
+import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus;
@@ -46,31 +47,32 @@
@Parameterized.Parameters(name = "{0}:descending={1}")
public static Iterable<Object[]> constructorFeeder()
{
- // Use TimeseriesQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
- // can't vectorize yet.
return StreamSupport.stream(TimeseriesQueryRunnerTest.constructorFeeder().spliterator(), false)
- .filter(constructor -> !((boolean) constructor[2]) /* !vectorize */)
- .map(constructor -> new Object[]{constructor[0], constructor[1], constructor[3]})
+ .map(constructor -> new Object[]{constructor[0], constructor[1], constructor[2], constructor[3]})
.collect(Collectors.toList());
}
private final QueryRunner runner;
private final boolean descending;
+ private final Druids.TimeseriesQueryBuilder queryBuilder;
public VarianceTimeseriesQueryTest(
QueryRunner runner,
boolean descending,
+ boolean vectorize,
List<AggregatorFactory> aggregatorFactories
)
{
this.runner = runner;
this.descending = descending;
+ this.queryBuilder = Druids.newTimeseriesQueryBuilder()
+ .context(ImmutableMap.of("vectorize", vectorize ? "force" : "false"));
}
@Test
public void testTimeseriesWithNullFilterOnNonExistentDimension()
{
- TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
+ TimeseriesQuery query = queryBuilder
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.DAY_GRAN)
.filters("bobby", null)
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
index cfb945b..344bbb9 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -285,16 +285,16 @@
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
- .dataSource(CalciteTests.DATASOURCE3)
- .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
- .granularity(Granularities.ALL)
- .aggregators(
- ImmutableList.of(
- new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
- new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
- new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
- )
- )
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+ )
+ )
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@@ -335,22 +335,22 @@
holder1.getVariance(false),
holder2.getVariance(false).floatValue(),
holder3.getVariance(false).longValue(),
- }
+ }
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
- .dataSource(CalciteTests.DATASOURCE3)
- .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
- .granularity(Granularities.ALL)
- .aggregators(
- ImmutableList.of(
- new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
- new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
- new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
- )
- )
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+ )
+ )
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@@ -391,28 +391,29 @@
Math.sqrt(holder1.getVariance(true)),
(float) Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
- }
+ }
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
- .dataSource(CalciteTests.DATASOURCE3)
- .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
- .granularity(Granularities.ALL)
- .aggregators(
- ImmutableList.of(
- new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
- new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
- new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
- )
- )
- .postAggregators(
- ImmutableList.of(
- new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
- new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
- new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
- )
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+ )
+ )
+ .postAggregators(
+ ImmutableList.of(
+ new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
+ new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
+ new StandardDeviationPostAggregator("a2", "a2:agg", "population")
+ )
+ )
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@@ -453,7 +454,7 @@
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
- }
+ }
);
assertResultsEquals(expectedResults, results);
@@ -464,9 +465,9 @@
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
- new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
- new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
- new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+ new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.postAggregators(
@@ -514,7 +515,7 @@
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
- }
+ }
);
assertResultsEquals(expectedResults, results);
@@ -530,9 +531,9 @@
)
.aggregators(
ImmutableList.of(
- new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
- new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
- new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
+ new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
+ new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
)
)
.postAggregators(
@@ -560,41 +561,41 @@
authenticationResult
).toList();
List<Object[]> expectedResults = NullHandling.sqlCompatible()
- ? ImmutableList.of(
- new Object[] {"a", 0f},
- new Object[] {null, 0f},
- new Object[] {"", 0f},
- new Object[] {"abc", null}
+ ? ImmutableList.of(
+ new Object[]{"a", 0f},
+ new Object[]{null, 0f},
+ new Object[]{"", 0f},
+ new Object[]{"abc", null}
) : ImmutableList.of(
- new Object[] {"a", 0.5f},
- new Object[] {"", 0.0033333334f},
- new Object[] {"abc", 0f}
+ new Object[]{"a", 0.5f},
+ new Object[]{"", 0.0033333334f},
+ new Object[]{"abc", 0f}
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
GroupByQuery.builder()
- .setDataSource(CalciteTests.DATASOURCE3)
- .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
- .setGranularity(Granularities.ALL)
- .setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
- .setAggregatorSpecs(
- new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
- )
- .setLimitSpec(
- DefaultLimitSpec
- .builder()
- .orderBy(
- new OrderByColumnSpec(
- "a0:agg",
- OrderByColumnSpec.Direction.DESCENDING,
- StringComparators.NUMERIC
- )
- )
- .build()
- )
- .setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
- .build(),
+ .setDataSource(CalciteTests.DATASOURCE3)
+ .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
+ .setAggregatorSpecs(
+ new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
+ )
+ .setLimitSpec(
+ DefaultLimitSpec
+ .builder()
+ .orderBy(
+ new OrderByColumnSpec(
+ "a0:agg",
+ OrderByColumnSpec.Direction.DESCENDING,
+ StringComparators.NUMERIC
+ )
+ )
+ .build()
+ )
+ .setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
+ .build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@@ -622,7 +623,7 @@
Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
- new VarianceAggregatorFactory("variance", "index")
+ new VarianceAggregatorFactory("variance", "index", null, null)
)
)
.descending(true)
@@ -648,9 +649,18 @@
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
- Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ Object[] expectedResult = expectedResults.get(i);
+ Object[] result = results.get(i);
+ Assert.assertEquals(expectedResult.length, result.length);
+ for (int j = 0; j < expectedResult.length; j++) {
+ if (expectedResult[j] instanceof Float) {
+ Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
+ } else if (expectedResult[j] instanceof Double) {
+ Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
+ } else {
+ Assert.assertEquals(expectedResult[j], result[j]);
+ }
+ }
}
}
-
-
}