| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import glob |
| import os |
| import pydoc |
| import shutil |
| import tempfile |
| import unittest |
| from typing import cast |
| import io |
| from contextlib import redirect_stdout |
| |
| from pyspark.sql import Row, functions, DataFrame |
| from pyspark.sql.functions import ( |
| col, |
| lit, |
| count, |
| struct, |
| date_format, |
| to_date, |
| array, |
| explode, |
| when, |
| concat, |
| ) |
| from pyspark.sql.types import ( |
| StringType, |
| IntegerType, |
| LongType, |
| StructType, |
| StructField, |
| ) |
| from pyspark.storagelevel import StorageLevel |
| from pyspark.errors import ( |
| AnalysisException, |
| IllegalArgumentException, |
| PySparkTypeError, |
| PySparkValueError, |
| ) |
| from pyspark.testing import assertDataFrameEqual |
| from pyspark.testing.sqlutils import ( |
| ReusedSQLTestCase, |
| SPARK_HOME, |
| have_pyarrow, |
| have_pandas, |
| pandas_requirement_message, |
| pyarrow_requirement_message, |
| ) |
| |
| |
| class DataFrameTestsMixin: |
| def test_range(self): |
| self.assertEqual(self.spark.range(1, 1).count(), 0) |
| self.assertEqual(self.spark.range(1, 0, -1).count(), 1) |
| self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) |
| self.assertEqual(self.spark.range(-2).count(), 0) |
| self.assertEqual(self.spark.range(3).count(), 3) |
| |
| def test_table(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.table(None) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "tableName", "arg_type": "NoneType"}, |
| ) |
| |
| def test_dataframe_star(self): |
| df1 = self.spark.createDataFrame([{"a": 1}]) |
| df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}]) |
| df3 = df2.withColumnsRenamed({"a": "x", "b": "y"}) |
| |
| df = df1.join(df2) |
| self.assertEqual(df.columns, ["a", "a", "b"]) |
| self.assertEqual(df.select(df1["*"]).columns, ["a"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| |
| df = df1.join(df2).withColumn("c", lit(0)) |
| self.assertEqual(df.columns, ["a", "a", "b", "c"]) |
| self.assertEqual(df.select(df1["*"]).columns, ["a"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| |
| df = df1.join(df2, "a") |
| self.assertEqual(df.columns, ["a", "b"]) |
| self.assertEqual(df.select(df1["*"]).columns, ["a"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| |
| df = df1.join(df2, "a").withColumn("c", lit(0)) |
| self.assertEqual(df.columns, ["a", "b", "c"]) |
| self.assertEqual(df.select(df1["*"]).columns, ["a"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| |
| df = df2.join(df3) |
| self.assertEqual(df.columns, ["a", "b", "x", "y"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| self.assertEqual(df.select(df3["*"]).columns, ["x", "y"]) |
| |
| df = df2.join(df3).withColumn("c", lit(0)) |
| self.assertEqual(df.columns, ["a", "b", "x", "y", "c"]) |
| self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) |
| self.assertEqual(df.select(df3["*"]).columns, ["x", "y"]) |
| |
| def test_count_star(self): |
| df1 = self.spark.createDataFrame([{"a": 1}]) |
| df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}]) |
| df3 = df2.select(struct("a", "b").alias("s")) |
| |
| self.assertEqual(df1.select(count(df1["*"])).columns, ["count(1)"]) |
| self.assertEqual(df1.select(count(col("*"))).columns, ["count(1)"]) |
| |
| self.assertEqual(df2.select(count(df2["*"])).columns, ["count(1)"]) |
| self.assertEqual(df2.select(count(col("*"))).columns, ["count(1)"]) |
| |
| self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"]) |
| self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"]) |
| |
| def test_self_join(self): |
| df1 = self.spark.range(10).withColumn("a", lit(0)) |
| df2 = df1.withColumnRenamed("a", "b") |
| df = df1.join(df2, df1["a"] == df2["b"]) |
| self.assertTrue(df.count() == 100) |
| df = df2.join(df1, df2["b"] == df1["a"]) |
| self.assertTrue(df.count() == 100) |
| |
| def test_self_join_II(self): |
| df = self.spark.createDataFrame([(1, 2), (3, 4)], schema=["a", "b"]) |
| df2 = df.select(df.a.alias("aa"), df.b) |
| df3 = df2.join(df, df2.b == df.b) |
| self.assertTrue(df3.columns, ["aa", "b", "a", "b"]) |
| self.assertTrue(df3.count() == 2) |
| |
| def test_self_join_III(self): |
| df1 = self.spark.range(10).withColumn("value", lit(1)) |
| df2 = df1.union(df1) |
| df3 = df1.join(df2, df1.id == df2.id, "left") |
| self.assertTrue(df3.columns, ["id", "value", "id", "value"]) |
| self.assertTrue(df3.count() == 20) |
| |
| def test_self_join_IV(self): |
| df1 = self.spark.range(10).withColumn("value", lit(1)) |
| df2 = df1.withColumn("value", lit(2)).union(df1.withColumn("value", lit(3))) |
| df3 = df1.join(df2, df1.id == df2.id, "right") |
| self.assertTrue(df3.columns, ["id", "value", "id", "value"]) |
| self.assertTrue(df3.count() == 20) |
| |
| def test_lateral_column_alias(self): |
| df1 = self.spark.range(10).select( |
| (col("id") + lit(1)).alias("x"), (col("x") + lit(1)).alias("y") |
| ) |
| df2 = self.spark.range(10).select(col("id").alias("x")) |
| df3 = df1.join(df2, df1.x == df2.x).select(df1.y) |
| self.assertTrue(df3.columns, ["y"]) |
| self.assertTrue(df3.count() == 9) |
| |
| def test_duplicated_column_names(self): |
| df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) |
| row = df.select("*").first() |
| self.assertEqual(1, row[0]) |
| self.assertEqual(2, row[1]) |
| self.assertEqual("Row(c=1, c=2)", str(row)) |
| # Cannot access columns |
| self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) |
| self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) |
| self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) |
| |
| def test_help_command(self): |
| # Regression test for SPARK-5464 |
| rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) |
| df = self.spark.read.json(rdd) |
| # render_doc() reproduces the help() exception without printing output |
| self.check_help_command(df) |
| |
| def check_help_command(self, df): |
| pydoc.render_doc(df) |
| pydoc.render_doc(df.foo) |
| pydoc.render_doc(df.take(1)) |
| |
| def test_drop(self): |
| df = self.spark.createDataFrame([("A", 50, "Y"), ("B", 60, "Y")], ["name", "age", "active"]) |
| self.assertEqual(df.drop("active").columns, ["name", "age"]) |
| self.assertEqual(df.drop("active", "nonexistent_column").columns, ["name", "age"]) |
| self.assertEqual(df.drop("name", "age", "active").columns, []) |
| self.assertEqual(df.drop(col("name")).columns, ["age", "active"]) |
| self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) |
| self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) |
| |
| def test_drop_notexistent_col(self): |
| df1 = self.spark.createDataFrame( |
| [("a", "b", "c")], |
| schema="colA string, colB string, colC string", |
| ) |
| df2 = self.spark.createDataFrame( |
| [("c", "d", "e")], |
| schema="colC string, colD string, colE string", |
| ) |
| df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn( |
| "colB", |
| when(df1["colB"] == "b", concat(df1["colB"].cast("string"), lit("x"))).otherwise( |
| df1["colB"] |
| ), |
| ) |
| df4 = df3.drop(df1["colB"]) |
| |
| self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", "colE"]) |
| self.assertEqual(df4.count(), 1) |
| |
| def test_drop_col_from_different_dataframe(self): |
| df1 = self.spark.range(10) |
| df2 = df1.withColumn("v0", lit(0)) |
| |
| # drop df2["id"] from df2 |
| self.assertEqual(df2.drop(df2["id"]).columns, ["v0"]) |
| |
| # drop df1["id"] from df2, which is semantically equal to df2["id"] |
| # note that df1.drop(df2["id"]) works in Classic, but not in Connect |
| self.assertEqual(df2.drop(df1["id"]).columns, ["v0"]) |
| |
| df3 = df2.select("*", lit(1).alias("v1")) |
| |
| # drop df3["id"] from df3 |
| self.assertEqual(df3.drop(df3["id"]).columns, ["v0", "v1"]) |
| |
| # drop df2["id"] from df3, which is semantically equal to df3["id"] |
| self.assertEqual(df3.drop(df2["id"]).columns, ["v0", "v1"]) |
| |
| # drop df1["id"] from df3, which is semantically equal to df3["id"] |
| self.assertEqual(df3.drop(df1["id"]).columns, ["v0", "v1"]) |
| |
| # drop df3["v0"] from df3 |
| self.assertEqual(df3.drop(df3["v0"]).columns, ["id", "v1"]) |
| |
| # drop df2["v0"] from df3, which is semantically equal to df3["v0"] |
| self.assertEqual(df3.drop(df2["v0"]).columns, ["id", "v1"]) |
| |
| def test_drop_join(self): |
| left_df = self.spark.createDataFrame( |
| [(1, "a"), (2, "b"), (3, "c")], |
| ["join_key", "value1"], |
| ) |
| right_df = self.spark.createDataFrame( |
| [(1, "aa"), (2, "bb"), (4, "dd")], |
| ["join_key", "value2"], |
| ) |
| joined_df = left_df.join( |
| right_df, |
| on=left_df["join_key"] == right_df["join_key"], |
| how="left", |
| ) |
| |
| dropped_1 = joined_df.drop(left_df["join_key"]) |
| self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"]) |
| self.assertEqual( |
| dropped_1.sort("value1").collect(), |
| [ |
| Row(value1="a", join_key=1, value2="aa"), |
| Row(value1="b", join_key=2, value2="bb"), |
| Row(value1="c", join_key=None, value2=None), |
| ], |
| ) |
| |
| dropped_2 = joined_df.drop(right_df["join_key"]) |
| self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"]) |
| self.assertEqual( |
| dropped_2.sort("value1").collect(), |
| [ |
| Row(join_key=1, value1="a", value2="aa"), |
| Row(join_key=2, value1="b", value2="bb"), |
| Row(join_key=3, value1="c", value2=None), |
| ], |
| ) |
| |
| def test_with_columns_renamed(self): |
| df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) |
| |
| # rename both columns |
| renamed_df1 = df.withColumnsRenamed({"name": "naam", "age": "leeftijd"}) |
| self.assertEqual(renamed_df1.columns, ["naam", "leeftijd"]) |
| |
| # rename one column with one missing name |
| renamed_df2 = df.withColumnsRenamed({"name": "naam", "address": "adres"}) |
| self.assertEqual(renamed_df2.columns, ["naam", "age"]) |
| |
| # negative test for incorrect type |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.withColumnsRenamed(("name", "x")) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_DICT", |
| messageParameters={"arg_name": "colsMap", "arg_type": "tuple"}, |
| ) |
| |
| def test_with_columns_renamed_with_duplicated_names(self): |
| df1 = self.spark.createDataFrame([(1, "v1")], ["id", "value"]) |
| df2 = self.spark.createDataFrame([(1, "x", "v2")], ["id", "a", "value"]) |
| join = df2.join(df1, on=["id"], how="left") |
| |
| self.assertEqual( |
| join.withColumnRenamed("id", "value").columns, |
| join.withColumnsRenamed({"id": "value"}).columns, |
| ) |
| self.assertEqual( |
| join.withColumnRenamed("a", "b").columns, |
| join.withColumnsRenamed({"a": "b"}).columns, |
| ) |
| self.assertEqual( |
| join.withColumnRenamed("value", "new_value").columns, |
| join.withColumnsRenamed({"value": "new_value"}).columns, |
| ) |
| self.assertEqual( |
| join.withColumnRenamed("x", "y").columns, |
| join.withColumnsRenamed({"x": "y"}).columns, |
| ) |
| |
| def test_ordering_of_with_columns_renamed(self): |
| df = self.spark.range(10) |
| |
| df1 = df.withColumnsRenamed({"id": "a", "a": "b"}) |
| self.assertEqual(df1.columns, ["b"]) |
| |
| df2 = df.withColumnsRenamed({"a": "b", "id": "a"}) |
| self.assertEqual(df2.columns, ["a"]) |
| |
| def test_drop_duplicates(self): |
| # SPARK-36034 test that drop duplicates throws a type error when in correct type provided |
| df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) |
| |
| # shouldn't drop a non-null row |
| self.assertEqual(df.dropDuplicates().count(), 2) |
| |
| self.assertEqual(df.dropDuplicates(["name"]).count(), 1) |
| |
| self.assertEqual(df.dropDuplicates(["name", "age"]).count(), 2) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.dropDuplicates("name") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_LIST_OR_TUPLE", |
| messageParameters={"arg_name": "subset", "arg_type": "str"}, |
| ) |
| |
| # Should raise proper error when taking non-string values |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.dropDuplicates([None]).show() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "subset", "arg_type": "NoneType"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.dropDuplicates([1]).show() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "subset", "arg_type": "int"}, |
| ) |
| |
| def test_drop_duplicates_with_ambiguous_reference(self): |
| df1 = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) |
| df2 = self.spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) |
| df3 = df1.join(df2, df1.name == df2.name, "inner") |
| |
| self.assertEqual(df3.drop("name", "age").columns, ["height"]) |
| self.assertEqual(df3.drop("name", df3.age, "unknown").columns, ["height"]) |
| self.assertEqual(df3.drop("name", "age", df3.height).columns, []) |
| |
| def test_drop_empty_column(self): |
| df = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) |
| |
| self.assertEqual(df.drop().columns, ["age", "name"]) |
| self.assertEqual(df.drop(*[]).columns, ["age", "name"]) |
| |
| def test_drop_column_name_with_dot(self): |
| df = ( |
| self.spark.range(1, 3) |
| .withColumn("first.name", lit("Peter")) |
| .withColumn("city.name", lit("raleigh")) |
| .withColumn("state", lit("nc")) |
| ) |
| |
| self.assertEqual(df.drop("first.name").columns, ["id", "city.name", "state"]) |
| self.assertEqual(df.drop("city.name").columns, ["id", "first.name", "state"]) |
| self.assertEqual(df.drop("first.name", "city.name").columns, ["id", "state"]) |
| self.assertEqual( |
| df.drop("first.name", "city.name", "unknown.unknown").columns, ["id", "state"] |
| ) |
| self.assertEqual( |
| df.drop("unknown.unknown").columns, ["id", "first.name", "city.name", "state"] |
| ) |
| |
| def test_with_column_with_existing_name(self): |
| keys = self.df.withColumn("key", self.df.key).select("key").collect() |
| self.assertEqual([r.key for r in keys], list(range(100))) |
| |
| # regression test for SPARK-10417 |
| def test_column_iterator(self): |
| def foo(): |
| for x in self.df.key: |
| break |
| |
| self.assertRaises(TypeError, foo) |
| |
| def test_with_columns(self): |
| # With single column |
| keys = self.df.withColumns({"key": self.df.key}).select("key").collect() |
| self.assertEqual([r.key for r in keys], list(range(100))) |
| |
| # With key and value columns |
| kvs = ( |
| self.df.withColumns({"key": self.df.key, "value": self.df.value}) |
| .select("key", "value") |
| .collect() |
| ) |
| self.assertEqual([(r.key, r.value) for r in kvs], [(i, str(i)) for i in range(100)]) |
| |
| # Columns rename |
| kvs = ( |
| self.df.withColumns({"key_alias": self.df.key, "value_alias": self.df.value}) |
| .select("key_alias", "value_alias") |
| .collect() |
| ) |
| self.assertEqual( |
| [(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in range(100)] |
| ) |
| |
| # Type check |
| self.assertRaises(TypeError, self.df.withColumns, ["key"]) |
| self.assertRaises(Exception, self.df.withColumns) |
| |
| def test_generic_hints(self): |
| df1 = self.spark.range(10e10).toDF("id") |
| df2 = self.spark.range(10e10).toDF("id") |
| |
| self.assertIsInstance(df1.hint("broadcast"), type(df1)) |
| |
| # Dummy rules |
| self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), type(df1)) |
| |
| with io.StringIO() as buf, redirect_stdout(buf): |
| df1.join(df2.hint("broadcast"), "id").explain(True) |
| self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin")) |
| |
| def test_coalesce_hints_with_string_parameter(self): |
| with self.sql_conf({"spark.sql.adaptive.coalescePartitions.enabled": False}): |
| df = self.spark.createDataFrame( |
| zip(["A", "B"] * 2**9, range(2**10)), |
| StructType([StructField("a", StringType()), StructField("n", IntegerType())]), |
| ) |
| with io.StringIO() as buf, redirect_stdout(buf): |
| # COALESCE |
| coalesce = df.hint("coalesce", 2) |
| coalesce.explain(True) |
| output = buf.getvalue() |
| self.assertGreaterEqual(output.count("Coalesce 2"), 1) |
| buf.truncate(0) |
| buf.seek(0) |
| |
| # REPARTITION_BY_RANGE |
| range_partitioned = df.hint("REPARTITION_BY_RANGE", 2, "a") |
| range_partitioned.explain(True) |
| output = buf.getvalue() |
| self.assertGreaterEqual(output.count("REPARTITION_BY_NUM"), 1) |
| buf.truncate(0) |
| buf.seek(0) |
| |
| # REBALANCE |
| rebalanced1 = df.hint("REBALANCE", "a") # just check this doesn't error |
| rebalanced1.explain(True) |
| rebalanced2 = df.hint("REBALANCE", 2) |
| rebalanced2.explain(True) |
| rebalanced3 = df.hint("REBALANCE", 2, "a") |
| rebalanced3.explain(True) |
| rebalanced4 = df.hint("REBALANCE", functions.col("a")) |
| rebalanced4.explain(True) |
| output = buf.getvalue() |
| self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_NONE"), 1) |
| self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_COL"), 3) |
| |
| # add tests for SPARK-23647 (test more types for hint) |
| def test_extended_hint_types(self): |
| df = self.spark.range(10e10).toDF("id") |
| such_a_nice_list = ["itworks1", "itworks2", "itworks3"] |
| int_list = [1, 2, 3] |
| hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list, int_list) |
| |
| self.assertIsInstance(df.hint("broadcast", []), type(df)) |
| self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df)) |
| |
| with io.StringIO() as buf, redirect_stdout(buf): |
| hinted_df.explain(True) |
| explain_output = buf.getvalue() |
| self.assertGreaterEqual(explain_output.count("1.2345"), 1) |
| self.assertGreaterEqual(explain_output.count("what"), 1) |
| self.assertGreaterEqual(explain_output.count("itworks"), 1) |
| |
| def test_sample(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.range(1).sample() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_BOOL_OR_FLOAT_OR_INT", |
| messageParameters={ |
| "arg_name": "withReplacement (optional), fraction (required) and seed (optional)", |
| "arg_type": "NoneType, NoneType, NoneType", |
| }, |
| ) |
| |
| self.assertRaises(TypeError, lambda: self.spark.range(1).sample("a")) |
| |
| self.assertRaises(TypeError, lambda: self.spark.range(1).sample(seed="abc")) |
| |
| self.assertRaises( |
| IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() |
| ) |
| |
| def test_sample_with_random_seed(self): |
| df = self.spark.range(10000).sample(0.1) |
| cnts = [df.count() for i in range(10)] |
| self.assertEqual(1, len(set(cnts))) |
| |
| def test_toDF_with_string(self): |
| df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) |
| data = [("John", 30), ("Alice", 25), ("Bob", 28)] |
| |
| result = df.toDF("key", "value") |
| self.assertEqual(result.schema.simpleString(), "struct<key:string,value:bigint>") |
| self.assertEqual(result.collect(), data) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.toDF("key", None) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_LIST_OF_STR", |
| messageParameters={"arg_name": "cols", "arg_type": "NoneType"}, |
| ) |
| |
| def test_toDF_with_schema_string(self): |
| data = [Row(key=i, value=str(i)) for i in range(100)] |
| rdd = self.sc.parallelize(data, 5) |
| |
| df = rdd.toDF("key: int, value: string") |
| self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>") |
| self.assertEqual(df.collect(), data) |
| |
| # different but compatible field types can be used. |
| df = rdd.toDF("key: string, value: string") |
| self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>") |
| self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) |
| |
| # field names can differ. |
| df = rdd.toDF(" a: int, b: string ") |
| self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>") |
| self.assertEqual(df.collect(), data) |
| |
| # number of fields must match. |
| self.assertRaisesRegex( |
| Exception, |
| "FIELD_STRUCT_LENGTH_MISMATCH", |
| lambda: rdd.coalesce(1).toDF("key: int").collect(), |
| ) |
| |
| # field types mismatch will cause exception at runtime. |
| self.assertRaisesRegex( |
| Exception, |
| "FIELD_DATA_TYPE_UNACCEPTABLE", |
| lambda: rdd.coalesce(1).toDF("key: float, value: string").collect(), |
| ) |
| |
| # flat schema values will be wrapped into row. |
| df = rdd.map(lambda row: row.key).toDF("int") |
| self.assertEqual(df.schema.simpleString(), "struct<value:int>") |
| self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) |
| |
| # users can use DataType directly instead of data type string. |
| df = rdd.map(lambda row: row.key).toDF(IntegerType()) |
| self.assertEqual(df.schema.simpleString(), "struct<value:int>") |
| self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) |
| |
| def test_create_df_with_collation(self): |
| schema = StructType([StructField("name", StringType("UNICODE_CI"), True)]) |
| df = self.spark.createDataFrame([("Alice",), ("alice",)], schema) |
| |
| self.assertEqual(df.select("name").distinct().count(), 1) |
| |
| def test_print_schema(self): |
| df = self.spark.createDataFrame([(1, (2, 2))], ["a", "b"]) |
| |
| with io.StringIO() as buf, redirect_stdout(buf): |
| df.printSchema(1) |
| self.assertEqual(1, buf.getvalue().count("long")) |
| self.assertEqual(0, buf.getvalue().count("_1")) |
| self.assertEqual(0, buf.getvalue().count("_2")) |
| |
| buf.truncate(0) |
| buf.seek(0) |
| |
| df.printSchema(2) |
| self.assertEqual(3, buf.getvalue().count("long")) |
| self.assertEqual(1, buf.getvalue().count("_1")) |
| self.assertEqual(1, buf.getvalue().count("_2")) |
| |
| def test_join_without_on(self): |
| df1 = self.spark.range(1).toDF("a") |
| df2 = self.spark.range(1).toDF("b") |
| |
| with self.sql_conf({"spark.sql.crossJoin.enabled": False}): |
| self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) |
| |
| with self.sql_conf({"spark.sql.crossJoin.enabled": True}): |
| actual = df1.join(df2, how="inner").collect() |
| expected = [Row(a=0, b=0)] |
| self.assertEqual(actual, expected) |
| |
| # Regression test for invalid join methods when on is None, Spark-14761 |
| def test_invalid_join_method(self): |
| df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) |
| df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) |
| self.assertRaises(AnalysisException, lambda: df1.join(df2, how="invalid-join-type")) |
| |
| # Cartesian products require cross join syntax |
| def test_require_cross(self): |
| df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) |
| df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) |
| |
| with self.sql_conf({"spark.sql.crossJoin.enabled": False}): |
| # joins without conditions require cross join syntax |
| self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) |
| |
| # works with crossJoin |
| self.assertEqual(1, df1.crossJoin(df2).count()) |
| |
| def test_cache_dataframe(self): |
| df = self.spark.createDataFrame([(2, 2), (3, 3)]) |
| try: |
| self.assertEqual(df.storageLevel, StorageLevel.NONE) |
| |
| df.cache() |
| self.assertEqual(df.storageLevel, StorageLevel.MEMORY_AND_DISK_DESER) |
| |
| df.unpersist() |
| self.assertEqual(df.storageLevel, StorageLevel.NONE) |
| |
| df.persist() |
| self.assertEqual(df.storageLevel, StorageLevel.MEMORY_AND_DISK_DESER) |
| |
| df.unpersist(blocking=True) |
| self.assertEqual(df.storageLevel, StorageLevel.NONE) |
| |
| df.persist(StorageLevel.DISK_ONLY) |
| self.assertEqual(df.storageLevel, StorageLevel.DISK_ONLY) |
| finally: |
| df.unpersist() |
| self.assertEqual(df.storageLevel, StorageLevel.NONE) |
| |
| def test_cache_table(self): |
| spark = self.spark |
| tables = ["tab1", "tab2", "tab3"] |
| with self.tempView(*tables): |
| for i, tab in enumerate(tables): |
| spark.createDataFrame([(2, i), (3, i)]).createOrReplaceTempView(tab) |
| self.assertFalse(spark.catalog.isCached(tab)) |
| spark.catalog.cacheTable("tab1") |
| spark.catalog.cacheTable("tab3", StorageLevel.OFF_HEAP) |
| self.assertTrue(spark.catalog.isCached("tab1")) |
| self.assertFalse(spark.catalog.isCached("tab2")) |
| self.assertTrue(spark.catalog.isCached("tab3")) |
| spark.catalog.cacheTable("tab2") |
| spark.catalog.uncacheTable("tab1") |
| spark.catalog.uncacheTable("tab3") |
| self.assertFalse(spark.catalog.isCached("tab1")) |
| self.assertTrue(spark.catalog.isCached("tab2")) |
| self.assertFalse(spark.catalog.isCached("tab3")) |
| spark.catalog.clearCache() |
| self.assertFalse(spark.catalog.isCached("tab1")) |
| self.assertFalse(spark.catalog.isCached("tab2")) |
| self.assertFalse(spark.catalog.isCached("tab3")) |
| self.assertRaisesRegex( |
| AnalysisException, |
| "does_not_exist", |
| lambda: spark.catalog.isCached("does_not_exist"), |
| ) |
| self.assertRaisesRegex( |
| AnalysisException, |
| "does_not_exist", |
| lambda: spark.catalog.cacheTable("does_not_exist"), |
| ) |
| self.assertRaisesRegex( |
| AnalysisException, |
| "does_not_exist", |
| lambda: spark.catalog.uncacheTable("does_not_exist"), |
| ) |
| |
| def test_repr_behaviors(self): |
| import re |
| |
| pattern = re.compile(r"^ *\|", re.MULTILINE) |
| df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) |
| |
| # test when eager evaluation is enabled and _repr_html_ will not be called |
| with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): |
| expected1 = """+-----+-----+ |
| || key|value| |
| |+-----+-----+ |
| || 1| 1| |
| ||22222|22222| |
| |+-----+-----+ |
| |""" |
| self.assertEqual(re.sub(pattern, "", expected1), df.__repr__()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): |
| expected2 = """+---+-----+ |
| ||key|value| |
| |+---+-----+ |
| || 1| 1| |
| ||222| 222| |
| |+---+-----+ |
| |""" |
| self.assertEqual(re.sub(pattern, "", expected2), df.__repr__()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): |
| expected3 = """+---+-----+ |
| ||key|value| |
| |+---+-----+ |
| || 1| 1| |
| |+---+-----+ |
| |only showing top 1 row""" |
| self.assertEqual(re.sub(pattern, "", expected3), df.__repr__()) |
| |
| # test when eager evaluation is enabled and _repr_html_ will be called |
| with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): |
| expected1 = """<table border='1'> |
| |<tr><th>key</th><th>value</th></tr> |
| |<tr><td>1</td><td>1</td></tr> |
| |<tr><td>22222</td><td>22222</td></tr> |
| |</table> |
| |""" |
| self.assertEqual(re.sub(pattern, "", expected1), df._repr_html_()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): |
| expected2 = """<table border='1'> |
| |<tr><th>key</th><th>value</th></tr> |
| |<tr><td>1</td><td>1</td></tr> |
| |<tr><td>222</td><td>222</td></tr> |
| |</table> |
| |""" |
| self.assertEqual(re.sub(pattern, "", expected2), df._repr_html_()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): |
| expected3 = """<table border='1'> |
| |<tr><th>key</th><th>value</th></tr> |
| |<tr><td>1</td><td>1</td></tr> |
| |</table> |
| |only showing top 1 row |
| |""" |
| self.assertEqual(re.sub(pattern, "", expected3), df._repr_html_()) |
| |
| # test when eager evaluation is disabled and _repr_html_ will be called |
| with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): |
| expected = "DataFrame[key: bigint, value: string]" |
| self.assertEqual(None, df._repr_html_()) |
| self.assertEqual(expected, df.__repr__()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): |
| self.assertEqual(None, df._repr_html_()) |
| self.assertEqual(expected, df.__repr__()) |
| with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): |
| self.assertEqual(None, df._repr_html_()) |
| self.assertEqual(expected, df.__repr__()) |
| |
| def test_same_semantics_error(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.range(10).sameSemantics(1) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_DATAFRAME", |
| messageParameters={"arg_name": "other", "arg_type": "int"}, |
| ) |
| |
| def test_input_files(self): |
| tpath = tempfile.mkdtemp() |
| shutil.rmtree(tpath) |
| try: |
| self.spark.range(1, 100, 1, 10).write.parquet(tpath) |
| # read parquet file and get the input files list |
| input_files_list = self.spark.read.parquet(tpath).inputFiles() |
| |
| # input files list should contain 10 entries |
| self.assertEqual(len(input_files_list), 10) |
| # all file paths in list must contain tpath |
| for file_path in input_files_list: |
| self.assertTrue(tpath in file_path) |
| finally: |
| shutil.rmtree(tpath) |
| |
| def test_df_show(self): |
| # SPARK-35408: ensure better diagnostics if incorrect parameters are passed |
| # to DataFrame.show |
| |
| df = self.spark.createDataFrame([("foo",)]) |
| df.show(5) |
| df.show(5, True) |
| df.show(5, 1, True) |
| df.show(n=5, truncate="1", vertical=False) |
| df.show(n=5, truncate=1.5, vertical=False) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.show(True) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_INT", |
| messageParameters={"arg_name": "n", "arg_type": "bool"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.show(vertical="foo") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_BOOL", |
| messageParameters={"arg_name": "vertical", "arg_type": "str"}, |
| ) |
| |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.show(truncate="foo") |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_BOOL", |
| messageParameters={"arg_name": "truncate", "arg_type": "str"}, |
| ) |
| |
| def test_df_merge_into(self): |
| filename_pattern = ( |
| "sql/catalyst/target/scala-*/test-classes/org/apache/spark/sql/connector/catalog/" |
| "InMemoryRowLevelOperationTableCatalog.class" |
| ) |
| if not bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))): |
| raise unittest.SkipTest( |
| "org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTableCatalog' " |
| "is not available. Will skip the related tests" |
| ) |
| |
| try: |
| # InMemoryRowLevelOperationTableCatalog is a test catalog that is included in the |
| # catalyst-test package. If Spark complains that it can't find this class, make sure |
| # the catalyst-test JAR does exist in "SPARK_HOME/assembly/target/scala-x.xx/jars" |
| # directory. If not, build it with `build/sbt test:assembly` and copy it over. |
| self.spark.conf.set( |
| "spark.sql.catalog.testcat", |
| "org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTableCatalog", |
| ) |
| with self.table("testcat.ns1.target"): |
| |
| def reset_target_table(): |
| self.spark.createDataFrame( |
| [(1, "Alice"), (2, "Bob")], ["id", "name"] |
| ).write.mode("overwrite").saveAsTable("testcat.ns1.target") |
| |
| source = self.spark.createDataFrame( |
| [(1, "Charlie"), (3, "David")], ["id", "name"] |
| ) # type: DataFrame |
| |
| from pyspark.sql.functions import col |
| |
| # Match -> update, NotMatch -> insert, NotMatchedBySource -> delete |
| reset_target_table() |
| # fmt: off |
| source.mergeInto("testcat.ns1.target", source.id == col("target.id")) \ |
| .whenMatched(source.id == 1).update({"name": source.name}) \ |
| .whenNotMatched().insert({"id": source.id, "name": source.name}) \ |
| .whenNotMatchedBySource().delete() \ |
| .merge() |
| # fmt: on |
| self.assertEqual( |
| self.spark.table("testcat.ns1.target").orderBy("id").collect(), |
| [Row(id=1, name="Charlie"), Row(id=3, name="David")], |
| ) |
| |
| # Match -> updateAll, NotMatch -> insertAll, NotMatchedBySource -> update |
| reset_target_table() |
| # fmt: off |
| source.mergeInto("testcat.ns1.target", source.id == col("target.id")) \ |
| .whenMatched(source.id == 1).updateAll() \ |
| .whenNotMatched(source.id == 3).insertAll() \ |
| .whenNotMatchedBySource(col("target.id") == lit(2)) \ |
| .update({"name": lit("not_matched")}) \ |
| .merge() |
| # fmt: on |
| self.assertEqual( |
| self.spark.table("testcat.ns1.target").orderBy("id").collect(), |
| [ |
| Row(id=1, name="Charlie"), |
| Row(id=2, name="not_matched"), |
| Row(id=3, name="David"), |
| ], |
| ) |
| |
| # Match -> delete, NotMatchedBySource -> delete |
| reset_target_table() |
| # fmt: off |
| self.spark.createDataFrame([(1, "AliceJr")], ["id", "name"]) \ |
| .write.mode("append").saveAsTable("testcat.ns1.target") |
| source.mergeInto("testcat.ns1.target", source.id == col("target.id")) \ |
| .whenMatched(col("target.name") != lit("AliceJr")).delete() \ |
| .whenNotMatchedBySource().delete() \ |
| .merge() |
| # fmt: on |
| self.assertEqual( |
| self.spark.table("testcat.ns1.target").orderBy("id").collect(), |
| [Row(id=1, name="AliceJr")], |
| ) |
| finally: |
| self.spark.conf.unset("spark.sql.catalog.testcat") |
| |
| @unittest.skipIf( |
| not have_pandas or not have_pyarrow, |
| cast(str, pandas_requirement_message or pyarrow_requirement_message), |
| ) |
| def test_pandas_api(self): |
| import pandas as pd |
| from pandas.testing import assert_frame_equal |
| |
| sdf = self.spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) |
| psdf_from_sdf = sdf.pandas_api() |
| psdf_from_sdf_with_index = sdf.pandas_api(index_col="Col1") |
| pdf = pd.DataFrame({"Col1": ["a", "b", "c"], "Col2": [1, 2, 3]}) |
| pdf_with_index = pdf.set_index("Col1") |
| |
| assert_frame_equal(pdf, psdf_from_sdf.to_pandas()) |
| assert_frame_equal(pdf_with_index, psdf_from_sdf_with_index.to_pandas()) |
| |
| def test_to(self): |
| schema = StructType( |
| [StructField("i", StringType(), True), StructField("j", IntegerType(), True)] |
| ) |
| df = self.spark.createDataFrame([("a", 1)], schema) |
| |
| schema1 = StructType([StructField("j", StringType()), StructField("i", StringType())]) |
| df1 = df.to(schema1) |
| self.assertEqual(schema1, df1.schema) |
| self.assertEqual(df.count(), df1.count()) |
| |
| schema2 = StructType([StructField("j", LongType())]) |
| df2 = df.to(schema2) |
| self.assertEqual(schema2, df2.schema) |
| self.assertEqual(df.count(), df2.count()) |
| |
| schema3 = StructType([StructField("struct", schema1, False)]) |
| df3 = df.select(struct("i", "j").alias("struct")).to(schema3) |
| self.assertEqual(schema3, df3.schema) |
| self.assertEqual(df.count(), df3.count()) |
| |
| # incompatible field nullability |
| schema4 = StructType([StructField("j", LongType(), False)]) |
| self.assertRaisesRegex( |
| AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4).count() |
| ) |
| |
| # field cannot upcast |
| schema5 = StructType([StructField("i", LongType())]) |
| self.assertRaisesRegex( |
| AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5).count() |
| ) |
| |
| def test_colregex(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.range(10).colRegex(10) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_STR", |
| messageParameters={"arg_name": "colName", "arg_type": "int"}, |
| ) |
| |
| def test_where(self): |
| with self.assertRaises(PySparkTypeError) as pe: |
| self.spark.range(10).where(10) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_COLUMN_OR_STR", |
| messageParameters={"arg_name": "condition", "arg_type": "int"}, |
| ) |
| |
| def test_duplicate_field_names(self): |
| data = [ |
| Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")), |
| Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")), |
| ] |
| schema = ( |
| StructType() |
| .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) |
| .add( |
| "struct", |
| StructType() |
| .add("a", IntegerType()) |
| .add("x", IntegerType()) |
| .add("x", StringType()) |
| .add("y", IntegerType()) |
| .add("y", StringType()) |
| .add("x", StringType()), |
| ) |
| ) |
| df = self.spark.createDataFrame(data, schema=schema) |
| |
| self.assertEqual(df.schema, schema) |
| self.assertEqual(df.collect(), data) |
| |
| def test_union_classmethod_usage(self): |
| df = self.spark.range(1) |
| self.assertEqual(DataFrame.union(df, df).collect(), [Row(id=0), Row(id=0)]) |
| |
| def test_isinstance_dataframe(self): |
| self.assertIsInstance(self.spark.range(1), DataFrame) |
| |
| def test_local_checkpoint_dataframe(self): |
| with io.StringIO() as buf, redirect_stdout(buf): |
| self.spark.range(1).localCheckpoint().explain() |
| self.assertIn("ExistingRDD", buf.getvalue()) |
| |
| def test_local_checkpoint_dataframe_with_storage_level(self): |
| # We don't have a way to reach into the server and assert the storage level server side, but |
| # this test should cover for unexpected errors in the API. |
| df = self.spark.range(10).localCheckpoint(eager=True, storageLevel=StorageLevel.DISK_ONLY) |
| df.collect() |
| |
| def test_transpose(self): |
| df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}]) |
| |
| # default index column |
| transposed_df = df.transpose() |
| expected_schema = StructType( |
| [StructField("key", StringType(), False), StructField("x", StringType(), True)] |
| ) |
| expected_data = [Row(key="b", x="y"), Row(key="c", x="z")] |
| expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema) |
| assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) |
| |
| # specified index column |
| transposed_df = df.transpose("c") |
| expected_schema = StructType( |
| [StructField("key", StringType(), False), StructField("z", StringType(), True)] |
| ) |
| expected_data = [Row(key="a", z="x"), Row(key="b", z="y")] |
| expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema) |
| assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) |
| |
| # enforce transpose max values |
| with self.sql_conf({"spark.sql.transposeMaxValues": 0}): |
| with self.assertRaises(AnalysisException) as pe: |
| df.transpose().collect() |
| self.check_error( |
| exception=pe.exception, |
| errorClass="TRANSPOSE_EXCEED_ROW_LIMIT", |
| messageParameters={"maxValues": "0", "config": "spark.sql.transposeMaxValues"}, |
| ) |
| |
| # enforce ascending order based on index column values for transposed columns |
| df = self.spark.createDataFrame([{"a": "z"}, {"a": "y"}, {"a": "x"}]) |
| transposed_df = df.transpose() |
| expected_schema = StructType( |
| [ |
| StructField("key", StringType(), False), |
| StructField("x", StringType(), True), |
| StructField("y", StringType(), True), |
| StructField("z", StringType(), True), |
| ] |
| ) # z, y, x -> x, y, z |
| expected_df = self.spark.createDataFrame([], schema=expected_schema) |
| assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) |
| |
| # enforce AtomicType Attribute for index column values |
| df = self.spark.createDataFrame([{"a": ["x", "x"], "b": "y", "c": "z"}]) |
| with self.assertRaises(AnalysisException) as pe: |
| df.transpose().collect() |
| self.check_error( |
| exception=pe.exception, |
| errorClass="TRANSPOSE_INVALID_INDEX_COLUMN", |
| messageParameters={ |
| "reason": "Index column must be of atomic type, " |
| "but found: ArrayType(StringType,true)" |
| }, |
| ) |
| |
| # enforce least common type for non-index columns |
| df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": 1}]) |
| with self.assertRaises(AnalysisException) as pe: |
| df.transpose().collect() |
| self.check_error( |
| exception=pe.exception, |
| errorClass="TRANSPOSE_NO_LEAST_COMMON_TYPE", |
| messageParameters={"dt1": '"STRING"', "dt2": '"BIGINT"'}, |
| ) |
| |
| def test_transpose_with_invalid_index_columns(self): |
| # SPARK-50602: invalid index columns |
| df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}]) |
| |
| with self.assertRaises(AnalysisException) as pe: |
| df.transpose(col("a") + 1).collect() |
| self.check_error( |
| exception=pe.exception, |
| errorClass="TRANSPOSE_INVALID_INDEX_COLUMN", |
| messageParameters={"reason": "Index column must be an atomic attribute"}, |
| ) |
| |
| def test_metadata_column(self): |
| with self.sql_conf( |
| {"spark.sql.catalog.testcat": "org.apache.spark.sql.connector.catalog.InMemoryCatalog"} |
| ): |
| tbl = "testcat.t" |
| with self.table(tbl): |
| self.spark.sql( |
| f""" |
| CREATE TABLE {tbl} (index bigint, data string) |
| PARTITIONED BY (bucket(4, index), index) |
| """ |
| ) |
| self.spark.sql(f"""INSERT INTO {tbl} VALUES (1, 'a'), (2, 'b'), (3, 'c')""") |
| |
| df = self.spark.sql(f"""SELECT * FROM {tbl}""") |
| assertDataFrameEqual( |
| df.select(df.metadataColumn("index")), |
| [Row(0), Row(0), Row(0)], |
| ) |
| |
| def test_with_column_and_generator(self): |
| # SPARK-51451: Generators should be available with withColumn |
| df = self.spark.createDataFrame([("082017",)], ["dt"]).select( |
| to_date(col("dt"), "MMyyyy").alias("dt") |
| ) |
| df_dt = df.withColumn("dt", date_format(col("dt"), "MM/dd/yyyy")) |
| monthArray = [lit(x) for x in range(0, 12)] |
| df_month_y = df_dt.withColumn("month_y", explode(array(monthArray))) |
| |
| assertDataFrameEqual( |
| df_month_y, |
| [Row(dt="08/01/2017", month_y=i) for i in range(12)], |
| ) |
| |
| df_dt_month_y = df.withColumns( |
| { |
| "dt": date_format(col("dt"), "MM/dd/yyyy"), |
| "month_y": explode(array(monthArray)), |
| } |
| ) |
| |
| assertDataFrameEqual( |
| df_dt_month_y, |
| [Row(dt="08/01/2017", month_y=i) for i in range(12)], |
| ) |
| |
| |
| class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): |
| def test_query_execution_unsupported_in_classic(self): |
| with self.assertRaises(PySparkValueError) as pe: |
| self.spark.range(1).executionInfo |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", |
| messageParameters={"member": "queryExecution"}, |
| ) |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.test_dataframe import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |