| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| from contextlib import redirect_stdout |
| import datetime |
| from enum import Enum |
| from inspect import getmembers, isfunction, isclass |
| import io |
| from itertools import chain |
| import math |
| import re |
| import unittest |
| |
| from pyspark.errors import PySparkTypeError, PySparkValueError, SparkRuntimeException |
| from pyspark.sql import Row, Window, functions as F, types |
| from pyspark.sql.avro.functions import from_avro, to_avro |
| from pyspark.sql.column import Column |
| from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull |
| from pyspark.sql.types import StructType, StructField, StringType |
| from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils |
| from pyspark.testing.utils import have_numpy, assertDataFrameEqual |
| |
| |
| class FunctionsTestsMixin: |
| def test_function_parity(self): |
| # This test compares the available list of functions in pyspark.sql.functions with those |
| # available in the Scala/Java DataFrame API in org.apache.spark.sql.functions. |
| # |
| # NOTE FOR DEVELOPERS: |
| # If this test fails one of the following needs to happen |
| # * If a function was added to org.apache.spark.sql.functions it either needs to be added to |
| # pyspark.sql.functions or added to the below expected_missing_in_py set. |
| # * If a function was added to pyspark.sql.functions that was already in |
| # org.apache.spark.sql.functions then it needs to be removed from expected_missing_in_py |
| # below. If the function has a different name it needs to be added to py_equiv_jvm |
| # mapping. |
| # * If it's not related to an added/removed function then likely the exclusion list |
| # jvm_excluded_fn needs to be updated. |
| |
| jvm_fn_set = {name for (name, value) in getmembers(self.sc._jvm.functions)} |
| py_fn_set = {name for (name, value) in getmembers(F, isfunction) if name[0] != "_"} |
| |
| # Functions on the JVM side we do not expect to be available in python because they are |
| # depreciated, irrelevant to python, or have equivalents. |
| jvm_excluded_fn = [ |
| "callUDF", # depreciated, use call_udf |
| "typedlit", # Scala only |
| "typedLit", # Scala only |
| "monotonicallyIncreasingId", # depreciated, use monotonically_increasing_id |
| "not", # equivalent to python ~expression |
| "any", # equivalent to python ~some |
| "len", # equivalent to python ~length |
| "udaf", # used for creating UDAF's which are not supported in PySpark |
| "partitioning$", # partitioning expressions for DSv2 |
| ] |
| |
| jvm_fn_set.difference_update(jvm_excluded_fn) |
| |
| # For functions that are named differently in pyspark this is the mapping of their |
| # python name to the JVM equivalent |
| py_equiv_jvm = {"create_map": "map"} |
| for py_name, jvm_name in py_equiv_jvm.items(): |
| if py_name in py_fn_set: |
| py_fn_set.remove(py_name) |
| py_fn_set.add(jvm_name) |
| |
| missing_in_py = jvm_fn_set.difference(py_fn_set) |
| |
| # Functions that we expect to be missing in python until they are added to pyspark |
| expected_missing_in_py = set( |
| # TODO(SPARK-53108): Implement the time_diff function in Python |
| ["time_diff"] |
| ) |
| |
| self.assertEqual( |
| expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" |
| ) |
| |
| def test_wildcard_import(self): |
| all_set = set(F.__all__) |
| |
| # { |
| # "abs", |
| # "acos", |
| # "acosh", |
| # "add_months", |
| # "aes_decrypt", |
| # "aes_encrypt", |
| # ..., |
| # } |
| fn_set = { |
| name |
| for (name, value) in getmembers(F, isfunction) |
| if name[0] != "_" and value.__module__ != "typing" |
| } |
| |
| deprecated_fn_list = [ |
| "approxCountDistinct", # deprecated |
| "bitwiseNOT", # deprecated |
| "countDistinct", # deprecated |
| "chr", # name conflict with builtin function |
| "random", # name conflict with builtin function |
| "shiftLeft", # deprecated |
| "shiftRight", # deprecated |
| "shiftRightUnsigned", # deprecated |
| "sumDistinct", # deprecated |
| "toDegrees", # deprecated |
| "toRadians", # deprecated |
| "uuid", # name conflict with builtin module |
| ] |
| unregistered_fn_list = [ |
| "chr", # name conflict with builtin function |
| "random", # name conflict with builtin function |
| "uuid", # name conflict with builtin module |
| ] |
| expected_fn_all_diff = set(deprecated_fn_list + unregistered_fn_list) |
| self.assertEqual(expected_fn_all_diff, fn_set - all_set) |
| |
| # { |
| # "AnalyzeArgument", |
| # "AnalyzeResult", |
| # ..., |
| # "UserDefinedFunction", |
| # "UserDefinedTableFunction", |
| # } |
| clz_set = { |
| name |
| for (name, value) in getmembers(F, isclass) |
| if name[0] != "_" and value.__module__ != "typing" |
| } |
| |
| expected_clz_all_diff = { |
| "ArrayType", # should be imported from pyspark.sql.types |
| "ByteType", # should be imported from pyspark.sql.types |
| "Column", # should be imported from pyspark.sql |
| "DataType", # should be imported from pyspark.sql.types |
| "NumericType", # should be imported from pyspark.sql.types |
| "PySparkTypeError", # should be imported from pyspark.errors |
| "PySparkValueError", # should be imported from pyspark.errors |
| "StringType", # should be imported from pyspark.sql.types |
| "StructType", # should be imported from pyspark.sql.types |
| } |
| self.assertEqual(expected_clz_all_diff, clz_set - all_set) |
| |
| unknonw_set = all_set - (fn_set | clz_set) |
| self.assertEqual(unknonw_set, set()) |
| |
| def test_explode(self): |
| d = [ |
| Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), |
| Row(a=1, intlist=[], mapfield={}), |
| Row(a=1, intlist=None, mapfield=None), |
| ] |
| data = self.spark.createDataFrame(d) |
| |
| result = data.select(F.explode(data.intlist).alias("a")).select("a").collect() |
| self.assertEqual(result[0][0], 1) |
| self.assertEqual(result[1][0], 2) |
| self.assertEqual(result[2][0], 3) |
| |
| result = data.select(F.explode(data.mapfield).alias("a", "b")).select("a", "b").collect() |
| self.assertEqual(result[0][0], "a") |
| self.assertEqual(result[0][1], "b") |
| |
| result = [tuple(x) for x in data.select(F.posexplode_outer("intlist")).collect()] |
| self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) |
| |
| result = [tuple(x) for x in data.select(F.posexplode_outer("mapfield")).collect()] |
| self.assertEqual(result, [(0, "a", "b"), (None, None, None), (None, None, None)]) |
| |
| result = [x[0] for x in data.select(F.explode_outer("intlist")).collect()] |
| self.assertEqual(result, [1, 2, 3, None, None]) |
| |
| result = [tuple(x) for x in data.select(F.explode_outer("mapfield")).collect()] |
| self.assertEqual(result, [("a", "b"), (None, None), (None, None)]) |
| |
| def test_inline(self): |
| d = [ |
| Row(structlist=[Row(b=1, c=2), Row(b=3, c=4)]), |
| Row(structlist=[Row(b=None, c=5), None]), |
| Row(structlist=[]), |
| ] |
| data = self.spark.createDataFrame(d) |
| |
| result = [tuple(x) for x in data.select(F.inline(data.structlist)).collect()] |
| self.assertEqual(result, [(1, 2), (3, 4), (None, 5), (None, None)]) |
| |
| result = [tuple(x) for x in data.select(F.inline_outer(data.structlist)).collect()] |
| self.assertEqual(result, [(1, 2), (3, 4), (None, 5), (None, None), (None, None)]) |
| |
| def test_basic_functions(self): |
| rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) |
| df = self.spark.read.json(rdd) |
| df.count() |
| df.collect() |
| df.schema |
| |
| # cache and checkpoint |
| self.assertFalse(df.is_cached) |
| df.persist() |
| df.unpersist(True) |
| df.cache() |
| self.assertTrue(df.is_cached) |
| self.assertEqual(2, df.count()) |
| |
| with self.tempView("temp"): |
| df.createOrReplaceTempView("temp") |
| df = self.spark.sql("select foo from temp") |
| df.count() |
| df.collect() |
| |
| def test_corr(self): |
| df = self.spark.createDataFrame([Row(a=i, b=math.sqrt(i)) for i in range(10)]) |
| corr = df.stat.corr("a", "b") |
| self.assertTrue(abs(corr - 0.95734012) < 1e-6) |
| |
| def test_sampleby(self): |
| df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in range(100)]) |
| sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0) |
| self.assertTrue(35 <= sampled.count() <= 36) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.sampleBy(10, fractions={0: 0.5, 1: 0.5}) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "col", "arg_type": "int"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.sampleBy("b", fractions=[0.5, 0.5]) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_DICT", |
| messageParameters={"arg_name": "fractions", "arg_type": "list"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.sampleBy("b", fractions={None: 0.5, 1: 0.5}) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="DISALLOWED_TYPE_FOR_CONTAINER", |
| messageParameters={ |
| "arg_name": "fractions", |
| "arg_type": "dict", |
| "allowed_types": "float, int, str", |
| "item_type": "NoneType", |
| }, |
| ) |
| |
| def test_cov(self): |
| df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)]) |
| cov = df.stat.cov("a", "b") |
| self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.stat.cov(10, "b") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "col1", "arg_type": "int"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.stat.cov("a", True) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "col2", "arg_type": "bool"}, |
| ) |
| |
| def test_crosstab(self): |
| df = self.spark.createDataFrame([Row(a=i % 3, b=i % 2) for i in range(1, 7)]) |
| ct = df.stat.crosstab("a", "b").collect() |
| ct = sorted(ct, key=lambda x: x[0]) |
| for i, row in enumerate(ct): |
| self.assertEqual(row[0], str(i)) |
| self.assertTrue(row[1], 1) |
| self.assertTrue(row[2], 1) |
| |
| def test_math_functions(self): |
| df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)]) |
| |
| SQLTestUtils.assert_close( |
| [math.cos(i) for i in range(10)], df.select(F.cos(df.a)).collect() |
| ) |
| SQLTestUtils.assert_close([math.cos(i) for i in range(10)], df.select(F.cos("a")).collect()) |
| SQLTestUtils.assert_close( |
| [math.sin(i) for i in range(10)], df.select(F.sin(df.a)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.sin(i) for i in range(10)], df.select(F.sin(df["a"])).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.pow(i, 2 * i) for i in range(10)], df.select(F.pow(df.a, df.b)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.pow(i, 2) for i in range(10)], df.select(F.pow(df.a, 2)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.pow(i, 2) for i in range(10)], df.select(F.pow(df.a, 2.0)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.hypot(i, 2 * i) for i in range(10)], df.select(F.hypot(df.a, df.b)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.hypot(i, 2 * i) for i in range(10)], df.select(F.hypot("a", "b")).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.hypot(i, 2) for i in range(10)], df.select(F.hypot("a", 2)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| [math.hypot(i, 2) for i in range(10)], df.select(F.hypot(df.a, 2)).collect() |
| ) |
| |
| def test_inverse_trig_functions(self): |
| df = self.spark.createDataFrame([Row(a=i * 0.2, b=i * -0.2) for i in range(10)]) |
| |
| def check(trig, inv, y_axis_symmetrical): |
| SQLTestUtils.assert_close( |
| [n * 0.2 for n in range(10)], |
| df.select(inv(trig(df.a))).collect(), |
| ) |
| if y_axis_symmetrical: |
| SQLTestUtils.assert_close( |
| [n * 0.2 for n in range(10)], |
| df.select(inv(trig(df.b))).collect(), |
| ) |
| else: |
| SQLTestUtils.assert_close( |
| [n * -0.2 for n in range(10)], |
| df.select(inv(trig(df.b))).collect(), |
| ) |
| |
| check(F.cosh, F.acosh, y_axis_symmetrical=True) |
| check(F.sinh, F.asinh, y_axis_symmetrical=False) |
| check(F.tanh, F.atanh, y_axis_symmetrical=False) |
| |
| def test_reciprocal_trig_functions(self): |
| # SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT) |
| lst = [ |
| 0.0, |
| math.pi / 6, |
| math.pi / 4, |
| math.pi / 3, |
| math.pi / 2, |
| math.pi, |
| 3 * math.pi / 2, |
| 2 * math.pi, |
| ] |
| |
| df = self.spark.createDataFrame(lst, types.DoubleType()) |
| |
| def to_reciprocal_trig(func): |
| return [1.0 / func(i) if func(i) != 0 else math.inf for i in lst] |
| |
| SQLTestUtils.assert_close( |
| to_reciprocal_trig(math.cos), df.select(F.sec(df.value)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| to_reciprocal_trig(math.sin), df.select(F.csc(df.value)).collect() |
| ) |
| SQLTestUtils.assert_close( |
| to_reciprocal_trig(math.tan), df.select(F.cot(df.value)).collect() |
| ) |
| |
| def test_rand_functions(self): |
| df = self.spark.createDataFrame([Row(key=i, value=str(i)) for i in range(100)]) |
| |
| rnd = df.select("key", F.rand()).collect() |
| for row in rnd: |
| assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] |
| rndn = df.select("key", F.randn(5)).collect() |
| for row in rndn: |
| assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] |
| |
| # If the specified seed is 0, we should use it. |
| # https://issues.apache.org/jira/browse/SPARK-9691 |
| rnd1 = df.select("key", F.rand(0)).collect() |
| rnd2 = df.select("key", F.rand(0)).collect() |
| self.assertEqual(sorted(rnd1), sorted(rnd2)) |
| |
| rndn1 = df.select("key", F.randn(0)).collect() |
| rndn2 = df.select("key", F.randn(0)).collect() |
| self.assertEqual(sorted(rndn1), sorted(rndn2)) |
| |
| def test_time_trunc(self): |
| # SPARK-53110: test the time_trunc function. |
| df = self.spark.range(1).select( |
| F.lit("minute").alias("unit"), F.lit(datetime.time(1, 2, 3)).alias("time") |
| ) |
| result = datetime.time(1, 2, 0) |
| row_from_col = df.select(F.time_trunc(df.unit, df.time)).first() |
| self.assertIsInstance(row_from_col[0], datetime.time) |
| self.assertEqual(row_from_col[0], result) |
| row_from_name = df.select(F.time_trunc("unit", "time")).first() |
| self.assertIsInstance(row_from_name[0], datetime.time) |
| self.assertEqual(row_from_name[0], result) |
| |
| def test_try_parse_url(self): |
| df = self.spark.createDataFrame( |
| [("https://spark.apache.org/path?query=1", "QUERY", "query")], |
| ["url", "part", "key"], |
| ) |
| actual = df.select(F.try_parse_url(df.url, df.part, df.key)) |
| assertDataFrameEqual(actual, [Row("1")]) |
| df = self.spark.createDataFrame( |
| [("inva lid://spark.apache.org/path?query=1", "QUERY", "query")], |
| ["url", "part", "key"], |
| ) |
| actual = df.select(F.try_parse_url(df.url, df.part, df.key)) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| def test_try_make_timestamp(self): |
| data = [(2024, 5, 22, 10, 30, 0)] |
| df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) |
| actual = df.select( |
| F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) |
| ) |
| assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) |
| |
| data = [(2024, 13, 22, 10, 30, 0)] |
| df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) |
| actual = df.select( |
| F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) |
| ) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| def test_try_make_timestamp_ltz(self): |
| # use local timezone here to avoid flakiness |
| data = [(2024, 5, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())] |
| df = self.spark.createDataFrame( |
| data, ["year", "month", "day", "hour", "minute", "second", "timezone"] |
| ) |
| actual = df.select( |
| F.try_make_timestamp_ltz( |
| df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))]) |
| |
| # use local timezone here to avoid flakiness |
| data = [(2024, 13, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())] |
| df = self.spark.createDataFrame( |
| data, ["year", "month", "day", "hour", "minute", "second", "timezone"] |
| ) |
| actual = df.select( |
| F.try_make_timestamp_ltz( |
| df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| def test_try_make_timestamp_ntz(self): |
| """Test cases for try_make_timestamp_ntz with 6-parameter and date/time forms.""" |
| |
| # Test 1: Valid 6 positional arguments |
| data = [(2024, 5, 22, 10, 30, 0)] |
| result = datetime.datetime(2024, 5, 22, 10, 30) |
| df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) |
| actual = df.select( |
| F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) |
| ) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 2: Invalid input (month=13) - should return NULL |
| data = [(2024, 13, 22, 10, 30, 0)] |
| df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) |
| actual = df.select( |
| F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) |
| ) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| # Test 3: Date/time keyword arguments |
| df = self.spark.range(1).select( |
| F.lit(datetime.date(2024, 5, 22)).alias("date"), |
| F.lit(datetime.time(10, 30, 0)).alias("time"), |
| ) |
| actual = df.select(F.try_make_timestamp_ntz(date=df.date, time=df.time)) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 4: All 6 keyword arguments |
| df_full = self.spark.createDataFrame( |
| [(2024, 5, 22, 10, 30, 45)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_full.select( |
| F.try_make_timestamp_ntz( |
| years=df_full.year, |
| months=df_full.month, |
| days=df_full.day, |
| hours=df_full.hour, |
| mins=df_full.minute, |
| secs=df_full.second, |
| ) |
| ) |
| expected = datetime.datetime(2024, 5, 22, 10, 30, 45) |
| assertDataFrameEqual(actual, [Row(expected)]) |
| |
| # Test 5: Only year provided - should raise Exception for missing required parameters |
| with self.assertRaises(Exception): |
| F.try_make_timestamp_ntz(years=df_full.year) |
| |
| # Test 6: Partial parameters - should raise Exception for missing required parameters |
| with self.assertRaises(Exception): |
| F.try_make_timestamp_ntz(years=df_full.year, months=df_full.month, days=df_full.day) |
| |
| # Test 7: Partial parameters - should raise Exception for missing required parameters |
| with self.assertRaises(Exception): |
| F.try_make_timestamp_ntz( |
| years=df_full.year, months=df_full.month, days=df_full.day, hours=df_full.hour |
| ) |
| |
| # Test 8: Fractional seconds |
| df_frac = self.spark.createDataFrame( |
| [(2024, 5, 22, 10, 30, 45.123)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_frac.select( |
| F.try_make_timestamp_ntz( |
| df_frac.year, |
| df_frac.month, |
| df_frac.day, |
| df_frac.hour, |
| df_frac.minute, |
| df_frac.second, |
| ) |
| ) |
| expected_frac = datetime.datetime(2024, 5, 22, 10, 30, 45, 123000) |
| assertDataFrameEqual(actual, [Row(expected_frac)]) |
| |
| # Test 9: Edge case - February 29 in leap year (full 6 parameters) |
| df_leap = self.spark.createDataFrame( |
| [(2024, 2, 29, 0, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_leap.select( |
| F.try_make_timestamp_ntz( |
| df_leap.year, |
| df_leap.month, |
| df_leap.day, |
| df_leap.hour, |
| df_leap.minute, |
| df_leap.second, |
| ) |
| ) |
| expected_leap = datetime.datetime(2024, 2, 29, 0, 0, 0) |
| assertDataFrameEqual(actual, [Row(expected_leap)]) |
| |
| # Test 10: Edge case - February 29 in non-leap year (should return NULL) |
| df_non_leap = self.spark.createDataFrame( |
| [(2023, 2, 29, 0, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_non_leap.select( |
| F.try_make_timestamp_ntz( |
| df_non_leap.year, |
| df_non_leap.month, |
| df_non_leap.day, |
| df_non_leap.hour, |
| df_non_leap.minute, |
| df_non_leap.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| # Test 11: Minimum valid values (full 6 parameters) |
| df_min = self.spark.createDataFrame( |
| [(1, 1, 1, 0, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_min.select( |
| F.try_make_timestamp_ntz( |
| df_min.year, df_min.month, df_min.day, df_min.hour, df_min.minute, df_min.second |
| ) |
| ) |
| expected_min = datetime.datetime(1, 1, 1, 0, 0, 0) |
| assertDataFrameEqual(actual, [Row(expected_min)]) |
| |
| # Test 12: Maximum valid hour/minute/second |
| df_max_time = self.spark.createDataFrame( |
| [(2024, 5, 22, 23, 59, 59)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_max_time.select( |
| F.try_make_timestamp_ntz( |
| df_max_time.year, |
| df_max_time.month, |
| df_max_time.day, |
| df_max_time.hour, |
| df_max_time.minute, |
| df_max_time.second, |
| ) |
| ) |
| expected_max_time = datetime.datetime(2024, 5, 22, 23, 59, 59) |
| assertDataFrameEqual(actual, [Row(expected_max_time)]) |
| |
| # Test 13: Invalid hour (should return NULL) |
| df_invalid_hour = self.spark.createDataFrame( |
| [(2024, 5, 22, 25, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| actual = df_invalid_hour.select( |
| F.try_make_timestamp_ntz( |
| df_invalid_hour.year, |
| df_invalid_hour.month, |
| df_invalid_hour.day, |
| df_invalid_hour.hour, |
| df_invalid_hour.minute, |
| df_invalid_hour.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| # Test 14: Valid date/time combination with NULL date |
| df = self.spark.range(1).select( |
| F.lit(None).cast("date").alias("date"), F.lit(datetime.time(10, 30, 0)).alias("time") |
| ) |
| actual = df.select(F.try_make_timestamp_ntz(date=df.date, time=df.time)) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| # Test 15: Valid date/time combination with NULL time |
| df = self.spark.range(1).select( |
| F.lit(datetime.date(2024, 5, 22)).alias("date"), F.lit(None).cast("time").alias("time") |
| ) |
| actual = df.select(F.try_make_timestamp_ntz(date=df.date, time=df.time)) |
| assertDataFrameEqual(actual, [Row(None)]) |
| |
| # Test 16: Mixed parameter usage should raise PySparkValueError |
| with self.assertRaises(PySparkValueError) as context: |
| F.try_make_timestamp_ntz(years=df_full.year, date=df_full.year) |
| error_msg = str(context.exception) |
| self.assertIn("CANNOT_SET_TOGETHER", error_msg) |
| self.assertIn("years|months|days|hours|mins|secs and date|time", error_msg) |
| |
| def test_string_functions(self): |
| string_functions = [ |
| "upper", |
| "lower", |
| "ascii", |
| "base64", |
| "unbase64", |
| "ltrim", |
| "rtrim", |
| "trim", |
| ] |
| |
| df = self.spark.createDataFrame([["nick"]], schema=["name"]) |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.col("name").substr(0, F.lit(1)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_SAME_TYPE", |
| messageParameters={ |
| "arg_name1": "startPos", |
| "arg_name2": "length", |
| "arg_type1": "int", |
| "arg_type2": "Column", |
| }, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.col("name").substr("", "") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_INT", |
| messageParameters={ |
| "arg_name": "startPos", |
| "arg_type": "str", |
| }, |
| ) |
| |
| for name in string_functions: |
| assertDataFrameEqual( |
| df.select(getattr(F, name)("name")), |
| df.select(getattr(F, name)(F.col("name"))), |
| ) |
| |
| def test_collation(self): |
| df = self.spark.createDataFrame([("a",), ("b",)], ["name"]) |
| actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct() |
| assertDataFrameEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual) |
| |
| def test_try_make_interval(self): |
| df = self.spark.createDataFrame([(2147483647,)], ["num"]) |
| actual = df.select(F.isnull(F.try_make_interval("num"))) |
| assertDataFrameEqual([Row(True)], actual) |
| |
| def test_octet_length_function(self): |
| # SPARK-36751: add octet length api for python |
| df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"]) |
| actual = df.select(F.octet_length("cat")) |
| assertDataFrameEqual([Row(3), Row(4)], actual) |
| |
| def test_bit_length_function(self): |
| # SPARK-36751: add bit length api for python |
| df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"]) |
| actual = df.select(F.bit_length("cat")) |
| assertDataFrameEqual([Row(24), Row(32)], actual) |
| |
| def test_array_contains_function(self): |
| df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"]) |
| actual = df.select(F.array_contains(df.data, "1").alias("b")) |
| assertDataFrameEqual([Row(b=True), Row(b=False)], actual) |
| |
| def test_levenshtein_function(self): |
| df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"]) |
| actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")) |
| assertDataFrameEqual([Row(b=3)], actual_without_threshold) |
| actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")) |
| assertDataFrameEqual([Row(b=-1)], actual_with_threshold) |
| |
| def test_between_function(self): |
| df = self.spark.createDataFrame( |
| [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)] |
| ) |
| assertDataFrameEqual( |
| [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)) |
| ) |
| |
| def test_dayofweek(self): |
| dt = datetime.datetime(2017, 11, 6) |
| df = self.spark.createDataFrame([Row(date=dt)]) |
| row = df.select(F.dayofweek(df.date)).first() |
| self.assertEqual(row[0], 2) |
| |
| def test_monthname(self): |
| dt = datetime.datetime(2017, 11, 6) |
| df = self.spark.createDataFrame([Row(date=dt)]) |
| row = df.select(F.monthname(df.date)).first() |
| self.assertEqual(row[0], "Nov") |
| |
| def test_dayname(self): |
| dt = datetime.datetime(2017, 11, 6) |
| df = self.spark.createDataFrame([Row(date=dt)]) |
| row = df.select(F.dayname(df.date)).first() |
| self.assertEqual(row[0], "Mon") |
| |
| def test_hour(self): |
| # SPARK-52892: test the hour function with time. |
| df = self.spark.range(1).select(F.lit(datetime.time(12, 34, 56)).alias("time")) |
| row_from_col = df.select(F.hour(df.time)).first() |
| self.assertEqual(row_from_col[0], 12) |
| row_from_name = df.select(F.hour("time")).first() |
| self.assertEqual(row_from_name[0], 12) |
| |
| def test_minute(self): |
| # SPARK-52893: test the minute function with time. |
| df = self.spark.range(1).select(F.lit(datetime.time(12, 34, 56)).alias("time")) |
| row_from_col = df.select(F.minute(df.time)).first() |
| self.assertEqual(row_from_col[0], 34) |
| row_from_name = df.select(F.minute("time")).first() |
| self.assertEqual(row_from_name[0], 34) |
| |
| def test_second(self): |
| # SPARK-52894: test the second function with time. |
| df = self.spark.range(1).select(F.lit(datetime.time(12, 34, 56)).alias("time")) |
| row_from_col = df.select(F.second(df.time)).first() |
| self.assertEqual(row_from_col[0], 56) |
| row_from_name = df.select(F.second("time")).first() |
| self.assertEqual(row_from_name[0], 56) |
| |
| # Test added for SPARK-37738; change Python API to accept both col & int as input |
| def test_date_add_function(self): |
| dt = datetime.date(2021, 12, 27) |
| |
| # Note; number var in Python gets converted to LongType column; |
| # this is not supported by the function, so cast to Integer explicitly |
| df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add integer") |
| |
| self.assertTrue( |
| all( |
| df.select( |
| F.date_add(df.date, df.add) == datetime.date(2021, 12, 29), |
| F.date_add(df.date, "add") == datetime.date(2021, 12, 29), |
| F.date_add(df.date, 3) == datetime.date(2021, 12, 30), |
| ).first() |
| ) |
| ) |
| |
| # Test added for SPARK-37738; change Python API to accept both col & int as input |
| def test_date_sub_function(self): |
| dt = datetime.date(2021, 12, 27) |
| |
| # Note; number var in Python gets converted to LongType column; |
| # this is not supported by the function, so cast to Integer explicitly |
| df = self.spark.createDataFrame([Row(date=dt, sub=2)], "date date, sub integer") |
| |
| self.assertTrue( |
| all( |
| df.select( |
| F.date_sub(df.date, df.sub) == datetime.date(2021, 12, 25), |
| F.date_sub(df.date, "sub") == datetime.date(2021, 12, 25), |
| F.date_sub(df.date, 3) == datetime.date(2021, 12, 24), |
| ).first() |
| ) |
| ) |
| |
| # Test added for SPARK-37738; change Python API to accept both col & int as input |
| def test_add_months_function(self): |
| dt = datetime.date(2021, 12, 27) |
| |
| # Note; number in Python gets converted to LongType column; |
| # this is not supported by the function, so cast to Integer explicitly |
| df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add integer") |
| |
| self.assertTrue( |
| all( |
| df.select( |
| F.add_months(df.date, df.add) == datetime.date(2022, 2, 27), |
| F.add_months(df.date, "add") == datetime.date(2022, 2, 27), |
| F.add_months(df.date, 3) == datetime.date(2022, 3, 27), |
| ).first() |
| ) |
| ) |
| |
| def test_make_time(self): |
| # SPARK-52888: test the make_time function. |
| df = self.spark.createDataFrame([(1, 2, 3)], ["hour", "minute", "second"]) |
| result = datetime.time(1, 2, 3) |
| row_from_col = df.select(F.make_time(df.hour, df.minute, df.second)).first() |
| self.assertIsInstance(row_from_col[0], datetime.time) |
| self.assertEqual(row_from_col[0], result) |
| row_from_name = df.select(F.make_time("hour", "minute", "second")).first() |
| self.assertIsInstance(row_from_name[0], datetime.time) |
| self.assertEqual(row_from_name[0], result) |
| |
| def test_make_timestamp_ntz(self): |
| """Comprehensive test cases for make_timestamp_ntz with various arguments and edge cases.""" |
| |
| # Test 1: Basic 6 positional arguments |
| data = [(2024, 5, 22, 10, 30, 0)] |
| result = datetime.datetime(2024, 5, 22, 10, 30) |
| df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) |
| |
| actual = df.select( |
| F.make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) |
| ) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 2: All 6 keyword arguments |
| actual = df.select( |
| F.make_timestamp_ntz( |
| years=df.year, |
| months=df.month, |
| days=df.day, |
| hours=df.hour, |
| mins=df.minute, |
| secs=df.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 3: Date/time keyword arguments |
| df_dt = self.spark.range(1).select( |
| F.lit(datetime.date(2024, 5, 22)).alias("date"), |
| F.lit(datetime.time(10, 30, 0)).alias("time"), |
| ) |
| actual = df_dt.select(F.make_timestamp_ntz(date=df_dt.date, time=df_dt.time)) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 4: Fractional seconds with positional arguments |
| data_frac = [(2024, 5, 22, 10, 30, 45.123)] |
| result_frac = datetime.datetime(2024, 5, 22, 10, 30, 45, 123000) |
| df_frac = self.spark.createDataFrame( |
| data_frac, ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| |
| actual = df_frac.select( |
| F.make_timestamp_ntz( |
| df_frac.year, |
| df_frac.month, |
| df_frac.day, |
| df_frac.hour, |
| df_frac.minute, |
| df_frac.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(result_frac)]) |
| |
| # Test 5: Fractional seconds with keyword arguments |
| actual = df_frac.select( |
| F.make_timestamp_ntz( |
| years=df_frac.year, |
| months=df_frac.month, |
| days=df_frac.day, |
| hours=df_frac.hour, |
| mins=df_frac.minute, |
| secs=df_frac.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(result_frac)]) |
| |
| # Test 6: Fractional seconds with date/time arguments |
| df_dt_frac = self.spark.range(1).select( |
| F.lit(datetime.date(2024, 5, 22)).alias("date"), |
| F.lit(datetime.time(10, 30, 45, 123000)).alias("time"), |
| ) |
| actual = df_dt_frac.select(F.make_timestamp_ntz(date=df_dt_frac.date, time=df_dt_frac.time)) |
| assertDataFrameEqual(actual, [Row(result_frac)]) |
| |
| # Test 7: Edge case - February 29 in leap year |
| df_leap = self.spark.createDataFrame( |
| [(2024, 2, 29, 0, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| expected_leap = datetime.datetime(2024, 2, 29, 0, 0, 0) |
| actual = df_leap.select( |
| F.make_timestamp_ntz( |
| df_leap.year, |
| df_leap.month, |
| df_leap.day, |
| df_leap.hour, |
| df_leap.minute, |
| df_leap.second, |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(expected_leap)]) |
| |
| # Test 8: Maximum valid time values |
| df_max = self.spark.createDataFrame( |
| [(2024, 12, 31, 23, 59, 59)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| expected_max = datetime.datetime(2024, 12, 31, 23, 59, 59) |
| actual = df_max.select( |
| F.make_timestamp_ntz( |
| df_max.year, df_max.month, df_max.day, df_max.hour, df_max.minute, df_max.second |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(expected_max)]) |
| |
| # Test 9: Minimum valid values |
| df_min = self.spark.createDataFrame( |
| [(1, 1, 1, 0, 0, 0)], ["year", "month", "day", "hour", "minute", "second"] |
| ) |
| expected_min = datetime.datetime(1, 1, 1, 0, 0, 0) |
| actual = df_min.select( |
| F.make_timestamp_ntz( |
| df_min.year, df_min.month, df_min.day, df_min.hour, df_min.minute, df_min.second |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(expected_min)]) |
| |
| # Test 10: Mixed positional and keyword (should work for valid combinations) |
| actual = df.select( |
| F.make_timestamp_ntz( |
| df.year, df.month, df.day, hours=df.hour, mins=df.minute, secs=df.second |
| ) |
| ) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 11: Using literal values |
| actual = self.spark.range(1).select( |
| F.make_timestamp_ntz(F.lit(2024), F.lit(5), F.lit(22), F.lit(10), F.lit(30), F.lit(0)) |
| ) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Test 12: Using string column names |
| actual = df.select(F.make_timestamp_ntz("year", "month", "day", "hour", "minute", "second")) |
| assertDataFrameEqual(actual, [Row(result)]) |
| |
| # Error handling tests |
| |
| # Test 13: Mixing timestamp and date/time keyword arguments |
| with self.assertRaises(PySparkValueError) as context: |
| df_dt.select( |
| F.make_timestamp_ntz(years=df.year, date=df_dt.date, time=df_dt.time) |
| ).collect() |
| error_msg = str(context.exception) |
| self.assertIn("CANNOT_SET_TOGETHER", error_msg) |
| self.assertIn("years|months|days|hours|mins|secs and date|time", error_msg) |
| |
| # Test 14: Incomplete keyword arguments - should raise Exception for None values |
| with self.assertRaises(Exception): |
| F.make_timestamp_ntz(years=df.year, months=df.month, days=df.day) |
| |
| # Test 15: Only one keyword argument - should raise Exception for None values |
| with self.assertRaises(Exception): |
| F.make_timestamp_ntz(years=df.year) |
| |
| # Test 16: Only date without time - should raise Exception for None values |
| with self.assertRaises(Exception): |
| F.make_timestamp_ntz(date=df_dt.date) |
| |
| def test_make_date(self): |
| # SPARK-36554: expose make_date expression |
| df = self.spark.createDataFrame([(2020, 6, 26)], ["Y", "M", "D"]) |
| row_from_col = df.select(F.make_date(df.Y, df.M, df.D)).first() |
| self.assertEqual(row_from_col[0], datetime.date(2020, 6, 26)) |
| row_from_name = df.select(F.make_date("Y", "M", "D")).first() |
| self.assertEqual(row_from_name[0], datetime.date(2020, 6, 26)) |
| |
| def test_expr(self): |
| row = Row(a="length string", b=75) |
| df = self.spark.createDataFrame([row]) |
| result = df.select(F.expr("length(a)")).collect()[0].asDict() |
| self.assertEqual(13, result["length(a)"]) |
| |
| # add test for SPARK-10577 (test broadcast join hint) |
| def test_functions_broadcast(self): |
| df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) |
| df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) |
| |
| # equijoin - should be converted into broadcast join |
| with io.StringIO() as buf, redirect_stdout(buf): |
| df1.join(F.broadcast(df2), "key").explain(True) |
| self.assertGreaterEqual(buf.getvalue().count("Broadcast"), 1) |
| |
| # no join key -- should not be a broadcast join |
| with io.StringIO() as buf, redirect_stdout(buf): |
| df1.crossJoin(F.broadcast(df2)).explain(True) |
| self.assertGreaterEqual(buf.getvalue().count("Broadcast"), 1) |
| |
| # planner should not crash without a join |
| F.broadcast(df1).explain(True) |
| |
| def test_first_last_ignorenulls(self): |
| df = self.spark.range(0, 100) |
| df2 = df.select(F.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) |
| df3 = df2.select( |
| F.first(df2.id, False).alias("a"), |
| F.first(df2.id, True).alias("b"), |
| F.last(df2.id, False).alias("c"), |
| F.last(df2.id, True).alias("d"), |
| ) |
| assertDataFrameEqual([Row(a=None, b=1, c=None, d=98)], df3) |
| |
| def test_approxQuantile(self): |
| df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in range(10)]) |
| for f in ["a", "a"]: |
| aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) |
| self.assertTrue(isinstance(aq, list)) |
| self.assertEqual(len(aq), 3) |
| self.assertTrue(all(isinstance(q, float) for q in aq)) |
| aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1) |
| self.assertTrue(isinstance(aqs, list)) |
| self.assertEqual(len(aqs), 2) |
| self.assertTrue(isinstance(aqs[0], list)) |
| self.assertEqual(len(aqs[0]), 3) |
| self.assertTrue(all(isinstance(q, float) for q in aqs[0])) |
| self.assertTrue(isinstance(aqs[1], list)) |
| self.assertEqual(len(aqs[1]), 3) |
| self.assertTrue(all(isinstance(q, float) for q in aqs[1])) |
| aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1) |
| self.assertTrue(isinstance(aqt, list)) |
| self.assertEqual(len(aqt), 2) |
| self.assertTrue(isinstance(aqt[0], list)) |
| self.assertEqual(len(aqt[0]), 3) |
| self.assertTrue(all(isinstance(q, float) for q in aqt[0])) |
| self.assertTrue(isinstance(aqt[1], list)) |
| self.assertEqual(len(aqt[1]), 3) |
| self.assertTrue(all(isinstance(q, float) for q in aqt[1])) |
| self.assertRaises(TypeError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) |
| self.assertRaises(TypeError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) |
| self.assertRaises(TypeError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) |
| |
| def test_sorting_functions_with_column(self): |
| self.check_sorting_functions_with_column(Column) |
| |
| def check_sorting_functions_with_column(self, tpe): |
| funs = [F.asc_nulls_first, F.asc_nulls_last, F.desc_nulls_first, F.desc_nulls_last] |
| exprs = [F.col("x"), "x"] |
| |
| for fun in funs: |
| for _expr in exprs: |
| res = fun(_expr) |
| self.assertIsInstance(res, tpe) |
| self.assertIn(f"""'x {fun.__name__.replace("_", " ").upper()}'""", str(res)) |
| |
| for _expr in exprs: |
| res = F.asc(_expr) |
| self.assertIsInstance(res, tpe) |
| self.assertIn("""'x ASC NULLS FIRST'""", str(res)) |
| |
| for _expr in exprs: |
| res = F.desc(_expr) |
| self.assertIsInstance(res, tpe) |
| self.assertIn("""'x DESC NULLS LAST'""", str(res)) |
| |
| def test_sort_with_nulls_order(self): |
| df = self.spark.createDataFrame( |
| [("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"] |
| ) |
| assertDataFrameEqual( |
| df.select(df.name).orderBy(F.asc_nulls_first("name")), |
| [Row(name=None), Row(name="Alice"), Row(name="Tom")], |
| ) |
| assertDataFrameEqual( |
| df.select(df.name).orderBy(F.asc_nulls_last("name")), |
| [Row(name="Alice"), Row(name="Tom"), Row(name=None)], |
| ) |
| assertDataFrameEqual( |
| df.select(df.name).orderBy(F.desc_nulls_first("name")), |
| [Row(name=None), Row(name="Tom"), Row(name="Alice")], |
| ) |
| assertDataFrameEqual( |
| df.select(df.name).orderBy(F.desc_nulls_last("name")), |
| [Row(name="Tom"), Row(name="Alice"), Row(name=None)], |
| ) |
| |
| def test_input_file_name_reset_for_rdd(self): |
| rdd = self.sc.textFile("python/test_support/hello/hello.txt").map(lambda x: {"data": x}) |
| df = self.spark.createDataFrame(rdd, "data STRING") |
| df.select(F.input_file_name().alias("file")).collect() |
| |
| non_file_df = self.spark.range(100).select(F.input_file_name()) |
| |
| results = non_file_df.collect() |
| self.assertTrue(len(results) == 100) |
| |
| # [SPARK-24605]: if everything was properly reset after the last job, this should return |
| # empty string rather than the file read in the last job. |
| for result in results: |
| self.assertEqual(result[0], "") |
| |
| def test_slice(self): |
| df = self.spark.createDataFrame( |
| [ |
| ( |
| [1, 2, 3], |
| 2, |
| 2, |
| ), |
| ( |
| [4, 5], |
| 2, |
| 2, |
| ), |
| ], |
| ["x", "index", "len"], |
| ) |
| |
| expected = [Row(sliced=[2, 3]), Row(sliced=[5])] |
| assertDataFrameEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")), expected) |
| assertDataFrameEqual(df.select(F.slice(df.x, F.lit(2), F.lit(2)).alias("sliced")), expected) |
| assertDataFrameEqual(df.select(F.slice("x", "index", "len").alias("sliced")), expected) |
| |
| assertDataFrameEqual( |
| df.select(F.slice(df.x, F.size(df.x) - 1, F.lit(1)).alias("sliced")), |
| [Row(sliced=[2]), Row(sliced=[4])], |
| ) |
| assertDataFrameEqual( |
| df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 1).alias("sliced")), |
| [Row(sliced=[1, 2]), Row(sliced=[4])], |
| ) |
| |
| def test_array_repeat(self): |
| df = self.spark.range(1) |
| df = df.withColumn("repeat_n", F.lit(3)) |
| |
| expected = [Row(val=[0, 0, 0])] |
| assertDataFrameEqual(df.select(F.array_repeat("id", 3).alias("val")), expected) |
| assertDataFrameEqual(df.select(F.array_repeat("id", F.lit(3)).alias("val")), expected) |
| assertDataFrameEqual(df.select(F.array_repeat("id", "repeat_n").alias("val")), expected) |
| |
| def test_input_file_name_udf(self): |
| df = self.spark.read.text("python/test_support/hello/hello.txt") |
| df = df.select(F.udf(lambda x: x)("value"), F.input_file_name().alias("file")) |
| file_name = df.collect()[0].file |
| self.assertTrue("python/test_support/hello/hello.txt" in file_name) |
| |
| def test_least(self): |
| df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) |
| |
| expected = [Row(least=1)] |
| assertDataFrameEqual(df.select(F.least(df.a, df.b, df.c).alias("least")), expected) |
| assertDataFrameEqual( |
| df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")), expected |
| ) |
| assertDataFrameEqual(df.select(F.least("a", "b", "c").alias("least")), expected) |
| |
| with self.assertRaises(PySparkValueError) as pe: |
| df.select(F.least(df.a).alias("least")).collect() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="WRONG_NUM_COLUMNS", |
| messageParameters={"func_name": "least", "num_cols": "2"}, |
| ) |
| |
| def test_overlay(self): |
| actual = list( |
| chain.from_iterable( |
| [ |
| re.findall("(overlay\\(.*\\))", str(x)) |
| for x in [ |
| F.overlay(F.col("foo"), F.col("bar"), 1), |
| F.overlay("x", "y", 3), |
| F.overlay(F.col("x"), F.col("y"), 1, 3), |
| F.overlay("x", "y", 2, 5), |
| F.overlay("x", "y", F.lit(11)), |
| F.overlay("x", "y", F.lit(2), F.lit(5)), |
| ] |
| ] |
| ) |
| ) |
| |
| expected = [ |
| "overlay(foo, bar, 1, -1)", |
| "overlay(x, y, 3, -1)", |
| "overlay(x, y, 1, 3)", |
| "overlay(x, y, 2, 5)", |
| "overlay(x, y, 11, -1)", |
| "overlay(x, y, 2, 5)", |
| ] |
| |
| self.assertListEqual(actual, expected) |
| |
| df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", "y", "pos", "len")) |
| |
| exp = [Row(ol="SPARK_CORESQL")] |
| assertDataFrameEqual(df.select(F.overlay(df.x, df.y, 7, 0).alias("ol")), exp) |
| assertDataFrameEqual(df.select(F.overlay(df.x, df.y, F.lit(7), F.lit(0)).alias("ol")), exp) |
| assertDataFrameEqual(df.select(F.overlay("x", "y", "pos", "len").alias("ol")), exp) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.select(F.overlay(df.x, df.y, 7.5, 0).alias("ol")).collect() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_INT_OR_STR", |
| messageParameters={"arg_name": "pos", "arg_type": "float"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.select(F.overlay(df.x, df.y, 7, 0.5).alias("ol")).collect() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_INT_OR_STR", |
| messageParameters={"arg_name": "len", "arg_type": "float"}, |
| ) |
| |
| def test_percentile(self): |
| actual = list( |
| chain.from_iterable( |
| [ |
| re.findall("(percentile\\(.*\\))", str(x)) |
| for x in [ |
| F.percentile(F.col("foo"), F.lit(0.5)), |
| F.percentile(F.col("bar"), 0.25, 2), |
| F.percentile(F.col("bar"), [0.25, 0.5, 0.75]), |
| F.percentile(F.col("foo"), (0.05, 0.95), 100), |
| F.percentile("foo", 0.5), |
| F.percentile("bar", [0.1, 0.9], F.lit(10)), |
| ] |
| ] |
| ) |
| ) |
| |
| expected = [ |
| "percentile(foo, 0.5, 1)", |
| "percentile(bar, 0.25, 2)", |
| "percentile(bar, array(0.25, 0.5, 0.75), 1)", |
| "percentile(foo, array(0.05, 0.95), 100)", |
| "percentile(foo, 0.5, 1)", |
| "percentile(bar, array(0.1, 0.9), 10)", |
| ] |
| |
| self.assertListEqual(actual, expected) |
| |
| def test_median(self): |
| actual = list( |
| chain.from_iterable( |
| [ |
| re.findall("(median\\(.*\\))", str(x)) |
| for x in [ |
| F.median(F.col("foo")), |
| ] |
| ] |
| ) |
| ) |
| |
| expected = [ |
| "median(foo)", |
| ] |
| |
| self.assertListEqual(actual, expected) |
| |
| def test_percentile_approx(self): |
| actual = list( |
| chain.from_iterable( |
| [ |
| re.findall("(percentile_approx\\(.*\\))", str(x)) |
| for x in [ |
| F.percentile_approx(F.col("foo"), F.lit(0.5)), |
| F.percentile_approx(F.col("bar"), 0.25, 42), |
| F.percentile_approx(F.col("bar"), [0.25, 0.5, 0.75]), |
| F.percentile_approx(F.col("foo"), (0.05, 0.95), 100), |
| F.percentile_approx("foo", 0.5), |
| F.percentile_approx("bar", [0.1, 0.9], F.lit(10)), |
| ] |
| ] |
| ) |
| ) |
| |
| expected = [ |
| "percentile_approx(foo, 0.5, 10000)", |
| "percentile_approx(bar, 0.25, 42)", |
| "percentile_approx(bar, array(0.25, 0.5, 0.75), 10000)", |
| "percentile_approx(foo, array(0.05, 0.95), 100)", |
| "percentile_approx(foo, 0.5, 10000)", |
| "percentile_approx(bar, array(0.1, 0.9), 10)", |
| ] |
| |
| self.assertListEqual(actual, expected) |
| |
| def test_nth_value(self): |
| df = self.spark.createDataFrame( |
| [ |
| ("a", 0, None), |
| ("a", 1, "x"), |
| ("a", 2, "y"), |
| ("a", 3, "z"), |
| ("a", 4, None), |
| ("b", 1, None), |
| ("b", 2, None), |
| ], |
| schema=("key", "order", "value"), |
| ) |
| w = Window.partitionBy("key").orderBy("order") |
| |
| rs = df.select( |
| df.key, |
| df.order, |
| F.nth_value("value", 2).over(w), |
| F.nth_value("value", 2, False).over(w), |
| F.nth_value("value", 2, True).over(w), |
| ).collect() |
| |
| expected = [ |
| ("a", 0, None, None, None), |
| ("a", 1, "x", "x", None), |
| ("a", 2, "x", "x", "y"), |
| ("a", 3, "x", "x", "y"), |
| ("a", 4, "x", "x", "y"), |
| ("b", 1, None, None, None), |
| ("b", 2, None, None, None), |
| ] |
| |
| for r, ex in zip(sorted(rs), sorted(expected)): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| def test_higher_order_function_failures(self): |
| # Should fail with varargs |
| with self.assertRaises(PySparkValueError) as pe: |
| F.transform(F.col("foo"), lambda *x: F.lit(1)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="UNSUPPORTED_PARAM_TYPE_FOR_HIGHER_ORDER_FUNCTION", |
| messageParameters={"func_name": "<lambda>"}, |
| ) |
| |
| # Should fail with kwargs |
| with self.assertRaises(PySparkValueError) as pe: |
| F.transform(F.col("foo"), lambda **x: F.lit(1)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="UNSUPPORTED_PARAM_TYPE_FOR_HIGHER_ORDER_FUNCTION", |
| messageParameters={"func_name": "<lambda>"}, |
| ) |
| |
| # Should fail with nullary function |
| with self.assertRaises(PySparkValueError) as pe: |
| F.transform(F.col("foo"), lambda: F.lit(1)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION", |
| messageParameters={"func_name": "<lambda>", "num_args": "0"}, |
| ) |
| |
| # Should fail with quaternary function |
| with self.assertRaises(PySparkValueError) as pe: |
| F.transform(F.col("foo"), lambda x1, x2, x3, x4: F.lit(1)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION", |
| messageParameters={"func_name": "<lambda>", "num_args": "4"}, |
| ) |
| |
| # Should fail if function doesn't return Column |
| with self.assertRaises(PySparkValueError) as pe: |
| F.transform(F.col("foo"), lambda x: 1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN", |
| messageParameters={"func_name": "<lambda>", "return_type": "int"}, |
| ) |
| |
| def test_nested_higher_order_function(self): |
| # SPARK-35382: lambda vars must be resolved properly in nested higher order functions |
| df = self.spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters") |
| |
| actual = df.select( |
| F.flatten( |
| F.transform( |
| "numbers", |
| lambda number: F.transform( |
| "letters", lambda letter: F.struct(number.alias("n"), letter.alias("l")) |
| ), |
| ) |
| ) |
| ).first()[0] |
| |
| expected = [ |
| (1, "a"), |
| (1, "b"), |
| (1, "c"), |
| (2, "a"), |
| (2, "b"), |
| (2, "c"), |
| (3, "a"), |
| (3, "b"), |
| (3, "c"), |
| ] |
| |
| self.assertEqual(actual, expected) |
| |
| def test_window_functions(self): |
| df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) |
| w = Window.partitionBy("value").orderBy("key") |
| |
| sel = df.select( |
| df.value, |
| df.key, |
| F.max("key").over(w.rowsBetween(0, 1)), |
| F.min("key").over(w.rowsBetween(0, 1)), |
| F.count("key").over(w.rowsBetween(float("-inf"), float("inf"))), |
| F.row_number().over(w), |
| F.rank().over(w), |
| F.dense_rank().over(w), |
| F.ntile(2).over(w), |
| ) |
| rs = sorted(sel.collect()) |
| expected = [ |
| ("1", 1, 1, 1, 1, 1, 1, 1, 1), |
| ("2", 1, 1, 1, 3, 1, 1, 1, 1), |
| ("2", 1, 2, 1, 3, 2, 1, 1, 1), |
| ("2", 2, 2, 2, 3, 3, 3, 2, 2), |
| ] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| def test_window_functions_without_partitionBy(self): |
| df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) |
| w = Window.orderBy("key", df.value) |
| |
| sel = df.select( |
| df.value, |
| df.key, |
| F.max("key").over(w.rowsBetween(0, 1)), |
| F.min("key").over(w.rowsBetween(0, 1)), |
| F.count("key").over(w.rowsBetween(float("-inf"), float("inf"))), |
| F.row_number().over(w), |
| F.rank().over(w), |
| F.dense_rank().over(w), |
| F.ntile(2).over(w), |
| ) |
| rs = sorted(sel.collect()) |
| expected = [ |
| ("1", 1, 1, 1, 4, 1, 1, 1, 1), |
| ("2", 1, 1, 1, 4, 2, 2, 2, 1), |
| ("2", 1, 2, 1, 4, 3, 2, 2, 2), |
| ("2", 2, 2, 2, 4, 4, 4, 3, 2), |
| ] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| def test_window_functions_cumulative_sum(self): |
| df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) |
| |
| # Test cumulative sum |
| sel = df.select( |
| df.key, F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)) |
| ) |
| rs = sorted(sel.collect()) |
| expected = [("one", 1), ("two", 3)] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow |
| sel = df.select( |
| df.key, F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)) |
| ) |
| rs = sorted(sel.collect()) |
| expected = [("one", 1), ("two", 3)] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow |
| frame_end = Window.unboundedFollowing + 1 |
| sel = df.select( |
| df.key, F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)) |
| ) |
| rs = sorted(sel.collect()) |
| expected = [("one", 3), ("two", 2)] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| def test_window_functions_moving_average(self): |
| data = [ |
| (datetime.datetime(2023, 1, 1), 20), |
| (datetime.datetime(2023, 1, 2), 22), |
| (datetime.datetime(2023, 1, 3), 21), |
| (datetime.datetime(2023, 1, 4), 23), |
| (datetime.datetime(2023, 1, 5), 24), |
| (datetime.datetime(2023, 1, 6), 26), |
| ] |
| df = self.spark.createDataFrame(data, ["date", "temperature"]) |
| |
| def to_sec(i): |
| return i * 86400 |
| |
| w = Window.orderBy(F.col("date").cast("timestamp").cast("long")).rangeBetween(-to_sec(3), 0) |
| res = df.withColumn("3_day_avg_temp", F.avg("temperature").over(w)) |
| rs = sorted(res.collect()) |
| expected = [ |
| (datetime.datetime(2023, 1, 1, 0, 0), 20, 20.0), |
| (datetime.datetime(2023, 1, 2, 0, 0), 22, 21.0), |
| (datetime.datetime(2023, 1, 3, 0, 0), 21, 21.0), |
| (datetime.datetime(2023, 1, 4, 0, 0), 23, 21.5), |
| (datetime.datetime(2023, 1, 5, 0, 0), 24, 22.5), |
| (datetime.datetime(2023, 1, 6, 0, 0), 26, 23.5), |
| ] |
| for r, ex in zip(rs, expected): |
| self.assertEqual(tuple(r), ex[: len(r)]) |
| |
| def test_window_time(self): |
| df = self.spark.createDataFrame( |
| [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], ["date", "val"] |
| ) |
| |
| w = df.groupBy(F.window("date", "5 seconds", "5 seconds")).agg(F.sum("val").alias("sum")) |
| r = w.select( |
| w.window.end.cast("string").alias("end"), |
| F.window_time(w.window).cast("string").alias("window_time"), |
| "sum", |
| ).collect() |
| self.assertEqual( |
| r[0], Row(end="2016-03-11 09:00:10", window_time="2016-03-11 09:00:09.999999", sum=1) |
| ) |
| |
| def test_collect_functions(self): |
| df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) |
| |
| self.assertEqual(sorted(df.select(F.collect_set(df.key).alias("r")).collect()[0].r), [1, 2]) |
| self.assertEqual( |
| sorted(df.select(F.collect_list(df.key).alias("r")).collect()[0].r), [1, 1, 1, 2] |
| ) |
| self.assertEqual( |
| sorted(df.select(F.collect_set(df.value).alias("r")).collect()[0].r), ["1", "2"] |
| ) |
| self.assertEqual( |
| sorted(df.select(F.collect_list(df.value).alias("r")).collect()[0].r), |
| ["1", "2", "2", "2"], |
| ) |
| |
| def test_listagg_functions(self): |
| df = self.spark.createDataFrame( |
| [(1, "1"), (2, "2"), (None, None), (1, "2")], ["key", "value"] |
| ) |
| df_with_bytes = self.spark.createDataFrame( |
| [(b"\x01",), (b"\x02",), (None,), (b"\x03",), (b"\x02",)], ["bytes"] |
| ) |
| df_with_nulls = self.spark.createDataFrame( |
| [(None,), (None,), (None,), (None,), (None,)], |
| StructType([StructField("nulls", StringType(), True)]), |
| ) |
| # listagg and string_agg are aliases |
| for listagg_ref in [F.listagg, F.string_agg]: |
| self.assertEqual(df.select(listagg_ref(df.key).alias("r")).collect()[0].r, "121") |
| self.assertEqual(df.select(listagg_ref(df.value).alias("r")).collect()[0].r, "122") |
| self.assertEqual( |
| df.select(listagg_ref(df.value, ",").alias("r")).collect()[0].r, "1,2,2" |
| ) |
| self.assertEqual( |
| df_with_bytes.select(listagg_ref(df_with_bytes.bytes, b"\x42").alias("r")) |
| .collect()[0] |
| .r, |
| b"\x01\x42\x02\x42\x03\x42\x02", |
| ) |
| self.assertEqual( |
| df_with_nulls.select(listagg_ref(df_with_nulls.nulls).alias("r")).collect()[0].r, |
| None, |
| ) |
| |
| def test_listagg_distinct_functions(self): |
| df = self.spark.createDataFrame( |
| [(1, "1"), (2, "2"), (None, None), (1, "2")], ["key", "value"] |
| ) |
| df_with_bytes = self.spark.createDataFrame( |
| [(b"\x01",), (b"\x02",), (None,), (b"\x03",), (b"\x02",)], ["bytes"] |
| ) |
| df_with_nulls = self.spark.createDataFrame( |
| [(None,), (None,), (None,), (None,), (None,)], |
| StructType([StructField("nulls", StringType(), True)]), |
| ) |
| # listagg_distinct and string_agg_distinct are aliases |
| for listagg_distinct_ref in [F.listagg_distinct, F.string_agg_distinct]: |
| self.assertEqual( |
| df.select(listagg_distinct_ref(df.key).alias("r")).collect()[0].r, "12" |
| ) |
| self.assertEqual( |
| df.select(listagg_distinct_ref(df.value).alias("r")).collect()[0].r, "12" |
| ) |
| self.assertEqual( |
| df.select(listagg_distinct_ref(df.value, ",").alias("r")).collect()[0].r, "1,2" |
| ) |
| self.assertEqual( |
| df_with_bytes.select(listagg_distinct_ref(df_with_bytes.bytes, b"\x42").alias("r")) |
| .collect()[0] |
| .r, |
| b"\x01\x42\x02\x42\x03", |
| ) |
| self.assertEqual( |
| df_with_nulls.select(listagg_distinct_ref(df_with_nulls.nulls).alias("r")) |
| .collect()[0] |
| .r, |
| None, |
| ) |
| |
| def test_datetime_functions(self): |
| df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") |
| parse_result = df.select(F.to_date(F.col("dateCol"))).first() |
| self.assertEqual(datetime.date(2017, 1, 22), parse_result["to_date(dateCol)"]) |
| |
| def test_try_datetime_functions(self): |
| df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") |
| parse_result = df.select(F.try_to_date(F.col("dateCol")).alias("tryToDateCol")).first() |
| self.assertEqual(datetime.date(2017, 1, 22), parse_result["tryToDateCol"]) |
| |
| def test_assert_true(self): |
| self.check_assert_true(SparkRuntimeException) |
| |
| def check_assert_true(self, tpe): |
| df = self.spark.range(3) |
| |
| assertDataFrameEqual( |
| df.select(F.assert_true(df.id < 3)).toDF("val"), |
| [Row(val=None), Row(val=None), Row(val=None)], |
| ) |
| |
| with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] too big"): |
| df.select(F.assert_true(df.id < 2, "too big")).toDF("val").collect() |
| |
| with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] 2000000.0"): |
| df.select(F.assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect() |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.select(F.assert_true(df.id < 2, 5)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "errMsg", "arg_type": "int"}, |
| ) |
| |
| def test_raise_error(self): |
| self.check_raise_error(SparkRuntimeException) |
| |
| def check_raise_error(self, tpe): |
| df = self.spark.createDataFrame([Row(id="foobar")]) |
| |
| with self.assertRaisesRegex(tpe, "foobar"): |
| df.select(F.raise_error(df.id)).collect() |
| |
| with self.assertRaisesRegex(tpe, "barfoo"): |
| df.select(F.raise_error("barfoo")).collect() |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.select(F.raise_error(None)) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "errMsg", "arg_type": "NoneType"}, |
| ) |
| |
| def test_sum_distinct(self): |
| self.spark.range(10).select( |
| F.assert_true(F.sum_distinct(F.col("id")) == F.sumDistinct(F.col("id"))) |
| ).collect() |
| |
| def test_shiftleft(self): |
| self.spark.range(10).select( |
| F.assert_true(F.shiftLeft(F.col("id"), 2) == F.shiftleft(F.col("id"), 2)) |
| ).collect() |
| |
| def test_shiftright(self): |
| self.spark.range(10).select( |
| F.assert_true(F.shiftRight(F.col("id"), 2) == F.shiftright(F.col("id"), 2)) |
| ).collect() |
| |
| def test_shiftrightunsigned(self): |
| self.spark.range(10).select( |
| F.assert_true( |
| F.shiftRightUnsigned(F.col("id"), 2) == F.shiftrightunsigned(F.col("id"), 2) |
| ) |
| ).collect() |
| |
| def test_lit_time(self): |
| t = datetime.time(12, 34, 56) |
| actual = self.spark.range(1).select(F.lit(t)).first()[0] |
| self.assertEqual(actual, t) |
| |
| def test_lit_day_time_interval(self): |
| td = datetime.timedelta(days=1, hours=12, milliseconds=123) |
| actual = self.spark.range(1).select(F.lit(td)).first()[0] |
| self.assertEqual(actual, td) |
| |
| def test_lit_list(self): |
| # SPARK-40271: added list type supporting |
| test_list = [1, 2, 3] |
| expected = [1, 2, 3] |
| actual = self.spark.range(1).select(F.lit(test_list)).first()[0] |
| self.assertEqual(actual, expected) |
| |
| test_list = [[1, 2, 3], [3, 4]] |
| expected = [[1, 2, 3], [3, 4]] |
| actual = self.spark.range(1).select(F.lit(test_list)).first()[0] |
| self.assertEqual(actual, expected) |
| |
| with self.sql_conf({"spark.sql.ansi.enabled": False}): |
| test_list = ["a", 1, None, 1.0] |
| expected = ["a", "1", None, "1.0"] |
| actual = self.spark.range(1).select(F.lit(test_list)).first()[0] |
| self.assertEqual(actual, expected) |
| |
| test_list = [["a", 1, None, 1.0], [1, None, "b"]] |
| expected = [["a", "1", None, "1.0"], ["1", None, "b"]] |
| actual = self.spark.range(1).select(F.lit(test_list)).first()[0] |
| self.assertEqual(actual, expected) |
| |
| df = self.spark.range(10) |
| with self.assertRaises(PySparkValueError) as pe: |
| F.lit([df.id, df.id]) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="COLUMN_IN_LIST", |
| messageParameters={"func_name": "lit"}, |
| ) |
| |
| # Test added for SPARK-39832; change Python API to accept both col & str as input |
| def test_regexp_replace(self): |
| df = self.spark.createDataFrame( |
| [("100-200", r"(\d+)", "--")], ["str", "pattern", "replacement"] |
| ) |
| self.assertTrue( |
| all( |
| df.select( |
| F.regexp_replace("str", r"(\d+)", "--") == "-----", |
| F.regexp_replace("str", F.col("pattern"), F.col("replacement")) == "-----", |
| ).first() |
| ) |
| ) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_lit_np_scalar(self): |
| import numpy as np |
| |
| dtype_to_spark_dtypes = [ |
| (np.int8, [("1", "tinyint")]), |
| (np.int16, [("1", "smallint")]), |
| (np.int32, [("1", "int")]), |
| (np.int64, [("1", "bigint")]), |
| (np.float32, [("1.0", "float")]), |
| (np.float64, [("1.0", "double")]), |
| (np.bool_, [("true", "boolean")]), |
| ] |
| for dtype, spark_dtypes in dtype_to_spark_dtypes: |
| with self.subTest(dtype): |
| self.assertEqual(self.spark.range(1).select(F.lit(dtype(1))).dtypes, spark_dtypes) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_np_scalar_input(self): |
| import numpy as np |
| |
| df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"]) |
| for dtype in [np.int8, np.int16, np.int32, np.int64]: |
| res = df.select(F.array_contains(df.data, dtype(1)).alias("b")) |
| assertDataFrameEqual([Row(b=True), Row(b=False)], res) |
| res = df.select(F.array_position(df.data, dtype(1)).alias("c")) |
| assertDataFrameEqual([Row(c=1), Row(c=0)], res) |
| |
| df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"]) |
| for dtype in [np.float32, np.float64]: |
| res = df.select(F.array_contains(df.data, dtype(1)).alias("b")) |
| assertDataFrameEqual([Row(b=True), Row(b=False)], res) |
| res = df.select(F.array_position(df.data, dtype(1)).alias("c")) |
| assertDataFrameEqual([Row(c=1), Row(c=0)], res) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_ndarray_input(self): |
| import numpy as np |
| |
| arr_dtype_to_spark_dtypes = [ |
| ("int8", [("b", "array<tinyint>")]), |
| ("int16", [("b", "array<smallint>")]), |
| ("int32", [("b", "array<int>")]), |
| ("int64", [("b", "array<bigint>")]), |
| ("float32", [("b", "array<float>")]), |
| ("float64", [("b", "array<double>")]), |
| ] |
| for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes: |
| arr = np.array([1, 2]).astype(t) |
| self.assertEqual( |
| expected_spark_dtypes, self.spark.range(1).select(F.lit(arr).alias("b")).dtypes |
| ) |
| arr = np.array([1, 2]).astype(np.uint) |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.range(1).select(F.lit(arr).alias("b")) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR", |
| messageParameters={ |
| "dtype": "uint64", |
| }, |
| ) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_bool_ndarray(self): |
| import numpy as np |
| |
| for arr in [ |
| np.array([], np.bool_), |
| np.array([True, False], np.bool_), |
| np.array([1, 0, 3], np.bool_), |
| ]: |
| self.assertEqual( |
| [("a", "array<boolean>")], |
| self.spark.range(1).select(F.lit(arr).alias("a")).dtypes, |
| ) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_str_ndarray(self): |
| import numpy as np |
| |
| for arr in [ |
| np.array([], np.str_), |
| np.array(["a"], np.str_), |
| np.array([1, 2, 3], np.str_), |
| ]: |
| self.assertEqual( |
| [("a", "array<string>")], |
| self.spark.range(1).select(F.lit(arr).alias("a")).dtypes, |
| ) |
| |
| @unittest.skipIf(not have_numpy, "NumPy not installed") |
| def test_empty_ndarray(self): |
| import numpy as np |
| |
| arr_dtype_to_spark_dtypes = [ |
| ("int8", [("b", "array<tinyint>")]), |
| ("int16", [("b", "array<smallint>")]), |
| ("int32", [("b", "array<int>")]), |
| ("int64", [("b", "array<bigint>")]), |
| ("float32", [("b", "array<float>")]), |
| ("float64", [("b", "array<double>")]), |
| ] |
| for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes: |
| arr = np.array([]).astype(t) |
| self.assertEqual( |
| expected_spark_dtypes, self.spark.range(1).select(F.lit(arr).alias("b")).dtypes |
| ) |
| |
| def test_binary_math_function(self): |
| funcs, expected = zip( |
| *[(F.atan2, 0.13664), (F.hypot, 8.07527), (F.pow, 2.14359), (F.pmod, 1.1)] |
| ) |
| df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs)) |
| for a, e in zip(df.first(), expected): |
| self.assertAlmostEqual(a, e, 5) |
| |
| def test_map_functions(self): |
| # SPARK-38496: Check basic functionality of all "map" type related functions |
| expected = {"a": 1, "b": 2} |
| expected2 = {"c": 3, "d": 4} |
| df = self.spark.createDataFrame( |
| [(list(expected.keys()), list(expected.values()))], ["k", "v"] |
| ) |
| actual = ( |
| df.select( |
| F.expr("map('c', 3, 'd', 4) as dict2"), |
| F.map_from_arrays(df.k, df.v).alias("dict"), |
| "*", |
| ) |
| .select( |
| F.map_contains_key("dict", "a").alias("one"), |
| F.map_contains_key("dict", "d").alias("not_exists"), |
| F.map_keys("dict").alias("keys"), |
| F.map_values("dict").alias("values"), |
| F.map_entries("dict").alias("items"), |
| "*", |
| ) |
| .select( |
| F.map_concat("dict", "dict2").alias("merged"), |
| F.map_from_entries(F.arrays_zip("keys", "values")).alias("from_items"), |
| "*", |
| ) |
| .first() |
| ) |
| self.assertEqual(expected, actual["dict"]) |
| self.assertTrue(actual["one"]) |
| self.assertFalse(actual["not_exists"]) |
| self.assertEqual(list(expected.keys()), actual["keys"]) |
| self.assertEqual(list(expected.values()), actual["values"]) |
| self.assertEqual(expected, dict(actual["items"])) |
| self.assertEqual({**expected, **expected2}, dict(actual["merged"])) |
| self.assertEqual(expected, actual["from_items"]) |
| |
| def test_parse_json(self): |
| df = self.spark.createDataFrame([{"json": """{ "a" : 1 }"""}]) |
| actual = df.select( |
| F.to_json(F.parse_json(df.json)).alias("var"), |
| F.to_json(F.parse_json(F.lit("""{"b": [{"c": "str2"}]}"""))).alias("var_lit"), |
| ).first() |
| |
| self.assertEqual("""{"a":1}""", actual["var"]) |
| self.assertEqual("""{"b":[{"c":"str2"}]}""", actual["var_lit"]) |
| |
| def test_variant_expressions(self): |
| df = self.spark.createDataFrame( |
| [Row(json="""{ "a" : 1 }""", path="$.a"), Row(json="""{ "b" : 2 }""", path="$.b")] |
| ) |
| v = F.parse_json(df.json) |
| |
| def check(resultDf, expected): |
| self.assertEqual([r[0] for r in resultDf.collect()], expected) |
| |
| check(df.select(F.is_variant_null(v)), [False, False]) |
| check(df.select(F.schema_of_variant(v)), ["OBJECT<a: BIGINT>", "OBJECT<b: BIGINT>"]) |
| check(df.select(F.schema_of_variant_agg(v)), ["OBJECT<a: BIGINT, b: BIGINT>"]) |
| |
| check(df.select(F.variant_get(v, "$.a", "int")), [1, None]) |
| check(df.select(F.variant_get(v, "$.b", "int")), [None, 2]) |
| check(df.select(F.variant_get(v, "$.a", "double")), [1.0, None]) |
| |
| # non-literal variant_get |
| check(df.select(F.variant_get(v, df.path, "int")), [1, 2]) |
| check(df.select(F.try_variant_get(v, df.path, "binary")), [None, None]) |
| |
| with self.assertRaises(SparkRuntimeException) as ex: |
| df.select(F.variant_get(v, "$.a", "binary")).collect() |
| |
| self.check_error( |
| exception=ex.exception, |
| errorClass="INVALID_VARIANT_CAST", |
| messageParameters={"value": "1", "dataType": '"BINARY"'}, |
| ) |
| |
| check(df.select(F.try_variant_get(v, "$.a", "int")), [1, None]) |
| check(df.select(F.try_variant_get(v, "$.b", "int")), [None, 2]) |
| check(df.select(F.try_variant_get(v, "$.a", "double")), [1.0, None]) |
| check(df.select(F.try_variant_get(v, "$.a", "binary")), [None, None]) |
| |
| def test_schema_of_json(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.schema_of_json(1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "json", "arg_type": "int"}, |
| ) |
| |
| def test_try_parse_json(self): |
| df = self.spark.createDataFrame([{"json": """{ "a" : 1 }"""}, {"json": """{ a : 1 }"""}]) |
| actual = df.select( |
| F.to_json(F.try_parse_json(df.json)).alias("var"), |
| ).collect() |
| self.assertEqual("""{"a":1}""", actual[0]["var"]) |
| self.assertEqual(None, actual[1]["var"]) |
| |
| def test_try_to_time(self): |
| # SPARK-52891: test the try_to_time function. |
| df = self.spark.createDataFrame([("10:30:00", "HH:mm:ss")], ["time", "format"]) |
| result = datetime.time(10, 30, 0) |
| # Test without format. |
| row_from_col_no_format = df.select(F.try_to_time(df.time)).first() |
| self.assertIsInstance(row_from_col_no_format[0], datetime.time) |
| self.assertEqual(row_from_col_no_format[0], result) |
| row_from_name_no_format = df.select(F.try_to_time("time")).first() |
| self.assertIsInstance(row_from_name_no_format[0], datetime.time) |
| self.assertEqual(row_from_name_no_format[0], result) |
| # Test with format. |
| row_from_col_with_format = df.select(F.try_to_time(df.time, df.format)).first() |
| self.assertIsInstance(row_from_col_with_format[0], datetime.time) |
| self.assertEqual(row_from_col_with_format[0], result) |
| row_from_name_with_format = df.select(F.try_to_time("time", "format")).first() |
| self.assertIsInstance(row_from_name_with_format[0], datetime.time) |
| self.assertEqual(row_from_name_with_format[0], result) |
| # Test with malformed time. |
| df = self.spark.createDataFrame([("malformed", "HH:mm:ss")], ["time", "format"]) |
| row_from_col_no_format_malformed = df.select(F.try_to_time(df.time)).first() |
| self.assertIsNone(row_from_col_no_format_malformed[0]) |
| row_from_name_no_format_malformed = df.select(F.try_to_time("time")).first() |
| self.assertIsNone(row_from_name_no_format_malformed[0]) |
| row_from_col_with_format_malformed = df.select(F.try_to_time(df.time, df.format)).first() |
| self.assertIsNone(row_from_col_with_format_malformed[0]) |
| row_from_name_with_format_malformed = df.select(F.try_to_time("time", "format")).first() |
| self.assertIsNone(row_from_name_with_format_malformed[0]) |
| |
| def test_to_variant_object(self): |
| df = self.spark.createDataFrame([(1, {"a": 1})], "i int, v struct<a int>") |
| actual = df.select( |
| F.to_json(F.to_variant_object(df.v)).alias("var"), |
| ).collect() |
| self.assertEqual("""{"a":1}""", actual[0]["var"]) |
| |
| def test_schema_of_csv(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.schema_of_csv(1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "csv", "arg_type": "int"}, |
| ) |
| |
| def test_from_csv(self): |
| df = self.spark.range(10) |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.from_csv(df.id, 1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "schema", "arg_type": "int"}, |
| ) |
| |
| def test_schema_of_xml(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.schema_of_xml(1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "xml", "arg_type": "int"}, |
| ) |
| |
| def test_from_xml(self): |
| df = self.spark.range(10) |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.from_xml(df.id, 1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR_OR_STRUCT", |
| messageParameters={"arg_name": "schema", "arg_type": "int"}, |
| ) |
| |
| def test_greatest(self): |
| df = self.spark.range(10) |
| with self.assertRaises(PySparkValueError) as pe: |
| F.greatest(df.id) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="WRONG_NUM_COLUMNS", |
| messageParameters={"func_name": "greatest", "num_cols": "2"}, |
| ) |
| |
| def test_when(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.when("id", 1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN", |
| messageParameters={"arg_name": "condition", "arg_type": "str"}, |
| ) |
| |
| def test_window(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.window("date", 5) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "windowDuration", "arg_type": "int"}, |
| ) |
| |
| def test_session_window(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.session_window("date", 5) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "gapDuration", "arg_type": "int"}, |
| ) |
| |
| def test_current_user(self): |
| df = self.spark.range(1).select(F.current_user()) |
| self.assertIsInstance(df.first()[0], str) |
| self.assertEqual(df.schema.names[0], "current_user()") |
| df = self.spark.range(1).select(F.user()) |
| self.assertEqual(df.schema.names[0], "user()") |
| df = self.spark.range(1).select(F.session_user()) |
| self.assertEqual(df.schema.names[0], "session_user()") |
| |
| def test_bucket(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| F.bucket("5", "id") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_INT", |
| messageParameters={"arg_name": "numBuckets", "arg_type": "str"}, |
| ) |
| |
| def test_to_time(self): |
| # SPARK-52890: test the to_time function. |
| df = self.spark.createDataFrame([("10:30:00", "HH:mm:ss")], ["time", "format"]) |
| result = datetime.time(10, 30, 0) |
| # Test without format. |
| row_from_col_no_format = df.select(F.to_time(df.time)).first() |
| self.assertIsInstance(row_from_col_no_format[0], datetime.time) |
| self.assertEqual(row_from_col_no_format[0], result) |
| row_from_name_no_format = df.select(F.to_time("time")).first() |
| self.assertIsInstance(row_from_name_no_format[0], datetime.time) |
| self.assertEqual(row_from_name_no_format[0], result) |
| # Test with format. |
| row_from_col_with_format = df.select(F.to_time(df.time, df.format)).first() |
| self.assertIsInstance(row_from_col_with_format[0], datetime.time) |
| self.assertEqual(row_from_col_with_format[0], result) |
| row_from_name_with_format = df.select(F.to_time("time", "format")).first() |
| self.assertIsInstance(row_from_name_with_format[0], datetime.time) |
| self.assertEqual(row_from_name_with_format[0], result) |
| |
| def test_to_timestamp_ltz(self): |
| df = self.spark.createDataFrame([("2016-12-31",)], ["e"]) |
| df = df.select(F.to_timestamp_ltz(df.e, F.lit("yyyy-MM-dd")).alias("r")) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| |
| df = self.spark.createDataFrame([("2016-12-31",)], ["e"]) |
| df = df.select(F.to_timestamp_ltz(df.e).alias("r")) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| |
| def test_to_timestamp_ntz(self): |
| df = self.spark.createDataFrame([("2016-12-31",)], ["e"]) |
| df = df.select(F.to_timestamp_ntz(df.e).alias("r")) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| |
| def test_convert_timezone(self): |
| df = self.spark.createDataFrame([("2015-04-08",)], ["dt"]) |
| df = df.select( |
| F.convert_timezone(F.lit("America/Los_Angeles"), F.lit("Asia/Hong_Kong"), "dt") |
| ) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| |
| def test_map_concat(self): |
| df = self.spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2") |
| self.assertEqual( |
| df.select(F.map_concat(["map1", "map2"]).alias("map3")).first()[0], |
| {1: "a", 2: "b", 3: "c"}, |
| ) |
| |
| def test_version(self): |
| self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str) |
| |
| # SPARK-45216: Fix non-deterministic seeded Dataset APIs |
| def test_non_deterministic_with_seed(self): |
| df = self.spark.createDataFrame([([*range(0, 10, 1)],)], ["a"]) |
| |
| r = F.rand() |
| r2 = F.randn() |
| r3 = F.shuffle("a") |
| res = df.select(r, r, r2, r2, r3, r3).collect() |
| for i in range(3): |
| self.assertEqual(res[0][i * 2], res[0][i * 2 + 1]) |
| |
| def test_current_time(self): |
| # SPARK-52889: test the current_time function without precision. |
| df = self.spark.range(1).select(F.current_time()) |
| self.assertIsInstance(df.first()[0], datetime.time) |
| self.assertEqual(df.schema.names[0], "current_time(6)") |
| # SPARK-52889: test the current_time function with precision. |
| df = self.spark.range(1).select(F.current_time(3)) |
| self.assertIsInstance(df.first()[0], datetime.time) |
| self.assertEqual(df.schema.names[0], "current_time(3)") |
| |
| def test_current_timestamp(self): |
| df = self.spark.range(1).select(F.current_timestamp()) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| self.assertEqual(df.schema.names[0], "current_timestamp()") |
| df = self.spark.range(1).select(F.now()) |
| self.assertIsInstance(df.first()[0], datetime.datetime) |
| self.assertEqual(df.schema.names[0], "now()") |
| |
| def test_json_tuple_empty_fields(self): |
| df = self.spark.createDataFrame( |
| [ |
| ("1", """{"f1": "value1", "f2": "value2"}"""), |
| ("2", """{"f1": "value12"}"""), |
| ], |
| ("key", "jstring"), |
| ) |
| self.assertRaisesRegex( |
| PySparkValueError, |
| "At least one field must be specified", |
| lambda: df.select(F.json_tuple(df.jstring)), |
| ) |
| |
| def test_avro_type_check(self): |
| parameters = ["data", "jsonFormatSchema", "options"] |
| expected_type = ["pyspark.sql.Column or str", "str", "dict, optional"] |
| dummyDF = self.spark.createDataFrame([Row(a=i, b=i) for i in range(5)]) |
| |
| # test from_avro type checks for each parameter |
| wrong_type_value = 1 |
| with self.assertRaises(PySparkTypeError) as pe1: |
| dummyDF.select(from_avro(wrong_type_value, "jsonSchema", None)) |
| with self.assertRaises(PySparkTypeError) as pe2: |
| dummyDF.select(from_avro("value", wrong_type_value, None)) |
| with self.assertRaises(PySparkTypeError) as pe3: |
| dummyDF.select(from_avro("value", "jsonSchema", wrong_type_value)) |
| from_avro_pes = [pe1, pe2, pe3] |
| for i in range(3): |
| self.check_error( |
| exception=from_avro_pes[i].exception, |
| errorClass="INVALID_TYPE", |
| messageParameters={"arg_name": parameters[i], "arg_type": expected_type[i]}, |
| ) |
| |
| # test to_avro type checks for each parameter |
| with self.assertRaises(PySparkTypeError) as pe4: |
| dummyDF.select(to_avro(wrong_type_value, "jsonSchema")) |
| with self.assertRaises(PySparkTypeError) as pe5: |
| dummyDF.select(to_avro("value", wrong_type_value)) |
| to_avro_pes = [pe4, pe5] |
| for i in range(2): |
| self.check_error( |
| exception=to_avro_pes[i].exception, |
| errorClass="INVALID_TYPE", |
| messageParameters={"arg_name": parameters[i], "arg_type": expected_type[i]}, |
| ) |
| |
| def test_enum_literals(self): |
| class IntEnum(Enum): |
| X = 1 |
| Y = 2 |
| Z = 3 |
| |
| id = F.col("id") |
| b = F.col("b") |
| |
| cols, expected = list( |
| zip( |
| (F.lit(IntEnum.X), 1), |
| (F.lit([IntEnum.X, IntEnum.Y]), [1, 2]), |
| (F.rand(IntEnum.X), 0.9531453492357947), |
| (F.randn(IntEnum.X), -1.1081822375859998), |
| (F.when(b, IntEnum.X), 1), |
| ) |
| ) |
| |
| result = ( |
| self.spark.range(1, 2) |
| .select(id, id.astype("string").alias("s"), id.astype("boolean").alias("b")) |
| .select(*cols) |
| .first() |
| ) |
| |
| for r, c, e in zip(result, cols, expected): |
| self.assertEqual(r, e, str(c)) |
| |
| def test_nullifzero_zeroifnull(self): |
| df = self.spark.createDataFrame([(0,), (1,)], ["a"]) |
| result = df.select(nullifzero(df.a).alias("r")) |
| assertDataFrameEqual([Row(r=None), Row(r=1)], result) |
| |
| df = self.spark.createDataFrame([(None,), (1,)], ["a"]) |
| result = df.select(zeroifnull(df.a).alias("r")) |
| assertDataFrameEqual([Row(r=0), Row(r=1)], result) |
| |
| def test_randstr_uniform(self): |
| df = self.spark.createDataFrame([(0,)], ["a"]) |
| result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)") |
| assertDataFrameEqual([Row(5)], result) |
| # The random seed is optional. |
| result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)") |
| assertDataFrameEqual([Row(5)], result) |
| |
| df = self.spark.createDataFrame([(0,)], ["a"]) |
| result = df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")).selectExpr("x > 5") |
| assertDataFrameEqual([Row(True)], result) |
| # The random seed is optional. |
| result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5") |
| assertDataFrameEqual([Row(True)], result) |
| |
| def test_string_validation(self): |
| df = self.spark.createDataFrame([("abc",)], ["a"]) |
| # test is_valid_utf8 |
| result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r")) |
| assertDataFrameEqual([Row(r=True)], result_is_valid_utf8) |
| # test make_valid_utf8 |
| result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r")) |
| assertDataFrameEqual([Row(r="abc")], result_make_valid_utf8) |
| # test validate_utf8 |
| result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r")) |
| assertDataFrameEqual([Row(r="abc")], result_validate_utf8) |
| # test try_validate_utf8 |
| result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r")) |
| assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8) |
| |
| |
| class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): |
| pass |
| |
| |
| if __name__ == "__main__": |
| import unittest |
| from pyspark.sql.tests.test_functions import * # noqa: F401 |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |