blob: 969d9cfdc0d804d5c0401b3b4af35079c92b2582 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import os.path
import json
from shapely.geometry import Point
from shapely.geometry import LineString
from shapely.geometry.base import BaseGeometry
from shapely.wkt import loads as wkt_loads
import geopandas
from tests.test_base import TestBase
from tests import geoparquet_input_location
from tests import plain_parquet_input_location
from tests import legacy_parquet_input_location
class TestGeoParquet(TestBase):
def test_interoperability_with_geopandas(self, tmp_path):
test_data = [
[1, Point(0, 0), LineString([(1, 2), (3, 4), (5, 6)])],
[2, LineString([(1, 2), (3, 4), (5, 6)]), Point(1, 1)],
[3, Point(1, 1), LineString([(1, 2), (3, 4), (5, 6)])]
]
df = self.spark.createDataFrame(data=test_data, schema=["id", "g0", "g1"]).repartition(1)
geoparquet_save_path = os.path.join(tmp_path, "test.parquet")
df.write.format("geoparquet").save(geoparquet_save_path)
# Load geoparquet file written by sedona using geopandas
gdf = geopandas.read_parquet(geoparquet_save_path)
assert gdf.dtypes['g0'].name == 'geometry'
assert gdf.dtypes['g1'].name == 'geometry'
# Load geoparquet file written by geopandas using sedona
gdf = geopandas.GeoDataFrame([
{'g': wkt_loads('POINT (1 2)'), 'i': 10},
{'g': wkt_loads('LINESTRING (1 2, 3 4)'), 'i': 20}
], geometry='g')
geoparquet_save_path2 = os.path.join(tmp_path, "test_2.parquet")
gdf.to_parquet(geoparquet_save_path2)
df2 = self.spark.read.format("geoparquet").load(geoparquet_save_path2)
assert df2.count() == 2
row = df2.collect()[0]
assert isinstance(row['g'], BaseGeometry)
def test_load_geoparquet_with_spatial_filter(self):
df = self.spark.read.format("geoparquet").load(geoparquet_input_location)\
.where("ST_Contains(geometry, ST_GeomFromText('POINT (35.174722 -6.552465)'))")
rows = df.collect()
assert len(rows) == 1
assert rows[0]['name'] == 'Tanzania'
def test_load_plain_parquet_file(self):
with pytest.raises(Exception) as excinfo:
self.spark.read.format("geoparquet").load(plain_parquet_input_location)
assert "does not contain valid geo metadata" in str(excinfo.value)
def test_inspect_geoparquet_metadata(self):
df = self.spark.read.format("geoparquet.metadata").load(geoparquet_input_location)
rows = df.collect()
assert len(rows) == 1
row = rows[0]
assert row['path'].endswith('.parquet')
assert len(row['version'].split('.')) == 3
assert row['primary_column'] == 'geometry'
column_metadata = row['columns']['geometry']
assert column_metadata['encoding'] == 'WKB'
assert len(column_metadata['bbox']) == 4
assert isinstance(json.loads(column_metadata['crs']), dict)
def test_reading_legacy_parquet_files(self):
df = self.spark.read.format("geoparquet").option("legacyMode", "true").load(legacy_parquet_input_location)
rows = df.collect()
assert len(rows) > 0
for row in rows:
assert isinstance(row['geom'], BaseGeometry)
assert isinstance(row['struct_geom']['g0'], BaseGeometry)
assert isinstance(row['struct_geom']['g1'], BaseGeometry)