blob: 952258e8db489de69321cd97c1f56a62b999cad3 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
from typing import TYPE_CHECKING, Any, Union, Sequence, List, Optional, Tuple, cast, Iterable
from pyspark.sql.column import Column
from pyspark.sql.window import (
Window as ParentWindow,
WindowSpec as ParentWindowSpec,
)
from pyspark.sql.connect.expressions import Expression, SortOrder
from pyspark.sql.connect.functions import builtin as F
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
__all__ = ["Window", "WindowSpec"]
def _to_cols(cols: Tuple[Union["ColumnOrName", Sequence["ColumnOrName"]], ...]) -> List[Column]:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
return [F._to_col(c) for c in cast(Iterable["ColumnOrName"], cols)]
class WindowFrame:
def __init__(self, isRowFrame: bool, start: int, end: int) -> None:
super().__init__()
assert isinstance(isRowFrame, bool)
assert isinstance(start, int)
assert isinstance(end, int)
self._isRowFrame = isRowFrame
self._start = start
self._end = end
def __repr__(self) -> str:
if self._isRowFrame:
return f"WindowFrame(ROW_FRAME, {self._start}, {self._end})"
else:
return f"WindowFrame(RANGE_FRAME, {self._start}, {self._end})"
class WindowSpec(ParentWindowSpec):
def __new__(
cls,
partitionSpec: Sequence[Expression],
orderSpec: Sequence[SortOrder],
frame: Optional[WindowFrame],
) -> "WindowSpec":
self = object.__new__(cls)
self.__init__(partitionSpec, orderSpec, frame) # type: ignore[misc]
return self
def __getnewargs__(self) -> Tuple[Any, ...]:
return (self._partitionSpec, self._orderSpec, self._frame)
def __init__(
self,
partitionSpec: Sequence[Expression],
orderSpec: Sequence[SortOrder],
frame: Optional[WindowFrame],
) -> None:
assert isinstance(partitionSpec, list) and all(
isinstance(p, Expression) for p in partitionSpec
)
assert isinstance(orderSpec, list) and all(isinstance(s, SortOrder) for s in orderSpec)
assert frame is None or isinstance(frame, WindowFrame)
self._partitionSpec = partitionSpec
self._orderSpec = orderSpec
self._frame = frame
def partitionBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return WindowSpec(
partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc]
orderSpec=self._orderSpec,
frame=self._frame,
)
def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)],
frame=self._frame,
)
def rowsBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=self._orderSpec,
frame=WindowFrame(isRowFrame=True, start=start, end=end),
)
def rangeBetween(self, start: int, end: int) -> "WindowSpec":
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=self._orderSpec,
frame=WindowFrame(isRowFrame=False, start=start, end=end),
)
def __repr__(self) -> str:
strs: List[str] = []
if len(self._partitionSpec) > 0:
str_p = ", ".join([str(p) for p in self._partitionSpec])
strs.append(f"PartitionBy({str_p})")
if len(self._orderSpec) > 0:
str_s = ", ".join([str(s) for s in self._orderSpec])
strs.append(f"OrderBy({str_s})")
if self._frame is not None:
strs.append(str(self._frame))
return "WindowSpec(" + ", ".join(strs) + ")"
class Window(ParentWindow):
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)
@staticmethod
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.partitionBy(*cols)
@staticmethod
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
return Window._spec.orderBy(*cols)
@staticmethod
def rowsBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rowsBetween(start, end)
@staticmethod
def rangeBetween(start: int, end: int) -> "WindowSpec":
return Window._spec.rangeBetween(start, end)
def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.window
globs = pyspark.sql.window.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.window tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.sql.window,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()