| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| from pyspark.sql import Column, functions as F, is_remote |
| |
| from typing import Union, TYPE_CHECKING |
| |
| if TYPE_CHECKING: |
| from pyspark.sql._typing import ColumnOrName |
| |
| """ |
| Internal Spark functions used in Pandas API on Spark & PySpark Native Plotting. |
| """ |
| |
| |
| class InternalFunction: |
| @staticmethod |
| def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: |
| if is_remote(): |
| from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns |
| |
| return _invoke_function_over_columns(name, *cols) |
| |
| else: |
| from pyspark.sql.classic.column import Column, _to_seq, _to_java_column |
| from pyspark import SparkContext |
| |
| sc = SparkContext._active_spark_context |
| return Column( |
| sc._jvm.PythonSQLUtils.internalFn( # type: ignore |
| name, _to_seq(sc, cols, _to_java_column) # type: ignore |
| ) |
| ) |
| |
| @staticmethod |
| def timestamp_ntz_to_long(col: Column) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("timestamp_ntz_to_long", col) |
| |
| @staticmethod |
| def product(col: Column, dropna: bool) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "pandas_product", col, F.lit(dropna) |
| ) |
| |
| @staticmethod |
| def stddev(col: Column, ddof: int) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "pandas_stddev", col, F.lit(ddof) |
| ) |
| |
| @staticmethod |
| def var(col: Column, ddof: int) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "pandas_var", col, F.lit(ddof) |
| ) |
| |
| @staticmethod |
| def skew(col: Column) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("pandas_skew", col) |
| |
| @staticmethod |
| def kurt(col: Column) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("pandas_kurt", col) |
| |
| @staticmethod |
| def mode(col: Column, dropna: bool) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "pandas_mode", col, F.lit(dropna) |
| ) |
| |
| @staticmethod |
| def covar(col1: Column, col2: Column, ddof: int) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "pandas_covar", col1, col2, F.lit(ddof) |
| ) |
| |
| @staticmethod |
| def ewm(col: Column, alpha: float, ignorena: bool) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "ewm", col, F.lit(alpha), F.lit(ignorena) |
| ) |
| |
| @staticmethod |
| def null_index(col: Column) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("null_index", col) |
| |
| @staticmethod |
| def distributed_id() -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("distributed_id") |
| |
| @staticmethod |
| def distributed_sequence_id() -> Column: |
| return InternalFunction._invoke_internal_function_over_columns("distributed_sequence_id") |
| |
| @staticmethod |
| def collect_top_k(col: Column, num: int, reverse: bool) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "collect_top_k", col, F.lit(num), F.lit(reverse) |
| ) |
| |
| @staticmethod |
| def array_binary_search(col: Column, value: Column) -> Column: |
| return InternalFunction._invoke_internal_function_over_columns( |
| "array_binary_search", col, value |
| ) |
| |
| @staticmethod |
| def make_interval(unit: str, e: Union[Column, int, float]) -> Column: |
| unit_mapping = { |
| "YEAR": "years", |
| "MONTH": "months", |
| "WEEK": "weeks", |
| "DAY": "days", |
| "HOUR": "hours", |
| "MINUTE": "mins", |
| "SECOND": "secs", |
| } |
| return F.make_interval(**{unit_mapping[unit]: F.lit(e)}) |
| |
| @staticmethod |
| def vector_get(vec: Column, idx: Column) -> Column: |
| unwrapped = F.unwrap_udt(vec) |
| is_dense = unwrapped.getField("type") == F.lit(1) |
| values = unwrapped.getField("values") |
| size = F.when(is_dense, F.array_size(values)).otherwise(unwrapped.getField("size")) |
| sparse_idx = InternalFunction.array_binary_search(unwrapped.getField("indices"), idx) |
| value = ( |
| F.when(is_dense, F.get(values, idx)) |
| .when(sparse_idx >= 0, F.get(values, sparse_idx)) |
| .otherwise(F.lit(0.0)) |
| ) |
| |
| return F.when((0 <= idx) & (idx < size), value).otherwise( |
| F.raise_error(F.printf(F.lit("Vector index must be in [0, %s), but got %s"), size, idx)) |
| ) |
| |
| # The main different between this function and the `array_max` + `array_position`: |
| # This function ignores NaN/NULL values, while `array_max` treats NaN values as largest values. |
| @staticmethod |
| def array_argmax(arr: Column) -> Column: |
| def merge(acc: Column, vv: Column) -> Column: |
| v = acc.getField("v") |
| i = acc.getField("i") |
| j = acc.getField("j") |
| return F.when( |
| (~vv.isNaN()) & (~vv.isNull()) & (vv > v), |
| F.struct(vv.alias("v"), j.alias("i"), j + 1), |
| ).otherwise(F.struct(v.alias("v"), i.alias("i"), j + 1)) |
| |
| return F.aggregate( |
| arr, |
| F.struct( |
| F.lit(float("-inf")).alias("v"), # max value |
| F.lit(-1).alias("i"), # index of max value |
| F.lit(0).alias("j"), # current index |
| ), |
| merge, |
| lambda acc: acc.getField("i"), |
| ) |