blob: 0d70cb1153d10d4e88d94f0dcf427d46b29f1653 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from abc import abstractmethod
from apache_beam.runners.worker.operations import Operation
from apache_beam.utils.windowed_value import WindowedValue
class FunctionOperation(Operation):
"""
Base class of function operation that will execute StatelessFunction or StatefulFunction for
each input element.
"""
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls):
super(FunctionOperation, self).__init__(name, spec, counter_factory, sampler)
self.consumer = consumers['output'][0]
self._value_coder_impl = self.consumer.windowed_coder.wrapped_value_coder.get_impl()
self.operation_cls = operation_cls
self.operation = self.generate_operation()
self.func = self.operation.func
self.operation.open()
def setup(self):
super(FunctionOperation, self).setup()
def start(self):
with self.scoped_start_state:
super(FunctionOperation, self).start()
def finish(self):
with self.scoped_finish_state:
super(FunctionOperation, self).finish()
self.operation.finish()
def needs_finalization(self):
return False
def reset(self):
super(FunctionOperation, self).reset()
def teardown(self):
with self.scoped_finish_state:
self.operation.close()
def progress_metrics(self):
metrics = super(FunctionOperation, self).progress_metrics()
metrics.processed_elements.measured.output_element_counts.clear()
tag = None
receiver = self.receivers[0]
metrics.processed_elements.measured.output_element_counts[
str(tag)] = receiver.opcounter.element_counter.value()
return metrics
def process(self, o: WindowedValue):
with self.scoped_process_state:
output_stream = self.consumer.output_stream
for value in o.value:
self._value_coder_impl.encode_to_stream(self.func(value), output_stream, True)
output_stream.maybe_flush()
def monitoring_infos(self, transform_id, tag_to_pcollection_id):
"""
Only pass user metric to Java
:param tag_to_pcollection_id: useless for user metric
"""
return super().user_monitoring_infos(transform_id)
@abstractmethod
def generate_operation(self):
pass
class StatelessFunctionOperation(FunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls):
super(StatelessFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers, operation_cls)
def generate_operation(self):
return self.operation_cls(self.spec)
class StatefulFunctionOperation(FunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cls,
keyed_state_backend):
self.keyed_state_backend = keyed_state_backend
super(StatefulFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers, operation_cls)
def generate_operation(self):
return self.operation_cls(self.spec, self.keyed_state_backend)