blob: 7800d5756b04e1a23fa7e7596a263b833c047e05 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark import pandas as ps
from pyspark.errors import ParseException
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.utils import assertDataFrameEqual
class SQLTestsMixin:
def test_error_variable_not_exist(self):
with self.assertRaisesRegex(KeyError, "variable_foo"):
ps.sql("select * from {variable_foo}")
def test_error_bad_sql(self):
with self.assertRaises(ParseException):
ps.sql("this is not valid sql")
def test_series_not_referred(self):
psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
with self.assertRaisesRegex(ValueError, "The series in {ser}"):
ps.sql("SELECT {ser} FROM range(10)", ser=psdf.A)
def test_sql_with_index_col(self):
import pandas as pd
# Index
psdf = ps.DataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], name="index")
)
psdf_reset_index = psdf.reset_index()
actual = ps.sql(
"select * from {psdf_reset_index} where A > 1",
index_col="index",
psdf_reset_index=psdf_reset_index,
)
expected = psdf.iloc[[1, 2]]
assertDataFrameEqual(actual, expected)
# MultiIndex
psdf = ps.DataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6]},
index=pd.MultiIndex.from_tuples(
[("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
),
)
psdf_reset_index = psdf.reset_index()
actual = ps.sql(
"select * from {psdf_reset_index} where A > 1",
index_col=["index1", "index2"],
psdf_reset_index=psdf_reset_index,
)
expected = psdf.iloc[[1, 2]]
assertDataFrameEqual(actual, expected)
def test_sql_with_pandas_objects(self):
import pandas as pd
pdf = pd.DataFrame({"a": [1, 2, 3, 4]})
assertDataFrameEqual(
ps.sql("SELECT {col} + 1 as a FROM {tbl}", col=pdf.a, tbl=pdf), pdf + 1
)
def test_sql_with_python_objects(self):
assertDataFrameEqual(
ps.sql("SELECT {col} as a FROM range(1)", col="lit"), ps.DataFrame({"a": ["lit"]})
)
assertDataFrameEqual(
ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="lit", pred=(1, 2, 3)),
ps.DataFrame({"id": [1, 2, 3]}),
)
assertDataFrameEqual(
ps.sql("SELECT {col} as a FROM range(1)", col="a'''c''d"),
ps.DataFrame({"a": ["a'''c''d"]}),
)
assertDataFrameEqual(
ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="a'''c''d", pred=(1, 2, 3)),
ps.DataFrame({"id": [1, 2, 3]}),
)
def test_sql_with_pandas_on_spark_objects(self):
psdf = ps.DataFrame({"a": [1, 2, 3, 4]})
assertDataFrameEqual(ps.sql("SELECT {col} FROM {tbl}", col=psdf.a, tbl=psdf), psdf)
assertDataFrameEqual(ps.sql("SELECT {tbl.a} FROM {tbl}", tbl=psdf), psdf)
psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
assertDataFrameEqual(
ps.sql("SELECT {col}, {col2} FROM {tbl}", col=psdf.A, col2=psdf.B, tbl=psdf), psdf
)
assertDataFrameEqual(ps.sql("SELECT {tbl.A}, {tbl.B} FROM {tbl}", tbl=psdf), psdf)
class SQLTests(SQLTestsMixin, PandasOnSparkTestCase, SQLTestUtils):
pass
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_sql import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)