| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| # mypy: disable-error-code="override" |
| |
| from pyspark.errors.exceptions.base import ( |
| SessionNotSameException, |
| PySparkIndexError, |
| PySparkAttributeError, |
| ) |
| from pyspark.resource import ResourceProfile |
| from pyspark.sql.connect.utils import check_dependencies |
| |
| check_dependencies(__name__) |
| |
| from typing import ( |
| Any, |
| Dict, |
| Iterator, |
| List, |
| Optional, |
| Tuple, |
| Union, |
| Sequence, |
| TYPE_CHECKING, |
| overload, |
| Callable, |
| cast, |
| Type, |
| ) |
| |
| import sys |
| import random |
| import pyarrow as pa |
| import json |
| import warnings |
| from collections.abc import Iterable |
| from functools import cached_property |
| |
| from pyspark import _NoValue |
| from pyspark._globals import _NoValueType |
| from pyspark.util import is_remote_only |
| from pyspark.sql.types import Row, StructType, _create_row |
| from pyspark.sql.dataframe import ( |
| DataFrame as ParentDataFrame, |
| DataFrameNaFunctions as ParentDataFrameNaFunctions, |
| DataFrameStatFunctions as ParentDataFrameStatFunctions, |
| ) |
| |
| from pyspark.errors import ( |
| PySparkTypeError, |
| PySparkAttributeError, |
| PySparkValueError, |
| PySparkNotImplementedError, |
| PySparkRuntimeError, |
| ) |
| from pyspark.util import PythonEvalType |
| from pyspark.storagelevel import StorageLevel |
| import pyspark.sql.connect.plan as plan |
| from pyspark.sql.connect.conversion import ArrowTableToRowsConversion |
| from pyspark.sql.connect.group import GroupedData |
| from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2 |
| from pyspark.sql.connect.streaming.readwriter import DataStreamWriter |
| from pyspark.sql.column import Column |
| from pyspark.sql.connect.expressions import ( |
| ColumnReference, |
| UnresolvedRegex, |
| UnresolvedStar, |
| ) |
| from pyspark.sql.connect.functions import builtin as F |
| from pyspark.sql.pandas.types import from_arrow_schema |
| from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] |
| |
| |
| if TYPE_CHECKING: |
| from pyspark.sql.connect._typing import ( |
| ColumnOrName, |
| ColumnOrNameOrOrdinal, |
| LiteralType, |
| PrimitiveType, |
| OptionalPrimitiveType, |
| PandasMapIterFunction, |
| ArrowMapIterFunction, |
| ) |
| from pyspark.core.rdd import RDD |
| from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike |
| from pyspark.sql.connect.observation import Observation |
| from pyspark.sql.connect.session import SparkSession |
| from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame |
| |
| |
| class DataFrame(ParentDataFrame): |
| def __new__( |
| cls, |
| plan: plan.LogicalPlan, |
| session: "SparkSession", |
| ) -> "DataFrame": |
| self = object.__new__(cls) |
| self.__init__(plan, session) # type: ignore[misc] |
| return self |
| |
| def __init__( |
| self, |
| plan: plan.LogicalPlan, |
| session: "SparkSession", |
| ): |
| """Creates a new data frame""" |
| self._plan = plan |
| if self._plan is None: |
| raise PySparkRuntimeError( |
| error_class="MISSING_VALID_PLAN", |
| message_parameters={"operator": "__init__"}, |
| ) |
| |
| self._session: "SparkSession" = session # type: ignore[assignment] |
| if self._session is None: |
| raise PySparkRuntimeError( |
| error_class="NO_ACTIVE_SESSION", |
| message_parameters={"operator": "__init__"}, |
| ) |
| |
| # Check whether _repr_html is supported or not, we use it to avoid calling RPC twice |
| # by __repr__ and _repr_html_ while eager evaluation opens. |
| self._support_repr_html = False |
| self._cached_schema: Optional[StructType] = None |
| |
| def __reduce__(self) -> Tuple: |
| """ |
| Custom method for serializing the DataFrame object using Pickle. Since the DataFrame |
| overrides "__getattr__" method, the default serialization method does not work. |
| |
| Returns |
| ------- |
| The tuple containing the information needed to reconstruct the object. |
| |
| """ |
| return ( |
| DataFrame, |
| ( |
| self._plan, |
| self._session, |
| ), |
| { |
| "_support_repr_html": self._support_repr_html, |
| "_cached_schema": self._cached_schema, |
| }, |
| ) |
| |
| def __repr__(self) -> str: |
| if not self._support_repr_html: |
| ( |
| repl_eager_eval_enabled, |
| repl_eager_eval_max_num_rows, |
| repl_eager_eval_truncate, |
| ) = self._session._client.get_configs( |
| "spark.sql.repl.eagerEval.enabled", |
| "spark.sql.repl.eagerEval.maxNumRows", |
| "spark.sql.repl.eagerEval.truncate", |
| ) |
| if repl_eager_eval_enabled == "true": |
| return self._show_string( |
| n=int(cast(str, repl_eager_eval_max_num_rows)), |
| truncate=int(cast(str, repl_eager_eval_truncate)), |
| vertical=False, |
| ) |
| return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) |
| |
| def _repr_html_(self) -> Optional[str]: |
| if not self._support_repr_html: |
| self._support_repr_html = True |
| ( |
| repl_eager_eval_enabled, |
| repl_eager_eval_max_num_rows, |
| repl_eager_eval_truncate, |
| ) = self._session._client.get_configs( |
| "spark.sql.repl.eagerEval.enabled", |
| "spark.sql.repl.eagerEval.maxNumRows", |
| "spark.sql.repl.eagerEval.truncate", |
| ) |
| if repl_eager_eval_enabled == "true": |
| table, _ = DataFrame( |
| plan.HtmlString( |
| child=self._plan, |
| num_rows=int(cast(str, repl_eager_eval_max_num_rows)), |
| truncate=int(cast(str, repl_eager_eval_truncate)), |
| ), |
| session=self._session, |
| )._to_table() |
| return table[0][0].as_py() |
| else: |
| return None |
| |
| @property |
| def write(self) -> "DataFrameWriter": |
| return DataFrameWriter(self._plan, self._session) |
| |
| def isEmpty(self) -> bool: |
| return len(self.select().take(1)) == 0 |
| |
| @overload |
| def select(self, *cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| @overload |
| def select(self, __cols: Union[List[Column], List[str]]) -> ParentDataFrame: |
| ... |
| |
| def select(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc] |
| if len(cols) == 1 and isinstance(cols[0], list): |
| cols = cols[0] |
| return DataFrame( |
| plan.Project(self._plan, [F._to_col(c) for c in cols]), |
| session=self._session, |
| ) |
| |
| def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame: |
| sql_expr = [] |
| if len(expr) == 1 and isinstance(expr[0], list): |
| expr = expr[0] # type: ignore[assignment] |
| for element in expr: |
| if isinstance(element, str): |
| sql_expr.append(F.expr(element)) |
| else: |
| sql_expr.extend([F.expr(e) for e in element]) |
| |
| return DataFrame(plan.Project(self._plan, sql_expr), session=self._session) |
| |
| def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame: |
| if not exprs: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"item": "exprs"}, |
| ) |
| |
| if len(exprs) == 1 and isinstance(exprs[0], dict): |
| measures = [F._invoke_function(f, F.col(e)) for e, f in exprs[0].items()] |
| return self.groupBy().agg(*measures) |
| else: |
| # other expressions |
| assert all(isinstance(c, Column) for c in exprs), "all exprs should be Expression" |
| exprs = cast(Tuple[Column, ...], exprs) |
| return self.groupBy().agg(*exprs) |
| |
| def alias(self, alias: str) -> ParentDataFrame: |
| return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) |
| |
| def colRegex(self, colName: str) -> Column: |
| from pyspark.sql.connect.column import Column as ConnectColumn |
| |
| if not isinstance(colName, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "colName", "arg_type": type(colName).__name__}, |
| ) |
| return ConnectColumn(UnresolvedRegex(colName, self._plan._plan_id)) |
| |
| @property |
| def dtypes(self) -> List[Tuple[str, str]]: |
| return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] |
| |
| @property |
| def columns(self) -> List[str]: |
| return self.schema.names |
| |
| @property |
| def sparkSession(self) -> "SparkSession": |
| return self._session |
| |
| def count(self) -> int: |
| table, _ = self.agg( |
| F._invoke_function("count", F.lit(1)) |
| )._to_table() # type: ignore[operator] |
| return table[0][0].as_py() |
| |
| def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.Join( |
| left=self._plan, right=other._plan, on=None, how="cross" # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def _check_same_session(self, other: ParentDataFrame) -> None: |
| if self._session.session_id != other._session.session_id: # type: ignore[attr-defined] |
| raise SessionNotSameException( |
| error_class="SESSION_NOT_SAME", |
| message_parameters={}, |
| ) |
| |
| def coalesce(self, numPartitions: int) -> ParentDataFrame: |
| if not numPartitions > 0: |
| raise PySparkValueError( |
| error_class="VALUE_NOT_POSITIVE", |
| message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)}, |
| ) |
| return DataFrame( |
| plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False), |
| self._session, |
| ) |
| |
| @overload |
| def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| @overload |
| def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| def repartition( # type: ignore[misc] |
| self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" |
| ) -> ParentDataFrame: |
| if isinstance(numPartitions, int): |
| if not numPartitions > 0: |
| raise PySparkValueError( |
| error_class="VALUE_NOT_POSITIVE", |
| message_parameters={ |
| "arg_name": "numPartitions", |
| "arg_value": str(numPartitions), |
| }, |
| ) |
| if len(cols) == 0: |
| return DataFrame( |
| plan.Repartition(self._plan, numPartitions, shuffle=True), |
| self._session, |
| ) |
| else: |
| return DataFrame( |
| plan.RepartitionByExpression( |
| self._plan, numPartitions, [F._to_col(c) for c in cols] |
| ), |
| self.sparkSession, |
| ) |
| elif isinstance(numPartitions, (str, Column)): |
| cols = (numPartitions,) + cols |
| return DataFrame( |
| plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]), |
| self.sparkSession, |
| ) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={ |
| "arg_name": "numPartitions", |
| "arg_type": type(numPartitions).__name__, |
| }, |
| ) |
| |
| @overload |
| def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| @overload |
| def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| def repartitionByRange( # type: ignore[misc] |
| self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" |
| ) -> ParentDataFrame: |
| if isinstance(numPartitions, int): |
| if not numPartitions > 0: |
| raise PySparkValueError( |
| error_class="VALUE_NOT_POSITIVE", |
| message_parameters={ |
| "arg_name": "numPartitions", |
| "arg_value": str(numPartitions), |
| }, |
| ) |
| if len(cols) == 0: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"item": "cols"}, |
| ) |
| else: |
| return DataFrame( |
| plan.RepartitionByExpression( |
| self._plan, numPartitions, [F._sort_col(c) for c in cols] |
| ), |
| self.sparkSession, |
| ) |
| elif isinstance(numPartitions, (str, Column)): |
| return DataFrame( |
| plan.RepartitionByExpression( |
| self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)] |
| ), |
| self.sparkSession, |
| ) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_INT_OR_STR", |
| message_parameters={ |
| "arg_name": "numPartitions", |
| "arg_type": type(numPartitions).__name__, |
| }, |
| ) |
| |
| def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: |
| if subset is not None and not isinstance(subset, (list, tuple)): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_TUPLE", |
| message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__}, |
| ) |
| |
| if subset is None: |
| return DataFrame( |
| plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session |
| ) |
| else: |
| return DataFrame( |
| plan.Deduplicate(child=self._plan, column_names=subset), session=self._session |
| ) |
| |
| drop_duplicates = dropDuplicates |
| |
| def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame: |
| if subset is not None and not isinstance(subset, (list, tuple)): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_TUPLE", |
| message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__}, |
| ) |
| |
| if subset is None: |
| return DataFrame( |
| plan.Deduplicate(child=self._plan, all_columns_as_keys=True, within_watermark=True), |
| session=self._session, |
| ) |
| else: |
| return DataFrame( |
| plan.Deduplicate(child=self._plan, column_names=subset, within_watermark=True), |
| session=self._session, |
| ) |
| |
| def distinct(self) -> ParentDataFrame: |
| return DataFrame( |
| plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session |
| ) |
| |
| @overload |
| def drop(self, cols: "ColumnOrName") -> ParentDataFrame: |
| ... |
| |
| @overload |
| def drop(self, *cols: str) -> ParentDataFrame: |
| ... |
| |
| def drop(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc] |
| _cols = list(cols) |
| if any(not isinstance(c, (str, Column)) for c in _cols): |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(cols).__name__}, |
| ) |
| |
| return DataFrame( |
| plan.Drop( |
| child=self._plan, |
| columns=_cols, |
| ), |
| session=self._session, |
| ) |
| |
| def filter(self, condition: Union[Column, str]) -> ParentDataFrame: |
| if isinstance(condition, str): |
| expr = F.expr(condition) |
| else: |
| expr = condition |
| return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) |
| |
| def first(self) -> Optional[Row]: |
| return self.head() |
| |
| @overload # type: ignore[no-overload-impl] |
| def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": |
| ... |
| |
| @overload |
| def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": |
| ... |
| |
| def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData: |
| if len(cols) == 1 and isinstance(cols[0], list): |
| cols = cols[0] |
| |
| _cols: List[Column] = [] |
| for c in cols: |
| if isinstance(c, Column): |
| _cols.append(c) |
| elif isinstance(c, str): |
| _cols.append(self[c]) |
| elif isinstance(c, int) and not isinstance(c, bool): |
| if c < 1: |
| raise PySparkIndexError( |
| error_class="INDEX_NOT_POSITIVE", message_parameters={"index": str(c)} |
| ) |
| # ordinal is 1-based |
| _cols.append(self[c - 1]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| |
| return GroupedData(df=self, group_type="groupby", grouping_cols=_cols) |
| |
| groupby = groupBy # type: ignore[assignment] |
| |
| @overload |
| def rollup(self, *cols: "ColumnOrName") -> "GroupedData": |
| ... |
| |
| @overload |
| def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": |
| ... |
| |
| def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] |
| _cols: List[Column] = [] |
| for c in cols: |
| if isinstance(c, Column): |
| _cols.append(c) |
| elif isinstance(c, str): |
| _cols.append(self[c]) |
| elif isinstance(c, int) and not isinstance(c, bool): |
| if c < 1: |
| raise PySparkIndexError( |
| error_class="INDEX_NOT_POSITIVE", message_parameters={"index": str(c)} |
| ) |
| # ordinal is 1-based |
| _cols.append(self[c - 1]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| |
| return GroupedData(df=self, group_type="rollup", grouping_cols=_cols) |
| |
| @overload |
| def cube(self, *cols: "ColumnOrName") -> "GroupedData": |
| ... |
| |
| @overload |
| def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": |
| ... |
| |
| def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] |
| _cols: List[Column] = [] |
| for c in cols: |
| if isinstance(c, Column): |
| _cols.append(c) |
| elif isinstance(c, str): |
| _cols.append(self[c]) |
| elif isinstance(c, int) and not isinstance(c, bool): |
| if c < 1: |
| raise PySparkIndexError( |
| error_class="INDEX_NOT_POSITIVE", message_parameters={"index": str(c)} |
| ) |
| # ordinal is 1-based |
| _cols.append(self[c - 1]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| |
| return GroupedData(df=self, group_type="cube", grouping_cols=_cols) |
| |
| def groupingSets( |
| self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName" |
| ) -> "GroupedData": |
| gsets: List[List[Column]] = [] |
| for grouping_set in groupingSets: |
| gset: List[Column] = [] |
| for c in grouping_set: |
| if isinstance(c, Column): |
| gset.append(c) |
| elif isinstance(c, str): |
| gset.append(self[c]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={ |
| "arg_name": "groupingSets", |
| "arg_type": type(c).__name__, |
| }, |
| ) |
| gsets.append(gset) |
| |
| gcols: List[Column] = [] |
| for c in cols: |
| if isinstance(c, Column): |
| gcols.append(c) |
| elif isinstance(c, str): |
| gcols.append(self[c]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| |
| return GroupedData( |
| df=self, group_type="grouping_sets", grouping_cols=gcols, grouping_sets=gsets |
| ) |
| |
| @overload |
| def head(self) -> Optional[Row]: |
| ... |
| |
| @overload |
| def head(self, n: int) -> List[Row]: |
| ... |
| |
| def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: |
| if n is None: |
| rs = self.head(1) |
| return rs[0] if rs else None |
| return self.take(n) |
| |
| def take(self, num: int) -> List[Row]: |
| return self.limit(num).collect() |
| |
| def join( |
| self, |
| other: ParentDataFrame, |
| on: Optional[Union[str, List[str], Column, List[Column]]] = None, |
| how: Optional[str] = None, |
| ) -> ParentDataFrame: |
| self._check_same_session(other) |
| if how is not None and isinstance(how, str): |
| how = how.lower().replace("_", "") |
| return DataFrame( |
| plan.Join(left=self._plan, right=other._plan, on=on, how=how), # type: ignore[arg-type] |
| session=self._session, |
| ) |
| |
| def _joinAsOf( |
| self, |
| other: ParentDataFrame, |
| leftAsOfColumn: Union[str, Column], |
| rightAsOfColumn: Union[str, Column], |
| on: Optional[Union[str, List[str], Column, List[Column]]] = None, |
| how: Optional[str] = None, |
| *, |
| tolerance: Optional[Column] = None, |
| allowExactMatches: bool = True, |
| direction: str = "backward", |
| ) -> ParentDataFrame: |
| self._check_same_session(other) |
| if how is None: |
| how = "inner" |
| assert isinstance(how, str), "how should be a string" |
| |
| if tolerance is not None: |
| assert isinstance(tolerance, Column), "tolerance should be Column" |
| |
| def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: |
| if isinstance(col, Column): |
| return col |
| else: |
| return df._col(col) # type: ignore[operator] |
| |
| return DataFrame( |
| plan.AsOfJoin( |
| left=self._plan, |
| right=other._plan, # type: ignore[arg-type] |
| left_as_of=_convert_col(self, leftAsOfColumn), |
| right_as_of=_convert_col(other, rightAsOfColumn), |
| on=on, |
| how=how, |
| tolerance=tolerance, |
| allow_exact_matches=allowExactMatches, |
| direction=direction, |
| ), |
| session=self._session, |
| ) |
| |
| def limit(self, n: int) -> ParentDataFrame: |
| return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) |
| |
| def tail(self, num: int) -> List[Row]: |
| return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() |
| |
| def _sort_cols( |
| self, |
| cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], |
| kwargs: Dict[str, Any], |
| ) -> List[Column]: |
| """Return a JVM Seq of Columns that describes the sort order""" |
| if cols is None: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"item": "cols"}, |
| ) |
| |
| if len(cols) == 1 and isinstance(cols[0], list): |
| cols = cols[0] |
| |
| _cols: List[Column] = [] |
| for c in cols: |
| if isinstance(c, int) and not isinstance(c, bool): |
| # ordinal is 1-based |
| if c > 0: |
| _c = self[c - 1] |
| # negative ordinal means sort by desc |
| elif c < 0: |
| _c = self[-c - 1].desc() |
| else: |
| raise PySparkIndexError( |
| error_class="ZERO_INDEX", |
| message_parameters={}, |
| ) |
| else: |
| _c = c # type: ignore[assignment] |
| _cols.append(F._to_col(cast("ColumnOrName", _c))) |
| |
| ascending = kwargs.get("ascending", True) |
| if isinstance(ascending, (bool, int)): |
| if not ascending: |
| _cols = [c.desc() for c in _cols] |
| elif isinstance(ascending, list): |
| _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)] |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_LIST", |
| message_parameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, |
| ) |
| |
| return [F._sort_col(c) for c in _cols] |
| |
| def sort( |
| self, |
| *cols: Union[int, str, Column, List[Union[int, str, Column]]], |
| **kwargs: Any, |
| ) -> ParentDataFrame: |
| return DataFrame( |
| plan.Sort( |
| self._plan, |
| columns=self._sort_cols(cols, kwargs), |
| is_global=True, |
| ), |
| session=self._session, |
| ) |
| |
| orderBy = sort |
| |
| def sortWithinPartitions( |
| self, |
| *cols: Union[int, str, Column, List[Union[int, str, Column]]], |
| **kwargs: Any, |
| ) -> ParentDataFrame: |
| return DataFrame( |
| plan.Sort( |
| self._plan, |
| columns=self._sort_cols(cols, kwargs), |
| is_global=False, |
| ), |
| session=self._session, |
| ) |
| |
| def sample( |
| self, |
| withReplacement: Optional[Union[float, bool]] = None, |
| fraction: Optional[Union[int, float]] = None, |
| seed: Optional[int] = None, |
| ) -> ParentDataFrame: |
| # For the cases below: |
| # sample(True, 0.5 [, seed]) |
| # sample(True, fraction=0.5 [, seed]) |
| # sample(withReplacement=False, fraction=0.5 [, seed]) |
| is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float) |
| |
| # For the case below: |
| # sample(faction=0.5 [, seed]) |
| is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float) |
| |
| # For the case below: |
| # sample(0.5 [, seed]) |
| is_withReplacement_omitted_args = isinstance(withReplacement, float) |
| |
| if not ( |
| is_withReplacement_set |
| or is_withReplacement_omitted_kwargs |
| or is_withReplacement_omitted_args |
| ): |
| argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_FLOAT_OR_INT", |
| message_parameters={ |
| "arg_name": "withReplacement (optional), " |
| + "fraction (required) and seed (optional)", |
| "arg_type": ", ".join(argtypes), |
| }, |
| ) |
| |
| if is_withReplacement_omitted_args: |
| if fraction is not None: |
| seed = cast(int, fraction) |
| fraction = withReplacement |
| withReplacement = None |
| |
| if withReplacement is None: |
| withReplacement = False |
| |
| seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) |
| |
| return DataFrame( |
| plan.Sample( |
| child=self._plan, |
| lower_bound=0.0, |
| upper_bound=fraction, # type: ignore[arg-type] |
| with_replacement=withReplacement, # type: ignore[arg-type] |
| seed=seed, |
| ), |
| session=self._session, |
| ) |
| |
| def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame: |
| return self.withColumnsRenamed({existing: new}) |
| |
| def withColumnsRenamed(self, colsMap: Dict[str, str]) -> ParentDataFrame: |
| if not isinstance(colsMap, dict): |
| raise PySparkTypeError( |
| error_class="NOT_DICT", |
| message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, |
| ) |
| |
| return DataFrame(plan.WithColumnsRenamed(self._plan, colsMap), self._session) |
| |
| def _show_string( |
| self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False |
| ) -> str: |
| if not isinstance(n, int) or isinstance(n, bool): |
| raise PySparkTypeError( |
| error_class="NOT_INT", |
| message_parameters={"arg_name": "n", "arg_type": type(n).__name__}, |
| ) |
| if not isinstance(vertical, bool): |
| raise PySparkTypeError( |
| error_class="NOT_BOOL", |
| message_parameters={"arg_name": "vertical", "arg_type": type(vertical).__name__}, |
| ) |
| |
| _truncate: int = -1 |
| if isinstance(truncate, bool) and truncate: |
| _truncate = 20 |
| else: |
| try: |
| _truncate = int(truncate) |
| except ValueError: |
| raise PySparkTypeError( |
| error_class="NOT_BOOL", |
| message_parameters={ |
| "arg_name": "truncate", |
| "arg_type": type(truncate).__name__, |
| }, |
| ) |
| |
| table, _ = DataFrame( |
| plan.ShowString( |
| child=self._plan, |
| num_rows=n, |
| truncate=_truncate, |
| vertical=vertical, |
| ), |
| session=self._session, |
| )._to_table() |
| return table[0][0].as_py() |
| |
| def withColumns(self, colsMap: Dict[str, Column]) -> ParentDataFrame: |
| if not isinstance(colsMap, dict): |
| raise PySparkTypeError( |
| error_class="NOT_DICT", |
| message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, |
| ) |
| |
| names: List[str] = [] |
| columns: List[Column] = [] |
| for columnName, column in colsMap.items(): |
| names.append(columnName) |
| columns.append(column) |
| |
| return DataFrame( |
| plan.WithColumns( |
| self._plan, |
| columnNames=names, |
| columns=columns, |
| ), |
| session=self._session, |
| ) |
| |
| def withColumn(self, colName: str, col: Column) -> ParentDataFrame: |
| if not isinstance(col, Column): |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN", |
| message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, |
| ) |
| return DataFrame( |
| plan.WithColumns( |
| self._plan, |
| columnNames=[colName], |
| columns=[col], |
| ), |
| session=self._session, |
| ) |
| |
| def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> ParentDataFrame: |
| if not isinstance(metadata, dict): |
| raise PySparkTypeError( |
| error_class="NOT_DICT", |
| message_parameters={"arg_name": "metadata", "arg_type": type(metadata).__name__}, |
| ) |
| |
| return DataFrame( |
| plan.WithColumns( |
| self._plan, |
| columnNames=[columnName], |
| columns=[self[columnName]], |
| metadata=[json.dumps(metadata)], |
| ), |
| session=self._session, |
| ) |
| |
| def unpivot( |
| self, |
| ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]], |
| values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], |
| variableColumnName: str, |
| valueColumnName: str, |
| ) -> ParentDataFrame: |
| assert ids is not None, "ids must not be None" |
| |
| def _convert_cols( |
| cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]] |
| ) -> List[Column]: |
| if cols is None: |
| return [] |
| elif isinstance(cols, (tuple, list)): |
| return [F._to_col(c) for c in cols] |
| else: |
| return [F._to_col(cols)] |
| |
| return DataFrame( |
| plan.Unpivot( |
| self._plan, |
| _convert_cols(ids), |
| _convert_cols(values) if values is not None else None, |
| variableColumnName, |
| valueColumnName, |
| ), |
| self._session, |
| ) |
| |
| melt = unpivot |
| |
| def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame: |
| # TODO: reuse error handling code in sql.DataFrame.withWatermark() |
| if not eventTime or type(eventTime) is not str: |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "eventTime", "arg_type": type(eventTime).__name__}, |
| ) |
| if not delayThreshold or type(delayThreshold) is not str: |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={ |
| "arg_name": "delayThreshold", |
| "arg_type": type(delayThreshold).__name__, |
| }, |
| ) |
| |
| return DataFrame( |
| plan.WithWatermark( |
| self._plan, |
| event_time=eventTime, |
| delay_threshold=delayThreshold, |
| ), |
| session=self._session, |
| ) |
| |
| def hint( |
| self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]] |
| ) -> ParentDataFrame: |
| if len(parameters) == 1 and isinstance(parameters[0], list): |
| parameters = parameters[0] # type: ignore[assignment] |
| |
| if not isinstance(name, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, |
| ) |
| |
| allowed_types = (str, float, int, Column, list) |
| allowed_primitive_types = (str, float, int) |
| allowed_types_repr = ", ".join( |
| [t.__name__ for t in allowed_types[:-1]] |
| + ["list[" + t.__name__ + "]" for t in allowed_primitive_types] |
| ) |
| for p in parameters: |
| if not isinstance(p, allowed_types): |
| raise PySparkTypeError( |
| error_class="INVALID_ITEM_FOR_CONTAINER", |
| message_parameters={ |
| "arg_name": "parameters", |
| "allowed_types": allowed_types_repr, |
| "item_type": type(p).__name__, |
| }, |
| ) |
| if isinstance(p, list): |
| if not all(isinstance(e, allowed_primitive_types) for e in p): |
| raise PySparkTypeError( |
| error_class="INVALID_ITEM_FOR_CONTAINER", |
| message_parameters={ |
| "arg_name": "parameters", |
| "allowed_types": allowed_types_repr, |
| "item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]", |
| }, |
| ) |
| |
| return DataFrame( |
| plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]), |
| session=self._session, |
| ) |
| |
| def randomSplit( |
| self, |
| weights: List[float], |
| seed: Optional[int] = None, |
| ) -> List[ParentDataFrame]: |
| for w in weights: |
| if w < 0.0: |
| raise PySparkValueError( |
| error_class="VALUE_NOT_POSITIVE", |
| message_parameters={"arg_name": "weights", "arg_value": str(w)}, |
| ) |
| seed = seed if seed is not None else random.randint(0, sys.maxsize) |
| total = sum(weights) |
| if total <= 0: |
| raise PySparkValueError( |
| error_class="VALUE_NOT_POSITIVE", |
| message_parameters={"arg_name": "sum(weights)", "arg_value": str(total)}, |
| ) |
| proportions = list(map(lambda x: x / total, weights)) |
| normalizedCumWeights = [0.0] |
| for v in proportions: |
| normalizedCumWeights.append(normalizedCumWeights[-1] + v) |
| j = 1 |
| length = len(normalizedCumWeights) |
| splits = [] |
| while j < length: |
| lowerBound = normalizedCumWeights[j - 1] |
| upperBound = normalizedCumWeights[j] |
| samplePlan = DataFrame( |
| plan.Sample( |
| child=self._plan, |
| lower_bound=lowerBound, |
| upper_bound=upperBound, |
| with_replacement=False, |
| seed=int(seed), |
| deterministic_order=True, |
| ), |
| session=self._session, |
| ) |
| splits.append(samplePlan) |
| j += 1 |
| |
| return splits # type: ignore[return-value] |
| |
| def observe( |
| self, |
| observation: Union["Observation", str], |
| *exprs: Column, |
| ) -> ParentDataFrame: |
| from pyspark.sql.connect.observation import Observation |
| |
| if len(exprs) == 0: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"item": "exprs"}, |
| ) |
| if not all(isinstance(c, Column) for c in exprs): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OF_COLUMN", |
| message_parameters={"arg_name": "exprs"}, |
| ) |
| |
| if isinstance(observation, Observation): |
| return observation._on(self, *exprs) |
| elif isinstance(observation, str): |
| return DataFrame( |
| plan.CollectMetrics(self._plan, observation, list(exprs)), |
| self._session, |
| ) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_OBSERVATION_OR_STR", |
| message_parameters={ |
| "arg_name": "observation", |
| "arg_type": type(observation).__name__, |
| }, |
| ) |
| |
| def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: |
| print(self._show_string(n, truncate, vertical)) |
| |
| def union(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return self.unionAll(other) |
| |
| def unionAll(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, other._plan, "union", is_all=True # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def unionByName( |
| self, other: ParentDataFrame, allowMissingColumns: bool = False |
| ) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, |
| other._plan, # type: ignore[arg-type] |
| "union", |
| by_name=True, |
| allow_missing_columns=allowMissingColumns, |
| ), |
| session=self._session, |
| ) |
| |
| def subtract(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, other._plan, "except", is_all=False # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, other._plan, "except", is_all=True # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def intersect(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, other._plan, "intersect", is_all=False # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: |
| self._check_same_session(other) |
| return DataFrame( |
| plan.SetOperation( |
| self._plan, other._plan, "intersect", is_all=True # type: ignore[arg-type] |
| ), |
| session=self._session, |
| ) |
| |
| def where(self, condition: Union[Column, str]) -> ParentDataFrame: |
| if not isinstance(condition, (str, Column)): |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "condition", "arg_type": type(condition).__name__}, |
| ) |
| return self.filter(condition) |
| |
| @property |
| def na(self) -> ParentDataFrameNaFunctions: |
| return DataFrameNaFunctions(self) |
| |
| def fillna( |
| self, |
| value: Union["LiteralType", Dict[str, "LiteralType"]], |
| subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, |
| ) -> ParentDataFrame: |
| if not isinstance(value, (float, int, str, bool, dict)): |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR", |
| message_parameters={"arg_name": "value", "arg_type": type(value).__name__}, |
| ) |
| if isinstance(value, dict): |
| if len(value) == 0: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"item": "value"}, |
| ) |
| for c, v in value.items(): |
| if not isinstance(c, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={ |
| "arg_name": "key type of dict", |
| "arg_type": type(c).__name__, |
| }, |
| ) |
| if not isinstance(v, (bool, int, float, str)): |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR", |
| message_parameters={ |
| "arg_name": "value type of dict", |
| "arg_type": type(v).__name__, |
| }, |
| ) |
| |
| _cols: List[str] = [] |
| if subset is not None: |
| if isinstance(subset, str): |
| _cols = [subset] |
| elif isinstance(subset, (tuple, list)): |
| for c in subset: |
| if not isinstance(c, str): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| _cols = list(subset) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_TUPLE", |
| message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__}, |
| ) |
| |
| if isinstance(value, dict): |
| _cols = list(value.keys()) |
| _values = [value[c] for c in _cols] |
| else: |
| _values = [value] |
| |
| return DataFrame( |
| plan.NAFill(child=self._plan, cols=_cols, values=_values), |
| session=self._session, |
| ) |
| |
| def dropna( |
| self, |
| how: str = "any", |
| thresh: Optional[int] = None, |
| subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, |
| ) -> ParentDataFrame: |
| min_non_nulls: Optional[int] = None |
| |
| if how is not None: |
| if not isinstance(how, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "how", "arg_type": type(how).__name__}, |
| ) |
| if how == "all": |
| min_non_nulls = 1 |
| elif how == "any": |
| min_non_nulls = None |
| else: |
| raise PySparkValueError( |
| error_class="CANNOT_BE_EMPTY", |
| message_parameters={"arg_name": "how", "arg_value": str(how)}, |
| ) |
| |
| if thresh is not None: |
| if not isinstance(thresh, int): |
| raise PySparkTypeError( |
| error_class="NOT_INT", |
| message_parameters={"arg_name": "thresh", "arg_type": type(thresh).__name__}, |
| ) |
| |
| # 'thresh' overwrites 'how' |
| min_non_nulls = thresh |
| |
| _cols: List[str] = [] |
| if subset is not None: |
| if isinstance(subset, str): |
| _cols = [subset] |
| elif isinstance(subset, (tuple, list)): |
| for c in subset: |
| if not isinstance(c, str): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, |
| ) |
| _cols = list(subset) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__}, |
| ) |
| |
| return DataFrame( |
| plan.NADrop(child=self._plan, cols=_cols, min_non_nulls=min_non_nulls), |
| session=self._session, |
| ) |
| |
| def replace( |
| self, |
| to_replace: Union[ |
| "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"] |
| ], |
| value: Optional[ |
| Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType] |
| ] = _NoValue, |
| subset: Optional[List[str]] = None, |
| ) -> ParentDataFrame: |
| if value is _NoValue: |
| if isinstance(to_replace, dict): |
| value = None |
| else: |
| raise PySparkTypeError( |
| error_class="ARGUMENT_REQUIRED", |
| message_parameters={"arg_name": "value", "condition": "`to_replace` is dict"}, |
| ) |
| |
| # Helper functions |
| def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]: |
| """Given a type or tuple of types and a sequence of xs |
| check if each x is instance of type(s) |
| |
| >>> all_of(bool)([True, False]) |
| True |
| >>> all_of(str)(["a", 1]) |
| False |
| """ |
| |
| def all_of_(xs: Iterable) -> bool: |
| return all(isinstance(x, types) for x in xs) |
| |
| return all_of_ |
| |
| all_of_bool = all_of(bool) |
| all_of_str = all_of(str) |
| all_of_numeric = all_of((float, int)) |
| |
| # Validate input types |
| valid_types = (bool, float, int, str, list, tuple) |
| if not isinstance(to_replace, valid_types + (dict,)): |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE", |
| message_parameters={ |
| "arg_name": "to_replace", |
| "arg_type": type(to_replace).__name__, |
| }, |
| ) |
| |
| if ( |
| not isinstance(value, valid_types) |
| and value is not None |
| and not isinstance(to_replace, dict) |
| ): |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "value", "arg_type": type(value).__name__}, |
| ) |
| |
| if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): |
| if len(to_replace) != len(value): |
| raise PySparkValueError( |
| error_class="LENGTH_SHOULD_BE_THE_SAME", |
| message_parameters={ |
| "arg1": "to_replace", |
| "arg2": "value", |
| "arg1_length": str(len(to_replace)), |
| "arg2_length": str(len(value)), |
| }, |
| ) |
| |
| if not (subset is None or isinstance(subset, (list, tuple, str))): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__}, |
| ) |
| |
| # Reshape input arguments if necessary |
| if isinstance(to_replace, (float, int, str)): |
| to_replace = [to_replace] |
| |
| if isinstance(to_replace, dict): |
| rep_dict = to_replace |
| if value is not None: |
| warnings.warn("to_replace is a dict and value is not None. value will be ignored.") |
| else: |
| if isinstance(value, (float, int, str)) or value is None: |
| value = [value for _ in range(len(to_replace))] |
| rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value))) |
| |
| if isinstance(subset, str): |
| subset = [subset] |
| |
| # Verify we were not passed in mixed type generics. |
| if not any( |
| all_of_type(rep_dict.keys()) |
| and all_of_type(x for x in rep_dict.values() if x is not None) |
| for all_of_type in [all_of_bool, all_of_str, all_of_numeric] |
| ): |
| raise PySparkValueError( |
| error_class="MIXED_TYPE_REPLACEMENT", |
| message_parameters={}, |
| ) |
| |
| def _convert_int_to_float(v: Any) -> Any: |
| # a bool is also an int |
| if v is not None and not isinstance(v, bool) and isinstance(v, int): |
| return float(v) |
| else: |
| return v |
| |
| _replacements = [] |
| for k, v in rep_dict.items(): |
| _k = _convert_int_to_float(k) |
| _v = _convert_int_to_float(v) |
| _replacements.append((F.lit(_k), F.lit(_v))) |
| |
| return DataFrame( |
| plan.NAReplace( |
| child=self._plan, |
| cols=subset, |
| replacements=_replacements, |
| ), |
| session=self._session, |
| ) |
| |
| @property |
| def stat(self) -> ParentDataFrameStatFunctions: |
| return DataFrameStatFunctions(self) |
| |
| def summary(self, *statistics: str) -> ParentDataFrame: |
| _statistics: List[str] = list(statistics) |
| for s in _statistics: |
| if not isinstance(s, str): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OF_STR", |
| message_parameters={"arg_name": "statistics", "arg_type": type(s).__name__}, |
| ) |
| return DataFrame( |
| plan.StatSummary(child=self._plan, statistics=_statistics), |
| session=self._session, |
| ) |
| |
| def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame: |
| if len(cols) == 1 and isinstance(cols[0], list): |
| cols = cols[0] # type: ignore[assignment] |
| |
| _cols = [] |
| for column in cols: |
| if isinstance(column, str): |
| _cols.append(column) |
| else: |
| _cols.extend([s for s in column]) |
| return DataFrame( |
| plan.StatDescribe(child=self._plan, cols=_cols), |
| session=self._session, |
| ) |
| |
| def cov(self, col1: str, col2: str) -> float: |
| if not isinstance(col1, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__}, |
| ) |
| if not isinstance(col2, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__}, |
| ) |
| table, _ = DataFrame( |
| plan.StatCov(child=self._plan, col1=col1, col2=col2), |
| session=self._session, |
| )._to_table() |
| return table[0][0].as_py() |
| |
| def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: |
| if not isinstance(col1, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__}, |
| ) |
| if not isinstance(col2, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__}, |
| ) |
| if not method: |
| method = "pearson" |
| if not method == "pearson": |
| raise PySparkValueError( |
| error_class="VALUE_NOT_PEARSON", |
| message_parameters={"arg_name": "method", "arg_value": method}, |
| ) |
| table, _ = DataFrame( |
| plan.StatCorr(child=self._plan, col1=col1, col2=col2, method=method), |
| session=self._session, |
| )._to_table() |
| return table[0][0].as_py() |
| |
| def approxQuantile( |
| self, |
| col: Union[str, List[str], Tuple[str]], |
| probabilities: Union[List[float], Tuple[float]], |
| relativeError: float, |
| ) -> Union[List[float], List[List[float]]]: |
| if not isinstance(col, (str, list, tuple)): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, |
| ) |
| |
| isStr = isinstance(col, str) |
| |
| if isinstance(col, tuple): |
| col = list(col) |
| elif isStr: |
| col = [cast(str, col)] |
| |
| for c in col: |
| if not isinstance(c, str): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OF_STR", |
| message_parameters={"arg_name": "columns", "arg_type": type(c).__name__}, |
| ) |
| |
| if not isinstance(probabilities, (list, tuple)): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_TUPLE", |
| message_parameters={ |
| "arg_name": "probabilities", |
| "arg_type": type(probabilities).__name__, |
| }, |
| ) |
| if isinstance(probabilities, tuple): |
| probabilities = list(probabilities) |
| for p in probabilities: |
| if not isinstance(p, (float, int)) or p < 0 or p > 1: |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OF_FLOAT_OR_INT", |
| message_parameters={ |
| "arg_name": "probabilities", |
| "arg_type": type(p).__name__, |
| }, |
| ) |
| |
| if not isinstance(relativeError, (float, int)): |
| raise PySparkTypeError( |
| error_class="NOT_FLOAT_OR_INT", |
| message_parameters={ |
| "arg_name": "relativeError", |
| "arg_type": type(relativeError).__name__, |
| }, |
| ) |
| if relativeError < 0: |
| raise PySparkValueError( |
| error_class="NEGATIVE_VALUE", |
| message_parameters={ |
| "arg_name": "relativeError", |
| "arg_value": str(relativeError), |
| }, |
| ) |
| relativeError = float(relativeError) |
| table, _ = DataFrame( |
| plan.StatApproxQuantile( |
| child=self._plan, |
| cols=list(col), |
| probabilities=probabilities, |
| relativeError=relativeError, |
| ), |
| session=self._session, |
| )._to_table() |
| jaq = [q.as_py() for q in table[0][0]] |
| jaq_list = [list(j) for j in jaq] |
| return jaq_list[0] if isStr else jaq_list |
| |
| def crosstab(self, col1: str, col2: str) -> ParentDataFrame: |
| if not isinstance(col1, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__}, |
| ) |
| if not isinstance(col2, str): |
| raise PySparkTypeError( |
| error_class="NOT_STR", |
| message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__}, |
| ) |
| return DataFrame( |
| plan.StatCrosstab(child=self._plan, col1=col1, col2=col2), |
| session=self._session, |
| ) |
| |
| def freqItems( |
| self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None |
| ) -> ParentDataFrame: |
| if isinstance(cols, tuple): |
| cols = list(cols) |
| if not isinstance(cols, list): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OR_TUPLE", |
| message_parameters={"arg_name": "cols", "arg_type": type(cols).__name__}, |
| ) |
| if not support: |
| support = 0.01 |
| return DataFrame( |
| plan.StatFreqItems(child=self._plan, cols=cols, support=support), |
| session=self._session, |
| ) |
| |
| def sampleBy( |
| self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None |
| ) -> ParentDataFrame: |
| if not isinstance(col, (str, Column)): |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_STR", |
| message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, |
| ) |
| if not isinstance(fractions, dict): |
| raise PySparkTypeError( |
| error_class="NOT_DICT", |
| message_parameters={"arg_name": "fractions", "arg_type": type(fractions).__name__}, |
| ) |
| |
| _fractions = [] |
| for k, v in fractions.items(): |
| if not isinstance(k, (float, int, str)): |
| raise PySparkTypeError( |
| error_class="DISALLOWED_TYPE_FOR_CONTAINER", |
| message_parameters={ |
| "arg_name": "fractions", |
| "arg_type": type(fractions).__name__, |
| "allowed_types": "float, int, str", |
| "item_type": type(k).__name__, |
| }, |
| ) |
| _fractions.append((F.lit(k), float(v))) |
| |
| seed = seed if seed is not None else random.randint(0, sys.maxsize) |
| return DataFrame( |
| plan.StatSampleBy( |
| child=self._plan, |
| col=F._to_col(col), |
| fractions=_fractions, |
| seed=seed, |
| ), |
| session=self._session, |
| ) |
| |
| def _ipython_key_completions_(self) -> List[str]: |
| """Returns the names of columns in this :class:`DataFrame`. |
| |
| Examples |
| -------- |
| >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"]) |
| >>> df._ipython_key_completions_() |
| ['age', 'name'] |
| |
| Would return illegal identifiers. |
| >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1", "name?1"]) |
| >>> df._ipython_key_completions_() |
| ['age 1', 'name?1'] |
| """ |
| return self.columns |
| |
| def __getattr__(self, name: str) -> "Column": |
| if name in ["_jseq", "_jdf", "_jmap", "_jcols", "rdd", "toJSON"]: |
| raise PySparkAttributeError( |
| error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name} |
| ) |
| |
| if name not in self.columns: |
| raise PySparkAttributeError( |
| error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name} |
| ) |
| |
| return self._col(name) |
| |
| @overload |
| def __getitem__(self, item: Union[int, str]) -> Column: |
| ... |
| |
| @overload |
| def __getitem__(self, item: Union[Column, List, Tuple]) -> ParentDataFrame: |
| ... |
| |
| def __getitem__( |
| self, item: Union[int, str, Column, List, Tuple] |
| ) -> Union[Column, ParentDataFrame]: |
| from pyspark.sql.connect.column import Column as ConnectColumn |
| |
| if isinstance(item, str): |
| if item == "*": |
| return ConnectColumn( |
| UnresolvedStar( |
| unparsed_target=None, |
| plan_id=self._plan._plan_id, |
| ) |
| ) |
| else: |
| # TODO: revisit vanilla Spark's Dataset.col |
| # if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { |
| # colRegex(colName) |
| # } else { |
| # ConnectColumn(addDataFrameIdToCol(resolve(colName))) |
| # } |
| |
| # validate the column name |
| if not hasattr(self._session, "is_mock_session"): |
| from pyspark.sql.connect.types import verify_col_name |
| |
| # Try best to verify the column name with cached schema |
| # If fails, fall back to the server side validation |
| if not verify_col_name(item, self.schema): |
| self.select(item).isLocal() |
| |
| return self._col(item) |
| elif isinstance(item, Column): |
| return self.filter(item) |
| elif isinstance(item, (list, tuple)): |
| return self.select(*item) |
| elif isinstance(item, int): |
| return F.col(self.columns[item]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", |
| message_parameters={"arg_name": "item", "arg_type": type(item).__name__}, |
| ) |
| |
| def _col(self, name: str) -> Column: |
| from pyspark.sql.connect.column import Column as ConnectColumn |
| |
| return ConnectColumn( |
| ColumnReference( |
| unparsed_identifier=name, |
| plan_id=self._plan._plan_id, |
| ) |
| ) |
| |
| def __dir__(self) -> List[str]: |
| attrs = set(dir(DataFrame)) |
| attrs.update(self.columns) |
| return sorted(attrs) |
| |
| def collect(self) -> List[Row]: |
| table, schema = self._to_table() |
| |
| schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True) |
| |
| assert schema is not None and isinstance(schema, StructType) |
| |
| return ArrowTableToRowsConversion.convert(table, schema) |
| |
| def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: |
| query = self._plan.to_proto(self._session.client) |
| table, schema = self._session.client.to_table(query, self._plan.observations) |
| assert table is not None |
| return (table, schema) |
| |
| def toArrow(self) -> "pa.Table": |
| table, _ = self._to_table() |
| return table |
| |
| def toPandas(self) -> "PandasDataFrameLike": |
| query = self._plan.to_proto(self._session.client) |
| return self._session.client.to_pandas(query, self._plan.observations) |
| |
| @property |
| def schema(self) -> StructType: |
| # Schema caching is correct in most cases. Connect is lazy by nature. This means that |
| # we only resolve the plan when it is submitted for execution or analysis. We do not |
| # cache intermediate resolved plan. If the input (changes table, view redefinition, |
| # etc...) of the plan changes between the schema() call, and a subsequent action, the |
| # cached schema might be inconsistent with the end schema. |
| if self._cached_schema is None: |
| query = self._plan.to_proto(self._session.client) |
| self._cached_schema = self._session.client.schema(query) |
| return self._cached_schema |
| |
| def isLocal(self) -> bool: |
| query = self._plan.to_proto(self._session.client) |
| result = self._session.client._analyze(method="is_local", plan=query).is_local |
| assert result is not None |
| return result |
| |
| @cached_property |
| def isStreaming(self) -> bool: |
| query = self._plan.to_proto(self._session.client) |
| result = self._session.client._analyze(method="is_streaming", plan=query).is_streaming |
| assert result is not None |
| return result |
| |
| def _tree_string(self, level: Optional[int] = None) -> str: |
| query = self._plan.to_proto(self._session.client) |
| result = self._session.client._analyze( |
| method="tree_string", plan=query, level=level |
| ).tree_string |
| assert result is not None |
| return result |
| |
| def printSchema(self, level: Optional[int] = None) -> None: |
| print(self._tree_string(level)) |
| |
| def inputFiles(self) -> List[str]: |
| query = self._plan.to_proto(self._session.client) |
| result = self._session.client._analyze(method="input_files", plan=query).input_files |
| assert result is not None |
| return result |
| |
| def to(self, schema: StructType) -> ParentDataFrame: |
| assert schema is not None |
| return DataFrame( |
| plan.ToSchema(child=self._plan, schema=schema), |
| session=self._session, |
| ) |
| |
| def toDF(self, *cols: str) -> ParentDataFrame: |
| for col_ in cols: |
| if not isinstance(col_, str): |
| raise PySparkTypeError( |
| error_class="NOT_LIST_OF_STR", |
| message_parameters={"arg_name": "cols", "arg_type": type(col_).__name__}, |
| ) |
| return DataFrame(plan.ToDF(self._plan, list(cols)), self._session) |
| |
| def transform( |
| self, func: Callable[..., ParentDataFrame], *args: Any, **kwargs: Any |
| ) -> ParentDataFrame: |
| result = func(self, *args, **kwargs) |
| assert isinstance( |
| result, DataFrame |
| ), "Func returned an instance of type [%s], " "should have been DataFrame." % type(result) |
| return result |
| |
| def _explain_string( |
| self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None |
| ) -> str: |
| if extended is not None and mode is not None: |
| raise PySparkValueError( |
| error_class="CANNOT_SET_TOGETHER", |
| message_parameters={"arg_list": "extended and mode"}, |
| ) |
| |
| # For the no argument case: df.explain() |
| is_no_argument = extended is None and mode is None |
| |
| # For the cases below: |
| # explain(True) |
| # explain(extended=False) |
| is_extended_case = isinstance(extended, bool) and mode is None |
| |
| # For the case when extended is mode: |
| # df.explain("formatted") |
| is_extended_as_mode = isinstance(extended, str) and mode is None |
| |
| # For the mode specified: |
| # df.explain(mode="formatted") |
| is_mode_case = extended is None and isinstance(mode, str) |
| |
| if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): |
| argtypes = [str(type(arg)) for arg in [extended, mode] if arg is not None] |
| raise PySparkTypeError( |
| error_class="NOT_BOOL_OR_STR", |
| message_parameters={ |
| "arg_name": "extended (optional) and mode (optional)", |
| "arg_type": ", ".join(argtypes), |
| }, |
| ) |
| |
| # Sets an explain mode depending on a given argument |
| if is_no_argument: |
| explain_mode = "simple" |
| elif is_extended_case: |
| explain_mode = "extended" if extended else "simple" |
| elif is_mode_case: |
| explain_mode = cast(str, mode) |
| elif is_extended_as_mode: |
| explain_mode = cast(str, extended) |
| |
| query = self._plan.to_proto(self._session.client) |
| return self._session.client.explain_string(query, explain_mode) |
| |
| def explain( |
| self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None |
| ) -> None: |
| print(self._explain_string(extended=extended, mode=mode)) |
| |
| def createTempView(self, name: str) -> None: |
| command = plan.CreateView( |
| child=self._plan, name=name, is_global=False, replace=False |
| ).command(session=self._session.client) |
| self._session.client.execute_command(command, self._plan.observations) |
| |
| def createOrReplaceTempView(self, name: str) -> None: |
| command = plan.CreateView( |
| child=self._plan, name=name, is_global=False, replace=True |
| ).command(session=self._session.client) |
| self._session.client.execute_command(command, self._plan.observations) |
| |
| def createGlobalTempView(self, name: str) -> None: |
| command = plan.CreateView( |
| child=self._plan, name=name, is_global=True, replace=False |
| ).command(session=self._session.client) |
| self._session.client.execute_command(command, self._plan.observations) |
| |
| def createOrReplaceGlobalTempView(self, name: str) -> None: |
| command = plan.CreateView( |
| child=self._plan, name=name, is_global=True, replace=True |
| ).command(session=self._session.client) |
| self._session.client.execute_command(command, self._plan.observations) |
| |
| def cache(self) -> ParentDataFrame: |
| return self.persist() |
| |
| def persist( |
| self, |
| storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER), |
| ) -> ParentDataFrame: |
| relation = self._plan.plan(self._session.client) |
| self._session.client._analyze( |
| method="persist", relation=relation, storage_level=storageLevel |
| ) |
| return self |
| |
| @property |
| def storageLevel(self) -> StorageLevel: |
| relation = self._plan.plan(self._session.client) |
| storage_level = self._session.client._analyze( |
| method="get_storage_level", relation=relation |
| ).storage_level |
| assert storage_level is not None |
| return storage_level |
| |
| def unpersist(self, blocking: bool = False) -> ParentDataFrame: |
| relation = self._plan.plan(self._session.client) |
| self._session.client._analyze(method="unpersist", relation=relation, blocking=blocking) |
| return self |
| |
| @property |
| def is_cached(self) -> bool: |
| return self.storageLevel != StorageLevel.NONE |
| |
| def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: |
| query = self._plan.to_proto(self._session.client) |
| |
| schema: Optional[StructType] = None |
| for schema_or_table in self._session.client.to_table_as_iterator( |
| query, self._plan.observations |
| ): |
| if isinstance(schema_or_table, StructType): |
| assert schema is None |
| schema = schema_or_table |
| else: |
| assert isinstance(schema_or_table, pa.Table) |
| table = schema_or_table |
| if schema is None: |
| schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True) |
| yield from ArrowTableToRowsConversion.convert(table, schema) |
| |
| def pandas_api( |
| self, index_col: Optional[Union[str, List[str]]] = None |
| ) -> "PandasOnSparkDataFrame": |
| from pyspark.pandas.namespace import _get_index_map |
| from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame |
| from pyspark.pandas.internal import InternalFrame |
| |
| index_spark_columns, index_names = _get_index_map(self, index_col) |
| internal = InternalFrame( |
| spark_frame=self, |
| index_spark_columns=index_spark_columns, |
| index_names=index_names, # type: ignore[arg-type] |
| ) |
| return PandasOnSparkDataFrame(internal) |
| |
| def registerTempTable(self, name: str) -> None: |
| warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning) |
| self.createOrReplaceTempView(name) |
| |
| def _map_partitions( |
| self, |
| func: "PandasMapIterFunction", |
| schema: Union[StructType, str], |
| evalType: int, |
| barrier: bool, |
| profile: Optional[ResourceProfile], |
| ) -> ParentDataFrame: |
| from pyspark.sql.connect.udf import UserDefinedFunction |
| |
| _validate_pandas_udf(func, evalType) |
| udf_obj = UserDefinedFunction( |
| func, |
| returnType=schema, |
| evalType=evalType, |
| ) |
| |
| return DataFrame( |
| plan.MapPartitions( |
| child=self._plan, |
| function=udf_obj, |
| cols=self.columns, |
| is_barrier=barrier, |
| profile=profile, |
| ), |
| session=self._session, |
| ) |
| |
| def mapInPandas( |
| self, |
| func: "PandasMapIterFunction", |
| schema: Union[StructType, str], |
| barrier: bool = False, |
| profile: Optional[ResourceProfile] = None, |
| ) -> ParentDataFrame: |
| return self._map_partitions( |
| func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, barrier, profile |
| ) |
| |
| def mapInArrow( |
| self, |
| func: "ArrowMapIterFunction", |
| schema: Union[StructType, str], |
| barrier: bool = False, |
| profile: Optional[ResourceProfile] = None, |
| ) -> ParentDataFrame: |
| return self._map_partitions( |
| func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, barrier, profile |
| ) |
| |
| def foreach(self, f: Callable[[Row], None]) -> None: |
| def foreach_func(row: Any) -> None: |
| f(row) |
| |
| self.select(F.struct(*self.schema.fieldNames()).alias("row")).select( |
| F.udf(foreach_func, StructType())("row") # type: ignore[arg-type] |
| ).collect() |
| |
| def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: |
| schema = self.schema |
| field_converters = [ |
| ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields |
| ] |
| |
| def foreach_partition_func(itr: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]: |
| def flatten() -> Iterator[Row]: |
| for table in itr: |
| columnar_data = [column.to_pylist() for column in table.columns] |
| for i in range(0, table.num_rows): |
| values = [ |
| field_converters[j](columnar_data[j][i]) |
| for j in range(table.num_columns) |
| ] |
| yield _create_row(fields=schema.fieldNames(), values=values) |
| |
| f(flatten()) |
| return iter([]) |
| |
| self.mapInArrow(foreach_partition_func, schema=StructType()).collect() |
| |
| @property |
| def writeStream(self) -> DataStreamWriter: |
| return DataStreamWriter(plan=self._plan, session=self._session) |
| |
| def sameSemantics(self, other: ParentDataFrame) -> bool: |
| if not isinstance(other, DataFrame): |
| raise PySparkTypeError( |
| error_class="NOT_DATAFRAME", |
| message_parameters={"arg_name": "other", "arg_type": type(other).__name__}, |
| ) |
| self._check_same_session(other) |
| return self._session.client.same_semantics( |
| plan=self._plan.to_proto(self._session.client), |
| other=other._plan.to_proto(other._session.client), |
| ) |
| |
| def semanticHash(self) -> int: |
| return self._session.client.semantic_hash( |
| plan=self._plan.to_proto(self._session.client), |
| ) |
| |
| def writeTo(self, table: str) -> "DataFrameWriterV2": |
| return DataFrameWriterV2(self._plan, self._session, table) |
| |
| def offset(self, n: int) -> ParentDataFrame: |
| return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) |
| |
| if not is_remote_only(): |
| |
| def checkpoint(self, eager: bool = True) -> "DataFrame": |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "checkpoint()"}, |
| ) |
| |
| def localCheckpoint(self, eager: bool = True) -> "DataFrame": |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "localCheckpoint()"}, |
| ) |
| |
| def toJSON(self, use_unicode: bool = True) -> "RDD[str]": |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "toJSON()"}, |
| ) |
| |
| @property |
| def rdd(self) -> "RDD[Row]": |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "rdd"}, |
| ) |
| |
| |
| class DataFrameNaFunctions(ParentDataFrameNaFunctions): |
| def __init__(self, df: ParentDataFrame): |
| self.df = df |
| |
| def fill( |
| self, |
| value: Union["LiteralType", Dict[str, "LiteralType"]], |
| subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, |
| ) -> ParentDataFrame: |
| return self.df.fillna(value=value, subset=subset) # type: ignore[arg-type] |
| |
| def drop( |
| self, |
| how: str = "any", |
| thresh: Optional[int] = None, |
| subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, |
| ) -> ParentDataFrame: |
| return self.df.dropna(how=how, thresh=thresh, subset=subset) |
| |
| def replace( |
| self, |
| to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]], |
| value: Optional[ |
| Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType] |
| ] = _NoValue, |
| subset: Optional[List[str]] = None, |
| ) -> ParentDataFrame: |
| return self.df.replace(to_replace, value, subset) # type: ignore[arg-type] |
| |
| |
| class DataFrameStatFunctions(ParentDataFrameStatFunctions): |
| def __init__(self, df: ParentDataFrame): |
| self.df = df |
| |
| def cov(self, col1: str, col2: str) -> float: |
| return self.df.cov(col1, col2) |
| |
| def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: |
| return self.df.corr(col1, col2, method) |
| |
| def approxQuantile( |
| self, |
| col: Union[str, List[str], Tuple[str]], |
| probabilities: Union[List[float], Tuple[float]], |
| relativeError: float, |
| ) -> Union[List[float], List[List[float]]]: |
| return self.df.approxQuantile(col, probabilities, relativeError) |
| |
| def crosstab(self, col1: str, col2: str) -> ParentDataFrame: |
| return self.df.crosstab(col1, col2) |
| |
| def freqItems( |
| self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None |
| ) -> ParentDataFrame: |
| return self.df.freqItems(cols, support) |
| |
| def sampleBy( |
| self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None |
| ) -> ParentDataFrame: |
| return self.df.sampleBy(col, fractions, seed) |
| |
| |
| def _test() -> None: |
| import os |
| import sys |
| import doctest |
| from pyspark.util import is_remote_only |
| from pyspark.sql import SparkSession as PySparkSession |
| import pyspark.sql.dataframe |
| |
| # It inherits docstrings but doctests cannot detect them so we run |
| # the parent classe's doctests here directly. |
| os.chdir(os.environ["SPARK_HOME"]) |
| |
| globs = pyspark.sql.dataframe.__dict__.copy() |
| |
| if not is_remote_only(): |
| del pyspark.sql.dataframe.DataFrame.toJSON.__doc__ |
| del pyspark.sql.dataframe.DataFrame.rdd.__doc__ |
| del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__ |
| del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__ |
| |
| globs["spark"] = ( |
| PySparkSession.builder.appName("sql.connect.dataframe tests") |
| .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) |
| .getOrCreate() |
| ) |
| |
| (failure_count, test_count) = doctest.testmod( |
| pyspark.sql.dataframe, |
| globs=globs, |
| optionflags=doctest.ELLIPSIS |
| | doctest.NORMALIZE_WHITESPACE |
| | doctest.IGNORE_EXCEPTION_DETAIL, |
| ) |
| |
| globs["spark"].stop() |
| |
| if failure_count: |
| sys.exit(-1) |
| |
| |
| if __name__ == "__main__": |
| _test() |