blob: 4d731348bd426234be23d5e6f66fb24e982b9240 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyspark.sql.functions import expr
from tests import (
csv_point1_input_location,
csv_point_input_location,
csv_polygon1_input_location,
)
from tests.test_base import TestBase
class TestPredicate(TestBase):
def test_st_contains(self):
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
result_df = self.spark.sql(
"select * from pointdf where ST_Contains(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)"
)
result_df.show()
assert result_df.count() == 999
def test_st_intersects(self):
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
result_df = self.spark.sql(
"select * from pointdf where ST_Intersects(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)"
)
result_df.show()
assert result_df.count() == 999
def test_st_within(self):
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
result_df = self.spark.sql(
"select * from pointdf where ST_Within(pointdf.arealandmark, ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0))"
)
result_df.show()
assert result_df.count() == 999
def test_st_equals_for_st_point(self):
point_df_csv = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point1_input_location)
)
point_df_csv.createOrReplaceTempView("pointtable")
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as point from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
equal_df = self.spark.sql(
"select * from pointdf where ST_Equals(pointdf.point, ST_Point(100.1, 200.1)) "
)
equal_df.show()
assert equal_df.count() == 5, f"Expected 5 value but got {equal_df.count()}"
def test_st_equals_for_polygon(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
equal_df1 = self.spark.sql(
"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.01,200.01,100.5,200.5)) "
)
equal_df1.show()
assert equal_df1.count() == 5, f"Expected 5 value but got ${equal_df1.count()}"
equal_df_2 = self.spark.sql(
"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.5,200.5,100.01,200.01)) "
)
equal_df_2.show()
assert equal_df_2.count() == 5, f"Expected 5 value but got {equal_df_2.count()}"
def test_st_equals_for_st_point_and_st_polygon(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
equal_df = self.spark.sql(
"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_Point(91.01,191.01)) "
)
equal_df.show()
assert equal_df.count() == 0, f"Expected 0 value but got {equal_df.count()}"
def test_st_equals_for_st_linestring_and_st_polygon(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
string = "100.01,200.01,100.5,200.01,100.5,200.5,100.01,200.5,100.01,200.01"
equal_df = self.spark.sql(
f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_LineStringFromText('{string}', ',')) "
)
equal_df.show()
assert equal_df.count() == 0, f"Expected 0 value but got {equal_df.count()}"
def test_st_equals_for_st_polygon_from_envelope_and_st_polygon_from_text(self):
polygon_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_polygon1_input_location)
)
polygon_csv_df.createOrReplaceTempView("polygontable")
polygon_df = self.spark.sql(
"select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable"
)
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
string = "100.01,200.01,100.5,200.01,100.5,200.5,100.01,200.5,100.01,200.01"
equal_df = self.spark.sql(
f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromText('{string}', ',')) "
)
equal_df.show()
assert equal_df.count() == 5, f"Expected 5 value but got {equal_df.count()}"
def test_st_crosses(self):
crosses_test_table = self.spark.sql(
"select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('LINESTRING(1 5, 5 1)') as b"
)
crosses_test_table.createOrReplaceTempView("crossesTesttable")
crosses = self.spark.sql("select(ST_Crosses(a, b)) from crossesTesttable")
not_crosses_test_table = self.spark.sql(
"select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 5 2, 5 5, 2 5, 2 2))') as b"
)
not_crosses_test_table.createOrReplaceTempView("notCrossesTesttable")
not_crosses = self.spark.sql(
"select(ST_Crosses(a, b)) from notCrossesTesttable"
)
assert crosses.take(1)[0][0]
assert not not_crosses.take(1)[0][0]
def test_st_touches(self):
point_csv_df = (
self.spark.read.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csv_point_input_location)
)
point_csv_df.createOrReplaceTempView("pointtable")
point_df = self.spark.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable"
)
point_df.createOrReplaceTempView("pointdf")
result_df = self.spark.sql(
"select * from pointdf where ST_Touches(pointdf.arealandmark, ST_PolygonFromEnvelope(0.0,99.0,1.1,101.1))"
)
result_df.show()
assert result_df.count() == 1
def test_st_relate(self):
baseDf = self.spark.sql(
"SELECT ST_GeomFromWKT('LINESTRING (1 1, 5 5)') AS g1, ST_GeomFromWKT('POLYGON ((3 3, 3 7, 7 7, 7 3, 3 3))') as g2, '1010F0212' as im"
)
actual = baseDf.selectExpr("ST_Relate(g1, g2)").take(1)[0][0]
assert actual == "1010F0212"
actual = baseDf.selectExpr("ST_Relate(g1, g2, im)").take(1)[0][0]
assert actual
def test_st_relate_match(self):
actual = self.spark.sql(
"SELECT ST_RelateMatch('101202FFF', 'TTTTTTFFF') "
).take(1)[0][0]
assert actual
def test_st_overlaps(self):
test_table = self.spark.sql(
"select ST_GeomFromWKT('POLYGON((2.5 2.5, 2.5 4.5, 4.5 4.5, 4.5 2.5, 2.5 2.5))') as a,ST_GeomFromWKT('POLYGON((4 4, 4 6, 6 6, 6 4, 4 4))') as b, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as c, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as d"
)
test_table.createOrReplaceTempView("testtable")
overlaps = self.spark.sql("select ST_Overlaps(a,b) from testtable")
not_overlaps = self.spark.sql("select ST_Overlaps(c,d) from testtable")
assert overlaps.take(1)[0][0]
assert not not_overlaps.take(1)[0][0]
def test_st_ordering_equals_ok(self):
test_table = self.spark.sql(
"select ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as a, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as b, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 0 -2, 2 0))') as c, ST_GeomFromWKT('POLYGON((0 2, -2 0, 2 0, 0 2))') as d"
)
test_table.createOrReplaceTempView("testorderingequals")
order_equals = self.spark.sql(
"select ST_OrderingEquals(a,b) from testorderingequals"
)
not_order_equals_diff_geom = self.spark.sql(
"select ST_OrderingEquals(a,c) from testorderingequals"
)
not_order_equals_diff_order = self.spark.sql(
"select ST_OrderingEquals(a,d) from testorderingequals"
)
assert order_equals.take(1)[0][0]
assert not not_order_equals_diff_geom.take(1)[0][0]
assert not not_order_equals_diff_order.take(1)[0][0]
def test_st_dwithin(self):
test_table = self.spark.sql(
"select ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (2 0)') as point_1"
)
test_table.createOrReplaceTempView("test_table")
isWithin = self.spark.sql(
"select ST_DWithin(origin, point_1, 3) from test_table"
).head()[0]
assert isWithin is True
def test_dwithin_use_sphere(self):
test_table = self.spark.sql(
"select ST_GeomFromWKT('POINT (-122.335167 47.608013)') as seattle, ST_GeomFromWKT('POINT (-73.935242 40.730610)') as ny"
)
test_table.createOrReplaceTempView("test_table")
isWithin = self.spark.sql(
"select ST_DWithin(seattle, ny, 2000000, true) from test_table"
).head()[0]
assert isWithin is False
def test_dwithin_use_sphere_complex_boolean_expression(self):
expected = 55
df_point = self.spark.range(10).withColumn("pt", expr("ST_Point(id, id)"))
df_polygon = self.spark.range(10).withColumn(
"poly", expr("ST_Point(id, id + 0.01)")
)
actual = (
df_point.alias("a")
.join(
df_polygon.alias("b"),
expr("ST_DWithin(pt, poly, 10000, a.`id` % 2 = 0)"),
)
.count()
)
assert expected == actual