Merge pull request #12412 from [BEAM-10603] Add ElementLimiters to all Cache Managers.

[BEAM-10603] Add ElementLimiters to all Cache Managers.
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 9894535..48f1fc5 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -67,13 +67,17 @@
     """Returns the latest version number of the PCollection cache."""
     raise NotImplementedError
 
-  def read(self, *labels):
-    # type (*str) -> Tuple[str, Generator[Any]]
+  def read(self, *labels, **args):
+    # type (*str, Dict[str, Any]) -> Tuple[str, Generator[Any]]
 
     """Return the PCollection as a list as well as the version number.
 
     Args:
       *labels: List of labels for PCollection instance.
+      **args: Dict of additional arguments. Currently only supports 'limiters'
+        as a list of ElementLimiters, and 'tail' as a boolean. Limiters limits
+        the amount of elements read and duration with respect to processing
+        time.
 
     Returns:
       A tuple containing an iterator for the items in the PCollection and the
@@ -97,6 +101,17 @@
     """
     raise NotImplementedError
 
+  def clear(self, *labels):
+    # type (*str) -> Boolean
+
+    """Clears the cache entry of the given labels and returns True on success.
+
+    Args:
+      value: An encodable (with corresponding PCoder) value
+      *labels: List of labels for PCollection instance
+    """
+    raise NotImplementedError
+
   def source(self, *labels):
     # type (*str) -> ptransform.PTransform
 
@@ -196,17 +211,35 @@
         self._default_pcoder if self._default_pcoder is not None else
         self._saved_pcoders[self._path(*labels)])
 
-  def read(self, *labels):
+  def read(self, *labels, **args):
     # Return an iterator to an empty list if it doesn't exist.
     if not self.exists(*labels):
       return iter([]), -1
 
+    limiters = args.pop('limiters', [])
+
     # Otherwise, return a generator to the cached PCollection.
     source = self.source(*labels)._source
     range_tracker = source.get_range_tracker(None, None)
     reader = source.read(range_tracker)
     version = self._latest_version(*labels)
-    return reader, version
+
+    # The return type is a generator, so in order to implement the limiter for
+    # the FileBasedCacheManager we wrap the original generator with the logic
+    # to limit yielded elements.
+    def limit_reader(r):
+      for e in r:
+        # Update the limiters and break early out of reading from cache if any
+        # are triggered.
+        for l in limiters:
+          l.update(e)
+
+        if any(l.is_triggered() for l in limiters):
+          break
+
+        yield e
+
+    return limit_reader(reader), version
 
   def write(self, values, *labels):
     sink = self.sink(labels)._sink
@@ -218,6 +251,12 @@
       writer.write(v)
     writer.close()
 
+  def clear(self, *labels):
+    if self.exists(*labels):
+      filesystems.FileSystems.delete(self._match(*labels))
+      return True
+    return False
+
   def source(self, *labels):
     return self._reader_class(
         self._glob_path(*labels), coder=self.load_pcoder(*labels))
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index 7868e90..e7dc936 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -30,6 +30,7 @@
 from apache_beam import coders
 from apache_beam.io import filesystems
 from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
 
 
 class FileBasedCacheManagerTest(object):
@@ -91,6 +92,18 @@
     self.mock_write_cache(cache_version_one, prefix, cache_label)
     self.assertTrue(self.cache_manager.exists(prefix, cache_label))
 
+  def test_clear(self):
+    """Test that CacheManager can correctly tell if the cache exists or not."""
+    prefix = 'full'
+    cache_label = 'some-cache-label'
+    cache_version_one = ['cache', 'version', 'one']
+
+    self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+    self.mock_write_cache(cache_version_one, prefix, cache_label)
+    self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+    self.assertTrue(self.cache_manager.clear(prefix, cache_label))
+    self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+
   def test_read_basic(self):
     """Test the condition where the cache is read once after written once."""
     prefix = 'full'
@@ -180,6 +193,21 @@
     self.assertTrue(
         self.cache_manager.is_latest_version(version, prefix, cache_label))
 
