blob: 6e68fc40b1ac1f190ca01bf77bde7f841ccd6b59 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np
import shapely
from pyspark.sql.functions import expr
from shapely.geometry import Point, Polygon, LineString, box
from tests.test_base import TestBase
from sedona.spark.geopandas import GeoSeries
from sedona.spark.geopandas.sindex import SpatialIndex
from packaging.version import parse as parse_version
@pytest.mark.skipif(
parse_version(shapely.__version__) < parse_version("2.0.0"),
reason=f"Tests require shapely>=2.0.0, but found v{shapely.__version__}",
)
class TestSpatialIndex(TestBase):
"""Tests for the spatial index functionality in GeoSeries."""
def setup_method(self):
"""Set up test data."""
# Create a GeoSeries with point geometries
self.points = GeoSeries(
[Point(0, 0), Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4)]
)
# Create a GeoSeries with polygon geometries
self.polygons = GeoSeries(
[
Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]),
Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]),
Polygon([(2, 2), (3, 2), (3, 3), (2, 3)]),
Polygon([(3, 3), (4, 3), (4, 4), (3, 4)]),
Polygon([(4, 4), (5, 4), (5, 5), (4, 5)]),
]
)
# Create a GeoSeries with line geometries
self.lines = GeoSeries(
[
LineString([(0, 0), (1, 1)]),
LineString([(1, 1), (2, 2)]),
LineString([(2, 2), (3, 3)]),
LineString([(3, 3), (4, 4)]),
LineString([(4, 4), (5, 5)]),
]
)
def test_construct_from_geoseries(self):
# Construct from a GeoSeries
gs = GeoSeries([Point(x, x) for x in range(5)])
sindex = SpatialIndex(gs)
result = sindex.query(Point(2, 2))
# SpatialIndex constructed from GeoSeries return geometries
assert result == [Point(2, 2)]
def test_construct_from_pyspark_dataframe(self):
# Construct from PySparkDataFrame
df = self.spark.createDataFrame(
[(Point(x, x),) for x in range(5)], ["geometry"]
)
sindex = SpatialIndex(df, column_name="geometry")
result = sindex.query(Point(2, 2))
assert result == [Point(2, 2)]
def test_construct_from_nparray(self):
# Construct from np.array
array = np.array([Point(x, x) for x in range(5)])
sindex = SpatialIndex(array)
result = sindex.query(Point(2, 2))
# Returns indices like original geopandas
assert result == np.array([2])
def test_geoseries_sindex_property_exists(self):
"""Test that the sindex property exists on GeoSeries."""
assert hasattr(self.points, "sindex")
assert hasattr(self.polygons, "sindex")
assert hasattr(self.lines, "sindex")
def test_geodataframe_sindex_property_exists(self):
"""Test that the sindex property exists on GeoDataFrame."""
assert hasattr(self.points.to_geoframe(), "sindex")
assert hasattr(self.polygons.to_geoframe(), "sindex")
assert hasattr(self.lines.to_geoframe(), "sindex")
def test_query_with_point(self):
"""Test querying the spatial index with a point geometry."""
# Create a list of Shapely geometries - squares around points (0,0), (1,1), etc.
geometries = [
Polygon(
[
(i - 0.5, j - 0.5),
(i + 0.5, j - 0.5),
(i + 0.5, j + 0.5),
(i - 0.5, j + 0.5),
]
)
for i, j in [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
]
# Create a spatial index from the geometries
geom_array = np.array(geometries, dtype=object)
sindex = SpatialIndex(geom_array)
# Test query with a point that should intersect with one polygon
query_point = Point(2.2, 2.2)
result_indices = sindex.query(query_point)
assert len(result_indices) == 1
# Test query with a point that intersects no polygons
empty_point = Point(10, 10)
empty_results = sindex.query(empty_point)
assert len(empty_results) == 0
def test_query_with_spark_dataframe(self):
"""Test querying the spatial index with a Spark DataFrame."""
# Create a spatial DataFrame with polygons
polygons_data = [
(1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"),
(2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"),
(3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
(4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"),
(5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
]
df = self.spark.createDataFrame(polygons_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
# Create a SpatialIndex from the DataFrame
sindex = SpatialIndex(spatial_df, index_type="strtree", column_name="geometry")
# Test query with a point that should intersect with one polygon
from shapely.geometry import Point
query_point = Point(2.2, 2.2)
# Execute query
result_indices = sindex.query(query_point, "intersects")
# Verify results - should find at least one result (polygon containing the point)
assert len(result_indices) > 0
# Test query with a polygon that should intersect multiple polygons
from shapely.geometry import box
query_box = box(1.5, 1.5, 3.5, 3.5)
# Execute query
box_results = sindex.query(query_box, predicate="contains")
# Verify results - should find multiple polygons
assert len(box_results) == 1
# Test with contains predicate
# The query box fully contains polygon at index 2 (POLYGON((2 2, 3 2, 3 3, 2 3, 2 2)))
contains_results = sindex.query(query_box, predicate="contains")
# Verify contains results
assert len(contains_results) >= 1
# Test with a point outside any polygon
outside_point = Point(10, 10)
outside_results = sindex.query(outside_point)
# Verify no results for point outside
assert len(outside_results) == 0
def test_nearest_method(self):
"""Test the nearest method for finding k-nearest neighbors."""
# Create a spatial DataFrame with points
points_data = [
(1, "POINT(0 0)"),
(2, "POINT(1 1)"),
(3, "POINT(2 2)"),
(4, "POINT(3 3)"),
(5, "POINT(4 4)"),
]
df = self.spark.createDataFrame(points_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
# Create a SpatialIndex from the DataFrame
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test finding single nearest neighbor
query_point = Point(1.2, 1.2)
nearest_result = spark_sindex.nearest(query_point)
assert len(nearest_result) == 1
# The nearest point should have id=2 (POINT(1 1))
assert nearest_result[0].wkt == "POINT (1 1)"
# Test finding k=2 nearest neighbors
nearest_2_results = spark_sindex.nearest(query_point, k=2)
assert len(nearest_2_results) == 2
# Test with return_distance=True
nearest_with_dist = spark_sindex.nearest(query_point, k=2, return_distance=True)
assert len(nearest_with_dist) == 2 # Returns tuple of (rows, distances)
rows, distances = nearest_with_dist
assert len(rows) == 2
assert len(distances) == 2
assert all(d >= 0 for d in distances) # Distances should be non-negative
def test_nearest_spark_with_various_geometries(self):
"""Test nearest with different geometry types in Spark mode."""
# Create a spatial DataFrame with mixed geometry types
mixed_data = [
(1, "POINT(0 0)"),
(2, "LINESTRING(1 1, 2 2)"),
(3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
(4, "POINT(3 3)"),
(5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
]
df = self.spark.createDataFrame(mixed_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test with point query
query_point = Point(2.5, 2.5)
nearest_geom = spark_sindex.nearest(query_point)
# Should find polygon containing the point
assert len(nearest_geom) == 1
assert "POLYGON" in nearest_geom[0].wkt
# Test with linestring query
query_line = LineString([(1.5, 1.5), (2.5, 2.5)])
nearest_to_line = spark_sindex.nearest(query_line)
assert len(nearest_to_line) == 1
def test_nearest_spark_with_less_results_than_k(self):
"""Test nearest when less or no results are expected."""
# Create DataFrame with only one point
single_point = [(1, "POINT(0 0)")]
df = self.spark.createDataFrame(single_point, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Query with k=5 when only 1 point exists
results = spark_sindex.nearest(Point(0.1, 0.1), k=5)
assert len(results) == 1 # Should only return the one available point
def test_nearest_spark_with_distance_verification(self):
"""Test that nearest returns results in correct distance order."""
# Create points in a grid pattern
grid_points = [
(1, "POINT(0 0)"),
(2, "POINT(1 0)"),
(3, "POINT(0 1)"),
(4, "POINT(1 1)"),
(5, "POINT(2 2)"),
]
df = self.spark.createDataFrame(grid_points, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Query point at (0.5, 0.5)
query_point = Point(0.5, 0.5)
# Get 3 nearest with distances
results, distances = spark_sindex.nearest(
query_point, k=3, return_distance=True
)
# Verify distances are in ascending order
assert distances[0] <= distances[1] <= distances[2]
# Manually calculate expected distances
expected_nearest_points = [
Point(0, 0),
Point(1, 0),
Point(0, 1),
Point(1, 1),
Point(2, 2),
]
expected_distances = [p.distance(query_point) for p in expected_nearest_points]
expected_distances.sort()
# Verify the first 3 distances match our calculations
for i in range(3):
assert abs(distances[i] - expected_distances[i]) < 0.0001
def test_nearest_spark_with_identical_distances(self):
"""Test nearest when multiple geometries have identical distances."""
# Create points that are equidistant from center
equidistant_points = [
(1, "POINT(0 1)"), # 1 unit from center
(2, "POINT(1 0)"), # 1 unit from center
(3, "POINT(0 -1)"), # 1 unit from center
(4, "POINT(-1 0)"), # 1 unit from center
(5, "POINT(2 2)"), # Center point
]
df = self.spark.createDataFrame(equidistant_points, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Query center point
query_point = Point(0, 0)
# Get the 4 nearest points
results = spark_sindex.nearest(query_point, k=4)
assert len(results) == 4
# With return_distance, verify the center point has distance 0
results, distances = spark_sindex.nearest(
query_point, k=4, return_distance=True
)
assert min(distances) == 1.0
def test_intersection_method(self):
"""Test the intersection method for finding geometries that intersect a bounding box."""
# Create a spatial DataFrame with polygons
polygons_data = [
(1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"),
(2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"),
(3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
(4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"),
(5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
]
df = self.spark.createDataFrame(polygons_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
# Create a SpatialIndex from the DataFrame
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test intersection with a bounding box that should intersect with middle polygons
bounds = (
1.5,
1.5,
3.5,
3.5,
) # Should intersect with polygons at (2,2) and (3,3)
result_rows = spark_sindex.intersection(bounds)
# Verify correct results are returned
expected = [
Polygon([(1, 1), (2, 1), (2, 2), (1, 2), (1, 1)]),
Polygon([(2, 2), (3, 2), (3, 3), (2, 3), (2, 2)]),
Polygon([(3, 3), (4, 3), (4, 4), (3, 4), (3, 3)]),
]
assert result_rows == expected
# Test with bounds that don't intersect any geometry
empty_bounds = (10, 10, 11, 11)
empty_results = spark_sindex.intersection(empty_bounds)
assert len(empty_results) == 0
# Test with bounds that cover all geometries
full_bounds = (-1, -1, 6, 6)
full_results = spark_sindex.intersection(full_bounds)
expected = [
Polygon([(0, 0), (1, 0), (1, 1), (0, 1), (0, 0)]),
Polygon([(1, 1), (2, 1), (2, 2), (1, 2), (1, 1)]),
Polygon([(2, 2), (3, 2), (3, 3), (2, 3), (2, 2)]),
Polygon([(3, 3), (4, 3), (4, 4), (3, 4), (3, 3)]),
Polygon([(4, 4), (5, 4), (5, 5), (4, 5), (4, 4)]),
]
assert full_results == expected
def test_intersection_with_points(self):
"""Test the intersection method with point geometries."""
# Create a spatial DataFrame with points
points_data = [
(1, "POINT(0 0)"),
(2, "POINT(1 1)"),
(3, "POINT(2 2)"),
(4, "POINT(3 3)"),
(5, "POINT(4 4)"),
]
df = self.spark.createDataFrame(points_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test with bounds containing points 2, 3
bounds = (0.5, 0.5, 2.5, 2.5)
results = spark_sindex.intersection(bounds)
# Verify correct results
assert len(results) == 2
def test_intersection_with_linestrings(self):
"""Test the intersection method with linestring geometries."""
# Create a spatial DataFrame with linestrings
lines_data = [
(1, "LINESTRING(0 0, 1 1)"),
(2, "LINESTRING(1 1, 2 2)"),
(3, "LINESTRING(2 2, 3 3)"),
(4, "LINESTRING(3 3, 4 4)"),
(5, "LINESTRING(4 4, 5 5)"),
]
df = self.spark.createDataFrame(lines_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test with bounds crossing lines 2, 3
bounds = (1.5, 1.5, 2.5, 2.5)
results = spark_sindex.intersection(bounds)
# Verify results
assert len(results) == 2
def test_intersection_with_mixed_geometries(self):
"""Test the intersection method with mixed geometry types."""
# Create a spatial DataFrame with mixed geometry types
mixed_data = [
(1, "POINT(0 0)"),
(2, "LINESTRING(1 1, 2 2)"),
(3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
(4, "POINT(3 3)"),
(5, "LINESTRING(4 4, 5 5)"),
]
df = self.spark.createDataFrame(mixed_data, ["id", "wkt"])
spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
spark_sindex = SpatialIndex(
spatial_df, index_type="strtree", column_name="geometry"
)
# Test with bounds that should intersect with polygon and point
bounds = (2.5, 2.5, 3.5, 3.5)
results = spark_sindex.intersection(bounds)
# Verify results
assert len(results) == 2
# test from the geopandas docstring
def test_geoseries_sindex_intersection(self):
gs = GeoSeries([Point(x, x) for x in range(10)])
result = gs.sindex.intersection(box(1, 1, 3, 3).bounds)
# Unlike original geopandas, this returns geometries instead of indices
expected = [Point(1, 1), Point(2, 2), Point(3, 3)]
assert result == expected