[SEDONA-478] Make Sedona geometry functions and spatial join working without GeoTools (#1398)
diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
index e3152e4..1d5c1ab 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
@@ -19,6 +19,7 @@
package org.apache.sedona.sql
import org.apache.sedona.sql.UDF.RasterUdafCatalog
+import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.{gridClassName, isGeoToolsAvailable}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
import org.apache.spark.sql.{SparkSession, functions}
@@ -26,19 +27,6 @@
object RasterRegistrator {
val logger: Logger = LoggerFactory.getLogger(getClass)
- private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
-
- // Helper method to check if GridCoverage2D is available
- private def isGeoToolsAvailable: Boolean = {
- try {
- Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
- true
- } catch {
- case _: ClassNotFoundException =>
- logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
- false
- }
- }
def registerAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala
new file mode 100644
index 0000000..1d197c2
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.utils
+
+import org.apache.sedona.sql.RasterRegistrator.logger
+
+/**
+ * A helper object to check if GeoTools GridCoverage2D is available on the classpath.
+ */
+object GeoToolsCoverageAvailability {
+ val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
+
+ lazy val isGeoToolsAvailable: Boolean = {
+ try {
+ Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
+ true
+ } catch {
+ case _: ClassNotFoundException =>
+ logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
+ false
+ }
+ }
+}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala
new file mode 100644
index 0000000..2d3349d
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
+import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer}
+import org.apache.spark.sql.types.{ArrayType, DataTypes, UserDefinedType}
+
+import scala.reflect.runtime.universe.{Type, typeOf}
+import org.geotools.coverage.grid.GridCoverage2D
+
+object InferrableRasterTypes {
+ implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
+ new InferrableType[GridCoverage2D] {}
+ implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] =
+ new InferrableType[Array[GridCoverage2D]] {}
+
+ def isRasterType(t: Type): Boolean = t =:= typeOf[GridCoverage2D]
+ def isRasterArrayType(t: Type): Boolean = t =:= typeOf[Array[GridCoverage2D]]
+
+ val rasterUDT: UserDefinedType[_] = RasterUDT
+ val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT)
+
+ def rasterExtractor(expr: Expression)(input: InternalRow): Any = expr.toRaster(input)
+
+ def rasterSerializer(output: Any): Any =
+ if (output != null) {
+ output.asInstanceOf[GridCoverage2D].serialize
+ } else {
+ null
+ }
+
+ def rasterArraySerializer(output: Any): Any =
+ if (output != null) {
+ val rasters = output.asInstanceOf[Array[GridCoverage2D]]
+ val serialized = rasters.map { raster =>
+ val serialized = raster.serialize
+ raster.dispose(true)
+ serialized
+ }
+ ArrayData.toArrayData(serialized)
+ } else {
+ null
+ }
+}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index 28096c4..6b9f89c 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -22,13 +22,11 @@
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
-import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
-import org.geotools.coverage.grid.GridCoverage2D
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.reflect.runtime.universe.TypeTag
@@ -75,14 +73,10 @@
// This is a compile time type shield for the types we are able to infer. Anything
// other than these types will cause a compilation error. This is the Scala
// 2 way of making a union type.
-sealed class InferrableType[T: TypeTag]
+class InferrableType[T: TypeTag]
object InferrableType {
implicit val geometryInstance: InferrableType[Geometry] =
new InferrableType[Geometry] {}
- implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
- new InferrableType[GridCoverage2D] {}
- implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] =
- new InferrableType[Array[GridCoverage2D]] {}
implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
new InferrableType[Array[Geometry]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
@@ -127,8 +121,8 @@
expr => input => expr.toGeometry(input)
} else if (t =:= typeOf[Array[Geometry]]) {
expr => input => expr.toGeometryArray(input)
- } else if (t =:= typeOf[GridCoverage2D]) {
- expr => input => expr.toRaster(input)
+ } else if (InferredRasterExpression.isRasterType(t)) {
+ InferredRasterExpression.rasterExtractor
} else if (t =:= typeOf[Array[Double]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray()
} else if (t =:= typeOf[String]) {
@@ -156,14 +150,8 @@
} else {
null
}
- } else if (t =:= typeOf[GridCoverage2D]) {
- output => {
- if (output != null) {
- output.asInstanceOf[GridCoverage2D].serialize
- } else {
- null
- }
- }
+ } else if (InferredRasterExpression.isRasterType(t)) {
+ InferredRasterExpression.rasterSerializer
} else if (t =:= typeOf[String]) {
output =>
if (output != null) {
@@ -194,19 +182,8 @@
} else {
null
}
- } else if (t =:= typeOf[Array[GridCoverage2D]]) {
- output =>
- if (output != null) {
- val rasters = output.asInstanceOf[Array[GridCoverage2D]]
- val serialized = rasters.map { raster =>
- val serialized = raster.serialize
- raster.dispose(true)
- serialized
- }
- ArrayData.toArrayData(serialized)
- } else {
- null
- }
+ } else if (InferredRasterExpression.isRasterArrayType(t)) {
+ InferredRasterExpression.rasterArraySerializer
} else if (t =:= typeOf[Option[Boolean]]) {
output =>
if (output != null) {
@@ -224,10 +201,10 @@
GeometryUDT
} else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) {
DataTypes.createArrayType(GeometryUDT)
- } else if (t =:= typeOf[GridCoverage2D]) {
- RasterUDT
- } else if (t =:= typeOf[Array[GridCoverage2D]]) {
- DataTypes.createArrayType(RasterUDT)
+ } else if (InferredRasterExpression.isRasterType(t)) {
+ InferredRasterExpression.rasterUDT
+ } else if (InferredRasterExpression.isRasterArrayType(t)) {
+ InferredRasterExpression.rasterUDTArray
} else if (t =:= typeOf[java.lang.Double]) {
DoubleType
} else if (t =:= typeOf[java.lang.Integer]) {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala
new file mode 100644
index 0000000..9c6875a
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.isGeoToolsAvailable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.types.{ArrayType, UserDefinedType}
+
+import scala.reflect.runtime.universe.{Type, typeOf}
+
+object InferredRasterExpression {
+ def isRasterType(t: Type): Boolean =
+ isGeoToolsAvailable && InferrableRasterTypes.isRasterType(t)
+
+ def isRasterArrayType(t: Type): Boolean =
+ isGeoToolsAvailable && InferrableRasterTypes.isRasterArrayType(t)
+
+ def rasterUDT: UserDefinedType[_] = if (isGeoToolsAvailable) {
+ InferrableRasterTypes.rasterUDT
+ } else {
+ null
+ }
+
+ def rasterUDTArray: ArrayType = if (isGeoToolsAvailable) {
+ InferrableRasterTypes.rasterUDTArray
+ } else {
+ null
+ }
+
+ val rasterExtractor: Expression => InternalRow => Any = if (isGeoToolsAvailable) {
+ InferrableRasterTypes.rasterExtractor
+ } else {
+ _ => _ => null
+ }
+
+ val rasterSerializer: Any => Any = if (isGeoToolsAvailable) {
+ InferrableRasterTypes.rasterSerializer
+ } else {
+ (_: Any) => null
+ }
+
+ val rasterArraySerializer: Any => Any = if (isGeoToolsAvailable) {
+ InferrableRasterTypes.rasterArraySerializer
+ } else {
+ (_: Any) => null
+ }
+}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index 85719ce..a8baca9 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -22,13 +22,10 @@
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
-import org.apache.spark.sql.types.{ByteType, DataTypes}
+import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}
-import java.util
-
object implicits {
implicit class InputExpressionEnhancer(inputExpression: Expression) {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala
index fa8390a..e13e81d 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala
@@ -22,6 +22,7 @@
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
case class RS_ConvexHull(inputExpressions: Seq[Expression]) extends InferredExpression(GeometryFunctions.convexHull _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
index bd30844..42fb3fd 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
@@ -23,6 +23,7 @@
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
/// Calculate Normalized Difference between two bands
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala
index 10ea368..fc87706 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala
@@ -21,6 +21,7 @@
import org.apache.sedona.common.raster.PixelFunctionEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
case class RS_SetValues(inputExpressions: Seq[Expression]) extends InferredExpression(
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
index f5499f2..22315ed 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
@@ -26,6 +26,7 @@
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DoubleType, IntegerType, StructType}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
index f0039c6..b7ffbbe 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
@@ -21,6 +21,7 @@
import org.apache.sedona.common.raster.RasterAccessors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
case class RS_NumBands(inputExpressions: Seq[Expression]) extends InferredExpression(RasterAccessors.numBands _) {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala
index 11b4152..b64a9b5 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala
@@ -25,6 +25,7 @@
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.geotools.coverage.grid.GridCoverage2D
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala
index de782a8..de7f57d 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala
@@ -21,6 +21,7 @@
import org.apache.sedona.common.raster.RasterBandEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
case class RS_SetBandNoDataValue(inputExpressions: Seq[Expression]) extends InferredExpression(
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
index ae6c9e1..1e4a6a8 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
@@ -25,6 +25,7 @@
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer}
import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal, IntegerType, NullType, StructType}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
index 3b13e03..db77310 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
@@ -21,6 +21,7 @@
import org.apache.sedona.common.raster.RasterEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
case class RS_SetSRID(inputExpressions: Seq[Expression]) extends InferredExpression(RasterEditors.setSrid _) {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
index 4d9e375..07a0673 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
@@ -21,6 +21,7 @@
import org.apache.sedona.common.raster.RasterOutputs
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
case class RS_AsGeoTiff(inputExpressions: Seq[Expression])
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index e34b1b8..7086812 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -34,7 +34,7 @@
leftShapeExpr: Expression,
rightRdd: RDD[UnsafeRow],
rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
- if (leftShapeExpr.dataType.acceptsType(RasterUDT) || rightShapeExpr.dataType.acceptsType(RasterUDT)) {
+ if (leftShapeExpr.dataType.isInstanceOf[RasterUDT] || rightShapeExpr.dataType.isInstanceOf[RasterUDT]) {
(toWGS84EnvelopeRDD(leftRdd, leftShapeExpr),
toWGS84EnvelopeRDD(rightRdd, rightShapeExpr))
} else {
@@ -60,7 +60,7 @@
// transformation for both sides. We use expanded WGS84 envelope as the joined geometries and perform a
// coarse-grained spatial join.
val spatialRdd = new SpatialRDD[Geometry]
- val wgs84EnvelopeRdd = if (shapeExpression.dataType.acceptsType(RasterUDT)) {
+ val wgs84EnvelopeRdd = if (shapeExpression.dataType.isInstanceOf[RasterUDT]) {
rdd.map { row =>
val raster = RasterSerializer.deserialize(shapeExpression.eval(row).asInstanceOf[Array[Byte]])
val shape = JoinedGeometryRaster.rasterToWGS84Envelope(raster)