blob: 09c3171ee11fd067def835758d1ef38bfd2bde3b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import uuid
import datetime
import decimal
import math
from pyspark.testing.connectutils import (
PlanOnlyTestFixture,
should_test_connect,
connect_requirement_message,
)
from pyspark.errors import PySparkValueError
if should_test_connect:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.expressions import LiteralExpression
from pyspark.sql.connect.functions import col, lit, max, min, sum
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import (
StringType,
StructType,
StructField,
IntegerType,
MapType,
ArrayType,
DoubleType,
)
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectPlanTests(PlanOnlyTestFixture):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
def test_sql_project(self):
plan = self.connect.sql("SELECT 1")._plan.to_proto(self.connect)
self.assertEqual(plan.root.sql.query, "SELECT 1")
def test_simple_project(self):
plan = self.connect.readTable(table_name=self.tbl_name)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root, "Root relation must be set")
self.assertIsNotNone(plan.root.read)
def test_join_using_columns(self):
left_input = self.connect.readTable(table_name=self.tbl_name)
right_input = self.connect.readTable(table_name=self.tbl_name)
plan = left_input.join(other=right_input, on="join_column")._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.join.using_columns), 1)
plan2 = left_input.join(other=right_input, on=["col1", "col2"])._plan.to_proto(self.connect)
self.assertEqual(len(plan2.root.join.using_columns), 2)
def test_join_condition(self):
left_input = self.connect.readTable(table_name=self.tbl_name)
right_input = self.connect.readTable(table_name=self.tbl_name)
plan = left_input.join(
other=right_input, on=left_input.name == right_input.name
)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.join.join_condition)
plan = left_input.join(
other=right_input,
on=[left_input.name == right_input.name, left_input.age == right_input.age],
)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.join.join_condition)
def test_crossjoin(self):
# SPARK-41227: Test CrossJoin
left_input = self.connect.readTable(table_name=self.tbl_name)
right_input = self.connect.readTable(table_name=self.tbl_name)
crossJoin_plan = left_input.crossJoin(other=right_input)._plan.to_proto(self.connect)
join_plan = left_input.join(other=right_input, how="cross")._plan.to_proto(self.connect)
self.assertEqual(
crossJoin_plan.root.join.left.read.named_table,
join_plan.root.join.left.read.named_table,
)
self.assertEqual(
crossJoin_plan.root.join.right.read.named_table,
join_plan.root.join.right.read.named_table,
)
self.assertEqual(
crossJoin_plan.root.join.join_type,
join_plan.root.join.join_type,
)
def test_filter(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.filter)
self.assertTrue(
isinstance(
plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction
)
)
self.assertEqual(plan.root.filter.condition.unresolved_function.function_name, ">")
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2)
def test_filter_with_string_expr(self):
"""SPARK-41297: filter supports SQL expression"""
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter("id < 10")._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.filter)
self.assertIsNotNone(plan.root.filter.condition.expression_string)
self.assertEqual(plan.root.filter.condition.expression_string.expression, "id < 10")
def test_fill_na(self):
# SPARK-41128: Test fill na
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.fillna(value=1)._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 1)
self.assertEqual(plan.root.fill_na.values[0].long, 1)
self.assertEqual(plan.root.fill_na.cols, [])
plan = df.na.fill(value="xyz")._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 1)
self.assertEqual(plan.root.fill_na.values[0].string, "xyz")
self.assertEqual(plan.root.fill_na.cols, [])
plan = df.na.fill(value="xyz", subset=["col_a", "col_b"])._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 1)
self.assertEqual(plan.root.fill_na.values[0].string, "xyz")
self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"])
plan = df.na.fill(value=True, subset=("col_a", "col_b", "col_c"))._plan.to_proto(
self.connect
)
self.assertEqual(len(plan.root.fill_na.values), 1)
self.assertEqual(plan.root.fill_na.values[0].boolean, True)
self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b", "col_c"])
plan = df.fillna({"col_a": 1.5, "col_b": "abc"})._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 2)
self.assertEqual(plan.root.fill_na.values[0].double, 1.5)
self.assertEqual(plan.root.fill_na.values[1].string, "abc")
self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"])
def test_drop_na(self):
# SPARK-41148: Test drop na
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.dropna()._plan.to_proto(self.connect)
self.assertEqual(plan.root.drop_na.cols, [])
self.assertEqual(plan.root.drop_na.HasField("min_non_nulls"), False)
plan = df.na.drop(thresh=2, subset=("col_a", "col_b"))._plan.to_proto(self.connect)
self.assertEqual(plan.root.drop_na.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.drop_na.min_non_nulls, 2)
plan = df.dropna(how="all", subset="col_c")._plan.to_proto(self.connect)
self.assertEqual(plan.root.drop_na.cols, ["col_c"])
self.assertEqual(plan.root.drop_na.min_non_nulls, 1)
def test_replace(self):
# SPARK-41315: Test replace
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.replace(10, 20)._plan.to_proto(self.connect)
self.assertEqual(plan.root.replace.cols, [])
self.assertEqual(plan.root.replace.replacements[0].old_value.double, 10.0)
self.assertEqual(plan.root.replace.replacements[0].new_value.double, 20.0)
plan = df.na.replace((1, 2, 3), (4, 5, 6), subset=("col_a", "col_b"))._plan.to_proto(
self.connect
)
self.assertEqual(plan.root.replace.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.replace.replacements[0].old_value.double, 1.0)
self.assertEqual(plan.root.replace.replacements[0].new_value.double, 4.0)
self.assertEqual(plan.root.replace.replacements[1].old_value.double, 2.0)
self.assertEqual(plan.root.replace.replacements[1].new_value.double, 5.0)
self.assertEqual(plan.root.replace.replacements[2].old_value.double, 3.0)
self.assertEqual(plan.root.replace.replacements[2].new_value.double, 6.0)
plan = df.replace(["Alice", "Bob"], ["A", "B"], subset="col_x")._plan.to_proto(self.connect)
self.assertEqual(plan.root.replace.cols, ["col_x"])
self.assertEqual(plan.root.replace.replacements[0].old_value.string, "Alice")
self.assertEqual(plan.root.replace.replacements[0].new_value.string, "A")
self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob")
self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B")
def test_unpivot(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = (
df.filter(df.col_name > 3)
.unpivot(["id"], ["name"], "variable", "value")
._plan.to_proto(self.connect)
)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
self.assertEqual(plan.root.unpivot.HasField("values"), True)
self.assertTrue(
all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
)
self.assertEqual(
plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
plan = (
df.filter(df.col_name > 3)
.unpivot(["id"], None, "variable", "value")
._plan.to_proto(self.connect)
)
self.assertTrue(len(plan.root.unpivot.ids) == 1)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
self.assertEqual(plan.root.unpivot.HasField("values"), False)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
def test_melt(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = (
df.filter(df.col_name > 3)
.melt(["id"], ["name"], "variable", "value")
._plan.to_proto(self.connect)
)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
self.assertEqual(plan.root.unpivot.HasField("values"), True)
self.assertTrue(
all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
)
self.assertEqual(
plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
plan = (
df.filter(df.col_name > 3)
.melt(["id"], [], "variable", "value")
._plan.to_proto(self.connect)
)
self.assertTrue(len(plan.root.unpivot.ids) == 1)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
self.assertEqual(plan.root.unpivot.HasField("values"), True)
self.assertTrue(len(plan.root.unpivot.values.values) == 0)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
def test_random_split(self):
# SPARK-41440: test randomSplit(weights, seed).
from typing import List
df = self.connect.readTable(table_name=self.tbl_name)
def checkRelations(relations: List["DataFrame"]):
self.assertTrue(len(relations) == 3)
plan = relations[0]._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.16666666666666666)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
self.assertEqual(plan.root.sample.deterministic_order, True)
plan = relations[1]._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.16666666666666666)
self.assertEqual(plan.root.sample.upper_bound, 0.5)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
self.assertEqual(plan.root.sample.deterministic_order, True)
plan = relations[2]._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.5)
self.assertEqual(plan.root.sample.upper_bound, 1.0)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
self.assertEqual(plan.root.sample.deterministic_order, True)
relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0], 1)
checkRelations(relations)
relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0])
checkRelations(relations)
def test_observe(self):
# SPARK-41527: test DataFrame.observe()
df = self.connect.readTable(table_name=self.tbl_name)
plan = (
df.filter(df.col_name > 3)
.observe("my_metric", min("id"), max("id"), sum("id"))
._plan.to_proto(self.connect)
)
self.assertEqual(plan.root.collect_metrics.name, "my_metric")
self.assertTrue(
all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
)
self.assertEqual(
plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
)
self.assertTrue(
len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
)
self.assertTrue(
all(
isinstance(c, proto.Expression)
for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
)
)
self.assertEqual(
plan.root.collect_metrics.metrics[0]
.unresolved_function.arguments[0]
.unresolved_attribute.unparsed_identifier,
"id",
)
from pyspark.sql.connect.observation import Observation
class MockDF(DataFrame):
def __new__(cls, df: DataFrame) -> "DataFrame":
self = object.__new__(cls)
self.__init__(df) # type: ignore[misc]
return self
def __init__(self, df: DataFrame):
super().__init__(df._plan, df._session)
@property
def isStreaming(self) -> bool:
return False
plan = (
MockDF(df.filter(df.col_name > 3))
.observe(Observation("my_metric"), min("id"), max("id"), sum("id"))
._plan.to_proto(self.connect)
)
self.assertEqual(plan.root.collect_metrics.name, "my_metric")
self.assertTrue(
all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
)
self.assertEqual(
plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
)
self.assertTrue(
len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
)
self.assertTrue(
all(
isinstance(c, proto.Expression)
for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
)
)
self.assertEqual(
plan.root.collect_metrics.metrics[0]
.unresolved_function.arguments[0]
.unresolved_attribute.unparsed_identifier,
"id",
)
def test_summary(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
self.assertEqual(plan.root.summary.statistics, [])
plan = (
df.filter(df.col_name > 3)
.summary("count", "mean", "stddev", "min", "25%")
._plan.to_proto(self.connect)
)
self.assertEqual(
plan.root.summary.statistics,
["count", "mean", "stddev", "min", "25%"],
)
def test_describe(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).describe()._plan.to_proto(self.connect)
self.assertEqual(plan.root.describe.cols, [])
plan = df.filter(df.col_name > 3).describe("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(
plan.root.describe.cols,
["col_a", "col_b"],
)
def test_crosstab(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).crosstab("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(plan.root.crosstab.col1, "col_a")
self.assertEqual(plan.root.crosstab.col2, "col_b")
plan = df.stat.crosstab("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(plan.root.crosstab.col1, "col_a")
self.assertEqual(plan.root.crosstab.col2, "col_b")
def test_freqItems(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = (
df.filter(df.col_name > 3).freqItems(["col_a", "col_b"], 1)._plan.to_proto(self.connect)
)
self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.freq_items.support, 1)
plan = df.filter(df.col_name > 3).freqItems(["col_a", "col_b"])._plan.to_proto(self.connect)
self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.freq_items.support, 0.01)
plan = df.stat.freqItems(["col_a", "col_b"], 1)._plan.to_proto(self.connect)
self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.freq_items.support, 1)
plan = df.stat.freqItems(["col_a", "col_b"])._plan.to_proto(self.connect)
self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
self.assertEqual(plan.root.freq_items.support, 0.01)
def test_limit(self):
df = self.connect.readTable(table_name=self.tbl_name)
limit_plan = df.limit(10)._plan.to_proto(self.connect)
self.assertEqual(limit_plan.root.limit.limit, 10)
def test_offset(self):
df = self.connect.readTable(table_name=self.tbl_name)
offset_plan = df.offset(10)._plan.to_proto(self.connect)
self.assertEqual(offset_plan.root.offset.offset, 10)
def test_sample(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).sample(fraction=0.3)._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.3)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), False)
self.assertEqual(plan.root.sample.deterministic_order, False)
plan = (
df.filter(df.col_name > 3)
.sample(withReplacement=True, fraction=0.4, seed=-1)
._plan.to_proto(self.connect)
)
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.4)
self.assertEqual(plan.root.sample.with_replacement, True)
self.assertEqual(plan.root.sample.seed, -1)
self.assertEqual(plan.root.sample.deterministic_order, False)
def test_sort(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).sort("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(
[f.child.unresolved_attribute.unparsed_identifier for f in plan.root.sort.order],
["col_a", "col_b"],
)
self.assertEqual(plan.root.sort.is_global, True)
self.assertEqual(
plan.root.sort.order[0].direction,
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING,
)
self.assertEqual(
plan.root.sort.order[0].direction,
proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST,
)
self.assertEqual(
plan.root.sort.order[1].direction,
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING,
)
self.assertEqual(
plan.root.sort.order[1].direction,
proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST,
)
plan = df.filter(df.col_name > 3).orderBy("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(
[f.child.unresolved_attribute.unparsed_identifier for f in plan.root.sort.order],
["col_a", "col_b"],
)
self.assertEqual(plan.root.sort.is_global, True)
plan = (
df.filter(df.col_name > 3)
.sortWithinPartitions(df.col_a.desc(), df.col_b.asc())
._plan.to_proto(self.connect)
)
self.assertEqual(
[f.child.unresolved_attribute.unparsed_identifier for f in plan.root.sort.order],
["col_a", "col_b"],
)
self.assertEqual(plan.root.sort.is_global, False)
self.assertEqual(
plan.root.sort.order[0].direction,
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING,
)
self.assertEqual(
plan.root.sort.order[0].direction,
proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST,
)
self.assertEqual(
plan.root.sort.order[1].direction,
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING,
)
self.assertEqual(
plan.root.sort.order[1].direction,
proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST,
)
def test_drop(self):
# SPARK-41169: test drop
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).drop("col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(
plan.root.drop.column_names,
["col_a", "col_b"],
)
plan = df.filter(df.col_name > 3).drop(df.col_x, "col_b")._plan.to_proto(self.connect)
self.assertEqual(
[f.unresolved_attribute.unparsed_identifier for f in plan.root.drop.columns],
["col_x"],
)
self.assertEqual(
plan.root.drop.column_names,
["col_b"],
)
def test_deduplicate(self):
df = self.connect.readTable(table_name=self.tbl_name)
distinct_plan = df.distinct()._plan.to_proto(self.connect)
self.assertTrue(distinct_plan.root.deduplicate.HasField("input"), "input must be set")
self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True)
self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0)
deduplicate_on_all_columns_plan = df.dropDuplicates()._plan.to_proto(self.connect)
self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True)
self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0)
deduplicate_on_all_columns_plan = df.drop_duplicates()._plan.to_proto(self.connect)
self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True)
self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0)
deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto(
self.connect
)
self.assertEqual(
deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False
)
self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2)
def test_relation_alias(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.alias("table_alias")._plan.to_proto(self.connect)
self.assertEqual(plan.root.subquery_alias.alias, "table_alias")
self.assertIsNotNone(plan.root.subquery_alias.input)
def test_range(self):
plan = self.connect.range(start=10, end=20, step=3, num_partitions=4)._plan.to_proto(
self.connect
)
self.assertEqual(plan.root.range.start, 10)
self.assertEqual(plan.root.range.end, 20)
self.assertEqual(plan.root.range.step, 3)
self.assertEqual(plan.root.range.num_partitions, 4)
plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect)
self.assertEqual(plan.root.range.start, 10)
self.assertEqual(plan.root.range.end, 20)
self.assertEqual(plan.root.range.step, 1)
self.assertFalse(plan.root.range.HasField("num_partitions"))
def test_datasource_read(self):
reader = DataFrameReader(self.connect)
df = reader.load(path="test_path", format="text", schema="id INT", op1="opv", op2="opv2")
plan = df._plan.to_proto(self.connect)
data_source = plan.root.read.data_source
self.assertEqual(data_source.format, "text")
self.assertEqual(data_source.schema, "id INT")
self.assertEqual(len(data_source.options), 2)
self.assertEqual(data_source.options.get("op1"), "opv")
self.assertEqual(data_source.options.get("op2"), "opv2")
self.assertEqual(len(data_source.paths), 1)
self.assertEqual(data_source.paths[0], "test_path")
def test_all_the_plans(self):
df = self.connect.readTable(table_name=self.tbl_name)
df = df.select(df.col1).filter(df.col2 == 2).sort(df.col3.asc())
plan = df._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root, "Root relation must be set")
self.assertIsNotNone(plan.root.read)
def test_union(self):
df1 = self.connect.readTable(table_name=self.tbl_name)
df2 = self.connect.readTable(table_name=self.tbl_name)
plan1 = df1.union(df2)._plan.to_proto(self.connect)
self.assertTrue(plan1.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan1.root.set_op.set_op_type)
plan2 = df1.union(df2)._plan.to_proto(self.connect)
self.assertTrue(plan2.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan2.root.set_op.set_op_type)
plan3 = df1.unionByName(df2, True)._plan.to_proto(self.connect)
self.assertTrue(plan3.root.set_op.by_name)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan3.root.set_op.set_op_type)
def test_subtract(self):
# SPARK-41453: test `subtract` API for Python client.
df1 = self.connect.readTable(table_name=self.tbl_name)
df2 = self.connect.readTable(table_name=self.tbl_name)
plan1 = df1.subtract(df2)._plan.to_proto(self.connect)
self.assertTrue(not plan1.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_EXCEPT, plan1.root.set_op.set_op_type)
def test_except(self):
# SPARK-41010: test `except` API for Python client.
df1 = self.connect.readTable(table_name=self.tbl_name)
df2 = self.connect.readTable(table_name=self.tbl_name)
plan1 = df1.exceptAll(df2)._plan.to_proto(self.connect)
self.assertTrue(plan1.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_EXCEPT, plan1.root.set_op.set_op_type)
def test_intersect(self):
# SPARK-41010: test `intersect` API for Python client.
df1 = self.connect.readTable(table_name=self.tbl_name)
df2 = self.connect.readTable(table_name=self.tbl_name)
plan1 = df1.intersect(df2)._plan.to_proto(self.connect)
self.assertFalse(plan1.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_INTERSECT, plan1.root.set_op.set_op_type)
plan2 = df1.intersectAll(df2)._plan.to_proto(self.connect)
self.assertTrue(plan2.root.set_op.is_all)
self.assertEqual(proto.SetOperation.SET_OP_TYPE_INTERSECT, plan2.root.set_op.set_op_type)
def test_coalesce_and_repartition(self):
# SPARK-41026: test Coalesce and Repartition API in Python client.
df = self.connect.readTable(table_name=self.tbl_name)
plan1 = df.coalesce(10)._plan.to_proto(self.connect)
self.assertEqual(10, plan1.root.repartition.num_partitions)
self.assertFalse(plan1.root.repartition.shuffle)
plan2 = df.repartition(20)._plan.to_proto(self.connect)
self.assertTrue(plan2.root.repartition.shuffle)
with self.assertRaises(PySparkValueError) as pe:
df.coalesce(-1)._plan.to_proto(self.connect)
self.check_error(
exception=pe.exception,
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "numPartitions", "arg_value": "-1"},
)
with self.assertRaises(PySparkValueError) as pe:
df.repartition(-1)._plan.to_proto(self.connect)
self.check_error(
exception=pe.exception,
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "numPartitions", "arg_value": "-1"},
)
def test_repartition_by_expression(self):
# SPARK-41354: test dataframe.repartition(expressions)
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.repartition(10, "col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(10, plan.root.repartition_by_expression.num_partitions)
self.assertEqual(
[
f.unresolved_attribute.unparsed_identifier
for f in plan.root.repartition_by_expression.partition_exprs
],
["col_a", "col_b"],
)
def test_repartition_by_range(self):
# SPARK-41354: test dataframe.repartitionByRange(expressions)
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.repartitionByRange(10, "col_a", "col_b")._plan.to_proto(self.connect)
self.assertEqual(10, plan.root.repartition_by_expression.num_partitions)
self.assertEqual(
[
f.sort_order.child.unresolved_attribute.unparsed_identifier
for f in plan.root.repartition_by_expression.partition_exprs
],
["col_a", "col_b"],
)
def test_to(self):
# SPARK-41464: test `to` API in Python client.
df = self.connect.readTable(table_name=self.tbl_name)
schema = StructType(
[
StructField("col1", IntegerType(), True),
StructField("col2", StringType(), True),
StructField("map1", MapType(StringType(), IntegerType(), True), True),
StructField("array1", ArrayType(IntegerType(), True), True),
]
)
new_plan = df.to(schema)._plan.to_proto(self.connect)
self.assertEqual(pyspark_types_to_proto_types(schema), new_plan.root.to_schema.schema)
def test_write_operation(self):
wo = WriteOperation(self.connect.readTable("name")._plan)
wo.mode = "overwrite"
wo.source = "parquet"
p = wo.command(None)
self.assertIsNotNone(p)
self.assertFalse(p.write_operation.HasField("path"))
self.assertFalse(p.write_operation.HasField("table"))
wo.path = "path"
p = wo.command(None)
self.assertIsNotNone(p)
self.assertTrue(p.write_operation.HasField("path"))
self.assertFalse(p.write_operation.HasField("table"))
wo.path = None
wo.table_name = "table"
wo.table_save_method = "save_as_table"
p = wo.command(None)
self.assertFalse(p.write_operation.HasField("path"))
self.assertTrue(p.write_operation.HasField("table"))
wo.bucket_cols = ["a", "b", "c"]
p = wo.command(None)
self.assertFalse(p.write_operation.HasField("bucket_by"))
wo.num_buckets = 10
p = wo.command(None)
self.assertTrue(p.write_operation.HasField("bucket_by"))
# Unsupported save mode
wo.mode = "unknown"
with self.assertRaises(ValueError):
wo.command(None)
def test_column_regexp(self):
# SPARK-41438: test colRegex
df = self.connect.readTable(table_name=self.tbl_name)
col = df.colRegex("col_name")
self.assertIsInstance(col, Column)
self.assertEqual("Column<'UnresolvedRegex(col_name)'>", str(col))
col_plan = col.to_plan(self.session.client)
self.assertIsNotNone(col_plan)
self.assertEqual(col_plan.unresolved_regex.col_name, "col_name")
def test_print(self):
# SPARK-41717: test print
self.assertEqual(
self.connect.sql("SELECT 1")._plan.print().strip(),
"<SQL query='SELECT 1', args='None', named_args='None', views='None'>",
)
self.assertEqual(
self.connect.range(1, 10)._plan.print().strip(),
"<Range start='1', end='10', step='1', num_partitions='None'>",
)
def test_repr(self):
# SPARK-41717: test __repr_html__
self.assertIn("query: SELECT 1", self.connect.sql("SELECT 1")._plan._repr_html_().strip())
expected = (
"<b>Range</b><br/>",
"start: 1 <br/>",
"end: 10 <br/>",
"step: 1 <br/>",
"num_partitions: None <br/>",
)
actual = self.connect.range(1, 10)._plan._repr_html_().strip()
for line in expected:
self.assertIn(line, actual)
def test_select_with_columns_and_strings(self):
df = self.connect.with_plan(Read("table"))
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
self.assertIsNotNone(df.select("name"))
self.assertIsNotNone(df.select("name", "name2"))
self.assertIsNotNone(df.select(col("name"), col("name2")))
self.assertIsNotNone(df.select(col("name"), "name2"))
self.assertIsNotNone(df.select("*"))
def test_join_with_join_type(self):
df_left = self.connect.with_plan(Read("table"))
df_right = self.connect.with_plan(Read("table"))
for join_type_str, join_type in [
(None, proto.Join.JoinType.JOIN_TYPE_INNER),
("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
("full", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
("fullouter", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
("full_outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
("left", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
("left_outer", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
("right", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
("right_outer", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
("semi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
("left_semi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
("anti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
("left_anti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
("cross", proto.Join.JoinType.JOIN_TYPE_CROSS),
]:
joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto(
self.connect
)
self.assertEqual(joined_df.root.join.join_type, join_type)
def test_simple_column_expressions(self):
df = self.connect.with_plan(Read("table"))
c1 = df.col_name
self.assertIsInstance(c1, Column)
c2 = df["col_name"]
self.assertIsInstance(c2, Column)
c3 = col("col_name")
self.assertIsInstance(c3, Column)
# All Protos should be identical
cp1 = c1.to_plan(None)
cp2 = c2.to_plan(None)
cp3 = c3.to_plan(None)
self.assertIsNotNone(cp1)
self.assertEqual(cp1, cp2)
self.assertEqual(
cp2.unresolved_attribute.unparsed_identifier,
cp3.unresolved_attribute.unparsed_identifier,
)
self.assertTrue(cp2.unresolved_attribute.HasField("plan_id"))
self.assertFalse(cp3.unresolved_attribute.HasField("plan_id"))
def test_null_literal(self):
null_lit = lit(None)
null_lit_p = null_lit.to_plan(None)
self.assertEqual(null_lit_p.literal.HasField("null"), True)
def test_binary_literal(self):
val = b"binary\0\0asas"
bin_lit = lit(val)
bin_lit_p = bin_lit.to_plan(None)
self.assertEqual(bin_lit_p.literal.binary, val)
def test_uuid_literal(self):
val = uuid.uuid4()
with self.assertRaises(TypeError):
lit(val)
def test_column_literals(self):
df = self.connect.with_plan(Read("table"))
lit_df = df.select(lit(10))
self.assertIsNotNone(lit_df._plan.to_proto(None))
self.assertIsNotNone(lit(10).to_plan(None))
plan = lit(10).to_plan(None)
self.assertIs(plan.literal.integer, 10)
plan = lit(1 << 33).to_plan(None)
self.assertEqual(plan.literal.long, 1 << 33)
def test_numeric_literal_types(self):
int_lit = lit(10)
float_lit = lit(10.1)
decimal_lit = lit(decimal.Decimal(99))
self.assertIsNotNone(int_lit.to_plan(None))
self.assertIsNotNone(float_lit.to_plan(None))
self.assertIsNotNone(decimal_lit.to_plan(None))
def test_float_nan_inf(self):
na_lit = lit(float("nan"))
self.assertIsNotNone(na_lit.to_plan(None))
inf_lit = lit(float("inf"))
self.assertIsNotNone(inf_lit.to_plan(None))
inf_lit = lit(float("-inf"))
self.assertIsNotNone(inf_lit.to_plan(None))
def test_datetime_literal_types(self):
"""Test the different timestamp, date, and timedelta types."""
datetime_lit = lit(datetime.datetime.now())
p = datetime_lit.to_plan(None)
self.assertIsNotNone(datetime_lit.to_plan(None))
self.assertGreater(p.literal.timestamp, 10000000000000)
date_lit = lit(datetime.date.today())
time_delta = lit(datetime.timedelta(days=1, seconds=2, microseconds=3))
self.assertIsNotNone(date_lit.to_plan(None))
self.assertIsNotNone(time_delta.to_plan(None))
# (24 * 3600 + 2) * 1000000 + 3
self.assertEqual(86402000003, time_delta.to_plan(None).literal.day_time_interval)
def test_list_to_literal(self):
"""Test conversion of lists to literals"""
empty_list = []
single_type = [1, 2, 3, 4]
multi_type = ["ooo", 1, "asas", 2.3]
empty_list_lit = lit(empty_list)
single_type_lit = lit(single_type)
multi_type_lit = lit(multi_type)
p = empty_list_lit.to_plan(None)
self.assertIsNotNone(p)
p = single_type_lit.to_plan(None)
self.assertIsNotNone(p)
p = multi_type_lit.to_plan(None)
self.assertIsNotNone(p)
def test_column_alias(self) -> None:
# SPARK-40809: Support for Column Aliases
col0 = col("a").alias("martin")
self.assertEqual("Column<'a AS martin'>", str(col0))
col0 = col("a").alias("martin", metadata={"pii": True})
plan = col0.to_plan(self.session.client)
self.assertIsNotNone(plan)
self.assertEqual(plan.alias.metadata, '{"pii": true}')
def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
the protobuf structure."""
df = self.connect.with_plan(Read("table"))
expr = lit(10) < lit(10)
expr_plan = expr.to_plan(None)
self.assertIsNotNone(expr_plan.unresolved_function)
self.assertEqual(expr_plan.unresolved_function.function_name, "<")
expr = df.id % lit(10) == lit(10)
expr_plan = expr.to_plan(None)
self.assertIsNotNone(expr_plan.unresolved_function)
self.assertEqual(expr_plan.unresolved_function.function_name, "==")
lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, proto.Expression)
self.assertIsInstance(lit_fun.literal, proto.Expression.Literal)
self.assertEqual(lit_fun.literal.integer, 10)
mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, proto.Expression)
self.assertIsInstance(mod_fun.unresolved_function, proto.Expression.UnresolvedFunction)
self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
self.assertIsInstance(mod_fun.unresolved_function.arguments[0], proto.Expression)
self.assertIsInstance(
mod_fun.unresolved_function.arguments[0].unresolved_attribute,
proto.Expression.UnresolvedAttribute,
)
self.assertEqual(
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier, "id"
)
def test_literal_expression_with_arrays(self):
l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
self.assertTrue(l0.array.element_type.HasField("string"))
self.assertEqual(len(l0.array.elements), 3)
self.assertEqual(l0.array.elements[0].string, "x")
self.assertEqual(l0.array.elements[1].string, "y")
self.assertEqual(l0.array.elements[2].string, "z")
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
self.assertTrue(l1.array.element_type.HasField("integer"))
self.assertEqual(len(l1.array.elements), 2)
self.assertEqual(l1.array.elements[0].integer, 3)
self.assertEqual(l1.array.elements[1].integer, -3)
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
self.assertTrue(l2.array.element_type.HasField("double"))
self.assertEqual(len(l2.array.elements), 3)
self.assertTrue(math.isnan(l2.array.elements[0].double))
self.assertEqual(l2.array.elements[1].double, -3.0)
self.assertEqual(l2.array.elements[2].double, 0.0)
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
self.assertTrue(l3.array.element_type.HasField("array"))
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
self.assertEqual(len(l3.array.elements), 2)
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
l4 = (
LiteralExpression._from_value([[float("inf"), 0.4], [0.5, float("nan")], []])
.to_plan(None)
.literal
)
self.assertTrue(l4.array.element_type.HasField("array"))
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
self.assertEqual(len(l4.array.elements), 3)
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
self.assertEqual(len(l4.array.elements[2].array.elements), 0)
def test_literal_to_any_conversion(self):
for value in [
b"binary\0\0asas",
True,
False,
0,
12,
-1,
0.0,
1.234567,
decimal.Decimal(0.0),
decimal.Decimal(1.234567),
"sss",
datetime.date(2022, 12, 13),
datetime.datetime.now(),
datetime.timedelta(1, 2, 3),
[1, 2, 3, 4, 5, 6],
[-1.0, 2.0, 3.0],
["x", "y", "z"],
[[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]],
]:
lit = LiteralExpression._from_value(value)
proto_lit = lit.to_plan(None).literal
value2 = LiteralExpression._to_value(proto_lit)
self.assertEqual(value, value2)
with self.assertRaises(AssertionError):
lit = LiteralExpression._from_value(1.234567)
proto_lit = lit.to_plan(None).literal
LiteralExpression._to_value(proto_lit, StringType())
with self.assertRaises(AssertionError):
lit = LiteralExpression._from_value("1.234567")
proto_lit = lit.to_plan(None).literal
LiteralExpression._to_value(proto_lit, DoubleType())
with self.assertRaises(AssertionError):
# build a array<string> proto literal, but with incorrect elements
proto_lit = proto.Expression().literal
proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType()))
proto_lit.array.elements.append(
LiteralExpression("string", StringType()).to_plan(None).literal
)
proto_lit.array.elements.append(
LiteralExpression(1.234, DoubleType()).to_plan(None).literal
)
LiteralExpression._to_value(proto_lit, DoubleType)
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)