blob: bab476e4782d9fb08a425763124c83c9605e6af4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import sys
from typing import TYPE_CHECKING, Union, Sequence, List, Optional
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
ColumnReference,
Expression,
SortOrder,
)
from pyspark.sql.connect.types import (
JVM_LONG_MIN,
JVM_LONG_MAX,
)
from pyspark.sql.window import Window as PySparkWindow, WindowSpec as PySparkWindowSpec
from pyspark.errors import PySparkTypeError
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
__all__ = ["Window", "WindowSpec"]
class WindowFrame:
def __init__(self, isRowFrame: bool, start: int, end: int) -> None:
super().__init__()
assert isinstance(isRowFrame, bool)
assert isinstance(start, int)
assert isinstance(end, int)
self._isRowFrame = isRowFrame
self._start = start
self._end = end
def __repr__(self) -> str:
if self._isRowFrame:
return f"WindowFrame(ROW_FRAME, {self._start}, {self._end})"
else:
return f"WindowFrame(RANGE_FRAME, {self._start}, {self._end})"
class WindowSpec:
def __init__(
self,
partitionSpec: Sequence[Expression],
orderSpec: Sequence[SortOrder],
frame: Optional[WindowFrame],
) -> None:
assert isinstance(partitionSpec, list) and all(
isinstance(p, Expression) for p in partitionSpec
)
assert isinstance(orderSpec, list) and all(isinstance(s, SortOrder) for s in orderSpec)
assert frame is None or isinstance(frame, WindowFrame)
self._partitionSpec = partitionSpec
self._orderSpec = orderSpec
self._frame = frame
def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
_cols.append(col)
elif isinstance(col, list):
for c in col:
if isinstance(c, (str, Column)):
_cols.append(c)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(c).__name__},
)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(col).__name__},
)
newPartitionSpec: List[Expression] = []
for c in _cols:
if isinstance(c, Column):
newPartitionSpec.append(c._expr)
else:
newPartitionSpec.append(ColumnReference(c))
return WindowSpec(
partitionSpec=newPartitionSpec,
orderSpec=self._orderSpec,
frame=self._frame,
)
def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
_cols.append(col)
elif isinstance(col, list):
for c in col:
if isinstance(c, (str, Column)):
_cols.append(c)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(c).__name__},
)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(col).__name__},
)
newOrderSpec: List[SortOrder] = []
for c in _cols:
if isinstance(c, Column):
if isinstance(c._expr, SortOrder):
newOrderSpec.append(c._expr)
else:
newOrderSpec.append(SortOrder(c._expr))
else:
newOrderSpec.append(SortOrder(ColumnReference(c)))
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=newOrderSpec,
frame=self._frame,
)
def rowsBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=self._orderSpec,
frame=WindowFrame(isRowFrame=True, start=start, end=end),
)
def rangeBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=self._orderSpec,
frame=WindowFrame(isRowFrame=False, start=start, end=end),
)
def __repr__(self) -> str:
strs: List[str] = []
if len(self._partitionSpec) > 0:
str_p = ", ".join([str(p) for p in self._partitionSpec])
strs.append(f"PartitionBy({str_p})")
if len(self._orderSpec) > 0:
str_s = ", ".join([str(s) for s in self._orderSpec])
strs.append(f"OrderBy({str_s})")
if self._frame is not None:
strs.append(str(self._frame))
return "WindowSpec(" + ", ".join(strs) + ")"
WindowSpec.rangeBetween.__doc__ = PySparkWindowSpec.rangeBetween.__doc__
WindowSpec.rowsBetween.__doc__ = PySparkWindowSpec.rowsBetween.__doc__
WindowSpec.orderBy.__doc__ = PySparkWindowSpec.orderBy.__doc__
WindowSpec.partitionBy.__doc__ = PySparkWindowSpec.partitionBy.__doc__
WindowSpec.__doc__ = PySparkWindowSpec.__doc__
class Window:
_PRECEDING_THRESHOLD = max(-sys.maxsize, JVM_LONG_MIN)
_FOLLOWING_THRESHOLD = min(sys.maxsize, JVM_LONG_MAX)
unboundedPreceding: int = JVM_LONG_MIN
unboundedFollowing: int = JVM_LONG_MAX
currentRow: int = 0
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)
@staticmethod
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.partitionBy(*cols)
@staticmethod
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.orderBy(*cols)
@staticmethod
def rowsBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rowsBetween(start, end)
@staticmethod
def rangeBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rangeBetween(start, end)
Window.orderBy.__doc__ = PySparkWindow.orderBy.__doc__
Window.rowsBetween.__doc__ = PySparkWindow.rowsBetween.__doc__
Window.rangeBetween.__doc__ = PySparkWindow.rangeBetween.__doc__
Window.partitionBy.__doc__ = PySparkWindow.partitionBy.__doc__
Window.__doc__ = PySparkWindow.__doc__
def _test() -> None:
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.window
globs = pyspark.sql.connect.window.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.window tests").remote("local[4]").getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.window,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()