blob: 16c186826d51c7d2d3f1ca25967c99ea8eec33a2 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.operator.transform.function;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
import it.unimi.dsi.fastutil.doubles.DoubleSet;
import it.unimi.dsi.fastutil.floats.FloatArrayList;
import it.unimi.dsi.fastutil.floats.FloatList;
import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
import it.unimi.dsi.fastutil.floats.FloatSet;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec.DataType;
public class ValueInTransformFunction extends BaseTransformFunction {
public static final String FUNCTION_NAME = "valueIn";
private TransformFunction _mainTransformFunction;
private TransformResultMetadata _resultMetadata;
private Dictionary _dictionary;
private IntSet _dictIdSet;
private int[][] _dictIds;
private IntSet _intValueSet;
private int[][] _intValues;
private LongSet _longValueSet;
private long[][] _longValues;
private FloatSet _floatValueSet;
private float[][] _floatValues;
private DoubleSet _doubleValueSet;
private double[][] _doubleValues;
private Set<String> _stringValueSet;
private String[][] _stringValues;
@Override
public String getName() {
return FUNCTION_NAME;
}
@Override
public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
// Check that there are more than 1 arguments
int numArguments = arguments.size();
if (numArguments < 2) {
throw new IllegalArgumentException("At least 2 arguments are required for VALUE_IN transform function");
}
// Check that the first argument is a multi-valued column or transform function
TransformFunction firstArgument = arguments.get(0);
if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
throw new IllegalArgumentException(
"The first argument of VALUE_IN transform function must be a multi-valued column or a transform function");
}
_mainTransformFunction = firstArgument;
_resultMetadata = _mainTransformFunction.getResultMetadata();
_dictionary = _mainTransformFunction.getDictionary();
// Collect all values for the VALUE_IN transform function
_stringValueSet = new HashSet<>(numArguments - 1);
for (int i = 1; i < numArguments; i++) {
_stringValueSet.add(((LiteralTransformFunction) arguments.get(i)).getLiteral());
}
}
@Override
public TransformResultMetadata getResultMetadata() {
return _resultMetadata;
}
@Override
public Dictionary getDictionary() {
return _dictionary;
}
@Override
public int[][] transformToDictIdsMV(ProjectionBlock projectionBlock) {
int length = projectionBlock.getNumDocs();
if (_dictIdSet == null) {
_dictIdSet = new IntOpenHashSet();
assert _dictionary != null;
for (String inValue : _stringValueSet) {
int dictId = _dictionary.indexOf(inValue);
if (dictId >= 0) {
_dictIdSet.add(dictId);
}
}
if (_dictIds == null || _dictIds.length < length) {
_dictIds = new int[length][];
}
}
int[][] unFilteredDictIds = _mainTransformFunction.transformToDictIdsMV(projectionBlock);
for (int i = 0; i < length; i++) {
_dictIds[i] = filterInts(_dictIdSet, unFilteredDictIds[i]);
}
return _dictIds;
}
@Override
public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) {
if (_dictionary != null || _resultMetadata.getDataType().getStoredType() != DataType.INT) {
return super.transformToIntValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_intValueSet == null) {
_intValueSet = new IntOpenHashSet();
for (String inValue : _stringValueSet) {
_intValueSet.add(Integer.parseInt(inValue));
}
if (_intValues == null || _intValues.length < length) {
_intValues = new int[length][];
}
}
int[][] unFilteredIntValues = _mainTransformFunction.transformToIntValuesMV(projectionBlock);
for (int i = 0; i < length; i++) {
_intValues[i] = filterInts(_intValueSet, unFilteredIntValues[i]);
}
return _intValues;
}
@Override
public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) {
if (_dictionary != null || _resultMetadata.getDataType().getStoredType() != DataType.LONG) {
return super.transformToLongValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_longValueSet == null) {
_longValueSet = new LongOpenHashSet();
for (String inValue : _stringValueSet) {
_longValueSet.add(Long.parseLong(inValue));
}
if (_longValues == null || _longValues.length < length) {
_longValues = new long[length][];
}
}
long[][] unFilteredLongValues = _mainTransformFunction.transformToLongValuesMV(projectionBlock);
for (int i = 0; i < length; i++) {
_longValues[i] = filterLongs(_longValueSet, unFilteredLongValues[i]);
}
return _longValues;
}
@Override
public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) {
if (_dictionary != null || _resultMetadata.getDataType().getStoredType() != DataType.FLOAT) {
return super.transformToFloatValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_floatValueSet == null) {
_floatValueSet = new FloatOpenHashSet();
for (String inValue : _stringValueSet) {
_floatValueSet.add(Float.parseFloat(inValue));
}
if (_floatValues == null || _floatValues.length < length) {
_floatValues = new float[length][];
}
}
float[][] unFilteredFloatValues = _mainTransformFunction.transformToFloatValuesMV(projectionBlock);
for (int i = 0; i < length; i++) {
_floatValues[i] = filterFloats(_floatValueSet, unFilteredFloatValues[i]);
}
return _floatValues;
}
@Override
public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) {
if (_dictionary != null || _resultMetadata.getDataType().getStoredType() != DataType.DOUBLE) {
return super.transformToDoubleValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_doubleValueSet == null) {
_doubleValueSet = new DoubleOpenHashSet();
for (String inValue : _stringValueSet) {
_doubleValueSet.add(Double.parseDouble(inValue));
}
if (_doubleValues == null || _doubleValues.length < length) {
_doubleValues = new double[length][];
}
}
double[][] unFilteredDoubleValues = _mainTransformFunction.transformToDoubleValuesMV(projectionBlock);
for (int i = 0; i < length; i++) {
_doubleValues[i] = filterDoubles(_doubleValueSet, unFilteredDoubleValues[i]);
}
return _doubleValues;
}
@Override
public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) {
if (_dictionary != null || _resultMetadata.getDataType().getStoredType() != DataType.STRING) {
return super.transformToStringValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_stringValues == null || _stringValues.length < length) {
_stringValues = new String[length][];
}
String[][] unFilteredStringValues = _mainTransformFunction.transformToStringValuesMV(projectionBlock);
for (int i = 0; i < length; i++) {
_stringValues[i] = filterStrings(_stringValueSet, unFilteredStringValues[i]);
}
return _stringValues;
}
private static int[] filterInts(IntSet intSet, int[] source) {
IntList intList = new IntArrayList();
for (int value : source) {
if (intSet.contains(value)) {
intList.add(value);
}
}
if (intList.size() == source.length) {
return source;
} else {
return intList.toIntArray();
}
}
private static long[] filterLongs(LongSet longSet, long[] source) {
LongList longList = new LongArrayList();
for (long value : source) {
if (longSet.contains(value)) {
longList.add(value);
}
}
if (longList.size() == source.length) {
return source;
} else {
return longList.toLongArray();
}
}
private static float[] filterFloats(FloatSet floatSet, float[] source) {
FloatList floatList = new FloatArrayList();
for (float value : source) {
if (floatSet.contains(value)) {
floatList.add(value);
}
}
if (floatList.size() == source.length) {
return source;
} else {
return floatList.toFloatArray();
}
}
private static double[] filterDoubles(DoubleSet doubleSet, double[] source) {
DoubleList doubleList = new DoubleArrayList();
for (double value : source) {
if (doubleSet.contains(value)) {
doubleList.add(value);
}
}
if (doubleList.size() == source.length) {
return source;
} else {
return doubleList.toDoubleArray();
}
}
private static String[] filterStrings(Set<String> stringSet, String[] source) {
List<String> stringList = new ArrayList<>();
for (String value : source) {
if (stringSet.contains(value)) {
stringList.add(value);
}
}
if (stringList.size() == source.length) {
return source;
} else {
return stringList.toArray(new String[0]);
}
}
}