blob: 1d1cc3507f0e6bf7a6b757e31cefe46aa2f40950 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import tempfile
from pyspark.errors import AnalysisException
from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
class ReadwriterTestsMixin:
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
assertDataFrameEqual(df.select("value"), actual)
df.write.json(tmpPath, "overwrite")
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
df.write.save(
format="json",
mode="overwrite",
path=tmpPath,
noUse="this options will not be used in save.",
)
actual = self.spark.read.load(
format="json", path=tmpPath, noUse="this options will not be used in load."
)
assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
assertDataFrameEqual(df, actual)
csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
finally:
shutil.rmtree(tmpPath)
def test_save_and_load_builder(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
assertDataFrameEqual(df.select("value"), actual)
df.write.mode("overwrite").json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
df.write.mode("overwrite").options(
noUse="this options will not be used in save."
).option("noUse", "this option will not be used in save.").format("json").save(
path=tmpPath
)
actual = self.spark.read.format("json").load(
path=tmpPath, noUse="this options will not be used in load."
)
assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
assertDataFrameEqual(df, actual)
finally:
shutil.rmtree(tmpPath)
def test_bucketed_write(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def count_bucketed_cols(names, table="pyspark_bucket"):
"""Given a sequence of column names and a table name
query the catalog and return number o columns which are
used for bucketing
"""
cols = self.spark.catalog.listColumns(table)
num = len([c for c in cols if c.name in names and c.isBucket])
return num
with self.table("pyspark_bucket"):
# Test write with one bucketing column
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write two bucketing columns
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with a list of columns
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort with a list of columns
(
df.write.bucketBy(2, "x")
.sortBy(["y", "z"])
.mode("overwrite")
.saveAsTable("pyspark_bucket")
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort with multiple columns
(
df.write.bucketBy(2, "x")
.sortBy("y", "z")
.mode("overwrite")
.saveAsTable("pyspark_bucket")
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def get_cluster_by_cols(table="pyspark_cluster_by"):
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]
table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.write.clusterBy("x").mode("overwrite").saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with two clustering columns
df.write.clusterBy("x", "y").mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "y"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with a list of columns
df.write.clusterBy(["y", "z"]).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["y", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with a tuple of columns
df.write.clusterBy(("x", "z")).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
def test_insert_into(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
with self.table("test_table"):
df.write.saveAsTable("test_table")
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table")
self.assertEqual(4, self.spark.sql("select * from test_table").count())
df.write.mode("overwrite").insertInto("test_table")
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table", True)
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table", False)
self.assertEqual(4, self.spark.sql("select * from test_table").count())
df.write.mode("overwrite").insertInto("test_table", False)
self.assertEqual(6, self.spark.sql("select * from test_table").count())
def test_cached_table(self):
with self.table("test_cached_table_1"):
self.spark.range(10).withColumn(
"value_1",
lit(1),
).write.saveAsTable("test_cached_table_1")
with self.table("test_cached_table_2"):
self.spark.range(10).withColumnRenamed("id", "index").withColumn(
"value_2", lit(2)
).write.saveAsTable("test_cached_table_2")
df1 = self.spark.read.table("test_cached_table_1")
df2 = self.spark.read.table("test_cached_table_2")
df3 = self.spark.read.table("test_cached_table_1")
join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2)
join2 = df3.join(join1, how="left", on=join1.index == df3.id)
self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"])
# "[SPARK-51182]: DataFrameWriter should throw dataPathNotSpecifiedError when path is not
# specified"
def test_save(self):
writer = self.df.write
with self.assertRaisesRegex(Exception, "'path' is not specified."):
writer.save()
class ReadwriterV2TestsMixin:
def test_api(self):
self.check_api(DataFrameWriterV2)
def check_api(self, tpe):
df = self.df
writer = df.writeTo("testcat.t")
self.assertIsInstance(writer, tpe)
self.assertIsInstance(writer.option("property", "value"), tpe)
self.assertIsInstance(writer.options(property="value"), tpe)
self.assertIsInstance(writer.using("source"), tpe)
self.assertIsInstance(writer.partitionedBy("id"), tpe)
self.assertIsInstance(writer.partitionedBy(col("id")), tpe)
self.assertIsInstance(writer.tableProperty("foo", "bar"), tpe)
def test_partitioning_functions(self):
self.check_partitioning_functions(DataFrameWriterV2)
self.partitioning_functions_user_error()
def check_partitioning_functions(self, tpe):
import datetime
from pyspark.sql.functions.partitioning import years, months, days, hours, bucket
df = self.spark.createDataFrame(
[(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
)
writer = df.writeTo("testcat.t")
self.assertIsInstance(writer.partitionedBy(years("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(months("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(days("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(hours("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe)
def partitioning_functions_user_error(self):
import datetime
from pyspark.sql.functions.partitioning import years, months, days, hours, bucket
df = self.spark.createDataFrame(
[(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
)
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(years("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(months("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(days("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(hours("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(bucket(2, "ts")).collect()
def test_create(self):
df = self.df
with self.table("test_table"):
df.writeTo("test_table").using("parquet").create()
self.assertEqual(100, self.spark.sql("select * from test_table").count())
def test_create_without_provider(self):
df = self.df
with self.table("test_table"):
df.writeTo("test_table").create()
self.assertEqual(100, self.spark.sql("select * from test_table").count())
def test_table_overwrite(self):
df = self.df
with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"):
df.writeTo("test_table").overwrite(lit(True))
def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def get_cluster_by_cols(table="pyspark_cluster_by"):
# Note that listColumns only returns top-level clustering columns and doesn't consider
# nested clustering columns as isCluster. This is fine for this test.
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]
table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.writeTo(table_name).using("parquet").clusterBy("x").create()
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
pass
class ReadwriterV2Tests(ReadwriterV2TestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_readwriter import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)