| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| from dataclasses import dataclass |
| from typing import Iterator, Optional |
| |
| from pyspark.errors import ( |
| PySparkAttributeError, |
| PythonException, |
| PySparkTypeError, |
| AnalysisException, |
| PySparkPicklingError, |
| ) |
| from pyspark.util import PythonEvalType |
| from pyspark.sql.functions import ( |
| array, |
| create_map, |
| array, |
| lit, |
| named_struct, |
| udf, |
| udtf, |
| AnalyzeArgument, |
| AnalyzeResult, |
| OrderingColumn, |
| PartitioningColumn, |
| SkipRestOfInputTableException, |
| ) |
| from pyspark.sql.types import ( |
| ArrayType, |
| BooleanType, |
| DataType, |
| IntegerType, |
| LongType, |
| MapType, |
| NullType, |
| Row, |
| StringType, |
| StructField, |
| StructType, |
| ) |
| from pyspark.testing import assertDataFrameEqual, assertSchemaEqual |
| from pyspark.testing.sqlutils import ( |
| have_pandas, |
| have_pyarrow, |
| pandas_requirement_message, |
| pyarrow_requirement_message, |
| ReusedSQLTestCase, |
| ) |
| |
| |
| class BaseUDTFTestsMixin: |
| def test_simple_udtf(self): |
| class TestUDTF: |
| def eval(self): |
| yield "hello", "world" |
| |
| func = udtf(TestUDTF, returnType="c1: string, c2: string") |
| rows = func().collect() |
| self.assertEqual(rows, [Row(c1="hello", c2="world")]) |
| |
| def test_udtf_yield_single_row_col(self): |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| rows = func(lit(1)).collect() |
| self.assertEqual(rows, [Row(a=1)]) |
| |
| def test_udtf_yield_multi_cols(self): |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int") |
| rows = func(lit(1)).collect() |
| self.assertEqual(rows, [Row(a=1, b=2)]) |
| |
| def test_udtf_yield_multi_rows(self): |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| yield a + 1, |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| rows = func(lit(1)).collect() |
| self.assertEqual(rows, [Row(a=1), Row(a=2)]) |
| |
| def test_udtf_yield_multi_row_col(self): |
| class TestUDTF: |
| def eval(self, a: int, b: int): |
| yield a, b, a + b |
| yield a, b, a - b |
| yield a, b, b - a |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int, c: int") |
| rows = func(lit(1), lit(2)).collect() |
| self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)]) |
| |
| def test_udtf_decorator(self): |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| rows = TestUDTF(lit(1)).collect() |
| self.assertEqual(rows, [Row(a=1, b=2)]) |
| |
| def test_udtf_registration(self): |
| class TestUDTF: |
| def eval(self, a: int, b: int): |
| yield a, b, a + b |
| yield a, b, a - b |
| yield a, b, b - a |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int, c: int") |
| self.spark.udtf.register("testUDTF", func) |
| df = self.spark.sql("SELECT * FROM testUDTF(1, 2)") |
| self.assertEqual( |
| df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)] |
| ) |
| |
| def test_udtf_with_lateral_join(self): |
| class TestUDTF: |
| def eval(self, a: int, b: int) -> Iterator: |
| yield a, b, a + b |
| yield a, b, a - b |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int, c: int") |
| self.spark.udtf.register("testUDTF", func) |
| df = self.spark.sql( |
| "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL testUDTF(a, b) f" |
| ) |
| expected = self.spark.createDataFrame( |
| [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", "c"] |
| ) |
| self.assertEqual(df.collect(), expected.collect()) |
| |
| def test_udtf_eval_with_return_stmt(self): |
| class TestUDTF: |
| def eval(self, a: int, b: int): |
| return [(a, a + 1), (b, b + 1)] |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int") |
| rows = func(lit(1), lit(2)).collect() |
| self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)]) |
| |
| def test_udtf_eval_returning_non_tuple(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): |
| TestUDTF(lit(1)).collect() |
| |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| return (a,) |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): |
| TestUDTF(lit(1)).collect() |
| |
| def test_udtf_with_invalid_return_value(self): |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self, a): |
| return a |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): |
| TestUDTF(lit(1)).collect() |
| |
| def test_udtf_with_zero_arg_and_invalid_return_value(self): |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self): |
| return 1 |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): |
| TestUDTF().collect() |
| |
| def test_udtf_with_invalid_return_value_in_terminate(self): |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self, a): |
| ... |
| |
| def terminate(self): |
| return 1 |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): |
| TestUDTF(lit(1)).collect() |
| |
| def test_udtf_eval_with_no_return(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| ... |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), []) |
| |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| return |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), []) |
| |
| def test_udtf_with_conditional_return(self): |
| class TestUDTF: |
| def eval(self, a: int): |
| if a > 5: |
| yield a, |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL test_udtf(id)").collect(), |
| [Row(id=6, a=6), Row(id=7, a=7)], |
| ) |
| |
| def test_udtf_with_empty_yield(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield |
| |
| assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=None)]) |
| |
| def test_udtf_with_none_output(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| yield None, |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)]) |
| df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"]) |
| self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), [Row(a=1, b=2)]) |
| assertDataFrameEqual( |
| TestUDTF(lit(1)).join(df, "a", "left"), [Row(a=None, b=None), Row(a=1, b=2)] |
| ) |
| |
| def test_udtf_with_none_input(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| |
| self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)]) |
| self.spark.udtf.register("testUDTF", TestUDTF) |
| df = self.spark.sql("SELECT * FROM testUDTF(null)") |
| self.assertEqual(df.collect(), [Row(a=None)]) |
| |
| # These are expected error message substrings to be used in test cases below. |
| tooManyPositionalArguments = "too many positional arguments" |
| missingARequiredArgument = "missing a required argument" |
| multipleValuesForArgument = "multiple values for argument" |
| |
| def test_udtf_with_wrong_num_input(self): |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument): |
| TestUDTF().collect() |
| |
| with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.tooManyPositionalArguments): |
| TestUDTF(lit(1), lit(2)).collect() |
| |
| def test_udtf_init_with_additional_args(self): |
| @udtf(returnType="x int") |
| class TestUDTF: |
| def __init__(self, a: int): |
| ... |
| |
| def eval(self, a: int): |
| yield a, |
| |
| with self.assertRaisesRegex(PythonException, r".*constructor has more than one argument.*"): |
| TestUDTF(lit(1)).show() |
| |
| def test_udtf_terminate_with_additional_args(self): |
| @udtf(returnType="x int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| |
| def terminate(self, a: int): |
| ... |
| |
| with self.assertRaisesRegex( |
| PythonException, r"terminate\(\) missing 1 required positional argument: 'a'" |
| ): |
| TestUDTF(lit(1)).show() |
| |
| def test_udtf_with_wrong_num_output(self): |
| err_msg = ( |
| r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the " |
| "result does not match the specified schema." |
| ) |
| |
| # Output less columns than specified return schema |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| |
| with self.assertRaisesRegex(PythonException, err_msg): |
| TestUDTF(lit(1)).collect() |
| |
| # Output more columns than specified return schema |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| with self.assertRaisesRegex(PythonException, err_msg): |
| TestUDTF(lit(1)).collect() |
| |
| def test_udtf_with_empty_output_schema_and_non_empty_output(self): |
| @udtf(returnType=StructType()) |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): |
| TestUDTF().collect() |
| |
| def test_udtf_with_non_empty_output_schema_and_empty_output(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self): |
| yield tuple() |
| |
| with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): |
| TestUDTF().collect() |
| |
| def test_udtf_init(self): |
| @udtf(returnType="a: int, b: int, c: string") |
| class TestUDTF: |
| def __init__(self): |
| self.key = "test" |
| |
| def eval(self, a: int): |
| yield a, a + 1, self.key |
| |
| rows = TestUDTF(lit(1)).collect() |
| self.assertEqual(rows, [Row(a=1, b=2, c="test")]) |
| |
| def test_udtf_terminate(self): |
| @udtf(returnType="key: string, value: float") |
| class TestUDTF: |
| def __init__(self): |
| self._count = 0 |
| self._sum = 0 |
| |
| def eval(self, x: int): |
| self._count += 1 |
| self._sum += x |
| yield "input", float(x) |
| |
| def terminate(self): |
| yield "count", float(self._count) |
| yield "avg", self._sum / self._count |
| |
| self.assertEqual( |
| TestUDTF(lit(1)).collect(), |
| [Row(key="input", value=1), Row(key="count", value=1.0), Row(key="avg", value=1.0)], |
| ) |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| df = self.spark.sql( |
| "SELECT id, key, value FROM range(0, 10, 1, 2), " |
| "LATERAL test_udtf(id) WHERE key != 'input'" |
| ) |
| self.assertEqual( |
| df.collect(), |
| [ |
| Row(id=4, key="count", value=5.0), |
| Row(id=4, key="avg", value=2.0), |
| Row(id=9, key="count", value=5.0), |
| Row(id=9, key="avg", value=7.0), |
| ], |
| ) |
| |
| def test_udtf_cleanup_with_exception_in_eval(self): |
| with tempfile.TemporaryDirectory(prefix="test_udtf_cleanup_with_exception_in_eval") as d: |
| path = os.path.join(d, "file.txt") |
| |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def __init__(self): |
| self.path = path |
| |
| def eval(self, x: int): |
| raise Exception("eval error") |
| |
| def terminate(self): |
| with open(self.path, "a") as f: |
| f.write("terminate") |
| |
| def cleanup(self): |
| with open(self.path, "a") as f: |
| f.write("cleanup") |
| |
| with self.assertRaisesRegex(PythonException, "eval error"): |
| TestUDTF(lit(1)).show() |
| |
| with open(path, "r") as f: |
| data = f.read() |
| |
| # Only cleanup method should be called. |
| self.assertEqual(data, "cleanup") |
| |
| def test_udtf_cleanup_with_exception_in_terminate(self): |
| with tempfile.TemporaryDirectory( |
| prefix="test_udtf_cleanup_with_exception_in_terminate" |
| ) as d: |
| path = os.path.join(d, "file.txt") |
| |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def __init__(self): |
| self.path = path |
| |
| def eval(self, x: int): |
| yield (x,) |
| |
| def terminate(self): |
| raise Exception("terminate error") |
| |
| def cleanup(self): |
| with open(self.path, "a") as f: |
| f.write("cleanup") |
| |
| with self.assertRaisesRegex(PythonException, "terminate error"): |
| TestUDTF(lit(1)).show() |
| |
| with open(path, "r") as f: |
| data = f.read() |
| |
| self.assertEqual(data, "cleanup") |
| |
| def test_init_with_exception(self): |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def __init__(self): |
| raise Exception("error") |
| |
| def eval(self): |
| yield 1, |
| |
| with self.assertRaisesRegex( |
| PythonException, |
| r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " |
| r"in the '__init__' method: error", |
| ): |
| TestUDTF().show() |
| |
| def test_eval_with_exception(self): |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self): |
| raise Exception("error") |
| |
| with self.assertRaisesRegex( |
| PythonException, |
| r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " |
| r"in the 'eval' method: error", |
| ): |
| TestUDTF().show() |
| |
| def test_terminate_with_exceptions(self): |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| def terminate(self): |
| raise ValueError("terminate error") |
| |
| with self.assertRaisesRegex( |
| PythonException, |
| r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " |
| r"in the 'terminate' method: terminate error", |
| ): |
| TestUDTF(lit(1)).collect() |
| |
| def test_udtf_terminate_with_wrong_num_output(self): |
| err_msg = ( |
| r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the result " |
| "does not match the specified schema." |
| ) |
| |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| def terminate(self): |
| yield 1, 2, 3 |
| |
| with self.assertRaisesRegex(PythonException, err_msg): |
| TestUDTF(lit(1)).show() |
| |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, a + 1 |
| |
| def terminate(self): |
| yield 1, |
| |
| with self.assertRaisesRegex(PythonException, err_msg): |
| TestUDTF(lit(1)).show() |
| |
| def test_udtf_determinism(self): |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a, |
| |
| func = udtf(TestUDTF, returnType="x: int") |
| # The UDTF is marked as non-deterministic by default. |
| self.assertFalse(func.deterministic) |
| func = func.asDeterministic() |
| self.assertTrue(func.deterministic) |
| |
| def test_nondeterministic_udtf(self): |
| import random |
| |
| class RandomUDTF: |
| def eval(self, a: int): |
| yield a + int(random.random()), |
| |
| random_udtf = udtf(RandomUDTF, returnType="x: int") |
| assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)]) |
| self.spark.udtf.register("random_udtf", random_udtf) |
| assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)]) |
| |
| def test_udtf_with_nondeterministic_input(self): |
| from pyspark.sql.functions import rand |
| |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield 1 if a > 100 else 0, |
| |
| assertDataFrameEqual(TestUDTF(rand(0) * 100), [Row(x=0)]) |
| |
| def test_udtf_with_invalid_return_type(self): |
| @udtf(returnType="int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a + 1, |
| |
| with self.assertRaises(PySparkTypeError) as e: |
| TestUDTF(lit(1)).collect() |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="UDTF_RETURN_TYPE_MISMATCH", |
| message_parameters={"name": "TestUDTF", "return_type": "IntegerType()"}, |
| ) |
| |
| @udtf(returnType=MapType(StringType(), IntegerType())) |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a + 1, |
| |
| with self.assertRaises(PySparkTypeError) as e: |
| TestUDTF(lit(1)).collect() |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="UDTF_RETURN_TYPE_MISMATCH", |
| message_parameters={ |
| "name": "TestUDTF", |
| "return_type": "MapType(StringType(), IntegerType(), True)", |
| }, |
| ) |
| |
| def test_udtf_with_struct_input_type(self): |
| @udtf(returnType="x: string") |
| class TestUDTF: |
| def eval(self, person): |
| yield f"{person.name}: {person.age}", |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| self.assertEqual( |
| self.spark.sql( |
| "select * from test_udtf(named_struct('name', 'Alice', 'age', 1))" |
| ).collect(), |
| [Row(x="Alice: 1")], |
| ) |
| |
| def test_udtf_with_array_input_type(self): |
| @udtf(returnType="x: string") |
| class TestUDTF: |
| def eval(self, args): |
| yield str(args), |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| self.assertEqual( |
| self.spark.sql("select * from test_udtf(array(1, 2, 3))").collect(), |
| [Row(x="[1, 2, 3]")], |
| ) |
| |
| def test_udtf_with_map_input_type(self): |
| @udtf(returnType="x: string") |
| class TestUDTF: |
| def eval(self, m): |
| yield str(m), |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| self.assertEqual( |
| self.spark.sql("select * from test_udtf(map('key', 'value'))").collect(), |
| [Row(x="{'key': 'value'}")], |
| ) |
| |
| def test_udtf_with_struct_output_types(self): |
| @udtf(returnType="x: struct<a:int,b:int>") |
| class TestUDTF: |
| def eval(self, x: int): |
| yield {"a": x, "b": x + 1}, |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=Row(a=1, b=2))]) |
| |
| def test_udtf_with_array_output_types(self): |
| @udtf(returnType="x: array<int>") |
| class TestUDTF: |
| def eval(self, x: int): |
| yield [x, x + 1, x + 2], |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=[1, 2, 3])]) |
| |
| def test_udtf_with_map_output_types(self): |
| @udtf(returnType="x: map<int,string>") |
| class TestUDTF: |
| def eval(self, x: int): |
| yield {x: str(x)}, |
| |
| self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})]) |
| |
| def test_udtf_with_empty_output_types(self): |
| @udtf(returnType=StructType()) |
| class TestUDTF: |
| def eval(self): |
| yield tuple() |
| |
| assertDataFrameEqual(TestUDTF(), [Row()]) |
| |
| def _check_result_or_exception( |
| self, func_handler, ret_type, expected, *, err_type=PythonException |
| ): |
| func = udtf(func_handler, returnType=ret_type) |
| if not isinstance(expected, str): |
| assertDataFrameEqual(func(), expected) |
| else: |
| with self.assertRaisesRegex(err_type, expected): |
| func().collect() |
| |
| def test_numeric_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| |
| for i, (ret_type, expected) in enumerate( |
| [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=1)]), |
| ("x: smallint", [Row(x=1)]), |
| ("x: int", [Row(x=1)]), |
| ("x: bigint", [Row(x=1)]), |
| ("x: string", [Row(x="1")]), # int to string is ok, but string to int is None |
| ( |
| "x: date", |
| "AttributeError", |
| ), # AttributeError: 'int' object has no attribute 'toordinal' |
| ( |
| "x: timestamp", |
| "AttributeError", |
| ), # AttributeError: 'int' object has no attribute 'tzinfo' |
| ("x: byte", [Row(x=1)]), |
| ("x: binary", [Row(x=None)]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<int>", [Row(x=None)]), |
| ("x: map<string,int>", [Row(x=None)]), |
| ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"), |
| ] |
| ): |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_numeric_string_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield "1", |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=None)]), |
| ("x: smallint", [Row(x=None)]), |
| ("x: int", [Row(x=None)]), |
| ("x: bigint", [Row(x=None)]), |
| ("x: string", [Row(x="1")]), |
| ("x: date", "AttributeError"), |
| ("x: timestamp", "AttributeError"), |
| ("x: byte", [Row(x=None)]), |
| ("x: binary", [Row(x=bytearray(b"1"))]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<int>", [Row(x=None)]), |
| ("x: map<string,int>", [Row(x=None)]), |
| ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_string_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield "hello", |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=None)]), |
| ("x: smallint", [Row(x=None)]), |
| ("x: int", [Row(x=None)]), |
| ("x: bigint", [Row(x=None)]), |
| ("x: string", [Row(x="hello")]), |
| ("x: date", "AttributeError"), |
| ("x: timestamp", "AttributeError"), |
| ("x: byte", [Row(x=None)]), |
| ("x: binary", [Row(x=bytearray(b"hello"))]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<int>", [Row(x=None)]), |
| ("x: map<string,int>", [Row(x=None)]), |
| ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_array_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield [0, 1.1, 2], |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=None)]), |
| ("x: smallint", [Row(x=None)]), |
| ("x: int", [Row(x=None)]), |
| ("x: bigint", [Row(x=None)]), |
| ("x: string", [Row(x="[0, 1.1, 2]")]), |
| ("x: date", "AttributeError"), |
| ("x: timestamp", "AttributeError"), |
| ("x: byte", [Row(x=None)]), |
| ("x: binary", [Row(x=None)]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<int>", [Row(x=[0, None, 2])]), |
| ("x: array<double>", [Row(x=[None, 1.1, None])]), |
| ("x: array<string>", [Row(x=["0", "1.1", "2"])]), |
| ("x: array<boolean>", [Row(x=[None, None, None])]), |
| ("x: array<array<int>>", [Row(x=[None, None, None])]), |
| ("x: map<string,int>", [Row(x=None)]), |
| ("x: struct<a:int,b:int,c:int>", [Row(x=Row(a=0, b=None, c=2))]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_map_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield {"a": 0, "b": 1.1, "c": 2}, |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=None)]), |
| ("x: smallint", [Row(x=None)]), |
| ("x: int", [Row(x=None)]), |
| ("x: bigint", [Row(x=None)]), |
| ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), |
| ("x: date", "AttributeError"), |
| ("x: timestamp", "AttributeError"), |
| ("x: byte", [Row(x=None)]), |
| ("x: binary", [Row(x=None)]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<string>", [Row(x=None)]), |
| ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), |
| ("x: map<string,boolean>", [Row(x={"a": None, "b": None, "c": None})]), |
| ("x: map<string,int>", [Row(x={"a": 0, "b": None, "c": 2})]), |
| ("x: map<string,float>", [Row(x={"a": None, "b": 1.1, "c": None})]), |
| ("x: map<string,map<string,int>>", [Row(x={"a": None, "b": None, "c": None})]), |
| ("x: struct<a:int>", [Row(x=Row(a=0))]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_struct_output_type_casting_dict(self): |
| class TestUDTF: |
| def eval(self): |
| yield {"a": 0, "b": 1.1, "c": 2}, |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=None)]), |
| ("x: tinyint", [Row(x=None)]), |
| ("x: smallint", [Row(x=None)]), |
| ("x: int", [Row(x=None)]), |
| ("x: bigint", [Row(x=None)]), |
| ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), |
| ("x: date", "AttributeError"), |
| ("x: timestamp", "AttributeError"), |
| ("x: byte", [Row(x=None)]), |
| ("x: binary", [Row(x=None)]), |
| ("x: float", [Row(x=None)]), |
| ("x: double", [Row(x=None)]), |
| ("x: decimal(10, 0)", [Row(x=None)]), |
| ("x: array<string>", [Row(x=None)]), |
| ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), |
| ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), |
| ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]), |
| ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, c=None))]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_struct_output_type_casting_row(self): |
| from py4j.protocol import Py4JJavaError |
| |
| self.check_struct_output_type_casting_row(Py4JJavaError) |
| |
| def check_struct_output_type_casting_row(self, error_type): |
| class TestUDTF: |
| def eval(self): |
| yield Row(a=0, b=1.1, c=2), |
| |
| err = ("PickleException", error_type) |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", err), |
| ("x: date", "ValueError"), |
| ("x: timestamp", "ValueError"), |
| ("x: byte", err), |
| ("x: binary", err), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", err), |
| ("x: map<string,string>", err), |
| ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), |
| ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]), |
| ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, c=None))]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| if isinstance(expected, tuple): |
| self._check_result_or_exception( |
| TestUDTF, ret_type, expected[0], err_type=expected[1] |
| ) |
| else: |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_inconsistent_output_types(self): |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| yield [1, 2], |
| |
| for ret_type, expected in [ |
| ("x: int", [Row(x=1), Row(x=None)]), |
| ("x: array<int>", [Row(x=None), Row(x=[1, 2])]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| assertDataFrameEqual(udtf(TestUDTF, returnType=ret_type)(), expected) |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) |
| def test_udtf_with_pandas_input_type(self): |
| import pandas as pd |
| |
| @udtf(returnType="corr: double") |
| class TestUDTF: |
| def eval(self, s1: pd.Series, s2: pd.Series): |
| yield s1.corr(s2) |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| # TODO(SPARK-43968): check during compile time instead of runtime |
| with self.assertRaisesRegex( |
| PythonException, "AttributeError: 'int' object has no attribute 'corr'" |
| ): |
| self.spark.sql( |
| "select * from values (1, 2), (2, 3) t(a, b), lateral test_udtf(a, b)" |
| ).collect() |
| |
| def test_udtf_register_error(self): |
| @udf |
| def upper(s: str): |
| return s.upper() |
| |
| with self.assertRaises(PySparkTypeError) as e: |
| self.spark.udtf.register("test_udf", upper) |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_EVAL_TYPE", |
| message_parameters={ |
| "name": "test_udf", |
| "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF", |
| }, |
| ) |
| |
| def test_udtf_pickle_error(self): |
| with tempfile.TemporaryDirectory(prefix="test_udtf_pickle_error") as d: |
| file = os.path.join(d, "file.txt") |
| file_obj = open(file, "w") |
| |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self): |
| file_obj |
| yield 1, |
| |
| with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"): |
| TestUDTF().collect() |
| |
| def test_udtf_access_spark_session(self): |
| df = self.spark.range(10) |
| |
| @udtf(returnType="x: int") |
| class TestUDTF: |
| def eval(self): |
| df.collect() |
| yield 1, |
| |
| with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"): |
| TestUDTF().collect() |
| |
| def test_udtf_no_eval(self): |
| with self.assertRaises(PySparkAttributeError) as e: |
| |
| @udtf(returnType="a: int, b: int") |
| class TestUDTF: |
| def run(self, a: int): |
| yield a, a + 1 |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_NO_EVAL", |
| message_parameters={"name": "TestUDTF"}, |
| ) |
| |
| def test_udtf_with_no_handler_class(self): |
| with self.assertRaises(PySparkTypeError) as e: |
| |
| @udtf(returnType="a: int") |
| def test_udtf(a: int): |
| yield a, |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_HANDLER_TYPE", |
| message_parameters={"type": "function"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as e: |
| udtf(1, returnType="a: int") |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_HANDLER_TYPE", |
| message_parameters={"type": "int"}, |
| ) |
| |
| def test_udtf_with_table_argument_query(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))").collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_int_and_table_argument_query(self): |
| class TestUDTF: |
| def eval(self, i: int, row: Row): |
| if row["id"] > i: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql( |
| "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0, 8)))" |
| ).collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_table_argument_identifier(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| with self.tempView("v"): |
| self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") |
| self.assertEqual( |
| self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_int_and_table_argument_identifier(self): |
| class TestUDTF: |
| def eval(self, i: int, row: Row): |
| if row["id"] > i: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| with self.tempView("v"): |
| self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") |
| self.assertEqual( |
| self.spark.sql("SELECT * FROM test_udtf(5, TABLE (v))").collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_table_argument_unknown_identifier(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): |
| self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect() |
| |
| def test_udtf_with_table_argument_malformed_query(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): |
| self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM v))").collect() |
| |
| def test_udtf_with_table_argument_cte_inside(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql( |
| """ |
| SELECT * FROM test_udtf(TABLE ( |
| WITH t AS ( |
| SELECT id FROM range(0, 8) |
| ) |
| SELECT * FROM t |
| )) |
| """ |
| ).collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_table_argument_cte_outside(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql( |
| """ |
| WITH t AS ( |
| SELECT id FROM range(0, 8) |
| ) |
| SELECT * FROM test_udtf(TABLE (SELECT id FROM t)) |
| """ |
| ).collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| self.assertEqual( |
| self.spark.sql( |
| """ |
| WITH t AS ( |
| SELECT id FROM range(0, 8) |
| ) |
| SELECT * FROM test_udtf(TABLE (t)) |
| """ |
| ).collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| # TODO(SPARK-44233): Fix the subquery resolution. |
| @unittest.skip("Fails to resolve the subquery.") |
| def test_udtf_with_table_argument_lateral_join(self): |
| class TestUDTF: |
| def eval(self, row: Row): |
| if row["id"] > 5: |
| yield row["id"], |
| |
| func = udtf(TestUDTF, returnType="a: int") |
| self.spark.udtf.register("test_udtf", func) |
| self.assertEqual( |
| self.spark.sql( |
| """ |
| SELECT * FROM |
| range(0, 8) AS t, |
| LATERAL test_udtf(TABLE (t)) |
| """ |
| ).collect(), |
| [Row(a=6), Row(a=7)], |
| ) |
| |
| def test_udtf_with_table_argument_multiple(self): |
| class TestUDTF: |
| def eval(self, a: Row, b: Row): |
| yield a[0], b[0] |
| |
| func = udtf(TestUDTF, returnType="a: int, b: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| query = """ |
| SELECT * FROM test_udtf( |
| TABLE (SELECT id FROM range(0, 2)), |
| TABLE (SELECT id FROM range(0, 3))) |
| """ |
| |
| with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": False}): |
| with self.assertRaisesRegex( |
| AnalysisException, "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" |
| ): |
| self.spark.sql(query).collect() |
| |
| with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}): |
| self.assertEqual( |
| self.spark.sql(query).collect(), |
| [ |
| Row(a=0, b=0), |
| Row(a=1, b=0), |
| Row(a=0, b=1), |
| Row(a=1, b=1), |
| Row(a=0, b=2), |
| Row(a=1, b=2), |
| ], |
| ) |
| |
| def test_docstring(self): |
| class TestUDTF: |
| """A UDTF for test.""" |
| |
| def __init__(self): |
| """Initialize the UDTF""" |
| ... |
| |
| @staticmethod |
| def analyze(x: AnalyzeArgument) -> AnalyzeResult: |
| """Analyze the argument.""" |
| ... |
| |
| def eval(self, x: int): |
| """Evaluate the input row.""" |
| yield x + 1, |
| |
| def terminate(self): |
| """Terminate the UDTF.""" |
| ... |
| |
| cls = udtf(TestUDTF).func |
| self.assertIn("A UDTF for test", cls.__doc__) |
| self.assertIn("Initialize the UDTF", cls.__init__.__doc__) |
| self.assertIn("Analyze the argument", cls.analyze.__doc__) |
| self.assertIn("Evaluate the input row", cls.eval.__doc__) |
| self.assertIn("Terminate the UDTF", cls.terminate.__doc__) |
| |
| def test_simple_udtf_with_analyze(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze() -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| expected = [Row(c1="hello", c2="world")] |
| assertDataFrameEqual(func(), expected) |
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) |
| |
| def test_udtf_with_analyze(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| assert isinstance(a, AnalyzeArgument) |
| assert isinstance(a.dataType, DataType) |
| assert a.value is not None |
| assert a.isTable is False |
| assert a.isConstantExpression is True |
| return AnalyzeResult(StructType().add("a", a.dataType)) |
| |
| def eval(self, a): |
| yield a, |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| for i, (df, expected_schema, expected_results) in enumerate( |
| [ |
| (func(lit(1)), StructType().add("a", IntegerType()), [Row(a=1)]), |
| # another data type |
| (func(lit("x")), StructType().add("a", StringType()), [Row(a="x")]), |
| # array type |
| ( |
| func(array(lit(1), lit(2), lit(3))), |
| StructType().add("a", ArrayType(IntegerType(), containsNull=False)), |
| [Row(a=[1, 2, 3])], |
| ), |
| # map type |
| ( |
| func(create_map(lit("x"), lit(1), lit("y"), lit(2))), |
| StructType().add( |
| "a", MapType(StringType(), IntegerType(), valueContainsNull=False) |
| ), |
| [Row(a={"x": 1, "y": 2})], |
| ), |
| # struct type |
| ( |
| func(named_struct(lit("x"), lit(1), lit("y"), lit(2))), |
| StructType().add( |
| "a", |
| StructType() |
| .add("x", IntegerType(), nullable=False) |
| .add("y", IntegerType(), nullable=False), |
| ), |
| [Row(a=Row(x=1, y=2))], |
| ), |
| # use SQL |
| ( |
| self.spark.sql("SELECT * from test_udtf(1)"), |
| StructType().add("a", IntegerType()), |
| [Row(a=1)], |
| ), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, expected_schema) |
| assertDataFrameEqual(df, expected_results) |
| |
| def test_udtf_with_analyze_decorator(self): |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze() -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| expected = [Row(c1="hello", c2="world")] |
| assertDataFrameEqual(TestUDTF(), expected) |
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) |
| |
| def test_udtf_with_analyze_decorator_parens(self): |
| @udtf() |
| class TestUDTF: |
| @staticmethod |
| def analyze() -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| expected = [Row(c1="hello", c2="world")] |
| assertDataFrameEqual(TestUDTF(), expected) |
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) |
| |
| def test_udtf_with_analyze_multiple_arguments(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("a", a.dataType).add("b", b.dataType)) |
| |
| def eval(self, a, b): |
| yield a, b |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| for i, (df, expected_schema, expected_results) in enumerate( |
| [ |
| ( |
| func(lit(1), lit("x")), |
| StructType().add("a", IntegerType()).add("b", StringType()), |
| [Row(a=1, b="x")], |
| ), |
| ( |
| self.spark.sql("SELECT * FROM test_udtf(1, 'x')"), |
| StructType().add("a", IntegerType()).add("b", StringType()), |
| [Row(a=1, b="x")], |
| ), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, expected_schema) |
| assertDataFrameEqual(df, expected_results) |
| |
| def test_udtf_with_analyze_arbitary_number_arguments(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(*args: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult( |
| StructType([StructField(f"col{i}", a.dataType) for i, a in enumerate(args)]) |
| ) |
| |
| def eval(self, *args): |
| yield args |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| for i, (df, expected_schema, expected_results) in enumerate( |
| [ |
| ( |
| func(lit(1)), |
| StructType().add("col0", IntegerType()), |
| [Row(a=1)], |
| ), |
| ( |
| self.spark.sql("SELECT * FROM test_udtf(1, 'x')"), |
| StructType().add("col0", IntegerType()).add("col1", StringType()), |
| [Row(a=1, b="x")], |
| ), |
| (func(), StructType(), [Row()]), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, expected_schema) |
| assertDataFrameEqual(df, expected_results) |
| |
| def test_udtf_with_analyze_table_argument(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| assert isinstance(a, AnalyzeArgument) |
| assert isinstance(a.dataType, StructType) |
| assert a.value is None |
| assert a.isTable is True |
| assert a.isConstantExpression is False |
| return AnalyzeResult(StructType().add("a", a.dataType[0].dataType)) |
| |
| def eval(self, a: Row): |
| if a["id"] > 5: |
| yield a["id"], |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))") |
| assertSchemaEqual(df.schema, StructType().add("a", LongType())) |
| assertDataFrameEqual(df, [Row(a=6), Row(a=7)]) |
| |
| def test_udtf_with_analyze_table_argument_adding_columns(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| assert isinstance(a.dataType, StructType) |
| assert a.isTable is True |
| assert a.isConstantExpression is False |
| return AnalyzeResult(a.dataType.add("is_even", BooleanType())) |
| |
| def eval(self, a: Row): |
| yield a["id"], a["id"] % 2 == 0 |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 4)))") |
| assertSchemaEqual( |
| df.schema, |
| StructType().add("id", LongType(), nullable=False).add("is_even", BooleanType()), |
| ) |
| assertDataFrameEqual( |
| df, |
| [ |
| Row(a=0, is_even=True), |
| Row(a=1, is_even=False), |
| Row(a=2, is_even=True), |
| Row(a=3, is_even=False), |
| ], |
| ) |
| |
| def test_udtf_with_analyze_table_argument_repeating_rows(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(n, row) -> AnalyzeResult: |
| if n.value is None or not isinstance(n.value, int) or (n.value < 1 or n.value > 10): |
| raise Exception("The first argument must be a scalar integer between 1 and 10") |
| |
| if row.isTable is False: |
| raise Exception("The second argument must be a table argument") |
| |
| assert isinstance(row.dataType, StructType) |
| return AnalyzeResult(row.dataType) |
| |
| def eval(self, n: int, row: Row): |
| for _ in range(n): |
| yield row |
| |
| func = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", func) |
| |
| expected_schema = StructType().add("id", LongType(), nullable=False) |
| expected_results = [ |
| Row(a=0), |
| Row(a=0), |
| Row(a=1), |
| Row(a=1), |
| Row(a=2), |
| Row(a=2), |
| Row(a=3), |
| Row(a=3), |
| ] |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(2, TABLE (SELECT id FROM range(0, 4)))"), |
| self.spark.sql( |
| "SELECT * FROM test_udtf(1 + 1, TABLE (SELECT id FROM range(0, 4)))" |
| ), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, expected_schema) |
| assertDataFrameEqual(df, expected_results) |
| |
| with self.assertRaisesRegex( |
| AnalysisException, "The first argument must be a scalar integer between 1 and 10" |
| ): |
| self.spark.sql( |
| "SELECT * FROM test_udtf(0, TABLE (SELECT id FROM range(0, 4)))" |
| ).collect() |
| |
| with self.sql_conf( |
| {"spark.sql.tvf.allowMultipleTableArguments.enabled": True} |
| ), self.assertRaisesRegex( |
| AnalysisException, "The first argument must be a scalar integer between 1 and 10" |
| ): |
| self.spark.sql( |
| """ |
| SELECT * FROM test_udtf( |
| TABLE (SELECT id FROM range(0, 1)), |
| TABLE (SELECT id FROM range(0, 4))) |
| """ |
| ).collect() |
| |
| with self.assertRaisesRegex( |
| AnalysisException, "The second argument must be a table argument" |
| ): |
| self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect() |
| |
| def test_udtf_with_both_return_type_and_analyze(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze() -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| with self.assertRaises(PySparkAttributeError) as e: |
| udtf(TestUDTF, returnType="c1: string, c2: string") |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE", |
| message_parameters={"name": "TestUDTF"}, |
| ) |
| |
| def test_udtf_with_neither_return_type_nor_analyze(self): |
| class TestUDTF: |
| def eval(self): |
| yield "hello", "world" |
| |
| with self.assertRaises(PySparkAttributeError) as e: |
| udtf(TestUDTF) |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_RETURN_TYPE", |
| message_parameters={"name": "TestUDTF"}, |
| ) |
| |
| def test_udtf_with_analyze_non_staticmethod(self): |
| class TestUDTF: |
| def analyze(self) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| with self.assertRaises(PySparkAttributeError) as e: |
| udtf(TestUDTF) |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_RETURN_TYPE", |
| message_parameters={"name": "TestUDTF"}, |
| ) |
| |
| with self.assertRaises(PySparkAttributeError) as e: |
| udtf(TestUDTF, returnType="c1: string, c2: string") |
| |
| self.check_error( |
| exception=e.exception, |
| error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE", |
| message_parameters={"name": "TestUDTF"}, |
| ) |
| |
| def test_udtf_with_analyze_returning_non_struct(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(): |
| return StringType() |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| func = udtf(TestUDTF) |
| |
| with self.assertRaisesRegex( |
| AnalysisException, |
| "'analyze' method expects a result of type pyspark.sql.udtf.AnalyzeResult, " |
| "but instead this method returned a value of type: " |
| "<class 'pyspark.sql.types.StringType'>", |
| ): |
| func().collect() |
| |
| def test_udtf_with_analyze_raising_an_exception(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze() -> AnalyzeResult: |
| raise Exception("Failed to analyze.") |
| |
| def eval(self): |
| yield "hello", "world" |
| |
| func = udtf(TestUDTF) |
| |
| with self.assertRaisesRegex(AnalysisException, "Failed to analyze."): |
| func().collect() |
| |
| def test_udtf_with_analyze_null_literal(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("a", a.dataType)) |
| |
| def eval(self, a): |
| yield a, |
| |
| func = udtf(TestUDTF) |
| |
| df = func(lit(None)) |
| assertSchemaEqual(df.schema, StructType().add("a", NullType())) |
| assertDataFrameEqual(df, [Row(a=None)]) |
| |
| def test_udtf_with_analyze_taking_wrong_number_of_arguments(self): |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("a", a.dataType).add("b", b.dataType)) |
| |
| def eval(self, a, b): |
| yield a, a + 1 |
| |
| func = udtf(TestUDTF) |
| |
| with self.assertRaisesRegex(AnalysisException, r"arguments"): |
| func(lit(1)).collect() |
| |
| with self.assertRaisesRegex(AnalysisException, r"arguments"): |
| func(lit(1), lit(2), lit(3)).collect() |
| |
| def test_udtf_with_analyze_taking_keyword_arguments(self): |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(**kwargs) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add("a", StringType()).add("b", StringType())) |
| |
| def eval(self, **kwargs): |
| yield "hello", "world" |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| expected = [Row(c1="hello", c2="world")] |
| assertDataFrameEqual(TestUDTF(), expected) |
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) |
| assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected) |
| |
| with self.assertRaisesRegex( |
| AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments |
| ): |
| TestUDTF(lit(1)).collect() |
| |
| with self.assertRaisesRegex( |
| AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments |
| ): |
| self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect() |
| |
| def test_udtf_with_analyze_using_broadcast(self): |
| colname = self.sc.broadcast("col1") |
| |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add(colname.value, a.dataType)) |
| |
| def eval(self, a): |
| assert colname.value == "col1" |
| yield a, |
| |
| def terminate(self): |
| assert colname.value == "col1" |
| yield 100, |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")]): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| def test_udtf_with_analyze_using_accumulator(self): |
| test_accum = self.sc.accumulator(0) |
| |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| test_accum.add(1) |
| return AnalyzeResult(StructType().add("col1", a.dataType)) |
| |
| def eval(self, a): |
| test_accum.add(10) |
| yield a, |
| |
| def terminate(self): |
| test_accum.add(100) |
| yield 100, |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")]): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| self.assertEqual(test_accum.value, 222) |
| |
| def _add_pyfile(self, path): |
| self.sc.addPyFile(path) |
| |
| def test_udtf_with_analyze_using_pyfile(self): |
| with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_pyfile") as d: |
| pyfile_path = os.path.join(d, "my_pyfile.py") |
| with open(pyfile_path, "w") as f: |
| f.write("my_func = lambda: 'col1'") |
| |
| self._add_pyfile(pyfile_path) |
| |
| class TestUDTF: |
| @staticmethod |
| def call_my_func() -> str: |
| import my_pyfile |
| |
| return my_pyfile.my_func() |
| |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.dataType)) |
| |
| def eval(self, a): |
| assert TestUDTF.call_my_func() == "col1" |
| yield a, |
| |
| def terminate(self): |
| assert TestUDTF.call_my_func() == "col1" |
| yield 100, |
| |
| test_udtf = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", test_udtf) |
| |
| for i, df in enumerate( |
| [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| def test_udtf_with_analyze_using_zipped_package(self): |
| with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_zipped_package") as d: |
| package_path = os.path.join(d, "my_zipfile") |
| os.mkdir(package_path) |
| pyfile_path = os.path.join(package_path, "__init__.py") |
| with open(pyfile_path, "w") as f: |
| f.write("my_func = lambda: 'col1'") |
| shutil.make_archive(package_path, "zip", d, "my_zipfile") |
| |
| self._add_pyfile(f"{package_path}.zip") |
| |
| class TestUDTF: |
| @staticmethod |
| def call_my_func() -> str: |
| import my_zipfile |
| |
| return my_zipfile.my_func() |
| |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.dataType)) |
| |
| def eval(self, a): |
| assert TestUDTF.call_my_func() == "col1" |
| yield a, |
| |
| def terminate(self): |
| assert TestUDTF.call_my_func() == "col1" |
| yield 100, |
| |
| test_udtf = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", test_udtf) |
| |
| for i, df in enumerate( |
| [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| def _add_archive(self, path): |
| self.sc.addArchive(path) |
| |
| def test_udtf_with_analyze_using_archive(self): |
| from pyspark.core.files import SparkFiles |
| |
| with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_archive") as d: |
| archive_path = os.path.join(d, "my_archive") |
| os.mkdir(archive_path) |
| pyfile_path = os.path.join(archive_path, "my_file.txt") |
| with open(pyfile_path, "w") as f: |
| f.write("col1") |
| shutil.make_archive(archive_path, "zip", d, "my_archive") |
| |
| self._add_archive(f"{archive_path}.zip#my_files") |
| |
| class TestUDTF: |
| @staticmethod |
| def read_my_archive() -> str: |
| with open( |
| os.path.join( |
| SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt" |
| ), |
| "r", |
| ) as my_file: |
| return my_file.read().strip() |
| |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add(TestUDTF.read_my_archive(), a.dataType)) |
| |
| def eval(self, a): |
| assert TestUDTF.read_my_archive() == "col1" |
| yield a, |
| |
| def terminate(self): |
| assert TestUDTF.read_my_archive() == "col1" |
| yield 100, |
| |
| test_udtf = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", test_udtf) |
| |
| for i, df in enumerate( |
| [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| def _add_file(self, path): |
| self.sc.addFile(path) |
| |
| def test_udtf_with_analyze_using_file(self): |
| from pyspark.core.files import SparkFiles |
| |
| with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_file") as d: |
| file_path = os.path.join(d, "my_file.txt") |
| with open(file_path, "w") as f: |
| f.write("col1") |
| |
| self._add_file(file_path) |
| |
| class TestUDTF: |
| @staticmethod |
| def read_my_file() -> str: |
| with open( |
| os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r" |
| ) as my_file: |
| return my_file.read().strip() |
| |
| @staticmethod |
| def analyze(a: AnalyzeArgument) -> AnalyzeResult: |
| return AnalyzeResult(StructType().add(TestUDTF.read_my_file(), a.dataType)) |
| |
| def eval(self, a): |
| assert TestUDTF.read_my_file() == "col1" |
| yield a, |
| |
| def terminate(self): |
| assert TestUDTF.read_my_file() == "col1" |
| yield 100, |
| |
| test_udtf = udtf(TestUDTF) |
| self.spark.udtf.register("test_udtf", test_udtf) |
| |
| for i, df in enumerate( |
| [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] |
| ): |
| with self.subTest(query_no=i): |
| assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) |
| assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) |
| |
| def test_udtf_with_named_arguments(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a, b): |
| yield a, |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), |
| self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), |
| TestUDTF(a=lit(10), b=lit("x")), |
| TestUDTF(b=lit("x"), a=lit(10)), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertDataFrameEqual(df, [Row(a=10)]) |
| |
| def test_udtf_with_named_arguments_negative(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a, b): |
| yield a, |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| with self.assertRaisesRegex( |
| AnalysisException, |
| "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", |
| ): |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show() |
| |
| with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() |
| |
| with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument): |
| self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show() |
| |
| with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.multipleValuesForArgument): |
| self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show() |
| |
| def test_udtf_with_kwargs(self): |
| @udtf(returnType="a: int, b: string") |
| class TestUDTF: |
| def eval(self, **kwargs): |
| yield kwargs["a"], kwargs["b"] |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), |
| self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), |
| TestUDTF(a=lit(10), b=lit("x")), |
| TestUDTF(b=lit("x"), a=lit(10)), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertDataFrameEqual(df, [Row(a=10, b="x")]) |
| |
| # negative |
| with self.assertRaisesRegex( |
| AnalysisException, |
| "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", |
| ): |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show() |
| |
| with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() |
| |
| def test_udtf_with_analyze_kwargs(self): |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult: |
| assert isinstance(kwargs["a"].dataType, IntegerType) |
| assert kwargs["a"].value == 10 |
| assert kwargs["a"].isTable is False |
| assert kwargs["a"].isConstantExpression is True |
| assert isinstance(kwargs["b"].dataType, StringType) |
| assert kwargs["b"].value == "x" |
| assert not kwargs["b"].isTable |
| return AnalyzeResult( |
| StructType( |
| [StructField(key, arg.dataType) for key, arg in sorted(kwargs.items())] |
| ) |
| ) |
| |
| def eval(self, **kwargs): |
| yield tuple(value for _, value in sorted(kwargs.items())) |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"), |
| self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"), |
| TestUDTF(a=lit(10), b=lit("x")), |
| TestUDTF(b=lit("x"), a=lit(10)), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertDataFrameEqual(df, [Row(a=10, b="x")]) |
| |
| def test_udtf_with_named_arguments_lateral_join(self): |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(a, b): |
| return AnalyzeResult(StructType().add("a", a.dataType)) |
| |
| def eval(self, a, b): |
| yield a, |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| # lateral join |
| for i, df in enumerate( |
| [ |
| self.spark.sql( |
| "SELECT f.* FROM " |
| "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(a => a, b => b) f" |
| ), |
| self.spark.sql( |
| "SELECT f.* FROM " |
| "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(b => b, a => a) f" |
| ), |
| ] |
| ): |
| with self.subTest(query_no=i): |
| assertDataFrameEqual(df, [Row(a=0), Row(a=1)]) |
| |
| def test_udtf_with_named_arguments_and_defaults(self): |
| @udtf |
| class TestUDTF: |
| @staticmethod |
| def analyze(a: AnalyzeArgument, b: Optional[AnalyzeArgument] = None): |
| assert isinstance(a.dataType, IntegerType) |
| assert a.value == 10 |
| assert not a.isTable |
| assert a.isConstantExpression is True |
| if b is not None: |
| assert isinstance(b.dataType, StringType) |
| assert b.value == "z" |
| assert not b.isTable |
| schema = StructType().add("a", a.dataType) |
| if b is None: |
| return AnalyzeResult(schema.add("b", IntegerType())) |
| else: |
| return AnalyzeResult(schema.add("b", b.dataType)) |
| |
| def eval(self, a, b=100): |
| yield a, b |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| # without "b" |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(10)"), |
| self.spark.sql("SELECT * FROM test_udtf(a => 10)"), |
| TestUDTF(lit(10)), |
| TestUDTF(a=lit(10)), |
| ] |
| ): |
| with self.subTest(with_b=False, query_no=i): |
| assertDataFrameEqual(df, [Row(a=10, b=100)]) |
| |
| # with "b" |
| for i, df in enumerate( |
| [ |
| self.spark.sql("SELECT * FROM test_udtf(10, b => 'z')"), |
| self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'z')"), |
| self.spark.sql("SELECT * FROM test_udtf(b => 'z', a => 10)"), |
| TestUDTF(lit(10), b=lit("z")), |
| TestUDTF(a=lit(10), b=lit("z")), |
| TestUDTF(b=lit("z"), a=lit(10)), |
| ] |
| ): |
| with self.subTest(with_b=True, query_no=i): |
| assertDataFrameEqual(df, [Row(a=10, b="z")]) |
| |
| def test_udtf_with_table_argument_and_partition_by(self): |
| class TestUDTF: |
| def __init__(self): |
| self._sum = 0 |
| self._partition_col = None |
| |
| def eval(self, row: Row): |
| # Make sure that the PARTITION BY expressions were projected out. |
| assert len(row.asDict().items()) == 2 |
| assert "partition_col" in row |
| assert "input" in row |
| self._sum += row["input"] |
| if self._partition_col is not None and self._partition_col != row["partition_col"]: |
| # Make sure that all values of the partitioning column are the same |
| # for each row consumed by this method for this instance of the class. |
| raise Exception( |
| f"self._partition_col was {self._partition_col} but the row " |
| + f"value was {row['partition_col']}" |
| ) |
| self._partition_col = row["partition_col"] |
| |
| def terminate(self): |
| yield self._partition_col, self._sum |
| |
| # This is a basic example. |
| func = udtf(TestUDTF, returnType="partition_col: int, total: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 21) |
| ) |
| SELECT partition_col, total |
| FROM test_udtf({table_arg}) |
| ORDER BY 1, 2 |
| """ |
| |
| for table_arg in [ |
| "TABLE(t) PARTITION BY partition_col - 1", |
| "row => TABLE(t) PARTITION BY partition_col - 1", |
| ]: |
| with self.subTest(table_arg=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [Row(partition_col=x, total=3) for x in range(1, 21)], |
| ) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT {str_first} AS partition_col, id AS input FROM range(0, 2) |
| UNION ALL |
| SELECT {str_second} AS partition_col, id AS input FROM range(0, 2) |
| ) |
| SELECT partition_col, total |
| FROM test_udtf({table_arg}) |
| ORDER BY 1, 2 |
| """ |
| |
| # These cases partition by constant values. |
| for str_first, str_second, result_first, result_second in ( |
| ("123", "456", 123, 456), |
| ("123", "NULL", None, 123), |
| ): |
| for table_arg in [ |
| "TABLE(t) PARTITION BY partition_col", |
| "row => TABLE(t) PARTITION BY partition_col", |
| ]: |
| with self.subTest(str_first=str_first, str_second=str_second, table_arg=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql( |
| base_query.format( |
| str_first=str_first, str_second=str_second, table_arg=table_arg |
| ) |
| ), |
| [ |
| Row(partition_col=result_first, total=1), |
| Row(partition_col=result_second, total=1), |
| ], |
| ) |
| |
| # Combine a lateral join with a TABLE argument with PARTITION BY . |
| func = udtf(TestUDTF, returnType="partition_col: int, total: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 3) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 3) |
| ) |
| SELECT v.a, v.b, f.partition_col, f.total |
| FROM VALUES (0, 1) AS v(a, b), |
| LATERAL test_udtf({table_arg}) f |
| ORDER BY 1, 2, 3, 4 |
| """ |
| |
| for table_arg in [ |
| "TABLE(t) PARTITION BY partition_col - 1", |
| "row => TABLE(t) PARTITION BY partition_col - 1", |
| ]: |
| with self.subTest(func_call=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [ |
| Row(a=0, b=1, partition_col=1, total=3), |
| Row(a=0, b=1, partition_col=2, total=3), |
| ], |
| ) |
| |
| def test_udtf_with_table_argument_and_partition_by_and_order_by(self): |
| class TestUDTF: |
| def __init__(self): |
| self._last = None |
| self._partition_col = None |
| |
| def eval(self, row: Row, partition_col: str): |
| # Make sure that the PARTITION BY and ORDER BY expressions were projected out. |
| assert len(row.asDict().items()) == 2 |
| assert "partition_col" in row |
| assert "input" in row |
| # Make sure that all values of the partitioning column are the same |
| # for each row consumed by this method for this instance of the class. |
| if self._partition_col is not None and self._partition_col != row[partition_col]: |
| raise Exception( |
| f"self._partition_col was {self._partition_col} but the row " |
| + f"value was {row[partition_col]}" |
| ) |
| self._last = row["input"] |
| self._partition_col = row[partition_col] |
| |
| def terminate(self): |
| yield self._partition_col, self._last |
| |
| func = udtf(TestUDTF, returnType="partition_col: int, last: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 21) |
| ) |
| SELECT partition_col, last |
| FROM test_udtf( |
| {table_arg}, |
| partition_col => 'partition_col') |
| ORDER BY 1, 2 |
| """ |
| |
| for order_by_str, result_val in ( |
| ("input ASC", 2), |
| ("input + 1 ASC", 2), |
| ("input DESC", 1), |
| ("input - 1 DESC", 1), |
| ): |
| for table_arg in [ |
| f"TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str}", |
| f"row => TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str}", |
| ]: |
| with self.subTest(table_arg=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [Row(partition_col=x, last=result_val) for x in range(1, 21)], |
| ) |
| |
| def test_udtf_with_table_argument_with_single_partition(self): |
| class TestUDTF: |
| def __init__(self): |
| self._count = 0 |
| self._sum = 0 |
| self._last = None |
| |
| def eval(self, row: Row): |
| # Make sure that the rows arrive in the expected order. |
| if self._last is not None and self._last > row["input"]: |
| raise Exception( |
| f"self._last was {self._last} but the row value was {row['input']}" |
| ) |
| self._count += 1 |
| self._last = row["input"] |
| self._sum += row["input"] |
| |
| def terminate(self): |
| yield self._count, self._sum, self._last |
| |
| func = udtf(TestUDTF, returnType="count: int, total: int, last: int") |
| self.spark.udtf.register("test_udtf", func) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 21) |
| ) |
| SELECT count, total, last |
| FROM test_udtf({table_arg}) |
| ORDER BY 1, 2 |
| """ |
| |
| for table_arg in [ |
| "TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col)", |
| "row => TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col)", |
| ]: |
| with self.subTest(table_arg=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [Row(count=40, total=60, last=2)], |
| ) |
| |
| def test_udtf_with_table_argument_with_single_partition_from_analyze(self): |
| @udtf |
| class TestUDTF: |
| def __init__(self): |
| self._count = 0 |
| self._sum = 0 |
| self._last = None |
| |
| @staticmethod |
| def analyze(*args, **kwargs): |
| return AnalyzeResult( |
| schema=StructType() |
| .add("count", IntegerType()) |
| .add("total", IntegerType()) |
| .add("last", IntegerType()), |
| withSinglePartition=True, |
| orderBy=[OrderingColumn("input"), OrderingColumn("partition_col")], |
| ) |
| |
| def eval(self, row: Row): |
| # Make sure that the rows arrive in the expected order. |
| if self._last is not None and self._last > row["input"]: |
| raise Exception( |
| f"self._last was {self._last} but the row value was {row['input']}" |
| ) |
| self._count += 1 |
| self._last = row["input"] |
| self._sum += row["input"] |
| |
| def terminate(self): |
| yield self._count, self._sum, self._last |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 21) |
| ) |
| SELECT count, total, last |
| FROM test_udtf({table_arg}) |
| ORDER BY 1, 2 |
| """ |
| |
| for table_arg in ["TABLE(t)", "row => TABLE(t)"]: |
| with self.subTest(table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [Row(count=40, total=60, last=2)], |
| ) |
| |
| def test_udtf_with_table_argument_with_partition_by_and_order_by_from_analyze(self): |
| @udtf |
| class TestUDTF: |
| def __init__(self): |
| self._partition_col = None |
| self._count = 0 |
| self._sum = 0 |
| self._last = None |
| |
| @staticmethod |
| def analyze(*args, **kwargs): |
| return AnalyzeResult( |
| schema=StructType() |
| .add("partition_col", IntegerType()) |
| .add("count", IntegerType()) |
| .add("total", IntegerType()) |
| .add("last", IntegerType()), |
| partitionBy=[PartitioningColumn("partition_col")], |
| orderBy=[ |
| OrderingColumn(name="input", ascending=True, overrideNullsFirst=False) |
| ], |
| ) |
| |
| def eval(self, row: Row): |
| # Make sure that the PARTITION BY and ORDER BY expressions were projected out. |
| assert len(row.asDict().items()) == 2 |
| assert "partition_col" in row |
| assert "input" in row |
| # Make sure that all values of the partitioning column are the same |
| # for each row consumed by this method for this instance of the class. |
| if self._partition_col is not None and self._partition_col != row["partition_col"]: |
| raise Exception( |
| f"self._partition_col was {self._partition_col} but the row " |
| + f"value was {row['partition_col']}" |
| ) |
| # Make sure that the rows arrive in the expected order. |
| if ( |
| self._last is not None |
| and row["input"] is not None |
| and self._last > row["input"] |
| ): |
| raise Exception( |
| f"self._last was {self._last} but the row value was {row['input']}" |
| ) |
| self._partition_col = row["partition_col"] |
| self._count += 1 |
| self._last = row["input"] |
| if row["input"] is not None: |
| self._sum += row["input"] |
| |
| def terminate(self): |
| yield self._partition_col, self._count, self._sum, self._last |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| base_query = """ |
| WITH t AS ( |
| SELECT id AS partition_col, 1 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT id AS partition_col, 2 AS input FROM range(1, 21) |
| UNION ALL |
| SELECT 42 AS partition_col, NULL AS input |
| UNION ALL |
| SELECT 42 AS partition_col, 1 AS input |
| UNION ALL |
| SELECT 42 AS partition_col, 2 AS input |
| ) |
| SELECT partition_col, count, total, last |
| FROM test_udtf({table_arg}) |
| ORDER BY 1, 2 |
| """ |
| |
| for table_arg in ["TABLE(t)", "row => TABLE(t)"]: |
| with self.subTest(table_arg=table_arg): |
| assertDataFrameEqual( |
| self.spark.sql(base_query.format(table_arg=table_arg)), |
| [Row(partition_col=x, count=2, total=3, last=2) for x in range(1, 21)] |
| + [Row(partition_col=42, count=3, total=3, last=None)], |
| ) |
| |
| def test_udtf_with_prepare_string_from_analyze(self): |
| @dataclass |
| class AnalyzeResultWithBuffer(AnalyzeResult): |
| buffer: str = "" |
| |
| @udtf |
| class TestUDTF: |
| def __init__(self, analyze_result=None): |
| self._total = 0 |
| if analyze_result is not None: |
| self._buffer = analyze_result.buffer |
| else: |
| self._buffer = "" |
| |
| @staticmethod |
| def analyze(argument, _): |
| if ( |
| argument.value is None |
| or argument.isTable |
| or not isinstance(argument.value, str) |
| or len(argument.value) == 0 |
| ): |
| raise Exception("The first argument must be non-empty string") |
| assert argument.dataType == StringType() |
| assert not argument.isTable |
| return AnalyzeResultWithBuffer( |
| schema=StructType().add("total", IntegerType()).add("buffer", StringType()), |
| withSinglePartition=True, |
| buffer=argument.value, |
| ) |
| |
| def eval(self, argument, row: Row): |
| self._total += 1 |
| |
| def terminate(self): |
| yield self._total, self._buffer |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| assertDataFrameEqual( |
| self.spark.sql( |
| """ |
| WITH t AS ( |
| SELECT id FROM range(1, 21) |
| ) |
| SELECT total, buffer |
| FROM test_udtf("abc", TABLE(t)) |
| """ |
| ), |
| [Row(count=20, buffer="abc")], |
| ) |
| |
| def test_udtf_with_skip_rest_of_input_table_exception(self): |
| @udtf(returnType="current: int, total: int") |
| class TestUDTF: |
| def __init__(self): |
| self._current = 0 |
| self._total = 0 |
| |
| def eval(self, input: Row): |
| self._current = input["id"] |
| self._total += 1 |
| if self._total >= 4: |
| raise SkipRestOfInputTableException("Stop at self._total >= 4") |
| |
| def terminate(self): |
| yield self._current, self._total |
| |
| self.spark.udtf.register("test_udtf", TestUDTF) |
| |
| # Run a test case including WITH SINGLE PARTITION on the UDTF call. The |
| # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed. |
| assertDataFrameEqual( |
| self.spark.sql( |
| """ |
| WITH t AS ( |
| SELECT id FROM range(1, 21) |
| ) |
| SELECT current, total |
| FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY id) |
| """ |
| ), |
| [Row(current=4, total=4)], |
| ) |
| |
| # Run a test case including WITH SINGLE PARTITION on the UDTF call. The |
| # SkipRestOfInputTableException stops scanning rows for each of the two partitions |
| # separately. |
| assertDataFrameEqual( |
| self.spark.sql( |
| """ |
| WITH t AS ( |
| SELECT id FROM range(1, 21) |
| ) |
| SELECT current, total |
| FROM test_udtf(TABLE(t) PARTITION BY floor(id / 10) ORDER BY id) |
| ORDER BY ALL |
| """ |
| ), |
| [Row(current=4, total=4), Row(current=13, total=4), Row(current=20, total=1)], |
| ) |
| |
| |
| class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): |
| @classmethod |
| def setUpClass(cls): |
| super(UDTFTests, cls).setUpClass() |
| cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false") |
| |
| @classmethod |
| def tearDownClass(cls): |
| try: |
| cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") |
| finally: |
| super(UDTFTests, cls).tearDownClass() |
| |
| |
| @unittest.skipIf( |
| not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message |
| ) |
| class UDTFArrowTestsMixin(BaseUDTFTestsMixin): |
| def test_eval_type(self): |
| def upper(x: str): |
| return upper(x) |
| |
| class TestUDTF: |
| def eval(self, x: str): |
| return upper(x) |
| |
| self.assertEqual( |
| udtf(TestUDTF, returnType="x: string", useArrow=False).evalType, |
| PythonEvalType.SQL_TABLE_UDF, |
| ) |
| |
| self.assertEqual( |
| udtf(TestUDTF, returnType="x: string", useArrow=True).evalType, |
| PythonEvalType.SQL_ARROW_TABLE_UDF, |
| ) |
| |
| def test_udtf_arrow_sql_conf(self): |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| |
| # We do not use `self.sql_conf` here to test the SQL SET command |
| # instead of using PySpark's `spark.conf.set`. |
| old_value = self.spark.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") |
| self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=False") |
| self.assertEqual(udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_TABLE_UDF) |
| self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=True") |
| self.assertEqual( |
| udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_ARROW_TABLE_UDF |
| ) |
| self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) |
| |
| def test_udtf_eval_returning_non_tuple(self): |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| yield a |
| |
| # When arrow is enabled, it can handle non-tuple return value. |
| assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) |
| |
| @udtf(returnType="a: int") |
| class TestUDTF: |
| def eval(self, a: int): |
| return [a] |
| |
| assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) |
| |
| def test_numeric_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=True)]), |
| ("x: tinyint", [Row(x=1)]), |
| ("x: smallint", [Row(x=1)]), |
| ("x: int", [Row(x=1)]), |
| ("x: bigint", [Row(x=1)]), |
| ("x: string", [Row(x="1")]), # require arrow.cast |
| ("x: date", err), |
| ("x: byte", [Row(x=1)]), |
| ("x: binary", [Row(x=bytearray(b"\x01"))]), |
| ("x: float", [Row(x=1.0)]), |
| ("x: double", [Row(x=1.0)]), |
| ("x: decimal(10, 0)", err), |
| ("x: array<int>", err), |
| ("x: map<string,int>", err), |
| ("x: struct<a:int>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_numeric_string_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield "1", |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", [Row(x=True)]), |
| ("x: tinyint", [Row(x=1)]), |
| ("x: smallint", [Row(x=1)]), |
| ("x: int", [Row(x=1)]), |
| ("x: bigint", [Row(x=1)]), |
| ("x: string", [Row(x="1")]), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", [Row(x=1)]), |
| ("x: binary", [Row(x=bytearray(b"1"))]), |
| ("x: float", [Row(x=1.0)]), |
| ("x: double", [Row(x=1.0)]), |
| ("x: decimal(10, 0)", [Row(x=1)]), |
| ("x: array<string>", [Row(x=["1"])]), |
| ("x: array<int>", [Row(x=[1])]), |
| ("x: map<string,int>", err), |
| ("x: struct<a:int>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_string_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield "hello", |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", [Row(x="hello")]), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", err), |
| ("x: binary", [Row(x=bytearray(b"hello"))]), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", [Row(x=["h", "e", "l", "l", "o"])]), |
| ("x: array<int>", err), |
| ("x: map<string,int>", err), |
| ("x: struct<a:int>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_array_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield [0, 1.1, 2], |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", err), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", err), |
| ("x: binary", err), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", [Row(x=["0", "1.1", "2"])]), |
| ("x: array<boolean>", [Row(x=[False, True, True])]), |
| ("x: array<int>", [Row(x=[0, 1, 2])]), |
| ("x: array<float>", [Row(x=[0, 1.1, 2])]), |
| ("x: array<array<int>>", err), |
| ("x: map<string,int>", err), |
| ("x: struct<a:int>", err), |
| ("x: struct<a:int,b:int,c:int>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_map_output_type_casting(self): |
| class TestUDTF: |
| def eval(self): |
| yield {"a": 0, "b": 1.1, "c": 2}, |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", err), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", err), |
| ("x: binary", err), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", [Row(x=["a", "b", "c"])]), |
| ("x: map<string,string>", err), |
| ("x: map<string,boolean>", err), |
| ("x: map<string,int>", [Row(x={"a": 0, "b": 1, "c": 2})]), |
| ("x: map<string,float>", [Row(x={"a": 0, "b": 1.1, "c": 2})]), |
| ("x: map<string,map<string,int>>", err), |
| ("x: struct<a:int>", [Row(x=Row(a=0))]), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_struct_output_type_casting_dict(self): |
| class TestUDTF: |
| def eval(self): |
| yield {"a": 0, "b": 1.1, "c": 2}, |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", err), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", err), |
| ("x: binary", err), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", [Row(x=["a", "b", "c"])]), |
| ("x: map<string,string>", err), |
| ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), |
| ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]), |
| ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, c=2))]), |
| ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_struct_output_type_casting_row(self): |
| class TestUDTF: |
| def eval(self): |
| yield Row(a=0, b=1.1, c=2), |
| |
| err = "UDTF_ARROW_TYPE_CAST_ERROR" |
| |
| for ret_type, expected in [ |
| ("x: boolean", err), |
| ("x: tinyint", err), |
| ("x: smallint", err), |
| ("x: int", err), |
| ("x: bigint", err), |
| ("x: string", err), |
| ("x: date", err), |
| ("x: timestamp", err), |
| ("x: byte", err), |
| ("x: binary", err), |
| ("x: float", err), |
| ("x: double", err), |
| ("x: decimal(10, 0)", err), |
| ("x: array<string>", [Row(x=["0", "1.1", "2"])]), |
| ("x: map<string,string>", err), |
| ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), |
| ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]), |
| ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, c=2))]), |
| ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err), |
| ]: |
| with self.subTest(ret_type=ret_type): |
| self._check_result_or_exception(TestUDTF, ret_type, expected) |
| |
| def test_inconsistent_output_types(self): |
| class TestUDTF: |
| def eval(self): |
| yield 1, |
| yield [1, 2], |
| |
| for ret_type in [ |
| "x: int", |
| "x: array<int>", |
| ]: |
| with self.subTest(ret_type=ret_type): |
| with self.assertRaisesRegex(PythonException, "UDTF_ARROW_TYPE_CAST_ERROR"): |
| udtf(TestUDTF, returnType=ret_type)().collect() |
| |
| |
| class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase): |
| @classmethod |
| def setUpClass(cls): |
| super(UDTFArrowTests, cls).setUpClass() |
| cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") |
| |
| @classmethod |
| def tearDownClass(cls): |
| try: |
| cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") |
| finally: |
| super(UDTFArrowTests, cls).tearDownClass() |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.test_udtf import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |