blob: ba3e2ff6286d2b41c278b58977cf94809e6e1ce2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyspark import Row
from pyspark.sql.functions import broadcast, expr
from pyspark.sql.types import (
DoubleType,
IntegerType,
StringType,
StructField,
StructType,
)
from tests import (
csv_point1_input_location,
csv_point2_input_location,
csv_point_input_location,
csv_polygon1_input_location,
csv_polygon1_random_input_location,
csv_polygon2_input_location,
csv_polygon2_random_input_location,
csv_polygon_input_location,
overlap_polygon_input_location,
)
from tests.test_base import TestBase
class TestPredicateJoin(TestBase):
def test_st_contains_in_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 1000
def test_st_intersects_in_a_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Intersects(polygondf.polygonshape,pointdf.pointshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 1000
def test_st_touches_in_a_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Touches(polygondf.polygonshape,pointdf.pointshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 0
def test_st_within_in_a_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Within(pointdf.pointshape, polygondf.polygonshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 1000
def test_st_overlaps_in_a_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_csv_overlap_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(overlap_polygon_input_location)
)
polygon_csv_overlap_df.createOrReplaceTempView("polygonoverlaptable")
polygon_overlap_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygonoverlaptable._c0 as Decimal(24,20)),cast(polygonoverlaptable._c1 as Decimal(24,20)), cast(polygonoverlaptable._c2 as Decimal(24,20)), cast(polygonoverlaptable._c3 as Decimal(24,20))) as polygonshape from polygonoverlaptable"
)
polygon_overlap_df.createOrReplaceTempView("polygonodf")
range_join_df = self.spark.sql(
"select * from polygondf, polygonodf where ST_Overlaps(polygondf.polygonshape, polygonodf.polygonshape)"
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 15
def test_st_crosses_in_a_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Crosses(pointdf.pointshape, polygondf.polygonshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 0
def test_st_distance_radius_in_a_join(self):
point_csv_df_1 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df_1.createOrReplaceTempView("pointtable")
point_csv_df_1.show()
point_df_1 = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable"
)
point_df_1.createOrReplaceTempView("pointdf1")
point_df_1.show()
point_csv_df_2 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df_2.createOrReplaceTempView("pointtable")
point_csv_df_2.show()
point_df2 = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable"
)
point_df2.createOrReplaceTempView("pointdf2")
point_df2.show()
distance_join_df = self.spark.sql(
"select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) <= 2"
)
distance_join_df.explain()
distance_join_df.show(10)
assert distance_join_df.count() == 2998
def test_st_distance_less_radius_in_a_join(self):
point_csv_df_1 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df_1.createOrReplaceTempView("pointtable")
point_csv_df_1.show()
point_df1 = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable"
)
point_df1.createOrReplaceTempView("pointdf1")
point_df1.show()
point_csv_df2 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df2.createOrReplaceTempView("pointtable")
point_csv_df2.show()
point_df2 = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable"
)
point_df2.createOrReplaceTempView("pointdf2")
point_df2.show()
distance_join_df = self.spark.sql(
"select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) < 2"
)
distance_join_df.explain()
distance_join_df.show(10)
assert distance_join_df.count() == 2998
def test_st_contains_in_a_range_and_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) "
+ "and ST_Contains(ST_PolygonFromEnvelope(1.0,101.0,501.0,601.0), polygondf.polygonshape)"
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.count() == 500
def test_super_small_data_join(self):
raw_point_df = self.spark.createDataFrame(
self.spark.sparkContext.parallelize(
[
Row(1, "40.0", "-120.0"),
Row(2, "30.0", "-110.0"),
Row(3, "20.0", "-100.0"),
]
),
StructType(
[
StructField("id", IntegerType(), True),
StructField("lat", StringType(), True),
StructField("lon", StringType(), True),
]
),
)
raw_point_df.createOrReplaceTempView("rawPointDf")
pointDF = self.spark.sql(
"select id, ST_Point(cast(lat as Decimal(24,20)), cast(lon as Decimal(24,20))) AS latlon_point FROM rawPointDf"
)
pointDF.createOrReplaceTempView("pointDf")
pointDF.show(truncate=False)
raw_polygon_df = self.spark.createDataFrame(
self.spark.sparkContext.parallelize(
[
Row("A", 25.0, -115.0, 35.0, -105.0),
Row("B", 25.0, -135.0, 35.0, -125.0),
]
),
StructType(
[
StructField("id", StringType(), True),
StructField("latmin", DoubleType(), True),
StructField("lonmin", DoubleType(), True),
StructField("latmax", DoubleType(), True),
StructField("lonmax", DoubleType(), True),
]
),
)
raw_polygon_df.createOrReplaceTempView("rawPolygonDf")
polygon_envelope_df = self.spark.sql(
"select id, ST_PolygonFromEnvelope("
+ "cast(latmin as Decimal(24,20)), cast(lonmin as Decimal(24,20)), "
+ "cast(latmax as Decimal(24,20)), cast(lonmax as Decimal(24,20))) AS polygon FROM rawPolygonDf"
)
polygon_envelope_df.createOrReplaceTempView("polygonDf")
within_envelope_df = self.spark.sql(
"select * FROM pointDf, polygonDf WHERE ST_Within(pointDf.latlon_point, polygonDf.polygon)"
)
assert within_envelope_df.count() == 1
def test_st_equals_in_a_join_for_st_point(self):
point_csv_df_1 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point1_input_location)
)
point_csv_df_1.createOrReplaceTempView("pointtable1")
point_csv_df_1.show()
point_df1 = self.spark.sql(
"select ST_Point(cast(pointtable1._c0 as Decimal(24,20)),cast(pointtable1._c1 as Decimal(24,20)) ) as pointshape1 from pointtable1"
)
point_df1.createOrReplaceTempView("pointdf1")
point_df1.show()
point_csv_df2 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point2_input_location)
)
point_csv_df2.createOrReplaceTempView("pointtable2")
point_csv_df2.show()
point_df2 = self.spark.sql(
"select ST_Point(cast(pointtable2._c0 as Decimal(24,20)),cast(pointtable2._c1 as Decimal(24,20))) as pointshape2 from pointtable2"
)
point_df2.createOrReplaceTempView("pointdf2")
point_df2.show()
equal_join_df = self.spark.sql(
"select * from pointdf1, pointdf2 where ST_Equals(pointdf1.pointshape1,pointdf2.pointshape2) "
)
equal_join_df.explain()
equal_join_df.show(3)
assert (
equal_join_df.count() == 100
), f"Expected 100 but got {equal_join_df.count()}"
def test_st_equals_in_a_join_for_st_polygon(self):
polygon_csv_df1 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df1.createOrReplaceTempView("polygontable1")
polygon_csv_df1.show()
polygon_df1 = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1"
)
polygon_df1.createOrReplaceTempView("polygondf1")
polygon_df1.show()
polygon_csv_df2 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon2_input_location)
)
polygon_csv_df2.createOrReplaceTempView("polygontable2")
polygon_csv_df2.show()
polygon_df2 = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2"
)
polygon_df2.createOrReplaceTempView("polygondf2")
polygon_df2.show()
equal_join_df = self.spark.sql(
"select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) "
)
equal_join_df.explain()
equal_join_df.show(3)
assert (
equal_join_df.count() == 100
), f"Expected 100 but got {equal_join_df.count()}"
def test_st_equals_in_a_join_for_st_polygon_random_shuffle(self):
polygon_csv_df1 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_random_input_location)
)
polygon_csv_df1.createOrReplaceTempView("polygontable1")
polygon_csv_df1.show()
polygon_df1 = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1"
)
polygon_df1.createOrReplaceTempView("polygondf1")
polygon_df1.show()
polygon_csv_df2 = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon2_random_input_location)
)
polygon_csv_df2.createOrReplaceTempView("polygontable2")
polygon_csv_df2.show()
polygon_df2 = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2"
)
polygon_df2.createOrReplaceTempView("polygondf2")
polygon_df2.show()
equal_join_df = self.spark.sql(
"select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) "
)
equal_join_df.explain()
equal_join_df.show(3)
assert (
equal_join_df.count() == 100
), f"Expected 100 but got {equal_join_df.count()}"
def test_st_equals_in_a_join_for_st_point_and_st_polygon(self):
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point1_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20)) ) as pointshape from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
equal_join_df = self.spark.sql(
"select * from pointdf, polygondf where ST_Equals(pointdf.pointshape,polygondf.polygonshape) "
)
equal_join_df.explain()
equal_join_df.show(3)
assert equal_join_df.count() == 0, f"Expected 0 but got {equal_join_df.count()}"
def test_st_contains_in_broadcast_join(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_csv_df.show()
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df = polygon_df.repartition(7)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_csv_df.show()
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable"
)
point_df = point_df.repartition(9)
point_df.createOrReplaceTempView("pointdf")
point_df.show()
range_join_df = self.spark.sql(
"select /*+ BROADCAST(polygondf) */ * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) "
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.rdd.getNumPartitions() == 9
assert range_join_df.count() == 1000
range_join_df = point_df.alias("pointdf").join(
broadcast(polygon_df).alias("polygondf"),
on=expr("ST_Contains(polygondf.polygonshape, pointdf.pointshape)"),
)
range_join_df.explain()
range_join_df.show(3)
assert range_join_df.rdd.getNumPartitions() == 9
assert range_join_df.count() == 1000