blob: fc5a4c79d8ad610869a45e199d3323db29ca0cff [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
User-defined function related classes and functions
"""
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import warnings
import sys
import functools
from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union
from pyspark.util import PythonEvalType
from pyspark.sql.connect.expressions import (
ColumnReference,
CommonInlineUserDefinedFunction,
Expression,
NamedArgumentExpression,
PythonUDF,
)
from pyspark.sql.connect.column import Column
from pyspark.sql.types import DataType, StringType, _parse_datatype_string
from pyspark.sql.udf import (
UDFRegistration as PySparkUDFRegistration,
UserDefinedFunction as PySparkUserDefinedFunction,
)
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version
from pyspark.errors import PySparkTypeError, PySparkRuntimeError
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
ColumnOrName,
DataTypeOrString,
UserDefinedFunctionLike,
)
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.types import StringType
def _create_py_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
is_arrow_enabled = False
if useArrow is None:
is_arrow_enabled = False
try:
from pyspark.sql.connect.session import SparkSession
session = SparkSession.active()
is_arrow_enabled = (
str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower()
== "true"
)
except PySparkRuntimeError as e:
if e.getCondition() == "NO_ACTIVE_OR_DEFAULT_SESSION":
pass # Just uses the default if no session found.
else:
raise e
else:
is_arrow_enabled = useArrow
if is_arrow_enabled:
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
except ImportError:
is_arrow_enabled = False
warnings.warn(
"Arrow optimization failed to enable because PyArrow or Pandas is not installed. "
"Falling back to a non-Arrow-optimized UDF.",
RuntimeWarning,
)
eval_type: Optional[int] = None
if useArrow is None:
# If the user doesn't explicitly set useArrow
from pyspark.sql.pandas.typehints import infer_eval_type_from_func
try:
# Try to infer the eval type from type hints
eval_type = infer_eval_type_from_func(f)
except Exception:
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)
if eval_type is None:
if is_arrow_enabled:
# Arrow optimized Python UDF
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
else:
# Fallback to Regular Python UDF
eval_type = PythonEvalType.SQL_BATCHED_UDF
return _create_udf(f, returnType, eval_type)
def _create_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
evalType: int,
name: Optional[str] = None,
deterministic: bool = True,
) -> "UserDefinedFunctionLike":
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
)
return udf_obj._wrapped()
class UserDefinedFunction:
"""
User defined function in Python
Notes
-----
The constructor of this class is not supposed to be directly called.
Use :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf`
to create this instance.
"""
def __init__(
self,
func: Callable[..., Any],
returnType: "DataTypeOrString" = StringType(),
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
):
if not callable(func):
raise PySparkTypeError(
errorClass="NOT_CALLABLE",
messageParameters={"arg_name": "func", "arg_type": type(func).__name__},
)
if not isinstance(returnType, (DataType, str)):
raise PySparkTypeError(
errorClass="NOT_DATATYPE_OR_STR",
messageParameters={
"arg_name": "returnType",
"arg_type": type(returnType).__name__,
},
)
if not isinstance(evalType, int):
raise PySparkTypeError(
errorClass="NOT_INT",
messageParameters={"arg_name": "evalType", "arg_type": type(evalType).__name__},
)
self.func = func
self._returnType = returnType
self._returnType_placeholder: Optional[DataType] = None
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
)
self.evalType = evalType
self.deterministic = deterministic
@property
def returnType(self) -> DataType:
# Make sure this is called after Connect Session is initialized.
# ``_parse_datatype_string`` accesses to Connect Server for parsing a DDL formatted string.
if self._returnType_placeholder is None:
if isinstance(self._returnType, DataType):
self._returnType_placeholder = self._returnType
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
PySparkUserDefinedFunction._check_return_type(self._returnType_placeholder, self.evalType)
return self._returnType_placeholder
def _build_common_inline_user_defined_function(
self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
) -> CommonInlineUserDefinedFunction:
def to_expr(col: "ColumnOrName") -> Expression:
if isinstance(col, Column):
return col._expr
else:
return ColumnReference(col) # type: ignore[arg-type]
arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items()
]
py_udf = PythonUDF(
output_type=self.returnType,
eval_type=self.evalType,
func=self.func,
python_ver="%d.%d" % sys.version_info[:2],
)
return CommonInlineUserDefinedFunction(
function_name=self._name,
function=py_udf,
deterministic=self.deterministic,
arguments=arg_exprs,
)
def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
return Column(self._build_common_inline_user_defined_function(*args, **kwargs))
# This function is for improving the online help system in the interactive interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
# argument annotation. (See: SPARK-19161)
def _wrapped(self) -> "UserDefinedFunctionLike":
"""
Wrap this udf with a function and attach docstring from func
"""
# It is possible for a callable instance without __name__ attribute or/and
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
# we should avoid wrapping the attributes from the wrapped function to the wrapper
# function. So, we take out these attribute names from the default names to set and
# then manually assign it after being wrapped.
assignments = tuple(
a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__" and a != "__module__"
)
@functools.wraps(self.func, assigned=assignments)
def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
return self(*args, **kwargs)
wrapper.__name__ = self._name
wrapper.__module__ = (
self.func.__module__
if hasattr(self.func, "__module__")
else self.func.__class__.__module__
)
wrapper.func = self.func # type: ignore[attr-defined]
wrapper.returnType = self.returnType # type: ignore[attr-defined]
wrapper.evalType = self.evalType # type: ignore[attr-defined]
wrapper.deterministic = self.deterministic # type: ignore[attr-defined]
wrapper.asNondeterministic = functools.wraps( # type: ignore[attr-defined]
self.asNondeterministic
)(lambda: self.asNondeterministic()._wrapped())
wrapper._unwrapped = self # type: ignore[attr-defined]
return wrapper # type: ignore[return-value]
def asNondeterministic(self) -> "UserDefinedFunction":
"""
Updates UserDefinedFunction to nondeterministic.
.. versionadded:: 3.4.0
"""
self.deterministic = False
return self
class UDFRegistration:
"""
Wrapper for user-defined function registration.
"""
def __init__(self, sparkSession: "SparkSession"):
self.sparkSession = sparkSession
def register(
self,
name: str,
f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
returnType: Optional["DataTypeOrString"] = None,
) -> "UserDefinedFunctionLike":
# This is to check whether the input function is from a user-defined function or
# Python function.
if hasattr(f, "asNondeterministic"):
if returnType is not None:
raise PySparkTypeError(
errorClass="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
messageParameters={"arg_name": "f", "return_type": str(returnType)},
)
f = cast("UserDefinedFunctionLike", f)
if f.evalType not in [
PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
]:
raise PySparkTypeError(
errorClass="INVALID_UDF_EVAL_TYPE",
messageParameters={
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
},
)
self.sparkSession._client.register_udf(
f.func, f.returnType, name, f.evalType, f.deterministic
)
return f
else:
if returnType is None:
returnType = StringType()
py_udf = _create_udf(
f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
)
self.sparkSession._client.register_udf(py_udf.func, returnType, name)
return py_udf
register.__doc__ = PySparkUDFRegistration.register.__doc__
def registerJavaFunction(
self,
name: str,
javaClassName: str,
returnType: Optional["DataTypeOrString"] = None,
) -> None:
self.sparkSession._client.register_java(name, javaClassName, returnType)
registerJavaFunction.__doc__ = PySparkUDFRegistration.registerJavaFunction.__doc__
def registerJavaUDAF(self, name: str, javaClassName: str) -> None:
self.sparkSession._client.register_java(name, javaClassName, aggregate=True)
registerJavaUDAF.__doc__ = PySparkUDFRegistration.registerJavaUDAF.__doc__