blob: 7ac00b1bc3db57a57a41e969ffd552b1a158a0e9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tempfile
import pytest
import shapely
import pandas as pd
import geopandas as gpd
import pyspark.pandas as ps
from functools import partial
from sedona.spark.geopandas import GeoDataFrame, GeoSeries, read_file, read_parquet
from tests import tests_resource
from tests.geopandas.test_geopandas_base import TestGeopandasBase
from shapely.geometry import (
Point,
Polygon,
MultiPoint,
MultiLineString,
LineString,
MultiPolygon,
GeometryCollection,
LinearRing,
)
from packaging.version import parse as parse_version
TEST_DATA_DIR = os.path.join("..", "spark", "common", "src", "test", "resources")
@pytest.mark.skipif(
parse_version(shapely.__version__) < parse_version("2.0.0"),
reason=f"Tests require shapely>=2.0.0, but found v{shapely.__version__}",
)
class TestIO(TestGeopandasBase):
def setup_method(self):
self.tempdir = tempfile.mkdtemp()
#########################################################
# File reading tests
#########################################################
# Modified version of Sedona's test_shapefile.py test_read_simple
@pytest.mark.parametrize(
"read_func",
[
partial(GeoDataFrame.from_file, format="shapefile"),
partial(read_file, format="Shapefile"),
],
)
def test_read_shapefile(self, read_func):
data_dir = os.path.join(tests_resource, "shapefiles/polygon")
df = read_func(data_dir)
assert df.count().item() == 10000
subset_df = GeoDataFrame(df.head(100))
# assert only one column
assert subset_df.shape[1] == 1
# assert all geometries are polygons or multipolygons
assert subset_df["geometry"].geom_type.isin(["Polygon", "MultiPolygon"]).all()
# Check inference and single file works
data_file = os.path.join(data_dir, "map.shp")
df = read_func(data_file)
assert df.count().item() == 10000
@pytest.mark.parametrize(
"read_func",
[
partial(GeoDataFrame.from_file, format="geojson"),
partial(read_file, format="GeoJSON"),
partial(read_file), # check format inference works
],
)
def test_read_geojson(self, read_func):
datafile = os.path.join(TEST_DATA_DIR, "geojson/test1.json")
df = read_func(datafile)
assert (df.count() == 1).all()
# Check that inference works
df = read_func(datafile)
assert (df.count() == 1).all()
@pytest.mark.parametrize(
"read_func",
[
partial(GeoDataFrame.from_file, format="geoparquet"),
partial(read_file, format="GeoParquet"),
partial(read_file), # check format inference works
read_parquet,
],
)
def test_read_geoparquet(self, read_func):
input_location = os.path.join(TEST_DATA_DIR, "geoparquet/example1.parquet")
df = read_func(input_location)
# check that all column counts are 5
assert (df.count() == 5).all()
# Check that inference works
df = read_func(input_location)
assert (df.count() == 5).all()
# From Sedona's GeoPackageReaderTest.scala
@pytest.mark.parametrize(
"read_func",
[
partial(GeoDataFrame.from_file, format="geopackage"),
partial(read_file, format="GeoPackage"),
partial(read_file), # check format inference works
],
)
def test_read_geopackage(self, read_func):
datafile = os.path.join(TEST_DATA_DIR, "geopackage/features.gpkg")
table_name = "GB_Hex_5km_GS_CompressibleGround_v8"
expected_cnt = 4233
df = read_func(datafile, table_name=table_name)
assert df["geom"].count() == expected_cnt
# Ensure inference works
table_name = "GB_Hex_5km_GS_Landslides_v8"
expected_cnt = 4228
df = read_func(datafile, table_name=table_name)
assert df["geom"].count() == expected_cnt
#########################################################
# File writing tests
#########################################################
def _get_next_temp_file_path(self, ext: str):
temp_file_path = os.path.join(
self.tempdir, next(tempfile._get_candidate_names()) + "." + ext
)
return temp_file_path
@pytest.mark.parametrize(
"write_func",
[
partial(GeoDataFrame.to_file, driver="GeoParquet"),
partial(GeoDataFrame.to_file, driver="geoparquet"),
partial(GeoDataFrame.to_file), # check format inference works
GeoDataFrame.to_parquet,
],
)
def test_to_geoparquet(self, write_func):
sgpd_df = GeoDataFrame(
{"geometry": [Point(0, 0), LineString([(0, 0), (1, 1)])], "int": [1, 2]}
)
temp_file_path = self._get_next_temp_file_path("parquet")
self._apply_func(sgpd_df, write_func, temp_file_path)
# Ensure reading from geopandas creates the same resulting GeoDataFrame
gpd_df = gpd.read_parquet(temp_file_path)
self.check_sgpd_df_equals_gpd_df(sgpd_df, gpd_df)
@pytest.mark.parametrize(
"write_func",
[
partial(GeoDataFrame.to_file, driver="geojson"), # index=None here is False
partial(GeoDataFrame.to_file, driver="GeoJSON", index=True),
partial(GeoDataFrame.to_file, driver="geojson", index=True),
partial(GeoDataFrame.to_file), # check format inference works
],
)
def test_to_geojson(self, write_func):
sgpd_df = GeoDataFrame(
{"geometry": [Point(0, 0), LineString([(0, 0), (1, 1)])], "int": [1, 2]},
index=[1, 2],
)
temp_file_path = self._get_next_temp_file_path("json")
self._apply_func(sgpd_df, write_func, temp_file_path)
read_result = GeoDataFrame.from_file(
temp_file_path, format="geojson"
).to_geopandas()
# if index was true, the contents should be in the same order as the original GeoDataFrame
if write_func.keywords.get("index", None) == True:
self.check_sgpd_df_equals_gpd_df(sgpd_df, read_result)
else:
# if index was not kept, just check we have all rows and we have default index
self.check_index_equal(read_result, pd.Index([0, 1]))
@pytest.mark.parametrize(
"write_func",
[
partial(GeoDataFrame.to_file, driver="geojson"),
],
)
def test_to_file_non_int_index(self, write_func):
sgpd_df = GeoDataFrame(
{"geometry": [Point(0, 0), LineString([(0, 0), (1, 1)])], "int": [1, 2]},
index=["a", "b"],
)
temp_file_path = self._get_next_temp_file_path("json")
self._apply_func(sgpd_df, write_func, temp_file_path)
read_result = GeoDataFrame.from_file(
temp_file_path, format="geojson"
).to_geopandas()
# Since index was of non-int dtype, index=None here is True
self.check_sgpd_df_equals_gpd_df(sgpd_df, read_result)
@pytest.mark.parametrize(
"format",
[
"geojson",
"geoparquet",
],
)
def test_to_file_and_from_file_series(self, format):
sgpd_ser = GeoSeries([Point(0, 0), LineString([(0, 0), (1, 1)])])
ext = format.replace("geo", "")
temp_file_path = self._get_next_temp_file_path(ext)
sgpd_ser.to_file(temp_file_path, driver=format, index=True)
read_result = GeoSeries.from_file(temp_file_path, format=format)
read_result = read_result.to_geopandas()
# In Geopandas, the name of the series is always read in to be "geometry"
sgpd_ser.name = "geometry"
# Since index=True, the contents should be in the same order as the original GeoSeries
self.check_sgpd_equals_gpd(sgpd_ser, read_result)
def _apply_func(self, obj, func, *args):
"""
Helper function to conditionally apply functions or methods to an object correctly.
"""
if type(func) == str:
return getattr(obj, func)(*args)
else:
return func(obj, *args)