blob: 9b6790d29aaa7a35b819f2eceec6942fc0998bc6 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import sys
import random
import warnings
from collections.abc import Iterable
from functools import reduce
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
overload,
TYPE_CHECKING,
)
from pyspark import _NoValue
from pyspark.resource import ResourceProfile
from pyspark._globals import _NoValueType
from pyspark.errors import (
PySparkTypeError,
PySparkValueError,
PySparkIndexError,
PySparkAttributeError,
)
from pyspark.util import (
_load_from_socket,
_local_iterator_from_socket,
)
from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.column import Column
from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import (
StructType,
Row,
_parse_datatype_json_string,
)
from pyspark.sql.dataframe import (
DataFrame as ParentDataFrame,
DataFrameNaFunctions as ParentDataFrameNaFunctions,
DataFrameStatFunctions as ParentDataFrameStatFunctions,
)
from pyspark.sql.utils import get_active_spark_context, toJArray
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext
from pyspark._typing import PrimitiveType
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.sql._typing import (
ColumnOrName,
ColumnOrNameOrOrdinal,
LiteralType,
OptionalPrimitiveType,
)
from pyspark.sql.pandas._typing import (
PandasMapIterFunction,
ArrowMapIterFunction,
DataFrameLike as PandasDataFrameLike,
)
from pyspark.sql.context import SQLContext
from pyspark.sql.session import SparkSession
from pyspark.sql.group import GroupedData
from pyspark.sql.observation import Observation
class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
def __new__(
cls,
jdf: "JavaObject",
sql_ctx: Union["SQLContext", "SparkSession"],
) -> "DataFrame":
self = object.__new__(cls)
self.__init__(jdf, sql_ctx) # type: ignore[misc]
return self
def __init__(
self,
jdf: "JavaObject",
sql_ctx: Union["SQLContext", "SparkSession"],
):
from pyspark.sql.context import SQLContext
self._sql_ctx: Optional["SQLContext"] = None
if isinstance(sql_ctx, SQLContext):
assert not os.environ.get("SPARK_TESTING") # Sanity check for our internal usage.
assert isinstance(sql_ctx, SQLContext)
# We should remove this if-else branch in the future release, and rename
# sql_ctx to session in the constructor. This is an internal code path but
# was kept with a warning because it's used intensively by third-party libraries.
warnings.warn("DataFrame constructor is internal. Do not directly use it.")
self._sql_ctx = sql_ctx
session = sql_ctx.sparkSession
else:
session = sql_ctx
self._session: "SparkSession" = session
self._sc: "SparkContext" = sql_ctx._sc
self._jdf: "JavaObject" = jdf
self.is_cached = False
# initialized lazily
self._schema: Optional[StructType] = None
self._lazy_rdd: Optional["RDD[Row]"] = None
# Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
@property
def sql_ctx(self) -> "SQLContext":
from pyspark.sql.context import SQLContext
warnings.warn(
"DataFrame.sql_ctx is an internal property, and will be removed "
"in future releases. Use DataFrame.sparkSession instead."
)
if self._sql_ctx is None:
self._sql_ctx = SQLContext._get_or_create(self._sc)
return self._sql_ctx
@property
def sparkSession(self) -> "SparkSession":
return self._session
@property
def rdd(self) -> "RDD[Row]":
from pyspark.core.rdd import RDD
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
self._lazy_rdd = RDD(
jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer())
)
return self._lazy_rdd
@property
def na(self) -> ParentDataFrameNaFunctions:
return DataFrameNaFunctions(self)
@property
def stat(self) -> ParentDataFrameStatFunctions:
return DataFrameStatFunctions(self)
def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
from pyspark.core.rdd import RDD
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def registerTempTable(self, name: str) -> None:
warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning)
self._jdf.createOrReplaceTempView(name)
def createTempView(self, name: str) -> None:
self._jdf.createTempView(name)
def createOrReplaceTempView(self, name: str) -> None:
self._jdf.createOrReplaceTempView(name)
def createGlobalTempView(self, name: str) -> None:
self._jdf.createGlobalTempView(name)
def createOrReplaceGlobalTempView(self, name: str) -> None:
self._jdf.createOrReplaceGlobalTempView(name)
@property
def write(self) -> DataFrameWriter:
return DataFrameWriter(self)
@property
def writeStream(self) -> DataStreamWriter:
return DataStreamWriter(self)
@property
def schema(self) -> StructType:
if self._schema is None:
try:
self._schema = cast(
StructType, _parse_datatype_json_string(self._jdf.schema().json())
)
except Exception as e:
raise PySparkValueError(
error_class="CANNOT_PARSE_DATATYPE",
message_parameters={"error": str(e)},
)
return self._schema
def printSchema(self, level: Optional[int] = None) -> None:
if level:
print(self._jdf.schema().treeString(level))
else:
print(self._jdf.schema().treeString())
def explain(
self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None
) -> None:
if extended is not None and mode is not None:
raise PySparkValueError(
error_class="CANNOT_SET_TOGETHER",
message_parameters={"arg_list": "extended and mode"},
)
# For the no argument case: df.explain()
is_no_argument = extended is None and mode is None
# For the cases below:
# explain(True)
# explain(extended=False)
is_extended_case = isinstance(extended, bool) and mode is None
# For the case when extended is mode:
# df.explain("formatted")
is_extended_as_mode = isinstance(extended, str) and mode is None
# For the mode specified:
# df.explain(mode="formatted")
is_mode_case = extended is None and isinstance(mode, str)
if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case):
if (extended is not None) and (not isinstance(extended, (bool, str))):
raise PySparkTypeError(
error_class="NOT_BOOL_OR_STR",
message_parameters={
"arg_name": "extended",
"arg_type": type(extended).__name__,
},
)
if (mode is not None) and (not isinstance(mode, str)):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "mode", "arg_type": type(mode).__name__},
)
# Sets an explain mode depending on a given argument
if is_no_argument:
explain_mode = "simple"
elif is_extended_case:
explain_mode = "extended" if extended else "simple"
elif is_mode_case:
explain_mode = cast(str, mode)
elif is_extended_as_mode:
explain_mode = cast(str, extended)
assert self._sc._jvm is not None
print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode))
def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)
def isLocal(self) -> bool:
return self._jdf.isLocal()
@property
def isStreaming(self) -> bool:
return self._jdf.isStreaming()
def isEmpty(self) -> bool:
return self._jdf.isEmpty()
def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
print(self._show_string(n, truncate, vertical))
def _show_string(
self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False
) -> str:
if not isinstance(n, int) or isinstance(n, bool):
raise PySparkTypeError(
error_class="NOT_INT",
message_parameters={"arg_name": "n", "arg_type": type(n).__name__},
)
if not isinstance(vertical, bool):
raise PySparkTypeError(
error_class="NOT_BOOL",
message_parameters={"arg_name": "vertical", "arg_type": type(vertical).__name__},
)
if isinstance(truncate, bool) and truncate:
return self._jdf.showString(n, 20, vertical)
else:
try:
int_truncate = int(truncate)
except ValueError:
raise PySparkTypeError(
error_class="NOT_BOOL",
message_parameters={
"arg_name": "truncate",
"arg_type": type(truncate).__name__,
},
)
return self._jdf.showString(n, int_truncate, vertical)
def __repr__(self) -> str:
if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled():
vertical = False
return self._jdf.showString(
self.sparkSession._jconf.replEagerEvalMaxNumRows(),
self.sparkSession._jconf.replEagerEvalTruncate(),
vertical,
)
else:
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
def _repr_html_(self) -> Optional[str]:
"""Returns a :class:`DataFrame` with html code when you enabled eager evaluation
by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
using support eager evaluation with HTML.
"""
if not self._support_repr_html:
self._support_repr_html = True
if self.sparkSession._jconf.isReplEagerEvalEnabled():
return self._jdf.htmlString(
self.sparkSession._jconf.replEagerEvalMaxNumRows(),
self.sparkSession._jconf.replEagerEvalTruncate(),
)
else:
return None
def checkpoint(self, eager: bool = True) -> ParentDataFrame:
jdf = self._jdf.checkpoint(eager)
return DataFrame(jdf, self.sparkSession)
def localCheckpoint(self, eager: bool = True) -> ParentDataFrame:
jdf = self._jdf.localCheckpoint(eager)
return DataFrame(jdf, self.sparkSession)
def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame:
if not eventTime or type(eventTime) is not str:
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "eventTime", "arg_type": type(eventTime).__name__},
)
if not delayThreshold or type(delayThreshold) is not str:
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={
"arg_name": "delayThreshold",
"arg_type": type(delayThreshold).__name__,
},
)
jdf = self._jdf.withWatermark(eventTime, delayThreshold)
return DataFrame(jdf, self.sparkSession)
def hint(
self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]]
) -> ParentDataFrame:
if len(parameters) == 1 and isinstance(parameters[0], list):
parameters = parameters[0] # type: ignore[assignment]
if not isinstance(name, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "name", "arg_type": type(name).__name__},
)
allowed_types = (str, float, int, Column, list)
allowed_primitive_types = (str, float, int)
allowed_types_repr = ", ".join(
[t.__name__ for t in allowed_types[:-1]]
+ ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
)
for p in parameters:
if not isinstance(p, allowed_types):
raise PySparkTypeError(
error_class="DISALLOWED_TYPE_FOR_CONTAINER",
message_parameters={
"arg_name": "parameters",
"arg_type": type(parameters).__name__,
"allowed_types": allowed_types_repr,
"item_type": type(p).__name__,
},
)
if isinstance(p, list):
if not all(isinstance(e, allowed_primitive_types) for e in p):
raise PySparkTypeError(
error_class="DISALLOWED_TYPE_FOR_CONTAINER",
message_parameters={
"arg_name": "parameters",
"arg_type": type(parameters).__name__,
"allowed_types": allowed_types_repr,
"item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]",
},
)
def _converter(parameter: Union[str, list, float, int, Column]) -> Any:
if isinstance(parameter, Column):
return _to_java_column(parameter)
elif isinstance(parameter, list):
# for list input, we are assuming only one element type exist in the list.
# for empty list, we are converting it into an empty long[] in the JVM side.
gateway = self._sc._gateway
assert gateway is not None
jclass = gateway.jvm.long
if len(parameter) >= 1:
mapping = {
str: gateway.jvm.java.lang.String,
float: gateway.jvm.double,
int: gateway.jvm.long,
}
jclass = mapping[type(parameter[0])]
return toJArray(gateway, jclass, parameter)
else:
return parameter
jdf = self._jdf.hint(name, self._jseq(parameters, _converter))
return DataFrame(jdf, self.sparkSession)
def count(self) -> int:
return int(self._jdf.count())
def collect(self) -> List[Row]:
with SCCallSiteSync(self._sc):
sock_info = self._jdf.collectToPython()
return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer())))
def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
with SCCallSiteSync(self._sc):
sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))
def limit(self, num: int) -> ParentDataFrame:
jdf = self._jdf.limit(num)
return DataFrame(jdf, self.sparkSession)
def offset(self, num: int) -> ParentDataFrame:
jdf = self._jdf.offset(num)
return DataFrame(jdf, self.sparkSession)
def take(self, num: int) -> List[Row]:
return self.limit(num).collect()
def tail(self, num: int) -> List[Row]:
with SCCallSiteSync(self._sc):
sock_info = self._jdf.tailToPython(num)
return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer())))
def foreach(self, f: Callable[[Row], None]) -> None:
self.rdd.foreach(f)
def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
self.rdd.foreachPartition(f) # type: ignore[arg-type]
def cache(self) -> ParentDataFrame:
self.is_cached = True
self._jdf.cache()
return self
def persist(
self,
storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
) -> ParentDataFrame:
self.is_cached = True
javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
self._jdf.persist(javaStorageLevel)
return self
@property
def storageLevel(self) -> StorageLevel:
java_storage_level = self._jdf.storageLevel()
storage_level = StorageLevel(
java_storage_level.useDisk(),
java_storage_level.useMemory(),
java_storage_level.useOffHeap(),
java_storage_level.deserialized(),
java_storage_level.replication(),
)
return storage_level
def unpersist(self, blocking: bool = False) -> ParentDataFrame:
self.is_cached = False
self._jdf.unpersist(blocking)
return self
def coalesce(self, numPartitions: int) -> ParentDataFrame:
return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
@overload
def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
...
@overload
def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
...
def repartition( # type: ignore[misc]
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
if len(cols) == 0:
return DataFrame(self._jdf.repartition(numPartitions), self.sparkSession)
else:
return DataFrame(
self._jdf.repartition(numPartitions, self._jcols(*cols)),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sparkSession)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={
"arg_name": "numPartitions",
"arg_type": type(numPartitions).__name__,
},
)
@overload
def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
...
@overload
def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
...
def repartitionByRange( # type: ignore[misc]
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> ParentDataFrame:
if isinstance(numPartitions, int):
if len(cols) == 0:
raise PySparkValueError(
error_class="CANNOT_BE_EMPTY",
message_parameters={"item": "partition-by expression"},
)
else:
return DataFrame(
self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sparkSession)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_INT_OR_STR",
message_parameters={
"arg_name": "numPartitions",
"arg_type": type(numPartitions).__name__,
},
)
def distinct(self) -> ParentDataFrame:
return DataFrame(self._jdf.distinct(), self.sparkSession)
@overload
def sample(self, fraction: float, seed: Optional[int] = ...) -> ParentDataFrame:
...
@overload
def sample(
self,
withReplacement: Optional[bool],
fraction: float,
seed: Optional[int] = ...,
) -> ParentDataFrame:
...
def sample( # type: ignore[misc]
self,
withReplacement: Optional[Union[float, bool]] = None,
fraction: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
) -> ParentDataFrame:
# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float)
# For the case below:
# sample(faction=0.5 [, seed])
is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float)
# For the case below:
# sample(0.5 [, seed])
is_withReplacement_omitted_args = isinstance(withReplacement, float)
if not (
is_withReplacement_set
or is_withReplacement_omitted_kwargs
or is_withReplacement_omitted_args
):
argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]]
raise PySparkTypeError(
error_class="NOT_BOOL_OR_FLOAT_OR_INT",
message_parameters={
"arg_name": "withReplacement (optional), "
+ "fraction (required) and seed (optional)",
"arg_type": ", ".join(argtypes),
},
)
if is_withReplacement_omitted_args:
if fraction is not None:
seed = cast(int, fraction)
fraction = withReplacement
withReplacement = None
seed = int(seed) if seed is not None else None
args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
jdf = self._jdf.sample(*args)
return DataFrame(jdf, self.sparkSession)
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None
) -> ParentDataFrame:
if isinstance(col, str):
col = Column(col)
elif not isinstance(col, Column):
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
)
if not isinstance(fractions, dict):
raise PySparkTypeError(
error_class="NOT_DICT",
message_parameters={"arg_name": "fractions", "arg_type": type(fractions).__name__},
)
for k, v in fractions.items():
if not isinstance(k, (float, int, str)):
raise PySparkTypeError(
error_class="DISALLOWED_TYPE_FOR_CONTAINER",
message_parameters={
"arg_name": "fractions",
"arg_type": type(fractions).__name__,
"allowed_types": "float, int, str",
"item_type": type(k).__name__,
},
)
fractions[k] = float(v)
col = col._jc
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(
self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sparkSession
)
def randomSplit(
self, weights: List[float], seed: Optional[int] = None
) -> List[ParentDataFrame]:
for w in weights:
if w < 0.0:
raise PySparkValueError(
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "weights", "arg_value": str(w)},
)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
df_array = self._jdf.randomSplit(
_to_list(self.sparkSession._sc, cast(List["ColumnOrName"], weights)), int(seed)
)
return [DataFrame(df, self.sparkSession) for df in df_array]
@property
def dtypes(self) -> List[Tuple[str, str]]:
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
def columns(self) -> List[str]:
return [f.name for f in self.schema.fields]
def colRegex(self, colName: str) -> Column:
if not isinstance(colName, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "colName", "arg_type": type(colName).__name__},
)
jc = self._jdf.colRegex(colName)
return Column(jc)
def to(self, schema: StructType) -> ParentDataFrame:
assert schema is not None
jschema = self._jdf.sparkSession().parseDataType(schema.json())
return DataFrame(self._jdf.to(jschema), self.sparkSession)
def alias(self, alias: str) -> ParentDataFrame:
assert isinstance(alias, str), "alias should be a string"
return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession)
def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame:
jdf = self._jdf.crossJoin(other._jdf)
return DataFrame(jdf, self.sparkSession)
def join(
self,
other: ParentDataFrame,
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
) -> ParentDataFrame:
if on is not None and not isinstance(on, list):
on = [on] # type: ignore[assignment]
if on is not None:
if isinstance(on[0], str):
on = self._jseq(cast(List[str], on))
else:
assert isinstance(on[0], Column), "on should be Column or list of Column"
on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
on = on._jc
if on is None and how is None:
jdf = self._jdf.join(other._jdf)
else:
if how is None:
how = "inner"
if on is None:
on = self._jseq([])
assert isinstance(how, str), "how should be a string"
jdf = self._jdf.join(other._jdf, on, how)
return DataFrame(jdf, self.sparkSession)
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
other: ParentDataFrame,
leftAsOfColumn: Union[str, Column],
rightAsOfColumn: Union[str, Column],
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
*,
tolerance: Optional[Column] = None,
allowExactMatches: bool = True,
direction: str = "backward",
) -> ParentDataFrame:
"""
Perform an as-of join.
This is similar to a left-join except that we match on the nearest
key rather than equal keys.
.. versionchanged:: 4.0.0
Supports Spark Connect.
Parameters
----------
other : :class:`DataFrame`
Right side of the join
leftAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
rightAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
on : str, list or :class:`Column`, optional
a string for the join column name, a list of column names,
a join expression (Column), or a list of Columns.
If `on` is a string or a list of strings indicating the name of the join column(s),
the column(s) must exist on both sides, and this performs an equi-join.
how : str, optional
default ``inner``. Must be one of: ``inner`` and ``left``.
tolerance : :class:`Column`, optional
an asof tolerance within this range; must be compatible
with the merge index.
allowExactMatches : bool, optional
default ``True``.
direction : str, optional
default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``.
Examples
--------
The following performs an as-of join between ``left`` and ``right``.
>>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"])
>>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)],
... ["a", "right_val"])
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> from pyspark.sql import functions as sf
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=sf.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=sf.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=None),
Row(a=10, left_val='c', right_val=None)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=6)]
"""
if isinstance(leftAsOfColumn, str):
leftAsOfColumn = self[leftAsOfColumn]
left_as_of_jcol = leftAsOfColumn._jc
if isinstance(rightAsOfColumn, str):
rightAsOfColumn = other[rightAsOfColumn]
right_as_of_jcol = rightAsOfColumn._jc
if on is not None and not isinstance(on, list):
on = [on] # type: ignore[assignment]
if on is not None:
if isinstance(on[0], str):
on = self._jseq(cast(List[str], on))
else:
assert isinstance(on[0], Column), "on should be Column or list of Column"
on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
on = on._jc
if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"
if tolerance is not None:
assert isinstance(tolerance, Column), "tolerance should be Column"
tolerance = tolerance._jc
jdf = self._jdf.joinAsOf(
other._jdf,
left_as_of_jcol,
right_as_of_jcol,
on,
how,
tolerance,
allowExactMatches,
direction,
)
return DataFrame(jdf, self.sparkSession)
def sortWithinPartitions(
self,
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
return DataFrame(jdf, self.sparkSession)
def sort(
self,
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
return DataFrame(jdf, self.sparkSession)
orderBy = sort
def _jseq(
self,
cols: Sequence,
converter: Optional[Callable[..., Union["PrimitiveType", "JavaObject"]]] = None,
) -> "JavaObject":
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sparkSession._sc, cols, converter)
def _jmap(self, jm: Dict) -> "JavaObject":
"""Return a JVM Scala Map from a dict"""
return _to_scala_map(self.sparkSession._sc, jm)
def _jcols(self, *cols: "ColumnOrName") -> "JavaObject":
"""Return a JVM Seq of Columns from a list of Column or column names
If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
return self._jseq(cols, _to_java_column)
def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject":
"""Return a JVM Seq of Columns from a list of Column or column names or column ordinals.
If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
_cols = []
for c in cols:
if isinstance(c, int) and not isinstance(c, bool):
if c < 1:
raise PySparkIndexError(
error_class="INDEX_NOT_POSITIVE", message_parameters={"index": str(c)}
)
# ordinal is 1-based
_cols.append(self[c - 1])
else:
_cols.append(c) # type: ignore[arg-type]
return self._jseq(_cols, _to_java_column)
def _sort_cols(
self,
cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
kwargs: Dict[str, Any],
) -> "JavaObject":
"""Return a JVM Seq of Columns that describes the sort order"""
if not cols:
raise PySparkValueError(
error_class="CANNOT_BE_EMPTY",
message_parameters={"item": "column"},
)
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
jcols = []
for c in cols:
if isinstance(c, int) and not isinstance(c, bool):
# ordinal is 1-based
if c > 0:
_c = self[c - 1]
# negative ordinal means sort by desc
elif c < 0:
_c = self[-c - 1].desc()
else:
raise PySparkIndexError(
error_class="ZERO_INDEX",
message_parameters={},
)
else:
_c = c # type: ignore[assignment]
jcols.append(_to_java_column(cast("ColumnOrName", _c)))
ascending = kwargs.get("ascending", True)
if isinstance(ascending, (bool, int)):
if not ascending:
jcols = [jc.desc() for jc in jcols]
elif isinstance(ascending, list):
jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)]
else:
raise PySparkTypeError(
error_class="NOT_BOOL_OR_LIST",
message_parameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
)
return self._jseq(jcols)
def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sparkSession)
def summary(self, *statistics: str) -> ParentDataFrame:
if len(statistics) == 1 and isinstance(statistics[0], list):
statistics = statistics[0]
jdf = self._jdf.summary(self._jseq(statistics))
return DataFrame(jdf, self.sparkSession)
@overload
def head(self) -> Optional[Row]:
...
@overload
def head(self, n: int) -> List[Row]:
...
def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
if n is None:
rs = self.head(1)
return rs[0] if rs else None
return self.take(n)
def first(self) -> Optional[Row]:
return self.head()
@overload
def __getitem__(self, item: Union[int, str]) -> Column:
...
@overload
def __getitem__(self, item: Union[Column, List, Tuple]) -> ParentDataFrame:
...
def __getitem__(
self, item: Union[int, str, Column, List, Tuple]
) -> Union[Column, ParentDataFrame]:
if isinstance(item, str):
jc = self._jdf.apply(item)
return Column(jc)
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
return self.select(*item)
elif isinstance(item, int):
jc = self._jdf.apply(self.columns[item])
return Column(jc)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR",
message_parameters={"arg_name": "item", "arg_type": type(item).__name__},
)
def __getattr__(self, name: str) -> Column:
if name not in self.columns:
raise PySparkAttributeError(
error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
)
jc = self._jdf.apply(name)
return Column(jc)
def __dir__(self) -> List[str]:
attrs = set(dir(DataFrame))
attrs.update(filter(lambda s: s.isidentifier(), self.columns))
return sorted(attrs)
@overload
def select(self, *cols: "ColumnOrName") -> ParentDataFrame:
...
@overload
def select(self, __cols: Union[List[Column], List[str]]) -> ParentDataFrame:
...
def select(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc]
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sparkSession)
@overload
def selectExpr(self, *expr: str) -> ParentDataFrame:
...
@overload
def selectExpr(self, *expr: List[str]) -> ParentDataFrame:
...
def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame:
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0] # type: ignore[assignment]
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sparkSession)
def filter(self, condition: "ColumnOrName") -> ParentDataFrame:
if isinstance(condition, str):
jdf = self._jdf.filter(condition)
elif isinstance(condition, Column):
jdf = self._jdf.filter(condition._jc)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "condition", "arg_type": type(condition).__name__},
)
return DataFrame(jdf, self.sparkSession)
@overload
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
...
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
@overload
def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
...
@overload
def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
...
def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
jgd = self._jdf.rollup(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
@overload
def cube(self, *cols: "ColumnOrName") -> "GroupedData":
...
@overload
def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
...
def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
jgd = self._jdf.cube(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
def groupingSets(
self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName"
) -> "GroupedData":
from pyspark.sql.group import GroupedData
jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in groupingSets])
jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols))
return GroupedData(jgd, self)
def unpivot(
self,
ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
variableColumnName: str,
valueColumnName: str,
) -> ParentDataFrame:
assert ids is not None, "ids must not be None"
def to_jcols(
cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
) -> "JavaObject":
if isinstance(cols, list):
return self._jcols(*cols)
if isinstance(cols, tuple):
return self._jcols(*list(cols))
return self._jcols(cols)
jids = to_jcols(ids)
if values is None:
jdf = self._jdf.unpivotWithSeq(jids, variableColumnName, valueColumnName)
else:
jvals = to_jcols(values)
jdf = self._jdf.unpivotWithSeq(jids, jvals, variableColumnName, valueColumnName)
return DataFrame(jdf, self.sparkSession)
def melt(
self,
ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
variableColumnName: str,
valueColumnName: str,
) -> ParentDataFrame:
return self.unpivot(ids, values, variableColumnName, valueColumnName)
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame:
return self.groupBy().agg(*exprs) # type: ignore[arg-type]
def observe(
self,
observation: Union["Observation", str],
*exprs: Column,
) -> ParentDataFrame:
from pyspark.sql import Observation
if len(exprs) == 0:
raise PySparkValueError(
error_class="CANNOT_BE_EMPTY",
message_parameters={"item": "exprs"},
)
if not all(isinstance(c, Column) for c in exprs):
raise PySparkTypeError(
error_class="NOT_LIST_OF_COLUMN",
message_parameters={"arg_name": "exprs"},
)
if isinstance(observation, Observation):
return observation._on(self, *exprs)
elif isinstance(observation, str):
return DataFrame(
self._jdf.observe(
observation, exprs[0]._jc, _to_seq(self._sc, [c._jc for c in exprs[1:]])
),
self.sparkSession,
)
else:
raise PySparkTypeError(
error_class="NOT_LIST_OF_COLUMN",
message_parameters={
"arg_name": "observation",
"arg_type": type(observation).__name__,
},
)
def union(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(self._jdf.union(other._jdf), self.sparkSession)
def unionAll(self, other: ParentDataFrame) -> ParentDataFrame:
return self.union(other)
def unionByName(
self, other: ParentDataFrame, allowMissingColumns: bool = False
) -> ParentDataFrame:
return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sparkSession)
def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession)
def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession)
def subtract(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sparkSession)
def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)
if subset is None:
jdf = self._jdf.dropDuplicates()
else:
jdf = self._jdf.dropDuplicates(self._jseq(subset))
return DataFrame(jdf, self.sparkSession)
def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)
if subset is None:
jdf = self._jdf.dropDuplicatesWithinWatermark()
else:
jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset))
return DataFrame(jdf, self.sparkSession)
def dropna(
self,
how: str = "any",
thresh: Optional[int] = None,
subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
) -> ParentDataFrame:
if how is not None and how not in ["any", "all"]:
raise PySparkValueError(
error_class="VALUE_NOT_ANY_OR_ALL",
message_parameters={"arg_name": "how", "arg_type": how},
)
if subset is None:
subset = self.columns
elif isinstance(subset, str):
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_STR_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)
if thresh is None:
thresh = len(subset) if how == "any" else 1
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sparkSession)
@overload
def fillna(
self,
value: "LiteralType",
subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
) -> ParentDataFrame:
...
@overload
def fillna(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame:
...
def fillna(
self,
value: Union["LiteralType", Dict[str, "LiteralType"]],
subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
) -> ParentDataFrame:
if not isinstance(value, (float, int, str, bool, dict)):
raise PySparkTypeError(
error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR",
message_parameters={"arg_name": "value", "arg_type": type(value).__name__},
)
# Note that bool validates isinstance(int), but we don't want to
# convert bools to floats
if not isinstance(value, bool) and isinstance(value, int):
value = float(value)
if isinstance(value, dict):
return DataFrame(self._jdf.na().fill(value), self.sparkSession)
elif subset is None:
return DataFrame(self._jdf.na().fill(value), self.sparkSession)
else:
if isinstance(subset, str):
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sparkSession)
@overload
def replace(
self,
to_replace: "LiteralType",
value: "OptionalPrimitiveType",
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
@overload
def replace(
self,
to_replace: List["LiteralType"],
value: List["OptionalPrimitiveType"],
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
@overload
def replace(
self,
to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
@overload
def replace(
self,
to_replace: List["LiteralType"],
value: "OptionalPrimitiveType",
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
def replace( # type: ignore[misc]
self,
to_replace: Union[
"LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
],
value: Optional[
Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
] = _NoValue,
subset: Optional[List[str]] = None,
) -> ParentDataFrame:
if value is _NoValue:
if isinstance(to_replace, dict):
value = None
else:
raise PySparkTypeError(
error_class="ARGUMENT_REQUIRED",
message_parameters={"arg_name": "value", "condition": "`to_replace` is dict"},
)
# Helper functions
def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]:
"""Given a type or tuple of types and a sequence of xs
check if each x is instance of type(s)
>>> all_of(bool)([True, False])
True
>>> all_of(str)(["a", 1])
False
"""
def all_of_(xs: Iterable) -> bool:
return all(isinstance(x, types) for x in xs)
return all_of_
all_of_bool = all_of(bool)
all_of_str = all_of(str)
all_of_numeric = all_of((float, int))
# Validate input types
valid_types = (bool, float, int, str, list, tuple)
if not isinstance(to_replace, valid_types + (dict,)):
raise PySparkTypeError(
error_class="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
message_parameters={
"arg_name": "to_replace",
"arg_type": type(to_replace).__name__,
},
)
if (
not isinstance(value, valid_types)
and value is not None
and not isinstance(to_replace, dict)
):
raise PySparkTypeError(
error_class="NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE",
message_parameters={
"arg_name": "value",
"arg_type": type(value).__name__,
},
)
if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
if len(to_replace) != len(value):
raise PySparkValueError(
error_class="LENGTH_SHOULD_BE_THE_SAME",
message_parameters={
"arg1": "to_replace",
"arg2": "value",
"arg1_length": str(len(to_replace)),
"arg2_length": str(len(value)),
},
)
if not (subset is None or isinstance(subset, (list, tuple, str))):
raise PySparkTypeError(
error_class="NOT_LIST_OR_STR_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)
# Reshape input arguments if necessary
if isinstance(to_replace, (float, int, str)):
to_replace = [to_replace]
if isinstance(to_replace, dict):
rep_dict = to_replace
if value is not None:
warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
else:
if isinstance(value, (float, int, str)) or value is None:
value = [value for _ in range(len(to_replace))]
rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value)))
if isinstance(subset, str):
subset = [subset]
# Verify we were not passed in mixed type generics.
if not any(
all_of_type(rep_dict.keys())
and all_of_type(x for x in rep_dict.values() if x is not None)
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
):
raise PySparkValueError(
error_class="MIXED_TYPE_REPLACEMENT",
message_parameters={},
)
if subset is None:
return DataFrame(self._jdf.na().replace("*", rep_dict), self.sparkSession)
else:
return DataFrame(
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)),
self.sparkSession,
)
@overload
def approxQuantile(
self,
col: str,
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> List[float]:
...
@overload
def approxQuantile(
self,
col: Union[List[str], Tuple[str]],
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> List[List[float]]:
...
def approxQuantile(
self,
col: Union[str, List[str], Tuple[str]],
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> Union[List[float], List[List[float]]]:
if not isinstance(col, (str, list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_STR_OR_TUPLE",
message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
)
isStr = isinstance(col, str)
if isinstance(col, tuple):
col = list(col)
elif isStr:
col = [cast(str, col)]
for c in col:
if not isinstance(c, str):
raise PySparkTypeError(
error_class="DISALLOWED_TYPE_FOR_CONTAINER",
message_parameters={
"arg_name": "col",
"arg_type": type(col).__name__,
"allowed_types": "str",
"item_type": type(c).__name__,
},
)
col = _to_list(self._sc, cast(List["ColumnOrName"], col))
if not isinstance(probabilities, (list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={
"arg_name": "probabilities",
"arg_type": type(probabilities).__name__,
},
)
if isinstance(probabilities, tuple):
probabilities = list(probabilities)
for p in probabilities:
if not isinstance(p, (float, int)) or p < 0 or p > 1:
raise PySparkTypeError(
error_class="NOT_LIST_OF_FLOAT_OR_INT",
message_parameters={
"arg_name": "probabilities",
"arg_type": type(p).__name__,
},
)
probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities))
if not isinstance(relativeError, (float, int)):
raise PySparkTypeError(
error_class="NOT_FLOAT_OR_INT",
message_parameters={
"arg_name": "relativeError",
"arg_type": type(relativeError).__name__,
},
)
if relativeError < 0:
raise PySparkValueError(
error_class="NEGATIVE_VALUE",
message_parameters={
"arg_name": "relativeError",
"arg_value": str(relativeError),
},
)
relativeError = float(relativeError)
jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
jaq_list = [list(j) for j in jaq]
return jaq_list[0] if isStr else jaq_list
def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float:
if not isinstance(col1, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__},
)
if not isinstance(col2, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__},
)
if not method:
method = "pearson"
if not method == "pearson":
raise PySparkValueError(
error_class="VALUE_NOT_PEARSON",
message_parameters={"arg_name": "method", "arg_value": method},
)
return self._jdf.stat().corr(col1, col2, method)
def cov(self, col1: str, col2: str) -> float:
if not isinstance(col1, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__},
)
if not isinstance(col2, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__},
)
return self._jdf.stat().cov(col1, col2)
def crosstab(self, col1: str, col2: str) -> ParentDataFrame:
if not isinstance(col1, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col1", "arg_type": type(col1).__name__},
)
if not isinstance(col2, str):
raise PySparkTypeError(
error_class="NOT_STR",
message_parameters={"arg_name": "col2", "arg_type": type(col2).__name__},
)
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sparkSession)
def freqItems(
self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None
) -> ParentDataFrame:
if isinstance(cols, tuple):
cols = list(cols)
if not isinstance(cols, list):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "cols", "arg_type": type(cols).__name__},
)
if not support:
support = 0.01
return DataFrame(
self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sparkSession
)
def _ipython_key_completions_(self) -> List[str]:
"""Returns the names of columns in this :class:`DataFrame`.
Examples
--------
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df._ipython_key_completions_()
['age', 'name']
Would return illegal identifiers.
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1", "name?1"])
>>> df._ipython_key_completions_()
['age 1', 'name?1']
"""
return self.columns
def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
# Below code is to help enable kwargs in future.
assert len(colsMap) == 1
colsMap = colsMap[0] # type: ignore[assignment]
if not isinstance(colsMap, dict):
raise PySparkTypeError(
error_class="NOT_DICT",
message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__},
)
col_names = list(colsMap.keys())
cols = list(colsMap.values())
return DataFrame(
self._jdf.withColumns(_to_seq(self._sc, col_names), self._jcols(*cols)),
self.sparkSession,
)
def withColumn(self, colName: str, col: Column) -> ParentDataFrame:
if not isinstance(col, Column):
raise PySparkTypeError(
error_class="NOT_COLUMN",
message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
)
return DataFrame(self._jdf.withColumn(colName, col._jc), self.sparkSession)
def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sparkSession)
def withColumnsRenamed(self, colsMap: Dict[str, str]) -> ParentDataFrame:
if not isinstance(colsMap, dict):
raise PySparkTypeError(
error_class="NOT_DICT",
message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__},
)
col_names: List[str] = []
new_col_names: List[str] = []
for k, v in colsMap.items():
col_names.append(k)
new_col_names.append(v)
return DataFrame(
self._jdf.withColumnsRenamed(
_to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
),
self.sparkSession,
)
def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> ParentDataFrame:
from py4j.java_gateway import JVMView
if not isinstance(metadata, dict):
raise PySparkTypeError(
error_class="NOT_DICT",
message_parameters={"arg_name": "metadata", "arg_type": type(metadata).__name__},
)
sc = get_active_spark_context()
jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson(
json.dumps(metadata)
)
return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession)
@overload
def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
...
@overload
def drop(self, *cols: str) -> ParentDataFrame:
...
def drop(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc]
column_names: List[str] = []
java_columns: List["JavaObject"] = []
for c in cols:
if isinstance(c, str):
column_names.append(c)
elif isinstance(c, Column):
java_columns.append(c._jc)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "col", "arg_type": type(c).__name__},
)
jdf = self._jdf
if len(java_columns) > 0:
first_column, *remaining_columns = java_columns
jdf = jdf.drop(first_column, self._jseq(remaining_columns))
if len(column_names) > 0:
jdf = jdf.drop(self._jseq(column_names))
return DataFrame(jdf, self.sparkSession)
def toDF(self, *cols: str) -> ParentDataFrame:
for col in cols:
if not isinstance(col, str):
raise PySparkTypeError(
error_class="NOT_LIST_OF_STR",
message_parameters={"arg_name": "cols", "arg_type": type(col).__name__},
)
jdf = self._jdf.toDF(self._jseq(cols))
return DataFrame(jdf, self.sparkSession)
def transform(
self, func: Callable[..., ParentDataFrame], *args: Any, **kwargs: Any
) -> ParentDataFrame:
result = func(self, *args, **kwargs)
assert isinstance(
result, DataFrame
), "Func returned an instance of type [%s], " "should have been DataFrame." % type(result)
return result
def sameSemantics(self, other: ParentDataFrame) -> bool:
if not isinstance(other, DataFrame):
raise PySparkTypeError(
error_class="NOT_DATAFRAME",
message_parameters={"arg_name": "other", "arg_type": type(other).__name__},
)
return self._jdf.sameSemantics(other._jdf)
def semanticHash(self) -> int:
return self._jdf.semanticHash()
def inputFiles(self) -> List[str]:
return list(self._jdf.inputFiles())
def where(self, condition: "ColumnOrName") -> ParentDataFrame:
return self.filter(condition)
# Two aliases below were added for pandas compatibility many years ago.
# There are too many differences compared to pandas and we cannot just
# make it "compatible" by adding aliases. Therefore, we stop adding such
# aliases as of Spark 3.0. Two methods below remain just
# for legacy users currently.
@overload
def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...
@overload
def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
...
def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
return self.groupBy(*cols)
def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
return self.dropDuplicates(subset)
def writeTo(self, table: str) -> DataFrameWriterV2:
return DataFrameWriterV2(self, table)
def pandas_api(
self, index_col: Optional[Union[str, List[str]]] = None
) -> "PandasOnSparkDataFrame":
from pyspark.pandas.namespace import _get_index_map
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.pandas.internal import InternalFrame
index_spark_columns, index_names = _get_index_map(self, index_col)
internal = InternalFrame(
spark_frame=self,
index_spark_columns=index_spark_columns,
index_names=index_names, # type: ignore[arg-type]
)
return PandasOnSparkDataFrame(internal)
def mapInPandas(
self,
func: "PandasMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> ParentDataFrame:
return PandasMapOpsMixin.mapInPandas(self, func, schema, barrier, profile)
def mapInArrow(
self,
func: "ArrowMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> ParentDataFrame:
return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile)
def toArrow(self) -> "pa.Table":
return PandasConversionMixin.toArrow(self)
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)
def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject":
"""
Convert a dict into a JVM Map.
"""
assert sc._jvm is not None
return sc._jvm.PythonUtils.toScalaMap(jm)
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
self.df = df
def drop(
self,
how: str = "any",
thresh: Optional[int] = None,
subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
) -> ParentDataFrame:
return self.df.dropna(how=how, thresh=thresh, subset=subset)
@overload
def fill(self, value: "LiteralType", subset: Optional[List[str]] = ...) -> ParentDataFrame:
...
@overload
def fill(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame:
...
def fill(
self,
value: Union["LiteralType", Dict[str, "LiteralType"]],
subset: Optional[List[str]] = None,
) -> ParentDataFrame:
return self.df.fillna(value=value, subset=subset) # type: ignore[arg-type]
@overload
def replace(
self,
to_replace: List["LiteralType"],
value: List["OptionalPrimitiveType"],
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
@overload
def replace(
self,
to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
@overload
def replace(
self,
to_replace: List["LiteralType"],
value: "OptionalPrimitiveType",
subset: Optional[List[str]] = ...,
) -> ParentDataFrame:
...
def replace( # type: ignore[misc]
self,
to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]],
value: Optional[
Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
] = _NoValue,
subset: Optional[List[str]] = None,
) -> ParentDataFrame:
return self.df.replace(to_replace, value, subset) # type: ignore[arg-type]
class DataFrameStatFunctions(ParentDataFrameStatFunctions):
def __init__(self, df: ParentDataFrame):
self.df = df
@overload
def approxQuantile(
self,
col: str,
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> List[float]:
...
@overload
def approxQuantile(
self,
col: Union[List[str], Tuple[str]],
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> List[List[float]]:
...
def approxQuantile(
self,
col: Union[str, List[str], Tuple[str]],
probabilities: Union[List[float], Tuple[float]],
relativeError: float,
) -> Union[List[float], List[List[float]]]:
return self.df.approxQuantile(col, probabilities, relativeError)
def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float:
return self.df.corr(col1, col2, method)
def cov(self, col1: str, col2: str) -> float:
return self.df.cov(col1, col2)
def crosstab(self, col1: str, col2: str) -> ParentDataFrame:
return self.df.crosstab(col1, col2)
def freqItems(self, cols: List[str], support: Optional[float] = None) -> ParentDataFrame:
return self.df.freqItems(cols, support)
def sampleBy(
self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
) -> ParentDataFrame:
return self.df.sampleBy(col, fractions, seed)
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.dataframe
# It inherits docstrings but doctests cannot detect them so we run
# the parent classe's doctests here directly.
globs = pyspark.sql.dataframe.__dict__.copy()
spark = (
SparkSession.builder.master("local[4]").appName("sql.classic.dataframe tests").getOrCreate()
)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()