blob: c3de3d48dbcc1932c1713713fdd69c8dcc1d68c5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.functions;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Set;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.util.BucketUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampNTZType;
import org.apache.spark.sql.types.TimestampType;
import org.apache.spark.unsafe.types.UTF8String;
/**
* A Spark function implementation for the Iceberg bucket transform.
*
* <p>Example usage: {@code SELECT system.bucket(128, 'abc')}, which returns the bucket 122.
*
* <p>Note that for performance reasons, the given input number of buckets is not validated in the
* implementations used in code-gen. The number of buckets must be positive to give meaningful
* results.
*/
public class BucketFunction implements UnboundFunction {
private static final int NUM_BUCKETS_ORDINAL = 0;
private static final int VALUE_ORDINAL = 1;
private static final Set<DataType> SUPPORTED_NUM_BUCKETS_TYPES =
ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType);
@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public BoundFunction bind(StructType inputType) {
if (inputType.size() != 2) {
throw new UnsupportedOperationException(
"Wrong number of inputs (expected numBuckets and value)");
}
StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL];
StructField valueField = inputType.fields()[VALUE_ORDINAL];
if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) {
throw new UnsupportedOperationException(
"Expected number of buckets to be tinyint, shortint or int");
}
DataType type = valueField.dataType();
if (type instanceof DateType) {
return new BucketInt(type);
} else if (type instanceof ByteType
|| type instanceof ShortType
|| type instanceof IntegerType) {
return new BucketInt(DataTypes.IntegerType);
} else if (type instanceof LongType) {
return new BucketLong(type);
} else if (type instanceof TimestampType) {
return new BucketLong(type);
} else if (type instanceof TimestampNTZType) {
return new BucketLong(type);
} else if (type instanceof DecimalType) {
return new BucketDecimal(type);
} else if (type instanceof StringType) {
return new BucketString();
} else if (type instanceof BinaryType) {
return new BucketBinary();
} else {
throw new UnsupportedOperationException(
"Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary");
}
}
@Override
public String description() {
return name()
+ "(numBuckets, col) - Call Iceberg's bucket transform\n"
+ " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n"
+ " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)";
}
@Override
public String name() {
return "bucket";
}
public abstract static class BucketBase extends BaseScalarFunction<Integer> {
public static int apply(int numBuckets, int hashedValue) {
return (hashedValue & Integer.MAX_VALUE) % numBuckets;
}
@Override
public String name() {
return "bucket";
}
@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
}
// Used for both int and date - tinyint and smallint are upcasted to int by Spark.
public static class BucketInt extends BucketBase {
private final DataType sqlType;
// magic method used in codegen
public static int invoke(int numBuckets, int value) {
return apply(numBuckets, hash(value));
}
// Visible for testing
public static int hash(int value) {
return BucketUtil.hash(value);
}
public BucketInt(DataType sqlType) {
this.sqlType = sqlType;
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}
@Override
public String canonicalName() {
return String.format("iceberg.bucket(%s)", sqlType.catalogString());
}
@Override
public Integer produceResult(InternalRow input) {
// return null for null input to match what Spark does in the code-generated versions.
if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getInt(VALUE_ORDINAL));
}
}
}
// Used for both BigInt and Timestamp
public static class BucketLong extends BucketBase {
private final DataType sqlType;
// magic function for usage with codegen - needs to be static
public static int invoke(int numBuckets, long value) {
return apply(numBuckets, hash(value));
}
// Visible for testing
public static int hash(long value) {
return BucketUtil.hash(value);
}
public BucketLong(DataType sqlType) {
this.sqlType = sqlType;
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}
@Override
public String canonicalName() {
return String.format("iceberg.bucket(%s)", sqlType.catalogString());
}
@Override
public Integer produceResult(InternalRow input) {
if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getLong(VALUE_ORDINAL));
}
}
}
public static class BucketString extends BucketBase {
// magic function for usage with codegen
public static Integer invoke(int numBuckets, UTF8String value) {
if (value == null) {
return null;
}
// TODO - We can probably hash the bytes directly given they're already UTF-8 input.
return apply(numBuckets, hash(value.toString()));
}
// Visible for testing
public static int hash(String value) {
return BucketUtil.hash(value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.StringType};
}
@Override
public String canonicalName() {
return "iceberg.bucket(string)";
}
@Override
public Integer produceResult(InternalRow input) {
if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getUTF8String(VALUE_ORDINAL));
}
}
}
public static class BucketBinary extends BucketBase {
public static Integer invoke(int numBuckets, byte[] value) {
if (value == null) {
return null;
}
return apply(numBuckets, hash(ByteBuffer.wrap(value)));
}
// Visible for testing
public static int hash(ByteBuffer value) {
return BucketUtil.hash(value);
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType};
}
@Override
public Integer produceResult(InternalRow input) {
if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getBinary(VALUE_ORDINAL));
}
}
@Override
public String canonicalName() {
return "iceberg.bucket(binary)";
}
}
public static class BucketDecimal extends BucketBase {
private final DataType sqlType;
private final int precision;
private final int scale;
// magic method used in codegen
public static Integer invoke(int numBuckets, Decimal value) {
if (value == null) {
return null;
}
return apply(numBuckets, hash(value.toJavaBigDecimal()));
}
// Visible for testing
public static int hash(BigDecimal value) {
return BucketUtil.hash(value);
}
public BucketDecimal(DataType sqlType) {
this.sqlType = sqlType;
this.precision = ((DecimalType) sqlType).precision();
this.scale = ((DecimalType) sqlType).scale();
}
@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}
@Override
public Integer produceResult(InternalRow input) {
if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) {
return null;
} else {
int numBuckets = input.getInt(NUM_BUCKETS_ORDINAL);
Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale);
return invoke(numBuckets, value);
}
}
@Override
public String canonicalName() {
return "iceberg.bucket(decimal)";
}
}
}