blob: 6b9f89c4517c64fcd4d487b1ffc6ee1de18849de [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
/**
* This is the base class for wrapping Java/Scala functions as a catalyst expression in Spark SQL.
* @param fSeq The functions to be wrapped. Subclasses can simply pass a function to this constructor,
* and the function will be converted to [[InferrableFunction]] by [[InferrableFunctionConverter]]
* automatically.
*/
abstract class InferredExpression(fSeq: InferrableFunction *)
extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with FoldableExpression
with Serializable {
def inputExpressions: Seq[Expression]
lazy val f: InferrableFunction = fSeq match {
// If there is only one function, simply use it and let org.apache.sedona.sql.UDF.Catalog handle default arguments.
case Seq(f) => f
// If there are multiple overloaded functions, find the one with the same number of arguments as the input
// expressions. Please note that the Catalog won't be able to handle default arguments in this case. We'll
// move default argument handling from Catalog to this class in the future.
case _ => fSeq.find(f => f.sparkInputTypes.size == inputExpressions.size) match {
case Some(f) => f
case None => throw new IllegalArgumentException(s"No overloaded function ${getClass.getName} has ${inputExpressions.size} arguments")
}
}
override def children: Seq[Expression] = inputExpressions
override def toString: String = s" **${getClass.getName}** "
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
override def dataType: DataType = f.sparkReturnType
private lazy val argExtractors: Array[InternalRow => Any] = f.buildExtractors(inputExpressions)
private lazy val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)
override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
override def evalWithoutSerialization(input: InternalRow): Any = evaluator(input)
}
// This is a compile time type shield for the types we are able to infer. Anything
// other than these types will cause a compilation error. This is the Scala
// 2 way of making a union type.
class InferrableType[T: TypeTag]
object InferrableType {
implicit val geometryInstance: InferrableType[Geometry] =
new InferrableType[Geometry] {}
implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
new InferrableType[Array[Geometry]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
new InferrableType[java.lang.Double] {}
implicit val javaIntegerInstance: InferrableType[java.lang.Integer] =
new InferrableType[java.lang.Integer] {}
implicit val javaLongInstance: InferrableType[java.lang.Long] =
new InferrableType[java.lang.Long] {}
implicit val doubleInstance: InferrableType[Double] =
new InferrableType[Double] {}
implicit val booleanInstance: InferrableType[Boolean] =
new InferrableType[Boolean] {}
implicit val booleanOptInstance: InferrableType[Option[Boolean]] =
new InferrableType[Option[Boolean]] {}
implicit val intInstance: InferrableType[Int] =
new InferrableType[Int] {}
implicit val longInstance: InferrableType[Long] =
new InferrableType[Long] {}
implicit val stringInstance: InferrableType[String] =
new InferrableType[String] {}
implicit val binaryInstance: InferrableType[Array[Byte]] =
new InferrableType[Array[Byte]] {}
implicit val intArrayInstance: InferrableType[Array[Int]] =
new InferrableType[Array[Int]] {}
implicit val javaIntArrayInstance: InferrableType[Array[java.lang.Integer]] =
new InferrableType[Array[java.lang.Integer]]
implicit val longArrayInstance: InferrableType[Array[Long]] =
new InferrableType[Array[Long]] {}
implicit val javaLongArrayInstance: InferrableType[Array[java.lang.Long]] =
new InferrableType[Array[java.lang.Long]] {}
implicit val doubleArrayInstance: InferrableType[Array[Double]] =
new InferrableType[Array[Double]] {}
implicit val javaDoubleListInstance: InferrableType[java.util.List[java.lang.Double]] =
new InferrableType[java.util.List[java.lang.Double]] {}
implicit val javaGeomListInstance: InferrableType[java.util.List[Geometry]] =
new InferrableType[java.util.List[Geometry]] {}
}
object InferredTypes {
def buildArgumentExtractor(t: Type): Expression => InternalRow => Any = {
if (t =:= typeOf[Geometry]) {
expr => input => expr.toGeometry(input)
} else if (t =:= typeOf[Array[Geometry]]) {
expr => input => expr.toGeometryArray(input)
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterExtractor
} else if (t =:= typeOf[Array[Double]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray()
} else if (t =:= typeOf[String]) {
expr => input => expr.asString(input)
} else if (t =:= typeOf[Array[Long]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData].toLongArray()
} else if (t =:= typeOf[Array[Int]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData] match {
case null => null
case arrayData: ArrayData => arrayData.toIntArray()
}
} else if (t =:= typeOf[java.util.List[Geometry]]) {
expr => input => expr.toGeometryList(input)
} else if (t =:= typeOf[java.util.List[java.lang.Double]]) {
expr => input => expr.toDoubleList(input)
} else {
expr => input => expr.eval(input)
}
}
def buildSerializer(t: Type): Any => Any = {
if (t =:= typeOf[Geometry]) {
output =>
if (output != null) {
output.asInstanceOf[Geometry].toGenericArrayData
} else {
null
}
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterSerializer
} else if (t =:= typeOf[String]) {
output =>
if (output != null) {
UTF8String.fromString(output.asInstanceOf[String])
} else {
null
}
} else if (t =:= typeOf[Array[java.lang.Long]] || t =:= typeOf[Array[Long]] ||
t =:= typeOf[Array[Double]]) {
output =>
if (output != null) {
ArrayData.toArrayData(output)
} else {
null
}
}else if (t =:= typeOf[java.util.List[java.lang.Double]]) {
output =>
if (output != null) {
ArrayData.toArrayData(output.asInstanceOf[java.util.List[java.lang.Double]].map(elem => elem))
}else {
null
}
}
else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) {
output =>
if (output != null) {
ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData))
} else {
null
}
} else if (InferredRasterExpression.isRasterArrayType(t)) {
InferredRasterExpression.rasterArraySerializer
} else if (t =:= typeOf[Option[Boolean]]) {
output =>
if (output != null) {
output.asInstanceOf[Option[Boolean]].orNull
} else {
null
}
} else {
output => output
}
}
def inferSparkType(t: Type): DataType = {
if (t =:= typeOf[Geometry]) {
GeometryUDT
} else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) {
DataTypes.createArrayType(GeometryUDT)
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterUDT
} else if (InferredRasterExpression.isRasterArrayType(t)) {
InferredRasterExpression.rasterUDTArray
} else if (t =:= typeOf[java.lang.Double]) {
DoubleType
} else if (t =:= typeOf[java.lang.Integer]) {
IntegerType
} else if (t =:= typeOf[Double]) {
DoubleType
} else if (t =:= typeOf[Int]) {
IntegerType
} else if (t =:= typeOf[Long] || t =:= typeOf[java.lang.Long]) {
LongType
} else if (t =:= typeOf[String]) {
StringType
} else if (t =:= typeOf[Array[Byte]]) {
BinaryType
} else if (t =:= typeOf[Array[Int]] || t =:= typeOf[Array[java.lang.Integer]]) {
DataTypes.createArrayType(IntegerType)
} else if (t =:= typeOf[Array[Long]] || t =:= typeOf[Array[java.lang.Long]]) {
DataTypes.createArrayType(LongType)
} else if (t =:= typeOf[Array[Double]] || t =:= typeOf[java.util.List[java.lang.Double]]) {
DataTypes.createArrayType(DoubleType)
} else if (t =:= typeOf[Option[Boolean]]) {
BooleanType
} else if (t =:= typeOf[Boolean]) {
BooleanType
} else {
throw new IllegalArgumentException(s"Cannot infer spark type for $t")
}
}
}
case class InferrableFunction(sparkInputTypes: Seq[AbstractDataType],
sparkReturnType: DataType,
serializer: Any => Any,
argExtractorBuilders: Seq[Expression => InternalRow => Any],
evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any) {
def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = {
argExtractorBuilders.zipAll(expressions, null, null).flatMap {
case (null, _) => None
case (builder, expr) => Some(builder(expr))
}.toArray
}
}
object InferrableFunction {
/**
* Infer input types and return type from a type tag, and construct builder for argument extractors.
* @param typeTag Type tag of the function.
* @param evaluatorBuilder Builder for the evaluator.
* @return InferrableFunction.
*/
def apply(typeTag: TypeTag[_], evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any): InferrableFunction = {
val argTypes = typeTag.tpe.typeArgs.init
val returnType = typeTag.tpe.typeArgs.last
val sparkInputTypes: Seq[AbstractDataType] = argTypes.map(InferredTypes.inferSparkType)
val sparkReturnType: DataType = InferredTypes.inferSparkType(returnType)
val serializer = InferredTypes.buildSerializer(returnType)
val argExtractorBuilders = argTypes.map(InferredTypes.buildArgumentExtractor)
InferrableFunction(sparkInputTypes, sparkReturnType, serializer, argExtractorBuilders, evaluatorBuilder)
}
/**
* A variant of binary inferred expression which allows the second argument to be null.
* @param f Function to be wrapped as a catalyst expression.
* @param typeTag Type tag of the function.
* @tparam R Return type of the function.
* @tparam A1 Type of the first argument.
* @tparam A2 Type of the second argument.
* @return InferrableFunction.
*/
def allowRightNull[R, A1, A2](f: (A1, A2) => R)(implicit typeTag: TypeTag[(A1, A2) => R]): InferrableFunction = {
apply(typeTag, extractors => {
val func = f.asInstanceOf[(Any, Any) => Any]
val extractor1 = extractors(0)
val extractor2 = extractors(1)
input => {
val arg1 = extractor1(input)
val arg2 = extractor2(input)
if (arg1 != null) {
func(arg1, arg2)
} else {
null
}
}
})
}
}