blob: 917320d354e2baa5c8dca6d58d2669f38b59c0b7 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import array
import datetime
import unittest
import random
import string
from pyspark.errors import PySparkValueError
from pyspark.sql.types import (
StructType,
StructField,
StringType,
IntegerType,
LongType,
MapType,
ArrayType,
Row,
)
from pyspark.testing.objects import MyObject, PythonOnlyUDT
from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
if should_test_connect:
import pandas as pd
import numpy as np
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.errors.exceptions.connect import ParseException
class SparkConnectCreationTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
def test_with_local_data(self):
"""SPARK-41114: Test creating a dataframe using local data"""
pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
df = self.connect.createDataFrame(pdf)
rows = df.filter(df.a == CF.lit(3)).collect()
self.assertTrue(len(rows) == 1)
self.assertEqual(rows[0][0], 3)
self.assertEqual(rows[0][1], "c")
# Check correct behavior for empty DataFrame
pdf = pd.DataFrame({"a": []})
with self.assertRaises(ValueError):
self.connect.createDataFrame(pdf)
def test_with_local_ndarray(self):
"""SPARK-41446: Test creating a dataframe using local list"""
data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
sdf = self.spark.createDataFrame(data)
cdf = self.connect.createDataFrame(data)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
for schema in [
StructType(
[
StructField("col1", IntegerType(), True),
StructField("col2", IntegerType(), True),
StructField("col3", IntegerType(), True),
StructField("col4", IntegerType(), True),
]
),
"struct<col1 int, col2 int, col3 int, col4 int>",
"col1 int, col2 int, col3 int, col4 int",
"col1 int, col2 long, col3 string, col4 long",
"col1 int, col2 string, col3 short, col4 long",
["a", "b", "c", "d"],
("x1", "x2", "x3", "x4"),
]:
with self.subTest(schema=schema):
sdf = self.spark.createDataFrame(data, schema=schema)
cdf = self.connect.createDataFrame(data, schema=schema)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
self.check_error(
exception=pe.exception,
errorClass="AXIS_LENGTH_MISMATCH",
messageParameters={"expected_length": "5", "actual_length": "4"},
)
with self.assertRaises(ParseException):
self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
self.check_error(
exception=pe.exception,
errorClass="AXIS_LENGTH_MISMATCH",
messageParameters={"expected_length": "3", "actual_length": "4"},
)
# test 1 dim ndarray
data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
self.assertEqual(data.ndim, 1)
sdf = self.spark.createDataFrame(data)
cdf = self.connect.createDataFrame(data)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
def test_with_local_list(self):
"""SPARK-41446: Test creating a dataframe using local list"""
data = [[1, 2, 3, 4]]
sdf = self.spark.createDataFrame(data)
cdf = self.connect.createDataFrame(data)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
for schema in [
"struct<col1 int, col2 int, col3 int, col4 int>",
"col1 int, col2 int, col3 int, col4 int",
"col1 int, col2 long, col3 string, col4 long",
"col1 int, col2 string, col3 short, col4 long",
["a", "b", "c", "d"],
("x1", "x2", "x3", "x4"),
]:
sdf = self.spark.createDataFrame(data, schema=schema)
cdf = self.connect.createDataFrame(data, schema=schema)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
self.check_error(
exception=pe.exception,
errorClass="AXIS_LENGTH_MISMATCH",
messageParameters={"expected_length": "5", "actual_length": "4"},
)
with self.assertRaises(ParseException):
self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int")
with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
self.check_error(
exception=pe.exception,
errorClass="AXIS_LENGTH_MISMATCH",
messageParameters={"expected_length": "3", "actual_length": "4"},
)
def test_with_local_rows(self):
# SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries
rows = [
Row(course="dotNET", year=2012, earnings=10000),
Row(course="Java", year=2012, earnings=20000),
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000),
Row(course="Scala", year=2022, earnings=None),
]
dicts = [row.asDict() for row in rows]
for data in [rows, dicts]:
sdf = self.spark.createDataFrame(data)
cdf = self.connect.createDataFrame(data)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
# test with rename
sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
def test_streaming_local_relation(self):
threshold_conf = "spark.sql.session.localRelationCacheThreshold"
old_threshold = self.connect.conf.get(threshold_conf)
threshold = 1024 * 1024
self.connect.conf.set(threshold_conf, threshold)
try:
suffix = "abcdef"
letters = string.ascii_lowercase
str = "".join(random.choice(letters) for i in range(threshold)) + suffix
data = [[0, str], [1, str]]
for i in range(0, 2):
cdf = self.connect.createDataFrame(data, ["a", "b"])
self.assert_eq(cdf.count(), len(data))
self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False)
finally:
self.connect.conf.set(threshold_conf, old_threshold)
def test_with_atom_type(self):
for data in [[(1), (2), (3)], [1, 2, 3]]:
for schema in ["long", "int", "short"]:
sdf = self.spark.createDataFrame(data, schema=schema)
cdf = self.connect.createDataFrame(data, schema=schema)
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
def test_with_none_and_nan(self):
# SPARK-41855: make createDataFrame support None and NaN
# SPARK-41814: test with eqNullSafe
data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
data2 = [Row(id=1, value=np.nan), Row(id=2, value=42.0), Row(id=3, value=None)]
data3 = [
{"id": 1, "value": float("NaN")},
{"id": 2, "value": 42.0},
{"id": 3, "value": None},
]
data4 = [{"id": 1, "value": np.nan}, {"id": 2, "value": 42.0}, {"id": 3, "value": None}]
data5 = [(1, float("NaN")), (2, 42.0), (3, None)]
data6 = [(1, np.nan), (2, 42.0), (3, None)]
data7 = np.array([[1, float("NaN")], [2, 42.0], [3, None]])
data8 = np.array([[1, np.nan], [2, 42.0], [3, None]])
# +---+-----+
# | id|value|
# +---+-----+
# | 1| NaN|
# | 2| 42.0|
# | 3| NULL|
# +---+-----+
for data in [data1, data2, data3, data4, data5, data6, data7, data8]:
if isinstance(data[0], (Row, dict)):
# data1, data2, data3, data4
cdf = self.connect.createDataFrame(data)
sdf = self.spark.createDataFrame(data)
else:
# data5, data6, data7, data8
cdf = self.connect.createDataFrame(data, schema=["id", "value"])
sdf = self.spark.createDataFrame(data, schema=["id", "value"])
self.assert_eq(cdf.toPandas(), sdf.toPandas())
self.assert_eq(
cdf.select(
cdf["value"].eqNullSafe(None),
cdf["value"].eqNullSafe(float("NaN")),
cdf["value"].eqNullSafe(42.0),
).toPandas(),
sdf.select(
sdf["value"].eqNullSafe(None),
sdf["value"].eqNullSafe(float("NaN")),
sdf["value"].eqNullSafe(42.0),
).toPandas(),
)
# SPARK-41851: test with nanvl
data = [(1.0, float("nan")), (float("nan"), 2.0)]
cdf = self.connect.createDataFrame(data, ("a", "b"))
sdf = self.spark.createDataFrame(data, ("a", "b"))
self.assert_eq(cdf.toPandas(), sdf.toPandas())
self.assert_eq(
cdf.select(
CF.nanvl("a", "b").alias("r1"), CF.nanvl(cdf.a, cdf.b).alias("r2")
).toPandas(),
sdf.select(
SF.nanvl("a", "b").alias("r1"), SF.nanvl(sdf.a, sdf.b).alias("r2")
).toPandas(),
)
# SPARK-41852: test with pmod
data = [
(1.0, float("nan")),
(float("nan"), 2.0),
(10.0, 3.0),
(float("nan"), float("nan")),
(-3.0, 4.0),
(-10.0, 3.0),
(-5.0, -6.0),
(7.0, -8.0),
(1.0, 2.0),
]
cdf = self.connect.createDataFrame(data, ("a", "b"))
sdf = self.spark.createDataFrame(data, ("a", "b"))
self.assert_eq(cdf.toPandas(), sdf.toPandas())
self.assert_eq(
cdf.select(CF.pmod("a", "b")).toPandas(),
sdf.select(SF.pmod("a", "b")).toPandas(),
)
def test_cast_with_ddl(self):
data = [Row(date=datetime.date(2021, 12, 27), add=2)]
cdf = self.connect.createDataFrame(data, "date date, add integer")
sdf = self.spark.createDataFrame(data, "date date, add integer")
self.assertEqual(cdf.schema, sdf.schema)
def test_create_empty_df(self):
for schema in [
"STRING",
"x STRING",
"x STRING, y INTEGER",
StringType(),
StructType(
[
StructField("x", StringType(), True),
StructField("y", IntegerType(), True),
]
),
]:
cdf = self.connect.createDataFrame(data=[], schema=schema)
sdf = self.spark.createDataFrame(data=[], schema=schema)
self.assert_eq(cdf.toPandas(), sdf.toPandas())
# check error
with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data=[])
self.check_error(
exception=pe.exception,
errorClass="CANNOT_INFER_EMPTY_SCHEMA",
messageParameters={},
)
def test_create_dataframe_from_arrays(self):
# SPARK-42021: createDataFrame support array.array
data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", [4, 5, 6]))]
data2 = [(array.array("d", [1, 2, 3]), 2, "3")]
data3 = [{"a": 1, "b": array.array("i", [1, 2, 3])}]
for data in [data1, data2, data3]:
cdf = self.connect.createDataFrame(data)
sdf = self.spark.createDataFrame(data)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_timestampe_create_from_rows(self):
data = [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]
cdf = self.connect.createDataFrame(data, ["date", "val"])
sdf = self.spark.createDataFrame(data, ["date", "val"])
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_create_dataframe_with_coercion(self):
data1 = [[1.33, 1], ["2.1", 1]]
data2 = [[True, 1], ["false", 1]]
for data in [data1, data2]:
cdf = self.connect.createDataFrame(data, ["a", "b"])
sdf = self.spark.createDataFrame(data, ["a", "b"])
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_nested_type_create_from_rows(self):
data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
# root
# |-- a: long (nullable = true)
# |-- b: struct (nullable = true)
# | |-- c: long (nullable = true)
# | |-- d: struct (nullable = true)
# | | |-- e: long (nullable = true)
# | | |-- f: struct (nullable = true)
# | | | |-- g: long (nullable = true)
# | | | |-- h: struct (nullable = true)
# | | | | |-- i: long (nullable = true)
data2 = [
(
1,
"a",
Row(
a=1,
b=[1, 2, 3],
c={"a": "b"},
d=Row(x=1, y="y", z=Row(o=1, p=2, q=Row(g=1.5))),
),
)
]
# root
# |-- _1: long (nullable = true)
# |-- _2: string (nullable = true)
# |-- _3: struct (nullable = true)
# | |-- a: long (nullable = true)
# | |-- b: array (nullable = true)
# | | |-- element: long (containsNull = true)
# | |-- c: map (nullable = true)
# | | |-- key: string
# | | |-- value: string (valueContainsNull = true)
# | |-- d: struct (nullable = true)
# | | |-- x: long (nullable = true)
# | | |-- y: string (nullable = true)
# | | |-- z: struct (nullable = true)
# | | | |-- o: long (nullable = true)
# | | | |-- p: long (nullable = true)
# | | | |-- q: struct (nullable = true)
# | | | | |-- g: double (nullable = true)
data3 = [
Row(
a=1,
b=[1, 2, 3],
c={"a": "b"},
d=Row(x=1, y="y", z=Row(1, 2, 3)),
e=list("hello connect"),
)
]
# root
# |-- a: long (nullable = true)
# |-- b: array (nullable = true)
# | |-- element: long (containsNull = true)
# |-- c: map (nullable = true)
# | |-- key: string
# | |-- value: string (valueContainsNull = true)
# |-- d: struct (nullable = true)
# | |-- x: long (nullable = true)
# | |-- y: string (nullable = true)
# | |-- z: struct (nullable = true)
# | | |-- _1: long (nullable = true)
# | | |-- _2: long (nullable = true)
# | | |-- _3: long (nullable = true)
# |-- e: array (nullable = true)
# | |-- element: string (containsNull = true)
data4 = [
{
"a": 1,
"b": Row(x=1, y=Row(z=2)),
"c": {"x": -1, "y": 2},
"d": [1, 2, 3, 4, 5],
}
]
# root
# |-- a: long (nullable = true)
# |-- b: struct (nullable = true)
# | |-- x: long (nullable = true)
# | |-- y: struct (nullable = true)
# | | |-- z: long (nullable = true)
# |-- c: map (nullable = true)
# | |-- key: string
# | |-- value: long (valueContainsNull = true)
# |-- d: array (nullable = true)
# | |-- element: long (containsNull = true)
data5 = [
{
"a": [Row(x=1, y="2"), Row(x=-1, y="-2")],
"b": [[1, 2, 3], [4, 5], [6]],
"c": {3: {4: {5: 6}}, 7: {8: {9: 0}}},
}
]
# root
# |-- a: array (nullable = true)
# | |-- element: struct (containsNull = true)
# | | |-- x: long (nullable = true)
# | | |-- y: string (nullable = true)
# |-- b: array (nullable = true)
# | |-- element: array (containsNull = true)
# | | |-- element: long (containsNull = true)
# |-- c: map (nullable = true)
# | |-- key: long
# | |-- value: map (valueContainsNull = true)
# | | |-- key: long
# | | |-- value: map (valueContainsNull = true)
# | | | |-- key: long
# | | | |-- value: long (valueContainsNull = true)
for data in [data1, data2, data3, data4, data5]:
with self.subTest(data=data):
cdf = self.connect.createDataFrame(data)
sdf = self.spark.createDataFrame(data)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_create_df_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
# +---+-----+
# |key|value|
# +---+-----+
# | 1| 1|
# | 2| 2|
# +---+-----+
cdf = self.connect.createDataFrame(data)
sdf = self.spark.createDataFrame(data)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_create_df_nullability(self):
data = [("asd", None)]
schema = StructType(
[
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=False),
]
)
with self.assertRaises(PySparkValueError):
self.spark.createDataFrame(data, schema)
def test_create_dataframe_from_pandas_with_ns_timestamp(self):
"""Truncate the timestamps for nanoseconds."""
from datetime import datetime, timezone, timedelta
from pandas import Timestamp
import pandas as pd
# Nanoseconds are truncated to microseconds in the serializer
# Arrow will throw an error if precision is lost
# (i.e., nanoseconds cannot be represented in microseconds)
pdf = pd.DataFrame(
{
"naive": [datetime(2019, 1, 1, 0)],
"aware": [
Timestamp(
year=2019, month=1, day=1, nanosecond=0, tz=timezone(timedelta(hours=-8))
)
],
}
)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
self.assertEqual(
self.connect.createDataFrame(pdf).collect(),
self.spark.createDataFrame(pdf).collect(),
)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
self.assertEqual(
self.connect.createDataFrame(pdf).collect(),
self.spark.createDataFrame(pdf).collect(),
)
def test_schema_has_nullable(self):
schema_false = StructType().add("id", IntegerType(), False)
cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())
schema_true = StructType().add("id", IntegerType(), True)
cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
self.assertEqual(cdf2.schema, sdf2.schema)
self.assertEqual(cdf2.collect(), sdf2.collect())
pdf1 = cdf1.toPandas()
cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
self.assertEqual(cdf3.schema, sdf3.schema)
self.assertEqual(cdf3.collect(), sdf3.collect())
pdf2 = cdf2.toPandas()
cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
self.assertEqual(cdf4.schema, sdf4.schema)
self.assertEqual(cdf4.collect(), sdf4.collect())
def test_array_has_nullable(self):
for schemas, data in [
(
[StructType().add("arr", ArrayType(IntegerType(), False), True)],
[Row([1, 2]), Row([3]), Row(None)],
),
(
[
StructType().add("arr", ArrayType(IntegerType(), True), True),
"arr array<integer>",
],
[Row([1, None]), Row([3]), Row(None)],
),
(
[StructType().add("arr", ArrayType(IntegerType(), False), False)],
[Row([1, 2]), Row([3])],
),
(
[
StructType().add("arr", ArrayType(IntegerType(), True), False),
"arr array<integer> not null",
],
[Row([1, None]), Row([3])],
),
]:
for schema in schemas:
with self.subTest(schema=schema):
cdf = self.connect.createDataFrame(data, schema=schema)
sdf = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_map_has_nullable(self):
for schemas, data in [
(
[StructType().add("map", MapType(StringType(), IntegerType(), False), True)],
[Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
),
(
[
StructType().add("map", MapType(StringType(), IntegerType(), True), True),
"map map<string, integer>",
],
[Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
),
(
[StructType().add("map", MapType(StringType(), IntegerType(), False), False)],
[Row({"a": 1, "b": 2}), Row({"a": 3})],
),
(
[
StructType().add("map", MapType(StringType(), IntegerType(), True), False),
"map map<string, integer> not null",
],
[Row({"a": 1, "b": None}), Row({"a": 3})],
),
]:
for schema in schemas:
with self.subTest(schema=schema):
cdf = self.connect.createDataFrame(data, schema=schema)
sdf = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_struct_has_nullable(self):
for schemas, data in [
(
[
StructType().add("struct", StructType().add("i", IntegerType(), False), True),
"struct struct<i: integer not null>",
],
[Row(Row(1)), Row(Row(2)), Row(None)],
),
(
[
StructType().add("struct", StructType().add("i", IntegerType(), True), True),
"struct struct<i: integer>",
],
[Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
),
(
[
StructType().add("struct", StructType().add("i", IntegerType(), False), False),
"struct struct<i: integer not null> not null",
],
[Row(Row(1)), Row(Row(2))],
),
(
[
StructType().add("struct", StructType().add("i", IntegerType(), True), False),
"struct struct<i: integer> not null",
],
[Row(Row(1)), Row(Row(2)), Row(Row(None))],
),
]:
for schema in schemas:
with self.subTest(schema=schema):
cdf = self.connect.createDataFrame(data, schema=schema)
sdf = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_large_client_data(self):
# SPARK-42816 support more than 4MB message size.
# ~200bytes
cols = ["abcdefghijklmnoprstuvwxyz" for x in range(10)]
# 100k rows => 20MB
row_count = 100 * 1000
rows = [cols] * row_count
self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
def test_simple_udt(self):
from pyspark.ml.linalg import MatrixUDT, VectorUDT
for schema in [
StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())),
StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())),
StructType().add("key", LongType()).add("val", PythonOnlyUDT()),
StructType().add("key", LongType()).add("vec", VectorUDT()),
StructType().add("key", LongType()).add("mat", MatrixUDT()),
]:
cdf = self.connect.createDataFrame(data=[], schema=schema)
sdf = self.spark.createDataFrame(data=[], schema=schema)
self.assertEqual(cdf.schema, sdf.schema)
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_creation import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)