blob: 616897f7cdcb3fb1ac022c0d6a233cf805f24ec1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import chain
from typing import cast, Any, Callable, Union
import pandas as pd
import numpy as np
from pandas.api.types import is_list_like, CategoricalDtype
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef import pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.column import Column
class CategoricalOps(DataTypeOps):
"""
The class for binary operations of pandas-on-Spark objects with categorical types.
"""
@property
def pretty_name(self) -> str:
return "categoricals"
def restore(self, col: pd.Series) -> pd.Series:
"""Restore column when to_pandas."""
return pd.Series(
pd.Categorical.from_codes(
col.replace(np.nan, -1).astype(int),
categories=cast(CategoricalDtype, self.dtype).categories,
ordered=cast(CategoricalDtype, self.dtype).ordered,
)
)
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.cat.codes
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, _ = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype) and cast(CategoricalDtype, dtype).categories is None:
return index_ops.copy()
return _to_cat(index_ops).astype(dtype)
def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__eq__, is_equality_comparison=True)
def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__ne__, is_equality_comparison=True)
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__lt__)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__le__)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__gt__)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
return _compare(left, right, Column.__ge__)
def _compare(
left: IndexOpsLike,
right: Any,
f: Callable[..., Column],
*,
is_equality_comparison: bool = False
) -> SeriesOrIndex:
"""
Compare a Categorical operand `left` to `right` with the given Spark Column function.
Parameters
----------
left: A Categorical operand
right: The other operand to compare with
f : The Spark Column function to apply
is_equality_comparison: True if it is equality comparison, ie. == or !=. False by default.
Returns
-------
SeriesOrIndex
"""
if isinstance(right, IndexOpsMixin) and isinstance(right.dtype, CategoricalDtype):
if not is_equality_comparison:
if not cast(CategoricalDtype, left.dtype).ordered:
raise TypeError("Unordered Categoricals can only compare equality or not.")
# Check if categoricals have the same dtype, same categories, and same ordered
if hash(left.dtype) != hash(right.dtype):
raise TypeError("Categoricals can only be compared if 'categories' are the same.")
if cast(CategoricalDtype, left.dtype).ordered:
return column_op(f)(left, right)
else:
return column_op(f)(_to_cat(left), _to_cat(right))
elif not is_list_like(right):
categories = cast(CategoricalDtype, left.dtype).categories
if right not in categories:
raise TypeError("Cannot compare a Categorical with a scalar, which is not a category.")
right_code = categories.get_loc(right)
return column_op(f)(left, right_code)
else:
raise TypeError("Cannot compare a Categorical with the given type.")
def _to_cat(index_ops: IndexOpsLike) -> IndexOpsLike:
categories = cast(CategoricalDtype, index_ops.dtype).categories
if len(categories) == 0:
scol = SF.lit(None)
else:
kvs = chain(*[(SF.lit(code), SF.lit(category)) for code, category in enumerate(categories)])
map_scol = F.create_map(*kvs)
scol = map_scol[index_ops.spark.column]
return index_ops._with_new_scol(scol)