blob: e60ad183d147466ee25012a8797d6514d21f42e2 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.sql import Row
from pyspark.testing.sqlutils import ReusedSQLTestCase
class SQLTestsMixin:
def test_simple(self):
res = self.spark.sql("SELECT 1 + 1").collect()
self.assertEqual(len(res), 1)
self.assertEqual(res[0][0], 2)
def test_args_dict(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name)",
args={"table_name": "test"},
)
self.assertEqual(df.count(), 10)
self.assertEqual(df.limit(5).count(), 5)
self.assertEqual(df.offset(5).count(), 5)
self.assertEqual(df.take(1), [Row(id=0)])
self.assertEqual(df.tail(1), [Row(id=9)])
def test_args_list(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM test WHERE ? < id AND id < ?",
args=[1, 6],
)
self.assertEqual(df.count(), 4)
self.assertEqual(df.limit(3).count(), 3)
self.assertEqual(df.offset(3).count(), 1)
self.assertEqual(df.take(1), [Row(id=2)])
self.assertEqual(df.tail(1), [Row(id=5)])
def test_kwargs_literal(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}",
args={"table_name": "test"},
m1=3,
m2=7,
m3=9,
)
self.assertEqual(df.count(), 4)
self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)])
self.assertEqual(df.take(1), [Row(id=4)])
self.assertEqual(df.tail(1), [Row(id=9)])
def test_kwargs_literal_multiple_ref(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0",
args={"table_name": "test"},
m=6,
)
self.assertEqual(df.count(), 4)
self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)])
self.assertEqual(df.take(1), [Row(id=6)])
self.assertEqual(df.tail(1), [Row(id=9)])
def test_kwargs_dataframe(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE id > 4",
df=df0,
)
self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 5)
self.assertEqual(df1.take(1), [Row(id=5)])
self.assertEqual(df1.tail(1), [Row(id=9)])
def test_kwargs_dataframe_with_column(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2",
{"m1": 4, "m2": 9},
df=df0,
)
self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 4)
self.assertEqual(df1.take(1), [Row(id=5)])
self.assertEqual(df1.tail(1), [Row(id=8)])
def test_nested_view(self):
with self.tempView("v1", "v2", "v3", "v4"):
self.spark.range(10).createOrReplaceTempView("v1")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v1", "m": 1},
).createOrReplaceTempView("v2")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v2", "m": 2},
).createOrReplaceTempView("v3")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v3", "m": 3},
).createOrReplaceTempView("v4")
df = self.spark.sql("select * from v4")
self.assertEqual(df.count(), 6)
self.assertEqual(df.take(1), [Row(id=4)])
self.assertEqual(df.tail(1), [Row(id=9)])
def test_nested_dataframe(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[1],
df=df0,
)
df2 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[2],
df=df1,
)
df3 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[3],
df=df2,
)
self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 8)
self.assertEqual(df1.take(1), [Row(id=2)])
self.assertEqual(df1.tail(1), [Row(id=9)])
self.assertEqual(df0.schema, df2.schema)
self.assertEqual(df2.count(), 7)
self.assertEqual(df2.take(1), [Row(id=3)])
self.assertEqual(df2.tail(1), [Row(id=9)])
self.assertEqual(df0.schema, df3.schema)
self.assertEqual(df3.count(), 6)
self.assertEqual(df3.take(1), [Row(id=4)])
self.assertEqual(df3.tail(1), [Row(id=9)])
def test_lit_time(self):
import datetime
actual = self.spark.sql("select TIME '12:34:56'").first()[0]
self.assertEqual(actual, datetime.time(12, 34, 56))
class SQLTests(SQLTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_sql import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)