blob: 463ee40eb704ee471a611ea8db892e348228e613 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
from enum import Enum
from functools import partial
from typing import Any, Tuple, Dict, List
from pyflink.common import Row
from pyflink.datastream.time_domain import TimeDomain
from pyflink.fn_execution import flink_fn_execution_pb2, pickle
from pyflink.serializers import PickleSerializer
from pyflink.table import functions
from pyflink.table.udf import DelegationTableFunction, DelegatingScalarFunction, \
AggregateFunction, PandasAggregateFunctionWrapper
_func_num = 0
_constant_num = 0
def wrap_pandas_result(it):
import pandas as pd
return [pd.Series([result]) for result in it]
def check_pandas_udf_result(f, *input_args):
output = f(*input_args)
import pandas as pd
assert type(output) == pd.Series or type(output) == pd.DataFrame, \
"The result type of Pandas UDF '%s' must be pandas.Series or pandas.DataFrame, got %s" \
% (f.__name__, type(output))
assert len(output) == len(input_args[0]), \
"The result length '%d' of Pandas UDF '%s' is not equal to the input length '%d'" \
% (len(output), f.__name__, len(input_args[0]))
return output
def extract_over_window_user_defined_function(user_defined_function_proto):
window_index = user_defined_function_proto.window_index
return (*extract_user_defined_function(user_defined_function_proto, True), window_index)
def extract_user_defined_function(user_defined_function_proto, pandas_udaf=False)\
-> Tuple[str, Dict, List]:
"""
Extracts user-defined-function from the proto representation of a
:class:`UserDefinedFunction`.
:param user_defined_function_proto: the proto representation of the Python
:param pandas_udaf: whether the user_defined_function_proto is pandas udaf
:class:`UserDefinedFunction`
"""
def _next_func_num():
global _func_num
_func_num = _func_num + 1
return _func_num
def _extract_input(args) -> Tuple[str, Dict, List]:
local_variable_dict = {}
local_funcs = []
args_str = []
for arg in args:
if arg.HasField("udf"):
# for chaining Python UDF input: the input argument is a Python ScalarFunction
udf_arg, udf_variable_dict, udf_funcs = extract_user_defined_function(arg.udf)
args_str.append(udf_arg)
local_variable_dict.update(udf_variable_dict)
local_funcs.extend(udf_funcs)
elif arg.HasField("inputOffset"):
# the input argument is a column of the input row
args_str.append("value[%s]" % arg.inputOffset)
else:
# the input argument is a constant value
constant_value_name, parsed_constant_value = \
_parse_constant_value(arg.inputConstant)
args_str.append(constant_value_name)
local_variable_dict[constant_value_name] = parsed_constant_value
return ",".join(args_str), local_variable_dict, local_funcs
variable_dict = {}
user_defined_funcs = []
user_defined_func = pickle.loads(user_defined_function_proto.payload)
if pandas_udaf:
user_defined_func = PandasAggregateFunctionWrapper(user_defined_func)
func_name = 'f%s' % _next_func_num()
if isinstance(user_defined_func, DelegatingScalarFunction) \
or isinstance(user_defined_func, DelegationTableFunction):
if user_defined_function_proto.is_pandas_udf:
variable_dict[func_name] = partial(check_pandas_udf_result, user_defined_func.func)
else:
variable_dict[func_name] = user_defined_func.func
else:
variable_dict[func_name] = user_defined_func.eval
user_defined_funcs.append(user_defined_func)
func_args, input_variable_dict, input_funcs = _extract_input(user_defined_function_proto.inputs)
variable_dict.update(input_variable_dict)
user_defined_funcs.extend(input_funcs)
return "%s(%s)" % (func_name, func_args), variable_dict, user_defined_funcs
def _parse_constant_value(constant_value) -> Tuple[str, Any]:
j_type = constant_value[0]
serializer = PickleSerializer()
pickled_data = serializer.loads(constant_value[1:])
# the type set contains
# TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
# the pickled_data doesn't need to transfer to anther python object
if j_type == 0:
parsed_constant_value = pickled_data
# the type is DATE
elif j_type == 1:
parsed_constant_value = \
datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
# the type is TIME
elif j_type == 2:
seconds, milliseconds = divmod(pickled_data, 1000)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
parsed_constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
# the type is TIMESTAMP
elif j_type == 3:
parsed_constant_value = \
datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
+ datetime.timedelta(milliseconds=pickled_data)
else:
raise Exception("Unknown type %s, should never happen" % str(j_type))
def _next_constant_num():
global _constant_num
_constant_num = _constant_num + 1
return _constant_num
constant_value_name = 'c%s' % _next_constant_num()
return constant_value_name, parsed_constant_value
def extract_user_defined_aggregate_function(
current_index,
user_defined_function_proto,
distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int], List[int]]]):
user_defined_agg = load_aggregate_function(user_defined_function_proto.payload)
assert isinstance(user_defined_agg, AggregateFunction)
args_str = []
local_variable_dict = {}
for arg in user_defined_function_proto.inputs:
if arg.HasField("inputOffset"):
# the input argument is a column of the input row
args_str.append("value[%s]" % arg.inputOffset)
else:
# the input argument is a constant value
constant_value_name, parsed_constant_value = \
_parse_constant_value(arg.inputConstant)
for key, value in local_variable_dict.items():
if value == parsed_constant_value:
constant_value_name = key
break
if constant_value_name not in local_variable_dict:
local_variable_dict[constant_value_name] = parsed_constant_value
args_str.append(constant_value_name)
if user_defined_function_proto.distinct:
if tuple(args_str) in distinct_info_dict:
distinct_info_dict[tuple(args_str)][0].append(current_index)
distinct_info_dict[tuple(args_str)][1].append(user_defined_function_proto.filter_arg)
distinct_index = distinct_info_dict[tuple(args_str)][0][0]
else:
distinct_info_dict[tuple(args_str)] = \
([current_index], [user_defined_function_proto.filter_arg])
distinct_index = current_index
else:
distinct_index = -1
return user_defined_agg, \
eval("lambda value : (%s,)" % ",".join(args_str), local_variable_dict) \
if args_str else lambda v: tuple(), \
user_defined_function_proto.filter_arg, \
distinct_index
def is_built_in_function(payload):
# The payload may be a pickled bytes or the class name of the built-in functions.
# If it represents a built-in function, it will start with 0x00.
# If it is a pickled bytes, it will start with 0x80.
return payload[0] == 0
def load_aggregate_function(payload):
if is_built_in_function(payload):
built_in_function_class_name = payload[1:].decode("utf-8")
cls = getattr(functions, built_in_function_class_name)
return cls()
else:
return pickle.loads(payload)
def extract_data_stream_stateless_function(udf_proto):
"""
Extracts user-defined-function from the proto representation of a
:class:`Function`.
:param udf_proto: the proto representation of the Python :class:`Function`
"""
func_type = udf_proto.function_type
UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
func = None
user_defined_func = pickle.loads(udf_proto.payload)
if func_type == UserDefinedDataStreamFunction.MAP:
func = user_defined_func.map
elif func_type == UserDefinedDataStreamFunction.FLAT_MAP:
func = user_defined_func.flat_map
elif func_type == UserDefinedDataStreamFunction.REDUCE:
reduce_func = user_defined_func.reduce
def wrapped_func(value):
return reduce_func(value[0], value[1])
func = wrapped_func
elif func_type == UserDefinedDataStreamFunction.CO_MAP:
co_map_func = user_defined_func
def wrapped_func(value):
# value in format of: [INPUT_FLAG, REAL_VALUE]
# INPUT_FLAG value of True for the left stream, while False for the right stream
return Row(CoMapFunctionOutputFlag.LEFT.value, co_map_func.map1(value[1])) \
if value[0] else Row(CoMapFunctionOutputFlag.RIGHT.value,
co_map_func.map2(value[2]))
func = wrapped_func
elif func_type == UserDefinedDataStreamFunction.CO_FLAT_MAP:
co_flat_map_func = user_defined_func
def wrapped_func(value):
if value[0]:
result = co_flat_map_func.flat_map1(value[1])
if result:
for result_val in result:
yield Row(CoFlatMapFunctionOutputFlag.LEFT.value, result_val)
yield Row(CoFlatMapFunctionOutputFlag.LEFT_END.value, None)
else:
result = co_flat_map_func.flat_map2(value[2])
if result:
for result_val in result:
yield Row(CoFlatMapFunctionOutputFlag.RIGHT.value, result_val)
yield Row(CoFlatMapFunctionOutputFlag.RIGHT_END.value, None)
func = wrapped_func
elif func_type == UserDefinedDataStreamFunction.TIMESTAMP_ASSIGNER:
extract_timestamp = user_defined_func.extract_timestamp
def wrapped_func(value):
pre_timestamp = value[0]
real_data = value[1]
return extract_timestamp(real_data, pre_timestamp)
func = wrapped_func
return func, user_defined_func
def extract_process_function(user_defined_function_proto, ctx):
process_function = pickle.loads(user_defined_function_proto.payload)
process_element = process_function.process_element
def wrapped_process_function(value):
# VALUE[CURRENT_TIMESTAMP, CURRENT_WATERMARK, NORMAL_DATA]
ctx.set_timestamp(value[0])
ctx.timer_service().set_current_watermark(value[1])
output_result = process_element(value[2], ctx)
return output_result
return wrapped_process_function, process_function
def extract_keyed_process_function(user_defined_function_proto, ctx, on_timer_ctx,
collector, keyed_state_backend):
process_function = pickle.loads(user_defined_function_proto.payload)
process_element = process_function.process_element
on_timer = process_function.on_timer
def wrapped_keyed_process_function(value):
if value[0] is not None:
# it is timer data
# VALUE: TIMER_FLAG, TIMESTAMP_OF_TIMER, CURRENT_WATERMARK, CURRENT_KEY_OF_TIMER, None
on_timer_ctx.set_timestamp(value[1])
on_timer_ctx.timer_service().set_current_watermark(value[2])
current_key = value[3]
on_timer_ctx.set_current_key(current_key)
keyed_state_backend.set_current_key(current_key)
if value[0] == KeyedProcessFunctionInputFlag.EVENT_TIME_TIMER.value:
on_timer_ctx.set_time_domain(TimeDomain.EVENT_TIME)
elif value[0] == KeyedProcessFunctionInputFlag.PROC_TIME_TIMER.value:
on_timer_ctx.set_time_domain(TimeDomain.PROCESSING_TIME)
else:
raise TypeError("TimeCharacteristic[%s] is not supported." % str(value[0]))
output_result = on_timer(value[1], on_timer_ctx)
else:
# it is normal data
# VALUE: TIMER_FLAG, CURRENT_TIMESTAMP, CURRENT_WATERMARK, None, NORMAL_DATA
# NORMAL_DATA: CURRENT_KEY, DATA
ctx.set_timestamp(value[1])
ctx.timer_service().set_current_watermark(value[2])
current_key = value[4][0]
ctx.set_current_key(current_key)
keyed_state_backend.set_current_key(Row(current_key))
output_result = process_element(value[4][1], ctx)
if output_result:
for result in output_result:
yield Row(None, None, None, result)
for result in collector.buf:
# 0: proc time timer data
# 1: event time timer data
# 2: normal data
# result_row: [TIMER_FLAG, TIMER TYPE, TIMER_KEY, RESULT_DATA]
yield Row(result[0], result[1], result[2], None)
collector.clear()
return wrapped_keyed_process_function, process_function
"""
All these Enum Classes MUST be in sync with
org.apache.flink.streaming.api.utils.PythonOperatorUtils if there are any changes.
"""
class KeyedProcessFunctionInputFlag(Enum):
EVENT_TIME_TIMER = 0
PROC_TIME_TIMER = 1
NORMAL_DATA = 2
class KeyedProcessFunctionOutputFlag(Enum):
REGISTER_EVENT_TIMER = 0
REGISTER_PROC_TIMER = 1
NORMAL_DATA = 2
DEL_EVENT_TIMER = 3
DEL_PROC_TIMER = 4
class CoFlatMapFunctionOutputFlag(Enum):
LEFT = 0
RIGHT = 1
LEFT_END = 2
RIGHT_END = 3
class CoMapFunctionOutputFlag(Enum):
LEFT = 0
RIGHT = 1