blob: 054cca1e02554bdcde42028fe95f141291fe0790 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tests.test_base import TestBase
from sedona.spark.core.enums import GridType
from sedona.spark.core.SpatialRDD import (
LineStringRDD,
PointRDD,
PolygonRDD,
RectangleRDD,
)
from sedona.spark.core.SpatialRDD.spatial_rdd import SpatialRDD
class TestJoinBase(TestBase):
use_legacy_apis = False
def create_point_rdd(self, location, splitter, num_partitions):
return PointRDD(self.sc, location, 1, splitter, False, num_partitions)
def create_linestring_rdd(self, location, splitter, num_partitions):
return LineStringRDD(self.sc, location, splitter, True, num_partitions)
def create_polygon_rdd(self, location, splitter, num_partitions):
return PolygonRDD(self.sc, location, splitter, True, num_partitions)
def create_rectangle_rdd(self, location, splitter, num_partitions):
rectangle_rdd = RectangleRDD(self.sc, location, splitter, True, num_partitions)
rectangle_rdd.analyze()
return rectangle_rdd
def partition_rdds(self, query_rdd: SpatialRDD, spatial_rdd: SpatialRDD, grid_type):
spatial_rdd.spatialPartitioning(grid_type)
return query_rdd.spatialPartitioning(spatial_rdd.getPartitioner())
def expect_to_preserve_original_duplicates(self, grid_type):
return grid_type == GridType.QUADTREE or grid_type == GridType.KDBTREE
def count_join_results(self, results):
count = 0
for row_data in results:
joined_data = row_data[1]
count += joined_data.__len__()
return count
def sanity_check_join_results(self, results):
for raw_data in results:
assert raw_data[1].__len__()
for geo_data in raw_data[1]:
assert raw_data[0].geom.intersects(geo_data.geom)
def sanity_check_flat_join_results(self, results):
for row_data in results:
assert row_data[0].geom.intersects(row_data[1].geom)