blob: 57393d8b97c53ab328531f2a208956af76f81b3f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Support for Dataflow triggers.
Triggers control when in processing time windows get emitted.
"""
# pytype: skip-file
from __future__ import absolute_import
import collections
import copy
import logging
import numbers
from abc import ABCMeta
from abc import abstractmethod
from builtins import object
from future.moves.itertools import zip_longest
from future.utils import iteritems
from future.utils import with_metaclass
from apache_beam.coders import coder_impl
from apache_beam.coders import observable
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.utils import windowed_value
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import TIME_GRANULARITY
# AfterCount is experimental. No backwards compatibility guarantees.
__all__ = [
'AccumulationMode',
'TriggerFn',
'DefaultTrigger',
'AfterWatermark',
'AfterProcessingTime',
'AfterCount',
'Repeatedly',
'AfterAny',
'AfterAll',
'AfterEach',
'OrFinally',
]
_LOGGER = logging.getLogger(__name__)
class AccumulationMode(object):
"""Controls what to do with data when a trigger fires multiple times."""
DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING
ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING
# TODO(robertwb): Provide retractions of previous outputs.
# RETRACTING = 3
class _StateTag(with_metaclass(ABCMeta, object)): # type: ignore[misc]
"""An identifier used to store and retrieve typed, combinable state.
The given tag must be unique for this step."""
def __init__(self, tag):
self.tag = tag
class _ValueStateTag(_StateTag):
"""StateTag pointing to an element."""
def __repr__(self):
return 'ValueStateTag(%s)' % (self.tag)
def with_prefix(self, prefix):
return _ValueStateTag(prefix + self.tag)
class _SetStateTag(_StateTag):
"""StateTag pointing to an element."""
def __repr__(self):
return 'SetStateTag({tag})'.format(tag=self.tag)
def with_prefix(self, prefix):
return _SetStateTag(prefix + self.tag)
class _CombiningValueStateTag(_StateTag):
"""StateTag pointing to an element, accumulated with a combiner.
The given tag must be unique for this step. The given CombineFn will be
applied (possibly incrementally and eagerly) when adding elements."""
# TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn)
def __init__(self, tag, combine_fn):
super(_CombiningValueStateTag, self).__init__(tag)
if not combine_fn:
raise ValueError('combine_fn must be specified.')
if not isinstance(combine_fn, core.CombineFn):
combine_fn = core.CombineFn.from_callable(combine_fn)
self.combine_fn = combine_fn
def __repr__(self):
return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn)
def with_prefix(self, prefix):
return _CombiningValueStateTag(prefix + self.tag, self.combine_fn)
def without_extraction(self):
class NoExtractionCombineFn(core.CombineFn):
create_accumulator = self.combine_fn.create_accumulator
add_input = self.combine_fn.add_input
merge_accumulators = self.combine_fn.merge_accumulators
compact = self.combine_fn.compact
extract_output = staticmethod(lambda x: x)
return _CombiningValueStateTag(self.tag, NoExtractionCombineFn())
class _ListStateTag(_StateTag):
"""StateTag pointing to a list of elements."""
def __repr__(self):
return 'ListStateTag(%s)' % self.tag
def with_prefix(self, prefix):
return _ListStateTag(prefix + self.tag)
class _WatermarkHoldStateTag(_StateTag):
def __init__(self, tag, timestamp_combiner_impl):
super(_WatermarkHoldStateTag, self).__init__(tag)
self.timestamp_combiner_impl = timestamp_combiner_impl
def __repr__(self):
return 'WatermarkHoldStateTag(%s, %s)' % (
self.tag, self.timestamp_combiner_impl)
def with_prefix(self, prefix):
return _WatermarkHoldStateTag(
prefix + self.tag, self.timestamp_combiner_impl)
# pylint: disable=unused-argument
# TODO(robertwb): Provisional API, Java likely to change as well.
class TriggerFn(with_metaclass(ABCMeta, object)): # type: ignore[misc]
"""A TriggerFn determines when window (panes) are emitted.
See https://beam.apache.org/documentation/programming-guide/#triggers
"""
@abstractmethod
def on_element(self, element, window, context):
"""Called when a new element arrives in a window.
Args:
element: the element being added
window: the window to which the element is being added
context: a context (e.g. a TriggerContext instance) for managing state
and setting timers
"""
pass
@abstractmethod
def on_merge(self, to_be_merged, merge_result, context):
"""Called when multiple windows are merged.
Args:
to_be_merged: the set of windows to be merged
merge_result: the window into which the windows are being merged
context: a context (e.g. a TriggerContext instance) for managing state
and setting timers
"""
pass
@abstractmethod
def should_fire(self, time_domain, timestamp, window, context):
"""Whether this trigger should cause the window to fire.
Args:
time_domain: WATERMARK for event-time timers and REAL_TIME for
processing-time timers.
timestamp: for time_domain WATERMARK, it represents the
watermark: (a lower bound on) the watermark of the system
and for time_domain REAL_TIME, it represents the
trigger: timestamp of the processing-time timer.
window: the window whose trigger is being considered
context: a context (e.g. a TriggerContext instance) for managing state
and setting timers
Returns:
whether this trigger should cause a firing
"""
pass
@abstractmethod
def has_ontime_pane(self):
"""Whether this trigger creates an empty pane even if there are no elements.
Returns:
True if this trigger guarantees that there will always be an ON_TIME pane
even if there are no elements in that pane.
"""
pass
@abstractmethod
def on_fire(self, watermark, window, context):
"""Called when a trigger actually fires.
Args:
watermark: (a lower bound on) the watermark of the system
window: the window whose trigger is being fired
context: a context (e.g. a TriggerContext instance) for managing state
and setting timers
Returns:
whether this trigger is finished
"""
pass
@abstractmethod
def reset(self, window, context):
"""Clear any state and timers used by this TriggerFn."""
pass
# pylint: enable=unused-argument
@staticmethod
def from_runner_api(proto, context):
return {
'after_all': AfterAll,
'after_any': AfterAny,
'after_each': AfterEach,
'after_end_of_window': AfterWatermark,
'after_processing_time': AfterProcessingTime,
# after_processing_time, after_synchronized_processing_time
'always': Always,
'default': DefaultTrigger,
'element_count': AfterCount,
# never
'or_finally': OrFinally,
'repeat': Repeatedly,
}[proto.WhichOneof('trigger')].from_runner_api(proto, context)
@abstractmethod
def to_runner_api(self, unused_context):
pass
def __ne__(self, other):
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other
class DefaultTrigger(TriggerFn):
"""Semantically Repeatedly(AfterWatermark()), but more optimized."""
def __init__(self):
pass
def __repr__(self):
return 'DefaultTrigger()'
def on_element(self, element, window, context):
context.set_timer('', TimeDomain.WATERMARK, window.end)
def on_merge(self, to_be_merged, merge_result, context):
# Note: Timer clearing solely an optimization.
for window in to_be_merged:
if window.end != merge_result.end:
context.clear_timer('', TimeDomain.WATERMARK)
def should_fire(self, time_domain, watermark, window, context):
if watermark >= window.end:
# Explicitly clear the timer so that late elements are not emitted again
# when the timer is fired.
context.clear_timer('', TimeDomain.WATERMARK)
return watermark >= window.end
def on_fire(self, watermark, window, context):
return False
def reset(self, window, context):
context.clear_timer('', TimeDomain.WATERMARK)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
@staticmethod
def from_runner_api(proto, context):
return DefaultTrigger()
def to_runner_api(self, unused_context):
return beam_runner_api_pb2.Trigger(
default=beam_runner_api_pb2.Trigger.Default())
def has_ontime_pane(self):
return True
class AfterProcessingTime(TriggerFn):
"""Fire exactly once after a specified delay from processing time.
AfterProcessingTime is experimental. No backwards compatibility guarantees.
"""
def __init__(self, delay=0):
"""Initialize a processing time trigger with a delay in seconds."""
self.delay = delay
def __repr__(self):
return 'AfterProcessingTime(delay=%d)' % self.delay
def on_element(self, element, window, context):
context.set_timer(
'', TimeDomain.REAL_TIME, context.get_current_time() + self.delay)
def on_merge(self, to_be_merged, merge_result, context):
# timers will be kept through merging
pass
def should_fire(self, time_domain, timestamp, window, context):
if time_domain == TimeDomain.REAL_TIME:
return True
def on_fire(self, timestamp, window, context):
return True
def reset(self, window, context):
pass
@staticmethod
def from_runner_api(proto, context):
return AfterProcessingTime(
delay=(
proto.after_processing_time.timestamp_transforms[0].delay.
delay_millis) // 1000)
def to_runner_api(self, context):
delay_proto = beam_runner_api_pb2.TimestampTransform(
delay=beam_runner_api_pb2.TimestampTransform.Delay(
delay_millis=self.delay * 1000))
return beam_runner_api_pb2.Trigger(
after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime(
timestamp_transforms=[delay_proto]))
def has_ontime_pane(self):
return False
class Always(TriggerFn):
"""Repeatedly invoke the given trigger, never finishing."""
def __init__(self):
pass
def __repr__(self):
return 'Always'
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 1
def on_element(self, element, window, context):
pass
def on_merge(self, to_be_merged, merge_result, context):
pass
def has_ontime_pane(self):
False
def reset(self, window, context):
pass
def should_fire(self, time_domain, watermark, window, context):
return True
def on_fire(self, watermark, window, context):
return False
@staticmethod
def from_runner_api(proto, context):
return Always()
def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
always=beam_runner_api_pb2.Trigger.Always())
class AfterWatermark(TriggerFn):
"""Fire exactly once when the watermark passes the end of the window.
Args:
early: if not None, a speculative trigger to repeatedly evaluate before
the watermark passes the end of the window
late: if not None, a speculative trigger to repeatedly evaluate after
the watermark passes the end of the window
"""
LATE_TAG = _CombiningValueStateTag('is_late', any)
def __init__(self, early=None, late=None):
self.early = Repeatedly(early) if early else None
self.late = Repeatedly(late) if late else None
def __repr__(self):
qualifiers = []
if self.early:
qualifiers.append('early=%s' % self.early.underlying)
if self.late:
qualifiers.append('late=%s' % self.late.underlying)
return 'AfterWatermark(%s)' % ', '.join(qualifiers)
def is_late(self, context):
return self.late and context.get_state(self.LATE_TAG)
def on_element(self, element, window, context):
if self.is_late(context):
self.late.on_element(element, window, NestedContext(context, 'late'))
else:
context.set_timer('', TimeDomain.WATERMARK, window.end)
if self.early:
self.early.on_element(element, window, NestedContext(context, 'early'))
def on_merge(self, to_be_merged, merge_result, context):
# TODO(robertwb): Figure out whether the 'rewind' semantics could be used
# here.
if self.is_late(context):
self.late.on_merge(
to_be_merged, merge_result, NestedContext(context, 'late'))
else:
# Note: Timer clearing solely an optimization.
for window in to_be_merged:
if window.end != merge_result.end:
context.clear_timer('', TimeDomain.WATERMARK)
if self.early:
self.early.on_merge(
to_be_merged, merge_result, NestedContext(context, 'early'))
def should_fire(self, time_domain, watermark, window, context):
if self.is_late(context):
return self.late.should_fire(
time_domain, watermark, window, NestedContext(context, 'late'))
elif watermark >= window.end:
# Explicitly clear the timer so that late elements are not emitted again
# when the timer is fired.
context.clear_timer('', TimeDomain.WATERMARK)
return True
elif self.early:
return self.early.should_fire(
time_domain, watermark, window, NestedContext(context, 'early'))
return False
def on_fire(self, watermark, window, context):
if self.is_late(context):
return self.late.on_fire(
watermark, window, NestedContext(context, 'late'))
elif watermark >= window.end:
context.add_state(self.LATE_TAG, True)
return not self.late
elif self.early:
self.early.on_fire(watermark, window, NestedContext(context, 'early'))
return False
def reset(self, window, context):
if self.late:
context.clear_state(self.LATE_TAG)
if self.early:
self.early.reset(window, NestedContext(context, 'early'))
if self.late:
self.late.reset(window, NestedContext(context, 'late'))
def __eq__(self, other):
return (
type(self) == type(other) and self.early == other.early and
self.late == other.late)
def __hash__(self):
return hash((type(self), self.early, self.late))
@staticmethod
def from_runner_api(proto, context):
return AfterWatermark(
early=TriggerFn.from_runner_api(
proto.after_end_of_window.early_firings, context)
if proto.after_end_of_window.HasField('early_firings') else None,
late=TriggerFn.from_runner_api(
proto.after_end_of_window.late_firings, context)
if proto.after_end_of_window.HasField('late_firings') else None)
def to_runner_api(self, context):
early_proto = self.early.underlying.to_runner_api(
context) if self.early else None
late_proto = self.late.underlying.to_runner_api(
context) if self.late else None
return beam_runner_api_pb2.Trigger(
after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow(
early_firings=early_proto, late_firings=late_proto))
def has_ontime_pane(self):
return True
class AfterCount(TriggerFn):
"""Fire when there are at least count elements in this window pane.
AfterCount is experimental. No backwards compatibility guarantees.
"""
COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn())
def __init__(self, count):
if not isinstance(count, numbers.Integral) or count < 1:
raise ValueError("count (%d) must be a positive integer." % count)
self.count = count
def __repr__(self):
return 'AfterCount(%s)' % self.count
def __eq__(self, other):
return type(self) == type(other) and self.count == other.count
def __hash__(self):
return hash(self.count)
def on_element(self, element, window, context):
context.add_state(self.COUNT_TAG, 1)
def on_merge(self, to_be_merged, merge_result, context):
# states automatically merged
pass
def should_fire(self, time_domain, watermark, window, context):
return context.get_state(self.COUNT_TAG) >= self.count
def on_fire(self, watermark, window, context):
return True
def reset(self, window, context):
context.clear_state(self.COUNT_TAG)
@staticmethod
def from_runner_api(proto, unused_context):
return AfterCount(proto.element_count.element_count)
def to_runner_api(self, unused_context):
return beam_runner_api_pb2.Trigger(
element_count=beam_runner_api_pb2.Trigger.ElementCount(
element_count=self.count))
def has_ontime_pane(self):
return False
class Repeatedly(TriggerFn):
"""Repeatedly invoke the given trigger, never finishing."""
def __init__(self, underlying):
self.underlying = underlying
def __repr__(self):
return 'Repeatedly(%s)' % self.underlying
def __eq__(self, other):
return type(self) == type(other) and self.underlying == other.underlying
def __hash__(self):
return hash(self.underlying)
def on_element(self, element, window, context):
self.underlying.on_element(element, window, context)
def on_merge(self, to_be_merged, merge_result, context):
self.underlying.on_merge(to_be_merged, merge_result, context)
def should_fire(self, time_domain, watermark, window, context):
return self.underlying.should_fire(time_domain, watermark, window, context)
def on_fire(self, watermark, window, context):
if self.underlying.on_fire(watermark, window, context):
self.underlying.reset(window, context)
return False
def reset(self, window, context):
self.underlying.reset(window, context)
@staticmethod
def from_runner_api(proto, context):
return Repeatedly(
TriggerFn.from_runner_api(proto.repeat.subtrigger, context))
def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
repeat=beam_runner_api_pb2.Trigger.Repeat(
subtrigger=self.underlying.to_runner_api(context)))
def has_ontime_pane(self):
return self.underlying.has_ontime_pane()
class _ParallelTriggerFn(with_metaclass(ABCMeta,
TriggerFn)): # type: ignore[misc]
def __init__(self, *triggers):
self.triggers = triggers
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__, ', '.join(str(t) for t in self.triggers))
def __eq__(self, other):
return type(self) == type(other) and self.triggers == other.triggers
def __hash__(self):
return hash(self.triggers)
@abstractmethod
def combine_op(self, trigger_results):
pass
def on_element(self, element, window, context):
for ix, trigger in enumerate(self.triggers):
trigger.on_element(element, window, self._sub_context(context, ix))
def on_merge(self, to_be_merged, merge_result, context):
for ix, trigger in enumerate(self.triggers):
trigger.on_merge(
to_be_merged, merge_result, self._sub_context(context, ix))
def should_fire(self, time_domain, watermark, window, context):
self._time_domain = time_domain
return self.combine_op(
trigger.should_fire(
time_domain, watermark, window, self._sub_context(context, ix))
for ix,
trigger in enumerate(self.triggers))
def on_fire(self, watermark, window, context):
finished = []
for ix, trigger in enumerate(self.triggers):
nested_context = self._sub_context(context, ix)
if trigger.should_fire(TimeDomain.WATERMARK,
watermark,
window,
nested_context):
finished.append(trigger.on_fire(watermark, window, nested_context))
return self.combine_op(finished)
def reset(self, window, context):
for ix, trigger in enumerate(self.triggers):
trigger.reset(window, self._sub_context(context, ix))
@staticmethod
def _sub_context(context, index):
return NestedContext(context, '%d/' % index)
@staticmethod
def from_runner_api(proto, context):
subtriggers = [
TriggerFn.from_runner_api(subtrigger, context) for subtrigger in
proto.after_all.subtriggers or proto.after_any.subtriggers
]
if proto.after_all.subtriggers:
return AfterAll(*subtriggers)
else:
return AfterAny(*subtriggers)
def to_runner_api(self, context):
subtriggers = [
subtrigger.to_runner_api(context) for subtrigger in self.triggers
]
if self.combine_op == all:
return beam_runner_api_pb2.Trigger(
after_all=beam_runner_api_pb2.Trigger.AfterAll(
subtriggers=subtriggers))
elif self.combine_op == any:
return beam_runner_api_pb2.Trigger(
after_any=beam_runner_api_pb2.Trigger.AfterAny(
subtriggers=subtriggers))
else:
raise NotImplementedError(self)
def has_ontime_pane(self):
return any(t.has_ontime_pane() for t in self.triggers)
class AfterAny(_ParallelTriggerFn):
"""Fires when any subtrigger fires.
Also finishes when any subtrigger finishes.
"""
combine_op = any
class AfterAll(_ParallelTriggerFn):
"""Fires when all subtriggers have fired.
Also finishes when all subtriggers have finished.
"""
combine_op = all
class AfterEach(TriggerFn):
INDEX_TAG = _CombiningValueStateTag(
'index', (lambda indices: 0 if not indices else max(indices)))
def __init__(self, *triggers):
self.triggers = triggers
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__, ', '.join(str(t) for t in self.triggers))
def __eq__(self, other):
return type(self) == type(other) and self.triggers == other.triggers
def __hash__(self):
return hash(self.triggers)
def on_element(self, element, window, context):
ix = context.get_state(self.INDEX_TAG)
if ix < len(self.triggers):
self.triggers[ix].on_element(
element, window, self._sub_context(context, ix))
def on_merge(self, to_be_merged, merge_result, context):
# This takes the furthest window on merging.
# TODO(robertwb): Revisit this when merging windows logic is settled for
# all possible merging situations.
ix = context.get_state(self.INDEX_TAG)
if ix < len(self.triggers):
self.triggers[ix].on_merge(
to_be_merged, merge_result, self._sub_context(context, ix))
def should_fire(self, time_domain, watermark, window, context):
ix = context.get_state(self.INDEX_TAG)
if ix < len(self.triggers):
return self.triggers[ix].should_fire(
time_domain, watermark, window, self._sub_context(context, ix))
def on_fire(self, watermark, window, context):
ix = context.get_state(self.INDEX_TAG)
if ix < len(self.triggers):
if self.triggers[ix].on_fire(watermark,
window,
self._sub_context(context, ix)):
ix += 1
context.add_state(self.INDEX_TAG, ix)
return ix == len(self.triggers)
def reset(self, window, context):
context.clear_state(self.INDEX_TAG)
for ix, trigger in enumerate(self.triggers):
trigger.reset(window, self._sub_context(context, ix))
@staticmethod
def _sub_context(context, index):
return NestedContext(context, '%d/' % index)
@staticmethod
def from_runner_api(proto, context):
return AfterEach(
*[
TriggerFn.from_runner_api(subtrigger, context)
for subtrigger in proto.after_each.subtriggers
])
def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
after_each=beam_runner_api_pb2.Trigger.AfterEach(
subtriggers=[
subtrigger.to_runner_api(context)
for subtrigger in self.triggers
]))
def has_ontime_pane(self):
return any(t.has_ontime_pane() for t in self.triggers)
class OrFinally(AfterAny):
@staticmethod
def from_runner_api(proto, context):
return OrFinally(
TriggerFn.from_runner_api(proto.or_finally.main, context),
# getattr is used as finally is a keyword in Python
TriggerFn.from_runner_api(
getattr(proto.or_finally, 'finally'), context))
def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
or_finally=beam_runner_api_pb2.Trigger.OrFinally(
main=self.triggers[0].to_runner_api(context),
# dict keyword argument is used as finally is a keyword in Python
**{'finally': self.triggers[1].to_runner_api(context)}))
class TriggerContext(object):
def __init__(self, outer, window, clock):
self._outer = outer
self._window = window
self._clock = clock
def get_current_time(self):
return self._clock.time()
def set_timer(self, name, time_domain, timestamp):
self._outer.set_timer(self._window, name, time_domain, timestamp)
def clear_timer(self, name, time_domain):
self._outer.clear_timer(self._window, name, time_domain)
def add_state(self, tag, value):
self._outer.add_state(self._window, tag, value)
def get_state(self, tag):
return self._outer.get_state(self._window, tag)
def clear_state(self, tag):
return self._outer.clear_state(self._window, tag)
class NestedContext(object):
"""Namespaced context useful for defining composite triggers."""
def __init__(self, outer, prefix):
self._outer = outer
self._prefix = prefix
def get_current_time(self):
return self._outer.get_current_time()
def set_timer(self, name, time_domain, timestamp):
self._outer.set_timer(self._prefix + name, time_domain, timestamp)
def clear_timer(self, name, time_domain):
self._outer.clear_timer(self._prefix + name, time_domain)
def add_state(self, tag, value):
self._outer.add_state(tag.with_prefix(self._prefix), value)
def get_state(self, tag):
return self._outer.get_state(tag.with_prefix(self._prefix))
def clear_state(self, tag):
self._outer.clear_state(tag.with_prefix(self._prefix))
# pylint: disable=unused-argument
class SimpleState(with_metaclass(ABCMeta, object)): # type: ignore[misc]
"""Basic state storage interface used for triggering.
Only timers must hold the watermark (by their timestamp).
"""
@abstractmethod
def set_timer(self, window, name, time_domain, timestamp):
pass
@abstractmethod
def get_window(self, window_id):
pass
@abstractmethod
def clear_timer(self, window, name, time_domain):
pass
@abstractmethod
def add_state(self, window, tag, value):
pass
@abstractmethod
def get_state(self, window, tag):
pass
@abstractmethod
def clear_state(self, window, tag):
pass
def at(self, window, clock):
return NestedContext(TriggerContext(self, window, clock), 'trigger')
class UnmergedState(SimpleState):
"""State suitable for use in TriggerDriver.
This class must be implemented by each backend.
"""
@abstractmethod
def set_global_state(self, tag, value):
pass
@abstractmethod
def get_global_state(self, tag, default=None):
pass
# pylint: enable=unused-argument
class MergeableStateAdapter(SimpleState):
"""Wraps an UnmergedState, tracking merged windows."""
# TODO(robertwb): A similar indirection could be used for sliding windows
# or other window_fns when a single element typically belongs to many windows.
WINDOW_IDS = _ValueStateTag('window_ids')
def __init__(self, raw_state):
self.raw_state = raw_state
self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {})
self.counter = None
def set_timer(self, window, name, time_domain, timestamp):
self.raw_state.set_timer(self._get_id(window), name, time_domain, timestamp)
def clear_timer(self, window, name, time_domain):
for window_id in self._get_ids(window):
self.raw_state.clear_timer(window_id, name, time_domain)
def add_state(self, window, tag, value):
if isinstance(tag, _ValueStateTag):
raise ValueError(
'Merging requested for non-mergeable state tag: %r.' % tag)
elif isinstance(tag, _CombiningValueStateTag):
tag = tag.without_extraction()
self.raw_state.add_state(self._get_id(window), tag, value)
def get_state(self, window, tag):
if isinstance(tag, _CombiningValueStateTag):
original_tag, tag = tag, tag.without_extraction()
values = [
self.raw_state.get_state(window_id, tag)
for window_id in self._get_ids(window)
]
if isinstance(tag, _ValueStateTag):
raise ValueError(
'Merging requested for non-mergeable state tag: %r.' % tag)
elif isinstance(tag, _CombiningValueStateTag):
return original_tag.combine_fn.extract_output(
original_tag.combine_fn.merge_accumulators(values))
elif isinstance(tag, _ListStateTag):
return [v for vs in values for v in vs]
elif isinstance(tag, _SetStateTag):
return {v for vs in values for v in vs}
elif isinstance(tag, _WatermarkHoldStateTag):
return tag.timestamp_combiner_impl.combine_all(values)
else:
raise ValueError('Invalid tag.', tag)
def clear_state(self, window, tag):
for window_id in self._get_ids(window):
self.raw_state.clear_state(window_id, tag)
if tag is None:
del self.window_ids[window]
self._persist_window_ids()
def merge(self, to_be_merged, merge_result):
for window in to_be_merged:
if window != merge_result:
if window in self.window_ids:
if merge_result in self.window_ids:
merge_window_ids = self.window_ids[merge_result]
else:
merge_window_ids = self.window_ids[merge_result] = []
merge_window_ids.extend(self.window_ids.pop(window))
self._persist_window_ids()
def known_windows(self):
return list(self.window_ids)
def get_window(self, window_id):
for window, ids in self.window_ids.items():
if window_id in ids:
return window
raise ValueError('No window for %s' % window_id)
def _get_id(self, window):
if window in self.window_ids:
return self.window_ids[window][0]
window_id = self._get_next_counter()
self.window_ids[window] = [window_id]
self._persist_window_ids()
return window_id
def _get_ids(self, window):
return self.window_ids.get(window, [])
def _get_next_counter(self):
if not self.window_ids:
self.counter = 0
elif self.counter is None:
self.counter = max(k for ids in self.window_ids.values() for k in ids)
self.counter += 1
return self.counter
def _persist_window_ids(self):
self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids)
def __repr__(self):
return '\n\t'.join([repr(self.window_ids)] +
repr(self.raw_state).split('\n'))
def create_trigger_driver(
windowing, is_batch=False, phased_combine_fn=None, clock=None):
"""Create the TriggerDriver for the given windowing and options."""
# TODO(robertwb): We can do more if we know elements are in timestamp
# sorted order.
if windowing.is_default() and is_batch:
driver = BatchGlobalTriggerDriver()
elif (windowing.windowfn == GlobalWindows() and
(windowing.triggerfn in [AfterCount(1), Always()]) and is_batch):
# Here we also just pass through all the values exactly once.
driver = BatchGlobalTriggerDriver()
else:
driver = GeneralTriggerDriver(windowing, clock)
if phased_combine_fn:
# TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using
# the known phased_combine_fn here.
driver = CombiningTriggerDriver(phased_combine_fn, driver)
return driver
class TriggerDriver(with_metaclass(ABCMeta, object)): # type: ignore[misc]
"""Breaks a series of bundle and timer firings into window (pane)s."""
@abstractmethod
def process_elements(
self,
state,
windowed_values,
output_watermark,
input_watermark=MIN_TIMESTAMP):
pass
@abstractmethod
def process_timer(
self,
window_id,
name,
time_domain,
timestamp,
state,
input_watermark=None):
pass
def process_entire_key(
self,
key,
windowed_values,
unused_output_watermark=None,
unused_input_watermark=None):
state = InMemoryUnmergedState()
for wvalue in self.process_elements(state,
windowed_values,
MIN_TIMESTAMP,
MIN_TIMESTAMP):
yield wvalue.with_value((key, wvalue.value))
while state.timers:
fired = state.get_and_clear_timers()
for timer_window, (name, time_domain, fire_time) in fired:
for wvalue in self.process_timer(timer_window,
name,
time_domain,
fire_time,
state):
yield wvalue.with_value((key, wvalue.value))
class _UnwindowedValues(observable.ObservableMixin):
"""Exposes iterable of windowed values as iterable of unwindowed values."""
def __init__(self, windowed_values):
super(_UnwindowedValues, self).__init__()
self._windowed_values = windowed_values
def __iter__(self):
for wv in self._windowed_values:
unwindowed_value = wv.value
self.notify_observers(unwindowed_value)
yield unwindowed_value
def __repr__(self):
return '<_UnwindowedValues of %s>' % self._windowed_values
def __reduce__(self):
return list, (list(self), )
def __eq__(self, other):
if isinstance(other, collections.Iterable):
return all(
a == b for a, b in zip_longest(self, other, fillvalue=object()))
else:
return NotImplemented
def __hash__(self):
return hash(tuple(self))
def __ne__(self, other):
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other
coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(
_UnwindowedValues)
class BatchGlobalTriggerDriver(TriggerDriver):
"""Groups all received values together.
"""
GLOBAL_WINDOW_TUPLE = (GlobalWindow(), )
ONLY_FIRING = windowed_value.PaneInfo(
is_first=True,
is_last=True,
timing=windowed_value.PaneInfoTiming.ON_TIME,
index=0,
nonspeculative_index=0)
def process_elements(
self,
state,
windowed_values,
unused_output_watermark,
unused_input_watermark=MIN_TIMESTAMP):
yield WindowedValue(
_UnwindowedValues(windowed_values),
MIN_TIMESTAMP,
self.GLOBAL_WINDOW_TUPLE,
self.ONLY_FIRING)
def process_timer(
self,
window_id,
name,
time_domain,
timestamp,
state,
input_watermark=None):
raise TypeError('Triggers never set or called for batch default windowing.')
class CombiningTriggerDriver(TriggerDriver):
"""Uses a phased_combine_fn to process output of wrapped TriggerDriver."""
def __init__(self, phased_combine_fn, underlying):
self.phased_combine_fn = phased_combine_fn
self.underlying = underlying
def process_elements(
self,
state,
windowed_values,
output_watermark,
input_watermark=MIN_TIMESTAMP):
uncombined = self.underlying.process_elements(
state, windowed_values, output_watermark, input_watermark)
for output in uncombined:
yield output.with_value(self.phased_combine_fn.apply(output.value))
def process_timer(
self,
window_id,
name,
time_domain,
timestamp,
state,
input_watermark=None):
uncombined = self.underlying.process_timer(
window_id, name, time_domain, timestamp, state, input_watermark)
for output in uncombined:
yield output.with_value(self.phased_combine_fn.apply(output.value))
class GeneralTriggerDriver(TriggerDriver):
"""Breaks a series of bundle and timer firings into window (pane)s.
Suitable for all variants of Windowing.
"""
ELEMENTS = _ListStateTag('elements')
TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn())
INDEX = _CombiningValueStateTag('index', combiners.CountCombineFn())
NONSPECULATIVE_INDEX = _CombiningValueStateTag(
'nonspeculative_index', combiners.CountCombineFn())
def __init__(self, windowing, clock):
self.clock = clock
self.allowed_lateness = windowing.allowed_lateness
self.window_fn = windowing.windowfn
self.timestamp_combiner_impl = TimestampCombiner.get_impl(
windowing.timestamp_combiner, self.window_fn)
# pylint: disable=invalid-name
self.WATERMARK_HOLD = _WatermarkHoldStateTag(
'watermark', self.timestamp_combiner_impl)
# pylint: enable=invalid-name
self.trigger_fn = windowing.triggerfn
self.accumulation_mode = windowing.accumulation_mode
self.is_merging = True
def process_elements(
self,
state,
windowed_values,
output_watermark,
input_watermark=MIN_TIMESTAMP):
if self.is_merging:
state = MergeableStateAdapter(state)
windows_to_elements = collections.defaultdict(list)
for wv in windowed_values:
for window in wv.windows:
# ignore expired windows
if input_watermark > window.end + self.allowed_lateness:
continue
windows_to_elements[window].append((wv.value, wv.timestamp))
# First handle merging.
if self.is_merging:
old_windows = set(state.known_windows())
all_windows = old_windows.union(list(windows_to_elements))
if all_windows != old_windows:
merged_away = {}
class TriggerMergeContext(WindowFn.MergeContext):
def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument
for window in to_be_merged:
if window != merge_result:
merged_away[window] = merge_result
# Clear state associated with PaneInfo since it is
# not preserved across merges.
state.clear_state(window, self.INDEX)
state.clear_state(window, self.NONSPECULATIVE_INDEX)
state.merge(to_be_merged, merge_result)
# using the outer self argument.
self.trigger_fn.on_merge(
to_be_merged, merge_result, state.at(merge_result, self.clock))
self.window_fn.merge(TriggerMergeContext(all_windows))
merged_windows_to_elements = collections.defaultdict(list)
for window, values in windows_to_elements.items():
while window in merged_away:
window = merged_away[window]
merged_windows_to_elements[window].extend(values)
windows_to_elements = merged_windows_to_elements
for window in merged_away:
state.clear_state(window, self.WATERMARK_HOLD)
# Next handle element adding.
for window, elements in windows_to_elements.items():
if state.get_state(window, self.TOMBSTONE):
continue
# Add watermark hold.
# TODO(ccy): Add late data and garbage-collection hold support.
output_time = self.timestamp_combiner_impl.merge(
window,
(
element_output_time for element_output_time in (
self.timestamp_combiner_impl.assign_output_time(
window, timestamp) for unused_value,
timestamp in elements)
if element_output_time >= output_watermark))
if output_time is not None:
state.add_state(window, self.WATERMARK_HOLD, output_time)
context = state.at(window, self.clock)
for value, unused_timestamp in elements:
state.add_state(window, self.ELEMENTS, value)
self.trigger_fn.on_element(value, window, context)
# Maybe fire this window.
if self.trigger_fn.should_fire(TimeDomain.WATERMARK,
input_watermark,
window,
context):
finished = self.trigger_fn.on_fire(input_watermark, window, context)
yield self._output(
window, finished, state, input_watermark, output_watermark, False)
def process_timer(
self,
window_id,
unused_name,
time_domain,
timestamp,
state,
input_watermark=None):
if input_watermark is None:
input_watermark = timestamp
if self.is_merging:
state = MergeableStateAdapter(state)
window = state.get_window(window_id)
if state.get_state(window, self.TOMBSTONE):
return
if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME):
if not self.is_merging or window in state.known_windows():
context = state.at(window, self.clock)
if self.trigger_fn.should_fire(time_domain, timestamp, window, context):
finished = self.trigger_fn.on_fire(timestamp, window, context)
yield self._output(
window,
finished,
state,
input_watermark,
timestamp,
time_domain == TimeDomain.WATERMARK)
else:
raise Exception('Unexpected time domain: %s' % time_domain)
def _output(
self,
window,
finished,
state,
input_watermark,
output_watermark,
maybe_ontime):
"""Output window and clean up if appropriate."""
index = state.get_state(window, self.INDEX)
state.add_state(window, self.INDEX, 1)
if output_watermark <= window.max_timestamp():
nonspeculative_index = -1
timing = windowed_value.PaneInfoTiming.EARLY
if state.get_state(window, self.NONSPECULATIVE_INDEX):
nonspeculative_index = state.get_state(
window, self.NONSPECULATIVE_INDEX)
state.add_state(window, self.NONSPECULATIVE_INDEX, 1)
_LOGGER.warning(
'Watermark moved backwards in time '
'or late data moved window end forward.')
else:
nonspeculative_index = state.get_state(window, self.NONSPECULATIVE_INDEX)
state.add_state(window, self.NONSPECULATIVE_INDEX, 1)
timing = (
windowed_value.PaneInfoTiming.ON_TIME if maybe_ontime and
nonspeculative_index == 0 else windowed_value.PaneInfoTiming.LATE)
pane_info = windowed_value.PaneInfo(
index == 0, finished, timing, index, nonspeculative_index)
values = state.get_state(window, self.ELEMENTS)
if finished:
# TODO(robertwb): allowed lateness
state.clear_state(window, self.ELEMENTS)
state.add_state(window, self.TOMBSTONE, 1)
elif self.accumulation_mode == AccumulationMode.DISCARDING:
state.clear_state(window, self.ELEMENTS)
timestamp = state.get_state(window, self.WATERMARK_HOLD)
if timestamp is None:
# If no watermark hold was set, output at end of window.
timestamp = window.max_timestamp()
elif input_watermark < window.end and self.trigger_fn.has_ontime_pane():
# Hold the watermark in case there is an empty pane that needs to be fired
# at the end of the window.
pass
else:
state.clear_state(window, self.WATERMARK_HOLD)
return WindowedValue(values, timestamp, (window, ), pane_info)
class InMemoryUnmergedState(UnmergedState):
"""In-memory implementation of UnmergedState.
Used for batch and testing.
"""
def __init__(self, defensive_copy=True):
# TODO(robertwb): Skip defensive_copy in production if it's too expensive.
self.timers = collections.defaultdict(dict)
self.state = collections.defaultdict(lambda: collections.defaultdict(list))
self.global_state = {}
self.defensive_copy = defensive_copy
def copy(self):
cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy)
cloned_object.timers = copy.deepcopy(self.timers)
cloned_object.global_state = copy.deepcopy(self.global_state)
for window in self.state:
for tag in self.state[window]:
cloned_object.state[window][tag] = copy.copy(self.state[window][tag])
return cloned_object
def set_global_state(self, tag, value):
assert isinstance(tag, _ValueStateTag)
if self.defensive_copy:
value = copy.deepcopy(value)
self.global_state[tag.tag] = value
def get_global_state(self, tag, default=None):
return self.global_state.get(tag.tag, default)
def set_timer(self, window, name, time_domain, timestamp):
self.timers[window][(name, time_domain)] = timestamp
def clear_timer(self, window, name, time_domain):
self.timers[window].pop((name, time_domain), None)
if not self.timers[window]:
del self.timers[window]
def get_window(self, window_id):
return window_id
def add_state(self, window, tag, value):
if self.defensive_copy:
value = copy.deepcopy(value)
if isinstance(tag, _ValueStateTag):
self.state[window][tag.tag] = value
elif isinstance(tag, _CombiningValueStateTag):
# TODO(robertwb): Store merged accumulators.
self.state[window][tag.tag].append(value)
elif isinstance(tag, _ListStateTag):
self.state[window][tag.tag].append(value)
elif isinstance(tag, _SetStateTag):
self.state[window][tag.tag].append(value)
elif isinstance(tag, _WatermarkHoldStateTag):
self.state[window][tag.tag].append(value)
else:
raise ValueError('Invalid tag.', tag)
def get_state(self, window, tag):
values = self.state[window][tag.tag]
if isinstance(tag, _ValueStateTag):
return values
elif isinstance(tag, _CombiningValueStateTag):
return tag.combine_fn.apply(values)
elif isinstance(tag, _ListStateTag):
return values
elif isinstance(tag, _SetStateTag):
return values
elif isinstance(tag, _WatermarkHoldStateTag):
return tag.timestamp_combiner_impl.combine_all(values)
else:
raise ValueError('Invalid tag.', tag)
def clear_state(self, window, tag):
self.state[window].pop(tag.tag, None)
if not self.state[window]:
self.state.pop(window, None)
def get_timers(
self, clear=False, watermark=MAX_TIMESTAMP, processing_time=None):
"""Gets expired timers and reports if there
are any realtime timers set per state.
Expiration is measured against the watermark for event-time timers,
and against a wall clock for processing-time timers.
"""
expired = []
has_realtime_timer = False
for window, timers in list(self.timers.items()):
for (name, time_domain), timestamp in list(timers.items()):
if time_domain == TimeDomain.REAL_TIME:
time_marker = processing_time
has_realtime_timer = True
elif time_domain == TimeDomain.WATERMARK:
time_marker = watermark
else:
_LOGGER.error(
'TimeDomain error: No timers defined for time domain %s.',
time_domain)
if timestamp <= time_marker:
expired.append((window, (name, time_domain, timestamp)))
if clear:
del timers[(name, time_domain)]
if not timers and clear:
del self.timers[window]
return expired, has_realtime_timer
def get_and_clear_timers(self, watermark=MAX_TIMESTAMP):
return self.get_timers(clear=True, watermark=watermark)[0]
def get_earliest_hold(self):
earliest_hold = MAX_TIMESTAMP
for unused_window, tagged_states in iteritems(self.state):
# TODO(BEAM-2519): currently, this assumes that the watermark hold tag is
# named "watermark". This is currently only true because the only place
# watermark holds are set is in the GeneralTriggerDriver, where we use
# this name. We should fix this by allowing enumeration of the tag types
# used in adding state.
if 'watermark' in tagged_states and tagged_states['watermark']:
hold = min(tagged_states['watermark']) - TIME_GRANULARITY
earliest_hold = min(earliest_hold, hold)
return earliest_hold
def __repr__(self):
state_str = '\n'.join(
'%s: %s' % (key, dict(state)) for key, state in self.state.items())
return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)