blob: d74e404d75281f5e4cbbd5f66c5345b1fd319652 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import Iterator, Tuple
from pyspark.sql import functions as sf
from pyspark.sql.window import Window
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
from pyspark.testing.utils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.util import PythonEvalType
@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class UnifiedUDFTestsMixin:
def test_scalar_pandas_udf(self):
import pandas as pd
@udf(returnType=LongType())
def pd_add1(ser: pd.Series) -> pd.Series:
assert isinstance(ser, pd.Series)
return ser + 1
self.assertEqual(pd_add1.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + 1).alias("res")).collect()
result1 = df.select(pd_add1("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1")
self.spark.udf.register("pd_add1", pd_add1)
result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_add1")
def test_scalar_pandas_udf_II(self):
import pandas as pd
@udf(returnType=LongType())
def pd_add(ser1: pd.Series, ser2: pd.Series) -> pd.Series:
assert isinstance(ser1, pd.Series)
assert isinstance(ser2, pd.Series)
return ser1 + ser2
self.assertEqual(pd_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + df.id).alias("res")).collect()
result1 = df.select(pd_add("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add")
self.spark.udf.register("pd_add", pd_add)
result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_add")
def test_scalar_pandas_iter_udf(self):
import pandas as pd
@udf(returnType=LongType())
def pd_add1_iter(it: Iterator[pd.Series]) -> Iterator[pd.Series]:
for ser in it:
assert isinstance(ser, pd.Series)
yield ser + 1
self.assertEqual(pd_add1_iter.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + 1).alias("res")).collect()
result1 = df.select(pd_add1_iter("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1_iter")
self.spark.udf.register("pd_add1_iter", pd_add1_iter)
result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_add1_iter")
def test_scalar_pandas_iter_udf_II(self):
import pandas as pd
@udf(returnType=LongType())
def pd_add_iter(it: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
for ser1, ser2 in it:
assert isinstance(ser1, pd.Series)
assert isinstance(ser2, pd.Series)
yield ser1 + ser2
self.assertEqual(pd_add_iter.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + df.id).alias("res")).collect()
result1 = df.select(pd_add_iter("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add_iter")
self.spark.udf.register("pd_add_iter", pd_add_iter)
result2 = self.spark.sql("SELECT pd_add_iter(id, id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_add_iter")
def test_grouped_agg_pandas_udf(self):
import pandas as pd
@udf(returnType=LongType())
def pd_max(ser: pd.Series) -> int:
assert isinstance(ser, pd.Series)
return ser.max()
self.assertEqual(pd_max.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
df = self.spark.range(0, 10)
expected = df.select(sf.max("id").alias("res")).collect()
result1 = df.select(pd_max("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_max")
self.spark.udf.register("pd_max", pd_max)
result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_max")
def test_window_agg_pandas_udf(self):
import pandas as pd
@udf(returnType=LongType())
def pd_win_max(ser: pd.Series) -> int:
assert isinstance(ser, pd.Series)
return ser.max()
self.assertEqual(pd_win_max.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
df = (
self.spark.range(10)
.withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i in range(20, 30)]))
.withColumn("v", sf.explode("vs"))
.drop("vs")
.withColumn("w", sf.lit(1.0))
)
w = (
Window.partitionBy("id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
.orderBy("v")
)
expected = df.withColumn("res", sf.max("v").over(w)).collect()
result1 = df.withColumn("res", pd_win_max("v").over(w)).collect()
self.assertEqual(result1, expected)
with self.tempView("pd_tbl"):
df.createOrReplaceTempView("pd_tbl")
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_win_max")
self.spark.udf.register("pd_win_max", pd_win_max)
result2 = self.spark.sql(
"""
SELECT *, pd_win_max(v) OVER (
PARTITION BY id
ORDER BY v
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS res FROM pd_tbl
"""
).collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pd_win_max")
def test_scalar_arrow_udf(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_add1(arr: pa.Array) -> pa.Array:
assert isinstance(arr, pa.Array)
return pa.compute.add(arr, 1)
self.assertEqual(pa_add1.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + 1).alias("res")).collect()
result1 = df.select(pa_add1("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1")
self.spark.udf.register("pa_add1", pa_add1)
result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_add1")
def test_scalar_arrow_udf_II(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_add(arr1: pa.Array, arr2: pa.Array) -> pa.Array:
assert isinstance(arr1, pa.Array)
assert isinstance(arr2, pa.Array)
return pa.compute.add(arr1, arr2)
self.assertEqual(pa_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + df.id).alias("res")).collect()
result1 = df.select(pa_add("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add")
self.spark.udf.register("pa_add", pa_add)
result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_add")
def test_scalar_arrow_iter_udf(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_add1_iter(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
for arr in it:
assert isinstance(arr, pa.Array)
yield pa.compute.add(arr, 1)
self.assertEqual(pa_add1_iter.evalType, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + 1).alias("res")).collect()
result1 = df.select(pa_add1_iter("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1_iter")
self.spark.udf.register("pa_add1_iter", pa_add1_iter)
result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_add1_iter")
def test_scalar_arrow_iter_udf_II(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_add_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for arr1, arr2 in it:
assert isinstance(arr1, pa.Array)
assert isinstance(arr2, pa.Array)
yield pa.compute.add(arr1, arr2)
self.assertEqual(pa_add_iter.evalType, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF)
df = self.spark.range(0, 10)
expected = df.select((df.id + df.id).alias("res")).collect()
result1 = df.select(pa_add_iter("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add_iter")
self.spark.udf.register("pa_add_iter", pa_add_iter)
result2 = self.spark.sql("SELECT pa_add_iter(id, id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_add_iter")
def test_grouped_agg_arrow_udf(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_max(arr: pa.Array) -> pa.Scalar:
assert isinstance(arr, pa.Array)
return pa.compute.max(arr)
self.assertEqual(pa_max.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
df = self.spark.range(0, 10)
expected = df.select(sf.max("id").alias("res")).collect()
result1 = df.select(pa_max("id").alias("res")).collect()
self.assertEqual(result1, expected)
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_max")
self.spark.udf.register("pa_max", pa_max)
result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0, 10)").collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_max")
def test_window_agg_arrow_udf(self):
import pyarrow as pa
@udf(returnType=LongType())
def pa_win_max(arr: pa.Array) -> pa.Scalar:
assert isinstance(arr, pa.Array)
return pa.compute.max(arr)
self.assertEqual(pa_win_max.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
df = (
self.spark.range(10)
.withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i in range(20, 30)]))
.withColumn("v", sf.explode("vs"))
.drop("vs")
.withColumn("w", sf.lit(1.0))
)
w = (
Window.partitionBy("id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
.orderBy("v")
)
expected = df.withColumn("mean_v", sf.max("v").over(w)).collect()
result1 = df.withColumn("mean_v", pa_win_max("v").over(w)).collect()
self.assertEqual(result1, expected)
with self.tempView("pa_tbl"):
df.createOrReplaceTempView("pa_tbl")
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_win_max")
self.spark.udf.register("pa_win_max", pa_win_max)
result2 = self.spark.sql(
"""
SELECT *, pa_win_max(v) OVER (
PARTITION BY id
ORDER BY v
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS res FROM pa_tbl
"""
).collect()
self.assertEqual(result2, expected)
self.spark.sql("DROP TEMPORARY FUNCTION pa_win_max")
def test_regular_python_udf(self):
import pandas as pd
import pyarrow as pa
@udf(returnType=LongType())
def f1(x):
return x + 1
@udf(returnType=LongType())
def f2(x: int) -> int:
return x + 1
# Cannot infer a vectorized UDF type
@udf(returnType=LongType())
def f3(x: int) -> pd.Series:
return x + 1
# Cannot infer a vectorized UDF type
@udf(returnType=LongType())
def f4(x: int) -> pa.Array:
return x + 1
# useArrow is explicitly set to false
@udf(returnType=LongType(), useArrow=False)
def f5(x: pd.Series) -> pd.Series:
return x + 1
# useArrow is explicitly set to false
@udf(returnType=LongType(), useArrow=False)
def f6(x: pa.Array) -> pa.Array:
return x + 1
expected = self.spark.range(10).select((sf.col("id") + 1).alias("res")).collect()
for f in [f1, f2, f3, f4, f5, f6]:
self.assertEqual(f.evalType, PythonEvalType.SQL_BATCHED_UDF)
result = self.spark.range(10).select(f("id").alias("res")).collect()
self.assertEqual(result, expected)
def test_arrow_optimized_python_udf(self):
import pandas as pd
import pyarrow as pa
@udf(returnType=LongType(), useArrow=True)
def f1(x):
return x + 1
@udf(returnType=LongType(), useArrow=True)
def f2(x: int) -> int:
return x + 1
# useArrow is explicitly set
@udf(returnType=LongType(), useArrow=True)
def f3(x: pd.Series) -> pd.Series:
return x + 1
# useArrow is explicitly set
@udf(returnType=LongType(), useArrow=True)
def f4(x: pa.Array) -> pa.Array:
return x + 1
expected = self.spark.range(10).select((sf.col("id") + 1).alias("res")).collect()
for f in [f1, f2, f3, f4]:
self.assertEqual(f.evalType, PythonEvalType.SQL_ARROW_BATCHED_UDF)
result = self.spark.range(10).select(f("id").alias("res")).collect()
self.assertEqual(result, expected)
class UnifiedUDFTests(UnifiedUDFTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_unified_udf import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)