blob: d928585279798ec86f526cd318ee9889f4516e74 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.query.aggregation.function;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.CommonConstants;
import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;
public class DistinctCountHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
protected final int _log2m;
public DistinctCountHLLAggregationFunction(List<ExpressionContext> arguments) {
super(arguments.get(0));
int numExpressions = arguments.size();
// This function expects 1 or 2 arguments.
Preconditions.checkArgument(numExpressions <= 2, "DistinctCountHLL expects 1 or 2 arguments, got: %s",
numExpressions);
if (arguments.size() == 2) {
_log2m = Integer.parseInt(arguments.get(1).getLiteral());
} else {
_log2m = CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M;
}
}
public int getLog2m() {
return _log2m;
}
@Override
public AggregationFunctionType getType() {
return AggregationFunctionType.DISTINCTCOUNTHLL;
}
@Override
public AggregationResultHolder createAggregationResultHolder() {
return new ObjectAggregationResultHolder();
}
@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
}
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
// Treat BYTES value as serialized HyperLogLog
DataType storedType = blockValSet.getValueType().getStoredType();
if (storedType == DataType.BYTES) {
byte[][] bytesValues = blockValSet.getBytesValuesSV();
try {
HyperLogLog hyperLogLog = aggregationResultHolder.getResult();
if (hyperLogLog != null) {
for (int i = 0; i < length; i++) {
hyperLogLog.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
}
} else {
hyperLogLog = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[0]);
aggregationResultHolder.setValue(hyperLogLog);
for (int i = 1; i < length; i++) {
hyperLogLog.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
}
}
} catch (Exception e) {
throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
}
return;
}
// For dictionary-encoded expression, store dictionary ids into the bitmap
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length);
return;
}
// For non-dictionary-encoded expression, store values into the HyperLogLog
HyperLogLog hyperLogLog = getHyperLogLog(aggregationResultHolder);
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
hyperLogLog.offer(intValues[i]);
}
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
hyperLogLog.offer(longValues[i]);
}
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
hyperLogLog.offer(floatValues[i]);
}
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
hyperLogLog.offer(doubleValues[i]);
}
break;
case STRING:
String[] stringValues = blockValSet.getStringValuesSV();
for (int i = 0; i < length; i++) {
hyperLogLog.offer(stringValues[i]);
}
break;
default:
throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
}
}
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
// Treat BYTES value as serialized HyperLogLog
DataType storedType = blockValSet.getValueType().getStoredType();
if (storedType == DataType.BYTES) {
byte[][] bytesValues = blockValSet.getBytesValuesSV();
try {
for (int i = 0; i < length; i++) {
HyperLogLog value = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]);
int groupKey = groupKeyArray[i];
HyperLogLog hyperLogLog = groupByResultHolder.getResult(groupKey);
if (hyperLogLog != null) {
hyperLogLog.addAll(value);
} else {
groupByResultHolder.setValueForKey(groupKey, value);
}
}
} catch (Exception e) {
throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
}
return;
}
// For dictionary-encoded expression, store dictionary ids into the bitmap
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
for (int i = 0; i < length; i++) {
getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]);
}
return;
}
// For non-dictionary-encoded expression, store values into the HyperLogLog
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer(intValues[i]);
}
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer(longValues[i]);
}
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer(floatValues[i]);
}
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer(doubleValues[i]);
}
break;
case STRING:
String[] stringValues = blockValSet.getStringValuesSV();
for (int i = 0; i < length; i++) {
getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer(stringValues[i]);
}
break;
default:
throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
}
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
// Treat BYTES value as serialized HyperLogLog
DataType storedType = blockValSet.getValueType().getStoredType();
if (storedType == DataType.BYTES) {
byte[][] bytesValues = blockValSet.getBytesValuesSV();
try {
for (int i = 0; i < length; i++) {
HyperLogLog value = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]);
for (int groupKey : groupKeysArray[i]) {
HyperLogLog hyperLogLog = groupByResultHolder.getResult(groupKey);
if (hyperLogLog != null) {
hyperLogLog.addAll(value);
} else {
// Create a new HyperLogLog for the group
groupByResultHolder.setValueForKey(groupKey,
ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
}
}
}
} catch (Exception e) {
throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
}
return;
}
// For dictionary-encoded expression, store dictionary ids into the bitmap
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
for (int i = 0; i < length; i++) {
setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]);
}
return;
}
// For non-dictionary-encoded expression, store values into the HyperLogLog
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], intValues[i]);
}
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], longValues[i]);
}
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], floatValues[i]);
}
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], doubleValues[i]);
}
break;
case STRING:
String[] stringValues = blockValSet.getStringValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i]);
}
break;
default:
throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
}
}
@Override
public HyperLogLog extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
Object result = aggregationResultHolder.getResult();
if (result == null) {
return new HyperLogLog(_log2m);
}
if (result instanceof DictIdsWrapper) {
// For dictionary-encoded expression, convert dictionary ids to HyperLogLog
return convertToHyperLogLog((DictIdsWrapper) result);
} else {
// For non-dictionary-encoded expression, directly return the HyperLogLog
return (HyperLogLog) result;
}
}
@Override
public HyperLogLog extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
Object result = groupByResultHolder.getResult(groupKey);
if (result == null) {
return new HyperLogLog(_log2m);
}
if (result instanceof DictIdsWrapper) {
// For dictionary-encoded expression, convert dictionary ids to HyperLogLog
return convertToHyperLogLog((DictIdsWrapper) result);
} else {
// For non-dictionary-encoded expression, directly return the HyperLogLog
return (HyperLogLog) result;
}
}
@Override
public HyperLogLog merge(HyperLogLog intermediateResult1, HyperLogLog intermediateResult2) {
// Can happen when aggregating serialized HyperLogLog with non-default log2m
if (intermediateResult1.sizeof() != intermediateResult2.sizeof()) {
if (intermediateResult1.cardinality() == 0) {
return intermediateResult2;
} else {
Preconditions.checkState(intermediateResult2.cardinality() == 0,
"Cannot merge HyperLogLogs of different sizes");
return intermediateResult1;
}
}
try {
intermediateResult1.addAll(intermediateResult2);
} catch (Exception e) {
throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
}
return intermediateResult1;
}
@Override
public ColumnDataType getIntermediateResultColumnType() {
return ColumnDataType.OBJECT;
}
@Override
public ColumnDataType getFinalResultColumnType() {
return ColumnDataType.LONG;
}
@Override
public Long extractFinalResult(HyperLogLog intermediateResult) {
return intermediateResult.cardinality();
}
/**
* Returns the dictionary id bitmap from the result holder or creates a new one if it does not exist.
*/
protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder,
Dictionary dictionary) {
DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult();
if (dictIdsWrapper == null) {
dictIdsWrapper = new DictIdsWrapper(dictionary);
aggregationResultHolder.setValue(dictIdsWrapper);
}
return dictIdsWrapper._dictIdBitmap;
}
/**
* Returns the HyperLogLog from the result holder or creates a new one if it does not exist.
*/
protected HyperLogLog getHyperLogLog(AggregationResultHolder aggregationResultHolder) {
HyperLogLog hyperLogLog = aggregationResultHolder.getResult();
if (hyperLogLog == null) {
hyperLogLog = new HyperLogLog(_log2m);
aggregationResultHolder.setValue(hyperLogLog);
}
return hyperLogLog;
}
/**
* Returns the dictionary id bitmap for the given group key or creates a new one if it does not exist.
*/
protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey,
Dictionary dictionary) {
DictIdsWrapper dictIdsWrapper = groupByResultHolder.getResult(groupKey);
if (dictIdsWrapper == null) {
dictIdsWrapper = new DictIdsWrapper(dictionary);
groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper);
}
return dictIdsWrapper._dictIdBitmap;
}
/**
* Returns the HyperLogLog for the given group key or creates a new one if it does not exist.
*/
protected HyperLogLog getHyperLogLog(GroupByResultHolder groupByResultHolder, int groupKey) {
HyperLogLog hyperLogLog = groupByResultHolder.getResult(groupKey);
if (hyperLogLog == null) {
hyperLogLog = new HyperLogLog(_log2m);
groupByResultHolder.setValueForKey(groupKey, hyperLogLog);
}
return hyperLogLog;
}
/**
* Helper method to set dictionary id for the given group keys into the result holder.
*/
private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys,
Dictionary dictionary, int dictId) {
for (int groupKey : groupKeys) {
getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictId);
}
}
/**
* Helper method to set value for the given group keys into the result holder.
*/
private void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, Object value) {
for (int groupKey : groupKeys) {
getHyperLogLog(groupByResultHolder, groupKey).offer(value);
}
}
/**
* Helper method to read dictionary and convert dictionary ids to HyperLogLog for dictionary-encoded expression.
*/
private HyperLogLog convertToHyperLogLog(DictIdsWrapper dictIdsWrapper) {
HyperLogLog hyperLogLog = new HyperLogLog(_log2m);
Dictionary dictionary = dictIdsWrapper._dictionary;
RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap;
PeekableIntIterator iterator = dictIdBitmap.getIntIterator();
while (iterator.hasNext()) {
hyperLogLog.offer(dictionary.get(iterator.next()));
}
return hyperLogLog;
}
private static final class DictIdsWrapper {
final Dictionary _dictionary;
final RoaringBitmap _dictIdBitmap;
private DictIdsWrapper(Dictionary dictionary) {
_dictionary = dictionary;
_dictIdBitmap = new RoaringBitmap();
}
}
}