blob: fac90c9efee6ca5c895b7d15c7522dbd153736b7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.functions;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.Set;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.util.BinaryUtil;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.TruncateUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
/**
* A Spark function implementation for the Iceberg truncate transform.
*
* <p>Example usage: {@code SELECT system.truncate(1, 'abc')}, which returns the String 'a'.
*
* <p>Note that for performance reasons, the given input width is not validated in the
* implementations used in code-gen. The width must remain non-negative to give meaningful results.
*/
public class TruncateFunction implements UnboundFunction {
private static final int WIDTH_ORDINAL = 0;
private static final int VALUE_ORDINAL = 1;
private static final Set<DataType> SUPPORTED_WIDTH_TYPES =
ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType);
@Override
public BoundFunction bind(StructType inputType) {
if (inputType.size() != 2) {
throw new UnsupportedOperationException("Wrong number of inputs (expected width and value)");
}
StructField widthField = inputType.fields()[WIDTH_ORDINAL];
StructField valueField = inputType.fields()[VALUE_ORDINAL];
if (!SUPPORTED_WIDTH_TYPES.contains(widthField.dataType())) {
throw new UnsupportedOperationException(
"Expected truncation width to be tinyint, shortint or int");
}
DataType valueType = valueField.dataType();
if (valueType instanceof ByteType) {
return new TruncateTinyInt();
} else if (valueType instanceof ShortType) {
return new TruncateSmallInt();
} else if (valueType instanceof IntegerType) {
return new TruncateInt();
} else if (valueType instanceof LongType) {
return new TruncateBigInt();
} else if (valueType instanceof DecimalType) {
return new TruncateDecimal(
((DecimalType) valueType).precision(), ((DecimalType) valueType).scale());
} else if (valueType instanceof StringType) {
return new TruncateString();
} else if (valueType instanceof BinaryType) {
return new TruncateBinary();
} else {
throw new UnsupportedOperationException(
"Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary");
}
}
@Override
public String description() {
return name()
+ "(width, col) - Call Iceberg's truncate transform\n"
+ " width :: width for truncation, e.g. truncate(10, 255) -> 250 (must be an integer)\n"
+ " col :: column to truncate (must be an integer, decimal, string, or binary)";
}
@Override
public String name() {
return "truncate";
}
public abstract static class TruncateBase<T> extends BaseScalarFunction<T> {
@Override
public String name() {
return "truncate";
}
}
public static class TruncateTinyInt extends TruncateBase<Byte> {
public static byte invoke(int width, byte value) {
return TruncateUtil.truncateByte(width, value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.ByteType};
}
@Override
public DataType resultType() {
return DataTypes.ByteType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(tinyint)";
}
@Override
public Byte produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getByte(VALUE_ORDINAL));
}
}
}
public static class TruncateSmallInt extends TruncateBase<Short> {
// magic method used in codegen
public static short invoke(int width, short value) {
return TruncateUtil.truncateShort(width, value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.ShortType};
}
@Override
public DataType resultType() {
return DataTypes.ShortType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(smallint)";
}
@Override
public Short produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getShort(VALUE_ORDINAL));
}
}
}
public static class TruncateInt extends TruncateBase<Integer> {
// magic method used in codegen
public static int invoke(int width, int value) {
return TruncateUtil.truncateInt(width, value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.IntegerType};
}
@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(int)";
}
@Override
public Integer produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getInt(VALUE_ORDINAL));
}
}
}
public static class TruncateBigInt extends TruncateBase<Long> {
// magic function for usage with codegen
public static long invoke(int width, long value) {
return TruncateUtil.truncateLong(width, value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.LongType};
}
@Override
public DataType resultType() {
return DataTypes.LongType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(bigint)";
}
@Override
public Long produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getLong(VALUE_ORDINAL));
}
}
}
public static class TruncateString extends TruncateBase<UTF8String> {
// magic function for usage with codegen
public static UTF8String invoke(int width, UTF8String value) {
if (value == null) {
return null;
}
return value.substring(0, width);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.StringType};
}
@Override
public DataType resultType() {
return DataTypes.StringType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(string)";
}
@Override
public UTF8String produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getUTF8String(VALUE_ORDINAL));
}
}
}
public static class TruncateBinary extends TruncateBase<byte[]> {
// magic method used in codegen
public static byte[] invoke(int width, byte[] value) {
if (value == null) {
return null;
}
return ByteBuffers.toByteArray(
BinaryUtil.truncateBinaryUnsafe(ByteBuffer.wrap(value), width));
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType};
}
@Override
public DataType resultType() {
return DataTypes.BinaryType;
}
@Override
public String canonicalName() {
return "iceberg.truncate(binary)";
}
@Override
public byte[] produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(WIDTH_ORDINAL), input.getBinary(VALUE_ORDINAL));
}
}
}
public static class TruncateDecimal extends TruncateBase<Decimal> {
private final int precision;
private final int scale;
public TruncateDecimal(int precision, int scale) {
this.precision = precision;
this.scale = scale;
}
// magic method used in codegen
public static Decimal invoke(int width, Decimal value) {
if (value == null) {
return null;
}
return Decimal.apply(
TruncateUtil.truncateDecimal(BigInteger.valueOf(width), value.toJavaBigDecimal()));
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.createDecimalType(precision, scale)};
}
@Override
public DataType resultType() {
return DataTypes.createDecimalType(precision, scale);
}
@Override
public String canonicalName() {
return String.format("iceberg.truncate(decimal(%d,%d))", precision, scale);
}
@Override
public Decimal produceResult(InternalRow input) {
if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
int width = input.getInt(WIDTH_ORDINAL);
Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale);
return invoke(width, value);
}
}
}
}