blob: c7274c0810cfb6906f8c0bf82b8380cf13ec2f0f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.errors import PySparkValueError
from pyspark.sql import functions as sf
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
class TVFTestsMixin:
def test_explode(self):
actual = self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM explode(array(1, 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode(
sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
)
expected = self.spark.sql("""SELECT * FROM explode(map('a', 1, 'b', 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.explode(sf.array())
expected = self.spark.sql("""SELECT * FROM explode(array())""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode(sf.create_map())
expected = self.spark.sql("""SELECT * FROM explode(map())""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.explode(sf.lit(None).astype("array<int>"))
expected = self.spark.sql("""SELECT * FROM explode(null :: array<int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode(sf.lit(None).astype("map<string, int>"))
expected = self.spark.sql("""SELECT * FROM explode(null :: map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_explode_with_lateral_join(self):
with self.tempView("t1", "t2"):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
t1.createOrReplaceTempView("t1")
t3 = self.spark.sql(
"VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4)) "
"AS t3(c1, c2)"
)
t3.createOrReplaceTempView("t3")
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.explode(sf.array(sf.col("c1").outer(), sf.col("c2").outer()))
.toDF("c3")
.alias("t2")
),
self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)"""),
)
assertDataFrameEqual(
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer()).toDF("v").alias("t2")),
self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)"""),
)
assertDataFrameEqual(
self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
.toDF("v")
.lateralJoin(
self.spark.range(1).select((sf.col("v").outer() + sf.lit(1)).alias("v2"))
),
self.spark.sql(
"""SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1 AS v2)"""
),
)
def test_explode_outer(self):
actual = self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM explode_outer(array(1, 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode_outer(
sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
)
expected = self.spark.sql("""SELECT * FROM explode_outer(map('a', 1, 'b', 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.explode_outer(sf.array())
expected = self.spark.sql("""SELECT * FROM explode_outer(array())""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode_outer(sf.create_map())
expected = self.spark.sql("""SELECT * FROM explode_outer(map())""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.explode_outer(sf.lit(None).astype("array<int>"))
expected = self.spark.sql("""SELECT * FROM explode_outer(null :: array<int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.explode_outer(sf.lit(None).astype("map<string, int>"))
expected = self.spark.sql("""SELECT * FROM explode_outer(null :: map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_explode_outer_with_lateral_join(self):
with self.tempView("t1", "t2"):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
t1.createOrReplaceTempView("t1")
t3 = self.spark.sql(
"VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4)) "
"AS t3(c1, c2)"
)
t3.createOrReplaceTempView("t3")
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.explode_outer(
sf.array(sf.col("c1").outer(), sf.col("c2").outer())
)
.toDF("c3")
.alias("t2")
),
self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"""),
)
assertDataFrameEqual(
t3.lateralJoin(
self.spark.tvf.explode_outer(sf.col("c2").outer()).toDF("v").alias("t2")
),
self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)"""),
)
assertDataFrameEqual(
self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
.toDF("v")
.lateralJoin(
self.spark.range(1).select((sf.col("v").outer() + sf.lit(1)).alias("v2"))
),
self.spark.sql(
"""
SELECT * FROM EXPLODE_OUTER(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1 AS v2)
"""
),
)
def test_inline(self):
actual = self.spark.tvf.inline(
sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2), sf.lit("b")))
)
expected = self.spark.sql("""SELECT * FROM inline(array(struct(1, 'a'), struct(2, 'b')))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.inline(sf.array().astype("array<struct<a:int,b:int>>"))
expected = self.spark.sql("""SELECT * FROM inline(array() :: array<struct<a:int,b:int>>)""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.inline(
sf.array(
sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2)),
sf.lit(None),
sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4)),
)
)
expected = self.spark.sql(
"""
SELECT * FROM
inline(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3, 'b', 4)))
"""
)
assertDataFrameEqual(actual=actual, expected=expected)
def test_inline_with_lateral_join(self):
with self.tempView("array_struct"):
array_struct = self.spark.sql(
"""
VALUES
(1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
(2, ARRAY()),
(3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
"""
)
array_struct.createOrReplaceTempView("array_struct")
assertDataFrameEqual(
array_struct.lateralJoin(self.spark.tvf.inline(sf.col("arr").outer())),
self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL INLINE(arr)"""),
)
assertDataFrameEqual(
array_struct.lateralJoin(
self.spark.tvf.inline(sf.col("arr").outer()).toDF("k", "v").alias("t"),
sf.col("id") == sf.col("k"),
"left",
),
self.spark.sql(
"""
SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v) ON id = k
"""
),
)
def test_inline_outer(self):
actual = self.spark.tvf.inline_outer(
sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2), sf.lit("b")))
)
expected = self.spark.sql(
"""SELECT * FROM inline_outer(array(struct(1, 'a'), struct(2, 'b')))"""
)
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.inline_outer(sf.array().astype("array<struct<a:int,b:int>>"))
expected = self.spark.sql(
"""SELECT * FROM inline_outer(array() :: array<struct<a:int,b:int>>)"""
)
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.inline_outer(
sf.array(
sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2)),
sf.lit(None),
sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4)),
)
)
expected = self.spark.sql(
"""
SELECT * FROM
inline_outer(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3, 'b', 4)))
"""
)
assertDataFrameEqual(actual=actual, expected=expected)
def test_inline_outer_with_lateral_join(self):
with self.tempView("array_struct"):
array_struct = self.spark.sql(
"""
VALUES
(1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
(2, ARRAY()),
(3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
"""
)
array_struct.createOrReplaceTempView("array_struct")
assertDataFrameEqual(
array_struct.lateralJoin(self.spark.tvf.inline_outer(sf.col("arr").outer())),
self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL INLINE_OUTER(arr)"""),
)
assertDataFrameEqual(
array_struct.lateralJoin(
self.spark.tvf.inline_outer(sf.col("arr").outer()).toDF("k", "v").alias("t"),
sf.col("id") == sf.col("k"),
"left",
),
self.spark.sql(
"""
SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr) t(k, v) ON id = k
"""
),
)
def test_json_tuple(self):
actual = self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'), sf.lit("a"), sf.lit("b"))
expected = self.spark.sql("""SELECT json_tuple('{"a":1, "b":2}', 'a', 'b')""")
assertDataFrameEqual(actual=actual, expected=expected)
with self.assertRaises(PySparkValueError) as pe:
self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'))
self.check_error(
exception=pe.exception,
errorClass="CANNOT_BE_EMPTY",
messageParameters={"item": "field"},
)
def test_json_tuple_with_lateral_join(self):
with self.tempView("json_table"):
json_table = self.spark.sql(
"""
VALUES
('1', '{"f1": "1", "f2": "2", "f3": 3, "f5": 5.23}'),
('2', '{"f1": "1", "f3": "3", "f2": 2, "f4": 4.01}'),
('3', '{"f1": 3, "f4": "4", "f3": "3", "f2": 2, "f5": 5.01}'),
('4', cast(null as string)),
('5', '{"f1": null, "f5": ""}'),
('6', '[invalid JSON string]') AS json_table(key, jstring)
"""
)
json_table.createOrReplaceTempView("json_table")
assertDataFrameEqual(
json_table.alias("t1")
.lateralJoin(
self.spark.tvf.json_tuple(
sf.col("jstring").outer(),
sf.lit("f1"),
sf.lit("f2"),
sf.lit("f3"),
sf.lit("f4"),
sf.lit("f5"),
).alias("t2")
)
.select("t1.key", "t2.*"),
self.spark.sql(
"""
SELECT t1.key, t2.* FROM json_table t1,
LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2
"""
),
)
assertDataFrameEqual(
json_table.alias("t1")
.lateralJoin(
self.spark.tvf.json_tuple(
sf.col("jstring").outer(),
sf.lit("f1"),
sf.lit("f2"),
sf.lit("f3"),
sf.lit("f4"),
sf.lit("f5"),
).alias("t2")
)
.where(sf.col("t2.c0").isNotNull())
.select("t1.key", "t2.*"),
self.spark.sql(
"""
SELECT t1.key, t2.* FROM json_table t1,
LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2
WHERE t2.c0 IS NOT NULL
"""
),
)
def test_posexplode(self):
actual = self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM posexplode(array(1, 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode(
sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
)
expected = self.spark.sql("""SELECT * FROM posexplode(map('a', 1, 'b', 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.posexplode(sf.array())
expected = self.spark.sql("""SELECT * FROM posexplode(array())""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode(sf.create_map())
expected = self.spark.sql("""SELECT * FROM posexplode(map())""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.posexplode(sf.lit(None).astype("array<int>"))
expected = self.spark.sql("""SELECT * FROM posexplode(null :: array<int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode(sf.lit(None).astype("map<string, int>"))
expected = self.spark.sql("""SELECT * FROM posexplode(null :: map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_posexplode_with_lateral_join(self):
with self.tempView("t1", "t2"):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
t1.createOrReplaceTempView("t1")
t3 = self.spark.sql(
"VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4)) "
"AS t3(c1, c2)"
)
t3.createOrReplaceTempView("t3")
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.posexplode(sf.array(sf.col("c1").outer(), sf.col("c2").outer()))
),
self.spark.sql("""SELECT * FROM t1, LATERAL POSEXPLODE(ARRAY(c1, c2))"""),
)
assertDataFrameEqual(
t3.lateralJoin(self.spark.tvf.posexplode(sf.col("c2").outer())),
self.spark.sql("""SELECT * FROM t3, LATERAL POSEXPLODE(c2)"""),
)
assertDataFrameEqual(
self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
.toDF("p", "v")
.lateralJoin(
self.spark.range(1).select((sf.col("v").outer() + sf.lit(1)).alias("v2"))
),
self.spark.sql(
"""
SELECT * FROM POSEXPLODE(ARRAY(1, 2)) t(p, v), LATERAL (SELECT v + 1 AS v2)
"""
),
)
def test_posexplode_outer(self):
actual = self.spark.tvf.posexplode_outer(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM posexplode_outer(array(1, 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode_outer(
sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
)
expected = self.spark.sql("""SELECT * FROM posexplode_outer(map('a', 1, 'b', 2))""")
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.posexplode_outer(sf.array())
expected = self.spark.sql("""SELECT * FROM posexplode_outer(array())""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode_outer(sf.create_map())
expected = self.spark.sql("""SELECT * FROM posexplode_outer(map())""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.posexplode_outer(sf.lit(None).astype("array<int>"))
expected = self.spark.sql("""SELECT * FROM posexplode_outer(null :: array<int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.posexplode_outer(sf.lit(None).astype("map<string, int>"))
expected = self.spark.sql("""SELECT * FROM posexplode_outer(null :: map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_posexplode_outer_with_lateral_join(self):
with self.tempView("t1", "t2"):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
t1.createOrReplaceTempView("t1")
t3 = self.spark.sql(
"VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4)) "
"AS t3(c1, c2)"
)
t3.createOrReplaceTempView("t3")
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.posexplode_outer(
sf.array(sf.col("c1").outer(), sf.col("c2").outer())
)
),
self.spark.sql("""SELECT * FROM t1, LATERAL POSEXPLODE_OUTER(ARRAY(c1, c2))"""),
)
assertDataFrameEqual(
t3.lateralJoin(self.spark.tvf.posexplode_outer(sf.col("c2").outer())),
self.spark.sql("""SELECT * FROM t3, LATERAL POSEXPLODE_OUTER(c2)"""),
)
assertDataFrameEqual(
self.spark.tvf.posexplode_outer(sf.array(sf.lit(1), sf.lit(2)))
.toDF("p", "v")
.lateralJoin(
self.spark.range(1).select((sf.col("v").outer() + sf.lit(1)).alias("v2"))
),
self.spark.sql(
"""
SELECT * FROM POSEXPLODE_OUTER(ARRAY(1, 2)) t(p, v),
LATERAL (SELECT v + 1 AS v2)
"""
),
)
def test_stack(self):
actual = self.spark.tvf.stack(sf.lit(2), sf.lit(1), sf.lit(2), sf.lit(3))
expected = self.spark.sql("""SELECT * FROM stack(2, 1, 2, 3)""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_stack_with_lateral_join(self):
with self.tempView("t1", "t3"):
t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
t1.createOrReplaceTempView("t1")
t3 = self.spark.sql(
"VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4)) "
"AS t3(c1, c2)"
)
t3.createOrReplaceTempView("t3")
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.stack(
sf.lit(2),
sf.lit("Key"),
sf.col("c1").outer(),
sf.lit("Value"),
sf.col("c2").outer(),
).alias("t")
).select("t.*"),
self.spark.sql(
"""SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t"""
),
)
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.stack(sf.lit(1), sf.col("c1").outer(), sf.col("c2").outer())
.toDF("x", "y")
.alias("t")
).select("t.*"),
self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)"""),
)
assertDataFrameEqual(
t1.join(t3, sf.col("t1.c1") == sf.col("t3.c1"))
.lateralJoin(
self.spark.tvf.stack(
sf.lit(1), sf.col("t1.c2").outer(), sf.col("t3.c2").outer()
).alias("t")
)
.select("t.*"),
self.spark.sql(
"""
SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1
JOIN LATERAL stack(1, t1.c2, t3.c2) t
"""
),
)
def test_collations(self):
actual = self.spark.tvf.collations()
expected = self.spark.sql("""SELECT * FROM collations()""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_sql_keywords(self):
actual = self.spark.tvf.sql_keywords()
expected = self.spark.sql("""SELECT * FROM sql_keywords()""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_variant_explode(self):
actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit('["hello", "world"]')))
expected = self.spark.sql(
"""SELECT * FROM variant_explode(parse_json('["hello", "world"]'))"""
)
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit('{"a": true, "b": 3.14}')))
expected = self.spark.sql(
"""SELECT * FROM variant_explode(parse_json('{"a": true, "b": 3.14}'))"""
)
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("[]")))
expected = self.spark.sql("""SELECT * FROM variant_explode(parse_json('[]'))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("{}")))
expected = self.spark.sql("""SELECT * FROM variant_explode(parse_json('{}'))""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.variant_explode(sf.lit(None).astype("variant"))
expected = self.spark.sql("""SELECT * FROM variant_explode(null :: variant)""")
assertDataFrameEqual(actual=actual, expected=expected)
# not a variant object/array
actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("1")))
expected = self.spark.sql("""SELECT * FROM variant_explode(parse_json('1'))""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_variant_explode_with_lateral_join(self):
with self.tempView("variant_table"):
variant_table = self.spark.sql(
"""
SELECT id, parse_json(v) AS v FROM VALUES
(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
(2, '[]'), (3, '{}'),
(4, NULL), (5, '1')
AS t(id, v)
"""
)
variant_table.createOrReplaceTempView("variant_table")
assertDataFrameEqual(
variant_table.alias("t1")
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()).alias("t"))
.select("t1.id", "t.*"),
self.spark.sql(
"""
SELECT t1.id, t.* FROM variant_table AS t1,
LATERAL variant_explode(v) AS t
"""
),
)
def test_variant_explode_outer(self):
actual = self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('["hello", "world"]')))
expected = self.spark.sql(
"""SELECT * FROM variant_explode_outer(parse_json('["hello", "world"]'))"""
)
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.variant_explode_outer(
sf.parse_json(sf.lit('{"a": true, "b": 3.14}'))
)
expected = self.spark.sql(
"""SELECT * FROM variant_explode_outer(parse_json('{"a": true, "b": 3.14}'))"""
)
assertDataFrameEqual(actual=actual, expected=expected)
# empty
actual = self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("[]")))
expected = self.spark.sql("""SELECT * FROM variant_explode_outer(parse_json('[]'))""")
assertDataFrameEqual(actual=actual, expected=expected)
actual = self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("{}")))
expected = self.spark.sql("""SELECT * FROM variant_explode_outer(parse_json('{}'))""")
assertDataFrameEqual(actual=actual, expected=expected)
# null
actual = self.spark.tvf.variant_explode_outer(sf.lit(None).astype("variant"))
expected = self.spark.sql("""SELECT * FROM variant_explode_outer(null :: variant)""")
assertDataFrameEqual(actual=actual, expected=expected)
# not a variant object/array
actual = self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("1")))
expected = self.spark.sql("""SELECT * FROM variant_explode_outer(parse_json('1'))""")
assertDataFrameEqual(actual=actual, expected=expected)
def test_variant_explode_outer_with_lateral_join(self):
with self.tempView("variant_table"):
variant_table = self.spark.sql(
"""
SELECT id, parse_json(v) AS v FROM VALUES
(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
(2, '[]'), (3, '{}'),
(4, NULL), (5, '1')
AS t(id, v)
"""
)
variant_table.createOrReplaceTempView("variant_table")
assertDataFrameEqual(
variant_table.alias("t1")
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()).alias("t"))
.select("t1.id", "t.*"),
self.spark.sql(
"""
SELECT t1.id, t.* FROM variant_table AS t1,
LATERAL variant_explode_outer(v) AS t
"""
),
)
class TVFTests(TVFTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_tvf import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)