blob: 33ba2c02c63cb77d1f843808acadda8b4c5c9a68 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import tempfile
from pyspark.errors import AnalysisException
from pyspark.sql import Row
from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import (
StructType,
StructField,
StringType,
BinaryType,
ArrayType,
MapType,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
class ReadwriterTestsMixin:
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
assertDataFrameEqual(df.select("value"), actual)
df.write.json(tmpPath, "overwrite")
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
df.write.save(
format="json",
mode="overwrite",
path=tmpPath,
noUse="this options will not be used in save.",
)
actual = self.spark.read.load(
format="json", path=tmpPath, noUse="this options will not be used in load."
)
assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
assertDataFrameEqual(df, actual)
csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
finally:
shutil.rmtree(tmpPath)
def test_save_and_load_builder(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
assertDataFrameEqual(df.select("value"), actual)
df.write.mode("overwrite").json(tmpPath)
actual = self.spark.read.json(tmpPath)
assertDataFrameEqual(df, actual)
df.write.mode("overwrite").options(
noUse="this options will not be used in save."
).option("noUse", "this option will not be used in save.").format("json").save(
path=tmpPath
)
actual = self.spark.read.format("json").load(
path=tmpPath, noUse="this options will not be used in load."
)
assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
assertDataFrameEqual(df, actual)
finally:
shutil.rmtree(tmpPath)
def test_bucketed_write(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def count_bucketed_cols(names, table="pyspark_bucket"):
"""Given a sequence of column names and a table name
query the catalog and return number o columns which are
used for bucketing
"""
cols = self.spark.catalog.listColumns(table)
num = len([c for c in cols if c.name in names and c.isBucket])
return num
with self.table("pyspark_bucket"):
# Test write with one bucketing column
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write two bucketing columns
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with a list of columns
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort with a list of columns
(
df.write.bucketBy(2, "x")
.sortBy(["y", "z"])
.mode("overwrite")
.saveAsTable("pyspark_bucket")
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
# Test write with bucket and sort with multiple columns
(
df.write.bucketBy(2, "x")
.sortBy("y", "z")
.mode("overwrite")
.saveAsTable("pyspark_bucket")
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def get_cluster_by_cols(table="pyspark_cluster_by"):
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]
table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.write.clusterBy("x").mode("overwrite").saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with two clustering columns
df.write.clusterBy("x", "y").mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "y"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with a list of columns
df.write.clusterBy(["y", "z"]).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["y", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
# Test write with a tuple of columns
df.write.clusterBy(("x", "z")).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
def test_insert_into(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
with self.table("test_table"):
df.write.saveAsTable("test_table")
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table")
self.assertEqual(4, self.spark.sql("select * from test_table").count())
df.write.mode("overwrite").insertInto("test_table")
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table", True)
self.assertEqual(2, self.spark.sql("select * from test_table").count())
df.write.insertInto("test_table", False)
self.assertEqual(4, self.spark.sql("select * from test_table").count())
df.write.mode("overwrite").insertInto("test_table", False)
self.assertEqual(6, self.spark.sql("select * from test_table").count())
def test_cached_table(self):
with self.table("test_cached_table_1"):
self.spark.range(10).withColumn(
"value_1",
lit(1),
).write.saveAsTable("test_cached_table_1")
with self.table("test_cached_table_2"):
self.spark.range(10).withColumnRenamed("id", "index").withColumn(
"value_2", lit(2)
).write.saveAsTable("test_cached_table_2")
df1 = self.spark.read.table("test_cached_table_1")
df2 = self.spark.read.table("test_cached_table_2")
df3 = self.spark.read.table("test_cached_table_1")
join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2)
join2 = df3.join(join1, how="left", on=join1.index == df3.id)
self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"])
def test_binary_type(self):
"""Test that binary type in data sources respects binaryAsBytes config"""
schema = StructType(
[
StructField("id", StringType()),
StructField("bin", BinaryType()),
StructField("arr_bin", ArrayType(BinaryType())),
StructField("map_bin", MapType(StringType(), BinaryType())),
]
)
# Create DataFrame with binary data (can use either bytes or bytearray)
data = [Row(id="1", bin=b"hello", arr_bin=[b"a"], map_bin={"key": b"value"})]
df = self.spark.createDataFrame(data, schema)
tmpPath = tempfile.mkdtemp()
try:
# Write to parquet
df.write.mode("overwrite").parquet(tmpPath)
for conf_value in ["true", "false"]:
expected_type = bytes if conf_value == "true" else bytearray
expected_bin = b"hello" if conf_value == "true" else bytearray(b"hello")
expected_arr = b"a" if conf_value == "true" else bytearray(b"a")
expected_map = b"value" if conf_value == "true" else bytearray(b"value")
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}):
result = self.spark.read.parquet(tmpPath).collect()
row = result[0]
# Check binary field
self.assertIsInstance(row.bin, expected_type)
self.assertEqual(row.bin, expected_bin)
# Check array of binary
self.assertIsInstance(row.arr_bin[0], expected_type)
self.assertEqual(row.arr_bin[0], expected_arr)
# Check map value
self.assertIsInstance(row.map_bin["key"], expected_type)
self.assertEqual(row.map_bin["key"], expected_map)
finally:
shutil.rmtree(tmpPath)
# "[SPARK-51182]: DataFrameWriter should throw dataPathNotSpecifiedError when path is not
# specified"
def test_save(self):
writer = self.df.write
with self.assertRaisesRegex(Exception, "'path' is not specified."):
writer.save()
class ReadwriterV2TestsMixin:
def test_api(self):
self.check_api(DataFrameWriterV2)
def check_api(self, tpe):
df = self.df
writer = df.writeTo("testcat.t")
self.assertIsInstance(writer, tpe)
self.assertIsInstance(writer.option("property", "value"), tpe)
self.assertIsInstance(writer.options(property="value"), tpe)
self.assertIsInstance(writer.using("source"), tpe)
self.assertIsInstance(writer.partitionedBy("id"), tpe)
self.assertIsInstance(writer.partitionedBy(col("id")), tpe)
self.assertIsInstance(writer.tableProperty("foo", "bar"), tpe)
def test_partitioning_functions(self):
self.check_partitioning_functions(DataFrameWriterV2)
self.partitioning_functions_user_error()
def check_partitioning_functions(self, tpe):
import datetime
from pyspark.sql.functions.partitioning import years, months, days, hours, bucket
df = self.spark.createDataFrame(
[(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
)
writer = df.writeTo("testcat.t")
self.assertIsInstance(writer.partitionedBy(years("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(months("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(days("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(hours("ts")), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe)
def partitioning_functions_user_error(self):
import datetime
from pyspark.sql.functions.partitioning import years, months, days, hours, bucket
df = self.spark.createDataFrame(
[(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
)
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(years("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(months("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(days("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(hours("ts")).collect()
with self.assertRaisesRegex(
Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY"
):
df.select(bucket(2, "ts")).collect()
def test_create(self):
df = self.df
with self.table("test_table"):
df.writeTo("test_table").using("parquet").create()
self.assertEqual(100, self.spark.sql("select * from test_table").count())
def test_create_without_provider(self):
df = self.df
with self.table("test_table"):
df.writeTo("test_table").create()
self.assertEqual(100, self.spark.sql("select * from test_table").count())
def test_table_overwrite(self):
df = self.df
with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"):
df.writeTo("test_table").overwrite(lit(True))
def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])
def get_cluster_by_cols(table="pyspark_cluster_by"):
# Note that listColumns only returns top-level clustering columns and doesn't consider
# nested clustering columns as isCluster. This is fine for this test.
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]
table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.writeTo(table_name).using("parquet").clusterBy("x").create()
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
pass
class ReadwriterV2Tests(ReadwriterV2TestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_readwriter import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)