blob: a682c6515ef61a5cb0d3cc90d1ab4a7b3438d543 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import random
import time
import unittest
import datetime
from decimal import Decimal
from typing import Iterator, Tuple
from pyspark.util import PythonEvalType
from pyspark.sql.functions import arrow_udf, ArrowUDFType
from pyspark.sql import functions as F
from pyspark.sql.types import (
IntegerType,
ByteType,
StructType,
ShortType,
BooleanType,
LongType,
FloatType,
DoubleType,
DecimalType,
StringType,
ArrayType,
StructField,
Row,
MapType,
BinaryType,
YearMonthIntervalType,
)
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.utils import (
have_numpy,
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ScalarArrowUDFTestsMixin:
@property
def nondeterministic_arrow_udf(self):
import pyarrow as pa
import numpy as np
@arrow_udf("double")
def random_udf(v):
return pa.array(np.random.random(len(v)))
return random_udf.asNondeterministic()
@property
def nondeterministic_arrow_iter_udf(self):
import pyarrow as pa
import numpy as np
@arrow_udf("double", ArrowUDFType.SCALAR_ITER)
def random_udf(it):
for v in it:
yield pa.array(np.random.random(len(v)))
return random_udf.asNondeterministic()
def test_arrow_udf_tokenize(self):
import pyarrow as pa
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
tokenize = arrow_udf(
lambda s: pa.compute.ascii_split_whitespace(s),
ArrayType(StringType()),
)
result = df.select(tokenize("vals").alias("hi"))
self.assertEqual(tokenize.returnType, ArrayType(StringType()))
self.assertEqual([Row(hi=["hi", "boo"]), Row(hi=["bye", "boo"])], result.collect())
def test_arrow_udf_output_nested_arrays(self):
import pyarrow as pa
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
tokenize = arrow_udf(
lambda s: pa.array([[v] for v in pa.compute.ascii_split_whitespace(s).to_pylist()]),
ArrayType(ArrayType(StringType())),
)
result = df.select(tokenize("vals").alias("hi"))
self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
self.assertEqual([Row(hi=[["hi", "boo"]]), Row(hi=[["bye", "boo"]])], result.collect())
def test_arrow_udf_output_structs(self):
import pyarrow as pa
df = self.spark.range(10).select("id", F.lit("foo").alias("name"))
create_struct = arrow_udf(
lambda x, y: pa.StructArray.from_arrays([x, y], names=["c1", "c2"]),
"struct<c1:long, c2:string>",
)
self.assertEqual(
df.select(create_struct("id", "name").alias("res")).first(),
Row(res=Row(c1=0, c2="foo")),
)
def test_arrow_udf_output_nested_structs(self):
import pyarrow as pa
df = self.spark.range(10).select("id", F.lit("foo").alias("name"))
create_struct = arrow_udf(
lambda x, y: pa.StructArray.from_arrays(
[x, pa.StructArray.from_arrays([x, y], names=["c3", "c4"])], names=["c1", "c2"]
),
"struct<c1:long, c2:struct<c3:long, c4:string>>",
)
self.assertEqual(
df.select(create_struct("id", "name").alias("res")).first(),
Row(res=Row(c1=0, c2=Row(c3=0, c4="foo"))),
)
def test_arrow_udf_input_output_nested_structs(self):
# root
# |-- s: struct (nullable = false)
# | |-- a: integer (nullable = false)
# | |-- y: struct (nullable = false)
# | | |-- b: integer (nullable = false)
# | | |-- x: struct (nullable = false)
# | | | |-- c: integer (nullable = false)
# | | | |-- d: integer (nullable = false)
df = self.spark.sql(
"""
SELECT STRUCT(a, STRUCT(b, STRUCT(c, d) AS x) AS y) AS s
FROM VALUES
(1, 2, 3, 4),
(-1, -2, -3, -4)
AS tab(a, b, c, d)
"""
)
schema = StructType(
[
StructField("a", IntegerType(), False),
StructField(
"y",
StructType(
[
StructField("b", IntegerType(), False),
StructField(
"x",
StructType(
[
StructField("c", IntegerType(), False),
StructField("d", IntegerType(), False),
]
),
False,
),
]
),
False,
),
]
)
extract_y = arrow_udf(lambda s: s.field("y"), schema["y"].dataType)
result = df.select(extract_y("s").alias("y"))
self.assertEqual(
[Row(y=Row(b=2, x=Row(c=3, d=4))), Row(y=Row(b=-2, x=Row(c=-3, d=-4)))],
result.collect(),
)
def test_arrow_udf_input_nested_maps(self):
import pyarrow as pa
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"attributes", MapType(StringType(), MapType(StringType(), StringType())), True
),
]
)
data = [("1", {"personal": {"name": "John", "city": "New York"}})]
# root
# |-- id: string (nullable = true)
# |-- attributes: map (nullable = true)
# | |-- key: string
# | |-- value: map (valueContainsNull = true)
# | | |-- key: string
# | | |-- value: string (valueContainsNull = true)
df = self.spark.createDataFrame(data, schema)
str_repr = arrow_udf(lambda s: pa.array(str(x.as_py()) for x in s), StringType())
result = df.select(str_repr("attributes").alias("s"))
self.assertEqual(
[Row(s="[('personal', [('name', 'John'), ('city', 'New York')])]")],
result.collect(),
)
def test_arrow_udf_input_nested_arrays(self):
import pyarrow as pa
# root
# |-- a: array (nullable = false)
# | |-- element: array (containsNull = false)
# | | |-- element: integer (containsNull = true)
df = self.spark.sql(
"""
SELECT ARRAY(ARRAY(1,2,3),ARRAY(4,NULL),ARRAY(5,6,NULL)) AS arr
"""
)
str_repr = arrow_udf(lambda s: pa.array(str(x.as_py()) for x in s), StringType())
result = df.select(str_repr("arr").alias("s"))
self.assertEqual(
[Row(s="[[1, 2, 3], [4, None], [5, 6, None]]")],
result.collect(),
)
def test_arrow_udf_input_arrow_array_struct(self):
import pyarrow as pa
df = self.spark.createDataFrame(
[[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]],
"array_struct_col array<struct<col1:string, col2:long, col3:double>>",
)
@arrow_udf("array<struct<col1:string, col2:long, col3:double>>")
def return_cols(cols):
assert isinstance(cols, pa.Array)
return cols
result = df.select(return_cols("array_struct_col"))
self.assertEqual(
[
Row(output=[Row(col1="a", col2=2, col3=3.0), Row(col1="a", col2=2, col3=3.0)]),
Row(output=[Row(col1="b", col2=5, col3=6.0), Row(col1="b", col2=5, col3=6.0)]),
],
result.collect(),
)
def test_arrow_udf_input_dates(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(1, DATE('2022-02-22')),
(2, DATE('2023-02-22')),
(3, DATE('2024-02-22'))
AS tab(i, date)
"""
)
@arrow_udf("int")
def extract_year(d):
assert isinstance(d, pa.Array)
assert isinstance(d, pa.Date32Array)
return pa.array([x.as_py().year for x in d], pa.int32())
result = df.select(extract_year("date").alias("year"))
self.assertEqual(
[Row(year=2022), Row(year=2023), Row(year=2024)],
result.collect(),
)
def test_arrow_udf_output_dates(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(2022, 1, 5),
(2023, 2, 6),
(2024, 3, 7)
AS tab(y, m, d)
"""
)
@arrow_udf("date")
def build_date(y, m, d):
assert all(isinstance(x, pa.Array) for x in [y, m, d])
dates = [
datetime.date(int(y[i].as_py()), int(m[i].as_py()), int(d[i].as_py()))
for i in range(len(y))
]
return pa.array(dates, pa.date32())
result = df.select(build_date("y", "m", "d").alias("date"))
self.assertEqual(
[
Row(date=datetime.date(2022, 1, 5)),
Row(date=datetime.date(2023, 2, 6)),
Row(date=datetime.date(2024, 3, 7)),
],
result.collect(),
)
def test_arrow_udf_input_timestamps(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(1, TIMESTAMP('2019-04-12 15:50:01')),
(2, TIMESTAMP('2020-04-12 15:50:02')),
(3, TIMESTAMP('2021-04-12 15:50:03'))
AS tab(i, ts)
"""
)
@arrow_udf("int")
def extract_second(d):
assert isinstance(d, pa.Array)
assert isinstance(d, pa.TimestampArray)
return pa.array([x.as_py().second for x in d], pa.int32())
result = df.select(extract_second("ts").alias("second"))
self.assertEqual(
[Row(second=1), Row(second=2), Row(second=3)],
result.collect(),
)
def test_arrow_udf_output_timestamps_ltz(self):
from zoneinfo import ZoneInfo
import pyarrow as pa
tz = self.spark.conf.get("spark.sql.session.timeZone")
df = self.spark.sql(
"""
SELECT * FROM VALUES
(2022, 1, 5, 15, 0, 1),
(2023, 2, 6, 16, 1, 2),
(2024, 3, 7, 17, 2, 3)
AS tab(y, m, d, h, mi, s)
"""
)
@arrow_udf("timestamp")
def build_ts(y, m, d, h, mi, s):
assert all(isinstance(x, pa.Array) for x in [y, m, d, h, mi, s])
dates = [
datetime.datetime(
int(y[i].as_py()),
int(m[i].as_py()),
int(d[i].as_py()),
int(h[i].as_py()),
int(mi[i].as_py()),
int(s[i].as_py()),
tzinfo=ZoneInfo(tz),
).astimezone(datetime.timezone.utc)
for i in range(len(y))
]
return pa.array(dates, pa.timestamp("us", "UTC"))
result = df.select(build_ts("y", "m", "d", "h", "mi", "s").alias("ts"))
self.assertEqual(
[
Row(ts=datetime.datetime(2022, 1, 5, 15, 0, 1)),
Row(ts=datetime.datetime(2023, 2, 6, 16, 1, 2)),
Row(ts=datetime.datetime(2024, 3, 7, 17, 2, 3)),
],
result.collect(),
)
def test_arrow_udf_output_timestamps_ntz(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(2022, 1, 5, 15, 0, 1),
(2023, 2, 6, 16, 1, 2),
(2024, 3, 7, 17, 2, 3)
AS tab(y, m, d, h, mi, s)
"""
)
@arrow_udf("timestamp_ntz")
def build_ts(y, m, d, h, mi, s):
assert all(isinstance(x, pa.Array) for x in [y, m, d, h, mi, s])
dates = [
datetime.datetime(
int(y[i].as_py()),
int(m[i].as_py()),
int(d[i].as_py()),
int(h[i].as_py()),
int(mi[i].as_py()),
int(s[i].as_py()),
)
for i in range(len(y))
]
return pa.array(dates, pa.timestamp("us"))
result = df.select(build_ts("y", "m", "d", "h", "mi", "s").alias("ts"))
self.assertEqual(
[
Row(ts=datetime.datetime(2022, 1, 5, 15, 0, 1)),
Row(ts=datetime.datetime(2023, 2, 6, 16, 1, 2)),
Row(ts=datetime.datetime(2024, 3, 7, 17, 2, 3)),
],
result.collect(),
)
def test_arrow_udf_input_times(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(1, TIME '12:34:56'),
(2, TIME '1:2:3'),
(3, TIME '0:58:59')
AS tab(i, ts)
"""
)
@arrow_udf("int")
def extract_second(v):
assert isinstance(v, pa.Array)
assert isinstance(v, pa.Time64Array), type(v)
return pa.array([t.as_py().second for t in v], pa.int32())
result = df.select(extract_second("ts").alias("sec"))
self.assertEqual(
[
Row(sec=56),
Row(sec=3),
Row(sec=59),
],
result.collect(),
)
def test_arrow_udf_output_times(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(12, 34, 56),
(1, 2, 3),
(0, 58, 59)
AS tab(h, mi, s)
"""
)
@arrow_udf("time")
def build_time(h, mi, s):
assert all(isinstance(x, pa.Array) for x in [h, mi, s])
dates = [
datetime.time(
int(h[i].as_py()),
int(mi[i].as_py()),
int(s[i].as_py()),
)
for i in range(len(h))
]
return pa.array(dates, pa.time64("ns"))
result = df.select(build_time("h", "mi", "s").alias("t"))
self.assertEqual(
[
Row(t=datetime.time(12, 34, 56)),
Row(t=datetime.time(1, 2, 3)),
Row(t=datetime.time(0, 58, 59)),
],
result.collect(),
)
def test_arrow_udf_input_variant(self):
import pyarrow as pa
@arrow_udf("int")
def scalar_f(v: pa.Array) -> pa.Array:
assert isinstance(v, pa.Array)
assert isinstance(v, pa.StructArray)
assert isinstance(v.field("metadata"), pa.BinaryArray)
assert isinstance(v.field("value"), pa.BinaryArray)
return pa.compute.binary_length(v.field("value"))
@arrow_udf("int")
def iter_f(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
for v in it:
assert isinstance(v, pa.Array)
assert isinstance(v, pa.StructArray)
assert isinstance(v.field("metadata"), pa.BinaryArray)
assert isinstance(v.field("value"), pa.BinaryArray)
yield pa.compute.binary_length(v.field("value"))
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
expected = [Row(l=2) for i in range(10)]
for f in [scalar_f, iter_f]:
result = df.select(f("v").alias("l")).collect()
self.assertEqual(result, expected)
def test_arrow_udf_output_variant(self):
# referring to test_udf_with_variant_output in test_pandas_udf_scalar
import pyarrow as pa
# referring to_arrow_type in to pyspark.sql.pandas.types
fields = [
pa.field("value", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
]
@arrow_udf("variant")
def scalar_f(v: pa.Array) -> pa.Array:
assert isinstance(v, pa.Array)
v = pa.array([bytes([12, i.as_py()]) for i in v], pa.binary())
m = pa.array([bytes([1, 0, 0]) for i in v], pa.binary())
return pa.StructArray.from_arrays([v, m], fields=fields)
@arrow_udf("variant")
def iter_f(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
for v in it:
assert isinstance(v, pa.Array)
v = pa.array([bytes([12, i.as_py()]) for i in v])
m = pa.array([bytes([1, 0, 0]) for i in v])
yield pa.StructArray.from_arrays([v, m], fields=fields)
df = self.spark.range(0, 10)
expected = [Row(l=i) for i in range(10)]
for f in [scalar_f, iter_f]:
result = df.select(f("id").cast("int").alias("l")).collect()
self.assertEqual(result, expected)
def test_arrow_udf_null_boolean(self):
data = [(True,), (True,), (None,), (False,)]
schema = StructType().add("bool", BooleanType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
bool_f = arrow_udf(lambda x: x, BooleanType(), udf_type)
res = df.select(bool_f(F.col("bool")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_byte(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("byte", ByteType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
byte_f = arrow_udf(lambda x: x, ByteType(), udf_type)
res = df.select(byte_f(F.col("byte")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_short(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("short", ShortType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
short_f = arrow_udf(lambda x: x, ShortType(), udf_type)
res = df.select(short_f(F.col("short")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_int(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("int", IntegerType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
int_f = arrow_udf(lambda x: x, IntegerType(), udf_type)
res = df.select(int_f(F.col("int")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_long(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("long", LongType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
long_f = arrow_udf(lambda x: x, LongType(), udf_type)
res = df.select(long_f(F.col("long")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_float(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("float", FloatType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
float_f = arrow_udf(lambda x: x, FloatType(), udf_type)
res = df.select(float_f(F.col("float")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_double(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("double", DoubleType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
double_f = arrow_udf(lambda x: x, DoubleType(), udf_type)
res = df.select(double_f(F.col("double")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_decimal(self):
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
schema = StructType().add("decimal", DecimalType(38, 18))
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
decimal_f = arrow_udf(lambda x: x, DecimalType(38, 18), udf_type)
res = df.select(decimal_f(F.col("decimal")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_string(self):
data = [("foo",), (None,), ("bar",), ("bar",)]
schema = StructType().add("str", StringType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
str_f = arrow_udf(lambda x: x, StringType(), udf_type)
res = df.select(str_f(F.col("str")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_binary(self):
data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
schema = StructType().add("binary", BinaryType())
df = self.spark.createDataFrame(data, schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
binary_f = arrow_udf(lambda x: x, BinaryType(), udf_type)
res = df.select(binary_f(F.col("binary")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_null_array(self):
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
df = self.spark.createDataFrame(data, schema=array_schema)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
array_f = arrow_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
result = df.select(array_f(F.col("array")))
self.assertEqual(df.collect(), result.collect())
def test_arrow_udf_empty_partition(self):
df = self.spark.createDataFrame([Row(id=1)]).repartition(2)
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
f = arrow_udf(lambda x: x, LongType(), udf_type)
res = df.select(f(F.col("id")))
self.assertEqual(df.collect(), res.collect())
def test_arrow_udf_datatype_string(self):
df = self.spark.range(10).select(
F.col("id").cast("string").alias("str"),
F.col("id").cast("int").alias("int"),
F.col("id").alias("long"),
F.col("id").cast("float").alias("float"),
F.col("id").cast("double").alias("double"),
F.col("id").cast("decimal").alias("decimal1"),
F.col("id").cast("decimal(10, 0)").alias("decimal2"),
F.col("id").cast("decimal(38, 18)").alias("decimal3"),
F.col("id").cast("boolean").alias("bool"),
)
def f(x):
return x
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
str_f = arrow_udf(f, "string", udf_type)
int_f = arrow_udf(f, "integer", udf_type)
long_f = arrow_udf(f, "long", udf_type)
float_f = arrow_udf(f, "float", udf_type)
double_f = arrow_udf(f, "double", udf_type)
decimal1_f = arrow_udf(f, "decimal", udf_type)
decimal2_f = arrow_udf(f, "decimal(10, 0)", udf_type)
decimal3_f = arrow_udf(f, "decimal(38, 18)", udf_type)
bool_f = arrow_udf(f, "boolean", udf_type)
res = df.select(
str_f(F.col("str")),
int_f(F.col("int")),
long_f(F.col("long")),
float_f(F.col("float")),
double_f(F.col("double")),
decimal1_f("decimal1"),
decimal2_f("decimal2"),
decimal3_f("decimal3"),
bool_f(F.col("bool")),
)
self.assertEqual(df.collect(), res.collect())
def test_udf_register_arrow_udf_basic(self):
import pyarrow as pa
scalar_original_add = arrow_udf(
lambda x, y: pa.compute.add(x, y).cast(pa.int32()), IntegerType()
)
self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
self.assertEqual(scalar_original_add.deterministic, True)
with self.temp_func("add1"):
new_add = self.spark.udf.register("add1", scalar_original_add)
self.assertEqual(new_add.deterministic, True)
self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
df = self.spark.range(10).select(
F.col("id").cast("int").alias("a"), F.col("id").cast("int").alias("b")
)
res1 = df.select(new_add(F.col("a"), F.col("b")))
res2 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
)
expected = df.select(F.expr("a + b"))
self.assertEqual(expected.collect(), res1.collect())
self.assertEqual(expected.collect(), res2.collect())
@arrow_udf(LongType())
def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for a, b in it:
yield pa.compute.add(a, b)
with self.temp_func("add1"):
new_add = self.spark.udf.register("add1", scalar_iter_add)
res3 = df.select(new_add(F.col("a"), F.col("b")))
res4 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
)
expected = df.select(F.expr("a + b"))
self.assertEqual(expected.collect(), res3.collect())
self.assertEqual(expected.collect(), res4.collect())
def test_catalog_register_arrow_udf_basic(self):
import pyarrow as pa
scalar_original_add = arrow_udf(
lambda x, y: pa.compute.add(x, y).cast(pa.int32()), IntegerType()
)
self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
self.assertEqual(scalar_original_add.deterministic, True)
with self.temp_func("add1"):
new_add = self.spark.catalog.registerFunction("add1", scalar_original_add)
self.assertEqual(new_add.deterministic, True)
self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
df = self.spark.range(10).select(
F.col("id").cast("int").alias("a"), F.col("id").cast("int").alias("b")
)
res1 = df.select(new_add(F.col("a"), F.col("b")))
res2 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
)
expected = df.select(F.expr("a + b"))
self.assertEqual(expected.collect(), res1.collect())
self.assertEqual(expected.collect(), res2.collect())
@arrow_udf(LongType())
def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for a, b in it:
yield pa.compute.add(a, b)
with self.temp_func("add1"):
new_add = self.spark.catalog.registerFunction("add1", scalar_iter_add)
res3 = df.select(new_add(F.col("a"), F.col("b")))
res4 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
)
expected = df.select(F.expr("a + b"))
self.assertEqual(expected.collect(), res3.collect())
self.assertEqual(expected.collect(), res4.collect())
def test_udf_register_nondeterministic_arrow_udf(self):
import pyarrow as pa
random_arrow_udf = arrow_udf(
lambda x: pa.compute.add(x, random.randint(6, 6)), LongType()
).asNondeterministic()
self.assertEqual(random_arrow_udf.deterministic, False)
self.assertEqual(random_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
with self.temp_func("randomArrowUDF"):
nondeterministic_arrow_udf = self.spark.udf.register("randomArrowUDF", random_arrow_udf)
self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
self.assertEqual(
nondeterministic_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF
)
[row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
self.assertEqual(row[0], 7)
def test_catalog_register_nondeterministic_arrow_udf(self):
import pyarrow as pa
random_arrow_udf = arrow_udf(
lambda x: pa.compute.add(x, random.randint(6, 6)), LongType()
).asNondeterministic()
self.assertEqual(random_arrow_udf.deterministic, False)
self.assertEqual(random_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
with self.temp_func("randomArrowUDF"):
nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
"randomArrowUDF", random_arrow_udf
)
self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
self.assertEqual(
nondeterministic_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF
)
[row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
self.assertEqual(row[0], 7)
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_nondeterministic_arrow_udf(self):
import pyarrow as pa
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
@arrow_udf("double")
def scalar_plus_ten(v):
return pa.compute.add(v, 10)
@arrow_udf("double", ArrowUDFType.SCALAR_ITER)
def iter_plus_ten(it):
for v in it:
yield pa.compute.add(v, 10)
for plus_ten in [scalar_plus_ten, iter_plus_ten]:
random_udf = self.nondeterministic_arrow_udf
df = self.spark.range(10).withColumn("rand", random_udf("id"))
result1 = df.withColumn("plus_ten(rand)", plus_ten(df["rand"])).toPandas()
self.assertEqual(random_udf.deterministic, False)
self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 10))
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_nondeterministic_arrow_udf_in_aggregate(self):
with self.quiet():
df = self.spark.range(10)
for random_udf in [
self.nondeterministic_arrow_udf,
self.nondeterministic_arrow_iter_udf,
]:
with self.assertRaisesRegex(AnalysisException, "Non-deterministic"):
df.groupby("id").agg(F.sum(random_udf("id"))).collect()
with self.assertRaisesRegex(AnalysisException, "Non-deterministic"):
df.agg(F.sum(random_udf("id"))).collect()
def test_arrow_udf_chained(self):
import pyarrow.compute as pc
scalar_f = arrow_udf(lambda x: pc.add(x, 1), LongType())
scalar_g = arrow_udf(lambda x: pc.subtract(x, 1), LongType())
iter_f = arrow_udf(
lambda it: map(lambda x: pc.add(x, 1), it),
LongType(),
ArrowUDFType.SCALAR_ITER,
)
iter_g = arrow_udf(
lambda it: map(lambda x: pc.subtract(x, 1), it),
LongType(),
ArrowUDFType.SCALAR_ITER,
)
df = self.spark.range(10)
expected = df.select(F.col("id").alias("res")).collect()
for f, g in [
(scalar_f, scalar_g),
(iter_f, iter_g),
(scalar_f, iter_g),
(iter_f, scalar_g),
]:
res = df.select(g(f(F.col("id"))).alias("res"))
self.assertEqual(expected, res.collect())
def test_arrow_udf_chained_ii(self):
import pyarrow.compute as pc
scalar_f = arrow_udf(lambda x: pc.add(x, 1), LongType())
iter_f = arrow_udf(
lambda it: map(lambda x: pc.add(x, 1), it),
LongType(),
ArrowUDFType.SCALAR_ITER,
)
df = self.spark.range(10)
expected = df.select((F.col("id") + 3).alias("res")).collect()
for f1, f2, f3 in [
(scalar_f, scalar_f, scalar_f),
(iter_f, iter_f, iter_f),
(scalar_f, iter_f, scalar_f),
(iter_f, scalar_f, scalar_f),
(iter_f, iter_f, iter_f),
]:
res = df.select(f1(f2(f3(F.col("id")))).alias("res"))
self.assertEqual(expected, res.collect())
def test_arrow_udf_chained_iii(self):
import pyarrow as pa
scalar_f = arrow_udf(lambda x: pa.compute.add(x, 1), LongType())
scalar_g = arrow_udf(lambda x: pa.compute.subtract(x, 1), LongType())
scalar_m = arrow_udf(lambda x, y: pa.compute.multiply(x, y), LongType())
iter_f = arrow_udf(
lambda it: map(lambda x: pa.compute.add(x, 1), it),
LongType(),
ArrowUDFType.SCALAR_ITER,
)
iter_g = arrow_udf(
lambda it: map(lambda x: pa.compute.subtract(x, 1), it),
LongType(),
ArrowUDFType.SCALAR_ITER,
)
@arrow_udf(LongType())
def iter_m(it: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for a, b in it:
yield pa.compute.multiply(a, b)
df = self.spark.range(10)
expected = df.select(((F.col("id") + 1) * (F.col("id") - 1)).alias("res")).collect()
for f, g, m in [
(scalar_f, scalar_g, scalar_m),
(iter_f, iter_g, iter_m),
(scalar_f, iter_g, scalar_m),
(iter_f, scalar_g, scalar_m),
(iter_f, scalar_g, iter_m),
]:
res = df.select(m(f(F.col("id")), g(F.col("id"))).alias("res"))
self.assertEqual(expected, res.collect())
def test_arrow_udf_chained_struct_type(self):
import pyarrow as pa
return_type = StructType([StructField("id", LongType()), StructField("str", StringType())])
@arrow_udf(return_type)
def scalar_f(id):
return pa.StructArray.from_arrays([id, id.cast(pa.string())], names=["id", "str"])
@arrow_udf(return_type)
def iter_f(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
for id in it:
yield pa.StructArray.from_arrays([id, id.cast(pa.string())], names=["id", "str"])
scalar_g = arrow_udf(lambda x: x, return_type)
iter_g = arrow_udf(lambda x: x, return_type, ArrowUDFType.SCALAR_ITER)
df = self.spark.range(10)
expected = df.select(
F.struct(F.col("id"), F.col("id").cast("string").alias("str")).alias("res")
).collect()
for f, g in [
(scalar_f, scalar_g),
(iter_f, iter_g),
(scalar_f, iter_g),
(iter_f, scalar_g),
]:
res = df.select(g(f(F.col("id"))).alias("res"))
self.assertEqual(expected, res.collect())
def test_arrow_udf_named_arguments(self):
import pyarrow as pa
@arrow_udf("int")
def test_udf(a, b):
return pa.compute.add(a, pa.compute.multiply(b, 10)).cast(pa.int32())
with self.temp_func("test_udf"):
self.spark.udf.register("test_udf", test_udf)
expected = [Row(0), Row(101)]
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") * 10)),
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
self.assertEqual(expected, df.collect())
def test_arrow_udf_named_arguments_negative(self):
import pyarrow as pa
@arrow_udf("int")
def test_udf(a, b):
return pa.compute.add(a, b).cast(pa.int32())
with self.temp_func("test_udf"):
self.spark.udf.register("test_udf", test_udf)
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()
with self.assertRaisesRegex(
PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
def test_arrow_udf_named_arguments_and_defaults(self):
import pyarrow as pa
@arrow_udf("int")
def test_udf(a, b=0):
return pa.compute.add(a, pa.compute.multiply(b, 10)).cast(pa.int32())
with self.temp_func("test_udf"):
self.spark.udf.register("test_udf", test_udf)
# without "b"
expected = [Row(0), Row(1)]
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(F.col("id"))),
self.spark.range(2).select(test_udf(a=F.col("id"))),
self.spark.sql("SELECT test_udf(id) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
]
):
with self.subTest(with_b=False, query_no=i):
self.assertEqual(expected, df.collect())
# with "b"
expected = [Row(0), Row(101)]
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") * 10)),
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(with_b=True, query_no=i):
self.assertEqual(expected, df.collect())
def test_arrow_udf_kwargs(self):
import pyarrow as pa
@arrow_udf("int")
def test_udf(a, **kwargs):
return pa.compute.add(a, pa.compute.multiply(kwargs["b"], 10)).cast(pa.int32())
with self.temp_func("test_udf"):
self.spark.udf.register("test_udf", test_udf)
expected = [Row(0), Row(101)]
for i, df in enumerate(
[
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
]
):
with self.subTest(query_no=i):
self.assertEqual(expected, df.collect())
def test_arrow_iter_udf_single_column(self):
import pyarrow as pa
@arrow_udf(LongType())
def add_one(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
for s in it:
yield pa.compute.add(s, 1)
df = self.spark.range(10)
expected = df.select((F.col("id") + 1).alias("res")).collect()
result = df.select(add_one("id").alias("res"))
self.assertEqual(expected, result.collect())
def test_arrow_iter_udf_two_columns(self):
import pyarrow as pa
@arrow_udf(LongType())
def multiple(it: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for a, b in it:
yield pa.compute.multiply(a, b)
df = self.spark.range(10).select(
F.col("id").alias("a"),
(F.col("id") + 1).alias("b"),
)
expected = df.select((F.col("a") * F.col("b")).alias("res")).collect()
result = df.select(multiple("a", "b").alias("res"))
self.assertEqual(expected, result.collect())
def test_arrow_iter_udf_three_columns(self):
import pyarrow as pa
@arrow_udf(LongType())
def multiple(it: Iterator[Tuple[pa.Array, pa.Array, pa.Array]]) -> Iterator[pa.Array]:
for a, b, c in it:
yield pa.compute.multiply(pa.compute.multiply(a, b), c)
df = self.spark.range(10).select(
F.col("id").alias("a"),
(F.col("id") + 1).alias("b"),
(F.col("id") + 2).alias("c"),
)
expected = df.select((F.col("a") * F.col("b") * F.col("c")).alias("res")).collect()
result = df.select(multiple("a", "b", "c").alias("res"))
self.assertEqual(expected, result.collect())
def test_return_type_coercion(self):
import pyarrow as pa
df = self.spark.range(10)
scalar_long = arrow_udf(lambda x: pa.compute.add(x, 1), LongType())
result1 = df.select(scalar_long("id").alias("res"))
self.assertEqual(10, len(result1.collect()))
# long -> int coercion
scalar_int1 = arrow_udf(lambda x: pa.compute.add(x, 1), IntegerType())
result2 = df.select(scalar_int1("id").alias("res"))
self.assertEqual(10, len(result2.collect()))
# long -> int coercion, overflow
scalar_int2 = arrow_udf(lambda x: pa.compute.add(x, 2147483647), IntegerType())
result3 = df.select(scalar_int2("id").alias("res"))
with self.assertRaises(Exception):
# pyarrow.lib.ArrowInvalid:
# Integer value 2147483652 not in range: -2147483648 to 2147483647
result3.collect()
def test_unsupported_return_types(self):
import pyarrow as pa
with self.quiet():
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*scalar Arrow UDF.*ArrayType.*YearMonthIntervalType",
):
arrow_udf(lambda x: x, ArrayType(YearMonthIntervalType()), udf_type)
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*scalar Arrow UDF.*ArrayType.*YearMonthIntervalType",
):
@arrow_udf(ArrayType(YearMonthIntervalType()))
def func_a(a: pa.Array) -> pa.Array:
return a
class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
ReusedSQLTestCase.setUpClass()
# Synchronize default timezone between Python and Java
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
time.tzset()
cls.sc.environment["TZ"] = tz
cls.spark.conf.set("spark.sql.session.timeZone", tz)
@classmethod
def tearDownClass(cls):
del os.environ["TZ"]
if cls.tz_prev is not None:
os.environ["TZ"] = cls.tz_prev
time.tzset()
ReusedSQLTestCase.tearDownClass()
if __name__ == "__main__":
from pyspark.sql.tests.arrow.test_arrow_udf_scalar import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)