blob: 8983d45d42d14e011162f7d9579734f277ab980a [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import decimal
import datetime
from pyspark.sql.types import (
Row,
StructField,
StructType,
MapType,
NullType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
ByteType,
BinaryType,
ShortType,
IntegerType,
FloatType,
DayTimeIntervalType,
StringType,
DoubleType,
LongType,
DecimalType,
BooleanType,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
if should_test_connect:
import pandas as pd
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import DistributedSequenceID, LiteralExpression
from pyspark.util import (
JVM_BYTE_MIN,
JVM_BYTE_MAX,
JVM_SHORT_MIN,
JVM_SHORT_MAX,
JVM_INT_MIN,
JVM_INT_MAX,
JVM_LONG_MIN,
JVM_LONG_MAX,
)
from pyspark.errors.exceptions.connect import SparkConnectException
class SparkConnectColumnTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
def test_column_operator(self):
# SPARK-41351: Column needs to support !=
df = self.connect.range(10)
self.assertEqual(9, len(df.filter(df.id != CF.lit(1)).collect()))
def test_columns(self):
# SPARK-41036: test `columns` API for python client.
query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
df = self.connect.sql(query)
df2 = self.spark.sql(query)
self.assertEqual(["id", "name"], df.columns)
self.assert_eq(
df.filter(df.name.rlike("20")).toPandas(), df2.filter(df2.name.rlike("20")).toPandas()
)
self.assert_eq(
df.filter(df.name.like("20")).toPandas(), df2.filter(df2.name.like("20")).toPandas()
)
self.assert_eq(
df.filter(df.name.ilike("20")).toPandas(), df2.filter(df2.name.ilike("20")).toPandas()
)
self.assert_eq(
df.filter(df.name.contains("20")).toPandas(),
df2.filter(df2.name.contains("20")).toPandas(),
)
self.assert_eq(
df.filter(df.name.startswith("2")).toPandas(),
df2.filter(df2.name.startswith("2")).toPandas(),
)
self.assert_eq(
df.filter(df.name.endswith("0")).toPandas(),
df2.filter(df2.name.endswith("0")).toPandas(),
)
self.assert_eq(
df.select(df.name.substr(0, 1).alias("col")).toPandas(),
df2.select(df2.name.substr(0, 1).alias("col")).toPandas(),
)
self.assert_eq(
df.select(df.name.substr(0, 1).name("col")).toPandas(),
df2.select(df2.name.substr(0, 1).name("col")).toPandas(),
)
df3 = self.connect.sql("SELECT cast(null as int) as name")
df4 = self.spark.sql("SELECT cast(null as int) as name")
self.assert_eq(
df3.filter(df3.name.isNull()).toPandas(),
df4.filter(df4.name.isNull()).toPandas(),
)
self.assert_eq(
df3.filter(df3.name.isNotNull()).toPandas(),
df4.filter(df4.name.isNotNull()).toPandas(),
)
# check error
with self.assertRaises(PySparkTypeError) as pe:
df.name.substr(df.id, 10)
self.check_error(
exception=pe.exception,
errorClass="NOT_SAME_TYPE",
messageParameters={
"arg_name1": "startPos",
"arg_name2": "length",
"arg_type1": "Column",
"arg_type2": "int",
},
)
with self.assertRaises(PySparkTypeError) as pe:
df.name.substr(10.5, 10.5)
self.check_error(
exception=pe.exception,
errorClass="NOT_COLUMN_OR_INT",
messageParameters={
"arg_name": "startPos",
"arg_type": "float",
},
)
def test_column_with_null(self):
# SPARK-41751: test isNull, isNotNull, eqNullSafe
query = """
SELECT * FROM VALUES
(1, 1, NULL), (2, NULL, NULL), (3, 3, 1)
AS tab(a, b, c)
"""
# +---+----+----+
# | a| b| c|
# +---+----+----+
# | 1| 1|NULL|
# | 2|NULL|NULL|
# | 3| 3| 1|
# +---+----+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# test isNull
self.assert_eq(
cdf.select(cdf.a.isNull(), cdf["b"].isNull(), CF.col("c").isNull()).toPandas(),
sdf.select(sdf.a.isNull(), sdf["b"].isNull(), SF.col("c").isNull()).toPandas(),
)
# test isNotNull
self.assert_eq(
cdf.select(cdf.a.isNotNull(), cdf["b"].isNotNull(), CF.col("c").isNotNull()).toPandas(),
sdf.select(sdf.a.isNotNull(), sdf["b"].isNotNull(), SF.col("c").isNotNull()).toPandas(),
)
# test eqNullSafe
self.assert_eq(
cdf.select(cdf.a.eqNullSafe(cdf.b), cdf["b"].eqNullSafe(CF.col("c"))).toPandas(),
sdf.select(sdf.a.eqNullSafe(sdf.b), sdf["b"].eqNullSafe(SF.col("c"))).toPandas(),
)
def test_invalid_ops(self):
query = """
SELECT * FROM VALUES
(1, 1, 0, NULL), (2, NULL, 1, 2.0), (3, 3, 4, 3.5)
AS tab(a, b, c, d)
"""
cdf = self.connect.sql(query)
with self.assertRaisesRegex(
ValueError,
"Cannot apply 'in' operator against a column",
):
1 in cdf.a
with self.assertRaisesRegex(
ValueError,
"Cannot convert column into bool",
):
cdf.a > 2 and cdf.b < 1
with self.assertRaisesRegex(
ValueError,
"Cannot convert column into bool",
):
cdf.a > 2 or cdf.b < 1
with self.assertRaisesRegex(
ValueError,
"Cannot convert column into bool",
):
not (cdf.a > 2)
with self.assertRaisesRegex(
TypeError,
"Column is not iterable",
):
for x in cdf.a:
pass
def test_datetime(self):
query = """
SELECT * FROM VALUES
(TIMESTAMP('2022-12-22 15:50:00'), DATE('2022-12-25'), 1.1),
(TIMESTAMP('2022-12-22 18:50:00'), NULL, 2.2),
(TIMESTAMP('2022-12-23 15:50:00'), DATE('2022-12-24'), 3.3),
(NULL, DATE('2022-12-22'), NULL)
AS tab(a, b, c)
"""
# +-------------------+----------+----+
# | a| b| c|
# +-------------------+----------+----+
# |2022-12-22 15:50:00|2022-12-25| 1.1|
# |2022-12-22 18:50:00| NULL| 2.2|
# |2022-12-23 15:50:00|2022-12-24| 3.3|
# | NULL|2022-12-22|NULL|
# +-------------------+----------+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# datetime.date
self.assert_eq(
cdf.select(cdf.a < datetime.date(2022, 12, 23)).toPandas(),
sdf.select(sdf.a < datetime.date(2022, 12, 23)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a != datetime.date(2022, 12, 23)).toPandas(),
sdf.select(sdf.a != datetime.date(2022, 12, 23)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a == datetime.date(2022, 12, 22)).toPandas(),
sdf.select(sdf.a == datetime.date(2022, 12, 22)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b < datetime.date(2022, 12, 23)).toPandas(),
sdf.select(sdf.b < datetime.date(2022, 12, 23)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b >= datetime.date(2022, 12, 23)).toPandas(),
sdf.select(sdf.b >= datetime.date(2022, 12, 23)).toPandas(),
)
# datetime.datetime
self.assert_eq(
cdf.select(cdf.a < datetime.datetime(2022, 12, 22, 17, 0, 0)).toPandas(),
sdf.select(sdf.a < datetime.datetime(2022, 12, 22, 17, 0, 0)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a > datetime.datetime(2022, 12, 22, 17, 0, 0)).toPandas(),
sdf.select(sdf.a > datetime.datetime(2022, 12, 22, 17, 0, 0)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b >= datetime.datetime(2022, 12, 23, 17, 0, 0)).toPandas(),
sdf.select(sdf.b >= datetime.datetime(2022, 12, 23, 17, 0, 0)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b < datetime.datetime(2022, 12, 23, 17, 0, 0)).toPandas(),
sdf.select(sdf.b < datetime.datetime(2022, 12, 23, 17, 0, 0)).toPandas(),
)
def test_decimal(self):
# SPARK-41701: test decimal
query = """
SELECT * FROM VALUES
(1, 1, 0, NULL), (2, NULL, 1, 2.0), (3, 3, 4, 3.5)
AS tab(a, b, c, d)
"""
# +---+----+---+----+
# | a| b| c| d|
# +---+----+---+----+
# | 1| 1| 0|NULL|
# | 2|NULL| 1| 2.0|
# | 3| 3| 4| 3.5|
# +---+----+---+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.select(cdf.a < decimal.Decimal(3)).toPandas(),
sdf.select(sdf.a < decimal.Decimal(3)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a != decimal.Decimal(2)).toPandas(),
sdf.select(sdf.a != decimal.Decimal(2)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a == decimal.Decimal(2)).toPandas(),
sdf.select(sdf.a == decimal.Decimal(2)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b < decimal.Decimal(2.5)).toPandas(),
sdf.select(sdf.b < decimal.Decimal(2.5)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.d >= decimal.Decimal(3.0)).toPandas(),
sdf.select(sdf.d >= decimal.Decimal(3.0)).toPandas(),
)
def test_none(self):
# SPARK-41783: test none
query = """
SELECT * FROM VALUES
(1, 1, NULL), (2, NULL, 1), (NULL, 3, 4)
AS tab(a, b, c)
"""
# +----+----+----+
# | a| b| c|
# +----+----+----+
# | 1| 1|NULL|
# | 2|NULL| 1|
# |NULL| 3| 4|
# +----+----+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.select(cdf.b > None, CF.col("c") >= None).toPandas(),
sdf.select(sdf.b > None, SF.col("c") >= None).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b < None, CF.col("c") <= None).toPandas(),
sdf.select(sdf.b < None, SF.col("c") <= None).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b.eqNullSafe(None), CF.col("c").eqNullSafe(None)).toPandas(),
sdf.select(sdf.b.eqNullSafe(None), SF.col("c").eqNullSafe(None)).toPandas(),
)
def test_simple_binary_expressions(self):
"""Test complex expression"""
query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
cdf = self.connect.sql(query)
pdf = (
cdf.select(cdf.id).where(cdf.id % CF.lit(30) == CF.lit(0)).sort(cdf.id.asc()).toPandas()
)
self.assertEqual(len(pdf.index), 4)
res = pd.DataFrame(data={"id": [0, 30, 60, 90]})
self.assertTrue(pdf.equals(res), f"{pdf.to_string()} != {res.to_string()}")
def test_literal_with_acceptable_type(self):
for value, dataType in [
(b"binary\0\0asas", BinaryType()),
(True, BooleanType()),
(False, BooleanType()),
(0, ByteType()),
(JVM_BYTE_MIN, ByteType()),
(JVM_BYTE_MAX, ByteType()),
(0, ShortType()),
(JVM_SHORT_MIN, ShortType()),
(JVM_SHORT_MAX, ShortType()),
(0, IntegerType()),
(JVM_INT_MIN, IntegerType()),
(JVM_INT_MAX, IntegerType()),
(0, LongType()),
(JVM_LONG_MIN, LongType()),
(JVM_LONG_MAX, LongType()),
(0.0, FloatType()),
(1.234567, FloatType()),
(float("nan"), FloatType()),
(float("inf"), FloatType()),
(float("-inf"), FloatType()),
(0.0, DoubleType()),
(1.234567, DoubleType()),
(float("nan"), DoubleType()),
(float("inf"), DoubleType()),
(float("-inf"), DoubleType()),
(decimal.Decimal(0.0), DecimalType()),
(decimal.Decimal(1.234567), DecimalType()),
("sss", StringType()),
(datetime.date(2022, 12, 13), DateType()),
(datetime.datetime.now(), DateType()),
(datetime.time(1, 0, 0), TimeType()),
(datetime.datetime.now(), TimestampType()),
(datetime.datetime.now(), TimestampNTZType()),
(datetime.timedelta(1, 2, 3), DayTimeIntervalType()),
]:
lit = LiteralExpression(value=value, dataType=dataType)
self.assertEqual(dataType, lit._dataType)
def test_literal_with_unsupported_type(self):
for value, dataType in [
(b"binary\0\0asas", BooleanType()),
(True, StringType()),
(False, DoubleType()),
(JVM_BYTE_MIN - 1, ByteType()),
(JVM_BYTE_MAX + 1, ByteType()),
(JVM_SHORT_MIN - 1, ShortType()),
(JVM_SHORT_MAX + 1, ShortType()),
(JVM_INT_MIN - 1, IntegerType()),
(JVM_INT_MAX + 1, IntegerType()),
(JVM_LONG_MIN - 1, LongType()),
(JVM_LONG_MAX + 1, LongType()),
(0.1, DecimalType()),
(datetime.date(2022, 12, 13), TimestampType()),
(datetime.timedelta(1, 2, 3), DateType()),
({1: 2}, MapType(IntegerType(), IntegerType())),
(
{"a": "xyz", "b": 1},
StructType([StructField("a", StringType()), StructField("b", IntegerType())]),
),
]:
with self.assertRaises(AssertionError):
LiteralExpression(value=value, dataType=dataType)
def test_literal_null(self):
for dataType in [
NullType(),
BinaryType(),
BooleanType(),
ByteType(),
ShortType(),
IntegerType(),
LongType(),
FloatType(),
DoubleType(),
DecimalType(),
DateType(),
TimeType(),
TimestampType(),
TimestampNTZType(),
DayTimeIntervalType(),
]:
lit_null = LiteralExpression(value=None, dataType=dataType)
self.assertTrue(lit_null._value is None)
self.assertEqual(dataType, lit_null._dataType)
cdf = self.connect.range(0, 1).select(Column(lit_null))
self.assertEqual(dataType, cdf.schema.fields[0].dataType)
for value, dataType in [
("123", NullType()),
(123, NullType()),
(None, MapType(IntegerType(), IntegerType())),
(None, StructType([StructField("a", StringType())])),
]:
with self.assertRaises(AssertionError):
LiteralExpression(value=value, dataType=dataType)
def test_literal_integers(self):
cdf = self.connect.range(0, 1)
sdf = self.spark.range(0, 1)
cdf1 = cdf.select(
CF.lit(0),
CF.lit(1),
CF.lit(-1),
CF.lit(JVM_INT_MAX),
CF.lit(JVM_INT_MIN),
CF.lit(JVM_INT_MAX + 1),
CF.lit(JVM_INT_MIN - 1),
CF.lit(JVM_LONG_MAX),
CF.lit(JVM_LONG_MIN),
CF.lit(JVM_LONG_MAX - 1),
CF.lit(JVM_LONG_MIN + 1),
)
sdf1 = sdf.select(
SF.lit(0),
SF.lit(1),
SF.lit(-1),
SF.lit(JVM_INT_MAX),
SF.lit(JVM_INT_MIN),
SF.lit(JVM_INT_MAX + 1),
SF.lit(JVM_INT_MIN - 1),
SF.lit(JVM_LONG_MAX),
SF.lit(JVM_LONG_MIN),
SF.lit(JVM_LONG_MAX - 1),
SF.lit(JVM_LONG_MIN + 1),
)
self.assertEqual(cdf1.schema, sdf1.schema)
self.assert_eq(cdf1.toPandas(), sdf1.toPandas())
# negative test for incorrect type
with self.assertRaises(PySparkValueError) as pe:
cdf.select(CF.lit(JVM_LONG_MAX + 1)).show()
self.check_error(
exception=pe.exception,
errorClass="VALUE_NOT_BETWEEN",
messageParameters={"arg_name": "value", "min": "-9223372036854775808", "max": "32767"},
)
with self.assertRaises(PySparkValueError) as pe:
cdf.select(CF.lit(JVM_LONG_MIN - 1)).show()
self.check_error(
exception=pe.exception,
errorClass="VALUE_NOT_BETWEEN",
messageParameters={"arg_name": "value", "min": "-9223372036854775808", "max": "32767"},
)
def test_cast(self):
# SPARK-41412: test basic Column.cast
query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
df = self.connect.sql(query)
df2 = self.spark.sql(query)
self.assert_eq(
df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas()
)
self.assert_eq(
df.select(df.id.astype("string")).toPandas(),
df2.select(df2.id.astype("string")).toPandas(),
)
for x in [
StringType(),
ShortType(),
IntegerType(),
LongType(),
FloatType(),
DoubleType(),
ByteType(),
DecimalType(10, 2),
BooleanType(),
DayTimeIntervalType(),
]:
self.assert_eq(
df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas()
)
with self.assertRaises(PySparkTypeError) as pe:
df.id.cast(10)
self.check_error(
exception=pe.exception,
errorClass="NOT_DATATYPE_OR_STR",
messageParameters={"arg_name": "dataType", "arg_type": "int"},
)
def test_isin(self):
# SPARK-41526: test Column.isin
query = """
SELECT * FROM VALUES
(1, 1, 0, NULL), (2, NULL, 1, 2.0), (3, 3, 4, 3.5)
AS tab(a, b, c, d)
"""
# +---+----+---+----+
# | a| b| c| d|
# +---+----+---+----+
# | 1| 1| 0|NULL|
# | 2|NULL| 1| 2.0|
# | 3| 3| 4| 3.5|
# +---+----+---+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# test literals
self.assert_eq(
cdf.select(cdf.b.isin(1, 2, 3)).toPandas(),
sdf.select(sdf.b.isin(1, 2, 3)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b.isin([1, 2, 3])).toPandas(),
sdf.select(sdf.b.isin([1, 2, 3])).toPandas(),
)
self.assert_eq(
cdf.select(cdf.b.isin(set([1, 2, 3]))).toPandas(),
sdf.select(sdf.b.isin(set([1, 2, 3]))).toPandas(),
)
self.assert_eq(
cdf.select(cdf.d.isin([1.0, None, 3.5])).toPandas(),
sdf.select(sdf.d.isin([1.0, None, 3.5])).toPandas(),
)
# test columns
self.assert_eq(
cdf.select(cdf.a.isin(cdf.b)).toPandas(),
sdf.select(sdf.a.isin(sdf.b)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a.isin(cdf.b, cdf.c)).toPandas(),
sdf.select(sdf.a.isin(sdf.b, sdf.c)).toPandas(),
)
# test columns mixed with literals
self.assert_eq(
cdf.select(cdf.a.isin(cdf.b, 4, 5, 6)).toPandas(),
sdf.select(sdf.a.isin(sdf.b, 4, 5, 6)).toPandas(),
)
def test_between(self):
query = """
SELECT * FROM VALUES
(TIMESTAMP('2022-12-22 15:50:00'), DATE('2022-12-25'), 1.1),
(TIMESTAMP('2022-12-22 18:50:00'), NULL, 2.2),
(TIMESTAMP('2022-12-23 15:50:00'), DATE('2022-12-24'), 3.3),
(NULL, DATE('2022-12-22'), NULL)
AS tab(a, b, c)
"""
# +-------------------+----------+----+
# | a| b| c|
# +-------------------+----------+----+
# |2022-12-22 15:50:00|2022-12-25| 1.1|
# |2022-12-22 18:50:00| NULL| 2.2|
# |2022-12-23 15:50:00|2022-12-24| 3.3|
# | NULL|2022-12-22|NULL|
# +-------------------+----------+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.select(cdf.c.between(0, 2)).toPandas(),
sdf.select(sdf.c.between(0, 2)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.c.between(1.1, 2.2)).toPandas(),
sdf.select(sdf.c.between(1.1, 2.2)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.c.between(decimal.Decimal(0), decimal.Decimal(2))).toPandas(),
sdf.select(sdf.c.between(decimal.Decimal(0), decimal.Decimal(2))).toPandas(),
)
self.assert_eq(
cdf.select(
cdf.a.between(
datetime.datetime(2022, 12, 22, 17, 0, 0),
datetime.datetime(2022, 12, 23, 6, 0, 0),
)
).toPandas(),
sdf.select(
sdf.a.between(
datetime.datetime(2022, 12, 22, 17, 0, 0),
datetime.datetime(2022, 12, 23, 6, 0, 0),
)
).toPandas(),
)
self.assert_eq(
cdf.select(
cdf.b.between(datetime.date(2022, 12, 23), datetime.date(2022, 12, 24))
).toPandas(),
sdf.select(
sdf.b.between(datetime.date(2022, 12, 23), datetime.date(2022, 12, 24))
).toPandas(),
)
def test_column_bitwise_ops(self):
# SPARK-41751: test bitwiseAND, bitwiseOR, bitwiseXOR
query = """
SELECT * FROM VALUES
(1, 1, 0), (2, NULL, 1), (3, 3, 4)
AS tab(a, b, c)
"""
# +---+----+---+
# | a| b| c|
# +---+----+---+
# | 1| 1| 0|
# | 2|NULL| 1|
# | 3| 3| 4|
# +---+----+---+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# test bitwiseAND
self.assert_eq(
cdf.select(cdf.a.bitwiseAND(cdf.b), cdf["a"].bitwiseAND(CF.col("c"))).toPandas(),
sdf.select(sdf.a.bitwiseAND(sdf.b), sdf["a"].bitwiseAND(SF.col("c"))).toPandas(),
)
# test bitwiseOR
self.assert_eq(
cdf.select(cdf.a.bitwiseOR(cdf.b), cdf["a"].bitwiseOR(CF.col("c"))).toPandas(),
sdf.select(sdf.a.bitwiseOR(sdf.b), sdf["a"].bitwiseOR(SF.col("c"))).toPandas(),
)
# test bitwiseXOR
self.assert_eq(
cdf.select(cdf.a.bitwiseXOR(cdf.b), cdf["a"].bitwiseXOR(CF.col("c"))).toPandas(),
sdf.select(sdf.a.bitwiseXOR(sdf.b), sdf["a"].bitwiseXOR(SF.col("c"))).toPandas(),
)
def test_column_accessor(self):
query = """
SELECT STRUCT(a, b, c) AS x, y, z, c FROM VALUES
(float(1.0), double(1.0), '2022', MAP('b', '123', 'a', 'kk'), ARRAY(1, 2, 3)),
(float(2.0), double(2.0), '2018', MAP('a', 'xy'), ARRAY(-1, -2, -3)),
(float(3.0), double(3.0), NULL, MAP('a', 'ab'), ARRAY(-1, 0, 1))
AS tab(a, b, c, y, z)
"""
# +----------------+-------------------+------------+----+
# | x| y| z| c|
# +----------------+-------------------+------------+----+
# |{1.0, 1.0, 2022}|{b -> 123, a -> kk}| [1, 2, 3]|2022|
# |{2.0, 2.0, 2018}| {a -> xy}|[-1, -2, -3]|2018|
# |{3.0, 3.0, null}| {a -> ab}| [-1, 0, 1]|NULL|
# +----------------+-------------------+------------+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# test struct
self.assert_eq(
cdf.select(cdf.x.a, cdf.x["b"], cdf["x"].c).toPandas(),
sdf.select(sdf.x.a, sdf.x["b"], sdf["x"].c).toPandas(),
)
self.assert_eq(
cdf.select(CF.col("x").a, cdf.x.b, CF.col("x")["c"]).toPandas(),
sdf.select(SF.col("x").a, sdf.x.b, SF.col("x")["c"]).toPandas(),
)
self.assert_eq(
cdf.select(cdf.x.getItem("a"), cdf.x.getItem("b"), cdf["x"].getField("c")).toPandas(),
sdf.select(sdf.x.getItem("a"), sdf.x.getItem("b"), sdf["x"].getField("c")).toPandas(),
)
# test map
self.assert_eq(
cdf.select(cdf.y.a, cdf.y["b"], cdf["y"].c).toPandas(),
sdf.select(sdf.y.a, sdf.y["b"], sdf["y"].c).toPandas(),
)
self.assert_eq(
cdf.select(CF.col("y").a, cdf.y.b, CF.col("y")["c"]).toPandas(),
sdf.select(SF.col("y").a, sdf.y.b, SF.col("y")["c"]).toPandas(),
)
self.assert_eq(
cdf.select(cdf.y.getItem("a"), cdf.y.getItem("b"), cdf["y"].getField("c")).toPandas(),
sdf.select(sdf.y.getItem("a"), sdf.y.getItem("b"), sdf["y"].getField("c")).toPandas(),
)
# test array
self.assert_eq(
cdf.select(cdf.z[0], cdf.z[1], cdf["z"][2]).toPandas(),
sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(),
)
self.assert_eq(
cdf.select(CF.col("z")[0], CF.get(cdf.z, 10), CF.get(CF.col("z"), -10)).toPandas(),
sdf.select(SF.col("z")[0], SF.get(sdf.z, 10), SF.get(SF.col("z"), -10)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), cdf["z"].getField(2)).toPandas(),
sdf.select(sdf.z.getItem(0), sdf.z.getItem(1), sdf["z"].getField(2)).toPandas(),
)
# test string with slice
self.assert_eq(
cdf.select(cdf.c[0:1], cdf["c"][2:10]).toPandas(),
sdf.select(sdf.c[0:1], sdf["c"][2:10]).toPandas(),
)
def test_column_arithmetic_ops(self):
# SPARK-41761: test arithmetic ops
query = """
SELECT * FROM VALUES
(1, 1, 0, NULL), (2, NULL, 1, 2.0), (3, 3, 4, 3.5)
AS tab(a, b, c, d)
"""
# +---+----+---+----+
# | a| b| c| d|
# +---+----+---+----+
# | 1| 1| 0|NULL|
# | 2|NULL| 1| 2.0|
# | 3| 3| 4| 3.5|
# +---+----+---+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.select(
cdf.a + cdf["b"] - 1, cdf.a - cdf["b"] * cdf["c"] / 2, cdf.d / cdf.b / 3
).toPandas(),
sdf.select(
sdf.a + sdf["b"] - 1, sdf.a - sdf["b"] * sdf["c"] / 2, sdf.d / sdf.b / 3
).toPandas(),
)
self.assert_eq(
cdf.select((-cdf.a)).toPandas(),
sdf.select((-sdf.a)).toPandas(),
)
self.assert_eq(
cdf.select(3 - cdf.a + cdf["b"] * cdf["c"] - cdf.d / cdf.b).toPandas(),
sdf.select(3 - sdf.a + sdf["b"] * sdf["c"] - sdf.d / sdf.b).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a % cdf["b"], cdf["a"] % 2, CF.try_mod(CF.lit(12), cdf.c)).toPandas(),
sdf.select(sdf.a % sdf["b"], sdf["a"] % 2, SF.try_mod(SF.lit(12), sdf.c)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.a ** cdf["b"], cdf.d**2, 2**cdf.c).toPandas(),
sdf.select(sdf.a ** sdf["b"], sdf.d**2, 2**sdf.c).toPandas(),
)
def test_column_field_ops(self):
# SPARK-41767: test withField, dropFields
query = """
SELECT STRUCT(a, b, c, d) AS x, e FROM VALUES
(float(1.0), double(1.0), '2022', 1, 0),
(float(2.0), double(2.0), '2018', NULL, 2),
(float(3.0), double(3.0), NULL, 3, NULL)
AS tab(a, b, c, d, e)
"""
# +----------------------+----+
# | x| e|
# +----------------------+----+
# | {1.0, 1.0, 2022, 1}| 0|
# |{2.0, 2.0, 2018, null}| 2|
# | {3.0, 3.0, null, 3}|NULL|
# +----------------------+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# add field
self.compare_by_show(
cdf.select(cdf.x.withField("z", cdf.e)),
sdf.select(sdf.x.withField("z", sdf.e)),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.withField("z", CF.col("e"))),
sdf.select(sdf.x.withField("z", SF.col("e"))),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.withField("z", CF.lit("xyz"))),
sdf.select(sdf.x.withField("z", SF.lit("xyz"))),
truncate=100,
)
# replace field
self.compare_by_show(
cdf.select(cdf.x.withField("a", cdf.e)),
sdf.select(sdf.x.withField("a", sdf.e)),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.withField("a", CF.col("e"))),
sdf.select(sdf.x.withField("a", SF.col("e"))),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.withField("a", CF.lit("xyz"))),
sdf.select(sdf.x.withField("a", SF.lit("xyz"))),
truncate=100,
)
# drop field
self.compare_by_show(
cdf.select(cdf.x.dropFields("a")),
sdf.select(sdf.x.dropFields("a")),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.dropFields("z")),
sdf.select(sdf.x.dropFields("z")),
truncate=100,
)
self.compare_by_show(
cdf.select(cdf.x.dropFields("a", "b", "z")),
sdf.select(sdf.x.dropFields("a", "b", "z")),
truncate=100,
)
# check error
# invalid column: not a struct column
with self.assertRaises(SparkConnectException):
cdf.select(cdf.e.withField("a", CF.lit(1))).show()
# invalid column: not a struct column
with self.assertRaises(SparkConnectException):
cdf.select(cdf.e.dropFields("a")).show()
# cannot drop all fields in struct
with self.assertRaises(SparkConnectException):
cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show()
with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show()
self.check_error(
exception=pe.exception,
errorClass="NOT_STR",
messageParameters={"arg_name": "fieldName", "arg_type": "Column"},
)
with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.withField("a", 2)).show()
self.check_error(
exception=pe.exception,
errorClass="NOT_COLUMN",
messageParameters={"arg_name": "col", "arg_type": "int"},
)
with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.dropFields("a", 1, 2)).show()
self.check_error(
exception=pe.exception,
errorClass="NOT_STR",
messageParameters={"arg_name": "fieldName", "arg_type": "int"},
)
with self.assertRaises(PySparkValueError) as pe:
cdf.select(cdf.x.dropFields()).show()
self.check_error(
exception=pe.exception,
errorClass="CANNOT_BE_EMPTY",
messageParameters={"item": "dropFields"},
)
def test_column_string_ops(self):
# SPARK-41764: test string ops
query = """
SELECT * FROM VALUES
(1, 'abcdef', 'ghij', 'hello world', 'a'),
(2, 'abcd', 'efghij', 'how are you', 'd')
AS tab(a, b, c, d, e)
"""
# +---+------+------+-----------+---+
# | a| b| c| d| e|
# +---+------+------+-----------+---+
# | 1|abcdef| ghij|hello world| a|
# | 2| abcd|efghij|how are you| d|
# +---+------+------+-----------+---+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.select(
cdf.b.startswith("a"), cdf["c"].startswith("g"), cdf["b"].startswith(cdf.e)
).toPandas(),
sdf.select(
sdf.b.startswith("a"), sdf["c"].startswith("g"), sdf["b"].startswith(sdf.e)
).toPandas(),
)
self.assert_eq(
cdf.select(
cdf.b.endswith("a"), cdf["c"].endswith("j"), cdf["b"].endswith(cdf.e)
).toPandas(),
sdf.select(
sdf.b.endswith("a"), sdf["c"].endswith("j"), sdf["b"].endswith(sdf.e)
).toPandas(),
)
self.assert_eq(
cdf.select(
cdf.b.contains("a"), cdf["c"].contains("j"), cdf["b"].contains(cdf.e)
).toPandas(),
sdf.select(
sdf.b.contains("a"), sdf["c"].contains("j"), sdf["b"].contains(sdf.e)
).toPandas(),
)
def test_with_field_column_name(self):
data = [Row(a=Row(b=1, c=2))]
cdf = self.connect.createDataFrame(data)
cdf1 = cdf.withColumn("a", cdf["a"].withField("b", CF.lit(3))).select("a.b")
sdf = self.spark.createDataFrame(data)
sdf1 = sdf.withColumn("a", sdf["a"].withField("b", SF.lit(3))).select("a.b")
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())
def test_distributed_sequence_id(self):
cdf = self.connect.range(10)
expected = self.connect.range(0, 10).selectExpr("id as index", "id")
self.assertEqual(
cdf.select(Column(DistributedSequenceID()).alias("index"), "*").collect(),
expected.collect(),
)
def test_lambda_str_representation(self):
from pyspark.sql.connect.expressions import UnresolvedNamedLambdaVariable
# forcely clear the internal increasing id,
# otherwise the string representation varies with this id
UnresolvedNamedLambdaVariable._nextVarNameId = 0
c = CF.array_sort(
"data",
lambda x, y: CF.when(x.isNull() | y.isNull(), CF.lit(0)).otherwise(
CF.length(y) - CF.length(x)
),
)
self.assertEqual(
str(c),
(
"""Column<'array_sort(data, LambdaFunction(CASE WHEN or(isNull(x_0), """
"""isNull(y_1)) THEN 0 ELSE -(length(y_1), length(x_0)) END, x_0, y_1))'>"""
),
)
def test_cast_default_column_name(self):
cdf = self.connect.range(1).select(
CF.lit(b"123").cast("STRING"),
CF.lit(123).cast("STRING"),
CF.lit(123).cast("LONG"),
CF.lit(123).cast("DOUBLE"),
)
sdf = self.spark.range(1).select(
SF.lit(b"123").cast("STRING"),
SF.lit(123).cast("STRING"),
SF.lit(123).cast("LONG"),
SF.lit(123).cast("DOUBLE"),
)
self.assertEqual(cdf.columns, sdf.columns)
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)