[SPARK-54169][GEO][SQL] Introduce Geography and Geometry types in Arrow writer
### What changes were proposed in this pull request?
Add Arrow serialization/deserialization support for `Geography` and `Geometry` types.
### Why are the changes needed?
Supporting geospatial types for clients (Spark Connect / Thrift Server / etc.) which consume result sets in Arrow format.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added unit tests:
- `ArrowUtilsSuite`
- `ArrowWriterSuite`
- `ArrowEncoderSuite`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52863 from uros-db/geo-arrow-serde.
Authored-by: Uros Bojanic <uros.bojanic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
new file mode 100644
index 0000000..44bdffe
--- /dev/null
+++ b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/* The MaybeNull class is a utility that introduces controlled nullability into a sequence
+ * of invocations. It is designed to return a ~null~ value at a specified interval while returning
+ * the provided value for all other invocations.
+ */
+case class MaybeNull(interval: Int) {
+ assert(interval > 1)
+ private var invocations = 0
+ def apply[T](value: T): T = {
+ val result = if (invocations % interval == 0) {
+ null.asInstanceOf[T]
+ } else {
+ value
+ }
+ invocations += 1
+ result
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6caabf2..23d8a0b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -143,6 +143,43 @@
largeVarTypes)).asJava)
case udt: UserDefinedType[_] =>
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+ case g: GeometryType =>
+ val fieldType =
+ new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
+
+ // WKB field is tagged with additional metadata so we can identify that the arrow
+ // struct actually represents a geometry schema.
+ val wkbFieldType = new FieldType(
+ false,
+ toArrowType(BinaryType, timeZoneId, largeVarTypes),
+ null,
+ Map("geometry" -> "true", "srid" -> g.srid.toString).asJava)
+
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+ new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
+
+ case g: GeographyType =>
+ val fieldType =
+ new FieldType(nullable, ArrowType.Struct.INSTANCE, null, null)
+
+ // WKB field is tagged with additional metadata so we can identify that the arrow
+ // struct actually represents a geography schema.
+ val wkbFieldType = new FieldType(
+ false,
+ toArrowType(BinaryType, timeZoneId, largeVarTypes),
+ null,
+ Map("geography" -> "true", "srid" -> g.srid.toString).asJava)
+
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+ new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
case _: VariantType =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
// The metadata field is tagged with additional metadata so we can identify that the arrow
@@ -175,6 +212,26 @@
}
}
+ def isGeometryField(field: Field): Boolean = {
+ assert(field.getType.isInstanceOf[ArrowType.Struct])
+ field.getChildren.asScala
+ .map(_.getName)
+ .asJava
+ .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+ child.getName == "wkb" && child.getMetadata.getOrDefault("geometry", "false") == "true"
+ }
+ }
+
+ def isGeographyField(field: Field): Boolean = {
+ assert(field.getType.isInstanceOf[ArrowType.Struct])
+ field.getChildren.asScala
+ .map(_.getName)
+ .asJava
+ .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+ child.getName == "wkb" && child.getMetadata.getOrDefault("geography", "false") == "true"
+ }
+ }
+
def fromArrowField(field: Field): DataType = {
field.getType match {
case _: ArrowType.Map =>
@@ -188,6 +245,26 @@
ArrayType(elementType, containsNull = elementField.isNullable)
case ArrowType.Struct.INSTANCE if isVariantField(field) =>
VariantType
+ case ArrowType.Struct.INSTANCE if isGeometryField(field) =>
+ // We expect that type metadata is associated with wkb field.
+ val metadataField =
+ field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+ val srid = metadataField.getMetadata.get("srid").toInt
+ if (srid == GeometryType.MIXED_SRID) {
+ GeometryType("ANY")
+ } else {
+ GeometryType(srid)
+ }
+ case ArrowType.Struct.INSTANCE if isGeographyField(field) =>
+ // We expect that type metadata is associated with wkb field.
+ val metadataField =
+ field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+ val srid = metadataField.getMetadata.get("srid").toInt
+ if (srid == GeographyType.MIXED_SRID) {
+ GeographyType("ANY")
+ } else {
+ GeographyType(srid)
+ }
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
index 3cf4b84..0a9942c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
@@ -111,6 +111,10 @@
return toPhysVal(Geometry.fromWkb(wkb));
}
+ public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) {
+ return toPhysVal(Geometry.fromWkb(wkb, srid));
+ }
+
// ST_SetSrid
public static GeographyVal stSetSrid(GeographyVal geo, int srid) {
// We only allow setting the SRID to geographic values.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 66116d7..019bc25 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -25,9 +25,12 @@
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.sql.catalyst.util.STUtils;
import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -146,6 +149,30 @@
super(type);
}
+ @Override
+ public GeographyVal getGeography(int rowId) {
+ if (isNullAt(rowId)) return null;
+
+ GeographyType gt = (GeographyType) this.type;
+ int srid = getChild(0).getInt(rowId);
+ byte[] bytes = getChild(1).getBinary(rowId);
+ gt.assertSridAllowedForType(srid);
+ // TODO(GEO-602): Geog still does not support different SRIDs, once it does,
+ // we need to update this.
+ return (bytes == null) ? null : STUtils.stGeogFromWKB(bytes);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int rowId) {
+ if (isNullAt(rowId)) return null;
+
+ GeometryType gt = (GeometryType) this.type;
+ int srid = getChild(0).getInt(rowId);
+ byte[] bytes = getChild(1).getBinary(rowId);
+ gt.assertSridAllowedForType(srid);
+ return (bytes == null) ? null : STUtils.stGeomFromWKB(bytes, srid);
+ }
+
public ArrowColumnVector(ValueVector vector) {
this(ArrowUtils.fromArrowField(vector.getField()));
initAccessor(vector);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 275fece..8d68e74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -24,6 +24,7 @@
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.STUtils
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
@@ -92,6 +93,16 @@
createFieldWriter(vector.getChildByOrdinal(ordinal))
}
new StructWriter(vector, children.toArray)
+ case (dt: GeometryType, vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new GeometryWriter(dt, vector, children.toArray)
+ case (dt: GeographyType, vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new GeographyWriter(dt, vector, children.toArray)
case (dt, _) =>
throw ExecutionErrors.unsupportedDataTypeError(dt)
}
@@ -446,6 +457,42 @@
}
}
+private[arrow] class GeographyWriter(
+ dt: GeographyType,
+ valueVector: StructVector,
+ children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setIndexDefined(count)
+
+ val geom = STUtils.deserializeGeog(input.getGeography(ordinal), dt)
+ val bytes = geom.getBytes
+ val srid = geom.getSrid
+
+ val row = InternalRow(srid, bytes)
+ children(0).write(row, 0)
+ children(1).write(row, 1)
+ }
+}
+
+private[arrow] class GeometryWriter(
+ dt: GeometryType,
+ valueVector: StructVector,
+ children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setIndexDefined(count)
+
+ val geom = STUtils.deserializeGeom(input.getGeometry(ordinal), dt)
+ val bytes = geom.getBytes
+ val srid = geom.getSrid
+
+ val row = InternalRow(srid, bytes)
+ children(0).write(row, 0)
+ children(1).write(row, 1)
+ }
+}
+
private[arrow] class MapWriter(
val valueVector: MapVector,
val structVector: StructVector,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index 7124c94..8011e69 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -49,6 +49,10 @@
roundtrip(BinaryType)
roundtrip(DecimalType.SYSTEM_DEFAULT)
roundtrip(DateType)
+ roundtrip(GeometryType("ANY"))
+ roundtrip(GeometryType(4326))
+ roundtrip(GeographyType("ANY"))
+ roundtrip(GeographyType(4326))
roundtrip(YearMonthIntervalType())
roundtrip(DayTimeIntervalType())
checkError(
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index b29d73b..bc840df 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -46,7 +46,7 @@
import org.apache.spark.sql.connect.test.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.VariantVal
-import org.apache.spark.util.SparkStringUtils
+import org.apache.spark.util.{MaybeNull, SparkStringUtils}
/**
* Tests for encoding external data to and from arrow.
@@ -218,20 +218,6 @@
}
}
- private case class MaybeNull(interval: Int) {
- assert(interval > 1)
- private var invocations = 0
- def apply[T](value: T): T = {
- val result = if (invocations % interval == 0) {
- null.asInstanceOf[T]
- } else {
- value
- }
- invocations += 1
- result
- }
- }
-
private def javaBigDecimal(i: Int): java.math.BigDecimal = {
javaBigDecimal(i, DecimalType.DEFAULT_SCALE)
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 6b9d08f..a5b5c39 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -51,6 +51,13 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-common-utils_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <classifier>tests</classifier>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 99d2455..2c0c049 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -17,16 +17,23 @@
package org.apache.spark.sql.execution.arrow
+import scala.jdk.CollectionConverters._
+
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
import org.apache.spark.sql.YearUDT
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder}
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.util.{Geography => InternalGeography, Geometry => InternalGeometry}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, GeographyVal, GeometryVal, UTF8String}
+import org.apache.spark.util.MaybeNull
class ArrowWriterSuite extends SparkFunSuite {
@@ -52,8 +59,16 @@
}
writer.finish()
+ val dataModified = data.map { datum =>
+ dt match {
+ case _: GeometryType => datum.asInstanceOf[GeometryVal].getBytes
+ case _: GeographyType => datum.asInstanceOf[GeographyVal].getBytes
+ case _ => datum
+ }
+ }
+
val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
- data.zipWithIndex.foreach {
+ dataModified.zipWithIndex.foreach {
case (null, rowId) => assert(reader.isNullAt(rowId))
case (datum, rowId) =>
val value = datatype match {
@@ -74,12 +89,31 @@
case _: YearMonthIntervalType => reader.getInt(rowId)
case _: DayTimeIntervalType => reader.getLong(rowId)
case CalendarIntervalType => reader.getInterval(rowId)
+ case _: GeometryType => reader.getGeometry(rowId).getBytes
+ case _: GeographyType => reader.getGeography(rowId).getBytes
}
assert(value === datum)
}
writer.root.close()
}
+
+ val wkbs = Seq("010100000000000000000031400000000000001c40",
+ "010100000000000000000034400000000000003540")
+ .map { x =>
+ x.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ }
+
+ val geographies = wkbs.map(x => InternalGeography.fromWkb(x, 4326).getValue)
+ val geometries = wkbs.map(x => InternalGeometry.fromWkb(x, 0).getValue)
+ val mixedGeometries = wkbs.zip(Seq(0, 4326)).map {
+ case (g, srid) => InternalGeometry.fromWkb(g, srid).getValue
+ }
+
+ check(GeometryType(0), geometries)
+ check(GeographyType(4326), geographies)
+ check(GeometryType("ANY"), mixedGeometries)
+ check(GeographyType("ANY"), geographies)
check(BooleanType, Seq(true, null, false))
check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte))
check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort))
@@ -110,6 +144,245 @@
check(new YearUDT, Seq(2020, 2021, null, 2022))
}
+ test("nested geographies") {
+ def check(
+ dt: StructType,
+ data: Seq[InternalRow]): Unit = {
+ val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC")
+
+ // Write data to arrow.
+ data.toSeq.foreach { datum =>
+ writer.write(datum)
+ }
+ writer.finish()
+
+ // Create arrow vector readers.
+ val vectors = writer.root.getFieldVectors.asScala
+ .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]]
+
+ val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt)
+
+ data.zipWithIndex.foreach { case (datum, i) =>
+ // Read data from arrow.
+ val internalRow = batch.getRow(i)
+
+ // All nullable results first must check whether the value is null.
+ if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) {
+ assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null)
+ } else {
+ val expectedStruct = datum.getStruct(0, 4)
+ val actualStruct = internalRow.getStruct(0, 4)
+ assert(expectedStruct.getInt(0) === actualStruct.getInt(0))
+ assert(expectedStruct.getInt(2) === actualStruct.getInt(2))
+
+ if (expectedStruct.getGeography(1) == null ||
+ actualStruct.getGeography(1) == null) {
+ assert(expectedStruct.getGeography(1) == null && actualStruct.getGeography(1) == null)
+ } else {
+ assert(expectedStruct.getGeography(1).getBytes ===
+ actualStruct.getGeography(1).getBytes)
+ }
+ if (expectedStruct.getGeography(3) == null ||
+ actualStruct.getGeography(3) == null) {
+ assert(expectedStruct.getGeography(3) == null && actualStruct.getGeography(3) == null)
+ } else {
+ assert(expectedStruct.getGeography(3).getBytes ===
+ actualStruct.getGeography(3).getBytes)
+ }
+
+ if (datum.getArray(1) == null ||
+ internalRow.getArray(1) == null) {
+ assert(internalRow.getArray(1) == null && datum.getArray(1) == null)
+ } else {
+ internalRow.getArray(1).toSeq[GeographyVal](GeographyType(4326))
+ .zip(datum.getArray(1).toSeq[GeographyVal](GeographyType(4326))).foreach {
+ case (actual, expected) =>
+ assert(actual.getBytes === expected.getBytes)
+ }
+ }
+
+ if (datum.getMap(2) == null ||
+ internalRow.getMap(2) == null) {
+ assert(internalRow.getMap(2) == null && datum.getMap(2) == null)
+ } else {
+ assert(internalRow.getMap(2).keyArray().toSeq(StringType) ===
+ datum.getMap(2).keyArray().toSeq(StringType))
+ internalRow.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY"))
+ .zip(datum.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY"))).foreach {
+ case (actual, expected) =>
+ assert((actual == null && expected == null) ||
+ actual.getBytes === expected.getBytes)
+ }
+ }
+ }
+ }
+
+ writer.root.close()
+ }
+
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+ val schema = new StructType()
+ .add(
+ "s",
+ new StructType()
+ .add("i1", "int")
+ .add("g0", "geography(4326)")
+ .add("i2", "int")
+ .add("g1", "geography(4326)"))
+ .add("a", "array<geography(4326)>")
+ .add("m", "map<string, geography(ANY)>")
+
+ val maybeNull5 = MaybeNull(5)
+ val maybeNull7 = MaybeNull(7)
+ val maybeNull11 = MaybeNull(11)
+ val maybeNull13 = MaybeNull(13)
+ val maybeNull17 = MaybeNull(17)
+
+ val nestedGeographySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer()
+ val data = Iterator
+ .tabulate(100)(i =>
+ nestedGeographySerializer.apply(
+ (Row(
+ maybeNull5(
+ Row(
+ i,
+ maybeNull7(org.apache.spark.sql.types.Geography.fromWKB(point1)),
+ i + 1,
+ maybeNull11(org.apache.spark.sql.types.Geography.fromWKB(point2, 4326)))),
+ maybeNull7((0 until 10).map(j =>
+ org.apache.spark.sql.types.Geography.fromWKB(point2, 4326))),
+ maybeNull13(
+ Map((i.toString, maybeNull17(
+ org.apache.spark.sql.types.Geography.fromWKB(point1, 4326)))))))))
+ .map(_.copy()).toSeq
+
+ check(schema, data)
+ }
+
+ test("nested geometries") {
+ def check(
+ dt: StructType,
+ data: Seq[InternalRow]): Unit = {
+ val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC")
+
+ // Write data to arrow.
+ data.toSeq.foreach { datum =>
+ writer.write(datum)
+ }
+ writer.finish()
+
+ // Create arrow vector readers.
+ val vectors = writer.root.getFieldVectors.asScala
+ .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]]
+
+ val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt)
+ data.zipWithIndex.foreach { case (datum, i) =>
+ // Read data from arrow.
+ val internalRow = batch.getRow(i)
+
+ // All nullable results first must check whether the value is null.
+ if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) {
+ assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null)
+ } else {
+ val expectedStruct = datum.getStruct(0, 4)
+ val actualStruct = internalRow.getStruct(0, 4)
+ assert(expectedStruct.getInt(0) === actualStruct.getInt(0))
+ assert(expectedStruct.getInt(2) === actualStruct.getInt(2))
+
+ if (expectedStruct.getGeometry(1) == null ||
+ actualStruct.getGeometry(1) == null) {
+ assert(expectedStruct.getGeometry(1) == null && actualStruct.getGeometry(1) == null)
+ } else {
+ assert(expectedStruct.getGeometry(1).getBytes ===
+ actualStruct.getGeometry(1).getBytes)
+ }
+ if (expectedStruct.getGeometry(3) == null ||
+ actualStruct.getGeometry(3) == null) {
+ assert(expectedStruct.getGeometry(3) == null && actualStruct.getGeometry(3) == null)
+ } else {
+ assert(expectedStruct.getGeometry(3).getBytes ===
+ actualStruct.getGeometry(3).getBytes)
+ }
+
+ if (datum.getArray(1) == null ||
+ internalRow.getArray(1) == null) {
+ assert(internalRow.getArray(1) == null && datum.getArray(1) == null)
+ } else {
+ internalRow.getArray(1).toSeq[GeometryVal](GeometryType(0))
+ .zip(datum.getArray(1).toSeq[GeometryVal](GeometryType(0))).foreach {
+ case (actual, expected) =>
+ assert(actual.getBytes === expected.getBytes)
+ }
+ }
+
+ if (datum.getMap(2) == null ||
+ internalRow.getMap(2) == null) {
+ assert(internalRow.getMap(2) == null && datum.getMap(2) == null)
+ } else {
+ assert(internalRow.getMap(2).keyArray().toSeq(StringType) ===
+ datum.getMap(2).keyArray().toSeq(StringType))
+ internalRow.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY"))
+ .zip(datum.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY"))).foreach {
+ case (actual, expected) =>
+ assert((actual == null && expected == null) ||
+ actual.getBytes === expected.getBytes)
+ }
+ }
+ }
+ }
+
+ writer.root.close()
+ }
+
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+ val schema = new StructType()
+ .add(
+ "s",
+ new StructType()
+ .add("i1", "int")
+ .add("g0", "geometry(0)")
+ .add("i2", "int")
+ .add("g4326", "geometry(4326)"))
+ .add("a", "array<geometry(0)>")
+ .add("m", "map<string, geometry(ANY)>")
+
+ val maybeNull5 = MaybeNull(5)
+ val maybeNull7 = MaybeNull(7)
+ val maybeNull11 = MaybeNull(11)
+ val maybeNull13 = MaybeNull(13)
+ val maybeNull17 = MaybeNull(17)
+
+ val nestedGeometrySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer()
+ val data = Iterator
+ .tabulate(100) { i =>
+ val mixedSrid = if (i % 2 == 0) 0 else 4326
+
+ nestedGeometrySerializer.apply(
+ (Row(
+ maybeNull5(
+ Row(
+ i,
+ maybeNull7(org.apache.spark.sql.types.Geometry.fromWKB(point1, 0)),
+ i + 1,
+ maybeNull11(org.apache.spark.sql.types.Geometry.fromWKB(point2, 4326)))),
+ maybeNull7((0 until 10).map(j =>
+ org.apache.spark.sql.types.Geometry.fromWKB(point2, 0))),
+ maybeNull13(
+ Map((i.toString, maybeNull17(
+ org.apache.spark.sql.types.Geometry.fromWKB(point1, mixedSrid))))))))
+ }.map(_.copy()).toSeq
+
+ check(schema, data)
+ }
+
test("get multiple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
val datatype = dt match {