+  def test_read_with_count_limiter(self):
+    """Test the condition where the cache is read once after written once."""
+    prefix = 'full'
+    cache_label = 'some-cache-label'
+    cache_version_one = ['cache', 'version', 'one']
+
+    self.mock_write_cache(cache_version_one, prefix, cache_label)
+    reader, version = self.cache_manager.read(
+        prefix, cache_label, limiters=[CountLimiter(2)])
+    pcoll_list = list(reader)
+    self.assertListEqual(pcoll_list, ['cache', 'version'])
+    self.assertEqual(version, 0)
+    self.assertTrue(
+        self.cache_manager.is_latest_version(version, prefix, cache_label))
+
 
 class TextFileBasedCacheManagerTest(
     FileBasedCacheManagerTest,
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index b2204cf..77f976d 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -153,13 +153,23 @@
       cache_dir,
       labels,
       is_cache_complete=None,
-      coder=SafeFastPrimitivesCoder()):
+      coder=None,
+      limiters=None):
+    if not coder:
+      coder = SafeFastPrimitivesCoder()
+
+    if not is_cache_complete:
+      is_cache_complete = lambda _: True
+
+    if not limiters:
+      limiters = []
+
     self._cache_dir = cache_dir
     self._coder = coder
     self._labels = labels
     self._path = os.path.join(self._cache_dir, *self._labels)
-    self._is_cache_complete = (
-        is_cache_complete if is_cache_complete else lambda _: True)
+    self._is_cache_complete = is_cache_complete
+    self._limiters = limiters
 
     from apache_beam.runners.interactive.pipeline_instrument import CacheKey
     self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
@@ -193,7 +203,8 @@
 
       # Check if we are at EOF or if we have an incomplete line.
       if not line or (line and line[-1] != b'\n'[0]):
-        if not tail:
+        # Read at least the first line to get the header.
+        if not tail and pos != 0:
           break
 
         # Complete reading only when the cache is complete.
@@ -210,10 +221,16 @@
         proto_cls = TestStreamFileHeader if pos == 0 else TestStreamFileRecord
         msg = self._try_parse_as(proto_cls, to_decode)
         if msg:
-          yield msg
+          for l in self._limiters:
+            l.update(msg)
+
+          if any(l.is_triggered() for l in self._limiters):
+            break
         else:
           break
 
+        yield msg
+
   def _try_parse_as(self, proto_cls, to_decode):
     try:
       msg = proto_cls()
@@ -285,7 +302,7 @@
     return os.path.exists(path)
 
   # TODO(srohde): Modify this to return the correct version.
-  def read(self, *labels):
+  def read(self, *labels, **args):
     """Returns a generator to read all records from file.
 
     Does not tail.
@@ -293,8 +310,12 @@
     if not self.exists(*labels):
       return iter([]), -1
 
+    limiters = args.pop('limiters', [])
+    tail = args.pop('tail', False)
+
     reader = StreamingCacheSource(
-        self._cache_dir, labels, self._is_cache_complete).read(tail=False)
+        self._cache_dir, labels, self._is_cache_complete,
+        limiters=limiters).read(tail=tail)
 
     # Return an empty iterator if there is nothing in the file yet. This can
     # only happen when tail is False.
@@ -304,7 +325,7 @@
       return iter([]), -1
     return StreamingCache.Reader([header], [reader]).read(), 1
 
-  def read_multiple(self, labels):
+  def read_multiple(self, labels, limiters=None, tail=True):
     """Returns a generator to read all records from file.
 
     Does tail until the cache is complete. This is because it is used in the
@@ -312,9 +333,9 @@
     pipeline runtime which needs to block.
     """
     readers = [
-        StreamingCacheSource(self._cache_dir, l,
-                             self._is_cache_complete).read(tail=True)
-        for l in labels
+        StreamingCacheSource(
+            self._cache_dir, l, self._is_cache_complete,
+            limiters=limiters).read(tail=tail) for l in labels
     ]
     headers = [next(r) for r in readers]
     return StreamingCache.Reader(headers, readers).read()
@@ -334,6 +355,14 @@
           val = v
         f.write(self._default_pcoder.encode(val) + b'\n')
 
+  def clear(self, *labels):
+    directory = os.path.join(self._cache_dir, *labels[:-1])
+    filepath = os.path.join(directory, labels[-1])
+    if os.path.exists(filepath):
+      os.remove(filepath)
+      return True
+    return False
+
   def source(self, *labels):
     """Returns the StreamingCacheManager source.
 
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index a56b851..23390cc 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -28,6 +28,8 @@
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
 from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
+from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
 from apache_beam.runners.interactive.pipeline_instrument import CacheKey
 from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -64,6 +66,14 @@
     # Assert that an empty reader returns an empty list.
     self.assertFalse([e for e in reader])
 
