blob: 33984657904e27abd98f6540a4eceae7b249a3c5 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.errors import AnalysisException
from pyspark.sql import Row
from pyspark.sql import functions as sf
from pyspark.sql.types import IntegerType
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
class DataFrameZipTestsMixin:
"""Tests for DataFrame.zip(). Currently only the classic path is supported;
Spark Connect raises ``NOT_IMPLEMENTED``."""
def test_zip_select_different_columns(self):
df = self.spark.createDataFrame([(1, 2, 3), (4, 5, 6), (7, 8, 9)], ["a", "b", "c"])
zipped = df.select("a").zip(df.select("b"))
self.assertEqual(zipped.columns, ["a", "b"])
self.assertEqual(
sorted(zipped.collect()),
[Row(a=1, b=2), Row(a=4, b=5), Row(a=7, b=8)],
)
def test_zip_with_expressions(self):
df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["a", "b"])
left = df.select((sf.col("a") + 1).alias("a_plus_1"))
right = df.select((sf.col("b") * 2).alias("b_times_2"))
self.assertEqual(
sorted(left.zip(right).collect()),
[
Row(a_plus_1=2, b_times_2=20),
Row(a_plus_1=3, b_times_2=40),
Row(a_plus_1=4, b_times_2=60),
],
)
def test_zip_one_side_is_base(self):
df = self.spark.createDataFrame([(1, 2), (3, 4)], ["a", "b"])
right = df.select((sf.col("a") + sf.col("b")).alias("sum"))
self.assertEqual(
sorted(df.zip(right).collect()),
[Row(a=1, b=2, sum=3), Row(a=3, b=4, sum=7)],
)
def test_zip_with_python_udf(self):
df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["a", "b"])
plus_one = sf.udf(lambda x: x + 1, IntegerType())
left = df.select(plus_one(sf.col("a")).alias("a_plus_1"))
right = df.select(plus_one(sf.col("b")).alias("b_plus_1"))
self.assertEqual(
sorted(left.zip(right).collect()),
[
Row(a_plus_1=2, b_plus_1=11),
Row(a_plus_1=3, b_plus_1=21),
Row(a_plus_1=4, b_plus_1=31),
],
)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_zip_with_pandas_udf(self):
import pandas as pd
@sf.pandas_udf(IntegerType())
def plus_one(s: pd.Series) -> pd.Series:
return s + 1
df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["a", "b"])
left = df.select(plus_one(sf.col("a")).alias("a_plus_1"))
right = df.select(plus_one(sf.col("b")).alias("b_plus_1"))
self.assertEqual(
sorted(left.zip(right).collect()),
[
Row(a_plus_1=2, b_plus_1=11),
Row(a_plus_1=3, b_plus_1=21),
Row(a_plus_1=4, b_plus_1=31),
],
)
def test_zip_different_bases_throws(self):
df1 = self.spark.createDataFrame([(1, 2)], ["a", "b"])
df2 = self.spark.createDataFrame([(3, 4, 5)], ["x", "y", "z"])
with self.assertRaises(AnalysisException) as ctx:
df1.select("a").zip(df2.select("x")).schema
self.assertEqual(ctx.exception.getCondition(), "ZIP_PLANS_NOT_MERGEABLE")
def test_zip_different_range_bases_throws(self):
df1 = self.spark.range(10).toDF("id1")
df2 = self.spark.range(20).toDF("id2")
with self.assertRaises(AnalysisException) as ctx:
df1.zip(df2).schema
self.assertEqual(ctx.exception.getCondition(), "ZIP_PLANS_NOT_MERGEABLE")
def test_zip_with_withColumn(self):
df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["a", "b"])
left = df.withColumn("a_plus_1", sf.col("a") + 1)
right = df.withColumn("b_times_2", sf.col("b") * 2)
zipped = left.zip(right)
# Schema has duplicates (a, b appear twice) since withColumn keeps original columns.
self.assertEqual(zipped.columns, ["a", "b", "a_plus_1", "a", "b", "b_times_2"])
rows = sorted(zipped.collect(), key=lambda r: r[0])
self.assertEqual(
[tuple(r) for r in rows],
[(1, 10, 2, 1, 10, 20), (2, 20, 3, 2, 20, 40), (3, 30, 4, 3, 30, 60)],
)
def test_zip_with_withColumnRenamed(self):
df = self.spark.createDataFrame([(1, 2), (3, 4)], ["a", "b"])
left = df.withColumnRenamed("a", "a1")
right = df.withColumnRenamed("b", "b1")
self.assertEqual(
sorted(left.zip(right).collect()),
[Row(a1=1, b=2, a=1, b1=2), Row(a1=3, b=4, a=3, b1=4)],
)
def test_zip_chained_withColumn(self):
# Stack two withColumn calls on left (two Project layers) and one on right.
df = self.spark.createDataFrame([(1, 10), (2, 20)], ["a", "b"])
left = df.withColumn("a_plus_1", sf.col("a") + 1).withColumn("a_plus_2", sf.col("a") + 2)
right = df.withColumn("b_times_2", sf.col("b") * 2)
zipped = left.zip(right)
self.assertEqual(
zipped.columns,
["a", "b", "a_plus_1", "a_plus_2", "a", "b", "b_times_2"],
)
rows = sorted(zipped.collect(), key=lambda r: r[0])
self.assertEqual(
[tuple(r) for r in rows],
[(1, 10, 2, 3, 1, 10, 20), (2, 20, 3, 4, 2, 20, 40)],
)
def test_zip_longer_chain(self):
# Left has three nested Projects; right has one.
df = self.spark.createDataFrame([(1, 2, 3), (4, 5, 6)], ["a", "b", "c"])
left = df.select("a", "b", "c").select("a", "b").select("a")
right = df.select("c")
self.assertEqual(
sorted(left.zip(right).collect()),
[Row(a=1, c=3), Row(a=4, c=6)],
)
def test_zip_parent_with_chained_child(self):
# df.zip(<chained projection of df>) -- the parent has no Project, child has many.
df = self.spark.createDataFrame([(1, 2), (3, 4)], ["a", "b"])
child = df.select((sf.col("a") + 1).alias("a_plus_1")).select(
(sf.col("a_plus_1") * 2).alias("doubled")
)
self.assertEqual(
sorted(df.zip(child).collect()),
[Row(a=1, b=2, doubled=4), Row(a=3, b=4, doubled=8)],
)
class DataFrameZipTests(DataFrameZipTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.testing import main
main()