Merge pull request #12412 from [BEAM-10603] Add ElementLimiters to all Cache Managers.
[BEAM-10603] Add ElementLimiters to all Cache Managers.
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 9894535..48f1fc5 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -67,13 +67,17 @@
"""Returns the latest version number of the PCollection cache."""
raise NotImplementedError
- def read(self, *labels):
- # type (*str) -> Tuple[str, Generator[Any]]
+ def read(self, *labels, **args):
+ # type (*str, Dict[str, Any]) -> Tuple[str, Generator[Any]]
"""Return the PCollection as a list as well as the version number.
Args:
*labels: List of labels for PCollection instance.
+ **args: Dict of additional arguments. Currently only supports 'limiters'
+ as a list of ElementLimiters, and 'tail' as a boolean. Limiters limits
+ the amount of elements read and duration with respect to processing
+ time.
Returns:
A tuple containing an iterator for the items in the PCollection and the
@@ -97,6 +101,17 @@
"""
raise NotImplementedError
+ def clear(self, *labels):
+ # type (*str) -> Boolean
+
+ """Clears the cache entry of the given labels and returns True on success.
+
+ Args:
+ value: An encodable (with corresponding PCoder) value
+ *labels: List of labels for PCollection instance
+ """
+ raise NotImplementedError
+
def source(self, *labels):
# type (*str) -> ptransform.PTransform
@@ -196,17 +211,35 @@
self._default_pcoder if self._default_pcoder is not None else
self._saved_pcoders[self._path(*labels)])
- def read(self, *labels):
+ def read(self, *labels, **args):
# Return an iterator to an empty list if it doesn't exist.
if not self.exists(*labels):
return iter([]), -1
+ limiters = args.pop('limiters', [])
+
# Otherwise, return a generator to the cached PCollection.
source = self.source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
reader = source.read(range_tracker)
version = self._latest_version(*labels)
- return reader, version
+
+ # The return type is a generator, so in order to implement the limiter for
+ # the FileBasedCacheManager we wrap the original generator with the logic
+ # to limit yielded elements.
+ def limit_reader(r):
+ for e in r:
+ # Update the limiters and break early out of reading from cache if any
+ # are triggered.
+ for l in limiters:
+ l.update(e)
+
+ if any(l.is_triggered() for l in limiters):
+ break
+
+ yield e
+
+ return limit_reader(reader), version
def write(self, values, *labels):
sink = self.sink(labels)._sink
@@ -218,6 +251,12 @@
writer.write(v)
writer.close()
+ def clear(self, *labels):
+ if self.exists(*labels):
+ filesystems.FileSystems.delete(self._match(*labels))
+ return True
+ return False
+
def source(self, *labels):
return self._reader_class(
self._glob_path(*labels), coder=self.load_pcoder(*labels))
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index 7868e90..e7dc936 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -30,6 +30,7 @@
from apache_beam import coders
from apache_beam.io import filesystems
from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
class FileBasedCacheManagerTest(object):
@@ -91,6 +92,18 @@
self.mock_write_cache(cache_version_one, prefix, cache_label)
self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+ def test_clear(self):
+ """Test that CacheManager can correctly tell if the cache exists or not."""
+ prefix = 'full'
+ cache_label = 'some-cache-label'
+ cache_version_one = ['cache', 'version', 'one']
+
+ self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+ self.mock_write_cache(cache_version_one, prefix, cache_label)
+ self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+ self.assertTrue(self.cache_manager.clear(prefix, cache_label))
+ self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+
def test_read_basic(self):
"""Test the condition where the cache is read once after written once."""
prefix = 'full'
@@ -180,6 +193,21 @@
self.assertTrue(
self.cache_manager.is_latest_version(version, prefix, cache_label))
+ def test_read_with_count_limiter(self):
+ """Test the condition where the cache is read once after written once."""
+ prefix = 'full'
+ cache_label = 'some-cache-label'
+ cache_version_one = ['cache', 'version', 'one']
+
+ self.mock_write_cache(cache_version_one, prefix, cache_label)
+ reader, version = self.cache_manager.read(
+ prefix, cache_label, limiters=[CountLimiter(2)])
+ pcoll_list = list(reader)
+ self.assertListEqual(pcoll_list, ['cache', 'version'])
+ self.assertEqual(version, 0)
+ self.assertTrue(
+ self.cache_manager.is_latest_version(version, prefix, cache_label))
+
class TextFileBasedCacheManagerTest(
FileBasedCacheManagerTest,
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index b2204cf..77f976d 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -153,13 +153,23 @@
cache_dir,
labels,
is_cache_complete=None,
- coder=SafeFastPrimitivesCoder()):
+ coder=None,
+ limiters=None):
+ if not coder:
+ coder = SafeFastPrimitivesCoder()
+
+ if not is_cache_complete:
+ is_cache_complete = lambda _: True
+
+ if not limiters:
+ limiters = []
+
self._cache_dir = cache_dir
self._coder = coder
self._labels = labels
self._path = os.path.join(self._cache_dir, *self._labels)
- self._is_cache_complete = (
- is_cache_complete if is_cache_complete else lambda _: True)
+ self._is_cache_complete = is_cache_complete
+ self._limiters = limiters
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
@@ -193,7 +203,8 @@
# Check if we are at EOF or if we have an incomplete line.
if not line or (line and line[-1] != b'\n'[0]):
- if not tail:
+ # Read at least the first line to get the header.
+ if not tail and pos != 0:
break
# Complete reading only when the cache is complete.
@@ -210,10 +221,16 @@
proto_cls = TestStreamFileHeader if pos == 0 else TestStreamFileRecord
msg = self._try_parse_as(proto_cls, to_decode)
if msg:
- yield msg
+ for l in self._limiters:
+ l.update(msg)
+
+ if any(l.is_triggered() for l in self._limiters):
+ break
else:
break
+ yield msg
+
def _try_parse_as(self, proto_cls, to_decode):
try:
msg = proto_cls()
@@ -285,7 +302,7 @@
return os.path.exists(path)
# TODO(srohde): Modify this to return the correct version.
- def read(self, *labels):
+ def read(self, *labels, **args):
"""Returns a generator to read all records from file.
Does not tail.
@@ -293,8 +310,12 @@
if not self.exists(*labels):
return iter([]), -1
+ limiters = args.pop('limiters', [])
+ tail = args.pop('tail', False)
+
reader = StreamingCacheSource(
- self._cache_dir, labels, self._is_cache_complete).read(tail=False)
+ self._cache_dir, labels, self._is_cache_complete,
+ limiters=limiters).read(tail=tail)
# Return an empty iterator if there is nothing in the file yet. This can
# only happen when tail is False.
@@ -304,7 +325,7 @@
return iter([]), -1
return StreamingCache.Reader([header], [reader]).read(), 1
- def read_multiple(self, labels):
+ def read_multiple(self, labels, limiters=None, tail=True):
"""Returns a generator to read all records from file.
Does tail until the cache is complete. This is because it is used in the
@@ -312,9 +333,9 @@
pipeline runtime which needs to block.
"""
readers = [
- StreamingCacheSource(self._cache_dir, l,
- self._is_cache_complete).read(tail=True)
- for l in labels
+ StreamingCacheSource(
+ self._cache_dir, l, self._is_cache_complete,
+ limiters=limiters).read(tail=tail) for l in labels
]
headers = [next(r) for r in readers]
return StreamingCache.Reader(headers, readers).read()
@@ -334,6 +355,14 @@
val = v
f.write(self._default_pcoder.encode(val) + b'\n')
+ def clear(self, *labels):
+ directory = os.path.join(self._cache_dir, *labels[:-1])
+ filepath = os.path.join(directory, labels[-1])
+ if os.path.exists(filepath):
+ os.remove(filepath)
+ return True
+ return False
+
def source(self, *labels):
"""Returns the StreamingCacheManager source.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index a56b851..23390cc 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -28,6 +28,8 @@
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
+from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
from apache_beam.testing.test_pipeline import TestPipeline
@@ -64,6 +66,14 @@
# Assert that an empty reader returns an empty list.
self.assertFalse([e for e in reader])
+ def test_clear(self):
+ cache = StreamingCache(cache_dir=None)
+ self.assertFalse(cache.exists('my_label'))
+ cache.write([TestStreamFileRecord()], 'my_label')
+ self.assertTrue(cache.exists('my_label'))
+ self.assertTrue(cache.clear('my_label'))
+ self.assertFalse(cache.exists('my_label'))
+
def test_single_reader(self):
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
"""
@@ -403,6 +413,106 @@
self.assertListEqual(actual_events, expected_events)
+ def test_single_reader_with_count_limiter(self):
+ """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+ """
+ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+ values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+ .add_element(element=0, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=1, event_time_secs=1)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=2)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(values, CACHED_PCOLLECTION_KEY)
+
+ reader, _ = cache.read(CACHED_PCOLLECTION_KEY, limiters=[CountLimiter(2)])
+ coder = coders.FastPrimitivesCoder()
+ events = list(reader)
+
+ # Units here are in microseconds.
+ # These are a slice of the original values such that we only get two
+ # elements.
+ expected = [
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(0), timestamp=0)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(1), timestamp=1 * 10**6)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ ]
+ self.assertSequenceEqual(events, expected)
+
+ def test_single_reader_with_processing_time_limiter(self):
+ """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+ """
+ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+ values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+ .advance_processing_time(1e-6)
+ .add_element(element=0, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=1, event_time_secs=1)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=2)
+ .advance_processing_time(1)
+ .add_element(element=3, event_time_secs=2)
+ .advance_processing_time(1)
+ .add_element(element=4, event_time_secs=2)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(values, CACHED_PCOLLECTION_KEY)
+
+ reader, _ = cache.read(
+ CACHED_PCOLLECTION_KEY, limiters=[ProcessingTimeLimiter(2)])
+ coder = coders.FastPrimitivesCoder()
+ events = list(reader)
+
+ # Units here are in microseconds.
+ # Expects that the elements are a slice of the original values where all
+ # processing time is less than the duration.
+ expected = [
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(0), timestamp=0)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(1), timestamp=1 * 10**6)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ ]
+ self.assertSequenceEqual(events, expected)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
index 2c84f80..c9888ab 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
@@ -109,12 +109,14 @@
self._count += 1
def is_triggered(self):
- return self._count >= self._max_count
+ return self._count > self._max_count
class ProcessingTimeLimiter(ElementLimiter):
"""Limits by how long the ProcessingTime passed in the element stream.
+ Reads all elements from the timespan [start, start + duration).
+
This measures the duration from the first element in the stream. Each
subsequent element has a delta "advance_duration" that moves the internal
clock forward. This triggers when the duration from the internal clock and
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
index 850c56e2c..347cb8e 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
@@ -28,7 +28,7 @@
def test_count_limiter(self):
limiter = CountLimiter(5)
- for e in range(4):
+ for e in range(5):
limiter.update(e)
self.assertFalse(limiter.is_triggered())
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
index f39f016..098f249 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -45,11 +45,23 @@
def _latest_version(self, *labels):
return True
- def read(self, *labels):
+ def read(self, *labels, **args):
if not self.exists(*labels):
return itertools.chain([]), -1
- ret = itertools.chain(self._cached[self._key(*labels)])
- return ret, None
+
+ limiters = args.pop('limiters', [])
+
+ def limit_reader(r):
+ for e in r:
+ for l in limiters:
+ l.update(e)
+
+ if any(l.is_triggered() for l in limiters):
+ break
+
+ yield e
+
+ return limit_reader(itertools.chain(self._cached[self._key(*labels)])), None
def write(self, value, *labels):
if not self.exists(*labels):