blob: f779cfd965c6b3aeabc0f652897788b7531847aa [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
from collections.abc import Generator
from functools import partial
from typing import Any, Tuple, Dict, List
from pyflink.common import Row
from pyflink.fn_execution import pickle
from pyflink.serializers import PickleSerializer
from pyflink.table import functions
from pyflink.table.udf import DelegationTableFunction, DelegatingScalarFunction, \
ImperativeAggregateFunction, PandasAggregateFunctionWrapper
_func_num = 0
_constant_num = 0
def normalize_table_function_result(it):
def normalize_one_row(value):
if isinstance(value, tuple):
# We assume that tuple is a single line output
return [*value]
elif isinstance(value, Row):
# We assume that tuple is a single line output
return value._values
else:
# single field value
return [value]
if it is None:
return []
if isinstance(it, (list, range, Generator)):
def func():
for item in it:
yield normalize_one_row(item)
return func()
else:
return [normalize_one_row(it)]
def normalize_pandas_result(it):
import pandas as pd
arrays = []
for result in it:
if isinstance(result, (Row, Tuple)):
arrays.append(pd.concat([pd.Series([item]) for item in result], axis=1))
else:
arrays.append(pd.Series([result]))
return arrays
def wrap_input_series_as_dataframe(*args):
import pandas as pd
return pd.concat(args, axis=1)
def check_pandas_udf_result(f, *input_args):
output = f(*input_args)
import pandas as pd
assert type(output) == pd.Series or type(output) == pd.DataFrame, \
"The result type of Pandas UDF '%s' must be pandas.Series or pandas.DataFrame, got %s" \
% (f.__name__, type(output))
assert len(output) == len(input_args[0]), \
"The result length '%d' of Pandas UDF '%s' is not equal to the input length '%d'" \
% (len(output), f.__name__, len(input_args[0]))
return output
def extract_over_window_user_defined_function(user_defined_function_proto):
window_index = user_defined_function_proto.window_index
return (*extract_user_defined_function(user_defined_function_proto, True), window_index)
def extract_user_defined_function(user_defined_function_proto, pandas_udaf=False)\
-> Tuple[str, Dict, List]:
"""
Extracts user-defined-function from the proto representation of a
:class:`UserDefinedFunction`.
:param user_defined_function_proto: the proto representation of the Python
:param pandas_udaf: whether the user_defined_function_proto is pandas udaf
:class:`UserDefinedFunction`
"""
def _next_func_num():
global _func_num
_func_num = _func_num + 1
return _func_num
def _extract_input(args) -> Tuple[str, Dict, List]:
local_variable_dict = {}
local_funcs = []
args_str = []
for arg in args:
if arg.HasField("udf"):
# for chaining Python UDF input: the input argument is a Python ScalarFunction
udf_arg, udf_variable_dict, udf_funcs = extract_user_defined_function(arg.udf)
args_str.append(udf_arg)
local_variable_dict.update(udf_variable_dict)
local_funcs.extend(udf_funcs)
elif arg.HasField("inputOffset"):
# the input argument is a column of the input row
args_str.append("value[%s]" % arg.inputOffset)
else:
# the input argument is a constant value
constant_value_name, parsed_constant_value = \
_parse_constant_value(arg.inputConstant)
args_str.append(constant_value_name)
local_variable_dict[constant_value_name] = parsed_constant_value
return ",".join(args_str), local_variable_dict, local_funcs
variable_dict = {}
user_defined_funcs = []
user_defined_func = pickle.loads(user_defined_function_proto.payload)
if pandas_udaf:
user_defined_func = PandasAggregateFunctionWrapper(user_defined_func)
func_name = 'f%s' % _next_func_num()
if isinstance(user_defined_func, DelegatingScalarFunction) \
or isinstance(user_defined_func, DelegationTableFunction):
if user_defined_function_proto.is_pandas_udf:
variable_dict[func_name] = partial(check_pandas_udf_result, user_defined_func.func)
else:
variable_dict[func_name] = user_defined_func.func
else:
variable_dict[func_name] = user_defined_func.eval
user_defined_funcs.append(user_defined_func)
func_args, input_variable_dict, input_funcs = _extract_input(user_defined_function_proto.inputs)
variable_dict.update(input_variable_dict)
user_defined_funcs.extend(input_funcs)
if user_defined_function_proto.takes_row_as_input:
if input_variable_dict:
# for constant or other udfs as input arguments.
func_str = "%s(%s)" % (func_name, func_args)
elif user_defined_function_proto.is_pandas_udf or pandas_udaf:
# for pandas udf/udaf, the input data structure is a List of Pandas.Series
# we need to merge these Pandas.Series into a Pandas.DataFrame
variable_dict['wrap_input_series_as_dataframe'] = wrap_input_series_as_dataframe
func_str = "%s(wrap_input_series_as_dataframe(%s))" % (func_name, func_args)
else:
# directly use `value` as input argument
# e.g.
# lambda value: Row(value[0], value[1])
# can be optimized to
# lambda value: value
func_str = "%s(value)" % func_name
else:
func_str = "%s(%s)" % (func_name, func_args)
return func_str, variable_dict, user_defined_funcs
def _parse_constant_value(constant_value) -> Tuple[str, Any]:
j_type = constant_value[0]
serializer = PickleSerializer()
pickled_data = serializer.loads(constant_value[1:])
# the type set contains
# TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
# the pickled_data doesn't need to transfer to anther python object
if j_type == 0:
parsed_constant_value = pickled_data
# the type is DATE
elif j_type == 1:
parsed_constant_value = \
datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
# the type is TIME
elif j_type == 2:
seconds, milliseconds = divmod(pickled_data, 1000)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
parsed_constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
# the type is TIMESTAMP
elif j_type == 3:
parsed_constant_value = \
datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
+ datetime.timedelta(milliseconds=pickled_data)
else:
raise Exception("Unknown type %s, should never happen" % str(j_type))
def _next_constant_num():
global _constant_num
_constant_num = _constant_num + 1
return _constant_num
constant_value_name = 'c%s' % _next_constant_num()
return constant_value_name, parsed_constant_value
def extract_user_defined_aggregate_function(
current_index,
user_defined_function_proto,
distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int], List[int]]]):
user_defined_agg = load_aggregate_function(user_defined_function_proto.payload)
assert isinstance(user_defined_agg, ImperativeAggregateFunction)
args_str = []
local_variable_dict = {}
for arg in user_defined_function_proto.inputs:
if arg.HasField("inputOffset"):
# the input argument is a column of the input row
args_str.append("value[%s]" % arg.inputOffset)
else:
# the input argument is a constant value
constant_value_name, parsed_constant_value = \
_parse_constant_value(arg.inputConstant)
for key, value in local_variable_dict.items():
if value == parsed_constant_value:
constant_value_name = key
break
if constant_value_name not in local_variable_dict:
local_variable_dict[constant_value_name] = parsed_constant_value
args_str.append(constant_value_name)
if user_defined_function_proto.distinct:
if tuple(args_str) in distinct_info_dict:
distinct_info_dict[tuple(args_str)][0].append(current_index)
distinct_info_dict[tuple(args_str)][1].append(user_defined_function_proto.filter_arg)
distinct_index = distinct_info_dict[tuple(args_str)][0][0]
else:
distinct_info_dict[tuple(args_str)] = \
([current_index], [user_defined_function_proto.filter_arg])
distinct_index = current_index
else:
distinct_index = -1
if user_defined_function_proto.takes_row_as_input and not local_variable_dict:
# directly use `value` as input argument
# e.g.
# lambda value: Row(value[0], value[1])
# can be optimized to
# lambda value: value
func_str = "lambda value : [value]"
else:
func_str = "lambda value : (%s,)" % ",".join(args_str)
return user_defined_agg, \
eval(func_str, local_variable_dict) \
if args_str else lambda v: tuple(), \
user_defined_function_proto.filter_arg, \
distinct_index
def is_built_in_function(payload):
# The payload may be a pickled bytes or the class name of the built-in functions.
# If it represents a built-in function, it will start with 0x00.
# If it is a pickled bytes, it will start with 0x80.
return payload[0] == 0
def load_aggregate_function(payload):
if is_built_in_function(payload):
built_in_function_class_name = payload[1:].decode("utf-8")
cls = getattr(functions, built_in_function_class_name)
return cls()
else:
return pickle.loads(payload)