+  def test_clear(self):
+    cache = StreamingCache(cache_dir=None)
+    self.assertFalse(cache.exists('my_label'))
+    cache.write([TestStreamFileRecord()], 'my_label')
+    self.assertTrue(cache.exists('my_label'))
+    self.assertTrue(cache.clear('my_label'))
+    self.assertFalse(cache.exists('my_label'))
+
   def test_single_reader(self):
     """Tests that we expect to see all the correctly emitted TestStreamPayloads.
     """
@@ -403,6 +413,106 @@
 
     self.assertListEqual(actual_events, expected_events)
 
+  def test_single_reader_with_count_limiter(self):
+    """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+    """
+    CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+    values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+              .add_element(element=0, event_time_secs=0)
+              .advance_processing_time(1)
+              .add_element(element=1, event_time_secs=1)
+              .advance_processing_time(1)
+              .add_element(element=2, event_time_secs=2)
+              .build()) # yapf: disable
+
+    cache = StreamingCache(cache_dir=None)
+    cache.write(values, CACHED_PCOLLECTION_KEY)
+
+    reader, _ = cache.read(CACHED_PCOLLECTION_KEY, limiters=[CountLimiter(2)])
+    coder = coders.FastPrimitivesCoder()
+    events = list(reader)
+
+    # Units here are in microseconds.
+    # These are a slice of the original values such that we only get two
+    # elements.
+    expected = [
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode(0), timestamp=0)
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1 * 10**6)),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode(1), timestamp=1 * 10**6)
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1 * 10**6)),
+    ]
+    self.assertSequenceEqual(events, expected)
+
+  def test_single_reader_with_processing_time_limiter(self):
+    """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+    """
+    CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+    values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+              .advance_processing_time(1e-6)
+              .add_element(element=0, event_time_secs=0)
+              .advance_processing_time(1)
+              .add_element(element=1, event_time_secs=1)
+              .advance_processing_time(1)
+              .add_element(element=2, event_time_secs=2)
+              .advance_processing_time(1)
+              .add_element(element=3, event_time_secs=2)
+              .advance_processing_time(1)
+              .add_element(element=4, event_time_secs=2)
+              .build()) # yapf: disable
+
+    cache = StreamingCache(cache_dir=None)
+    cache.write(values, CACHED_PCOLLECTION_KEY)
+
+    reader, _ = cache.read(
+        CACHED_PCOLLECTION_KEY, limiters=[ProcessingTimeLimiter(2)])
+    coder = coders.FastPrimitivesCoder()
+    events = list(reader)
+
+    # Units here are in microseconds.
+    # Expects that the elements are a slice of the original values where all
+    # processing time is less than the duration.
+    expected = [
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1)),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode(0), timestamp=0)
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1 * 10**6)),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode(1), timestamp=1 * 10**6)
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
+    ]
+    self.assertSequenceEqual(events, expected)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
index 2c84f80..c9888ab 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
@@ -109,12 +109,14 @@
       self._count += 1
 
   def is_triggered(self):
-    return self._count >= self._max_count
+    return self._count > self._max_count
 
 
 class ProcessingTimeLimiter(ElementLimiter):
   """Limits by how long the ProcessingTime passed in the element stream.
 
+  Reads all elements from the timespan [start, start + duration).
+
   This measures the duration from the first element in the stream. Each
   subsequent element has a delta "advance_duration" that moves the internal
   clock forward. This triggers when the duration from the internal clock and
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
index 850c56e2c..347cb8e 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
@@ -28,7 +28,7 @@
   def test_count_limiter(self):
     limiter = CountLimiter(5)
 
-    for e in range(4):
+    for e in range(5):
       limiter.update(e)
 
     self.assertFalse(limiter.is_triggered())
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
index f39f016..098f249 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -45,11 +45,23 @@
   def _latest_version(self, *labels):
     return True
 
-  def read(self, *labels):
+  def read(self, *labels, **args):
     if not self.exists(*labels):
       return itertools.chain([]), -1
-    ret = itertools.chain(self._cached[self._key(*labels)])
-    return ret, None
+
+    limiters = args.pop('limiters', [])
+
+    def limit_reader(r):
+      for e in r:
+        for l in limiters:
+          l.update(e)
+
+        if any(l.is_triggered() for l in limiters):
+          break
+
+        yield e
+
+    return limit_reader(itertools.chain(self._cached[self._key(*labels)])), None
 
   def write(self, value, *labels):
     if not self.exists(*labels):