| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import abc |
| import collections |
| import functools |
| import inspect |
| from typing import Union, List, Type, Callable, TypeVar, Generic, Iterable |
| |
| from pyflink.java_gateway import get_gateway |
| from pyflink.metrics import MetricGroup |
| from pyflink.table import Expression |
| from pyflink.table.types import DataType, _to_java_type, _to_java_data_type |
| from pyflink.util import java_utils |
| |
| __all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction', 'TableFunction', |
| 'TableAggregateFunction', 'udf', 'udtf', 'udaf', 'udtaf'] |
| |
| |
| class FunctionContext(object): |
| """ |
| Used to obtain global runtime information about the context in which the |
| user-defined function is executed. The information includes the metric group, |
| and global job parameters, etc. |
| """ |
| |
| def __init__(self, base_metric_group): |
| self._base_metric_group = base_metric_group |
| |
| def get_metric_group(self) -> MetricGroup: |
| """ |
| Returns the metric group for this parallel subtask. |
| |
| .. versionadded:: 1.11.0 |
| """ |
| if self._base_metric_group is None: |
| raise RuntimeError("Metric has not been enabled. You can enable " |
| "metric with the 'python.metric.enabled' configuration.") |
| return self._base_metric_group |
| |
| |
| class UserDefinedFunction(abc.ABC): |
| """ |
| Base interface for user-defined function. |
| |
| .. versionadded:: 1.10.0 |
| """ |
| |
| def open(self, function_context: FunctionContext): |
| """ |
| Initialization method for the function. It is called before the actual working methods |
| and thus suitable for one time setup work. |
| |
| :param function_context: the context of the function |
| :type function_context: FunctionContext |
| """ |
| pass |
| |
| def close(self): |
| """ |
| Tear-down method for the user code. It is called after the last call to the main |
| working methods. |
| """ |
| pass |
| |
| def is_deterministic(self) -> bool: |
| """ |
| Returns information about the determinism of the function's results. |
| It returns true if and only if a call to this function is guaranteed to |
| always return the same result given the same parameters. true is assumed by default. |
| If the function is not pure functional like random(), date(), now(), |
| this method must return false. |
| |
| :return: the determinism of the function's results. |
| """ |
| return True |
| |
| |
| class ScalarFunction(UserDefinedFunction): |
| """ |
| Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, |
| or multiple scalar values to a new scalar value. |
| |
| .. versionadded:: 1.10.0 |
| """ |
| |
| @abc.abstractmethod |
| def eval(self, *args): |
| """ |
| Method which defines the logic of the scalar function. |
| """ |
| pass |
| |
| |
| class TableFunction(UserDefinedFunction): |
| """ |
| Base interface for user-defined table function. A user-defined table function creates zero, one, |
| or multiple rows to a new row value. |
| |
| .. versionadded:: 1.11.0 |
| """ |
| |
| @abc.abstractmethod |
| def eval(self, *args): |
| """ |
| Method which defines the logic of the table function. |
| """ |
| pass |
| |
| |
| T = TypeVar('T') |
| ACC = TypeVar('ACC') |
| |
| |
| class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]): |
| """ |
| Base interface for user-defined aggregate function and table aggregate function. |
| |
| This class is used for unified handling of imperative aggregating functions. Concrete |
| implementations should extend from :class:`~pyflink.table.AggregateFunction` or |
| :class:`~pyflink.table.TableAggregateFunction`. |
| |
| .. versionadded:: 1.13.0 |
| """ |
| |
| @abc.abstractmethod |
| def create_accumulator(self) -> ACC: |
| """ |
| Creates and initializes the accumulator for this AggregateFunction. |
| |
| :return: the accumulator with the initial value |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def accumulate(self, accumulator: ACC, *args): |
| """ |
| Processes the input values and updates the provided accumulator instance. |
| |
| :param accumulator: the accumulator which contains the current aggregated results |
| :param args: the input value (usually obtained from new arrived data) |
| """ |
| pass |
| |
| def retract(self, accumulator: ACC, *args): |
| """ |
| Retracts the input values from the accumulator instance.The current design assumes the |
| inputs are the values that have been previously accumulated. |
| |
| :param accumulator: the accumulator which contains the current aggregated results |
| :param args: the input value (usually obtained from new arrived data). |
| """ |
| pass |
| |
| def merge(self, accumulator: ACC, accumulators): |
| """ |
| Merges a group of accumulator instances into one accumulator instance. This method must be |
| implemented for unbounded session window grouping aggregates and bounded grouping |
| aggregates. |
| |
| :param accumulator: the accumulator which will keep the merged aggregate results. It should |
| be noted that the accumulator may contain the previous aggregated |
| results. Therefore user should not replace or clean this instance in the |
| custom merge method. |
| :param accumulators: a group of accumulators that will be merged. |
| """ |
| pass |
| |
| def get_result_type(self) -> DataType: |
| """ |
| Returns the DataType of the AggregateFunction's result. |
| |
| :return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's result. |
| |
| """ |
| pass |
| |
| def get_accumulator_type(self) -> DataType: |
| """ |
| Returns the DataType of the AggregateFunction's accumulator. |
| |
| :return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's accumulator. |
| |
| """ |
| pass |
| |
| |
| class AggregateFunction(ImperativeAggregateFunction): |
| """ |
| Base interface for user-defined aggregate function. A user-defined aggregate function maps |
| scalar values of multiple rows to a new scalar value. |
| |
| .. versionadded:: 1.12.0 |
| """ |
| |
| @abc.abstractmethod |
| def get_value(self, accumulator: ACC) -> T: |
| """ |
| Called every time when an aggregation result should be materialized. The returned value |
| could be either an early and incomplete result (periodically emitted as data arrives) or |
| the final result of the aggregation. |
| |
| :param accumulator: the accumulator which contains the current intermediate results |
| :return: the aggregation result |
| """ |
| pass |
| |
| |
| class TableAggregateFunction(ImperativeAggregateFunction): |
| """ |
| Base class for a user-defined table aggregate function. A user-defined table aggregate function |
| maps scalar values of multiple rows to zero, one, or multiple rows (or structured types). If an |
| output record consists of only one field, the structured record can be omitted, and a scalar |
| value can be emitted that will be implicitly wrapped into a row by the runtime. |
| |
| .. versionadded:: 1.13.0 |
| """ |
| |
| @abc.abstractmethod |
| def emit_value(self, accumulator: ACC) -> Iterable[T]: |
| """ |
| Called every time when an aggregation result should be materialized. The returned value |
| could be either an early and incomplete result (periodically emitted as data arrives) or the |
| final result of the aggregation. |
| |
| :param accumulator: the accumulator which contains the current aggregated results. |
| :return: multiple aggregated result |
| """ |
| pass |
| |
| |
| class DelegatingScalarFunction(ScalarFunction): |
| """ |
| Helper scalar function implementation for lambda expression and python function. It's for |
| internal use only. |
| """ |
| |
| def __init__(self, func): |
| self.func = func |
| |
| def eval(self, *args): |
| return self.func(*args) |
| |
| |
| class DelegationTableFunction(TableFunction): |
| """ |
| Helper table function implementation for lambda expression and python function. It's for |
| internal use only. |
| """ |
| |
| def __init__(self, func): |
| self.func = func |
| |
| def eval(self, *args): |
| return self.func(*args) |
| |
| |
| class DelegatingPandasAggregateFunction(AggregateFunction): |
| """ |
| Helper pandas aggregate function implementation for lambda expression and python function. |
| It's for internal use only. |
| """ |
| |
| def __init__(self, func): |
| self.func = func |
| |
| def get_value(self, accumulator): |
| return accumulator[0] |
| |
| def create_accumulator(self): |
| return [] |
| |
| def accumulate(self, accumulator, *args): |
| accumulator.append(self.func(*args)) |
| |
| |
| class PandasAggregateFunctionWrapper(object): |
| """ |
| Wrapper for Pandas Aggregate function. |
| """ |
| def __init__(self, func: AggregateFunction): |
| self.func = func |
| |
| def open(self, function_context: FunctionContext): |
| self.func.open(function_context) |
| |
| def eval(self, *args): |
| accumulator = self.func.create_accumulator() |
| self.func.accumulate(accumulator, *args) |
| return self.func.get_value(accumulator) |
| |
| def close(self): |
| self.func.close() |
| |
| |
| class UserDefinedFunctionWrapper(object): |
| """ |
| Base Wrapper for Python user-defined function. It handles things like converting lambda |
| functions to user-defined functions, creating the Java user-defined function representation, |
| etc. It's for internal use only. |
| """ |
| |
| def __init__(self, func, input_types, func_type, deterministic=None, name=None): |
| if inspect.isclass(func) or ( |
| not isinstance(func, UserDefinedFunction) and not callable(func)): |
| raise TypeError( |
| "Invalid function: not a function or callable (__call__ is not defined): {0}" |
| .format(type(func))) |
| |
| if input_types is not None: |
| from pyflink.table.types import RowType |
| if not isinstance(input_types, collections.abc.Iterable) \ |
| or isinstance(input_types, RowType): |
| input_types = [input_types] |
| |
| for input_type in input_types: |
| if not isinstance(input_type, DataType): |
| raise TypeError( |
| "Invalid input_type: input_type should be DataType but contains {}".format( |
| input_type)) |
| |
| self._func = func |
| self._input_types = input_types |
| self._name = name or ( |
| func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) |
| |
| if deterministic is not None and isinstance(func, UserDefinedFunction) and deterministic \ |
| != func.is_deterministic(): |
| raise ValueError("Inconsistent deterministic: {} and {}".format( |
| deterministic, func.is_deterministic())) |
| |
| # default deterministic is True |
| self._deterministic = deterministic if deterministic is not None else ( |
| func.is_deterministic() if isinstance(func, UserDefinedFunction) else True) |
| self._func_type = func_type |
| self._judf_placeholder = None |
| self._takes_row_as_input = False |
| |
| def __call__(self, *args) -> Expression: |
| from pyflink.table import expressions as expr |
| return expr.call(self, *args) |
| |
| def alias(self, *alias_names: str): |
| self._alias_names = alias_names |
| return self |
| |
| def _set_takes_row_as_input(self): |
| self._takes_row_as_input = True |
| return self |
| |
| def _java_user_defined_function(self): |
| if self._judf_placeholder is None: |
| gateway = get_gateway() |
| |
| def get_python_function_kind(): |
| JPythonFunctionKind = gateway.jvm.org.apache.flink.table.functions.python. \ |
| PythonFunctionKind |
| if self._func_type == "general": |
| return JPythonFunctionKind.GENERAL |
| elif self._func_type == "pandas": |
| return JPythonFunctionKind.PANDAS |
| else: |
| raise TypeError("Unsupported func_type: %s." % self._func_type) |
| |
| if self._input_types is not None: |
| j_input_types = java_utils.to_jarray( |
| gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types]) |
| else: |
| j_input_types = None |
| j_function_kind = get_python_function_kind() |
| func = self._func |
| if not isinstance(self._func, UserDefinedFunction): |
| func = self._create_delegate_function() |
| |
| import cloudpickle |
| serialized_func = cloudpickle.dumps(func) |
| self._judf_placeholder = \ |
| self._create_judf(serialized_func, j_input_types, j_function_kind) |
| return self._judf_placeholder |
| |
| def _create_delegate_function(self) -> UserDefinedFunction: |
| pass |
| |
| def _create_judf(self, serialized_func, j_input_types, j_function_kind): |
| pass |
| |
| |
| class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper): |
| """ |
| Wrapper for Python user-defined scalar function. |
| """ |
| |
| def __init__(self, func, input_types, result_type, func_type, deterministic, name): |
| super(UserDefinedScalarFunctionWrapper, self).__init__( |
| func, input_types, func_type, deterministic, name) |
| |
| if not isinstance(result_type, DataType): |
| raise TypeError( |
| "Invalid returnType: returnType should be DataType but is {}".format(result_type)) |
| self._result_type = result_type |
| self._judf_placeholder = None |
| |
| def _create_judf(self, serialized_func, j_input_types, j_function_kind): |
| gateway = get_gateway() |
| j_result_type = _to_java_type(self._result_type) |
| PythonScalarFunction = gateway.jvm \ |
| .org.apache.flink.table.functions.python.PythonScalarFunction |
| j_scalar_function = PythonScalarFunction( |
| self._name, |
| bytearray(serialized_func), |
| j_input_types, |
| j_result_type, |
| j_function_kind, |
| self._deterministic, |
| self._takes_row_as_input, |
| _get_python_env()) |
| return j_scalar_function |
| |
| def _create_delegate_function(self) -> UserDefinedFunction: |
| return DelegatingScalarFunction(self._func) |
| |
| |
| class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper): |
| """ |
| Wrapper for Python user-defined table function. |
| """ |
| |
| def __init__(self, func, input_types, result_types, deterministic=None, name=None): |
| super(UserDefinedTableFunctionWrapper, self).__init__( |
| func, input_types, "general", deterministic, name) |
| |
| from pyflink.table.types import RowType |
| if not isinstance(result_types, collections.abc.Iterable) \ |
| or isinstance(result_types, RowType): |
| result_types = [result_types] |
| |
| for result_type in result_types: |
| if not isinstance(result_type, DataType): |
| raise TypeError( |
| "Invalid result_type: result_type should be DataType but contains {}".format( |
| result_type)) |
| |
| self._result_types = result_types |
| |
| def _create_judf(self, serialized_func, j_input_types, j_function_kind): |
| gateway = get_gateway() |
| j_result_types = java_utils.to_jarray(gateway.jvm.TypeInformation, |
| [_to_java_type(i) for i in self._result_types]) |
| j_result_type = gateway.jvm.org.apache.flink.api.java.typeutils.RowTypeInfo(j_result_types) |
| PythonTableFunction = gateway.jvm \ |
| .org.apache.flink.table.functions.python.PythonTableFunction |
| j_table_function = PythonTableFunction( |
| self._name, |
| bytearray(serialized_func), |
| j_input_types, |
| j_result_type, |
| j_function_kind, |
| self._deterministic, |
| self._takes_row_as_input, |
| _get_python_env()) |
| return j_table_function |
| |
| def _create_delegate_function(self) -> UserDefinedFunction: |
| return DelegationTableFunction(self._func) |
| |
| |
| class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper): |
| """ |
| Wrapper for Python user-defined aggregate function or user-defined table aggregate function. |
| """ |
| def __init__(self, func, input_types, result_type, accumulator_type, func_type, |
| deterministic, name, is_table_aggregate=False): |
| super(UserDefinedAggregateFunctionWrapper, self).__init__( |
| func, input_types, func_type, deterministic, name) |
| |
| if accumulator_type is None and func_type == "general": |
| accumulator_type = func.get_accumulator_type() |
| if result_type is None: |
| result_type = func.get_result_type() |
| if not isinstance(result_type, DataType): |
| raise TypeError( |
| "Invalid returnType: returnType should be DataType but is {}".format(result_type)) |
| from pyflink.table.types import MapType |
| if func_type == 'pandas' and isinstance(result_type, MapType): |
| raise TypeError( |
| "Invalid returnType: Pandas UDAF doesn't support DataType type {} currently" |
| .format(result_type)) |
| if accumulator_type is not None and not isinstance(accumulator_type, DataType): |
| raise TypeError( |
| "Invalid accumulator_type: accumulator_type should be DataType but is {}".format( |
| accumulator_type)) |
| self._result_type = result_type |
| self._accumulator_type = accumulator_type |
| self._is_table_aggregate = is_table_aggregate |
| |
| def _create_judf(self, serialized_func, j_input_types, j_function_kind): |
| if self._func_type == "pandas": |
| from pyflink.table.types import DataTypes |
| self._accumulator_type = DataTypes.ARRAY(self._result_type) |
| |
| if j_input_types is not None: |
| gateway = get_gateway() |
| j_input_types = java_utils.to_jarray( |
| gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types]) |
| j_result_type = _to_java_data_type(self._result_type) |
| j_accumulator_type = _to_java_data_type(self._accumulator_type) |
| |
| gateway = get_gateway() |
| if self._is_table_aggregate: |
| PythonAggregateFunction = gateway.jvm \ |
| .org.apache.flink.table.functions.python.PythonTableAggregateFunction |
| else: |
| PythonAggregateFunction = gateway.jvm \ |
| .org.apache.flink.table.functions.python.PythonAggregateFunction |
| j_aggregate_function = PythonAggregateFunction( |
| self._name, |
| bytearray(serialized_func), |
| j_input_types, |
| j_result_type, |
| j_accumulator_type, |
| j_function_kind, |
| self._deterministic, |
| self._takes_row_as_input, |
| _get_python_env()) |
| return j_aggregate_function |
| |
| def _create_delegate_function(self) -> UserDefinedFunction: |
| assert self._func_type == 'pandas' |
| return DelegatingPandasAggregateFunction(self._func) |
| |
| |
| # TODO: support to configure the python execution environment |
| def _get_python_env(): |
| gateway = get_gateway() |
| exec_type = gateway.jvm.org.apache.flink.table.functions.python.PythonEnv.ExecType.PROCESS |
| return gateway.jvm.org.apache.flink.table.functions.python.PythonEnv(exec_type) |
| |
| |
| def _create_udf(f, input_types, result_type, func_type, deterministic, name): |
| return UserDefinedScalarFunctionWrapper( |
| f, input_types, result_type, func_type, deterministic, name) |
| |
| |
| def _create_udtf(f, input_types, result_types, deterministic, name): |
| return UserDefinedTableFunctionWrapper(f, input_types, result_types, deterministic, name) |
| |
| |
| def _create_udaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name): |
| return UserDefinedAggregateFunctionWrapper( |
| f, input_types, result_type, accumulator_type, func_type, deterministic, name) |
| |
| |
| def _create_udtaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name): |
| return UserDefinedAggregateFunctionWrapper( |
| f, input_types, result_type, accumulator_type, func_type, deterministic, name, True) |
| |
| |
| def udf(f: Union[Callable, ScalarFunction, Type] = None, |
| input_types: Union[List[DataType], DataType] = None, result_type: DataType = None, |
| deterministic: bool = None, name: str = None, func_type: str = "general", |
| udf_type: str = None) -> Union[UserDefinedScalarFunctionWrapper, Callable]: |
| """ |
| Helper method for creating a user-defined function. |
| |
| Example: |
| :: |
| |
| >>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) |
| |
| >>> # The input_types is optional. |
| >>> @udf(result_type=DataTypes.BIGINT()) |
| ... def add(i, j): |
| ... return i + j |
| |
| >>> class SubtractOne(ScalarFunction): |
| ... def eval(self, i): |
| ... return i - 1 |
| >>> subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) |
| |
| :param f: lambda function or user-defined function. |
| :param input_types: optional, the input data types. |
| :param result_type: the result data type. |
| :param deterministic: the determinism of the function's results. True if and only if a call to |
| this function is guaranteed to always return the same result given the |
| same parameters. (default True) |
| :param name: the function name. |
| :param func_type: the type of the python function, available value: general, pandas, |
| (default: general) |
| :param udf_type: the type of the python function, available value: general, pandas, |
| (default: general) |
| :return: UserDefinedScalarFunctionWrapper or function. |
| |
| .. versionadded:: 1.10.0 |
| """ |
| if udf_type: |
| import warnings |
| warnings.warn("The param udf_type is deprecated in 1.12. Use func_type instead.") |
| func_type = udf_type |
| |
| if func_type not in ('general', 'pandas'): |
| raise ValueError("The func_type must be one of 'general, pandas', got %s." |
| % func_type) |
| |
| # decorator |
| if f is None: |
| return functools.partial(_create_udf, input_types=input_types, result_type=result_type, |
| func_type=func_type, deterministic=deterministic, |
| name=name) |
| else: |
| return _create_udf(f, input_types, result_type, func_type, deterministic, name) |
| |
| |
| def udtf(f: Union[Callable, TableFunction, Type] = None, |
| input_types: Union[List[DataType], DataType] = None, |
| result_types: Union[List[DataType], DataType] = None, deterministic: bool = None, |
| name: str = None) -> Union[UserDefinedTableFunctionWrapper, Callable]: |
| """ |
| Helper method for creating a user-defined table function. |
| |
| Example: |
| :: |
| |
| >>> # The input_types is optional. |
| >>> @udtf(result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) |
| ... def range_emit(s, e): |
| ... for i in range(e): |
| ... yield s, i |
| |
| >>> class MultiEmit(TableFunction): |
| ... def eval(self, i): |
| ... return range(i) |
| >>> multi_emit = udtf(MultiEmit(), DataTypes.BIGINT(), DataTypes.BIGINT()) |
| |
| :param f: user-defined table function. |
| :param input_types: optional, the input data types. |
| :param result_types: the result data types. |
| :param name: the function name. |
| :param deterministic: the determinism of the function's results. True if and only if a call to |
| this function is guaranteed to always return the same result given the |
| same parameters. (default True) |
| :return: UserDefinedTableFunctionWrapper or function. |
| |
| .. versionadded:: 1.11.0 |
| """ |
| # decorator |
| if f is None: |
| return functools.partial(_create_udtf, input_types=input_types, result_types=result_types, |
| deterministic=deterministic, name=name) |
| else: |
| return _create_udtf(f, input_types, result_types, deterministic, name) |
| |
| |
| def udaf(f: Union[Callable, AggregateFunction, Type] = None, |
| input_types: Union[List[DataType], DataType] = None, result_type: DataType = None, |
| accumulator_type: DataType = None, deterministic: bool = None, name: str = None, |
| func_type: str = "general") -> Union[UserDefinedAggregateFunctionWrapper, Callable]: |
| """ |
| Helper method for creating a user-defined aggregate function. |
| |
| Example: |
| :: |
| |
| >>> # The input_types is optional. |
| >>> @udaf(result_type=DataTypes.FLOAT(), func_type="pandas") |
| ... def mean_udaf(v): |
| ... return v.mean() |
| |
| :param f: user-defined aggregate function. |
| :param input_types: optional, the input data types. |
| :param result_type: the result data type. |
| :param accumulator_type: optional, the accumulator data type. |
| :param deterministic: the determinism of the function's results. True if and only if a call to |
| this function is guaranteed to always return the same result given the |
| same parameters. (default True) |
| :param name: the function name. |
| :param func_type: the type of the python function, available value: general, pandas, |
| (default: general) |
| :return: UserDefinedAggregateFunctionWrapper or function. |
| |
| .. versionadded:: 1.12.0 |
| """ |
| if func_type not in ('general', 'pandas'): |
| raise ValueError("The func_type must be one of 'general, pandas', got %s." |
| % func_type) |
| # decorator |
| if f is None: |
| return functools.partial(_create_udaf, input_types=input_types, result_type=result_type, |
| accumulator_type=accumulator_type, func_type=func_type, |
| deterministic=deterministic, name=name) |
| else: |
| return _create_udaf(f, input_types, result_type, accumulator_type, func_type, |
| deterministic, name) |
| |
| |
| def udtaf(f: Union[Callable, TableAggregateFunction, Type] = None, |
| input_types: Union[List[DataType], DataType] = None, result_type: DataType = None, |
| accumulator_type: DataType = None, deterministic: bool = None, name: str = None, |
| func_type: str = 'general') -> Union[UserDefinedAggregateFunctionWrapper, Callable]: |
| """ |
| Helper method for creating a user-defined table aggregate function. |
| |
| Example: |
| :: |
| |
| >>> # The input_types is optional. |
| >>> class Top2(TableAggregateFunction): |
| ... def emit_value(self, accumulator): |
| ... yield Row(accumulator[0]) |
| ... yield Row(accumulator[1]) |
| ... |
| ... def create_accumulator(self): |
| ... return [None, None] |
| ... |
| ... def accumulate(self, accumulator, *args): |
| ... if args[0] is not None: |
| ... if accumulator[0] is None or args[0] > accumulator[0]: |
| ... accumulator[1] = accumulator[0] |
| ... accumulator[0] = args[0] |
| ... elif accumulator[1] is None or args[0] > accumulator[1]: |
| ... accumulator[1] = args[0] |
| ... |
| ... def retract(self, accumulator, *args): |
| ... accumulator[0] = accumulator[0] - 1 |
| ... |
| ... def merge(self, accumulator, accumulators): |
| ... for other_acc in accumulators: |
| ... self.accumulate(accumulator, other_acc[0]) |
| ... self.accumulate(accumulator, other_acc[1]) |
| ... |
| ... def get_accumulator_type(self): |
| ... return DataTypes.ARRAY(DataTypes.BIGINT()) |
| ... |
| ... def get_result_type(self): |
| ... return DataTypes.ROW( |
| ... [DataTypes.FIELD("a", DataTypes.BIGINT())]) |
| >>> top2 = udtaf(Top2()) |
| |
| :param f: user-defined table aggregate function. |
| :param input_types: optional, the input data types. |
| :param result_type: the result data type. |
| :param accumulator_type: optional, the accumulator data type. |
| :param deterministic: the determinism of the function's results. True if and only if a call to |
| this function is guaranteed to always return the same result given the |
| same parameters. (default True) |
| :param name: the function name. |
| :param func_type: the type of the python function, available value: general |
| (default: general) |
| :return: UserDefinedAggregateFunctionWrapper or function. |
| |
| .. versionadded:: 1.13.0 |
| """ |
| if func_type != 'general': |
| raise ValueError("The func_type must be 'general', got %s." |
| % func_type) |
| if f is None: |
| return functools.partial(_create_udtaf, input_types=input_types, result_type=result_type, |
| accumulator_type=accumulator_type, func_type=func_type, |
| deterministic=deterministic, name=name) |
| else: |
| return _create_udtaf(f, input_types, result_type, accumulator_type, func_type, |
| deterministic, name) |