blob: 87cf48650c4d0ba03b4320c4c80112c4ef47ffde [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.sql.udf;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.adapter.enumerable.CallImplementor;
import org.apache.calcite.adapter.enumerable.EnumUtils;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexImpTable;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.linq4j.tree.ConstantExpression;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.ExpressionType;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.Function;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.ImplementableFunction;
import org.apache.calcite.schema.ScalarFunction;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import static org.apache.calcite.schema.impl.ReflectiveFunctionBase.builder;
/**
* Operator to extract nested Rows or Fields form a struct row type using a dotted path.
* The goal of this operator is two-fold.
* First it is a temporary fix for https://issues.apache.org/jira/browse/CALCITE-4065 to extract a row from a row.
* Second it will enable smooth backward compatible migration from existing udf that relies on legacy row format.
*/
public class GetNestedField extends SqlUserDefinedFunction {
public static final SqlFunction INSTANCE = new GetNestedField(new ExtractFunction());
public GetNestedField(Function function) {
super(new SqlIdentifier("GetNestedField", SqlParserPos.ZERO), null, null, null, ImmutableList.of(), function);
}
@Override
public SqlOperandCountRange getOperandCountRange() {
return SqlOperandCountRanges.of(2);
}
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
final SqlNode left = callBinding.operand(0);
final SqlNode right = callBinding.operand(1);
final RelDataType type = callBinding.getValidator().deriveType(callBinding.getScope(), left);
boolean isRow = true;
if (type.getSqlTypeName() != SqlTypeName.ROW) {
isRow = false;
} else if (type.getSqlIdentifier().isStar()) {
isRow = false;
}
if (!isRow && throwOnFailure) {
throw callBinding.newValidationSignatureError();
}
return isRow && OperandTypes.STRING.checkSingleOperandType(callBinding, right, 0, throwOnFailure);
}
@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final RelDataType recordType = opBinding.getOperandType(0);
switch (recordType.getSqlTypeName()) {
case ROW:
final String fieldName = opBinding.getOperandLiteralValue(1, String.class);
String[] fieldNameChain = fieldName.split("\\.");
RelDataType relDataType = opBinding.getOperandType(0);
for (int i = 0; i < fieldNameChain.length; i++) {
RelDataTypeField t = relDataType.getField(fieldNameChain[i], true, true);
Preconditions.checkNotNull(t,
"Can not find " + fieldNameChain[i] + " within record " + recordType.toString() + " Original String "
+ Arrays.toString(fieldNameChain) + " Original row " + recordType.toString());
relDataType = t.getType();
}
if (recordType.isNullable()) {
return typeFactory.createTypeWithNullability(relDataType, true);
} else {
return relDataType;
}
default:
throw new AssertionError("First Operand is suppose to be a Row Struct");
}
}
private static class ExtractFunction implements ScalarFunction, ImplementableFunction {
private final JavaTypeFactoryImpl javaTypeFactory = new JavaTypeFactoryImpl();
@Override
public CallImplementor getImplementor() {
return RexImpTable.createImplementor((translator, call, translatedOperands) -> {
Preconditions.checkState(translatedOperands.size() == 2 && call.operands.size() == 2,
"Expected 2 operands found " + Math.min(translatedOperands.size(), call.getOperands().size()));
Expression op0 = translatedOperands.get(0);
Expression op1 = translatedOperands.get(1);
Preconditions.checkState(op1.getNodeType().equals(ExpressionType.Constant),
"Operand 2 has to be constant and got " + op1.getNodeType());
Preconditions.checkState(op1.type.equals(String.class), "Operand 2 has to be String and got " + op1.type);
final String fieldName = (String) ((ConstantExpression) op1).value;
String[] fieldNameChain = fieldName.split("\\.");
RelDataType relDataType = call.operands.get(0).getType();
Preconditions.checkState(relDataType.getSqlTypeName().equals(SqlTypeName.ROW),
"Expected first operand to be ROW found " + relDataType.toString());
Expression currentExpression = op0;
for (int i = 0; i < fieldNameChain.length; i++) {
Preconditions.checkState(relDataType.getSqlTypeName() == SqlTypeName.ROW,
"Must be ROW found " + relDataType.toString());
RelDataTypeField t = relDataType.getField(fieldNameChain[i], true, true);
Preconditions.checkNotNull(t,
"Notfound " + fieldNameChain[i] + " in the following struct " + relDataType.toString()
+ " Original String " + Arrays.toString(fieldNameChain) + " Original row " + call.operands.get(0)
.getType());
currentExpression = Expressions.arrayIndex(Expressions.convert_(currentExpression, Object[].class),
Expressions.constant(t.getIndex()));
relDataType = t.getType();
}
Type fieldType = javaTypeFactory.getJavaClass(relDataType);
return EnumUtils.convert(currentExpression, fieldType);
}, NullPolicy.ARG0, false);
}
@Override
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
throw new IllegalStateException("should not be called");
}
@Override
public List<FunctionParameter> getParameters() {
return builder().add(Object[].class, "row").add(String.class, "path").build();
}
}
@Override
public String getAllowedSignatures(String opNameToUse) {
return opNameToUse + "(<ROW>, <VARCHAR>)";
}
}