[SPARK-54169][GEO][SQL] Introduce Geography and Geometry types in Arrow writer

### What changes were proposed in this pull request?
Add Arrow serialization/deserialization support for `Geography` and `Geometry` types.

### Why are the changes needed?
Supporting geospatial types for clients (Spark Connect / Thrift Server / etc.) which consume result sets in Arrow format.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added unit tests:
- `ArrowUtilsSuite`
- `ArrowWriterSuite`
- `ArrowEncoderSuite`

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #52863 from uros-db/geo-arrow-serde.

Authored-by: Uros Bojanic <uros.bojanic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
new file mode 100644
index 0000000..44bdffe
--- /dev/null
+++ b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/* The MaybeNull class is a utility that introduces controlled nullability into a sequence
+ * of invocations. It is designed to return a ~null~ value at a specified interval while returning
+ * the provided value for all other invocations.
+ */
+case class MaybeNull(interval: Int) {
+  assert(interval > 1)
+  private var invocations = 0
+  def apply[T](value: T): T = {
+    val result = if (invocations % interval == 0) {
+      null.asInstanceOf[T]
+    } else {
+      value
+    }
+    invocations += 1
+    result
+  }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6caabf2..23d8a0b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -143,6 +143,43 @@
               largeVarTypes)).asJava)
       case udt: UserDefinedType[_] =>
         toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+      case g: GeometryType =>
