blob: 51e7a751f4d39d28092d9774c0e4b8303c3b37f0 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.operator.transform.function;
import com.google.common.base.Preconditions;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionUtils;
import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.spi.data.FieldSpec.DataType;
/**
* Wrapper transform function on the annotated scalar function.
*/
public class ScalarTransformFunctionWrapper extends BaseTransformFunction {
private final String _name;
private final FunctionInvoker _functionInvoker;
private final PinotDataType _resultType;
private final TransformResultMetadata _resultMetadata;
private Object[] _arguments;
private int _numNonLiteralArguments;
private int[] _nonLiteralIndices;
private TransformFunction[] _nonLiteralFunctions;
private Object[][] _nonLiteralValues;
private int[] _intResults;
private float[] _floatResults;
private double[] _doubleResults;
private long[] _longResults;
private String[] _stringResults;
private byte[][] _bytesResults;
private int[][] _intMVResults;
private long[][] _longMVResults;
private float[][] _floatMVResults;
private double[][] _doubleMVResults;
private String[][] _stringMVResults;
public ScalarTransformFunctionWrapper(FunctionInfo functionInfo) {
_name = functionInfo.getMethod().getName();
_functionInvoker = new FunctionInvoker(functionInfo);
Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
int numParameters = parameterClasses.length;
for (int i = 0; i < numParameters; i++) {
Preconditions.checkArgument(parameterTypes[i] != null, "Unsupported parameter class: %s for method: %s",
parameterClasses[i], functionInfo.getMethod());
}
Class<?> resultClass = _functionInvoker.getResultClass();
PinotDataType resultType = FunctionUtils.getParameterType(resultClass);
if (resultType != null) {
_resultType = resultType;
_resultMetadata =
new TransformResultMetadata(FunctionUtils.getDataType(resultClass), _resultType.isSingleValue(), false);
} else {
// Handle unrecognized result class with STRING
_resultType = PinotDataType.STRING;
_resultMetadata = new TransformResultMetadata(DataType.STRING, true, false);
}
}
@Override
public String getName() {
return _name;
}
@Override
public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
int numArguments = arguments.size();
PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
Preconditions.checkArgument(numArguments == parameterTypes.length,
"Wrong number of arguments for method: %s, expected: %s, actual: %s", _functionInvoker.getMethod(),
parameterTypes.length, numArguments);
_arguments = new Object[numArguments];
_nonLiteralIndices = new int[numArguments];
_nonLiteralFunctions = new TransformFunction[numArguments];
for (int i = 0; i < numArguments; i++) {
TransformFunction transformFunction = arguments.get(i);
if (transformFunction instanceof LiteralTransformFunction) {
String literal = ((LiteralTransformFunction) transformFunction).getLiteral();
_arguments[i] = parameterTypes[i].convert(literal, PinotDataType.STRING);
} else {
_nonLiteralIndices[_numNonLiteralArguments] = i;
_nonLiteralFunctions[_numNonLiteralArguments] = transformFunction;
_numNonLiteralArguments++;
}
}
_nonLiteralValues = new Object[_numNonLiteralArguments][];
}
@Override
public TransformResultMetadata getResultMetadata() {
return _resultMetadata;
}
@Override
public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.INT) {
return super.transformToIntValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_intResults == null || _intResults.length < length) {
_intResults = new int[length];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_intResults[i] = (int) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _intResults;
}
@Override
public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.LONG) {
return super.transformToLongValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_longResults == null || _longResults.length < length) {
_longResults = new long[length];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_longResults[i] = (long) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _longResults;
}
@Override
public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.FLOAT) {
return super.transformToFloatValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_floatResults == null || _floatResults.length < length) {
_floatResults = new float[length];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_floatResults[i] = (float) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _floatResults;
}
@Override
public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.DOUBLE) {
return super.transformToDoubleValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_doubleResults == null || _doubleResults.length < length) {
_doubleResults = new double[length];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_doubleResults[i] = (double) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _doubleResults;
}
@Override
public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.STRING) {
return super.transformToStringValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_stringResults == null || _stringResults.length < length) {
_stringResults = new String[length];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
Object result = _functionInvoker.invoke(_arguments);
_stringResults[i] =
_resultType == PinotDataType.STRING ? result.toString() : (String) _resultType.toInternal(result);
}
return _stringResults;
}
@Override
public byte[][] transformToBytesValuesSV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.BYTES) {
return super.transformToBytesValuesSV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_bytesResults == null || _bytesResults.length < length) {
_bytesResults = new byte[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_bytesResults[i] = (byte[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _bytesResults;
}
@Override
public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.INT) {
return super.transformToIntValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_intMVResults == null) {
_intMVResults = new int[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_intMVResults[i] = (int[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _intMVResults;
}
@Override
public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.LONG) {
return super.transformToLongValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_longMVResults == null || _longMVResults.length < length) {
_longMVResults = new long[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_longMVResults[i] = (long[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _longMVResults;
}
@Override
public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.FLOAT) {
return super.transformToFloatValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_floatMVResults == null || _floatMVResults.length < length) {
_floatMVResults = new float[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_floatMVResults[i] = (float[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _floatMVResults;
}
@Override
public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.DOUBLE) {
return super.transformToDoubleValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_doubleMVResults == null || _doubleMVResults.length < length) {
_doubleMVResults = new double[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_doubleMVResults[i] = (double[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _doubleMVResults;
}
@Override
public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) {
if (_resultMetadata.getDataType().getStoredType() != DataType.STRING) {
return super.transformToStringValuesMV(projectionBlock);
}
int length = projectionBlock.getNumDocs();
if (_stringMVResults == null || _stringMVResults.length < length) {
_stringMVResults = new String[length][];
}
getNonLiteralValues(projectionBlock);
for (int i = 0; i < length; i++) {
for (int j = 0; j < _numNonLiteralArguments; j++) {
_arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i];
}
_stringMVResults[i] = (String[]) _resultType.toInternal(_functionInvoker.invoke(_arguments));
}
return _stringMVResults;
}
/**
* Helper method to fetch values for the non-literal transform functions based on the parameter types.
*/
private void getNonLiteralValues(ProjectionBlock projectionBlock) {
PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes();
for (int i = 0; i < _numNonLiteralArguments; i++) {
PinotDataType parameterType = parameterTypes[_nonLiteralIndices[i]];
TransformFunction transformFunction = _nonLiteralFunctions[i];
switch (parameterType) {
case INTEGER:
_nonLiteralValues[i] = ArrayUtils.toObject(transformFunction.transformToIntValuesSV(projectionBlock));
break;
case LONG:
_nonLiteralValues[i] = ArrayUtils.toObject(transformFunction.transformToLongValuesSV(projectionBlock));
break;
case FLOAT:
_nonLiteralValues[i] = ArrayUtils.toObject(transformFunction.transformToFloatValuesSV(projectionBlock));
break;
case DOUBLE:
_nonLiteralValues[i] = ArrayUtils.toObject(transformFunction.transformToDoubleValuesSV(projectionBlock));
break;
case BIG_DECIMAL:
_nonLiteralValues[i] = transformFunction.transformToBigDecimalValuesSV(projectionBlock);
break;
case BOOLEAN: {
int[] intValues = transformFunction.transformToIntValuesSV(projectionBlock);
int numValues = intValues.length;
Boolean[] booleanValues = new Boolean[numValues];
for (int j = 0; j < numValues; j++) {
booleanValues[j] = intValues[j] == 1;
}
_nonLiteralValues[i] = booleanValues;
break;
}
case TIMESTAMP: {
long[] longValues = transformFunction.transformToLongValuesSV(projectionBlock);
int numValues = longValues.length;
Timestamp[] timestampValues = new Timestamp[numValues];
for (int j = 0; j < numValues; j++) {
timestampValues[j] = new Timestamp(longValues[j]);
}
_nonLiteralValues[i] = timestampValues;
break;
}
case STRING:
_nonLiteralValues[i] = transformFunction.transformToStringValuesSV(projectionBlock);
break;
case BYTES:
_nonLiteralValues[i] = transformFunction.transformToBytesValuesSV(projectionBlock);
break;
case PRIMITIVE_INT_ARRAY:
_nonLiteralValues[i] = transformFunction.transformToIntValuesMV(projectionBlock);
break;
case PRIMITIVE_LONG_ARRAY:
_nonLiteralValues[i] = transformFunction.transformToLongValuesMV(projectionBlock);
break;
case PRIMITIVE_FLOAT_ARRAY:
_nonLiteralValues[i] = transformFunction.transformToFloatValuesMV(projectionBlock);
break;
case PRIMITIVE_DOUBLE_ARRAY:
_nonLiteralValues[i] = transformFunction.transformToDoubleValuesMV(projectionBlock);
break;
case STRING_ARRAY:
_nonLiteralValues[i] = transformFunction.transformToStringValuesMV(projectionBlock);
break;
default:
throw new IllegalStateException("Unsupported parameter type: " + parameterType);
}
}
}
}