Merge pull request #10090 Update container image tags used by Dataflow runner for Beam master
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ed2f013..0ddc48e 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -42,6 +42,7 @@
import "endpoints.proto";
import "google/protobuf/descriptor.proto";
import "google/protobuf/timestamp.proto";
+import "google/protobuf/duration.proto";
import "google/protobuf/wrappers.proto";
import "metrics.proto";
@@ -203,13 +204,21 @@
}
// An Application should be scheduled for execution after a delay.
+// Either an absolute timestamp or a relative timestamp can represent a
+// scheduled execution time.
message DelayedBundleApplication {
// Recommended time at which the application should be scheduled to execute
// by the runner. Times in the past may be scheduled to execute immediately.
+ // TODO(BEAM-8536): Migrate usage of absolute time to requested_time_delay.
google.protobuf.Timestamp requested_execution_time = 1;
// (Required) The application that should be scheduled.
BundleApplication application = 2;
+
+ // Recommended time delay at which the application should be scheduled to
+ // execute by the runner. Time delay that equals 0 may be scheduled to execute
+ // immediately. The unit of time delay should be microsecond.
+ google.protobuf.Duration requested_time_delay = 3;
}
// A request to process a given bundle.
diff --git a/runners/spark/job-server/build.gradle b/runners/spark/job-server/build.gradle
index 2b27a88..6fb7581 100644
--- a/runners/spark/job-server/build.gradle
+++ b/runners/spark/job-server/build.gradle
@@ -114,6 +114,7 @@
excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering'
},
)
}
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
index ab03de0..0f0b5b3 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
@@ -159,21 +159,6 @@
options.getGcpTempLocation();
}
- @Test
- public void testDefaultGcpTempLocationDoesNotExist() {
- GcpOptions options = PipelineOptionsFactory.as(GcpOptions.class);
- String tempLocation = "gs://does/not/exist";
- options.setTempLocation(tempLocation);
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage(
- "Error constructing default value for gcpTempLocation: tempLocation is not"
- + " a valid GCS path");
- thrown.expectCause(
- hasMessage(containsString("Output path does not exist or is not writeable")));
-
- options.getGcpTempLocation();
- }
-
private static void makePropertiesFileWithProject(File path, String projectId)
throws IOException {
String properties =
@@ -221,6 +206,21 @@
}
@Test
+ public void testDefaultGcpTempLocationDoesNotExist() throws IOException {
+ String tempLocation = "gs://does/not/exist";
+ options.setTempLocation(tempLocation);
+ when(mockGcsUtil.bucketAccessible(any(GcsPath.class))).thenReturn(false);
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage(
+ "Error constructing default value for gcpTempLocation: tempLocation is not"
+ + " a valid GCS path");
+ thrown.expectCause(
+ hasMessage(containsString("Output path does not exist or is not writeable")));
+
+ options.as(GcpOptions.class).getGcpTempLocation();
+ }
+
+ @Test
public void testCreateBucket() throws Exception {
doReturn(fakeProject).when(mockGet).execute();
when(mockGcsUtil.bucketOwner(any(GcsPath.class))).thenReturn(1L);
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 5b66730..e21052f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -35,6 +35,7 @@
import logging
import math
import random
+import threading
import uuid
from builtins import object
from builtins import range
@@ -1104,13 +1105,17 @@
class RestrictionTracker(object):
"""Manages concurrent access to a restriction.
- Experimental; no backwards-compatibility guarantees.
-
Keeps track of the restrictions claimed part for a Splittable DoFn.
+ The restriction may be modified by different threads, however the system will
+ ensure sufficient locking such that no methods on the restriction tracker
+ will be called concurrently.
+
See following documents for more details.
* https://s.apache.org/splittable-do-fn
* https://s.apache.org/splittable-do-fn-python-sdk
+
+ Experimental; no backwards-compatibility guarantees.
"""
def current_restriction(self):
@@ -1121,54 +1126,22 @@
The current restriction returned by method may be updated dynamically due
to due to concurrent invocation of other methods of the
- ``RestrictionTracker``, For example, ``checkpoint()``.
+ ``RestrictionTracker``, For example, ``split()``.
- ** Thread safety **
+ This API is required to be implemented.
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+ Returns: a restriction object.
"""
raise NotImplementedError
def current_progress(self):
"""Returns a RestrictionProgress object representing the current progress.
+
+ This API is recommended to be implemented. The runner can do a better job
+ at parallel processing with better progress signals.
"""
raise NotImplementedError
- def current_watermark(self):
- """Returns current watermark. By default, not report watermark.
-
- TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
- """
- return None
-
- def checkpoint(self):
- """Performs a checkpoint of the current restriction.
-
- Signals that the current ``DoFn.process()`` call should terminate as soon as
- possible. After this method returns, the tracker MUST refuse all future
- claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.
-
- This invocation modifies the value returned by ``current_restriction()``
- invocation and returns a restriction representing the rest of the work. The
- old value of ``current_restriction()`` is equivalent to the new value of
- ``current_restriction()`` and the return value of this method invocation
- combined.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
- """
-
- raise NotImplementedError
-
def check_done(self):
"""Checks whether the restriction has been fully processed.
@@ -1179,13 +1152,8 @@
remaining in the restriction when this method is invoked. Exception raised
must have an informative error message.
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+ This API is required to be implemented in order to make sure no data loss
+ during SDK processing.
Returns: ``True`` if current restriction has been fully processed.
Raises:
@@ -1215,8 +1183,12 @@
restrictions returned would be [100, 179), [179, 200) (note: current_offset
+ fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
- It is very important for pipeline scaling and end to end pipeline execution
- that try_split is implemented well.
+ ``fraction_of_remainder`` = 0 means a checkpoint is required.
+
+ The API is recommended to be implemented for batch pipeline given that it is
+ very important for pipeline scaling and end to end pipeline execution.
+
+ The API is required to be implemented for a streaming pipeline.
Args:
fraction_of_remainder: A hint as to the fraction of work the primary
@@ -1226,19 +1198,11 @@
Returns:
(primary_restriction, residual_restriction) if a split was possible,
otherwise returns ``None``.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
"""
raise NotImplementedError
def try_claim(self, position):
- """ Attempts to claim the block of work in the current restriction
+ """Attempts to claim the block of work in the current restriction
identified by the given position.
If this succeeds, the DoFn MUST execute the entire block of work. If it
@@ -1247,40 +1211,137 @@
work from ``DoFn.process()`` is also not allowed before the first call of
this method).
+ The API is required to be implemented.
+
Args:
position: current position that wants to be claimed.
Returns: ``True`` if the position can be claimed as current_position.
Otherwise, returns ``False``.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
"""
raise NotImplementedError
- def defer_remainder(self, watermark=None):
- """ Invokes checkpoint() in an SDF.process().
- TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
- ``ProcessContinuation``.
+class ThreadsafeRestrictionTracker(object):
+ """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+ This wrapper guarantees synchronization of modifying restrictions across
+ multi-thread.
+ """
+
+ def __init__(self, restriction_tracker):
+ if not isinstance(restriction_tracker, RestrictionTracker):
+ raise ValueError(
+ 'Initialize ThreadsafeRestrictionTracker requires'
+ 'RestrictionTracker.')
+ self._restriction_tracker = restriction_tracker
+ # Records an absolute timestamp when defer_remainder is called.
+ self._deferred_timestamp = None
+ self._lock = threading.RLock()
+ self._deferred_residual = None
+ self._deferred_watermark = None
+
+ def current_restriction(self):
+ with self._lock:
+ return self._restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ with self._lock:
+ return self._restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ """Performs self-checkpoint on current processing restriction with an
+ expected resuming time.
+
+ Self-checkpoint could happen during processing elements. When executing an
+ DoFn.process(), you may want to stop processing an element and resuming
+ later if current element has been processed quit a long time or you also
+ want to have some outputs from other elements. ``defer_remainder()`` can be
+ called on per element if needed.
Args:
- watermark
+ deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+ time gap between now and resuming, or an absolute ``timestamp.Timestamp``
+ for resuming execution time. If the time_delay is None, the deferred work
+ will be executed as soon as possible.
"""
- raise NotImplementedError
+
+ # Record current time for calculating deferred_time later.
+ self._deferred_timestamp = timestamp.Timestamp.now()
+ if (deferred_time and
+ not isinstance(deferred_time, timestamp.Duration) and
+ not isinstance(deferred_time, timestamp.Timestamp)):
+ raise ValueError('The timestamp of deter_remainder() should be a '
+ 'Duration or a Timestamp, or None.')
+ self._deferred_watermark = deferred_time
+ checkpoint = self.try_split(0)
+ if checkpoint:
+ _, self._deferred_residual = checkpoint
+
+ def check_done(self):
+ with self._lock:
+ return self._restriction_tracker.check_done()
+
+ def current_progress(self):
+ with self._lock:
+ return self._restriction_tracker.current_progress()
+
+ def try_split(self, fraction_of_remainder):
+ with self._lock:
+ return self._restriction_tracker.try_split(fraction_of_remainder)
def deferred_status(self):
- """ Returns deferred_residual with deferred_watermark.
+ """Returns deferred work which is produced by ``defer_remainder()``.
- TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
- ``ProcessContinuation``.
+ When there is a self-checkpoint performed, the system needs to fulfill the
+ DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
+ The system calls this API to get deferred_residual with watermark together
+ to help the runner to schedule a future work.
+
+ Returns: (deferred_residual, time_delay) if having any residual, else None.
"""
- raise NotImplementedError
+ if self._deferred_residual:
+ # If _deferred_watermark is None, create Duration(0).
+ if not self._deferred_watermark:
+ self._deferred_watermark = timestamp.Duration()
+ # If an absolute timestamp is provided, calculate the delta between
+ # the absoluted time and the time deferred_status() is called.
+ elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+ self._deferred_watermark = (self._deferred_watermark -
+ timestamp.Timestamp.now())
+ # If a Duration is provided, the deferred time should be:
+ # provided duration - the spent time since the defer_remainder() is
+ # called.
+ elif isinstance(self._deferred_watermark, timestamp.Duration):
+ self._deferred_watermark -= (timestamp.Timestamp.now() -
+ self._deferred_timestamp)
+ return self._deferred_residual, self._deferred_watermark
+
+
+class RestrictionTrackerView(object):
+ """A DoFn view of thread-safe RestrictionTracker.
+
+ The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+ exposes APIs that will be called by a ``DoFn.process()``. During execution
+ time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+ restriction_tracker.
+ """
+
+ def __init__(self, threadsafe_restriction_tracker):
+ if not isinstance(threadsafe_restriction_tracker,
+ ThreadsafeRestrictionTracker):
+ raise ValueError('Initialize RestrictionTrackerView requires '
+ 'ThreadsafeRestrictionTracker.')
+ self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+ def current_restriction(self):
+ return self._threadsafe_restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ return self._threadsafe_restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
class RestrictionProgress(object):
@@ -1400,17 +1461,8 @@
SourceBundle(residual_weight, self._source, split_pos,
stop_pos))
- def deferred_status(self):
- return None
-
- def current_watermark(self):
- return None
-
- def get_delegate_range_tracker(self):
- return self._delegate_range_tracker
-
- def get_tracking_source(self):
- return self._source
+ def check_done(self):
+ return self._delegate_range_tracker.fraction_consumed() >= 1.0
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1463,8 +1515,13 @@
restriction_tracker=core.DoFn.RestrictionParam(
_SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
source, chunk_size))):
- return restriction_tracker.get_tracking_source().read(
- restriction_tracker.get_delegate_range_tracker())
+ current_restriction = restriction_tracker.current_restriction()
+ assert isinstance(current_restriction, SourceBundle)
+ tracking_source = current_restriction.source
+ start = current_restriction.start_position
+ stop = current_restriction.stop_position
+ return tracking_source.read(tracking_source.get_range_tracker(start,
+ stop))
return SDFBoundedSourceDoFn(self.source)
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 7adb764..0a6afae 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -19,6 +19,7 @@
from __future__ import absolute_import
+import time
import unittest
import mock
@@ -28,6 +29,9 @@
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io import iobase
from apache_beam.io.iobase import SourceBundle
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.utils import timestamp
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
@@ -191,5 +195,87 @@
self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+ def test_defer_remainder_with_wrong_time_type(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ with self.assertRaises(ValueError):
+ threadsafe_tracker.defer_remainder(10)
+
+ def test_self_checkpoint_immediately(self):
+ restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ restriction_tracker)
+ threadsafe_tracker.defer_remainder()
+ deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+ expected_residual = OffsetRange(0, 10)
+ self.assertEqual(deferred_residual, expected_residual)
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ self.assertEqual(deferred_time, 0)
+
+ def test_self_checkpoint_with_relative_time(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation = 100 - 2 - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+ def test_self_checkpoint_with_absolute_time(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ now = timestamp.Timestamp.now()
+ schedule_time = now + timestamp.Duration(100)
+ self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+ threadsafe_tracker.defer_remainder(schedule_time)
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation =
+ # schedule_time - the time when deferred_status is called - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ iobase.RestrictionTrackerView(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+ def test_api_expose(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+ current_restriction = tracker_view.current_restriction()
+ self.assertEqual(current_restriction, OffsetRange(0, 10))
+ self.assertTrue(tracker_view.try_claim(0))
+ tracker_view.defer_remainder()
+ deferred_remainder, deferred_watermark = (
+ threadsafe_tracker.deferred_status())
+ self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+ self.assertEqual(deferred_watermark, timestamp.Duration())
+
+ def test_non_expose_apis(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+ with self.assertRaises(AttributeError):
+ tracker_view.check_done()
+ with self.assertRaises(AttributeError):
+ tracker_view.current_progress()
+ with self.assertRaises(AttributeError):
+ tracker_view.try_split()
+ with self.assertRaises(AttributeError):
+ tracker_view.deferred_status()
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 0ba5b23..20bb5c1 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -19,7 +19,6 @@
from __future__ import absolute_import
from __future__ import division
-import threading
from builtins import object
from apache_beam.io.iobase import RestrictionProgress
@@ -86,104 +85,69 @@
assert isinstance(offset_range, OffsetRange)
self._range = offset_range
self._current_position = None
- self._current_watermark = None
self._last_claim_attempt = None
- self._deferred_residual = None
self._checkpointed = False
- self._lock = threading.RLock()
def check_done(self):
- with self._lock:
- if self._last_claim_attempt < self._range.stop - 1:
- raise ValueError(
- 'OffsetRestrictionTracker is not done since work in range [%s, %s) '
- 'has not been claimed.'
- % (self._last_claim_attempt if self._last_claim_attempt is not None
- else self._range.start,
- self._range.stop))
+ if self._last_claim_attempt < self._range.stop - 1:
+ raise ValueError(
+ 'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+ 'has not been claimed.'
+ % (self._last_claim_attempt if self._last_claim_attempt is not None
+ else self._range.start,
+ self._range.stop))
def current_restriction(self):
- with self._lock:
- return self._range
-
- def current_watermark(self):
- return self._current_watermark
+ return self._range
def current_progress(self):
- with self._lock:
- if self._current_position is None:
- fraction = 0.0
- elif self._range.stop == self._range.start:
- # If self._current_position is not None, we must be done.
- fraction = 1.0
- else:
- fraction = (
- float(self._current_position - self._range.start)
- / (self._range.stop - self._range.start))
+ if self._current_position is None:
+ fraction = 0.0
+ elif self._range.stop == self._range.start:
+ # If self._current_position is not None, we must be done.
+ fraction = 1.0
+ else:
+ fraction = (
+ float(self._current_position - self._range.start)
+ / (self._range.stop - self._range.start))
return RestrictionProgress(fraction=fraction)
def start_position(self):
- with self._lock:
- return self._range.start
+ return self._range.start
def stop_position(self):
- with self._lock:
- return self._range.stop
-
- def default_size(self):
- return self._range.size()
+ return self._range.stop
def try_claim(self, position):
- with self._lock:
- if self._last_claim_attempt and position <= self._last_claim_attempt:
- raise ValueError(
- 'Positions claimed should strictly increase. Trying to claim '
- 'position %d while last claim attempt was %d.'
- % (position, self._last_claim_attempt))
+ if self._last_claim_attempt and position <= self._last_claim_attempt:
+ raise ValueError(
+ 'Positions claimed should strictly increase. Trying to claim '
+ 'position %d while last claim attempt was %d.'
+ % (position, self._last_claim_attempt))
- self._last_claim_attempt = position
- if position < self._range.start:
- raise ValueError(
- 'Position to be claimed cannot be smaller than the start position '
- 'of the range. Tried to claim position %r for the range [%r, %r)'
- % (position, self._range.start, self._range.stop))
+ self._last_claim_attempt = position
+ if position < self._range.start:
+ raise ValueError(
+ 'Position to be claimed cannot be smaller than the start position '
+ 'of the range. Tried to claim position %r for the range [%r, %r)'
+ % (position, self._range.start, self._range.stop))
- if position >= self._range.start and position < self._range.stop:
- self._current_position = position
- return True
+ if position >= self._range.start and position < self._range.stop:
+ self._current_position = position
+ return True
- return False
+ return False
def try_split(self, fraction_of_remainder):
- with self._lock:
- if not self._checkpointed:
- if self._current_position is None:
- cur = self._range.start - 1
- else:
- cur = self._current_position
- split_point = (
- cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
- if split_point < self._range.stop:
- self._range, residual_range = self._range.split_at(split_point)
- return self._range, residual_range
-
- # TODO(SDF): Replace all calls with try_claim(0).
- def checkpoint(self):
- with self._lock:
- # If self._current_position is 'None' no records have been claimed so
- # residual should start from self._range.start.
+ if not self._checkpointed:
if self._current_position is None:
- end_position = self._range.start
+ cur = self._range.start - 1
else:
- end_position = self._current_position + 1
- self._range, residual_range = self._range.split_at(end_position)
- return residual_range
-
- def defer_remainder(self, watermark=None):
- with self._lock:
- self._deferred_watermark = watermark or self._current_watermark
- self._deferred_residual = self.checkpoint()
-
- def deferred_status(self):
- if self._deferred_residual:
- return (self._deferred_residual, self._deferred_watermark)
+ cur = self._current_position
+ split_point = (
+ cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
+ if split_point < self._range.stop:
+ if fraction_of_remainder == 0:
+ self._checkpointed = True
+ self._range, residual_range = self._range.split_at(split_point)
+ return self._range, residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
index 459b039..4a57d98 100644
--- a/sdks/python/apache_beam/io/restriction_trackers_test.py
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -81,14 +81,14 @@
def test_checkpoint_unstarted(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 100), tracker.current_restriction())
self.assertEqual(OffsetRange(100, 200), checkpoint)
def test_checkpoint_just_started(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(100))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 101), tracker.current_restriction())
self.assertEqual(OffsetRange(101, 200), checkpoint)
@@ -96,7 +96,7 @@
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(105))
self.assertTrue(tracker.try_claim(110))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 111), tracker.current_restriction())
self.assertEqual(OffsetRange(111, 200), checkpoint)
@@ -105,9 +105,9 @@
self.assertTrue(tracker.try_claim(105))
self.assertTrue(tracker.try_claim(110))
self.assertTrue(tracker.try_claim(199))
- checkpoint = tracker.checkpoint()
+ checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 200), tracker.current_restriction())
- self.assertEqual(OffsetRange(200, 200), checkpoint)
+ self.assertEqual(None, checkpoint)
def test_checkpoint_after_failed_claim(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
@@ -116,7 +116,7 @@
self.assertTrue(tracker.try_claim(160))
self.assertFalse(tracker.try_claim(240))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertTrue(OffsetRange(100, 161), tracker.current_restriction())
self.assertTrue(OffsetRange(161, 200), checkpoint)
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 2ffe432..37e05bf 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -42,6 +42,8 @@
cdef object key_arg_name
cdef object restriction_provider
cdef object restriction_provider_arg_name
+ cdef object watermark_estimator
+ cdef object watermark_estimator_arg_name
cdef class DoFnSignature(object):
@@ -91,7 +93,9 @@
cdef bint cache_globally_windowed_args
cdef object process_method
cdef bint is_splittable
- cdef object restriction_tracker
+ cdef object threadsafe_restriction_tracker
+ cdef object watermark_estimator
+ cdef object watermark_estimator_param
cdef WindowedValue current_windowed_value
cdef bint is_key_param_required
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3e14f3b..8632cfd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
# cython: profile=True
"""Worker operations executor.
@@ -167,6 +166,8 @@
self.key_arg_name = None
self.restriction_provider = None
self.restriction_provider_arg_name = None
+ self.watermark_estimator = None
+ self.watermark_estimator_arg_name = None
for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
if isinstance(v, core.DoFn.StateParam):
@@ -184,6 +185,9 @@
elif isinstance(v, core.DoFn.RestrictionParam):
self.restriction_provider = v.restriction_provider
self.restriction_provider_arg_name = kw
+ elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
+ self.watermark_estimator = v.watermark_estimator
+ self.watermark_estimator_arg_name = kw
def invoke_timer_callback(self,
user_state_context,
@@ -264,6 +268,9 @@
def get_restriction_provider(self):
return self.process_method.restriction_provider
+ def get_watermark_estimator(self):
+ return self.process_method.watermark_estimator
+
def _validate(self):
self._validate_process()
self._validate_bundle_method(self.start_bundle_method)
@@ -458,7 +465,11 @@
signature.is_stateful_dofn())
self.user_state_context = user_state_context
self.is_splittable = signature.is_splittable_dofn()
- self.restriction_tracker = None
+ self.watermark_estimator = self.signature.get_watermark_estimator()
+ self.watermark_estimator_param = (
+ self.signature.process_method.watermark_estimator_arg_name
+ if self.watermark_estimator else None)
+ self.threadsafe_restriction_tracker = None
self.current_windowed_value = None
self.bundle_finalizer_param = bundle_finalizer_param
self.is_key_param_required = False
@@ -569,15 +580,24 @@
raise ValueError(
'A RestrictionTracker %r was provided but DoFn does not have a '
'RestrictionTrackerParam defined' % restriction_tracker)
- additional_kwargs[restriction_tracker_param] = restriction_tracker
+ from apache_beam.io import iobase
+ self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker(
+ restriction_tracker)
+ additional_kwargs[restriction_tracker_param] = (
+ iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+
+ if self.watermark_estimator:
+ # The watermark estimator needs to be reset for every element.
+ self.watermark_estimator.reset()
+ additional_kwargs[self.watermark_estimator_param] = (
+ self.watermark_estimator)
try:
self.current_windowed_value = windowed_value
- self.restriction_tracker = restriction_tracker
return self._invoke_process_per_window(
windowed_value, additional_args, additional_kwargs,
output_processor)
finally:
- self.restriction_tracker = None
+ self.threadsafe_restriction_tracker = None
self.current_windowed_value = windowed_value
elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
@@ -664,24 +684,34 @@
windowed_value, self.process_method(*args_for_process))
if self.is_splittable:
- deferred_status = self.restriction_tracker.deferred_status()
+ # TODO: Consider calling check_done right after SDF.Process() finishing.
+ # In order to do this, we need to know that current invoking dofn is
+ # ProcessSizedElementAndRestriction.
+ self.threadsafe_restriction_tracker.check_done()
+ deferred_status = self.threadsafe_restriction_tracker.deferred_status()
+ output_watermark = None
+ if self.watermark_estimator:
+ output_watermark = self.watermark_estimator.current_watermark()
if deferred_status:
deferred_restriction, deferred_watermark = deferred_status
element = windowed_value.value
size = self.signature.get_restriction_provider().restriction_size(
element, deferred_restriction)
- return (
+ return ((
windowed_value.with_value(((element, deferred_restriction), size)),
- deferred_watermark)
+ output_watermark), deferred_watermark)
def try_split(self, fraction):
- restriction_tracker = self.restriction_tracker
+ restriction_tracker = self.threadsafe_restriction_tracker
current_windowed_value = self.current_windowed_value
if restriction_tracker and current_windowed_value:
# Temporary workaround for [BEAM-7473]: get current_watermark before
# split, in case watermark gets advanced before getting split results.
# In worst case, current_watermark is always stale, which is ok.
- current_watermark = restriction_tracker.current_watermark()
+ if self.watermark_estimator:
+ current_watermark = self.watermark_estimator.current_watermark()
+ else:
+ current_watermark = None
split = restriction_tracker.try_split(fraction)
if split:
primary, residual = split
@@ -690,15 +720,13 @@
primary_size = restriction_provider.restriction_size(element, primary)
residual_size = restriction_provider.restriction_size(element, residual)
return (
- (self.current_windowed_value.with_value(
- ((element, primary), primary_size)),
- None),
- (self.current_windowed_value.with_value(
- ((element, residual), residual_size)),
- current_watermark))
+ ((self.current_windowed_value.with_value((
+ (element, primary), primary_size)), None), None),
+ ((self.current_windowed_value.with_value((
+ (element, residual), residual_size)), current_watermark), None))
def current_element_progress(self):
- restriction_tracker = self.restriction_tracker
+ restriction_tracker = self.threadsafe_restriction_tracker
if restriction_tracker:
return restriction_tracker.current_progress()
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index 946ef34..fd04d4c 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -51,6 +51,9 @@
def create_tracker(self, restriction):
return OffsetRestrictionTracker(restriction)
+ def restriction_size(self, element, restriction):
+ return restriction.size()
+
class ReadFiles(DoFn):
@@ -63,12 +66,11 @@
restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
*args, **kwargs):
file_name = element
- assert isinstance(restriction_tracker, OffsetRestrictionTracker)
with open(file_name, 'rb') as file:
- pos = restriction_tracker.start_position()
- if restriction_tracker.start_position() > 0:
- file.seek(restriction_tracker.start_position() - 1)
+ pos = restriction_tracker.current_restriction().start
+ if restriction_tracker.current_restriction().start > 0:
+ file.seek(restriction_tracker.current_restriction().start - 1)
line = file.readline()
pos = pos - 1 + len(line)
@@ -104,6 +106,9 @@
def split(self, element, restriction):
return [restriction,]
+ def restriction_size(self, element, restriction):
+ return restriction.size()
+
class ExpandStrings(DoFn):
@@ -118,10 +123,9 @@
side.extend(side1)
side.extend(side2)
side.extend(side3)
- assert isinstance(restriction_tracker, OffsetRestrictionTracker)
side = list(side)
- for i in range(restriction_tracker.start_position(),
- restriction_tracker.stop_position()):
+ for i in range(restriction_tracker.current_restriction().start,
+ restriction_tracker.current_restriction().stop):
if restriction_tracker.try_claim(i):
if not side:
yield (
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index e97d65e..377ceb7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -319,6 +319,9 @@
line = f.readline()
self.assertSetEqual(lines_actual, lines_expected)
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("BEAM-2939")
+
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("BEAM-2939")
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 2204a24..b7929cb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -41,6 +41,7 @@
from tenacity import stop_after_attempt
import apache_beam as beam
+from apache_beam.io import iobase
from apache_beam.io import restriction_trackers
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
@@ -56,9 +57,11 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.tools import utils
+from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms import window
+from apache_beam.utils import timestamp
if statesampler.FAST_SAMPLER:
DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
@@ -423,17 +426,14 @@
assert_that(actual, is_buffered_correctly)
def test_sdf(self):
-
class ExpandingStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(
- restriction_tracker,
- restriction_trackers.OffsetRestrictionTracker), restriction_tracker
- cur = restriction_tracker.start_position()
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
cur += 1
@@ -446,6 +446,56 @@
| beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
+ def test_sdf_with_check_done_failed(self):
+ class ExpandingStringsDoFn(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider())):
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
+ while restriction_tracker.try_claim(cur):
+ yield element[cur]
+ cur += 1
+ return
+ with self.assertRaises(Exception):
+ with self.create_pipeline() as p:
+ data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+ _ = (
+ p
+ | beam.Create(data)
+ | beam.ParDo(ExpandingStringsDoFn()))
+
+ def test_sdf_with_watermark_tracking(self):
+
+ class ExpandingStringsDoFn(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider()),
+ watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
+ core.WatermarkEstimator())):
+ cur = restriction_tracker.current_restriction().start
+ start = cur
+ while restriction_tracker.try_claim(cur):
+ watermark_estimator.set_watermark(timestamp.Timestamp(micros=cur))
+ assert watermark_estimator.current_watermark().micros == start
+ yield element[cur]
+ if cur % 2 == 1:
+ restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
+ return
+ cur += 1
+
+ with self.create_pipeline() as p:
+ data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+ actual = (
+ p
+ | beam.Create(data)
+ | beam.ParDo(ExpandingStringsDoFn()))
+ assert_that(actual, equal_to(list(''.join(data))))
+
def test_sdf_with_sdf_initiated_checkpointing(self):
counter = beam.metrics.Metrics.counter('ns', 'my_counter')
@@ -456,10 +506,8 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(
- restriction_tracker,
- restriction_trackers.OffsetRestrictionTracker), restriction_tracker
- cur = restriction_tracker.start_position()
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
counter.inc()
yield element[cur]
@@ -1123,6 +1171,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
@@ -1142,6 +1193,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
@@ -1172,6 +1226,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerSplitTest(unittest.TestCase):
@@ -1340,7 +1397,7 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
to_emit = []
- cur = restriction_tracker.start_position()
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
to_emit.append((element, cur))
element_counter.increment()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 8439c8f..b3440df 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -32,6 +32,7 @@
from builtins import object
from future.utils import itervalues
+from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
import apache_beam as beam
@@ -704,8 +705,7 @@
) = split
if element_primary:
split_response.primary_roots.add().CopyFrom(
- self.delayed_bundle_application(
- *element_primary).application)
+ self.bundle_application(*element_primary))
if element_residual:
split_response.residual_roots.add().CopyFrom(
self.delayed_bundle_application(*element_residual))
@@ -718,22 +718,39 @@
return split_response
def delayed_bundle_application(self, op, deferred_remainder):
- transform_id, main_input_tag, main_input_coder, outputs = op.input_info
# TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
- element_and_restriction, watermark = deferred_remainder
- if watermark:
- proto_watermark = timestamp_pb2.Timestamp()
- proto_watermark.FromMicroseconds(watermark.micros)
- output_watermarks = {output: proto_watermark for output in outputs}
+ ((element_and_restriction, output_watermark),
+ deferred_watermark) = deferred_remainder
+ if deferred_watermark:
+ assert isinstance(deferred_watermark, timestamp.Duration)
+ proto_deferred_watermark = duration_pb2.Duration()
+ proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+ else:
+ proto_deferred_watermark = None
+ return beam_fn_api_pb2.DelayedBundleApplication(
+ requested_time_delay=proto_deferred_watermark,
+ application=self.construct_bundle_application(
+ op, output_watermark, element_and_restriction))
+
+ def bundle_application(self, op, primary):
+ ((element_and_restriction, output_watermark),
+ _) = primary
+ return self.construct_bundle_application(
+ op, output_watermark, element_and_restriction)
+
+ def construct_bundle_application(self, op, output_watermark, element):
+ transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+ if output_watermark:
+ proto_output_watermark = timestamp_pb2.Timestamp()
+ proto_output_watermark.FromMicroseconds(output_watermark.micros)
+ output_watermarks = {output: proto_output_watermark for output in outputs}
else:
output_watermarks = None
- return beam_fn_api_pb2.DelayedBundleApplication(
- application=beam_fn_api_pb2.BundleApplication(
- transform_id=transform_id,
- input_id=main_input_tag,
- output_watermarks=output_watermarks,
- element=main_input_coder.get_impl().encode_nested(
- element_and_restriction)))
+ return beam_fn_api_pb2.BundleApplication(
+ transform_id=transform_id,
+ input_id=main_input_tag,
+ output_watermarks=output_watermarks,
+ element=main_input_coder.get_impl().encode_nested(element))
def metrics(self):
# DEPRECATED
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index 50740ba..fbef112 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -523,7 +523,7 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(
SyntheticSDFSourceRestrictionProvider())):
- cur = restriction_tracker.start_position()
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
r = np.random.RandomState(cur)
time.sleep(element['sleep_per_input_record_sec'])
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 148caae..06fd201 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -63,6 +63,7 @@
from apache_beam.typehints.decorators import get_type_hints
from apache_beam.typehints.trivial_inference import element_type
from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import timestamp
from apache_beam.utils import urns
try:
@@ -91,7 +92,8 @@
'Flatten',
'Create',
'Impulse',
- 'RestrictionProvider'
+ 'RestrictionProvider',
+ 'WatermarkEstimator'
]
# Type variables
@@ -242,6 +244,8 @@
def create_tracker(self, restriction):
"""Produces a new ``RestrictionTracker`` for the given restriction.
+ This API is required to be implemented.
+
Args:
restriction: an object that defines a restriction as identified by a
Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
@@ -252,7 +256,10 @@
raise NotImplementedError
def initial_restriction(self, element):
- """Produces an initial restriction for the given element."""
+ """Produces an initial restriction for the given element.
+
+ This API is required to be implemented.
+ """
raise NotImplementedError
def split(self, element, restriction):
@@ -262,6 +269,9 @@
reading input element for each of the returned restrictions should be the
same as the total set of elements produced by reading the input element for
the input restriction.
+
+ This API is optional if ``split_and_size`` has been implemented.
+
"""
yield restriction
@@ -281,11 +291,16 @@
By default, asks a newly-created restriction tracker for the default size
of the restriction.
+
+ This API is required to be implemented.
"""
- return self.create_tracker(restriction).default_size()
+ raise NotImplementedError
def split_and_size(self, element, restriction):
"""Like split, but also does sizing, returning (restriction, size) pairs.
+
+ This API is optional if ``split`` and ``restriction_size`` have been
+ implemented.
"""
for part in self.split(element, restriction):
yield part, self.restriction_size(element, part)
@@ -379,6 +394,43 @@
return None
+class WatermarkEstimator(object):
+ """A WatermarkEstimator which is used for tracking output_watermark in a
+ DoFn.process(), typically tracking per <element, restriction> pair in SDF in
+ streaming.
+
+ There are 3 APIs in this class: set_watermark, current_watermark and reset
+ with default implementations.
+
+ TODO(BEAM-8537): Create WatermarkEstimatorProvider to support different types.
+ """
+ def __init__(self):
+ self._watermark = None
+
+ def set_watermark(self, watermark):
+ """Update tracking output_watermark with latest output_watermark.
+ This function is called inside an SDF.Process() to track the watermark of
+ output element.
+
+ Args:
+ watermark: the `timestamp.Timestamp` of current output element.
+ """
+ if not isinstance(watermark, timestamp.Timestamp):
+ raise ValueError('watermark should be a object of timestamp.Timestamp')
+ if self._watermark is None:
+ self._watermark = watermark
+ else:
+ self._watermark = min(self._watermark, watermark)
+
+ def current_watermark(self):
+ """Get current output_watermark. This function is called by system."""
+ return self._watermark
+
+ def reset(self):
+ """ Reset current tracking watermark to None."""
+ self._watermark = None
+
+
class _DoFnParam(object):
"""DoFn parameter."""
@@ -459,6 +511,17 @@
del self._callbacks[:]
+class _WatermarkEstimatorParam(_DoFnParam):
+ """WatermarkEstomator DoFn parameter."""
+
+ def __init__(self, watermark_estimator):
+ if not isinstance(watermark_estimator, WatermarkEstimator):
+ raise ValueError('DoFn.WatermarkEstimatorParam expected'
+ 'WatermarkEstimator object.')
+ self.watermark_estimator = watermark_estimator
+ self.param_id = 'WatermarkEstimator'
+
+
class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
"""A function object used by a transform with custom processing.
@@ -477,7 +540,7 @@
TimestampParam = _DoFnParam('TimestampParam')
WindowParam = _DoFnParam('WindowParam')
PaneInfoParam = _DoFnParam('PaneInfoParam')
- WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+ WatermarkEstimatorParam = _WatermarkEstimatorParam
BundleFinalizerParam = _BundleFinalizerParam
KeyParam = _DoFnParam('KeyParam')
@@ -489,7 +552,7 @@
TimerParam = _TimerDoFnParam
DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
- WindowParam, WatermarkReporterParam, PaneInfoParam,
+ WindowParam, WatermarkEstimatorParam, PaneInfoParam,
BundleFinalizerParam, KeyParam, StateParam, TimerParam]
RestrictionParam = _RestrictionDoFnParam
@@ -522,7 +585,7 @@
``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
provided here to allow treatment as a Splittable ``DoFn``. The restriction
tracker will be derived from the restriction provider in the parameter.
- ``DoFn.WatermarkReporterParam``: a function that can be used to report
+ ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
output watermark of Splittable ``DoFn`` implementations.
Args:
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
new file mode 100644
index 0000000..1a27bd2
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for core module."""
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.transforms.core import WatermarkEstimator
+from apache_beam.utils.timestamp import Timestamp
+
+
+class WatermarkEstimatorTest(unittest.TestCase):
+
+ def test_set_watermark(self):
+ watermark_estimator = WatermarkEstimator()
+ self.assertEqual(watermark_estimator.current_watermark(), None)
+ # set_watermark should only accept timestamp.Timestamp.
+ with self.assertRaises(ValueError):
+ watermark_estimator.set_watermark(0)
+
+ # watermark_estimator should always keep minimal timestamp.
+ watermark_estimator.set_watermark(Timestamp(100))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.set_watermark(Timestamp(150))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.set_watermark(Timestamp(50))
+ self.assertEqual(watermark_estimator.current_watermark(), 50)
+
+ def test_reset(self):
+ watermark_estimator = WatermarkEstimator()
+ watermark_estimator.set_watermark(Timestamp(100))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.reset()
+ self.assertEqual(watermark_estimator.current_watermark(), None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index bb7e522..7a87e60 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -241,7 +241,8 @@
target_batch_overhead=.1,
target_batch_duration_secs=1,
variance=0.25,
- clock=time.time):
+ clock=time.time,
+ ignore_first_n_seen_per_batch_size=0):
if min_batch_size > max_batch_size:
raise ValueError("Minimum (%s) must not be greater than maximum (%s)" % (
min_batch_size, max_batch_size))
@@ -254,6 +255,9 @@
if not (target_batch_overhead or target_batch_duration_secs):
raise ValueError("At least one of target_batch_overhead or "
"target_batch_duration_secs must be positive.")
+ if ignore_first_n_seen_per_batch_size < 0:
+ raise ValueError('ignore_first_n_seen_per_batch_size (%s) must be non '
+ 'negative' % (ignore_first_n_seen_per_batch_size))
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._target_batch_overhead = target_batch_overhead
@@ -262,6 +266,10 @@
self._clock = clock
self._data = []
self._ignore_next_timing = False
+ self._ignore_first_n_seen_per_batch_size = (
+ ignore_first_n_seen_per_batch_size)
+ self._batch_size_num_seen = {}
+ self._replay_last_batch_size = None
self._size_distribution = Metrics.distribution(
'BatchElements', 'batch_size')
@@ -279,7 +287,7 @@
For example, the first emit of a ParDo operation is known to be anomalous
due to setup that may occur.
"""
- self._ignore_next_timing = False
+ self._ignore_next_timing = True
@contextlib.contextmanager
def record_time(self, batch_size):
@@ -290,8 +298,11 @@
self._size_distribution.update(batch_size)
self._time_distribution.update(int(elapsed_msec))
self._remainder_msecs = elapsed_msec - int(elapsed_msec)
+ # If we ignore the next timing, replay the batch size to get accurate
+ # timing.
if self._ignore_next_timing:
self._ignore_next_timing = False
+ self._replay_last_batch_size = batch_size
else:
self._data.append((batch_size, elapsed))
if len(self._data) >= self._MAX_DATA_POINTS:
@@ -364,7 +375,7 @@
except ImportError:
linear_regression = linear_regression_no_numpy
- def next_batch_size(self):
+ def _calculate_next_batch_size(self):
if self._min_batch_size == self._max_batch_size:
return self._min_batch_size
elif len(self._data) < 1:
@@ -414,6 +425,21 @@
return int(max(self._min_batch_size + jitter, min(target, cap)))
+ def next_batch_size(self):
+ # Check if we should replay a previous batch size due to it not being
+ # recorded.
+ if self._replay_last_batch_size:
+ result = self._replay_last_batch_size
+ self._replay_last_batch_size = None
+ else:
+ result = self._calculate_next_batch_size()
+
+ seen_count = self._batch_size_num_seen.get(result, 0) + 1
+ if seen_count <= self._ignore_first_n_seen_per_batch_size:
+ self.ignore_next_timing()
+ self._batch_size_num_seen[result] = seen_count
+ return result
+
class _GlobalWindowsBatchingDoFn(DoFn):
def __init__(self, batch_size_estimator):
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 4588c32..6ac05d0 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -157,6 +157,60 @@
self.assertLess(
max(stable_set), expected_target + expected_target * variance)
+ def test_ignore_first_n_batch_size(self):
+ clock = FakeClock()
+ batch_estimator = util._BatchSizeEstimator(
+ clock=clock, ignore_first_n_seen_per_batch_size=2)
+
+ expected_sizes = [
+ 1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 64, 64, 64
+ ]
+ actual_sizes = []
+ for i in range(len(expected_sizes)):
+ actual_sizes.append(batch_estimator.next_batch_size())
+ with batch_estimator.record_time(actual_sizes[-1]):
+ if i % 3 == 2:
+ clock.sleep(0.01)
+ else:
+ clock.sleep(1)
+
+ self.assertEqual(expected_sizes, actual_sizes)
+
+ # Check we only record the third timing.
+ expected_data_batch_sizes = [1, 2, 4, 8, 16, 32, 64]
+ actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
+ self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
+ expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
+ for i in range(len(expected_data_timing)):
+ self.assertAlmostEqual(
+ expected_data_timing[i], batch_estimator._data[i][1])
+
+ def test_ignore_next_timing(self):
+ clock = FakeClock()
+ batch_estimator = util._BatchSizeEstimator(clock=clock)
+ batch_estimator.ignore_next_timing()
+
+ expected_sizes = [1, 1, 2, 4, 8, 16]
+ actual_sizes = []
+ for i in range(len(expected_sizes)):
+ actual_sizes.append(batch_estimator.next_batch_size())
+ with batch_estimator.record_time(actual_sizes[-1]):
+ if i == 0:
+ clock.sleep(1)
+ else:
+ clock.sleep(0.01)
+
+ self.assertEqual(expected_sizes, actual_sizes)
+
+ # Check the first record_time was skipped.
+ expected_data_batch_sizes = [1, 2, 4, 8, 16]
+ actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
+ self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
+ expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01]
+ for i in range(len(expected_data_timing)):
+ self.assertAlmostEqual(
+ expected_data_timing[i], batch_estimator._data[i][1])
+
def _run_regression_test(self, linear_regression_fn, test_outliers):
xs = [random.random() for _ in range(10)]
ys = [2*x + 1 for x in xs]
diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py
index dda651d..0d5e14f 100644
--- a/sdks/python/apache_beam/transforms/window_test.py
+++ b/sdks/python/apache_beam/transforms/window_test.py
@@ -22,6 +22,8 @@
import unittest
from builtins import range
+from nose.plugins.attrib import attr
+
import apache_beam as beam
from apache_beam.runners import pipeline_context
from apache_beam.testing.test_pipeline import TestPipeline
@@ -281,6 +283,35 @@
assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]),
label='assert:mean')
+ @attr('ValidatesRunner')
+ def test_window_assignment_idempotency(self):
+ with TestPipeline() as p:
+ pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+ result = (pcoll
+ | 'window' >> WindowInto(FixedWindows(2))
+ | 'same window' >> WindowInto(FixedWindows(2))
+ | 'same window again' >> WindowInto(FixedWindows(2))
+ | GroupByKey())
+
+ assert_that(result, equal_to([('key', [0, 1]),
+ ('key', [2, 3]),
+ ('key', [4])]))
+
+ @attr('ValidatesRunner')
+ def test_window_assignment_through_multiple_gbk_idempotency(self):
+ with TestPipeline() as p:
+ pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+ result = (pcoll
+ | 'window' >> WindowInto(FixedWindows(2))
+ | 'gbk' >> GroupByKey()
+ | 'same window' >> WindowInto(FixedWindows(2))
+ | 'another gbk' >> GroupByKey()
+ | 'same window again' >> WindowInto(FixedWindows(2))
+ | 'gbk again' >> GroupByKey())
+
+ assert_that(result, equal_to([('key', [[[0, 1]]]),
+ ('key', [[[2, 3]]]),
+ ('key', [[[4]]])]))
class RunnerApiTest(unittest.TestCase):
diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py
index 9bccdfd..a3f3abf 100644
--- a/sdks/python/apache_beam/utils/timestamp.py
+++ b/sdks/python/apache_beam/utils/timestamp.py
@@ -25,6 +25,7 @@
import datetime
import functools
+import time
from builtins import object
import dateutil.parser
@@ -76,6 +77,10 @@
return Timestamp(seconds)
@staticmethod
+ def now():
+ return Timestamp(seconds=time.time())
+
+ @staticmethod
def _epoch_datetime_utc():
return datetime.datetime.fromtimestamp(0, pytz.utc)
@@ -173,6 +178,8 @@
return self + other
def __sub__(self, other):
+ if isinstance(other, Timestamp):
+ return Duration(micros=self.micros - other.micros)
other = Duration.of(other)
return Timestamp(micros=self.micros - other.micros)
diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py
index d26d561..2a4d454 100644
--- a/sdks/python/apache_beam/utils/timestamp_test.py
+++ b/sdks/python/apache_beam/utils/timestamp_test.py
@@ -100,6 +100,7 @@
self.assertEqual(Timestamp(123) - Duration(456), -333)
self.assertEqual(Timestamp(1230) % 456, 318)
self.assertEqual(Timestamp(1230) % Duration(456), 318)
+ self.assertEqual(Timestamp(123) - Timestamp(100), 23)
# Check that direct comparison of Timestamp and Duration is allowed.
self.assertTrue(Duration(123) == Timestamp(123))
@@ -116,6 +117,7 @@
self.assertEqual((Timestamp(123) - Duration(456)).__class__, Timestamp)
self.assertEqual((Timestamp(1230) % 456).__class__, Duration)
self.assertEqual((Timestamp(1230) % Duration(456)).__class__, Duration)
+ self.assertEqual((Timestamp(123) - Timestamp(100)).__class__, Duration)
# Unsupported operations.
with self.assertRaises(TypeError):
@@ -159,6 +161,10 @@
self.assertEqual('Timestamp(-999999999)',
str(Timestamp(-999999999)))
+ def test_now(self):
+ now = Timestamp.now()
+ self.assertTrue(isinstance(now, Timestamp))
+
class DurationTest(unittest.TestCase):
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index e3794ba..bc77fd8 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -183,6 +183,7 @@
'_TimerDoFnParam',
'_BundleFinalizerParam',
'_RestrictionDoFnParam',
+ '_WatermarkEstimatorParam',
# Sphinx cannot find this py:class reference target
'typing.Generic',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 7eea64c..9f1a9f3 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -228,7 +228,10 @@
python_requires=python_requires,
test_suite='nose.collector',
setup_requires=['pytest_runner'],
- tests_require=REQUIRED_TEST_PACKAGES,
+ tests_require= [
+ REQUIRED_TEST_PACKAGES,
+ INTERACTIVE_BEAM,
+ ],
extras_require={
'docs': ['Sphinx>=1.5.2,<2.0'],
'test': REQUIRED_TEST_PACKAGES,
diff --git a/sdks/python/test-suites/tox/py37/build.gradle b/sdks/python/test-suites/tox/py37/build.gradle
index 2a57ca9..c9c99e6 100644
--- a/sdks/python/test-suites/tox/py37/build.gradle
+++ b/sdks/python/test-suites/tox/py37/build.gradle
@@ -41,9 +41,6 @@
toxTask "testPy37Cython", "py37-cython"
test.dependsOn testPy37Cython
-toxTask "testPy37Interactive", "py37-interactive"
-test.dependsOn testPy37Interactive
-
// Ensure that testPy37Cython runs exclusively to other tests. This line is not
// actually required, since gradle doesn't do parallel execution within a
// project.
@@ -60,7 +57,6 @@
task preCommitPy37() {
dependsOn "testPy37Gcp"
dependsOn "testPy37Cython"
- dependsOn "testPy37Interactive"
}
task preCommitPy37Pytest {
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index fe3f65a..d227519 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -307,10 +307,3 @@
coverage report --skip-covered
# Generate report in xml format
coverage xml
-
-[testenv:py37-interactive]
-setenv =
- RUN_SKIPPED_PY3_TESTS=0
-extras = test,interactive
-commands =
- python setup.py nosetests --ignore-files '.*py3[8-9]\.py$' {posargs}