blob: fb78b4c1463265dee3879350001867543c6186e0 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from decimal import Decimal
import os
import time
import unittest
from typing import cast
from pyspark.sql import Row
import pyspark.sql.functions as F
from pyspark.sql.types import (
DecimalType,
StructType,
StructField,
StringType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
)
from pyspark.errors import (
PySparkTypeError,
PySparkValueError,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
class DataFrameCreationTestsMixin:
def test_create_str_from_dict(self):
data = [
{"broker": {"teamId": 3398, "contactEmail": "abc.xyz@123.ca"}},
]
for schema in [
StructType([StructField("broker", StringType())]),
"broker: string",
]:
df = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(
df.first().broker,
"""{'teamId': 3398, 'contactEmail': 'abc.xyz@123.ca'}""",
)
def test_create_dataframe_from_array_of_long(self):
import array
data = [Row(longarray=array.array("l", [-9223372036854775808, 0, 9223372036854775807]))]
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
def test_create_dataframe_from_datetime_time(self):
import datetime
df = self.spark.createDataFrame(
[
(datetime.time(1, 2, 3),),
(datetime.time(4, 5, 6),),
(datetime.time(7, 8, 9),),
],
["t"],
)
self.assertIsInstance(df.schema["t"].dataType, TimeType)
self.assertEqual(df.count(), 3)
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame(
{"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]},
columns=["d", "ts"],
)
# test types are inferred correctly without specifying schema
df = self.spark.createDataFrame(pdf)
self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
# test with schema will accept pdf as input
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
self.assertIsInstance(df.schema["ts"].dataType, TimestampNTZType)
self.assertIsInstance(df.schema["d"].dataType, DateType)
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
with self.quiet():
with self.assertRaisesRegex(
ImportError, "(Pandas >= .* must be installed|No module named '?pandas'?)"
):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame(
{"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}
)
self.spark.createDataFrame(pdf)
# Regression test for SPARK-23360
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_dst(self):
import pandas as pd
from pandas.testing import assert_frame_equal
from datetime import datetime
pdf = pd.DataFrame({"time": [datetime(2015, 10, 31, 22, 30)]})
df = self.spark.createDataFrame(pdf)
assert_frame_equal(pdf, df.toPandas())
orig_env_tz = os.environ.get("TZ", None)
try:
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
time.tzset()
with self.sql_conf({"spark.sql.session.timeZone": tz}):
df = self.spark.createDataFrame(pdf)
assert_frame_equal(pdf, df.toPandas())
finally:
del os.environ["TZ"]
if orig_env_tz is not None:
os.environ["TZ"] = orig_env_tz
time.tzset()
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_day_time_interval(self):
# SPARK-37277: Test DayTimeIntervalType in createDataFrame without Arrow.
import pandas as pd
from datetime import timedelta
df = self.spark.createDataFrame(pd.DataFrame({"a": [timedelta(microseconds=123)]}))
self.assertEqual(df.toPandas().a.iloc[0], timedelta(microseconds=123))
# test for SPARK-36337
def test_create_nan_decimal_dataframe(self):
self.assertEqual(
self.spark.createDataFrame(data=[Decimal("NaN")], schema="decimal").collect(),
[Row(value=None)],
)
def test_check_decimal_nan(self):
data = [Row(dec=Decimal("NaN"))]
schema = StructType([StructField("dec", DecimalType(), False)])
with self.assertRaises(PySparkValueError):
self.spark.createDataFrame(data=data, schema=schema)
def test_invalid_argument_create_dataframe(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame([(1, 2)], schema=123)
self.check_error(
exception=pe.exception,
errorClass="NOT_LIST_OR_NONE_OR_STRUCT",
messageParameters={"arg_name": "schema", "arg_type": "int"},
)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame(self.spark.range(1))
self.check_error(
exception=pe.exception,
errorClass="INVALID_TYPE",
messageParameters={"arg_name": "data", "arg_type": "DataFrame"},
)
def test_partial_inference_failure(self):
with self.assertRaises(PySparkValueError) as pe:
self.spark.createDataFrame([(None, 1)])
self.check_error(
exception=pe.exception,
errorClass="CANNOT_DETERMINE_TYPE",
messageParameters={},
)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
def test_schema_inference_from_pandas_with_dict(self):
# SPARK-47543: test for verifying if inferring `dict` as `MapType` work properly.
import pandas as pd
pdf = pd.DataFrame({"str_col": ["second"], "dict_col": [{"first": 0.7, "second": 0.3}]})
with self.sql_conf(
{
"spark.sql.execution.arrow.pyspark.enabled": True,
"spark.sql.execution.arrow.pyspark.fallback.enabled": False,
"spark.sql.execution.pandas.inferPandasDictAsMap": True,
}
):
sdf = self.spark.createDataFrame(pdf)
self.assertEqual(
sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).collect(),
[Row(str_col="second", dict_col={"first": 0.7, "second": 0.3}, test=0.3)],
)
# Empty dict should fail
pdf_empty_struct = pd.DataFrame({"str_col": ["second"], "dict_col": [{}]})
with self.assertRaises(PySparkValueError) as pe:
self.spark.createDataFrame(pdf_empty_struct)
self.check_error(
exception=pe.exception,
errorClass="CANNOT_INFER_EMPTY_SCHEMA",
messageParameters={},
)
# Dict has different types of values should fail
pdf_different_type = pd.DataFrame(
{"str_col": ["second"], "dict_col": [{"first": 0.7, "second": "0.3"}]}
)
self.assertRaises(
PySparkValueError, lambda: self.spark.createDataFrame(pdf_different_type)
)
with self.sql_conf(
{
"spark.sql.execution.arrow.pyspark.enabled": False,
"spark.sql.execution.pandas.inferPandasDictAsMap": True,
}
):
sdf = self.spark.createDataFrame(pdf)
self.assertEqual(
sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).collect(),
[Row(str_col="second", dict_col={"first": 0.7, "second": 0.3}, test=0.3)],
)
def test_empty_schema(self):
schema = StructType()
for data in [[], [Row()]]:
with self.subTest(data=data):
sdf = self.spark.createDataFrame(data, schema)
assertDataFrameEqual(sdf, data)
class DataFrameCreationTests(
DataFrameCreationTestsMixin,
ReusedSQLTestCase,
):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_creation import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)