blob: 626c282bbb2e14c284c5c079ed594b3d23b07843 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from decimal import Decimal
import os
import pydoc
import shutil
import tempfile
import time
import unittest
from typing import cast
import io
from contextlib import redirect_stdout
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, lit, count, sum, mean, struct
from pyspark.sql.pandas.utils import pyarrow_version_less_than_minimum
from pyspark.sql.types import (
StringType,
IntegerType,
DoubleType,
LongType,
StructType,
StructField,
BooleanType,
DateType,
TimestampType,
TimestampNTZType,
FloatType,
DayTimeIntervalType,
)
from pyspark.storagelevel import StorageLevel
from pyspark.errors import (
AnalysisException,
IllegalArgumentException,
PySparkTypeError,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
SQLTestUtils,
have_pyarrow,
have_pandas,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
class DataFrameTestsMixin:
def test_range(self):
self.assertEqual(self.spark.range(1, 1).count(), 0)
self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
self.assertEqual(self.spark.range(-2).count(), 0)
self.assertEqual(self.spark.range(3).count(), 3)
def test_duplicated_column_names(self):
df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
row = df.select("*").first()
self.assertEqual(1, row[0])
self.assertEqual(2, row[1])
self.assertEqual("Row(c=1, c=2)", str(row))
# Cannot access columns
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
df = self.spark.createDataFrame(vals)
items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
self.assertTrue(1 in items[0])
self.assertTrue(-2.0 in items[1])
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(df)
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
def test_drop(self):
df = self.spark.createDataFrame([("A", 50, "Y"), ("B", 60, "Y")], ["name", "age", "active"])
self.assertEqual(df.drop("active").columns, ["name", "age"])
self.assertEqual(df.drop("active", "nonexistent_column").columns, ["name", "age"])
self.assertEqual(df.drop("name", "age", "active").columns, [])
self.assertEqual(df.drop(col("name")).columns, ["age", "active"])
self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"])
def test_with_columns_renamed(self):
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"])
# rename both columns
renamed_df1 = df.withColumnsRenamed({"name": "naam", "age": "leeftijd"})
self.assertEqual(renamed_df1.columns, ["naam", "leeftijd"])
# rename one column with one missing name
renamed_df2 = df.withColumnsRenamed({"name": "naam", "address": "adres"})
self.assertEqual(renamed_df2.columns, ["naam", "age"])
# negative test for incorrect type
with self.assertRaises(PySparkTypeError) as pe:
df.withColumnsRenamed(("name", "x"))
self.check_error(
exception=pe.exception,
error_class="NOT_DICT",
message_parameters={"arg_name": "colsMap", "arg_type": "tuple"},
)
def test_drop_duplicates(self):
# SPARK-36034 test that drop duplicates throws a type error when in correct type provided
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"])
# shouldn't drop a non-null row
self.assertEqual(df.dropDuplicates().count(), 2)
self.assertEqual(df.dropDuplicates(["name"]).count(), 1)
self.assertEqual(df.dropDuplicates(["name", "age"]).count(), 2)
with self.assertRaises(PySparkTypeError) as pe:
df.dropDuplicates("name")
self.check_error(
exception=pe.exception,
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": "str"},
)
def test_drop_duplicates_with_ambiguous_reference(self):
df1 = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
df2 = self.spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")])
df3 = df1.join(df2, df1.name == df2.name, "inner")
self.assertEqual(df3.drop("name", "age").columns, ["height"])
self.assertEqual(df3.drop("name", df3.age, "unknown").columns, ["height"])
self.assertEqual(df3.drop("name", "age", df3.height).columns, [])
def test_dropna(self):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True),
]
)
# shouldn't drop a non-null row
self.assertEqual(
self.spark.createDataFrame([("Alice", 50, 80.1)], schema).dropna().count(), 1
)
# dropping rows with a single null value
self.assertEqual(
self.spark.createDataFrame([("Alice", None, 80.1)], schema).dropna().count(), 0
)
self.assertEqual(
self.spark.createDataFrame([("Alice", None, 80.1)], schema).dropna(how="any").count(), 0
)
# if how = 'all', only drop rows if all values are null
self.assertEqual(
self.spark.createDataFrame([("Alice", None, 80.1)], schema).dropna(how="all").count(), 1
)
self.assertEqual(
self.spark.createDataFrame([(None, None, None)], schema).dropna(how="all").count(), 0
)
# how and subset
self.assertEqual(
self.spark.createDataFrame([("Alice", 50, None)], schema)
.dropna(how="any", subset=["name", "age"])
.count(),
1,
)
self.assertEqual(
self.spark.createDataFrame([("Alice", None, None)], schema)
.dropna(how="any", subset=["name", "age"])
.count(),
0,
)
# threshold
self.assertEqual(
self.spark.createDataFrame([("Alice", None, 80.1)], schema).dropna(thresh=2).count(), 1
)
self.assertEqual(
self.spark.createDataFrame([("Alice", None, None)], schema).dropna(thresh=2).count(), 0
)
# threshold and subset
self.assertEqual(
self.spark.createDataFrame([("Alice", 50, None)], schema)
.dropna(thresh=2, subset=["name", "age"])
.count(),
1,
)
self.assertEqual(
self.spark.createDataFrame([("Alice", None, 180.9)], schema)
.dropna(thresh=2, subset=["name", "age"])
.count(),
0,
)
# thresh should take precedence over how
self.assertEqual(
self.spark.createDataFrame([("Alice", 50, None)], schema)
.dropna(how="any", thresh=2, subset=["name", "age"])
.count(),
1,
)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([("Alice", 50, None)], schema).dropna(subset=10)
self.check_error(
exception=pe.exception,
error_class="NOT_LIST_OR_STR_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": "int"},
)
def test_fillna(self):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True),
StructField("spy", BooleanType(), True),
]
)
# fillna shouldn't change non-null values
row = self.spark.createDataFrame([("Alice", 10, 80.1, True)], schema).fillna(50).first()
self.assertEqual(row.age, 10)
# fillna with int
row = self.spark.createDataFrame([("Alice", None, None, None)], schema).fillna(50).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.0)
# fillna with double
row = self.spark.createDataFrame([("Alice", None, None, None)], schema).fillna(50.1).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.1)
# fillna with bool
row = self.spark.createDataFrame([("Alice", None, None, None)], schema).fillna(True).first()
self.assertEqual(row.age, None)
self.assertEqual(row.spy, True)
# fillna with string
row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
self.assertEqual(row.name, "hello")
self.assertEqual(row.age, None)
# fillna with subset specified for numeric cols
row = (
self.spark.createDataFrame([(None, None, None, None)], schema)
.fillna(50, subset=["name", "age"])
.first()
)
self.assertEqual(row.name, None)
self.assertEqual(row.age, 50)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, None)
# fillna with subset specified for string cols
row = (
self.spark.createDataFrame([(None, None, None, None)], schema)
.fillna("haha", subset=["name", "age"])
.first()
)
self.assertEqual(row.name, "haha")
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, None)
# fillna with subset specified for bool cols
row = (
self.spark.createDataFrame([(None, None, None, None)], schema)
.fillna(True, subset=["name", "spy"])
.first()
)
self.assertEqual(row.name, None)
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, True)
# fillna with dictionary for boolean types
row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
self.assertEqual(row.a, True)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna(["a", True])
self.check_error(
exception=pe.exception,
error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR",
message_parameters={"arg_name": "value", "arg_type": "list"},
)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna(50, subset=10)
self.check_error(
exception=pe.exception,
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": "int"},
)
def test_repartitionByRange_dataframe(self):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True),
]
)
df1 = self.spark.createDataFrame(
[("Bob", 27, 66.0), ("Alice", 10, 10.0), ("Bob", 10, 66.0)], schema
)
df2 = self.spark.createDataFrame(
[("Alice", 10, 10.0), ("Bob", 10, 66.0), ("Bob", 27, 66.0)], schema
)
# test repartitionByRange(numPartitions, *cols)
df3 = df1.repartitionByRange(2, "name", "age")
self.assertEqual(df3.rdd.getNumPartitions(), 2)
self.assertEqual(df3.rdd.first(), df2.rdd.first())
self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
# test repartitionByRange(numPartitions, *cols)
df4 = df1.repartitionByRange(3, "name", "age")
self.assertEqual(df4.rdd.getNumPartitions(), 3)
self.assertEqual(df4.rdd.first(), df2.rdd.first())
self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
# test repartitionByRange(*cols)
df5 = df1.repartitionByRange(5, "name", "age")
self.assertEqual(df5.rdd.first(), df2.rdd.first())
self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
with self.assertRaises(PySparkTypeError) as pe:
df1.repartitionByRange([10], "name", "age")
self.check_error(
exception=pe.exception,
error_class="NOT_COLUMN_OR_INT_OR_STR",
message_parameters={"arg_name": "numPartitions", "arg_type": "list"},
)
def test_replace(self):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True),
]
)
# replace with int
row = self.spark.createDataFrame([("Alice", 10, 10.0)], schema).replace(10, 20).first()
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 20.0)
# replace with double
row = self.spark.createDataFrame([("Alice", 80, 80.0)], schema).replace(80.0, 82.1).first()
self.assertEqual(row.age, 82)
self.assertEqual(row.height, 82.1)
# replace with string
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace("Alice", "Ann")
.first()
)
self.assertEqual(row.name, "Ann")
self.assertEqual(row.age, 10)
# replace with subset specified by a string of a column name w/ actual change
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace(10, 20, subset="age")
.first()
)
self.assertEqual(row.age, 20)
# replace with subset specified by a string of a column name w/o actual change
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace(10, 20, subset="height")
.first()
)
self.assertEqual(row.age, 10)
# replace with subset specified with one column replaced, another column not in subset
# stays unchanged.
row = (
self.spark.createDataFrame([("Alice", 10, 10.0)], schema)
.replace(10, 20, subset=["name", "age"])
.first()
)
self.assertEqual(row.name, "Alice")
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 10.0)
# replace with subset specified but no column will be replaced
row = (
self.spark.createDataFrame([("Alice", 10, None)], schema)
.replace(10, 20, subset=["name", "height"])
.first()
)
self.assertEqual(row.name, "Alice")
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)
# replace with lists
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace(["Alice"], ["Ann"])
.first()
)
self.assertTupleEqual(row, ("Ann", 10, 80.1))
# replace with dict
row = self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace({10: 11}).first()
self.assertTupleEqual(row, ("Alice", 11, 80.1))
# test backward compatibility with dummy value
dummy_value = 1
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace({"Alice": "Bob"}, dummy_value)
.first()
)
self.assertTupleEqual(row, ("Bob", 10, 80.1))
# test dict with mixed numerics
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace({10: -10, 80.1: 90.5})
.first()
)
self.assertTupleEqual(row, ("Alice", -10, 90.5))
# replace with tuples
row = (
self.spark.createDataFrame([("Alice", 10, 80.1)], schema)
.replace(("Alice",), ("Bob",))
.first()
)
self.assertTupleEqual(row, ("Bob", 10, 80.1))
# replace multiple columns
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.replace((10, 80.0), (20, 90))
.first()
)
self.assertTupleEqual(row, ("Alice", 20, 90.0))
# test for mixed numerics
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.replace((10, 80), (20, 90.5))
.first()
)
self.assertTupleEqual(row, ("Alice", 20, 90.5))
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.replace({10: 20, 80: 90.5})
.first()
)
self.assertTupleEqual(row, ("Alice", 20, 90.5))
# replace with boolean
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.selectExpr("name = 'Bob'", "age <= 15")
.replace(False, True)
.first()
)
self.assertTupleEqual(row, (True, True))
# replace string with None and then drop None rows
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.replace("Alice", None)
.dropna()
)
self.assertEqual(row.count(), 0)
# replace with number and None
row = (
self.spark.createDataFrame([("Alice", 10, 80.0)], schema)
.replace([10, 80], [20, None])
.first()
)
self.assertTupleEqual(row, ("Alice", 20, None))
# should fail if subset is not list, tuple or None
with self.assertRaises(TypeError):
self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace(
{10: 11}, subset=1
).first()
# should fail if to_replace and value have different length
with self.assertRaises(ValueError):
self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace(
["Alice", "Bob"], ["Eve"]
).first()
# should fail if when received unexpected type
with self.assertRaises(TypeError):
from datetime import datetime
self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace(
datetime.now(), datetime.now()
).first()
# should fail if provided mixed type replacements
with self.assertRaises(ValueError):
self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace(
["Alice", 10], ["Eve", 20]
).first()
with self.assertRaises(ValueError):
self.spark.createDataFrame([("Alice", 10, 80.1)], schema).replace(
{"Alice": "Bob", 10: 20}
).first()
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([("Alice", 10, 80.0)], schema).replace(["Alice", "Bob"])
self.check_error(
exception=pe.exception,
error_class="ARGUMENT_REQUIRED",
message_parameters={"arg_name": "value", "condition": "`to_replace` is dict"},
)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([("Alice", 10, 80.0)], schema).replace(lambda x: x + 1, 10)
self.check_error(
exception=pe.exception,
error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
message_parameters={"arg_name": "to_replace", "arg_type": "function"},
)
def test_with_column_with_existing_name(self):
keys = self.df.withColumn("key", self.df.key).select("key").collect()
self.assertEqual([r.key for r in keys], list(range(100)))
# regression test for SPARK-10417
def test_column_iterator(self):
def foo():
for x in self.df.key:
break
self.assertRaises(TypeError, foo)
def test_with_columns(self):
# With single column
keys = self.df.withColumns({"key": self.df.key}).select("key").collect()
self.assertEqual([r.key for r in keys], list(range(100)))
# With key and value columns
kvs = (
self.df.withColumns({"key": self.df.key, "value": self.df.value})
.select("key", "value")
.collect()
)
self.assertEqual([(r.key, r.value) for r in kvs], [(i, str(i)) for i in range(100)])
# Columns rename
kvs = (
self.df.withColumns({"key_alias": self.df.key, "value_alias": self.df.value})
.select("key_alias", "value_alias")
.collect()
)
self.assertEqual(
[(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in range(100)]
)
# Type check
self.assertRaises(TypeError, self.df.withColumns, ["key"])
self.assertRaises(Exception, self.df.withColumns)
def test_generic_hints(self):
df1 = self.spark.range(10e10).toDF("id")
df2 = self.spark.range(10e10).toDF("id")
self.assertIsInstance(df1.hint("broadcast"), type(df1))
# Dummy rules
self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), type(df1))
with io.StringIO() as buf, redirect_stdout(buf):
df1.join(df2.hint("broadcast"), "id").explain(True)
self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin"))
# add tests for SPARK-23647 (test more types for hint)
def test_extended_hint_types(self):
df = self.spark.range(10e10).toDF("id")
such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list)
self.assertIsInstance(df.hint("broadcast", []), type(df))
self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))
with io.StringIO() as buf, redirect_stdout(buf):
hinted_df.explain(True)
explain_output = buf.getvalue()
self.assertGreaterEqual(explain_output.count("1.2345"), 1)
self.assertGreaterEqual(explain_output.count("what"), 1)
self.assertGreaterEqual(explain_output.count("itworks"), 1)
def test_unpivot(self):
# SPARK-39877: test the DataFrame.unpivot method
df = self.spark.createDataFrame(
[
(1, 10, 1.0, "one"),
(2, 20, 2.0, "two"),
(3, 30, 3.0, "three"),
],
["id", "int", "double", "str"],
)
with self.subTest(desc="with none identifier"):
with self.assertRaisesRegex(AssertionError, "ids must not be None"):
df.unpivot(None, ["int", "double"], "var", "val")
with self.subTest(desc="with no identifier"):
for id in [[], ()]:
with self.subTest(ids=id):
actual = df.unpivot(id, ["int", "double"], "var", "val")
self.assertEqual(actual.schema.simpleString(), "struct<var:string,val:double>")
self.assertEqual(
actual.collect(),
[
Row(var="int", value=10.0),
Row(var="double", value=1.0),
Row(var="int", value=20.0),
Row(var="double", value=2.0),
Row(var="int", value=30.0),
Row(var="double", value=3.0),
],
)
with self.subTest(desc="with single identifier column"):
for id in ["id", ["id"], ("id",)]:
with self.subTest(ids=id):
actual = df.unpivot(id, ["int", "double"], "var", "val")
self.assertEqual(
actual.schema.simpleString(),
"struct<id:bigint,var:string,val:double>",
)
self.assertEqual(
actual.collect(),
[
Row(id=1, var="int", value=10.0),
Row(id=1, var="double", value=1.0),
Row(id=2, var="int", value=20.0),
Row(id=2, var="double", value=2.0),
Row(id=3, var="int", value=30.0),
Row(id=3, var="double", value=3.0),
],
)
with self.subTest(desc="with multiple identifier columns"):
for ids in [["id", "double"], ("id", "double")]:
with self.subTest(ids=ids):
actual = df.unpivot(ids, ["int", "double"], "var", "val")
self.assertEqual(
actual.schema.simpleString(),
"struct<id:bigint,double:double,var:string,val:double>",
)
self.assertEqual(
actual.collect(),
[
Row(id=1, double=1.0, var="int", value=10.0),
Row(id=1, double=1.0, var="double", value=1.0),
Row(id=2, double=2.0, var="int", value=20.0),
Row(id=2, double=2.0, var="double", value=2.0),
Row(id=3, double=3.0, var="int", value=30.0),
Row(id=3, double=3.0, var="double", value=3.0),
],
)
with self.subTest(desc="with no identifier columns but none value columns"):
# select only columns that have common data type (double)
actual = df.select("id", "int", "double").unpivot([], None, "var", "val")
self.assertEqual(actual.schema.simpleString(), "struct<var:string,val:double>")
self.assertEqual(
actual.collect(),
[
Row(var="id", value=1.0),
Row(var="int", value=10.0),
Row(var="double", value=1.0),
Row(var="id", value=2.0),
Row(var="int", value=20.0),
Row(var="double", value=2.0),
Row(var="id", value=3.0),
Row(var="int", value=30.0),
Row(var="double", value=3.0),
],
)
with self.subTest(desc="with single identifier columns but none value columns"):
for ids in ["id", ["id"], ("id",)]:
with self.subTest(ids=ids):
# select only columns that have common data type (double)
actual = df.select("id", "int", "double").unpivot(ids, None, "var", "val")
self.assertEqual(
actual.schema.simpleString(), "struct<id:bigint,var:string,val:double>"
)
self.assertEqual(
actual.collect(),
[
Row(id=1, var="int", value=10.0),
Row(id=1, var="double", value=1.0),
Row(id=2, var="int", value=20.0),
Row(id=2, var="double", value=2.0),
Row(id=3, var="int", value=30.0),
Row(id=3, var="double", value=3.0),
],
)
with self.subTest(desc="with multiple identifier columns but none given value columns"):
for ids in [["id", "str"], ("id", "str")]:
with self.subTest(ids=ids):
actual = df.unpivot(ids, None, "var", "val")
self.assertEqual(
actual.schema.simpleString(),
"struct<id:bigint,str:string,var:string,val:double>",
)
self.assertEqual(
actual.collect(),
[
Row(id=1, str="one", var="int", val=10.0),
Row(id=1, str="one", var="double", val=1.0),
Row(id=2, str="two", var="int", val=20.0),
Row(id=2, str="two", var="double", val=2.0),
Row(id=3, str="three", var="int", val=30.0),
Row(id=3, str="three", var="double", val=3.0),
],
)
with self.subTest(desc="with single value column"):
for values in ["int", ["int"], ("int",)]:
with self.subTest(values=values):
actual = df.unpivot("id", values, "var", "val")
self.assertEqual(
actual.schema.simpleString(), "struct<id:bigint,var:string,val:bigint>"
)
self.assertEqual(
actual.collect(),
[
Row(id=1, var="int", val=10),
Row(id=2, var="int", val=20),
Row(id=3, var="int", val=30),
],
)
with self.subTest(desc="with multiple value columns"):
for values in [["int", "double"], ("int", "double")]:
with self.subTest(values=values):
actual = df.unpivot("id", values, "var", "val")
self.assertEqual(
actual.schema.simpleString(), "struct<id:bigint,var:string,val:double>"
)
self.assertEqual(
actual.collect(),
[
Row(id=1, var="int", val=10.0),
Row(id=1, var="double", val=1.0),
Row(id=2, var="int", val=20.0),
Row(id=2, var="double", val=2.0),
Row(id=3, var="int", val=30.0),
Row(id=3, var="double", val=3.0),
],
)
with self.subTest(desc="with columns"):
for id in [df.id, [df.id], (df.id,)]:
for values in [[df.int, df.double], (df.int, df.double)]:
with self.subTest(ids=id, values=values):
self.assertEqual(
df.unpivot(id, values, "var", "val").collect(),
df.unpivot("id", ["int", "double"], "var", "val").collect(),
)
with self.subTest(desc="with column names and columns"):
for ids in [[df.id, "str"], (df.id, "str")]:
for values in [[df.int, "double"], (df.int, "double")]:
with self.subTest(ids=ids, values=values):
self.assertEqual(
df.unpivot(ids, values, "var", "val").collect(),
df.unpivot(["id", "str"], ["int", "double"], "var", "val").collect(),
)
with self.subTest(desc="melt alias"):
self.assertEqual(
df.unpivot("id", ["int", "double"], "var", "val").collect(),
df.melt("id", ["int", "double"], "var", "val").collect(),
)
def test_unpivot_negative(self):
# SPARK-39877: test the DataFrame.unpivot method
df = self.spark.createDataFrame(
[
(1, 10, 1.0, "one"),
(2, 20, 2.0, "two"),
(3, 30, 3.0, "three"),
],
["id", "int", "double", "str"],
)
with self.subTest(desc="with no value columns"):
for values in [[], ()]:
with self.subTest(values=values):
with self.assertRaisesRegex(
AnalysisException,
r"\[UNPIVOT_REQUIRES_VALUE_COLUMNS] At least one value column "
r"needs to be specified for UNPIVOT, all columns specified as ids.*",
):
df.unpivot("id", values, "var", "val").collect()
with self.subTest(desc="with value columns without common data type"):
with self.assertRaisesRegex(
AnalysisException,
r"\[UNPIVOT_VALUE_DATA_TYPE_MISMATCH\] Unpivot value columns must share "
r"a least common type, some types do not: .*",
):
df.unpivot("id", ["int", "str"], "var", "val").collect()
def test_observe(self):
# SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
from pyspark.sql import Observation
df = self.spark.createDataFrame(
[
(1, 1.0, "one"),
(2, 2.0, "two"),
(3, 3.0, "three"),
],
["id", "val", "label"],
)
unnamed_observation = Observation()
named_observation = Observation("metric")
observed = (
df.orderBy("id")
.observe(
named_observation,
count(lit(1)).alias("cnt"),
sum(col("id")).alias("sum"),
mean(col("val")).alias("mean"),
)
.observe(unnamed_observation, count(lit(1)).alias("rows"))
)
# test that observe works transparently
actual = observed.collect()
self.assertEqual(
[
{"id": 1, "val": 1.0, "label": "one"},
{"id": 2, "val": 2.0, "label": "two"},
{"id": 3, "val": 3.0, "label": "three"},
],
[row.asDict() for row in actual],
)
# test that we retrieve the metrics
self.assertEqual(named_observation.get, dict(cnt=3, sum=6, mean=2.0))
self.assertEqual(unnamed_observation.get, dict(rows=3))
# observation requires name (if given) to be non empty string
with self.assertRaisesRegex(TypeError, "name should be a string"):
Observation(123)
with self.assertRaisesRegex(ValueError, "name should not be empty"):
Observation("")
# dataframe.observe requires at least one expr
with self.assertRaisesRegex(ValueError, "'exprs' should not be empty"):
df.observe(Observation())
# dataframe.observe requires non-None Columns
for args in [(None,), ("id",), (lit(1), None), (lit(1), "id")]:
with self.subTest(args=args):
with self.assertRaisesRegex(ValueError, "all 'exprs' should be Column"):
df.observe(Observation(), *args)
def test_observe_str(self):
# SPARK-38760: tests the DataFrame.observe(str, *Column) method
from pyspark.sql.streaming import StreamingQueryListener
observed_metrics = None
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
pass
def onQueryProgress(self, event):
nonlocal observed_metrics
observed_metrics = event.progress.observedMetrics
def onQueryTerminated(self, event):
pass
self.spark.streams.addListener(TestListener())
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df = df.observe("metric", count(lit(1)).alias("cnt"), sum(col("value")).alias("sum"))
q = df.writeStream.format("noop").queryName("test").start()
self.assertTrue(q.isActive)
time.sleep(10)
q.stop()
self.assertTrue(isinstance(observed_metrics, dict))
self.assertTrue("metric" in observed_metrics)
row = observed_metrics["metric"]
self.assertTrue(isinstance(row, Row))
self.assertTrue(hasattr(row, "cnt"))
self.assertTrue(hasattr(row, "sum"))
self.assertGreaterEqual(row.cnt, 0)
self.assertGreaterEqual(row.sum, 0)
def test_sample(self):
self.assertRaisesRegex(
TypeError, "should be a bool, float and number", lambda: self.spark.range(1).sample()
)
self.assertRaises(TypeError, lambda: self.spark.range(1).sample("a"))
self.assertRaises(TypeError, lambda: self.spark.range(1).sample(seed="abc"))
self.assertRaises(
IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count()
)
def test_toDF_with_schema_string(self):
data = [Row(key=i, value=str(i)) for i in range(100)]
rdd = self.sc.parallelize(data, 5)
df = rdd.toDF("key: int, value: string")
self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
self.assertEqual(df.collect(), data)
# different but compatible field types can be used.
df = rdd.toDF("key: string, value: string")
self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
# field names can differ.
df = rdd.toDF(" a: int, b: string ")
self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
self.assertEqual(df.collect(), data)
# number of fields must match.
self.assertRaisesRegex(
Exception, "Length of object", lambda: rdd.toDF("key: int").collect()
)
# field types mismatch will cause exception at runtime.
self.assertRaisesRegex(
Exception,
"FloatType\\(\\) can not accept",
lambda: rdd.toDF("key: float, value: string").collect(),
)
# flat schema values will be wrapped into row.
df = rdd.map(lambda row: row.key).toDF("int")
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
# users can use DataType directly instead of data type string.
df = rdd.map(lambda row: row.key).toDF(IntegerType())
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
# Cartesian products require cross join syntax
def test_require_cross(self):
df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
# joins without conditions require cross join syntax
self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
# works with crossJoin
self.assertEqual(1, df1.crossJoin(df2).count())
def test_cache_dataframe(self):
df = self.spark.createDataFrame([(2, 2), (3, 3)])
try:
self.assertEqual(df.storageLevel, StorageLevel.NONE)
df.cache()
self.assertEqual(df.storageLevel, StorageLevel.MEMORY_AND_DISK)
df.unpersist()
self.assertEqual(df.storageLevel, StorageLevel.NONE)
df.persist()
self.assertEqual(df.storageLevel, StorageLevel.MEMORY_AND_DISK_DESER)
df.unpersist(blocking=True)
self.assertEqual(df.storageLevel, StorageLevel.NONE)
df.persist(StorageLevel.DISK_ONLY)
self.assertEqual(df.storageLevel, StorageLevel.DISK_ONLY)
finally:
df.unpersist()
self.assertEqual(df.storageLevel, StorageLevel.NONE)
def test_cache_table(self):
spark = self.spark
with self.tempView("tab1", "tab2"):
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
spark.createDataFrame([(2, 4), (3, 4)]).createOrReplaceTempView("tab2")
self.assertFalse(spark.catalog.isCached("tab1"))
self.assertFalse(spark.catalog.isCached("tab2"))
spark.catalog.cacheTable("tab1")
self.assertTrue(spark.catalog.isCached("tab1"))
self.assertFalse(spark.catalog.isCached("tab2"))
spark.catalog.cacheTable("tab2")
spark.catalog.uncacheTable("tab1")
self.assertFalse(spark.catalog.isCached("tab1"))
self.assertTrue(spark.catalog.isCached("tab2"))
spark.catalog.clearCache()
self.assertFalse(spark.catalog.isCached("tab1"))
self.assertFalse(spark.catalog.isCached("tab2"))
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.isCached("does_not_exist"),
)
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.cacheTable("does_not_exist"),
)
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.uncacheTable("does_not_exist"),
)
def _to_pandas(self):
from datetime import datetime, date, timedelta
schema = (
StructType()
.add("a", IntegerType())
.add("b", StringType())
.add("c", BooleanType())
.add("d", FloatType())
.add("dt", DateType())
.add("ts", TimestampType())
.add("ts_ntz", TimestampNTZType())
.add("dt_interval", DayTimeIntervalType())
)
data = [
(
1,
"foo",
True,
3.0,
date(1969, 1, 1),
datetime(1969, 1, 1, 1, 1, 1),
datetime(1969, 1, 1, 1, 1, 1),
timedelta(days=1),
),
(2, "foo", True, 5.0, None, None, None, None),
(
3,
"bar",
False,
-1.0,
date(2012, 3, 3),
datetime(2012, 3, 3, 3, 3, 3),
datetime(2012, 3, 3, 3, 3, 3),
timedelta(hours=-1, milliseconds=421),
),
(
4,
"bar",
False,
6.0,
date(2100, 4, 4),
datetime(2100, 4, 4, 4, 4, 4),
datetime(2100, 4, 4, 4, 4, 4),
timedelta(microseconds=123),
),
]
df = self.spark.createDataFrame(data, schema)
return df.toPandas()
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas(self):
import numpy as np
pdf = self._to_pandas()
types = pdf.dtypes
self.assertEqual(types[0], np.int32)
self.assertEqual(types[1], object)
self.assertEqual(types[2], bool)
self.assertEqual(types[3], np.float32)
self.assertEqual(types[4], object) # datetime.date
self.assertEqual(types[5], "datetime64[ns]")
self.assertEqual(types[6], "datetime64[ns]")
self.assertEqual(types[7], "timedelta64[ns]")
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_with_duplicated_column_names(self):
for arrow_enabled in [False, True]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
self.check_to_pandas_with_duplicated_column_names()
def check_to_pandas_with_duplicated_column_names(self):
import numpy as np
sql = "select 1 v, 1 v"
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_on_cross_join(self):
for arrow_enabled in [False, True]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
self.check_to_pandas_on_cross_join()
def check_to_pandas_on_cross_join(self):
import numpy as np
sql = """
select t1.*, t2.* from (
select explode(sequence(1, 3)) v
) t1 left join (
select explode(sequence(1, 3)) v
) t2
"""
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegex(ImportError, "Pandas >= .* must be installed"):
self._to_pandas()
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_avoid_astype(self):
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType()).add("c", IntegerType())
data = [(1, "foo", 16777220), (None, "bar", None)]
df = self.spark.createDataFrame(data, schema)
types = df.toPandas().dtypes
self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
self.assertEqual(types[1], object)
self.assertEqual(types[2], np.float64)
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_empty_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_empty_dataframe()
def check_to_pandas_from_empty_dataframe(self):
# SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
# SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes
# when arrow is enabled
import numpy as np
sql = """
SELECT CAST(1 AS TINYINT) AS tinyint,
CAST(1 AS SMALLINT) AS smallint,
CAST(1 AS INT) AS int,
CAST(1 AS BIGINT) AS bigint,
CAST(0 AS FLOAT) AS float,
CAST(0 AS DOUBLE) AS double,
CAST(1 AS BOOLEAN) AS boolean,
CAST('foo' AS STRING) AS string,
CAST('2019-01-01' AS TIMESTAMP) AS timestamp,
CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz,
INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
"""
dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_null_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_null_dataframe()
def check_to_pandas_from_null_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
# SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes
# using arrow
import numpy as np
sql = """
SELECT CAST(NULL AS TINYINT) AS tinyint,
CAST(NULL AS SMALLINT) AS smallint,
CAST(NULL AS INT) AS int,
CAST(NULL AS BIGINT) AS bigint,
CAST(NULL AS FLOAT) AS float,
CAST(NULL AS DOUBLE) AS double,
CAST(NULL AS BOOLEAN) AS boolean,
CAST(NULL AS STRING) AS string,
CAST(NULL AS TIMESTAMP) AS timestamp,
CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz,
INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
"""
pdf = self.spark.sql(sql).toPandas()
types = pdf.dtypes
self.assertEqual(types[0], np.float64)
self.assertEqual(types[1], np.float64)
self.assertEqual(types[2], np.float64)
self.assertEqual(types[3], np.float64)
self.assertEqual(types[4], np.float32)
self.assertEqual(types[5], np.float64)
self.assertEqual(types[6], object)
self.assertEqual(types[7], object)
self.assertTrue(np.can_cast(np.datetime64, types[8]))
self.assertTrue(np.can_cast(np.datetime64, types[9]))
self.assertTrue(np.can_cast(np.timedelta64, types[10]))
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_mixed_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_mixed_dataframe()
def check_to_pandas_from_mixed_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
# SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes
# using arrow
import numpy as np
sql = """
SELECT CAST(col1 AS TINYINT) AS tinyint,
CAST(col2 AS SMALLINT) AS smallint,
CAST(col3 AS INT) AS int,
CAST(col4 AS BIGINT) AS bigint,
CAST(col5 AS FLOAT) AS float,
CAST(col6 AS DOUBLE) AS double,
CAST(col7 AS BOOLEAN) AS boolean,
CAST(col8 AS STRING) AS string,
timestamp_seconds(col9) AS timestamp,
timestamp_seconds(col10) AS timestamp_ntz,
INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
"""
pdf_with_some_nulls = self.spark.sql(sql).toPandas()
pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas()
self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
@unittest.skipIf(
not have_pandas or not have_pyarrow or pyarrow_version_less_than_minimum("2.0.0"),
pandas_requirement_message
or pyarrow_requirement_message
or "Pyarrow version must be 2.0.0 or higher",
)
def test_to_pandas_for_array_of_struct(self):
for is_arrow_enabled in [True, False]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": is_arrow_enabled}):
self.check_to_pandas_for_array_of_struct(is_arrow_enabled)
def check_to_pandas_for_array_of_struct(self, is_arrow_enabled):
# SPARK-38098: Support Array of Struct for Pandas UDFs and toPandas
import numpy as np
import pandas as pd
df = self.spark.createDataFrame(
[[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]],
"array_struct_col Array<struct<col1:string, col2:long, col3:double>>",
)
pdf = df.toPandas()
self.assertEqual(type(pdf), pd.DataFrame)
self.assertEqual(type(pdf["array_struct_col"]), pd.Series)
if is_arrow_enabled:
self.assertEqual(type(pdf["array_struct_col"][0]), np.ndarray)
else:
self.assertEqual(type(pdf["array_struct_col"][0]), list)
def test_create_dataframe_from_array_of_long(self):
import array
data = [Row(longarray=array.array("l", [-9223372036854775808, 0, 9223372036854775807]))]
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame(
{"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]},
columns=["d", "ts"],
)
# test types are inferred correctly without specifying schema
df = self.spark.createDataFrame(pdf)
self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
# test with schema will accept pdf as input
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
self.assertIsInstance(df.schema["ts"].dataType, TimestampNTZType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegex(
ImportError, "(Pandas >= .* must be installed|No module named '?pandas'?)"
):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame(
{"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}
)
self.spark.createDataFrame(pdf)
# Regression test for SPARK-23360
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_dst(self):
import pandas as pd
from pandas.testing import assert_frame_equal
from datetime import datetime
pdf = pd.DataFrame({"time": [datetime(2015, 10, 31, 22, 30)]})
df = self.spark.createDataFrame(pdf)
assert_frame_equal(pdf, df.toPandas())
orig_env_tz = os.environ.get("TZ", None)
try:
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
time.tzset()
with self.sql_conf({"spark.sql.session.timeZone": tz}):
df = self.spark.createDataFrame(pdf)
assert_frame_equal(pdf, df.toPandas())
finally:
del os.environ["TZ"]
if orig_env_tz is not None:
os.environ["TZ"] = orig_env_tz
time.tzset()
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_day_time_interval(self):
# SPARK-37277: Test DayTimeIntervalType in createDataFrame without Arrow.
import pandas as pd
from datetime import timedelta
df = self.spark.createDataFrame(pd.DataFrame({"a": [timedelta(microseconds=123)]}))
self.assertEqual(df.toPandas().a.iloc[0], timedelta(microseconds=123))
def test_repr_behaviors(self):
import re
pattern = re.compile(r"^ *\|", re.MULTILINE)
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
# test when eager evaluation is enabled and _repr_html_ will not be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
expected1 = """+-----+-----+
|| key|value|
|+-----+-----+
|| 1| 1|
||22222|22222|
|+-----+-----+
|"""
self.assertEqual(re.sub(pattern, "", expected1), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
expected2 = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
||222| 222|
|+---+-----+
|"""
self.assertEqual(re.sub(pattern, "", expected2), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
expected3 = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
|+---+-----+
|only showing top 1 row
|"""
self.assertEqual(re.sub(pattern, "", expected3), df.__repr__())
# test when eager evaluation is enabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
expected1 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
|<tr><td>1</td><td>1</td></tr>
|<tr><td>22222</td><td>22222</td></tr>
|</table>
|"""
self.assertEqual(re.sub(pattern, "", expected1), df._repr_html_())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
expected2 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
|<tr><td>1</td><td>1</td></tr>
|<tr><td>222</td><td>222</td></tr>
|</table>
|"""
self.assertEqual(re.sub(pattern, "", expected2), df._repr_html_())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
expected3 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
|<tr><td>1</td><td>1</td></tr>
|</table>
|only showing top 1 row
|"""
self.assertEqual(re.sub(pattern, "", expected3), df._repr_html_())
# test when eager evaluation is disabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
expected = "DataFrame[key: bigint, value: string]"
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
def test_to_local_iterator(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator()
self.assertEqual(expected, list(it))
# Test DataFrame with empty partition
df = self.spark.range(3, numPartitions=4)
it = df.toLocalIterator()
expected = df.collect()
self.assertEqual(expected, list(it))
def test_to_local_iterator_prefetch(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator(prefetchPartitions=True)
self.assertEqual(expected, list(it))
def test_to_local_iterator_not_fully_consumed(self):
with QuietTest(self.sc):
self.check_to_local_iterator_not_fully_consumed()
def check_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
df = self.spark.range(1 << 20, numPartitions=2)
it = df.toLocalIterator()
self.assertEqual(df.take(1)[0], next(it))
it = None # remove iterator from scope, socket is closed when cleaned up
# Make sure normal df operations still work
result = []
for i, row in enumerate(df.toLocalIterator()):
result.append(row)
if i == 7:
break
self.assertEqual(df.take(8), result)
def test_same_semantics_error(self):
with QuietTest(self.sc):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(10).sameSemantics(1)
self.check_error(
exception=pe.exception,
error_class="NOT_STR",
message_parameters={"arg_name": "other", "arg_type": "int"},
)
def test_input_files(self):
tpath = tempfile.mkdtemp()
shutil.rmtree(tpath)
try:
self.spark.range(1, 100, 1, 10).write.parquet(tpath)
# read parquet file and get the input files list
input_files_list = self.spark.read.parquet(tpath).inputFiles()
# input files list should contain 10 entries
self.assertEqual(len(input_files_list), 10)
# all file paths in list must contain tpath
for file_path in input_files_list:
self.assertTrue(tpath in file_path)
finally:
shutil.rmtree(tpath)
def test_df_show(self):
# SPARK-35408: ensure better diagnostics if incorrect parameters are passed
# to DataFrame.show
df = self.spark.createDataFrame([("foo",)])
df.show(5)
df.show(5, True)
df.show(5, 1, True)
df.show(n=5, truncate="1", vertical=False)
df.show(n=5, truncate=1.5, vertical=False)
with self.assertRaises(PySparkTypeError) as pe:
df.show(True)
self.check_error(
exception=pe.exception,
error_class="NOT_INT",
message_parameters={"arg_name": "n", "arg_type": "bool"},
)
with self.assertRaises(PySparkTypeError) as pe:
df.show(vertical="foo")
self.check_error(
exception=pe.exception,
error_class="NOT_BOOL",
message_parameters={"arg_name": "vertical", "arg_type": "str"},
)
with self.assertRaises(PySparkTypeError) as pe:
df.show(truncate="foo")
self.check_error(
exception=pe.exception,
error_class="NOT_BOOL",
message_parameters={"arg_name": "truncate", "arg_type": "str"},
)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
def test_pandas_api(self):
import pandas as pd
from pandas.testing import assert_frame_equal
sdf = self.spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"])
psdf_from_sdf = sdf.pandas_api()
psdf_from_sdf_with_index = sdf.pandas_api(index_col="Col1")
pdf = pd.DataFrame({"Col1": ["a", "b", "c"], "Col2": [1, 2, 3]})
pdf_with_index = pdf.set_index("Col1")
assert_frame_equal(pdf, psdf_from_sdf.to_pandas())
assert_frame_equal(pdf_with_index, psdf_from_sdf_with_index.to_pandas())
# test for SPARK-36337
def test_create_nan_decimal_dataframe(self):
self.assertEqual(
self.spark.createDataFrame(data=[Decimal("NaN")], schema="decimal").collect(),
[Row(value=None)],
)
def test_to(self):
schema = StructType(
[StructField("i", StringType(), True), StructField("j", IntegerType(), True)]
)
df = self.spark.createDataFrame([("a", 1)], schema)
schema1 = StructType([StructField("j", StringType()), StructField("i", StringType())])
df1 = df.to(schema1)
self.assertEqual(schema1, df1.schema)
self.assertEqual(df.count(), df1.count())
schema2 = StructType([StructField("j", LongType())])
df2 = df.to(schema2)
self.assertEqual(schema2, df2.schema)
self.assertEqual(df.count(), df2.count())
schema3 = StructType([StructField("struct", schema1, False)])
df3 = df.select(struct("i", "j").alias("struct")).to(schema3)
self.assertEqual(schema3, df3.schema)
self.assertEqual(df.count(), df3.count())
# incompatible field nullability
schema4 = StructType([StructField("j", LongType(), False)])
self.assertRaisesRegex(
AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4).count()
)
# field cannot upcast
schema5 = StructType([StructField("i", LongType())])
self.assertRaisesRegex(
AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5).count()
)
def test_repartition(self):
df = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
with self.assertRaises(PySparkTypeError) as pe:
df.repartition([10], "name", "age").rdd.getNumPartitions()
self.check_error(
exception=pe.exception,
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "numPartitions", "arg_type": "list"},
)
def test_colregex(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(10).colRegex(10)
self.check_error(
exception=pe.exception,
error_class="NOT_STR",
message_parameters={"arg_name": "colName", "arg_type": "int"},
)
def test_where(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(10).where(10)
self.check_error(
exception=pe.exception,
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "condition", "arg_type": "int"},
)
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home
SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"TestQueryExecutionListener.class"
)
cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
if cls.has_listener:
# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
cls.spark = (
SparkSession.builder.master("local[4]")
.appName(cls.__name__)
.config(
"spark.sql.queryExecutionListeners",
"org.apache.spark.sql.TestQueryExecutionListener",
)
.getOrCreate()
)
def setUp(self):
if not self.has_listener:
raise self.skipTest(
"'org.apache.spark.sql.TestQueryExecutionListener' is not "
"available. Will skip the related tests."
)
@classmethod
def tearDownClass(cls):
if hasattr(cls, "spark"):
cls.spark.stop()
def tearDown(self):
self.spark._jvm.OnSuccessCall.clear()
def test_query_execution_listener_on_collect(self):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be called before 'collect'",
)
self.spark.sql("SELECT * FROM range(1)").collect()
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'collect'",
)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
def test_query_execution_listener_on_collect_with_arrow(self):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be "
"called before 'toPandas'",
)
self.spark.sql("SELECT * FROM range(1)").toPandas()
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'toPandas'",
)
class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_dataframe import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)