| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| from apache_beam.portability import common_urns |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.runners.worker import bundle_processor, operation_specs |
| from apache_beam.utils import proto_utils |
| |
| from pyflink.fn_execution import flink_fn_execution_pb2 |
| from pyflink.fn_execution.coders import from_proto, from_type_info_proto, TimeWindowCoder, \ |
| CountWindowCoder, FlattenRowCoder |
| from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend |
| |
| import pyflink.fn_execution.datastream.operations as datastream_operations |
| import pyflink.fn_execution.table.operations as table_operations |
| |
| try: |
| import pyflink.fn_execution.beam.beam_operations_fast as beam_operations |
| except ImportError: |
| import pyflink.fn_execution.beam.beam_operations_slow as beam_operations |
| |
| |
| # ----------------- UDF -------------------- |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions) |
| def create_scalar_function(factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatelessFunctionOperation, |
| table_operations.ScalarFunctionOperation) |
| |
| |
| # ----------------- UDTF -------------------- |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.TABLE_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions) |
| def create_table_function(factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatelessFunctionOperation, |
| table_operations.TableFunctionOperation) |
| |
| |
| # ----------------- UDAF -------------------- |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.STREAM_GROUP_AGGREGATE_URN, |
| flink_fn_execution_pb2.UserDefinedAggregateFunctions) |
| def create_aggregate_function(factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatefulFunctionOperation, |
| table_operations.StreamGroupAggregateOperation) |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.STREAM_GROUP_TABLE_AGGREGATE_URN, |
| flink_fn_execution_pb2.UserDefinedAggregateFunctions) |
| def create_table_aggregate_function(factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatefulFunctionOperation, |
| table_operations.StreamGroupTableAggregateOperation) |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.STREAM_GROUP_WINDOW_AGGREGATE_URN, |
| flink_fn_execution_pb2.UserDefinedAggregateFunctions) |
| def create_group_window_aggregate_function(factory, transform_id, transform_proto, parameter, |
| consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatefulFunctionOperation, |
| table_operations.StreamGroupWindowAggregateOperation) |
| |
| |
| # ----------------- Pandas UDAF -------------------- |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.PANDAS_AGGREGATE_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions) |
| def create_pandas_aggregate_function(factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatelessFunctionOperation, |
| table_operations.PandasAggregateFunctionOperation) |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| table_operations.PANDAS_BATCH_OVER_WINDOW_AGGREGATE_FUNCTION_URN, |
| flink_fn_execution_pb2.UserDefinedFunctions) |
| def create_pandas_over_window_aggregate_function( |
| factory, transform_id, transform_proto, parameter, consumers): |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, parameter, |
| beam_operations.StatelessFunctionOperation, |
| table_operations.PandasBatchOverWindowAggregateFunctionOperation) |
| |
| |
| # ----------------- DataStream -------------------- |
| |
| |
| @bundle_processor.BeamTransformFactory.register_urn( |
| common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) |
| def create_data_stream_keyed_process_function(factory, transform_id, transform_proto, parameter, |
| consumers): |
| urn = parameter.do_fn.urn |
| payload = proto_utils.parse_Bytes( |
| parameter.do_fn.payload, flink_fn_execution_pb2.UserDefinedDataStreamFunction) |
| if urn == datastream_operations.DATA_STREAM_STATELESS_FUNCTION_URN: |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, payload, |
| beam_operations.StatelessFunctionOperation, |
| datastream_operations.StatelessOperation) |
| else: |
| return _create_user_defined_function_operation( |
| factory, transform_proto, consumers, payload, |
| beam_operations.StatefulFunctionOperation, |
| datastream_operations.StatefulOperation) |
| |
| |
| # ----------------- Utilities -------------------- |
| |
| |
| def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, |
| beam_operation_cls, internal_operation_cls): |
| output_tags = list(transform_proto.outputs.keys()) |
| output_coders = factory.get_output_coders(transform_proto) |
| spec = operation_specs.WorkerDoFn( |
| serialized_fn=udfs_proto, |
| output_tags=output_tags, |
| input=None, |
| side_inputs=None, |
| output_coders=[output_coders[tag] for tag in output_tags]) |
| |
| if hasattr(spec.serialized_fn, "key_type"): |
| # keyed operation, need to create the KeyedStateBackend. |
| row_schema = spec.serialized_fn.key_type.row_schema |
| key_row_coder = FlattenRowCoder([from_proto(f.type) for f in row_schema.fields]) |
| if spec.serialized_fn.HasField('group_window'): |
| if spec.serialized_fn.group_window.is_time_window: |
| window_coder = TimeWindowCoder() |
| else: |
| window_coder = CountWindowCoder() |
| else: |
| window_coder = None |
| keyed_state_backend = RemoteKeyedStateBackend( |
| factory.state_handler, |
| key_row_coder, |
| window_coder, |
| spec.serialized_fn.state_cache_size, |
| spec.serialized_fn.map_state_read_cache_size, |
| spec.serialized_fn.map_state_write_cache_size) |
| |
| return beam_operation_cls( |
| transform_proto.unique_name, |
| spec, |
| factory.counter_factory, |
| factory.state_sampler, |
| consumers, |
| internal_operation_cls, |
| keyed_state_backend) |
| elif internal_operation_cls == datastream_operations.StatefulOperation: |
| key_row_coder = from_type_info_proto(spec.serialized_fn.key_type_info) |
| keyed_state_backend = RemoteKeyedStateBackend( |
| factory.state_handler, |
| key_row_coder, |
| None, |
| 1000, |
| 1000, |
| 1000) |
| return beam_operation_cls( |
| transform_proto.unique_name, |
| spec, |
| factory.counter_factory, |
| factory.state_sampler, |
| consumers, |
| internal_operation_cls, |
| keyed_state_backend) |
| else: |
| return beam_operation_cls( |
| transform_proto.unique_name, |
| spec, |
| factory.counter_factory, |
| factory.state_sampler, |
| consumers, |
| internal_operation_cls) |