blob: 5c6acaffa324b70243eb6c745b88b9bbb81c0a46 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
from itertools import chain
import re
import math
from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window, types
from pyspark.sql.functions import (
udf,
input_file_name,
col,
percentile_approx,
lit,
assert_true,
sum_distinct,
sumDistinct,
shiftleft,
shiftLeft,
shiftRight,
shiftright,
shiftrightunsigned,
shiftRightUnsigned,
octet_length,
bit_length,
sec,
csc,
cot,
make_date,
date_add,
date_sub,
add_months,
array_repeat,
size,
slice,
least,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
class FunctionsTests(ReusedSQLTestCase):
def test_explode(self):
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
d = [
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
Row(a=1, intlist=[], mapfield={}),
Row(a=1, intlist=None, mapfield=None),
]
rdd = self.sc.parallelize(d)
data = self.spark.createDataFrame(rdd)
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
self.assertEqual(result[0][0], 1)
self.assertEqual(result[1][0], 2)
self.assertEqual(result[2][0], 3)
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")
result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
self.assertEqual(result, [(0, "a", "b"), (None, None, None), (None, None, None)])
result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
self.assertEqual(result, [1, 2, 3, None, None])
result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
self.assertEqual(result, [("a", "b"), (None, None), (None, None)])
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
df.count()
df.collect()
df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
df.persist()
df.unpersist(True)
df.cache()
self.assertTrue(df.is_cached)
self.assertEqual(2, df.count())
with self.tempView("temp"):
df.createOrReplaceTempView("temp")
df = self.spark.sql("select foo from temp")
df.count()
df.collect()
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr("a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
def test_sampleby(self):
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF()
sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(sampled.count() == 35)
def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
self.assertTrue(row[1], 1)
self.assertTrue(row[2], 1)
def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
SQLTestUtils.assert_close(
[math.cos(i) for i in range(10)], df.select(functions.cos(df.a)).collect()
)
SQLTestUtils.assert_close(
[math.cos(i) for i in range(10)], df.select(functions.cos("a")).collect()
)
SQLTestUtils.assert_close(
[math.sin(i) for i in range(10)], df.select(functions.sin(df.a)).collect()
)
SQLTestUtils.assert_close(
[math.sin(i) for i in range(10)], df.select(functions.sin(df["a"])).collect()
)
SQLTestUtils.assert_close(
[math.pow(i, 2 * i) for i in range(10)], df.select(functions.pow(df.a, df.b)).collect()
)
SQLTestUtils.assert_close(
[math.pow(i, 2) for i in range(10)], df.select(functions.pow(df.a, 2)).collect()
)
SQLTestUtils.assert_close(
[math.pow(i, 2) for i in range(10)], df.select(functions.pow(df.a, 2.0)).collect()
)
SQLTestUtils.assert_close(
[math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect(),
)
SQLTestUtils.assert_close(
[math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", "b")).collect(),
)
SQLTestUtils.assert_close(
[math.hypot(i, 2) for i in range(10)], df.select(functions.hypot("a", 2)).collect()
)
SQLTestUtils.assert_close(
[math.hypot(i, 2) for i in range(10)], df.select(functions.hypot(df.a, 2)).collect()
)
def test_inverse_trig_functions(self):
from pyspark.sql import functions
funs = [
(functions.acosh, "ACOSH"),
(functions.asinh, "ASINH"),
(functions.atanh, "ATANH"),
]
cols = ["a", functions.col("a")]
for f, alias in funs:
for c in cols:
self.assertIn(f"{alias}(a)", repr(f(c)))
def test_reciprocal_trig_functions(self):
# SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
lst = [
0.0,
math.pi / 6,
math.pi / 4,
math.pi / 3,
math.pi / 2,
math.pi,
3 * math.pi / 2,
2 * math.pi,
]
df = self.spark.createDataFrame(lst, types.DoubleType())
def to_reciprocal_trig(func):
return [1.0 / func(i) if func(i) != 0 else math.inf for i in lst]
SQLTestUtils.assert_close(to_reciprocal_trig(math.cos), df.select(sec(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.sin), df.select(csc(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.tan), df.select(cot(df.value)).collect())
def test_rand_functions(self):
df = self.df
from pyspark.sql import functions
rnd = df.select("key", functions.rand()).collect()
for row in rnd:
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
rndn = df.select("key", functions.randn(5)).collect()
for row in rndn:
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
# If the specified seed is 0, we should use it.
# https://issues.apache.org/jira/browse/SPARK-9691
rnd1 = df.select("key", functions.rand(0)).collect()
rnd2 = df.select("key", functions.rand(0)).collect()
self.assertEqual(sorted(rnd1), sorted(rnd2))
rndn1 = df.select("key", functions.randn(0)).collect()
rndn2 = df.select("key", functions.randn(0)).collect()
self.assertEqual(sorted(rndn1), sorted(rndn2))
def test_string_functions(self):
from pyspark.sql import functions
from pyspark.sql.functions import col, lit
string_functions = [
"upper",
"lower",
"ascii",
"base64",
"unbase64",
"ltrim",
"rtrim",
"trim",
]
df = self.spark.createDataFrame([["nick"]], schema=["name"])
self.assertRaisesRegex(
TypeError, "must be the same type", lambda: df.select(col("name").substr(0, lit(1)))
)
for name in string_functions:
self.assertEqual(
df.select(getattr(functions, name)("name")).first()[0],
df.select(getattr(functions, name)(col("name"))).first()[0],
)
def test_octet_length_function(self):
# SPARK-36751: add octet length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
actual = df.select(octet_length("cat")).collect()
self.assertEqual([Row(3), Row(4)], actual)
def test_bit_length_function(self):
# SPARK-36751: add bit length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
actual = df.select(bit_length("cat")).collect()
self.assertEqual([Row(24), Row(32)], actual)
def test_array_contains_function(self):
from pyspark.sql.functions import array_contains
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"])
actual = df.select(array_contains(df.data, "1").alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], actual)
def test_between_function(self):
df = self.sc.parallelize(
[Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
).toDF()
self.assertEqual(
[Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()
)
def test_dayofweek(self):
from pyspark.sql.functions import dayofweek
dt = datetime.datetime(2017, 11, 6)
df = self.spark.createDataFrame([Row(date=dt)])
row = df.select(dayofweek(df.date)).first()
self.assertEqual(row[0], 2)
# Test added for SPARK-37738; change Python API to accept both col & int as input
def test_date_add_function(self):
dt = datetime.date(2021, 12, 27)
# Note; number var in Python gets converted to LongType column;
# this is not supported by the function, so cast to Integer explicitly
df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add integer")
self.assertTrue(
all(
df.select(
date_add(df.date, df.add) == datetime.date(2021, 12, 29),
date_add(df.date, "add") == datetime.date(2021, 12, 29),
date_add(df.date, 3) == datetime.date(2021, 12, 30),
).first()
)
)
# Test added for SPARK-37738; change Python API to accept both col & int as input
def test_date_sub_function(self):
dt = datetime.date(2021, 12, 27)
# Note; number var in Python gets converted to LongType column;
# this is not supported by the function, so cast to Integer explicitly
df = self.spark.createDataFrame([Row(date=dt, sub=2)], "date date, sub integer")
self.assertTrue(
all(
df.select(
date_sub(df.date, df.sub) == datetime.date(2021, 12, 25),
date_sub(df.date, "sub") == datetime.date(2021, 12, 25),
date_sub(df.date, 3) == datetime.date(2021, 12, 24),
).first()
)
)
# Test added for SPARK-37738; change Python API to accept both col & int as input
def test_add_months_function(self):
dt = datetime.date(2021, 12, 27)
# Note; number in Python gets converted to LongType column;
# this is not supported by the function, so cast to Integer explicitly
df = self.spark.createDataFrame([Row(date=dt, add=2)], "date date, add integer")
self.assertTrue(
all(
df.select(
add_months(df.date, df.add) == datetime.date(2022, 2, 27),
add_months(df.date, "add") == datetime.date(2022, 2, 27),
add_months(df.date, 3) == datetime.date(2022, 3, 27),
).first()
)
)
def test_make_date(self):
# SPARK-36554: expose make_date expression
df = self.spark.createDataFrame([(2020, 6, 26)], ["Y", "M", "D"])
row_from_col = df.select(make_date(df.Y, df.M, df.D)).first()
self.assertEqual(row_from_col[0], datetime.date(2020, 6, 26))
row_from_name = df.select(make_date("Y", "M", "D")).first()
self.assertEqual(row_from_name[0], datetime.date(2020, 6, 26))
def test_expr(self):
from pyspark.sql import functions
row = Row(a="length string", b=75)
df = self.spark.createDataFrame([row])
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
self.assertEqual(13, result["length(a)"])
# add test for SPARK-10577 (test broadcast join hint)
def test_functions_broadcast(self):
from pyspark.sql.functions import broadcast
df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
# equijoin - should be converted into broadcast join
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
# no join key -- should not be a broadcast join
plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
def test_first_last_ignorenulls(self):
from pyspark.sql import functions
df = self.spark.range(0, 100)
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
df3 = df2.select(
functions.first(df2.id, False).alias("a"),
functions.first(df2.id, True).alias("b"),
functions.last(df2.id, False).alias("c"),
functions.last(df2.id, True).alias("d"),
)
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
def test_approxQuantile(self):
df = self.sc.parallelize([Row(a=i, b=i + 10) for i in range(10)]).toDF()
for f in ["a", "a"]:
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
self.assertTrue(all(isinstance(q, float) for q in aq))
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqs, list))
self.assertEqual(len(aqs), 2)
self.assertTrue(isinstance(aqs[0], list))
self.assertEqual(len(aqs[0]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
self.assertTrue(isinstance(aqs[1], list))
self.assertEqual(len(aqs[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqt, list))
self.assertEqual(len(aqt), 2)
self.assertTrue(isinstance(aqt[0], list))
self.assertEqual(len(aqt[0]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
self.assertTrue(isinstance(aqt[1], list))
self.assertEqual(len(aqt[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
self.assertRaises(TypeError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
self.assertRaises(TypeError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
self.assertRaises(TypeError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
def test_sorting_functions_with_column(self):
from pyspark.sql import functions
from pyspark.sql.column import Column
funs = [
functions.asc_nulls_first,
functions.asc_nulls_last,
functions.desc_nulls_first,
functions.desc_nulls_last,
]
exprs = [col("x"), "x"]
for fun in funs:
for expr in exprs:
res = fun(expr)
self.assertIsInstance(res, Column)
self.assertIn(f"""'x {fun.__name__.replace("_", " ").upper()}'""", str(res))
for expr in exprs:
res = functions.asc(expr)
self.assertIsInstance(res, Column)
self.assertIn("""'x ASC NULLS FIRST'""", str(res))
for expr in exprs:
res = functions.desc(expr)
self.assertIsInstance(res, Column)
self.assertIn("""'x DESC NULLS LAST'""", str(res))
def test_sort_with_nulls_order(self):
from pyspark.sql import functions
df = self.spark.createDataFrame(
[("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"]
)
self.assertEqual(
df.select(df.name).orderBy(functions.asc_nulls_first("name")).collect(),
[Row(name=None), Row(name="Alice"), Row(name="Tom")],
)
self.assertEqual(
df.select(df.name).orderBy(functions.asc_nulls_last("name")).collect(),
[Row(name="Alice"), Row(name="Tom"), Row(name=None)],
)
self.assertEqual(
df.select(df.name).orderBy(functions.desc_nulls_first("name")).collect(),
[Row(name=None), Row(name="Tom"), Row(name="Alice")],
)
self.assertEqual(
df.select(df.name).orderBy(functions.desc_nulls_last("name")).collect(),
[Row(name="Tom"), Row(name="Alice"), Row(name=None)],
)
def test_input_file_name_reset_for_rdd(self):
rdd = self.sc.textFile("python/test_support/hello/hello.txt").map(lambda x: {"data": x})
df = self.spark.createDataFrame(rdd, "data STRING")
df.select(input_file_name().alias("file")).collect()
non_file_df = self.spark.range(100).select(input_file_name())
results = non_file_df.collect()
self.assertTrue(len(results) == 100)
# [SPARK-24605]: if everything was properly reset after the last job, this should return
# empty string rather than the file read in the last job.
for result in results:
self.assertEqual(result[0], "")
def test_slice(self):
df = self.spark.createDataFrame(
[
(
[1, 2, 3],
2,
2,
),
(
[4, 5],
2,
2,
),
],
["x", "index", "len"],
)
expected = [Row(sliced=[2, 3]), Row(sliced=[5])]
self.assertTrue(
all(
[
df.select(slice(df.x, 2, 2).alias("sliced")).collect() == expected,
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect() == expected,
df.select(slice("x", "index", "len").alias("sliced")).collect() == expected,
]
)
)
self.assertEqual(
df.select(slice(df.x, size(df.x) - 1, lit(1)).alias("sliced")).collect(),
[Row(sliced=[2]), Row(sliced=[4])],
)
self.assertEqual(
df.select(slice(df.x, lit(1), size(df.x) - 1).alias("sliced")).collect(),
[Row(sliced=[1, 2]), Row(sliced=[4])],
)
def test_array_repeat(self):
df = self.spark.range(1)
df = df.withColumn("repeat_n", lit(3))
expected = [Row(val=[0, 0, 0])]
self.assertTrue(
all(
[
df.select(array_repeat("id", 3).alias("val")).collect() == expected,
df.select(array_repeat("id", lit(3)).alias("val")).collect() == expected,
df.select(array_repeat("id", "repeat_n").alias("val")).collect() == expected,
]
)
)
def test_input_file_name_udf(self):
df = self.spark.read.text("python/test_support/hello/hello.txt")
df = df.select(udf(lambda x: x)("value"), input_file_name().alias("file"))
file_name = df.collect()[0].file
self.assertTrue("python/test_support/hello/hello.txt" in file_name)
def test_least(self):
df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"])
expected = [Row(least=1)]
self.assertTrue(
all(
[
df.select(least(df.a, df.b, df.c).alias("least")).collect() == expected,
df.select(least(lit(3), lit(5), lit(1)).alias("least")).collect() == expected,
df.select(least("a", "b", "c").alias("least")).collect() == expected,
]
)
)
def test_overlay(self):
from pyspark.sql.functions import col, lit, overlay
from itertools import chain
import re
actual = list(
chain.from_iterable(
[
re.findall("(overlay\\(.*\\))", str(x))
for x in [
overlay(col("foo"), col("bar"), 1),
overlay("x", "y", 3),
overlay(col("x"), col("y"), 1, 3),
overlay("x", "y", 2, 5),
overlay("x", "y", lit(11)),
overlay("x", "y", lit(2), lit(5)),
]
]
)
)
expected = [
"overlay(foo, bar, 1, -1)",
"overlay(x, y, 3, -1)",
"overlay(x, y, 1, 3)",
"overlay(x, y, 2, 5)",
"overlay(x, y, 11, -1)",
"overlay(x, y, 2, 5)",
]
self.assertListEqual(actual, expected)
df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", "y", "pos", "len"))
exp = [Row(ol="SPARK_CORESQL")]
self.assertTrue(
all(
[
df.select(overlay(df.x, df.y, 7, 0).alias("ol")).collect() == exp,
df.select(overlay(df.x, df.y, lit(7), lit(0)).alias("ol")).collect() == exp,
df.select(overlay("x", "y", "pos", "len").alias("ol")).collect() == exp,
]
)
)
def test_percentile_approx(self):
actual = list(
chain.from_iterable(
[
re.findall("(percentile_approx\\(.*\\))", str(x))
for x in [
percentile_approx(col("foo"), lit(0.5)),
percentile_approx(col("bar"), 0.25, 42),
percentile_approx(col("bar"), [0.25, 0.5, 0.75]),
percentile_approx(col("foo"), (0.05, 0.95), 100),
percentile_approx("foo", 0.5),
percentile_approx("bar", [0.1, 0.9], lit(10)),
]
]
)
)
expected = [
"percentile_approx(foo, 0.5, 10000)",
"percentile_approx(bar, 0.25, 42)",
"percentile_approx(bar, array(0.25, 0.5, 0.75), 10000)",
"percentile_approx(foo, array(0.05, 0.95), 100)",
"percentile_approx(foo, 0.5, 10000)",
"percentile_approx(bar, array(0.1, 0.9), 10)",
]
self.assertListEqual(actual, expected)
def test_nth_value(self):
from pyspark.sql import Window
from pyspark.sql.functions import nth_value
df = self.spark.createDataFrame(
[
("a", 0, None),
("a", 1, "x"),
("a", 2, "y"),
("a", 3, "z"),
("a", 4, None),
("b", 1, None),
("b", 2, None),
],
schema=("key", "order", "value"),
)
w = Window.partitionBy("key").orderBy("order")
rs = df.select(
df.key,
df.order,
nth_value("value", 2).over(w),
nth_value("value", 2, False).over(w),
nth_value("value", 2, True).over(w),
).collect()
expected = [
("a", 0, None, None, None),
("a", 1, "x", "x", None),
("a", 2, "x", "x", "y"),
("a", 3, "x", "x", "y"),
("a", 4, "x", "x", "y"),
("b", 1, None, None, None),
("b", 2, None, None, None),
]
for r, ex in zip(sorted(rs), sorted(expected)):
self.assertEqual(tuple(r), ex[: len(r)])
def test_higher_order_function_failures(self):
from pyspark.sql.functions import col, transform
# Should fail with varargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda *x: lit(1))
# Should fail with kwargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda **x: lit(1))
# Should fail with nullary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda: lit(1))
# Should fail with quaternary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda x1, x2, x3, x4: lit(1))
# Should fail if function doesn't return Column
with self.assertRaises(ValueError):
transform(col("foo"), lambda x: 1)
def test_nested_higher_order_function(self):
# SPARK-35382: lambda vars must be resolved properly in nested higher order functions
from pyspark.sql.functions import flatten, struct, transform
df = self.spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters")
actual = df.select(
flatten(
transform(
"numbers",
lambda number: transform(
"letters", lambda letter: struct(number.alias("n"), letter.alias("l"))
),
)
)
).first()[0]
expected = [
(1, "a"),
(1, "b"),
(1, "c"),
(2, "a"),
(2, "b"),
(2, "c"),
(3, "a"),
(3, "b"),
(3, "c"),
]
self.assertEquals(actual, expected)
def test_window_functions(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.partitionBy("value").orderBy("key")
from pyspark.sql import functions as F
sel = df.select(
df.value,
df.key,
F.max("key").over(w.rowsBetween(0, 1)),
F.min("key").over(w.rowsBetween(0, 1)),
F.count("key").over(w.rowsBetween(float("-inf"), float("inf"))),
F.row_number().over(w),
F.rank().over(w),
F.dense_rank().over(w),
F.ntile(2).over(w),
)
rs = sorted(sel.collect())
expected = [
("1", 1, 1, 1, 1, 1, 1, 1, 1),
("2", 1, 1, 1, 3, 1, 1, 1, 1),
("2", 1, 2, 1, 3, 2, 1, 1, 1),
("2", 2, 2, 2, 3, 3, 3, 2, 2),
]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
def test_window_functions_without_partitionBy(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.orderBy("key", df.value)
from pyspark.sql import functions as F
sel = df.select(
df.value,
df.key,
F.max("key").over(w.rowsBetween(0, 1)),
F.min("key").over(w.rowsBetween(0, 1)),
F.count("key").over(w.rowsBetween(float("-inf"), float("inf"))),
F.row_number().over(w),
F.rank().over(w),
F.dense_rank().over(w),
F.ntile(2).over(w),
)
rs = sorted(sel.collect())
expected = [
("1", 1, 1, 1, 4, 1, 1, 1, 1),
("2", 1, 1, 1, 4, 2, 2, 2, 1),
("2", 1, 2, 1, 4, 3, 2, 2, 2),
("2", 2, 2, 2, 4, 4, 4, 3, 2),
]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
def test_window_functions_cumulative_sum(self):
df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
from pyspark.sql import functions as F
# Test cumulative sum
sel = df.select(
df.key, F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))
)
rs = sorted(sel.collect())
expected = [("one", 1), ("two", 3)]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
# Test boundary values less than JVM's Long.MinValue and make sure we don't overflow
sel = df.select(
df.key, F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))
)
rs = sorted(sel.collect())
expected = [("one", 1), ("two", 3)]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
# Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow
frame_end = Window.unboundedFollowing + 1
sel = df.select(
df.key, F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))
)
rs = sorted(sel.collect())
expected = [("one", 3), ("two", 2)]
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
def test_collect_functions(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql import functions
self.assertEqual(
sorted(df.select(functions.collect_set(df.key).alias("r")).collect()[0].r), [1, 2]
)
self.assertEqual(
sorted(df.select(functions.collect_list(df.key).alias("r")).collect()[0].r),
[1, 1, 1, 2],
)
self.assertEqual(
sorted(df.select(functions.collect_set(df.value).alias("r")).collect()[0].r), ["1", "2"]
)
self.assertEqual(
sorted(df.select(functions.collect_list(df.value).alias("r")).collect()[0].r),
["1", "2", "2", "2"],
)
def test_datetime_functions(self):
from pyspark.sql import functions
from datetime import date
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
self.assertEqual(date(2017, 1, 22), parse_result["to_date(dateCol)"])
def test_assert_true(self):
from pyspark.sql.functions import assert_true
df = self.spark.range(3)
self.assertEqual(
df.select(assert_true(df.id < 3)).toDF("val").collect(),
[Row(val=None), Row(val=None), Row(val=None)],
)
with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, "too big")).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("too big", str(cm.exception))
with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("2000000", str(cm.exception))
with self.assertRaises(TypeError) as cm:
df.select(assert_true(df.id < 2, 5))
self.assertEqual("errMsg should be a Column or a str, got <class 'int'>", str(cm.exception))
def test_raise_error(self):
from pyspark.sql.functions import raise_error
df = self.spark.createDataFrame([Row(id="foobar")])
with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error(df.id)).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("foobar", str(cm.exception))
with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error("barfoo")).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("barfoo", str(cm.exception))
with self.assertRaises(TypeError) as cm:
df.select(raise_error(None))
self.assertEqual(
"errMsg should be a Column or a str, got <class 'NoneType'>", str(cm.exception)
)
def test_sum_distinct(self):
self.spark.range(10).select(
assert_true(sum_distinct(col("id")) == sumDistinct(col("id")))
).collect()
def test_shiftleft(self):
self.spark.range(10).select(
assert_true(shiftLeft(col("id"), 2) == shiftleft(col("id"), 2))
).collect()
def test_shiftright(self):
self.spark.range(10).select(
assert_true(shiftRight(col("id"), 2) == shiftright(col("id"), 2))
).collect()
def test_shiftrightunsigned(self):
self.spark.range(10).select(
assert_true(shiftRightUnsigned(col("id"), 2) == shiftrightunsigned(col("id"), 2))
).collect()
def test_lit_day_time_interval(self):
td = datetime.timedelta(days=1, hours=12, milliseconds=123)
actual = self.spark.range(1).select(lit(td)).first()[0]
self.assertEqual(actual, td)
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_functions import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)