blob: 3d3303fb15c57011553ad75b66c8b723f86e2615 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: disable-error-code="operator"
from pyspark.resource import ResourceProfile
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
from typing import (
Any,
List,
Optional,
Type,
Sequence,
Union,
cast,
TYPE_CHECKING,
Mapping,
Dict,
Tuple,
)
import functools
import json
import pickle
from threading import Lock
from inspect import signature, isclass
import pyarrow as pa
from pyspark.serializers import CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.sql.types import DataType
import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.expressions import Expression
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
from pyspark.errors import (
PySparkValueError,
PySparkPicklingError,
IllegalArgumentException,
)
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.observation import Observation
class LogicalPlan:
_lock: Lock = Lock()
_nextPlanId: int = 0
INDENT = 2
def __init__(self, child: Optional["LogicalPlan"]) -> None:
self._child = child
self._plan_id = LogicalPlan._fresh_plan_id()
@staticmethod
def _fresh_plan_id() -> int:
plan_id: Optional[int] = None
with LogicalPlan._lock:
plan_id = LogicalPlan._nextPlanId
LogicalPlan._nextPlanId += 1
assert plan_id is not None
return plan_id
def _create_proto_relation(self) -> proto.Relation:
plan = proto.Relation()
plan.common.plan_id = self._plan_id
return plan
def plan(self, session: "SparkConnectClient") -> proto.Relation: # type: ignore[empty-body]
...
def command(self, session: "SparkConnectClient") -> proto.Command: # type: ignore[empty-body]
...
def _verify(self, session: "SparkConnectClient") -> bool:
"""This method is used to verify that the current logical plan
can be serialized to Proto and back and afterwards is identical."""
plan = proto.Plan()
plan.root.CopyFrom(self.plan(session))
serialized_plan = plan.SerializeToString()
test_plan = proto.Plan()
test_plan.ParseFromString(serialized_plan)
return test_plan == plan
def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan:
"""
Generates connect proto plan based on this LogicalPlan.
Parameters
----------
session : :class:`SparkConnectClient`, optional.
a session that connects remote spark cluster.
debug: bool
if enabled, the proto plan will be printed.
"""
plan = proto.Plan()
plan.root.CopyFrom(self.plan(session))
if debug:
print(plan)
return plan
@property
def observations(self) -> Dict[str, "Observation"]:
if self._child is None:
return {}
else:
return self._child.observations
def _parameters_to_print(self, parameters: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Extracts the parameters that are able to be printed. It looks up the signature
in the constructor of this :class:`LogicalPlan`, and retrieves the variables
from this instance by the same name (or the name with prefix `_`) defined
in the constructor.
Parameters
----------
parameters : map
Parameter mapping from ``inspect.signature(...).parameters``
Returns
-------
dict
A dictionary consisting of a string name and variable found in this
:class:`LogicalPlan`.
Notes
-----
:class:`LogicalPlan` itself is filtered out and considered as a non-printable
parameter.
Examples
--------
The example below returns a dictionary from `self._start`, `self._end`,
`self._num_partitions`.
>>> rg = Range(0, 10, 1)
>>> rg._parameters_to_print(signature(rg.__class__.__init__).parameters)
{'start': 0, 'end': 10, 'step': 1, 'num_partitions': None}
If the child is defined, it is not considered as a printable instance
>>> project = Project(rg, "value")
>>> project._parameters_to_print(signature(project.__class__.__init__).parameters)
{'columns': ['value']}
"""
params = {}
for name, tpe in parameters.items():
# LogicalPlan is not to print, e.g., LogicalPlan
is_logical_plan = isclass(tpe.annotation) and isinstance(tpe.annotation, LogicalPlan)
# Look up the string argument defined as a forward reference e.g., "LogicalPlan"
is_forwardref_logical_plan = getattr(tpe.annotation, "__forward_arg__", "").endswith(
"LogicalPlan"
)
# Wrapped LogicalPlan, e.g., Optional[LogicalPlan]
is_nested_logical_plan = any(
isclass(a) and issubclass(a, LogicalPlan)
for a in getattr(tpe.annotation, "__args__", ())
)
# Wrapped forward reference of LogicalPlan, e.g., Optional["LogicalPlan"].
is_nested_forwardref_logical_plan = any(
getattr(a, "__forward_arg__", "").endswith("LogicalPlan")
for a in getattr(tpe.annotation, "__args__", ())
)
if (
not is_logical_plan
and not is_forwardref_logical_plan
and not is_nested_logical_plan
and not is_nested_forwardref_logical_plan
):
# Searches self.name or self._name
try:
params[name] = getattr(self, name)
except AttributeError:
try:
params[name] = getattr(self, "_" + name)
except AttributeError:
pass # Simpy ignore
return params
def print(self, indent: int = 0) -> str:
"""
Print the simple string representation of the current :class:`LogicalPlan`.
Parameters
----------
indent : int
The number of leading spaces for the output string.
Returns
-------
str
Simple string representation of this :class:`LogicalPlan`.
"""
params = self._parameters_to_print(signature(self.__class__.__init__).parameters)
pretty_params = [f"{name}='{param}'" for name, param in params.items()]
if len(pretty_params) == 0:
pretty_str = ""
else:
pretty_str = " " + ", ".join(pretty_params)
return f"{' ' * indent}<{self.__class__.__name__}{pretty_str}>\n{self._child_print(indent)}"
def _repr_html_(self) -> str:
"""Returns a :class:`LogicalPlan` with HTML code. This is generally called in third-party
systems such as Jupyter.
Returns
-------
str
HTML representation of this :class:`LogicalPlan`.
"""
params = self._parameters_to_print(signature(self.__class__.__init__).parameters)
pretty_params = [
f"\n {name}: " f"{param} <br/>" for name, param in params.items()
]
if len(pretty_params) == 0:
pretty_str = ""
else:
pretty_str = "".join(pretty_params)
return f"""
<ul>
<li>
<b>{self.__class__.__name__}</b><br/>{pretty_str}
{self._child_repr()}
</li>
</ul>
"""
def _child_print(self, indent: int) -> str:
return self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
def _child_repr(self) -> str:
return self._child._repr_html_() if self._child is not None else ""
class DataSource(LogicalPlan):
"""A datasource with a format and optional a schema from which Spark reads data"""
def __init__(
self,
format: Optional[str] = None,
schema: Optional[str] = None,
options: Optional[Mapping[str, str]] = None,
paths: Optional[List[str]] = None,
predicates: Optional[List[str]] = None,
is_streaming: Optional[bool] = None,
) -> None:
super().__init__(None)
assert format is None or isinstance(format, str)
assert schema is None or isinstance(schema, str)
if options is not None:
for k, v in options.items():
assert isinstance(k, str)
assert isinstance(v, str)
if paths is not None:
assert isinstance(paths, list)
assert all(isinstance(path, str) for path in paths)
if predicates is not None:
assert isinstance(predicates, list)
assert all(isinstance(predicate, str) for predicate in predicates)
self._format = format
self._schema = schema
self._options = options
self._paths = paths
self._predicates = predicates
self._is_streaming = is_streaming
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._format is not None:
plan.read.data_source.format = self._format
if self._schema is not None:
plan.read.data_source.schema = self._schema
if self._options is not None and len(self._options) > 0:
for k, v in self._options.items():
plan.read.data_source.options[k] = v
if self._paths is not None and len(self._paths) > 0:
plan.read.data_source.paths.extend(self._paths)
if self._predicates is not None and len(self._predicates) > 0:
plan.read.data_source.predicates.extend(self._predicates)
if self._is_streaming is not None:
plan.read.is_streaming = self._is_streaming
return plan
class Read(LogicalPlan):
def __init__(
self,
table_name: str,
options: Optional[Dict[str, str]] = None,
is_streaming: Optional[bool] = None,
) -> None:
super().__init__(None)
self.table_name = table_name
self.options = options or {}
self._is_streaming = is_streaming
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.read.named_table.unparsed_identifier = self.table_name
if self._is_streaming is not None:
plan.read.is_streaming = self._is_streaming
for k, v in self.options.items():
plan.read.named_table.options[k] = v
return plan
def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<Read table_name={self.table_name}>\n"
class LocalRelation(LogicalPlan):
"""Creates a LocalRelation plan object based on a PyArrow Table."""
def __init__(
self,
table: Optional["pa.Table"],
schema: Optional[str] = None,
) -> None:
super().__init__(None)
if table is None:
assert schema is not None
else:
assert isinstance(table, pa.Table)
assert schema is None or isinstance(schema, str)
self._table = table
self._schema = schema
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._table is not None:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
for b in self._table.to_batches():
writer.write_batch(b)
plan.local_relation.data = sink.getvalue().to_pybytes()
if self._schema is not None:
plan.local_relation.schema = self._schema
return plan
def serialize(self, session: "SparkConnectClient") -> bytes:
p = self.plan(session)
return bytes(p.local_relation.SerializeToString())
def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<LocalRelation>\n"
def _repr_html_(self) -> str:
return """
<ul>
<li><b>LocalRelation</b></li>
</ul>
"""
class CachedLocalRelation(LogicalPlan):
"""Creates a CachedLocalRelation plan object based on a hash of a LocalRelation."""
def __init__(self, hash: str) -> None:
super().__init__(None)
self._hash = hash
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
clr = plan.cached_local_relation
clr.hash = self._hash
return plan
def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<CachedLocalRelation>\n"
def _repr_html_(self) -> str:
return """
<ul>
<li><b>CachedLocalRelation</b></li>
</ul>
"""
class ShowString(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, vertical: bool
) -> None:
super().__init__(child)
self.num_rows = num_rows
self.truncate = truncate
self.vertical = vertical
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.show_string.input.CopyFrom(self._child.plan(session))
plan.show_string.num_rows = self.num_rows
plan.show_string.truncate = self.truncate
plan.show_string.vertical = self.vertical
return plan
class HtmlString(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], num_rows: int, truncate: int) -> None:
super().__init__(child)
self.num_rows = num_rows
self.truncate = truncate
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.html_string.input.CopyFrom(self._child.plan(session))
plan.html_string.num_rows = self.num_rows
plan.html_string.truncate = self.truncate
return plan
class Project(LogicalPlan):
"""Logical plan object for a projection.
All input arguments are directly serialized into the corresponding protocol buffer
objects. This class only provides very limited error handling and input validation.
To be compatible with PySpark, we validate that the input arguments are all
expressions to be able to serialize them to the server.
"""
def __init__(
self,
child: Optional["LogicalPlan"],
columns: List[Column],
) -> None:
super().__init__(child)
assert all(isinstance(c, Column) for c in columns)
self._columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.project.input.CopyFrom(self._child.plan(session))
plan.project.expressions.extend([c.to_plan(session) for c in self._columns])
return plan
class WithColumns(LogicalPlan):
"""Logical plan object for a withColumns operation."""
def __init__(
self,
child: Optional["LogicalPlan"],
columnNames: Sequence[str],
columns: Sequence[Column],
metadata: Optional[Sequence[str]] = None,
) -> None:
super().__init__(child)
assert isinstance(columnNames, list)
assert len(columnNames) > 0
assert all(isinstance(c, str) for c in columnNames)
assert isinstance(columns, list)
assert len(columns) == len(columnNames)
assert all(isinstance(c, Column) for c in columns)
if metadata is not None:
assert isinstance(metadata, list)
assert len(metadata) == len(columnNames)
for m in metadata:
assert isinstance(m, str)
# validate json string
assert m == "" or json.loads(m) is not None
self._columnNames = columnNames
self._columns = columns
self._metadata = metadata
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.with_columns.input.CopyFrom(self._child.plan(session))
for i in range(0, len(self._columnNames)):
alias = proto.Expression.Alias()
alias.expr.CopyFrom(self._columns[i].to_plan(session))
alias.name.append(self._columnNames[i])
if self._metadata is not None:
alias.metadata = self._metadata[i]
plan.with_columns.aliases.append(alias)
return plan
class WithWatermark(LogicalPlan):
"""Logical plan object for a WithWatermark operation."""
def __init__(self, child: Optional["LogicalPlan"], event_time: str, delay_threshold: str):
super().__init__(child)
self._event_time = event_time
self._delay_threshold = delay_threshold
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.with_watermark.input.CopyFrom(self._child.plan(session))
plan.with_watermark.event_time = self._event_time
plan.with_watermark.delay_threshold = self._delay_threshold
return plan
class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""
def __init__(self, relationId: str):
super().__init__(None)
self._relationId = relationId
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relationId
return plan
class Hint(LogicalPlan):
"""Logical plan object for a Hint operation."""
def __init__(
self,
child: Optional["LogicalPlan"],
name: str,
parameters: Sequence[Column],
) -> None:
super().__init__(child)
assert isinstance(name, str)
self._name = name
assert parameters is not None and isinstance(parameters, List)
for param in parameters:
assert isinstance(param, Column)
self._parameters = parameters
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.hint.input.CopyFrom(self._child.plan(session))
plan.hint.name = self._name
plan.hint.parameters.extend([param.to_plan(session) for param in self._parameters])
return plan
class Filter(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], filter: Column) -> None:
super().__init__(child)
self.filter = filter
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.filter.input.CopyFrom(self._child.plan(session))
plan.filter.condition.CopyFrom(self.filter.to_plan(session))
return plan
class Limit(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None:
super().__init__(child)
self.limit = limit
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.limit.input.CopyFrom(self._child.plan(session))
plan.limit.limit = self.limit
return plan
class Tail(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None:
super().__init__(child)
self.limit = limit
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.tail.input.CopyFrom(self._child.plan(session))
plan.tail.limit = self.limit
return plan
class Offset(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None:
super().__init__(child)
self.offset = offset
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.offset.input.CopyFrom(self._child.plan(session))
plan.offset.offset = self.offset
return plan
class Deduplicate(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
all_columns_as_keys: bool = False,
column_names: Optional[List[str]] = None,
within_watermark: bool = False,
) -> None:
super().__init__(child)
self.all_columns_as_keys = all_columns_as_keys
self.column_names = column_names
self.within_watermark = within_watermark
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.deduplicate.input.CopyFrom(self._child.plan(session))
plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
plan.deduplicate.within_watermark = self.within_watermark
if self.column_names is not None:
plan.deduplicate.column_names.extend(self.column_names)
return plan
class Sort(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
columns: List[Column],
is_global: bool,
) -> None:
super().__init__(child)
assert all(isinstance(c, Column) for c in columns)
assert isinstance(is_global, bool)
self.columns = columns
self.is_global = is_global
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.sort.input.CopyFrom(self._child.plan(session))
plan.sort.order.extend([c.to_plan(session).sort_order for c in self.columns])
plan.sort.is_global = self.is_global
return plan
class Drop(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
columns: List[Union[Column, str]],
) -> None:
super().__init__(child)
if len(columns) > 0:
assert all(isinstance(c, (Column, str)) for c in columns)
self._columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.drop.input.CopyFrom(self._child.plan(session))
for c in self._columns:
if isinstance(c, Column):
plan.drop.columns.append(c.to_plan(session))
else:
plan.drop.column_names.append(c)
return plan
class Sample(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
lower_bound: float,
upper_bound: float,
with_replacement: bool,
seed: int,
deterministic_order: bool = False,
) -> None:
super().__init__(child)
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.with_replacement = with_replacement
self.seed = seed
self.deterministic_order = deterministic_order
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.sample.input.CopyFrom(self._child.plan(session))
plan.sample.lower_bound = self.lower_bound
plan.sample.upper_bound = self.upper_bound
plan.sample.with_replacement = self.with_replacement
plan.sample.seed = self.seed
plan.sample.deterministic_order = self.deterministic_order
return plan
class Aggregate(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
group_type: str,
grouping_cols: Sequence[Column],
aggregate_cols: Sequence[Column],
pivot_col: Optional[Column],
pivot_values: Optional[Sequence[Column]],
grouping_sets: Optional[Sequence[Sequence[Column]]],
) -> None:
super().__init__(child)
assert isinstance(group_type, str) and group_type in [
"groupby",
"rollup",
"cube",
"pivot",
"grouping_sets",
]
self._group_type = group_type
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
self._grouping_cols = grouping_cols
assert isinstance(aggregate_cols, list) and all(
isinstance(c, Column) for c in aggregate_cols
)
self._aggregate_cols = aggregate_cols
if group_type == "pivot":
assert pivot_col is not None and isinstance(pivot_col, Column)
assert pivot_values is None or isinstance(pivot_values, list)
elif group_type == "grouping_sets":
assert grouping_sets is None or isinstance(grouping_sets, list)
else:
assert pivot_col is None
assert pivot_values is None
assert grouping_sets is None
self._pivot_col = pivot_col
self._pivot_values = pivot_values
self._grouping_sets = grouping_sets
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.aggregate.input.CopyFrom(self._child.plan(session))
plan.aggregate.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.aggregate.aggregate_expressions.extend(
[c.to_plan(session) for c in self._aggregate_cols]
)
if self._group_type == "groupby":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
elif self._group_type == "rollup":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
elif self._group_type == "cube":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
elif self._group_type == "pivot":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
assert self._pivot_col is not None
plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
if self._pivot_values is not None and len(self._pivot_values) > 0:
plan.aggregate.pivot.values.extend(
[v.to_plan(session).literal for v in self._pivot_values]
)
elif self._group_type == "grouping_sets":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS
assert self._grouping_sets is not None
for grouping_set in self._grouping_sets:
plan.aggregate.grouping_sets.append(
proto.Aggregate.GroupingSets(
grouping_set=[c.to_plan(session) for c in grouping_set]
)
)
return plan
class Join(LogicalPlan):
def __init__(
self,
left: Optional["LogicalPlan"],
right: "LogicalPlan",
on: Optional[Union[str, List[str], Column, List[Column]]],
how: Optional[str],
) -> None:
super().__init__(left)
self.left = cast(LogicalPlan, left)
self.right = right
self.on = on
if how is None:
join_type = proto.Join.JoinType.JOIN_TYPE_INNER
elif how == "inner":
join_type = proto.Join.JoinType.JOIN_TYPE_INNER
elif how in ["outer", "full", "fullouter"]:
join_type = proto.Join.JoinType.JOIN_TYPE_FULL_OUTER
elif how in ["leftouter", "left"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
elif how in ["rightouter", "right"]:
join_type = proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
elif how in ["leftsemi", "semi"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
elif how in ["leftanti", "anti"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
elif how == "cross":
join_type = proto.Join.JoinType.JOIN_TYPE_CROSS
else:
raise IllegalArgumentException(
error_class="UNSUPPORTED_JOIN_TYPE",
message_parameters={"join_type": how},
)
self.how = join_type
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.join.left.CopyFrom(self.left.plan(session))
plan.join.right.CopyFrom(self.right.plan(session))
if self.on is not None:
if not isinstance(self.on, list):
if isinstance(self.on, str):
plan.join.using_columns.append(self.on)
else:
plan.join.join_condition.CopyFrom(self.on.to_plan(session))
elif len(self.on) > 0:
if isinstance(self.on[0], str):
plan.join.using_columns.extend(cast(str, self.on))
else:
merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
plan.join.join_type = self.how
return plan
@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
def print(self, indent: int = 0) -> str:
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return (
f"{i}<Join on={self.on} how={self.how}>\n{o}"
f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
)
def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Join</b><br />
Left: {self.left._repr_html_()}
Right: {self.right._repr_html_()}
</li>
</uL>
"""
class AsOfJoin(LogicalPlan):
def __init__(
self,
left: LogicalPlan,
right: LogicalPlan,
left_as_of: Column,
right_as_of: Column,
on: Optional[Union[str, List[str], Column, List[Column]]],
how: str,
tolerance: Optional[Column],
allow_exact_matches: bool,
direction: str,
) -> None:
super().__init__(left)
self.left = left
self.right = right
self.left_as_of = left_as_of
self.right_as_of = right_as_of
self.on = on
self.how = how
self.tolerance = tolerance
self.allow_exact_matches = allow_exact_matches
self.direction = direction
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.as_of_join.left.CopyFrom(self.left.plan(session))
plan.as_of_join.right.CopyFrom(self.right.plan(session))
plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session))
plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session))
if self.on is not None:
if not isinstance(self.on, list):
if isinstance(self.on, str):
plan.as_of_join.using_columns.append(self.on)
else:
plan.as_of_join.join_expr.CopyFrom(self.on.to_plan(session))
elif len(self.on) > 0:
if isinstance(self.on[0], str):
plan.as_of_join.using_columns.extend(cast(List[str], self.on))
else:
merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
plan.as_of_join.join_expr.CopyFrom(cast(Column, merge_column).to_plan(session))
plan.as_of_join.join_type = self.how
if self.tolerance is not None:
plan.as_of_join.tolerance.CopyFrom(self.tolerance.to_plan(session))
plan.as_of_join.allow_exact_matches = self.allow_exact_matches
plan.as_of_join.direction = self.direction
return plan
@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
def print(self, indent: int = 0) -> str:
assert self.left is not None
assert self.right is not None
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return (
f"{i}<AsOfJoin left_as_of={self.left_as_of}, right_as_of={self.right_as_of}, "
f"on={self.on} how={self.how}>\n{o}"
f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
)
def _repr_html_(self) -> str:
assert self.left is not None
assert self.right is not None
return f"""
<ul>
<li>
<b>AsOfJoin</b><br />
Left: {self.left._repr_html_()}
Right: {self.right._repr_html_()}
</li>
</uL>
"""
class SetOperation(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
other: Optional["LogicalPlan"],
set_op: str,
is_all: bool = True,
by_name: bool = False,
allow_missing_columns: bool = False,
) -> None:
super().__init__(child)
self.other = other
self.by_name = by_name
self.is_all = is_all
self.set_op = set_op
self.allow_missing_columns = allow_missing_columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
if self._child is not None:
plan.set_op.left_input.CopyFrom(self._child.plan(session))
if self.other is not None:
plan.set_op.right_input.CopyFrom(self.other.plan(session))
if self.set_op == "union":
plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
elif self.set_op == "intersect":
plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
elif self.set_op == "except":
plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
message_parameters={"operation": self.set_op},
)
plan.set_op.is_all = self.is_all
plan.set_op.by_name = self.by_name
plan.set_op.allow_missing_columns = self.allow_missing_columns
return plan
@property
def observations(self) -> Dict[str, "Observation"]:
return dict(
**super().observations,
**(self.other.observations if self.other is not None else {}),
)
def print(self, indent: int = 0) -> str:
assert self._child is not None
assert self.other is not None
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return (
f"{i}SetOperation\n{o}child1=\n{self._child.print(n)}"
f"\n{o}child2=\n{self.other.print(n)}"
)
def _repr_html_(self) -> str:
assert self._child is not None
assert self.other is not None
return f"""
<ul>
<li>
<b>SetOperation</b><br />
Left: {self._child._repr_html_()}
Right: {self.other._repr_html_()}
</li>
</uL>
"""
class Repartition(LogicalPlan):
"""Repartition Relation into a different number of partitions."""
def __init__(self, child: Optional["LogicalPlan"], num_partitions: int, shuffle: bool) -> None:
super().__init__(child)
self._num_partitions = num_partitions
self._shuffle = shuffle
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._child is not None:
plan.repartition.input.CopyFrom(self._child.plan(session))
plan.repartition.shuffle = self._shuffle
plan.repartition.num_partitions = self._num_partitions
return plan
class RepartitionByExpression(LogicalPlan):
"""Repartition Relation into a different number of partitions using Expression"""
def __init__(
self,
child: Optional["LogicalPlan"],
num_partitions: Optional[int],
columns: List[Column],
) -> None:
super().__init__(child)
self.num_partitions = num_partitions
assert all(isinstance(c, Column) for c in columns)
self.columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.repartition_by_expression.partition_exprs.extend(
[c.to_plan(session) for c in self.columns]
)
if self._child is not None:
plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
if self.num_partitions is not None:
plan.repartition_by_expression.num_partitions = self.num_partitions
return plan
class SubqueryAlias(LogicalPlan):
"""Alias for a relation."""
def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None:
super().__init__(child)
self._alias = alias
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._child is not None:
plan.subquery_alias.input.CopyFrom(self._child.plan(session))
plan.subquery_alias.alias = self._alias
return plan
class WithRelations(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
references: Sequence["LogicalPlan"],
) -> None:
super().__init__(child)
assert references is not None and len(references) > 0
assert all(isinstance(ref, LogicalPlan) for ref in references)
self._references = references
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._child is not None:
plan.with_relations.root.CopyFrom(self._child.plan(session))
for ref in self._references:
plan.with_relations.references.append(ref.plan(session))
return plan
class SQL(LogicalPlan):
def __init__(
self,
query: str,
args: Optional[List[Column]] = None,
named_args: Optional[Dict[str, Column]] = None,
views: Optional[Sequence[SubqueryAlias]] = None,
) -> None:
super().__init__(None)
if args is not None:
assert isinstance(args, List)
assert all(isinstance(arg, Column) for arg in args)
if named_args is not None:
assert isinstance(named_args, Dict)
for k, arg in named_args.items():
assert isinstance(k, str)
assert isinstance(arg, Column)
if views is not None:
assert isinstance(views, List)
assert all(isinstance(v, SubqueryAlias) for v in views)
if len(views) > 0:
# reserved plan id for WithRelations
self._plan_id_with_rel = LogicalPlan._fresh_plan_id()
self._query = query
self._args = args
self._named_args = named_args
self._views = views
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.sql.query = self._query
if self._args is not None and len(self._args) > 0:
plan.sql.pos_arguments.extend([arg.to_plan(session) for arg in self._args])
if self._named_args is not None and len(self._named_args) > 0:
for k, arg in self._named_args.items():
plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session))
if self._views is not None and len(self._views) > 0:
# build new plan like
# with_relations [id 10]
# root: sql [id 9]
# reference:
# view#1: [id 8]
# view#2: [id 5]
sql_plan = plan
plan = proto.Relation()
plan.common.plan_id = self._plan_id_with_rel
plan.with_relations.root.CopyFrom(sql_plan)
plan.with_relations.references.extend([v.plan(session) for v in self._views])
return plan
def command(self, session: "SparkConnectClient") -> proto.Command:
cmd = proto.Command()
cmd.sql_command.input.CopyFrom(self.plan(session))
return cmd
class Range(LogicalPlan):
def __init__(
self,
start: int,
end: int,
step: int,
num_partitions: Optional[int] = None,
) -> None:
super().__init__(None)
self._start = start
self._end = end
self._step = step
self._num_partitions = num_partitions
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.range.start = self._start
plan.range.end = self._end
plan.range.step = self._step
if self._num_partitions is not None:
plan.range.num_partitions = self._num_partitions
return plan
class ToSchema(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], schema: DataType) -> None:
super().__init__(child)
self._schema = schema
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.to_schema.input.CopyFrom(self._child.plan(session))
plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
return plan
class WithColumnsRenamed(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], colsMap: Mapping[str, str]) -> None:
super().__init__(child)
self._colsMap = colsMap
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.with_columns_renamed.input.CopyFrom(self._child.plan(session))
if len(self._colsMap) > 0:
for k, v in self._colsMap.items():
rename = proto.WithColumnsRenamed.Rename()
rename.col_name = k
rename.new_col_name = v
plan.with_columns_renamed.renames.append(rename)
return plan
class Unpivot(LogicalPlan):
"""Logical plan object for a unpivot operation."""
def __init__(
self,
child: Optional["LogicalPlan"],
ids: List[Column],
values: Optional[List[Column]],
variable_column_name: str,
value_column_name: str,
) -> None:
super().__init__(child)
self.ids = ids
self.values = values
self.variable_column_name = variable_column_name
self.value_column_name = value_column_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.unpivot.input.CopyFrom(self._child.plan(session))
plan.unpivot.ids.extend([id.to_plan(session) for id in self.ids])
if self.values is not None:
plan.unpivot.values.values.extend([v.to_plan(session) for v in self.values])
plan.unpivot.variable_column_name = self.variable_column_name
plan.unpivot.value_column_name = self.value_column_name
return plan
class CollectMetrics(LogicalPlan):
"""Logical plan object for a CollectMetrics operation."""
def __init__(
self,
child: Optional["LogicalPlan"],
observation: Union[str, "Observation"],
exprs: List[Column],
) -> None:
super().__init__(child)
self._observation = observation
assert all(isinstance(e, Column) for e in exprs)
self._exprs = exprs
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.collect_metrics.input.CopyFrom(self._child.plan(session))
plan.collect_metrics.name = (
self._observation
if isinstance(self._observation, str)
else str(self._observation._name)
)
plan.collect_metrics.metrics.extend([e.to_plan(session) for e in self._exprs])
return plan
@property
def observations(self) -> Dict[str, "Observation"]:
from pyspark.sql.connect.observation import Observation
if isinstance(self._observation, Observation):
observations = {str(self._observation._name): self._observation}
else:
observations = {}
return dict(**super().observations, **observations)
class NAFill(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any]
) -> None:
super().__init__(child)
assert (
isinstance(values, list)
and len(values) > 0
and all(isinstance(v, (bool, int, float, str)) for v in values)
)
if cols is not None and len(cols) > 0:
assert isinstance(cols, list) and all(isinstance(c, str) for c in cols)
if len(values) > 1:
assert len(cols) == len(values)
self.cols = cols
self.values = values
def _convert_value(self, v: Any) -> proto.Expression.Literal:
value = proto.Expression.Literal()
if isinstance(v, bool):
value.boolean = v
elif isinstance(v, int):
value.long = v
elif isinstance(v, float):
value.double = v
else:
value.string = v
return value
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.fill_na.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.fill_na.cols.extend(self.cols)
plan.fill_na.values.extend([self._convert_value(v) for v in self.values])
return plan
class NADrop(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
cols: Optional[List[str]],
min_non_nulls: Optional[int],
) -> None:
super().__init__(child)
self.cols = cols
self.min_non_nulls = min_non_nulls
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.drop_na.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.drop_na.cols.extend(self.cols)
if self.min_non_nulls is not None:
plan.drop_na.min_non_nulls = self.min_non_nulls
return plan
class NAReplace(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
cols: Optional[List[str]],
replacements: Sequence[Tuple[Column, Column]],
) -> None:
super().__init__(child)
self.cols = cols
assert replacements is not None and isinstance(replacements, List)
for k, v in replacements:
assert k is not None and isinstance(k, Column)
assert v is not None and isinstance(v, Column)
self.replacements = replacements
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.replace.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.replace.cols.extend(self.cols)
if len(self.replacements) > 0:
for old_value, new_value in self.replacements:
replacement = proto.NAReplace.Replacement()
replacement.old_value.CopyFrom(old_value.to_plan(session).literal)
replacement.new_value.CopyFrom(new_value.to_plan(session).literal)
plan.replace.replacements.append(replacement)
return plan
class StatSummary(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
super().__init__(child)
self.statistics = statistics
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.summary.input.CopyFrom(self._child.plan(session))
plan.summary.statistics.extend(self.statistics)
return plan
class StatDescribe(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], cols: List[str]) -> None:
super().__init__(child)
self.cols = cols
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.describe.input.CopyFrom(self._child.plan(session))
plan.describe.cols.extend(self.cols)
return plan
class StatCov(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None:
super().__init__(child)
self._col1 = col1
self._col2 = col2
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.cov.input.CopyFrom(self._child.plan(session))
plan.cov.col1 = self._col1
plan.cov.col2 = self._col2
return plan
class StatApproxQuantile(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
cols: List[str],
probabilities: List[float],
relativeError: float,
) -> None:
super().__init__(child)
self._cols = cols
self._probabilities = probabilities
self._relativeError = relativeError
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.approx_quantile.input.CopyFrom(self._child.plan(session))
plan.approx_quantile.cols.extend(self._cols)
plan.approx_quantile.probabilities.extend(self._probabilities)
plan.approx_quantile.relative_error = self._relativeError
return plan
class StatCrosstab(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None:
super().__init__(child)
self.col1 = col1
self.col2 = col2
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.crosstab.input.CopyFrom(self._child.plan(session))
plan.crosstab.col1 = self.col1
plan.crosstab.col2 = self.col2
return plan
class StatFreqItems(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
cols: List[str],
support: float,
) -> None:
super().__init__(child)
self._cols = cols
self._support = support
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.freq_items.input.CopyFrom(self._child.plan(session))
plan.freq_items.cols.extend(self._cols)
plan.freq_items.support = self._support
return plan
class StatSampleBy(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
col: Column,
fractions: Sequence[Tuple[Column, float]],
seed: int,
) -> None:
super().__init__(child)
assert col is not None and isinstance(col, (Column, str))
assert fractions is not None and isinstance(fractions, List)
for k, v in fractions:
assert k is not None and isinstance(k, Column)
assert v is not None and isinstance(v, float)
assert seed is None or isinstance(seed, int)
self._col = col
self._fractions = fractions
self._seed = seed
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.sample_by.input.CopyFrom(self._child.plan(session))
plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
if len(self._fractions) > 0:
for k, v in self._fractions:
fraction = proto.StatSampleBy.Fraction()
fraction.stratum.CopyFrom(k.to_plan(session).literal)
fraction.fraction = float(v)
plan.sample_by.fractions.append(fraction)
plan.sample_by.seed = self._seed
return plan
class StatCorr(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, method: str) -> None:
super().__init__(child)
self._col1 = col1
self._col2 = col2
self._method = method
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.corr.input.CopyFrom(self._child.plan(session))
plan.corr.col1 = self._col1
plan.corr.col2 = self._col2
plan.corr.method = self._method
return plan
class ToDF(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) -> None:
super().__init__(child)
self._cols = cols
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.to_df.input.CopyFrom(self._child.plan(session))
plan.to_df.column_names.extend(self._cols)
return plan
class CreateView(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], name: str, is_global: bool, replace: bool
) -> None:
super().__init__(child)
self._name = name
self._is_global = is_global
self._replace = replace
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
plan = proto.Command()
plan.create_dataframe_view.replace = self._replace
plan.create_dataframe_view.is_global = self._is_global
plan.create_dataframe_view.name = self._name
plan.create_dataframe_view.input.CopyFrom(self._child.plan(session))
return plan
class WriteOperation(LogicalPlan):
def __init__(self, child: "LogicalPlan") -> None:
super(WriteOperation, self).__init__(child)
self.source: Optional[str] = None
self.path: Optional[str] = None
self.table_name: Optional[str] = None
self.table_save_method: Optional[str] = None
self.mode: Optional[str] = None
self.sort_cols: List[str] = []
self.partitioning_cols: List[str] = []
self.options: Dict[str, Optional[str]] = {}
self.num_buckets: int = -1
self.bucket_cols: List[str] = []
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
plan = proto.Command()
plan.write_operation.input.CopyFrom(self._child.plan(session))
if self.source is not None:
plan.write_operation.source = self.source
plan.write_operation.sort_column_names.extend(self.sort_cols)
plan.write_operation.partitioning_columns.extend(self.partitioning_cols)
if self.num_buckets > 0:
plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols)
plan.write_operation.bucket_by.num_buckets = self.num_buckets
for k in self.options:
if self.options[k] is None:
plan.write_operation.options.pop(k, None)
else:
plan.write_operation.options[k] = cast(str, self.options[k])
if self.table_name is not None:
plan.write_operation.table.table_name = self.table_name
if self.table_save_method is not None:
tsm = self.table_save_method.lower()
if tsm == "save_as_table":
plan.write_operation.table.save_method = (
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE # noqa: E501
)
elif tsm == "insert_into":
plan.write_operation.table.save_method = (
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
)
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
message_parameters={"operation": tsm},
)
elif self.path is not None:
plan.write_operation.path = self.path
if self.mode is not None:
wm = self.mode.lower()
if wm == "append":
plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
elif wm == "overwrite":
plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
elif wm == "error":
plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
elif wm == "ignore":
plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
message_parameters={"operation": self.mode},
)
return plan
def print(self, indent: int = 0) -> str:
i = " " * indent
return (
f"{i}"
f"<WriteOperation source='{self.source}' "
f"path='{self.path} "
f"table_name='{self.table_name}' "
f"table_save_method='{self.table_save_method}' "
f"mode='{self.mode}' "
f"sort_cols='{self.sort_cols}' "
f"partitioning_cols='{self.partitioning_cols}' "
f"num_buckets='{self.num_buckets}' "
f"bucket_cols='{self.bucket_cols}' "
f"options='{self.options}'>"
)
def _repr_html_(self) -> str:
return (
f"<uL><li>WriteOperation <br />source='{self.source}'<br />"
f"path: '{self.path}<br />"
f"table_name: '{self.table_name}' <br />"
f"table_save_method: '{self.table_save_method}' <br />"
f"mode: '{self.mode}' <br />"
f"sort_cols: '{self.sort_cols}' <br />"
f"partitioning_cols: '{self.partitioning_cols}' <br />"
f"num_buckets: '{self.num_buckets}' <br />"
f"bucket_cols: '{self.bucket_cols}' <br />"
f"options: '{self.options}'<br />"
f"</li></ul>"
)
class WriteOperationV2(LogicalPlan):
def __init__(self, child: "LogicalPlan", table_name: str) -> None:
super(WriteOperationV2, self).__init__(child)
self.table_name: Optional[str] = table_name
self.provider: Optional[str] = None
self.partitioning_columns: List[Column] = []
self.options: dict[str, Optional[str]] = {}
self.table_properties: dict[str, Optional[str]] = {}
self.mode: Optional[str] = None
self.overwrite_condition: Optional[Column] = None
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
plan = proto.Command()
plan.write_operation_v2.input.CopyFrom(self._child.plan(session))
if self.table_name is not None:
plan.write_operation_v2.table_name = self.table_name
if self.provider is not None:
plan.write_operation_v2.provider = self.provider
plan.write_operation_v2.partitioning_columns.extend(
[c.to_plan(session) for c in self.partitioning_columns]
)
for k in self.options:
if self.options[k] is None:
plan.write_operation_v2.options.pop(k, None)
else:
plan.write_operation_v2.options[k] = cast(str, self.options[k])
for k in self.table_properties:
if self.table_properties[k] is None:
plan.write_operation_v2.table_properties.pop(k, None)
else:
plan.write_operation_v2.table_properties[k] = cast(str, self.table_properties[k])
if self.mode is not None:
wm = self.mode.lower()
if wm == "create":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE
elif wm == "overwrite":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE
if self.overwrite_condition is not None:
plan.write_operation_v2.overwrite_condition.CopyFrom(
self.overwrite_condition.to_plan(session)
)
elif wm == "overwrite_partitions":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS
elif wm == "append":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND
elif wm == "replace":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE
elif wm == "create_or_replace":
plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
message_parameters={"operation": self.mode},
)
return plan
class WriteStreamOperation(LogicalPlan):
def __init__(self, child: "LogicalPlan") -> None:
super(WriteStreamOperation, self).__init__(child)
self.write_op = proto.WriteStreamOperationStart()
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
self.write_op.input.CopyFrom(self._child.plan(session))
cmd = proto.Command()
cmd.write_stream_operation_start.CopyFrom(self.write_op)
return cmd
# Catalog API (internal-only)
class CurrentDatabase(LogicalPlan):
def __init__(self) -> None:
super().__init__(None)
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.current_database.SetInParent()
return plan
class SetCurrentDatabase(LogicalPlan):
def __init__(self, db_name: str) -> None:
super().__init__(None)
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.set_current_database.db_name = self._db_name
return plan
class ListDatabases(LogicalPlan):
def __init__(self, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._pattern = pattern
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.list_databases.SetInParent()
if self._pattern is not None:
plan.catalog.list_databases.pattern = self._pattern
return plan
class ListTables(LogicalPlan):
def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._db_name = db_name
self._pattern = pattern
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.list_tables.SetInParent()
if self._db_name is not None:
plan.catalog.list_tables.db_name = self._db_name
if self._pattern is not None:
plan.catalog.list_tables.pattern = self._pattern
return plan
class ListFunctions(LogicalPlan):
def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._db_name = db_name
self._pattern = pattern
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.list_functions.SetInParent()
if self._db_name is not None:
plan.catalog.list_functions.db_name = self._db_name
if self._pattern is not None:
plan.catalog.list_functions.pattern = self._pattern
return plan
class ListColumns(LogicalPlan):
def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
super().__init__(None)
self._table_name = table_name
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.list_columns.table_name = self._table_name
if self._db_name is not None:
plan.catalog.list_columns.db_name = self._db_name
return plan
class GetDatabase(LogicalPlan):
def __init__(self, db_name: str) -> None:
super().__init__(None)
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.get_database.db_name = self._db_name
return plan
class GetTable(LogicalPlan):
def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
super().__init__(None)
self._table_name = table_name
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.get_table.table_name = self._table_name
if self._db_name is not None:
plan.catalog.get_table.db_name = self._db_name
return plan
class GetFunction(LogicalPlan):
def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
super().__init__(None)
self._function_name = function_name
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.get_function.function_name = self._function_name
if self._db_name is not None:
plan.catalog.get_function.db_name = self._db_name
return plan
class DatabaseExists(LogicalPlan):
def __init__(self, db_name: str) -> None:
super().__init__(None)
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.database_exists.db_name = self._db_name
return plan
class TableExists(LogicalPlan):
def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
super().__init__(None)
self._table_name = table_name
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.table_exists.table_name = self._table_name
if self._db_name is not None:
plan.catalog.table_exists.db_name = self._db_name
return plan
class FunctionExists(LogicalPlan):
def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
super().__init__(None)
self._function_name = function_name
self._db_name = db_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.function_exists.function_name = self._function_name
if self._db_name is not None:
plan.catalog.function_exists.db_name = self._db_name
return plan
class CreateTable(LogicalPlan):
def __init__(
self,
table_name: str,
path: str,
source: Optional[str] = None,
description: Optional[str] = None,
schema: Optional[DataType] = None,
options: Mapping[str, str] = {},
) -> None:
super().__init__(None)
self._table_name = table_name
self._path = path
self._source = source
self._description = description
self._schema = schema
self._options = options
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.create_table.table_name = self._table_name
if self._path is not None:
plan.catalog.create_table.path = self._path
if self._source is not None:
plan.catalog.create_table.source = self._source
if self._description is not None:
plan.catalog.create_table.description = self._description
if self._schema is not None:
plan.catalog.create_table.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
for k in self._options.keys():
v = self._options.get(k)
if v is not None:
plan.catalog.create_table.options[k] = v
return plan
class DropTempView(LogicalPlan):
def __init__(self, view_name: str) -> None:
super().__init__(None)
self._view_name = view_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.drop_temp_view.view_name = self._view_name
return plan
class DropGlobalTempView(LogicalPlan):
def __init__(self, view_name: str) -> None:
super().__init__(None)
self._view_name = view_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.drop_global_temp_view.view_name = self._view_name
return plan
class RecoverPartitions(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.recover_partitions.table_name = self._table_name
return plan
class IsCached(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.is_cached.table_name = self._table_name
return plan
class CacheTable(LogicalPlan):
def __init__(self, table_name: str, storage_level: Optional[StorageLevel] = None) -> None:
super().__init__(None)
self._table_name = table_name
self._storage_level = storage_level
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
_cache_table = proto.CacheTable(table_name=self._table_name)
if self._storage_level:
_cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
plan.catalog.cache_table.CopyFrom(_cache_table)
return plan
class UncacheTable(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.uncache_table.table_name = self._table_name
return plan
class ClearCache(LogicalPlan):
def __init__(self) -> None:
super().__init__(None)
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.clear_cache.SetInParent()
return plan
class RefreshTable(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.refresh_table.table_name = self._table_name
return plan
class RefreshByPath(LogicalPlan):
def __init__(self, path: str) -> None:
super().__init__(None)
self._path = path
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.refresh_by_path.path = self._path
return plan
class CurrentCatalog(LogicalPlan):
def __init__(self) -> None:
super().__init__(None)
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.current_catalog.SetInParent()
return plan
class SetCurrentCatalog(LogicalPlan):
def __init__(self, catalog_name: str) -> None:
super().__init__(None)
self._catalog_name = catalog_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.set_current_catalog.catalog_name = self._catalog_name
return plan
class ListCatalogs(LogicalPlan):
def __init__(self, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._pattern = pattern
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.catalog.list_catalogs.SetInParent()
if self._pattern is not None:
plan.catalog.list_catalogs.pattern = self._pattern
return plan
class MapPartitions(LogicalPlan):
"""Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""
def __init__(
self,
child: Optional["LogicalPlan"],
function: "UserDefinedFunction",
cols: List[str],
is_barrier: bool,
profile: Optional[ResourceProfile],
) -> None:
super().__init__(child)
self._function = function._build_common_inline_user_defined_function(*cols)
self._is_barrier = is_barrier
self._profile = profile
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.map_partitions.input.CopyFrom(self._child.plan(session))
plan.map_partitions.func.CopyFrom(self._function.to_plan_udf(session))
plan.map_partitions.is_barrier = self._is_barrier
if self._profile is not None:
plan.map_partitions.profile_id = self._profile.id
return plan
class GroupMap(LogicalPlan):
"""Logical plan object for a Group Map API: apply, applyInPandas."""
def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
cols: List[str],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
super().__init__(child)
self._grouping_cols = grouping_cols
self._function = function._build_common_inline_user_defined_function(*cols)
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.group_map.input.CopyFrom(self._child.plan(session))
plan.group_map.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
return plan
class CoGroupMap(LogicalPlan):
"""Logical plan object for a CoGroup Map API: applyInPandas."""
def __init__(
self,
input: Optional["LogicalPlan"],
input_grouping_cols: Sequence[Column],
other: Optional["LogicalPlan"],
other_grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
):
assert isinstance(input_grouping_cols, list) and all(
isinstance(c, Column) for c in input_grouping_cols
)
assert isinstance(other_grouping_cols, list) and all(
isinstance(c, Column) for c in other_grouping_cols
)
super().__init__(input)
self._input_grouping_cols = input_grouping_cols
self._other_grouping_cols = other_grouping_cols
self._other = cast(LogicalPlan, other)
# The function takes entire DataFrame as inputs, no need to do
# column binding (no input columns).
self._function = function._build_common_inline_user_defined_function()
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.co_group_map.input.CopyFrom(self._child.plan(session))
plan.co_group_map.input_grouping_expressions.extend(
[c.to_plan(session) for c in self._input_grouping_cols]
)
plan.co_group_map.other.CopyFrom(self._other.plan(session))
plan.co_group_map.other_grouping_expressions.extend(
[c.to_plan(session) for c in self._other_grouping_cols]
)
plan.co_group_map.func.CopyFrom(self._function.to_plan_udf(session))
return plan
class ApplyInPandasWithState(LogicalPlan):
"""Logical plan object for a applyInPandasWithState."""
def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
output_schema: str,
state_schema: str,
output_mode: str,
timeout_conf: str,
cols: List[str],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
super().__init__(child)
self._grouping_cols = grouping_cols
self._function = function._build_common_inline_user_defined_function(*cols)
self._output_schema = output_schema
self._state_schema = state_schema
self._output_mode = output_mode
self._timeout_conf = timeout_conf
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.apply_in_pandas_with_state.input.CopyFrom(self._child.plan(session))
plan.apply_in_pandas_with_state.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.apply_in_pandas_with_state.func.CopyFrom(self._function.to_plan_udf(session))
plan.apply_in_pandas_with_state.output_schema = self._output_schema
plan.apply_in_pandas_with_state.state_schema = self._state_schema
plan.apply_in_pandas_with_state.output_mode = self._output_mode
plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
return plan
class PythonUDTF:
"""Represents a Python user-defined table function."""
def __init__(
self,
func: Type,
return_type: Optional[Union[DataType, str]],
eval_type: int,
python_ver: str,
) -> None:
self._func = func
self._name = func.__name__
self._return_type: Optional[DataType] = (
None
if return_type is None
else UnparsedDataType(return_type)
if isinstance(return_type, str)
else return_type
)
self._eval_type = eval_type
self._python_ver = python_ver
def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
udtf = proto.PythonUDTF()
if self._return_type is not None:
udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type))
udtf.eval_type = self._eval_type
try:
udtf.command = CloudPickleSerializer().dumps(self._func)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
"message": "Please check the stack trace and "
"make sure the function is serializable.",
},
)
udtf.python_ver = self._python_ver
return udtf
def __repr__(self) -> str:
return (
f"PythonUDTF({self._name}, {self._return_type}, "
f"{self._eval_type}, {self._python_ver})"
)
class CommonInlineUserDefinedTableFunction(LogicalPlan):
"""
Logical plan object for a user-defined table function with
an inlined defined function body.
"""
def __init__(
self,
function_name: str,
function: PythonUDTF,
deterministic: bool,
arguments: Sequence[Expression],
) -> None:
super().__init__(None)
self._function_name = function_name
self._deterministic = deterministic
self._arguments = arguments
self._function = function
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.common_inline_user_defined_table_function.function_name = self._function_name
plan.common_inline_user_defined_table_function.deterministic = self._deterministic
if len(self._arguments) > 0:
plan.common_inline_user_defined_table_function.arguments.extend(
[arg.to_plan(session) for arg in self._arguments]
)
plan.common_inline_user_defined_table_function.python_udtf.CopyFrom(
self._function.to_plan(session)
)
return plan
def udtf_plan(
self, session: "SparkConnectClient"
) -> "proto.CommonInlineUserDefinedTableFunction":
"""
Compared to `plan`, it returns a `proto.CommonInlineUserDefinedTableFunction`
instead of a `proto.Relation`.
"""
plan = proto.CommonInlineUserDefinedTableFunction()
plan.function_name = self._function_name
plan.deterministic = self._deterministic
if len(self._arguments) > 0:
plan.arguments.extend([arg.to_plan(session) for arg in self._arguments])
plan.python_udtf.CopyFrom(
cast(proto.PythonUDF, self._function.to_plan(session)) # type: ignore[arg-type]
)
return plan
def __repr__(self) -> str:
return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})"
class PythonDataSource:
"""Represents a user-defined Python data source."""
def __init__(self, data_source: Type, python_ver: str):
self._data_source = data_source
self._python_ver = python_ver
def to_plan(self, session: "SparkConnectClient") -> proto.PythonDataSource:
ds = proto.PythonDataSource()
ds.command = CloudPickleSerializer().dumps(self._data_source)
ds.python_ver = self._python_ver
return ds
class CommonInlineUserDefinedDataSource(LogicalPlan):
"""Logical plan object for a user-defined data source"""
def __init__(self, name: str, data_source: PythonDataSource) -> None:
super().__init__(None)
self._name = name
self._data_source = data_source
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.common_inline_user_defined_data_source.name = self._name
plan.common_inline_user_defined_data_source.python_data_source.CopyFrom(
self._data_source.to_plan(session)
)
return plan
def to_data_source_proto(
self, session: "SparkConnectClient"
) -> "proto.CommonInlineUserDefinedDataSource":
plan = proto.CommonInlineUserDefinedDataSource()
plan.name = self._name
plan.python_data_source.CopyFrom(self._data_source.to_plan(session))
return plan
class CachedRelation(LogicalPlan):
def __init__(self, plan: proto.Relation) -> None:
super(CachedRelation, self).__init__(None)
self._plan = plan
# Update the plan ID based on the incremented counter.
self._plan.common.plan_id = self._plan_id
def plan(self, session: "SparkConnectClient") -> proto.Relation:
return self._plan