+        val fieldType =
+          new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
+
+        // WKB field is tagged with additional metadata so we can identify that the arrow
+        // struct actually represents a geometry schema.
+        val wkbFieldType = new FieldType(
+          false,
+          toArrowType(BinaryType, timeZoneId, largeVarTypes),
+          null,
+          Map("geometry" -> "true", "srid" -> g.srid.toString).asJava)
+
+        new Field(
+          name,
+          fieldType,
+          Seq(
+            toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+            new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
+
+      case g: GeographyType =>
+        val fieldType =
+          new FieldType(nullable, ArrowType.Struct.INSTANCE, null, null)
+
+        // WKB field is tagged with additional metadata so we can identify that the arrow
+        // struct actually represents a geography schema.
+        val wkbFieldType = new FieldType(
+          false,
+          toArrowType(BinaryType, timeZoneId, largeVarTypes),
+          null,
+          Map("geography" -> "true", "srid" -> g.srid.toString).asJava)
+
+        new Field(
+          name,
+          fieldType,
+          Seq(
+            toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+            new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
       case _: VariantType =>
         val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
         // The metadata field is tagged with additional metadata so we can identify that the arrow
@@ -175,6 +212,26 @@
     }
   }
 
+  def isGeometryField(field: Field): Boolean = {
+    assert(field.getType.isInstanceOf[ArrowType.Struct])
+    field.getChildren.asScala
+      .map(_.getName)
+      .asJava
+      .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+      child.getName == "wkb" && child.getMetadata.getOrDefault("geometry", "false") == "true"
+    }
+  }
+
+  def isGeographyField(field: Field): Boolean = {
+    assert(field.getType.isInstanceOf[ArrowType.Struct])
+    field.getChildren.asScala
+      .map(_.getName)
+      .asJava
+      .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+      child.getName == "wkb" && child.getMetadata.getOrDefault("geography", "false") == "true"
+    }
+  }
+
   def fromArrowField(field: Field): DataType = {
     field.getType match {
       case _: ArrowType.Map =>
@@ -188,6 +245,26 @@
         ArrayType(elementType, containsNull = elementField.isNullable)
       case ArrowType.Struct.INSTANCE if isVariantField(field) =>
         VariantType
+      case ArrowType.Struct.INSTANCE if isGeometryField(field) =>
+        // We expect that type metadata is associated with wkb field.
+        val metadataField =
+          field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+        val srid = metadataField.getMetadata.get("srid").toInt
+        if (srid == GeometryType.MIXED_SRID) {
+          GeometryType("ANY")
+        } else {
+          GeometryType(srid)
+        }
+      case ArrowType.Struct.INSTANCE if isGeographyField(field) =>
+        // We expect that type metadata is associated with wkb field.
+        val metadataField =
+          field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+        val srid = metadataField.getMetadata.get("srid").toInt
+        if (srid == GeographyType.MIXED_SRID) {
+          GeographyType("ANY")
+        } else {
+          GeographyType(srid)
+        }
       case ArrowType.Struct.INSTANCE =>
         val fields = field.getChildren().asScala.map { child =>
           val dt = fromArrowField(child)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
index 3cf4b84..0a9942c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
@@ -111,6 +111,10 @@
     return toPhysVal(Geometry.fromWkb(wkb));
   }
 
+  public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) {
+    return toPhysVal(Geometry.fromWkb(wkb, srid));
+  }
+
   // ST_SetSrid
   public static GeographyVal stSetSrid(GeographyVal geo, int srid) {
     // We only allow setting the SRID to geographic values.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 66116d7..019bc25 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -25,9 +25,12 @@
 
 import org.apache.spark.SparkUnsupportedOperationException;
 import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.sql.catalyst.util.STUtils;
 import org.apache.spark.sql.util.ArrowUtils;
 import org.apache.spark.sql.types.*;
 import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
 import org.apache.spark.unsafe.types.UTF8String;
 
 /**
@@ -146,6 +149,30 @@
      super(type);
   }
 
+  @Override
+  public GeographyVal getGeography(int rowId) {
+    if (isNullAt(rowId)) return null;
+
+    GeographyType gt = (GeographyType) this.type;
+    int srid = getChild(0).getInt(rowId);
+    byte[] bytes = getChild(1).getBinary(rowId);
+    gt.assertSridAllowedForType(srid);
+    // TODO(GEO-602): Geog still does not support different SRIDs, once it does,
+    // we need to update this.
+    return (bytes == null) ? null : STUtils.stGeogFromWKB(bytes);
+  }
+
+  @Override
+  public GeometryVal getGeometry(int rowId) {
+    if (isNullAt(rowId)) return null;
+
+    GeometryType gt = (GeometryType) this.type;
+    int srid = getChild(0).getInt(rowId);
+    byte[] bytes = getChild(1).getBinary(rowId);
+    gt.assertSridAllowedForType(srid);
+    return (bytes == null) ? null : STUtils.stGeomFromWKB(bytes, srid);
+  }
+
   public ArrowColumnVector(ValueVector vector) {
     this(ArrowUtils.fromArrowField(vector.getField()));
     initAccessor(vector);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 275fece..8d68e74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -24,6 +24,7 @@
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.STUtils
 import org.apache.spark.sql.errors.ExecutionErrors
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
@@ -92,6 +93,16 @@
           createFieldWriter(vector.getChildByOrdinal(ordinal))
         }
         new StructWriter(vector, children.toArray)
+      case (dt: GeometryType, vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new GeometryWriter(dt, vector, children.toArray)
+      case (dt: GeographyType, vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new GeographyWriter(dt, vector, children.toArray)
       case (dt, _) =>
         throw ExecutionErrors.unsupportedDataTypeError(dt)
     }
@@ -446,6 +457,42 @@
   }
 }
 
+private[arrow] class GeographyWriter(
+    dt: GeographyType,
+    valueVector: StructVector,
+    children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setIndexDefined(count)
+
+    val geom = STUtils.deserializeGeog(input.getGeography(ordinal), dt)
+    val bytes = geom.getBytes
+    val srid = geom.getSrid
+
+    val row = InternalRow(srid, bytes)
+    children(0).write(row, 0)
+    children(1).write(row, 1)
+  }
+}
+
+private[arrow] class GeometryWriter(
+    dt: GeometryType,
+    valueVector: StructVector,
+    children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setIndexDefined(count)
+
+    val geom = STUtils.deserializeGeom(input.getGeometry(ordinal), dt)
+    val bytes = geom.getBytes
+    val srid = geom.getSrid
+
+    val row = InternalRow(srid, bytes)
+    children(0).write(row, 0)
+    children(1).write(row, 1)
+  }
+}
+
 private[arrow] class MapWriter(
     val valueVector: MapVector,
     val structVector: StructVector,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index 7124c94..8011e69 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -49,6 +49,10 @@
     roundtrip(BinaryType)
     roundtrip(DecimalType.SYSTEM_DEFAULT)
     roundtrip(DateType)
+    roundtrip(GeometryType("ANY"))
+    roundtrip(GeometryType(4326))
+    roundtrip(GeographyType("ANY"))
+    roundtrip(GeographyType(4326))
     roundtrip(YearMonthIntervalType())
     roundtrip(DayTimeIntervalType())
     checkError(
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index b29d73b..bc840df 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -46,7 +46,7 @@
 import org.apache.spark.sql.connect.test.ConnectFunSuite
 import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
 import org.apache.spark.unsafe.types.VariantVal
-import org.apache.spark.util.SparkStringUtils
+import org.apache.spark.util.{MaybeNull, SparkStringUtils}
 
 /**
  * Tests for encoding external data to and from arrow.
@@ -218,20 +218,6 @@
     }
   }
 
-  private case class MaybeNull(interval: Int) {
-    assert(interval > 1)
-    private var invocations = 0
-    def apply[T](value: T): T = {
-      val result = if (invocations % interval == 0) {
-        null.asInstanceOf[T]
-      } else {
-        value
-      }
-      invocations += 1
-      result
-    }
-  }
-
   private def javaBigDecimal(i: Int): java.math.BigDecimal = {
     javaBigDecimal(i, DecimalType.DEFAULT_SCALE)
   }
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 6b9d08f..a5b5c39 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -51,6 +51,13 @@
     </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
+      <artifactId>spark-common-utils_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <classifier>tests</classifier>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
       <artifactId>spark-core_${scala.binary.version}</artifactId>
       <version>${project.version}</version>
     </dependency>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 99d2455..2c0c049 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -17,16 +17,23 @@
 
 package org.apache.spark.sql.execution.arrow
 
+import scala.jdk.CollectionConverters._
+
 import org.apache.arrow.vector.VectorSchemaRoot
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.YearUDT
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder}
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.util.{Geography => InternalGeography, Geometry => InternalGeometry}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, GeographyVal, GeometryVal, UTF8String}
+import org.apache.spark.util.MaybeNull
 
 class ArrowWriterSuite extends SparkFunSuite {
 
@@ -52,8 +59,16 @@
       }
       writer.finish()
 
+      val dataModified = data.map { datum =>
+        dt match {
+          case _: GeometryType => datum.asInstanceOf[GeometryVal].getBytes
+          case _: GeographyType => datum.asInstanceOf[GeographyVal].getBytes
+          case _ => datum
+        }
+      }
+
       val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
-      data.zipWithIndex.foreach {
+      dataModified.zipWithIndex.foreach {
         case (null, rowId) => assert(reader.isNullAt(rowId))
         case (datum, rowId) =>
           val value = datatype match {
@@ -74,12 +89,31 @@
             case _: YearMonthIntervalType => reader.getInt(rowId)
             case _: DayTimeIntervalType => reader.getLong(rowId)
             case CalendarIntervalType => reader.getInterval(rowId)
+            case _: GeometryType => reader.getGeometry(rowId).getBytes
+            case _: GeographyType => reader.getGeography(rowId).getBytes
           }
           assert(value === datum)
       }
 
       writer.root.close()
     }
+
+    val wkbs = Seq("010100000000000000000031400000000000001c40",
+      "010100000000000000000034400000000000003540")
+      .map { x =>
+        x.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+    }
+
+    val geographies = wkbs.map(x => InternalGeography.fromWkb(x, 4326).getValue)
+    val geometries = wkbs.map(x => InternalGeometry.fromWkb(x, 0).getValue)
+    val mixedGeometries = wkbs.zip(Seq(0, 4326)).map {
+      case (g, srid) => InternalGeometry.fromWkb(g, srid).getValue
+    }
+
+    check(GeometryType(0), geometries)
+    check(GeographyType(4326), geographies)
+    check(GeometryType("ANY"), mixedGeometries)
+    check(GeographyType("ANY"), geographies)
     check(BooleanType, Seq(true, null, false))
     check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte))
     check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort))
@@ -110,6 +144,245 @@
     check(new YearUDT, Seq(2020, 2021, null, 2022))
   }
 
+  test("nested geographies") {
+    def check(
+      dt: StructType,
+      data: Seq[InternalRow]): Unit = {
+      val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC")
+
+      // Write data to arrow.
+      data.toSeq.foreach { datum =>
+        writer.write(datum)
+      }
+      writer.finish()
+
+      // Create arrow vector readers.
+      val vectors = writer.root.getFieldVectors.asScala
+        .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]]
+
+      val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt)
+
+      data.zipWithIndex.foreach { case (datum, i) =>
+        // Read data from arrow.
+        val internalRow = batch.getRow(i)
+
+        // All nullable results first must check whether the value is null.
+        if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) {
+          assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null)
+        } else {
+          val expectedStruct = datum.getStruct(0, 4)
+          val actualStruct = internalRow.getStruct(0, 4)
+          assert(expectedStruct.getInt(0) === actualStruct.getInt(0))
+          assert(expectedStruct.getInt(2) === actualStruct.getInt(2))
+
+          if (expectedStruct.getGeography(1) == null ||
+            actualStruct.getGeography(1) == null) {
+            assert(expectedStruct.getGeography(1) == null && actualStruct.getGeography(1) == null)
+          } else {
+            assert(expectedStruct.getGeography(1).getBytes ===
+              actualStruct.getGeography(1).getBytes)
+          }
+          if (expectedStruct.getGeography(3) == null ||
+            actualStruct.getGeography(3) == null) {
+            assert(expectedStruct.getGeography(3) == null && actualStruct.getGeography(3) == null)
+          } else {
+            assert(expectedStruct.getGeography(3).getBytes ===
+              actualStruct.getGeography(3).getBytes)
+          }
+
+          if (datum.getArray(1) == null ||
+            internalRow.getArray(1) == null) {
+            assert(internalRow.getArray(1) == null && datum.getArray(1) == null)
+          } else {
+            internalRow.getArray(1).toSeq[GeographyVal](GeographyType(4326))
+              .zip(datum.getArray(1).toSeq[GeographyVal](GeographyType(4326))).foreach {
+                case (actual, expected) =>
+                  assert(actual.getBytes === expected.getBytes)
+              }
+          }
+
+          if (datum.getMap(2) == null ||
+            internalRow.getMap(2) == null) {
+            assert(internalRow.getMap(2) == null && datum.getMap(2) == null)
+          } else {
+            assert(internalRow.getMap(2).keyArray().toSeq(StringType) ===
+              datum.getMap(2).keyArray().toSeq(StringType))
+            internalRow.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY"))
+              .zip(datum.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY"))).foreach {
+                case (actual, expected) =>
+                  assert((actual == null && expected == null) ||
+                    actual.getBytes === expected.getBytes)
+              }
+          }
+        }
+      }
+
+      writer.root.close()
+    }
+
+    val point1 = "010100000000000000000031400000000000001C40"
+      .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+    val point2 = "010100000000000000000035400000000000001E40"
+      .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+    val schema = new StructType()
+      .add(
+        "s",
+        new StructType()
+          .add("i1", "int")
+          .add("g0", "geography(4326)")
+          .add("i2", "int")
+          .add("g1", "geography(4326)"))
+      .add("a", "array<geography(4326)>")
+      .add("m", "map<string, geography(ANY)>")
+
+    val maybeNull5 = MaybeNull(5)
+    val maybeNull7 = MaybeNull(7)
+    val maybeNull11 = MaybeNull(11)
+    val maybeNull13 = MaybeNull(13)
+    val maybeNull17 = MaybeNull(17)
+
+    val nestedGeographySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer()
+    val data = Iterator
+      .tabulate(100)(i =>
+        nestedGeographySerializer.apply(
+          (Row(
+            maybeNull5(
+              Row(
+                i,
+                maybeNull7(org.apache.spark.sql.types.Geography.fromWKB(point1)),
+                i + 1,
+                maybeNull11(org.apache.spark.sql.types.Geography.fromWKB(point2, 4326)))),
+            maybeNull7((0 until 10).map(j =>
+              org.apache.spark.sql.types.Geography.fromWKB(point2, 4326))),
+            maybeNull13(
+              Map((i.toString, maybeNull17(
+                org.apache.spark.sql.types.Geography.fromWKB(point1, 4326)))))))))
+      .map(_.copy()).toSeq
+
+    check(schema, data)
+  }
+
+  test("nested geometries") {
+    def check(
+      dt: StructType,
+      data: Seq[InternalRow]): Unit = {
+      val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC")
+
+      // Write data to arrow.
+      data.toSeq.foreach { datum =>
+        writer.write(datum)
+      }
+      writer.finish()
+
+      // Create arrow vector readers.
+      val vectors = writer.root.getFieldVectors.asScala
+        .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]]
+
+      val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt)
+      data.zipWithIndex.foreach { case (datum, i) =>
+        // Read data from arrow.
+        val internalRow = batch.getRow(i)
+
+        // All nullable results first must check whether the value is null.
+        if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) {
+          assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null)
+        } else {
+          val expectedStruct = datum.getStruct(0, 4)
+          val actualStruct = internalRow.getStruct(0, 4)
+          assert(expectedStruct.getInt(0) === actualStruct.getInt(0))
+          assert(expectedStruct.getInt(2) === actualStruct.getInt(2))
+
+          if (expectedStruct.getGeometry(1) == null ||
+            actualStruct.getGeometry(1) == null) {
+            assert(expectedStruct.getGeometry(1) == null && actualStruct.getGeometry(1) == null)
+          } else {
+            assert(expectedStruct.getGeometry(1).getBytes ===
+              actualStruct.getGeometry(1).getBytes)
+          }
+          if (expectedStruct.getGeometry(3) == null ||
+            actualStruct.getGeometry(3) == null) {
+            assert(expectedStruct.getGeometry(3) == null && actualStruct.getGeometry(3) == null)
+          } else {
+            assert(expectedStruct.getGeometry(3).getBytes ===
+              actualStruct.getGeometry(3).getBytes)
+          }
+
+          if (datum.getArray(1) == null ||
+            internalRow.getArray(1) == null) {
+            assert(internalRow.getArray(1) == null && datum.getArray(1) == null)
+          } else {
+            internalRow.getArray(1).toSeq[GeometryVal](GeometryType(0))
+              .zip(datum.getArray(1).toSeq[GeometryVal](GeometryType(0))).foreach {
+                case (actual, expected) =>
+                  assert(actual.getBytes === expected.getBytes)
+              }
+          }
+
+          if (datum.getMap(2) == null ||
+            internalRow.getMap(2) == null) {
+            assert(internalRow.getMap(2) == null && datum.getMap(2) == null)
+          } else {
+            assert(internalRow.getMap(2).keyArray().toSeq(StringType) ===
+              datum.getMap(2).keyArray().toSeq(StringType))
+            internalRow.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY"))
+              .zip(datum.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY"))).foreach {
+                case (actual, expected) =>
+                  assert((actual == null && expected == null) ||
+                    actual.getBytes === expected.getBytes)
+              }
+          }
+        }
+      }
+
+      writer.root.close()
+    }
+
+    val point1 = "010100000000000000000031400000000000001C40"
+      .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+    val point2 = "010100000000000000000035400000000000001E40"
+      .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+    val schema = new StructType()
+      .add(
+        "s",
+        new StructType()
+          .add("i1", "int")
+          .add("g0", "geometry(0)")
+          .add("i2", "int")
+          .add("g4326", "geometry(4326)"))
+      .add("a", "array<geometry(0)>")
+      .add("m", "map<string, geometry(ANY)>")
+
+    val maybeNull5 = MaybeNull(5)
+    val maybeNull7 = MaybeNull(7)
+    val maybeNull11 = MaybeNull(11)
+    val maybeNull13 = MaybeNull(13)
+    val maybeNull17 = MaybeNull(17)
+
+    val nestedGeometrySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer()
+    val data = Iterator
+      .tabulate(100) { i =>
+        val mixedSrid = if (i % 2 == 0) 0 else 4326
+
+        nestedGeometrySerializer.apply(
+          (Row(
+            maybeNull5(
+              Row(
+                i,
+                maybeNull7(org.apache.spark.sql.types.Geometry.fromWKB(point1, 0)),
+                i + 1,
+                maybeNull11(org.apache.spark.sql.types.Geometry.fromWKB(point2, 4326)))),
+            maybeNull7((0 until 10).map(j =>
+              org.apache.spark.sql.types.Geometry.fromWKB(point2, 0))),
+            maybeNull13(
+              Map((i.toString, maybeNull17(
+                org.apache.spark.sql.types.Geometry.fromWKB(point1, mixedSrid))))))))
+      }.map(_.copy()).toSeq
+
+    check(schema, data)
+  }
+
   test("get multiple") {
     def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
       val datatype = dt match {