| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import unittest |
| |
| from pyspark.sql import Row |
| from pyspark.testing.sqlutils import ReusedSQLTestCase |
| |
| |
| class SQLTestsMixin: |
| def test_simple(self): |
| res = self.spark.sql("SELECT 1 + 1").collect() |
| self.assertEqual(len(res), 1) |
| self.assertEqual(res[0][0], 2) |
| |
| def test_args_dict(self): |
| with self.tempView("test"): |
| self.spark.range(10).createOrReplaceTempView("test") |
| df = self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:table_name)", |
| args={"table_name": "test"}, |
| ) |
| |
| self.assertEqual(df.count(), 10) |
| self.assertEqual(df.limit(5).count(), 5) |
| self.assertEqual(df.offset(5).count(), 5) |
| |
| self.assertEqual(df.take(1), [Row(id=0)]) |
| self.assertEqual(df.tail(1), [Row(id=9)]) |
| |
| def test_args_list(self): |
| with self.tempView("test"): |
| self.spark.range(10).createOrReplaceTempView("test") |
| df = self.spark.sql( |
| "SELECT * FROM test WHERE ? < id AND id < ?", |
| args=[1, 6], |
| ) |
| |
| self.assertEqual(df.count(), 4) |
| self.assertEqual(df.limit(3).count(), 3) |
| self.assertEqual(df.offset(3).count(), 1) |
| |
| self.assertEqual(df.take(1), [Row(id=2)]) |
| self.assertEqual(df.tail(1), [Row(id=5)]) |
| |
| def test_kwargs_literal(self): |
| with self.tempView("test"): |
| self.spark.range(10).createOrReplaceTempView("test") |
| |
| df = self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}", |
| args={"table_name": "test"}, |
| m1=3, |
| m2=7, |
| m3=9, |
| ) |
| |
| self.assertEqual(df.count(), 4) |
| self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)]) |
| self.assertEqual(df.take(1), [Row(id=4)]) |
| self.assertEqual(df.tail(1), [Row(id=9)]) |
| |
| def test_kwargs_literal_multiple_ref(self): |
| with self.tempView("test"): |
| self.spark.range(10).createOrReplaceTempView("test") |
| |
| df = self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0", |
| args={"table_name": "test"}, |
| m=6, |
| ) |
| |
| self.assertEqual(df.count(), 4) |
| self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)]) |
| self.assertEqual(df.take(1), [Row(id=6)]) |
| self.assertEqual(df.tail(1), [Row(id=9)]) |
| |
| def test_kwargs_dataframe(self): |
| df0 = self.spark.range(10) |
| df1 = self.spark.sql( |
| "SELECT * FROM {df} WHERE id > 4", |
| df=df0, |
| ) |
| |
| self.assertEqual(df0.schema, df1.schema) |
| self.assertEqual(df1.count(), 5) |
| self.assertEqual(df1.take(1), [Row(id=5)]) |
| self.assertEqual(df1.tail(1), [Row(id=9)]) |
| |
| def test_kwargs_dataframe_with_column(self): |
| df0 = self.spark.range(10) |
| df1 = self.spark.sql( |
| "SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2", |
| {"m1": 4, "m2": 9}, |
| df=df0, |
| ) |
| |
| self.assertEqual(df0.schema, df1.schema) |
| self.assertEqual(df1.count(), 4) |
| self.assertEqual(df1.take(1), [Row(id=5)]) |
| self.assertEqual(df1.tail(1), [Row(id=8)]) |
| |
| def test_nested_view(self): |
| with self.tempView("v1", "v2", "v3", "v4"): |
| self.spark.range(10).createOrReplaceTempView("v1") |
| self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", |
| args={"view": "v1", "m": 1}, |
| ).createOrReplaceTempView("v2") |
| self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", |
| args={"view": "v2", "m": 2}, |
| ).createOrReplaceTempView("v3") |
| self.spark.sql( |
| "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", |
| args={"view": "v3", "m": 3}, |
| ).createOrReplaceTempView("v4") |
| |
| df = self.spark.sql("select * from v4") |
| self.assertEqual(df.count(), 6) |
| self.assertEqual(df.take(1), [Row(id=4)]) |
| self.assertEqual(df.tail(1), [Row(id=9)]) |
| |
| def test_nested_dataframe(self): |
| df0 = self.spark.range(10) |
| df1 = self.spark.sql( |
| "SELECT * FROM {df} WHERE id > ?", |
| args=[1], |
| df=df0, |
| ) |
| df2 = self.spark.sql( |
| "SELECT * FROM {df} WHERE id > ?", |
| args=[2], |
| df=df1, |
| ) |
| df3 = self.spark.sql( |
| "SELECT * FROM {df} WHERE id > ?", |
| args=[3], |
| df=df2, |
| ) |
| |
| self.assertEqual(df0.schema, df1.schema) |
| self.assertEqual(df1.count(), 8) |
| self.assertEqual(df1.take(1), [Row(id=2)]) |
| self.assertEqual(df1.tail(1), [Row(id=9)]) |
| |
| self.assertEqual(df0.schema, df2.schema) |
| self.assertEqual(df2.count(), 7) |
| self.assertEqual(df2.take(1), [Row(id=3)]) |
| self.assertEqual(df2.tail(1), [Row(id=9)]) |
| |
| self.assertEqual(df0.schema, df3.schema) |
| self.assertEqual(df3.count(), 6) |
| self.assertEqual(df3.take(1), [Row(id=4)]) |
| self.assertEqual(df3.tail(1), [Row(id=9)]) |
| |
| def test_lit_time(self): |
| import datetime |
| |
| actual = self.spark.sql("select TIME '12:34:56'").first()[0] |
| self.assertEqual(actual, datetime.time(12, 34, 56)) |
| |
| |
| class SQLTests(SQLTestsMixin, ReusedSQLTestCase): |
| pass |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.test_sql import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |