blob: a8ce71df422296d5b7eeb77ccad7e1430b9ef30d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from sedona.core.enums import FileDataSplitter, GridType, IndexType
from sedona.core.enums.join_build_side import JoinBuildSide
from sedona.core.geom.envelope import Envelope
from sedona.core.spatialOperator import JoinQuery
from sedona.core.spatialOperator.join_params import JoinParams
from tests.spatial_operator.test_join_base import TestJoinBase
from tests.tools import tests_resource
input_location = os.path.join(tests_resource, "zcta510-small.csv")
query_window_set = os.path.join(tests_resource, "zcta510-small.csv")
offset = 0
splitter = FileDataSplitter.CSV
distance = 0.001
queryPolygonSet = os.path.join(tests_resource, "primaryroads-polygon.csv")
inputCount = 3000
inputBoundary = Envelope(-171.090042, 145.830505, -14.373765, 49.00127)
match_count = 17599
match_with_original_duplicates_count = 17738
def pytest_generate_tests(metafunc):
funcarglist = metafunc.cls.params[metafunc.function.__name__]
argnames = sorted(funcarglist[0])
metafunc.parametrize(
argnames, [[funcargs[name] for name in argnames] for funcargs in funcarglist]
)
parameters = [
dict(num_partitions=11, grid_type=GridType.QUADTREE),
dict(num_partitions=11, grid_type=GridType.QUADTREE),
dict(num_partitions=11, grid_type=GridType.KDBTREE),
]
params_dyn = [{**param, **{"index_type": IndexType.QUADTREE}} for param in parameters]
params_dyn.extend([{**param, **{"index_type": IndexType.RTREE}} for param in parameters])
class TestRectangleJoin(TestJoinBase):
params = {
"test_nested_loop": parameters,
"test_dynamic_index_int": params_dyn,
"test_index_int": params_dyn
}
def test_nested_loop(self, num_partitions, grid_type):
query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
self.partition_rdds(query_rdd, spatial_rdd, grid_type)
result = JoinQuery.SpatialJoinQuery(
spatial_rdd, query_rdd, False, True).collect()
count = 0
for el in result:
count += el[1].__len__()
self.sanity_check_join_results(result)
expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates(
grid_type) else match_count
assert expected_count == self.count_join_results(result)
def test_dynamic_index_int(self, num_partitions, grid_type, index_type):
query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
self.partition_rdds(query_rdd, spatial_rdd, grid_type)
join_params = JoinParams(True, index_type, JoinBuildSide.LEFT)
result = JoinQuery.spatialJoin(query_rdd, spatial_rdd, join_params).collect()
self.sanity_check_flat_join_results(result)
expected_count = match_with_original_duplicates_count \
if self.expect_to_preserve_original_duplicates(grid_type) else match_count
assert expected_count == result.__len__()
def test_index_int(self, num_partitions, grid_type, index_type):
query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions)
self.partition_rdds(query_rdd, spatial_rdd, grid_type)
spatial_rdd.buildIndex(index_type, True)
result = JoinQuery.SpatialJoinQuery(
spatial_rdd, query_rdd, False, True).collect()
self.sanity_check_join_results(result)
expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates(
grid_type) else match_count
assert expected_count == self.count_join_results(result)