blob: 931378a08187f6a727c3f248b92cb7411db6e890 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import json
import warnings
from typing import (
cast,
overload,
Any,
Callable,
Iterable,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)
from pyspark.sql.column import Column as ParentColumn
from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError
from pyspark.errors.utils import with_origin_to_class
from pyspark.sql.types import DataType
from pyspark.sql.utils import get_active_spark_context, enum_to_value
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
from pyspark.core.context import SparkContext
from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral
from pyspark.sql.window import WindowSpec
__all__ = ["Column"]
def _create_column_from_literal(
literal: Union["LiteralType", "DecimalLiteral", "DateTimeLiteral", "ParentColumn"]
) -> "JavaObject":
from py4j.java_gateway import JVMView
sc = get_active_spark_context()
return cast(JVMView, sc._jvm).functions.lit(enum_to_value(literal))
def _create_column_from_name(name: str) -> "JavaObject":
from py4j.java_gateway import JVMView
sc = get_active_spark_context()
return cast(JVMView, sc._jvm).functions.col(name)
def _to_java_column(col: "ColumnOrName") -> "JavaObject":
if isinstance(col, Column):
jcol = col._jc
elif isinstance(col, str):
jcol = _create_column_from_name(col)
else:
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
messageParameters={"arg_name": "col", "arg_type": type(col).__name__},
)
return jcol
@overload
def _to_seq(sc: "SparkContext", cols: Iterable["JavaObject"]) -> "JavaObject":
...
@overload
def _to_seq(
sc: "SparkContext",
cols: Iterable["ColumnOrName"],
converter: Optional[Callable[["ColumnOrName"], "JavaObject"]],
) -> "JavaObject":
...
def _to_seq(
sc: "SparkContext",
cols: Union[Iterable["ColumnOrName"], Iterable["JavaObject"]],
converter: Optional[Callable[["ColumnOrName"], "JavaObject"]] = None,
) -> "JavaObject":
"""
Convert a list of Columns (or names) into a JVM Seq of Column.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
assert sc._jvm is not None
return sc._jvm.PythonUtils.toSeq(cols)
def _to_list(
sc: "SparkContext",
cols: List["ColumnOrName"],
converter: Optional[Callable[["ColumnOrName"], "JavaObject"]] = None,
) -> "JavaObject":
"""
Convert a list of Columns (or names) into a JVM (Scala) List of Columns.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
assert sc._jvm is not None
return sc._jvm.PythonUtils.toList(cols)
def _unary_op(name: str, self: ParentColumn) -> ParentColumn:
"""Create a method for given unary operator"""
jc = getattr(self._jc, name)()
return Column(jc)
def _func_op(name: str, self: ParentColumn) -> ParentColumn:
from py4j.java_gateway import JVMView
sc = get_active_spark_context()
jc = getattr(cast(JVMView, sc._jvm).functions, name)(self._jc)
return Column(jc)
def _bin_func_op(
name: str,
self: ParentColumn,
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
reverse: bool = False,
) -> ParentColumn:
from py4j.java_gateway import JVMView
sc = get_active_spark_context()
fn = getattr(cast(JVMView, sc._jvm).functions, name)
jc = other._jc if isinstance(other, ParentColumn) else _create_column_from_literal(other)
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
return Column(njc)
def _bin_op(
name: str,
self: ParentColumn,
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> ParentColumn:
"""Create a method for given binary operator"""
jc = other._jc if isinstance(other, ParentColumn) else enum_to_value(other)
njc = getattr(self._jc, name)(jc)
return Column(njc)
def _reverse_op(
name: str,
self: ParentColumn,
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> ParentColumn:
"""Create a method for binary operator (this object is on right side)"""
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
return Column(jc)
@with_origin_to_class
class Column(ParentColumn):
def __new__(
cls,
jc: "JavaObject",
) -> "Column":
self = object.__new__(cls)
self.__init__(jc) # type: ignore[misc]
return self
def __init__(self, jc: "JavaObject") -> None:
self._jc = jc
# arithmetic operators
def __neg__(self) -> ParentColumn:
return _func_op("negate", self)
def __add__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("plus", self, other)
def __sub__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("minus", self, other)
def __mul__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("multiply", self, other)
def __div__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("divide", self, other)
def __truediv__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("divide", self, other)
def __mod__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("mod", self, other)
def __radd__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("plus", self, other)
def __rsub__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _reverse_op("minus", self, other)
def __rmul__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("multiply", self, other)
def __rdiv__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _reverse_op("divide", self, other)
def __rtruediv__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _reverse_op("divide", self, other)
def __rmod__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _reverse_op("mod", self, other)
def __pow__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_func_op("pow", self, other)
def __rpow__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_func_op("pow", self, other, reverse=True)
# logistic operators
def __eq__( # type: ignore[override]
self,
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> ParentColumn:
return _bin_op("equalTo", self, other)
def __ne__( # type: ignore[override]
self,
other: Any,
) -> ParentColumn:
return _bin_op("notEqual", self, other)
def __lt__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("lt", self, other)
def __le__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("leq", self, other)
def __ge__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("geq", self, other)
def __gt__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("gt", self, other)
def eqNullSafe(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("eqNullSafe", self, other)
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
def __and__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
from pyspark.sql.functions import lit
return _bin_op("and", self, lit(other))
def __or__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
from pyspark.sql.functions import lit
return _bin_op("or", self, lit(other))
def __invert__(self) -> ParentColumn:
return _func_op("not", self)
def __rand__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
from pyspark.sql.functions import lit
return _bin_op("and", self, lit(other))
def __ror__(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
from pyspark.sql.functions import lit
return _bin_op("or", self, lit(other))
# container operators
def __contains__(self, item: Any) -> None:
raise PySparkValueError(
errorClass="CANNOT_APPLY_IN_FOR_COLUMN",
messageParameters={},
)
# bitwise operators
def bitwiseOR(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("bitwiseOR", self, other)
def bitwiseAND(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("bitwiseAND", self, other)
def bitwiseXOR(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("bitwiseXOR", self, other)
def getItem(self, key: Any) -> ParentColumn:
if isinstance(key, Column):
warnings.warn(
"A column as 'key' in getItem is deprecated as of Spark 3.0, and will not "
"be supported in the future release. Use `column[key]` or `column.key` syntax "
"instead.",
FutureWarning,
)
return self[key]
def getField(self, name: Any) -> ParentColumn:
if isinstance(name, Column):
warnings.warn(
"A column as 'name' in getField is deprecated as of Spark 3.0, and will not "
"be supported in the future release. Use `column[name]` or `column.name` syntax "
"instead.",
FutureWarning,
)
return self[name]
def withField(self, fieldName: str, col: ParentColumn) -> ParentColumn:
if not isinstance(fieldName, str):
raise PySparkTypeError(
errorClass="NOT_STR",
messageParameters={"arg_name": "fieldName", "arg_type": type(fieldName).__name__},
)
if not isinstance(col, Column):
raise PySparkTypeError(
errorClass="NOT_COLUMN",
messageParameters={"arg_name": "col", "arg_type": type(col).__name__},
)
return Column(self._jc.withField(fieldName, col._jc))
def dropFields(self, *fieldNames: str) -> ParentColumn:
sc = get_active_spark_context()
jc = self._jc.dropFields(_to_seq(sc, fieldNames))
return Column(jc)
def __getattr__(self, item: Any) -> ParentColumn:
if item.startswith("__"):
raise PySparkAttributeError(
errorClass="CANNOT_ACCESS_TO_DUNDER",
messageParameters={},
)
return self[item]
def __getitem__(self, k: Any) -> ParentColumn:
if isinstance(k, slice):
if k.step is not None:
raise PySparkValueError(
errorClass="SLICE_WITH_STEP",
messageParameters={},
)
return self.substr(k.start, k.stop)
else:
return _bin_op("apply", self, k)
def __iter__(self) -> None:
raise PySparkTypeError(
errorClass="NOT_ITERABLE", messageParameters={"objectName": "Column"}
)
# string methods
def contains(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("contains", self, other)
def startswith(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("startsWith", self, other)
def endswith(
self, other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"]
) -> ParentColumn:
return _bin_op("endsWith", self, other)
def like(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "like")(enum_to_value(other))
return Column(njc)
def rlike(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "rlike")(enum_to_value(other))
return Column(njc)
def ilike(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "ilike")(enum_to_value(other))
return Column(njc)
def substr(
self, startPos: Union[int, ParentColumn], length: Union[int, ParentColumn]
) -> ParentColumn:
startPos = enum_to_value(startPos)
length = enum_to_value(length)
if type(startPos) != type(length):
raise PySparkTypeError(
errorClass="NOT_SAME_TYPE",
messageParameters={
"arg_name1": "startPos",
"arg_name2": "length",
"arg_type1": type(startPos).__name__,
"arg_type2": type(length).__name__,
},
)
if isinstance(startPos, int):
jc = self._jc.substr(startPos, length)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, cast(ParentColumn, length)._jc)
else:
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_INT",
messageParameters={"arg_name": "startPos", "arg_type": type(startPos).__name__},
)
return Column(jc)
def isin(self, *cols: Any) -> ParentColumn:
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cast(Tuple, cols[0])
cols = cast(
Tuple,
[c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols],
)
sc = get_active_spark_context()
jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
return Column(jc)
# order
def asc(self) -> ParentColumn:
return _unary_op("asc", self)
def asc_nulls_first(self) -> ParentColumn:
return _unary_op("asc_nulls_first", self)
def asc_nulls_last(self) -> ParentColumn:
return _unary_op("asc_nulls_last", self)
def desc(self) -> ParentColumn:
return _unary_op("desc", self)
def desc_nulls_first(self) -> ParentColumn:
return _unary_op("desc_nulls_first", self)
def desc_nulls_last(self) -> ParentColumn:
return _unary_op("desc_nulls_last", self)
def isNull(self) -> ParentColumn:
return _unary_op("isNull", self)
def isNotNull(self) -> ParentColumn:
return _unary_op("isNotNull", self)
def isNaN(self) -> ParentColumn:
return _unary_op("isNaN", self)
def alias(self, *alias: str, **kwargs: Any) -> ParentColumn:
metadata = kwargs.pop("metadata", None)
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
sc = get_active_spark_context()
if len(alias) == 1:
if metadata:
assert sc._jvm is not None
jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata))
return Column(getattr(self._jc, "as")(alias[0], jmeta))
else:
return Column(getattr(self._jc, "as")(alias[0]))
else:
if metadata is not None:
raise PySparkValueError(
errorClass="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
messageParameters={"arg_name": "metadata"},
)
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
def name(self, *alias: str, **kwargs: Any) -> ParentColumn:
return self.alias(*alias, **kwargs)
def cast(self, dataType: Union[DataType, str]) -> ParentColumn:
if isinstance(dataType, str):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
from pyspark.sql import SparkSession
spark = SparkSession._getActiveSessionOrCreate()
jdt = spark._jsparkSession.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise PySparkTypeError(
errorClass="NOT_DATATYPE_OR_STR",
messageParameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
)
return Column(jc)
def try_cast(self, dataType: Union[DataType, str]) -> ParentColumn:
if isinstance(dataType, str):
jc = self._jc.try_cast(dataType)
elif isinstance(dataType, DataType):
from pyspark.sql import SparkSession
spark = SparkSession._getActiveSessionOrCreate()
jdt = spark._jsparkSession.parseDataType(dataType.json())
jc = self._jc.try_cast(jdt)
else:
raise PySparkTypeError(
errorClass="NOT_DATATYPE_OR_STR",
messageParameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
)
return Column(jc)
def astype(self, dataType: Union[DataType, str]) -> ParentColumn:
return self.cast(dataType)
def between(
self,
lowerBound: Union[ParentColumn, "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
upperBound: Union[ParentColumn, "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
) -> ParentColumn:
return (self >= lowerBound) & (self <= upperBound)
def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
if not isinstance(condition, Column):
raise PySparkTypeError(
errorClass="NOT_COLUMN",
messageParameters={"arg_name": "condition", "arg_type": type(condition).__name__},
)
v = value._jc if isinstance(value, Column) else enum_to_value(value)
jc = self._jc.when(condition._jc, v)
return Column(jc)
def otherwise(self, value: Any) -> ParentColumn:
v = value._jc if isinstance(value, Column) else enum_to_value(value)
jc = self._jc.otherwise(v)
return Column(jc)
def over(self, window: "WindowSpec") -> ParentColumn:
from pyspark.sql.classic.window import WindowSpec
if not isinstance(window, WindowSpec):
raise PySparkTypeError(
errorClass="NOT_WINDOWSPEC",
messageParameters={"arg_name": "window", "arg_type": type(window).__name__},
)
jc = self._jc.over(window._jspec)
return Column(jc)
def __nonzero__(self) -> None:
raise PySparkValueError(
errorClass="CANNOT_CONVERT_COLUMN_INTO_BOOL",
messageParameters={},
)
__bool__ = __nonzero__
def __repr__(self) -> str:
return "Column<'%s'>" % self._jc.toString()
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.column
# It inherits docstrings but doctests cannot detect them so we run
# the parent classe's doctests here directly.
globs = pyspark.sql.column.__dict__.copy()
spark = (
SparkSession.builder.master("local[4]").appName("sql.classic.column tests").getOrCreate()
)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.column,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()