Merge pull request #12411: [BEAM-10603] Add ElementLimiters which allows the cache to prematurely based on read elements.
[BEAM-10603] Add ElementLimiters which allows the cache to prematurely based on read elements.
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
index a25aba0..2c84f80 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
@@ -24,6 +24,8 @@
import threading
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
from apache_beam.runners.interactive import interactive_environment as ie
@@ -36,6 +38,20 @@
raise NotImplementedError
+class ElementLimiter(Limiter):
+ """A `Limiter` that limits reading from cache based on some property of an
+ element.
+ """
+ def update(self, e):
+ # type: (Any) -> None
+
+ """Update the internal state based on some property of an element.
+
+ This is executed on every element that is read from cache.
+ """
+ raise NotImplementedError
+
+
class SizeLimiter(Limiter):
"""Limits the cache size to a specified byte limit."""
def __init__(
@@ -71,3 +87,56 @@
def is_triggered(self):
return self._triggered
+
+
+class CountLimiter(ElementLimiter):
+ """Limits by counting the number of elements seen."""
+ def __init__(self, max_count):
+ self._max_count = max_count
+ self._count = 0
+
+ def update(self, e):
+ # A TestStreamFileRecord can contain many elements at once. If e is a file
+ # record, then count the number of elements in the bundle.
+ if isinstance(e, TestStreamFileRecord):
+ if not e.recorded_event.element_event:
+ return
+ self._count += len(e.recorded_event.element_event.elements)
+
+ # Otherwise, count everything else but the header of the file since it is
+ # not an element.
+ elif not isinstance(e, TestStreamFileHeader):
+ self._count += 1
+
+ def is_triggered(self):
+ return self._count >= self._max_count
+
+
+class ProcessingTimeLimiter(ElementLimiter):
+ """Limits by how long the ProcessingTime passed in the element stream.
+
+ This measures the duration from the first element in the stream. Each
+ subsequent element has a delta "advance_duration" that moves the internal
+ clock forward. This triggers when the duration from the internal clock and
+ the start exceeds the given duration.
+ """
+ def __init__(self, max_duration_secs):
+ """Initialize the ProcessingTimeLimiter."""
+ self._max_duration_us = max_duration_secs * 1e6
+ self._start_us = 0
+ self._cur_time_us = 0
+
+ def update(self, e):
+ # Only look at TestStreamFileRecords which hold the processing time.
+ if not isinstance(e, TestStreamFileRecord):
+ return
+
+ if not e.recorded_event.processing_time_event:
+ return
+
+ if self._start_us == 0:
+ self._start_us = e.recorded_event.processing_time_event.advance_duration
+ self._cur_time_us += e.recorded_event.processing_time_event.advance_duration
+
+ def is_triggered(self):
+ return self._cur_time_us - self._start_us >= self._max_duration_us
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
new file mode 100644
index 0000000..850c56e2c
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
+from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
+
+
+class CaptureLimitersTest(unittest.TestCase):
+ def test_count_limiter(self):
+ limiter = CountLimiter(5)
+
+ for e in range(4):
+ limiter.update(e)
+
+ self.assertFalse(limiter.is_triggered())
+ limiter.update(5)
+ self.assertTrue(limiter.is_triggered())
+
+ def test_processing_time_limiter(self):
+ limiter = ProcessingTimeLimiter(max_duration_secs=2)
+
+ r = TestStreamFileRecord()
+ r.recorded_event.processing_time_event.advance_duration = int(1 * 1e6)
+ limiter.update(r)
+ self.assertFalse(limiter.is_triggered())
+
+ r = TestStreamFileRecord()
+ r.recorded_event.processing_time_event.advance_duration = int(2 * 1e6)
+ limiter.update(r)
+ self.assertTrue(limiter.is_triggered())
+
+
+if __name__ == '__main__':
+ unittest.main()