blob: 88c4f4d267b39c1d8542a1644e15f547aaf439e1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.utils import is_timestamp_ntz_preferred
check_dependencies(__name__)
from typing import (
cast,
TYPE_CHECKING,
Any,
Callable,
Union,
Sequence,
Tuple,
Optional,
)
import json
import decimal
import datetime
import warnings
from threading import Lock
import numpy as np
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.types import (
_from_numpy_type,
DateType,
ArrayType,
NullType,
BooleanType,
BinaryType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
StringType,
DataType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
)
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.types import (
JVM_BYTE_MIN,
JVM_BYTE_MAX,
JVM_SHORT_MIN,
JVM_SHORT_MAX,
JVM_INT_MIN,
JVM_INT_MAX,
JVM_LONG_MIN,
JVM_LONG_MAX,
UnparsedDataType,
pyspark_types_to_proto_types,
proto_schema_to_pyspark_data_type,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.window import WindowSpec
class Expression:
"""
Expression base class.
"""
def __init__(self) -> None:
pass
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
...
def __repr__(self) -> str:
...
def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
metadata = kwargs.pop("metadata", None)
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
return ColumnAlias(self, list(alias), metadata)
def name(self) -> str:
...
class CaseWhen(Expression):
def __init__(
self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression]
):
super().__init__()
assert isinstance(branches, list)
for branch in branches:
assert (
isinstance(branch, tuple)
and len(branch) == 2
and all(isinstance(expr, Expression) for expr in branch)
)
self._branches = branches
if else_value is not None:
assert isinstance(else_value, Expression)
self._else_value = else_value
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
args = []
for condition, value in self._branches:
args.append(condition)
args.append(value)
if self._else_value is not None:
args.append(self._else_value)
unresolved_function = UnresolvedFunction(name="when", args=args)
return unresolved_function.to_plan(session)
def __repr__(self) -> str:
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
_else = f" ELSE {self._else_value}" if self._else_value is not None else ""
return "CASE" + _cases + _else + " END"
class ColumnAlias(Expression):
def __init__(self, parent: Expression, alias: Sequence[str], metadata: Any):
super().__init__()
self._alias = alias
self._metadata = metadata
self._parent = parent
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
if len(self._alias) == 1:
exp = proto.Expression()
exp.alias.name.append(self._alias[0])
exp.alias.expr.CopyFrom(self._parent.to_plan(session))
if self._metadata:
exp.alias.metadata = json.dumps(self._metadata)
return exp
else:
if self._metadata:
raise PySparkValueError(
error_class="CANNOT_PROVIDE_METADATA",
message_parameters={},
)
exp = proto.Expression()
exp.alias.name.extend(self._alias)
exp.alias.expr.CopyFrom(self._parent.to_plan(session))
return exp
def __repr__(self) -> str:
return f"{self._parent} AS {','.join(self._alias)}"
class LiteralExpression(Expression):
"""A literal expression.
The Python types are converted best effort into the relevant proto types. On the Spark Connect
server side, the proto types are converted to the Catalyst equivalents."""
def __init__(self, value: Any, dataType: DataType) -> None:
super().__init__()
assert isinstance(
dataType,
(
NullType,
BinaryType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
StringType,
DateType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
ArrayType,
),
)
if isinstance(dataType, NullType):
assert value is None
if value is not None:
if isinstance(dataType, BinaryType):
assert isinstance(value, (bytes, bytearray))
elif isinstance(dataType, BooleanType):
assert isinstance(value, (bool, np.bool_))
value = bool(value)
elif isinstance(dataType, ByteType):
assert isinstance(value, (int, np.int8))
assert JVM_BYTE_MIN <= int(value) <= JVM_BYTE_MAX
value = int(value)
elif isinstance(dataType, ShortType):
assert isinstance(value, (int, np.int8, np.int16))
assert JVM_SHORT_MIN <= int(value) <= JVM_SHORT_MAX
value = int(value)
elif isinstance(dataType, IntegerType):
assert isinstance(value, (int, np.int8, np.int16, np.int32))
assert JVM_INT_MIN <= int(value) <= JVM_INT_MAX
value = int(value)
elif isinstance(dataType, LongType):
assert isinstance(value, (int, np.int8, np.int16, np.int32, np.int64))
assert JVM_LONG_MIN <= int(value) <= JVM_LONG_MAX
value = int(value)
elif isinstance(dataType, FloatType):
assert isinstance(value, (float, np.float32))
value = float(value)
elif isinstance(dataType, DoubleType):
assert isinstance(value, (float, np.float32, np.float64))
value = float(value)
elif isinstance(dataType, DecimalType):
assert isinstance(value, decimal.Decimal)
elif isinstance(dataType, StringType):
assert isinstance(value, str)
elif isinstance(dataType, DateType):
assert isinstance(value, (datetime.date, datetime.datetime))
if isinstance(value, datetime.date):
value = DateType().toInternal(value)
else:
value = DateType().toInternal(value.date())
elif isinstance(dataType, TimestampType):
assert isinstance(value, datetime.datetime)
value = TimestampType().toInternal(value)
elif isinstance(dataType, TimestampNTZType):
assert isinstance(value, datetime.datetime)
value = TimestampNTZType().toInternal(value)
elif isinstance(dataType, DayTimeIntervalType):
assert isinstance(value, datetime.timedelta)
value = DayTimeIntervalType().toInternal(value)
assert value is not None
elif isinstance(dataType, ArrayType):
assert isinstance(value, list)
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE",
message_parameters={"data_type": str(dataType)},
)
self._value = value
self._dataType = dataType
@classmethod
def _infer_type(cls, value: Any) -> DataType:
if value is None:
return NullType()
elif isinstance(value, (bytes, bytearray)):
return BinaryType()
elif isinstance(value, bool):
return BooleanType()
elif isinstance(value, int):
if JVM_INT_MIN <= value <= JVM_INT_MAX:
return IntegerType()
elif JVM_LONG_MIN <= value <= JVM_LONG_MAX:
return LongType()
else:
raise PySparkValueError(
error_class="VALUE_NOT_BETWEEN",
message_parameters={
"arg_name": "value",
"min": str(JVM_LONG_MIN),
"max": str(JVM_SHORT_MAX),
},
)
elif isinstance(value, float):
return DoubleType()
elif isinstance(value, str):
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred():
return TimestampNTZType()
elif isinstance(value, datetime.datetime):
return TimestampType()
elif isinstance(value, datetime.date):
return DateType()
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
elif isinstance(value, np.generic):
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
elif isinstance(value, np.bool_):
return BooleanType()
elif isinstance(value, list):
# follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
# right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
if len(value) == 0:
raise PySparkValueError(
error_class="CANNOT_BE_EMPTY",
message_parameters={"item": "value"},
)
first = value[0]
if first is None:
raise PySparkTypeError(
error_class="CANNOT_INFER_ARRAY_TYPE",
message_parameters={},
)
return ArrayType(LiteralExpression._infer_type(first), True)
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE",
message_parameters={"data_type": type(value).__name__},
)
@classmethod
def _from_value(cls, value: Any) -> "LiteralExpression":
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
@classmethod
def _to_value(
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
) -> Any:
if literal.HasField("null"):
return None
elif literal.HasField("binary"):
assert dataType is None or isinstance(dataType, BinaryType)
return literal.binary
elif literal.HasField("boolean"):
assert dataType is None or isinstance(dataType, BooleanType)
return literal.boolean
elif literal.HasField("byte"):
assert dataType is None or isinstance(dataType, ByteType)
return literal.byte
elif literal.HasField("short"):
assert dataType is None or isinstance(dataType, ShortType)
return literal.short
elif literal.HasField("integer"):
assert dataType is None or isinstance(dataType, IntegerType)
return literal.integer
elif literal.HasField("long"):
assert dataType is None or isinstance(dataType, LongType)
return literal.long
elif literal.HasField("float"):
assert dataType is None or isinstance(dataType, FloatType)
return literal.float
elif literal.HasField("double"):
assert dataType is None or isinstance(dataType, DoubleType)
return literal.double
elif literal.HasField("decimal"):
assert dataType is None or isinstance(dataType, DecimalType)
return decimal.Decimal(literal.decimal.value)
elif literal.HasField("string"):
assert dataType is None or isinstance(dataType, StringType)
return literal.string
elif literal.HasField("date"):
assert dataType is None or isinstance(dataType, DataType)
return DateType().fromInternal(literal.date)
elif literal.HasField("timestamp"):
assert dataType is None or isinstance(dataType, TimestampType)
return TimestampType().fromInternal(literal.timestamp)
elif literal.HasField("timestamp_ntz"):
assert dataType is None or isinstance(dataType, TimestampNTZType)
return TimestampNTZType().fromInternal(literal.timestamp_ntz)
elif literal.HasField("day_time_interval"):
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
elif literal.HasField("array"):
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
if dataType is not None:
assert isinstance(dataType, ArrayType)
assert elementType == dataType.elementType
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
raise PySparkTypeError(
error_class="UNSUPPORTED_LITERAL",
message_parameters={"literal": str(literal)},
)
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
"""Converts the literal expression to the literal in proto."""
expr = proto.Expression()
if self._value is None:
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
elif isinstance(self._dataType, BinaryType):
expr.literal.binary = bytes(self._value)
elif isinstance(self._dataType, BooleanType):
expr.literal.boolean = bool(self._value)
elif isinstance(self._dataType, ByteType):
expr.literal.byte = int(self._value)
elif isinstance(self._dataType, ShortType):
expr.literal.short = int(self._value)
elif isinstance(self._dataType, IntegerType):
expr.literal.integer = int(self._value)
elif isinstance(self._dataType, LongType):
expr.literal.long = int(self._value)
elif isinstance(self._dataType, FloatType):
expr.literal.float = float(self._value)
elif isinstance(self._dataType, DoubleType):
expr.literal.double = float(self._value)
elif isinstance(self._dataType, DecimalType):
expr.literal.decimal.value = str(self._value)
expr.literal.decimal.precision = self._dataType.precision
expr.literal.decimal.scale = self._dataType.scale
elif isinstance(self._dataType, StringType):
expr.literal.string = str(self._value)
elif isinstance(self._dataType, DateType):
expr.literal.date = int(self._value)
elif isinstance(self._dataType, TimestampType):
expr.literal.timestamp = int(self._value)
elif isinstance(self._dataType, TimestampNTZType):
expr.literal.timestamp_ntz = int(self._value)
elif isinstance(self._dataType, DayTimeIntervalType):
expr.literal.day_time_interval = int(self._value)
elif isinstance(self._dataType, ArrayType):
element_type = self._dataType.elementType
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
for v in self._value:
expr.literal.array.elements.append(
LiteralExpression(v, element_type).to_plan(session).literal
)
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE",
message_parameters={"data_type": str(self._dataType)},
)
return expr
def __repr__(self) -> str:
return f"{self._value}"
class ColumnReference(Expression):
"""Represents a column reference. There is no guarantee that this column
actually exists. In the context of this project, we refer by its name and
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""
def __init__(self, unparsed_identifier: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(unparsed_identifier, str)
self._unparsed_identifier = unparsed_identifier
assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id
def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
if self._plan_id is not None:
expr.unresolved_attribute.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
return f"{self._unparsed_identifier}"
def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, ColumnReference)
and other._unparsed_identifier == self._unparsed_identifier
)
class UnresolvedStar(Expression):
def __init__(self, unparsed_target: Optional[str]):
super().__init__()
if unparsed_target is not None:
assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*")
self._unparsed_target = unparsed_target
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
return expr
def __repr__(self) -> str:
if self._unparsed_target is not None:
return f"unresolvedstar({self._unparsed_target})"
else:
return "unresolvedstar()"
def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, UnresolvedStar)
and other._unparsed_target == self._unparsed_target
)
class SQLExpression(Expression):
"""Returns Expression which contains a string which is a SQL expression
and server side will parse it by Catalyst
"""
def __init__(self, expr: str) -> None:
super().__init__()
self._expr: str = expr
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the SQL expression."""
expr = proto.Expression()
expr.expression_string.expression = self._expr
return expr
def __eq__(self, other: Any) -> bool:
return other is not None and isinstance(other, SQLExpression) and other._expr == self._expr
class SortOrder(Expression):
def __init__(self, child: Expression, ascending: bool = True, nullsFirst: bool = True) -> None:
super().__init__()
self._child = child
self._ascending = ascending
self._nullsFirst = nullsFirst
def __repr__(self) -> str:
return (
str(self._child)
+ (" ASC" if self._ascending else " DESC")
+ (" NULLS FIRST" if self._nullsFirst else " NULLS LAST")
)
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
sort = proto.Expression()
sort.sort_order.child.CopyFrom(self._child.to_plan(session))
if self._ascending:
sort.sort_order.direction = (
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING
)
else:
sort.sort_order.direction = (
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING
)
if self._nullsFirst:
sort.sort_order.null_ordering = proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST
else:
sort.sort_order.null_ordering = proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST
return sort
class UnresolvedFunction(Expression):
def __init__(
self,
name: str,
args: Sequence["Expression"],
is_distinct: bool = False,
) -> None:
super().__init__()
assert isinstance(name, str)
self._name = name
assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args)
self._args = args
assert isinstance(is_distinct, bool)
self._is_distinct = is_distinct
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
fun.unresolved_function.function_name = self._name
if len(self._args) > 0:
fun.unresolved_function.arguments.extend([arg.to_plan(session) for arg in self._args])
fun.unresolved_function.is_distinct = self._is_distinct
return fun
def __repr__(self) -> str:
# Default print handling:
if self._is_distinct:
return f"{self._name}(distinct {', '.join([str(arg) for arg in self._args])})"
else:
return f"{self._name}({', '.join([str(arg) for arg in self._args])})"
class PythonUDF:
"""Represents a Python user-defined function."""
def __init__(
self,
output_type: Union[DataType, str],
eval_type: int,
func: Callable[..., Any],
python_ver: str,
) -> None:
self._output_type: DataType = (
UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
)
self._eval_type = eval_type
self._func = func
self._python_ver = python_ver
def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
if isinstance(self._output_type, UnparsedDataType):
parsed = session._analyze(
method="ddl_parse", ddl_string=self._output_type.data_type_string
).parsed
assert isinstance(parsed, DataType)
output_type = parsed
else:
output_type = self._output_type
expr = proto.PythonUDF()
expr.output_type.CopyFrom(pyspark_types_to_proto_types(output_type))
expr.eval_type = self._eval_type
expr.command = CloudPickleSerializer().dumps((self._func, output_type))
expr.python_ver = self._python_ver
return expr
def __repr__(self) -> str:
return f"{self._output_type}, {self._eval_type}, {self._func}, f{self._python_ver}"
class JavaUDF:
"""Represents a Java (aggregate) user-defined function."""
def __init__(
self,
class_name: str,
output_type: Optional[Union[DataType, str]] = None,
aggregate: bool = False,
) -> None:
self._class_name = class_name
self._output_type: Optional[DataType] = (
UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
)
self._aggregate = aggregate
def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
expr = proto.JavaUDF()
expr.class_name = self._class_name
if self._output_type is not None:
expr.output_type.CopyFrom(pyspark_types_to_proto_types(self._output_type))
expr.aggregate = self._aggregate
return expr
def __repr__(self) -> str:
return f"{self._class_name}, {self._output_type}"
class CommonInlineUserDefinedFunction(Expression):
"""Represents a user-defined function with an inlined defined function body of any programming
languages."""
def __init__(
self,
function_name: str,
function: Union[PythonUDF, JavaUDF],
deterministic: bool = False,
arguments: Sequence[Expression] = [],
):
super().__init__()
self._function_name = function_name
self._deterministic = deterministic
self._arguments = arguments
self._function = function
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.common_inline_user_defined_function.function_name = self._function_name
expr.common_inline_user_defined_function.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.common_inline_user_defined_function.arguments.extend(
[arg.to_plan(session) for arg in self._arguments]
)
expr.common_inline_user_defined_function.python_udf.CopyFrom(
cast(proto.PythonUDF, self._function.to_plan(session))
)
return expr
def to_plan_udf(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
"""Compared to `to_plan`, it returns a CommonInlineUserDefinedFunction instead of an
Expression."""
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.arguments.extend([arg.to_plan(session) for arg in self._arguments])
expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session)))
return expr
def to_plan_judf(
self, session: "SparkConnectClient"
) -> "proto.CommonInlineUserDefinedFunction":
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session)))
return expr
def __repr__(self) -> str:
return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})"
class WithField(Expression):
def __init__(
self,
structExpr: Expression,
fieldName: str,
valueExpr: Expression,
) -> None:
super().__init__()
assert isinstance(structExpr, Expression)
self._structExpr = structExpr
assert isinstance(fieldName, str)
self._fieldName = fieldName
assert isinstance(valueExpr, Expression)
self._valueExpr = valueExpr
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
return expr
def __repr__(self) -> str:
return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})"
class DropField(Expression):
def __init__(
self,
structExpr: Expression,
fieldName: str,
) -> None:
super().__init__()
assert isinstance(structExpr, Expression)
self._structExpr = structExpr
assert isinstance(fieldName, str)
self._fieldName = fieldName
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
return expr
def __repr__(self) -> str:
return f"DropField({self._structExpr}, {self._fieldName})"
class UnresolvedExtractValue(Expression):
def __init__(
self,
child: Expression,
extraction: Expression,
) -> None:
super().__init__()
assert isinstance(child, Expression)
self._child = child
assert isinstance(extraction, Expression)
self._extraction = extraction
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
return expr
def __repr__(self) -> str:
return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})"
class UnresolvedRegex(Expression):
def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(col_name, str)
self.col_name = col_name
assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.unresolved_regex.col_name = self.col_name
if self._plan_id is not None:
expr.unresolved_regex.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
return f"UnresolvedRegex({self.col_name})"
class CastExpression(Expression):
def __init__(
self,
expr: Expression,
data_type: Union[DataType, str],
) -> None:
super().__init__()
self._expr = expr
self._data_type = data_type
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
fun.cast.expr.CopyFrom(self._expr.to_plan(session))
if isinstance(self._data_type, str):
fun.cast.type_str = self._data_type
else:
fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type))
return fun
def __repr__(self) -> str:
return f"({self._expr} ({self._data_type}))"
class UnresolvedNamedLambdaVariable(Expression):
_lock: Lock = Lock()
_nextVarNameId: int = 0
def __init__(
self,
name_parts: Sequence[str],
) -> None:
super().__init__()
assert (
isinstance(name_parts, list)
and len(name_parts) > 0
and all(isinstance(p, str) for p in name_parts)
)
self._name_parts = name_parts
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.unresolved_named_lambda_variable.name_parts.extend(self._name_parts)
return expr
def __repr__(self) -> str:
return f"(UnresolvedNamedLambdaVariable({', '.join(self._name_parts)})"
@staticmethod
def fresh_var_name(name: str) -> str:
assert isinstance(name, str) and str != ""
_id: Optional[int] = None
with UnresolvedNamedLambdaVariable._lock:
_id = UnresolvedNamedLambdaVariable._nextVarNameId
UnresolvedNamedLambdaVariable._nextVarNameId += 1
assert _id is not None
return f"{name}_{_id}"
class LambdaFunction(Expression):
def __init__(
self,
function: Expression,
arguments: Sequence[UnresolvedNamedLambdaVariable],
) -> None:
super().__init__()
assert isinstance(function, Expression)
assert (
isinstance(arguments, list)
and len(arguments) > 0
and all(isinstance(arg, UnresolvedNamedLambdaVariable) for arg in arguments)
)
self._function = function
self._arguments = arguments
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.lambda_function.function.CopyFrom(self._function.to_plan(session))
expr.lambda_function.arguments.extend(
[arg.to_plan(session).unresolved_named_lambda_variable for arg in self._arguments]
)
return expr
def __repr__(self) -> str:
return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})"
class WindowExpression(Expression):
def __init__(
self,
windowFunction: Expression,
windowSpec: "WindowSpec",
) -> None:
super().__init__()
from pyspark.sql.connect.window import WindowSpec
assert windowFunction is not None and isinstance(windowFunction, Expression)
assert windowSpec is not None and isinstance(windowSpec, WindowSpec)
self._windowFunction = windowFunction
self._windowSpec = windowSpec
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.window.window_function.CopyFrom(self._windowFunction.to_plan(session))
if len(self._windowSpec._partitionSpec) > 0:
expr.window.partition_spec.extend(
[p.to_plan(session) for p in self._windowSpec._partitionSpec]
)
else:
warnings.warn(
"WARN WindowExpression: No Partition Defined for Window operation! "
"Moving all data to a single partition, this can cause serious "
"performance degradation."
)
if len(self._windowSpec._orderSpec) > 0:
expr.window.order_spec.extend(
[s.to_plan(session).sort_order for s in self._windowSpec._orderSpec]
)
if self._windowSpec._frame is not None:
if self._windowSpec._frame._isRowFrame:
expr.window.frame_spec.frame_type = (
proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_ROW
)
start = self._windowSpec._frame._start
if start == 0:
expr.window.frame_spec.lower.current_row = True
elif start == JVM_LONG_MIN:
expr.window.frame_spec.lower.unbounded = True
elif JVM_INT_MIN <= start <= JVM_INT_MAX:
expr.window.frame_spec.lower.value.literal.integer = start
else:
raise PySparkValueError(
error_class="VALUE_NOT_BETWEEN",
message_parameters={
"arg_name": "start",
"min": str(JVM_INT_MIN),
"max": str(JVM_INT_MAX),
},
)
end = self._windowSpec._frame._end
if end == 0:
expr.window.frame_spec.upper.current_row = True
elif end == JVM_LONG_MAX:
expr.window.frame_spec.upper.unbounded = True
elif JVM_INT_MIN <= end <= JVM_INT_MAX:
expr.window.frame_spec.upper.value.literal.integer = end
else:
raise PySparkValueError(
error_class="VALUE_NOT_BETWEEN",
message_parameters={
"arg_name": "end",
"min": str(JVM_INT_MIN),
"max": str(JVM_INT_MAX),
},
)
else:
expr.window.frame_spec.frame_type = (
proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_RANGE
)
start = self._windowSpec._frame._start
if start == 0:
expr.window.frame_spec.lower.current_row = True
elif start == JVM_LONG_MIN:
expr.window.frame_spec.lower.unbounded = True
else:
expr.window.frame_spec.lower.value.literal.long = start
end = self._windowSpec._frame._end
if end == 0:
expr.window.frame_spec.upper.current_row = True
elif end == JVM_LONG_MAX:
expr.window.frame_spec.upper.unbounded = True
else:
expr.window.frame_spec.upper.value.literal.long = end
return expr
def __repr__(self) -> str:
return f"WindowExpression({str(self._windowFunction)}, ({str(self._windowSpec)}))"
class DistributedSequenceID(Expression):
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
unresolved_function = UnresolvedFunction(name="distributed_sequence_id", args=[])
return unresolved_function.to_plan(session)
def __repr__(self) -> str:
return "DistributedSequenceID()"
class CallFunction(Expression):
def __init__(self, name: str, args: Sequence["Expression"]):
super().__init__()
assert isinstance(name, str)
self._name = name
assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args)
self._args = args
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.call_function.function_name = self._name
if len(self._args) > 0:
expr.call_function.arguments.extend([arg.to_plan(session) for arg in self._args])
return expr
def __repr__(self) -> str:
if len(self._args) > 0:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})"
else:
return f"CallFunction('{self._name}')"
class NamedArgumentExpression(Expression):
def __init__(self, key: str, value: Expression):
super().__init__()
assert isinstance(key, str)
self._key = key
assert isinstance(value, Expression)
self._value = value
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.named_argument_expression.key = self._key
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr
def __repr__(self) -> str:
return f"{self._key} => {self._value}"