Merge pull request #10090 Update container image tags used by Dataflow runner for Beam master

diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ed2f013..0ddc48e 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -42,6 +42,7 @@
 import "endpoints.proto";
 import "google/protobuf/descriptor.proto";
 import "google/protobuf/timestamp.proto";
+import "google/protobuf/duration.proto";
 import "google/protobuf/wrappers.proto";
 import "metrics.proto";
 
@@ -203,13 +204,21 @@
 }
 
 // An Application should be scheduled for execution after a delay.
+// Either an absolute timestamp or a relative timestamp can represent a
+// scheduled execution time.
 message DelayedBundleApplication {
   // Recommended time at which the application should be scheduled to execute
   // by the runner. Times in the past may be scheduled to execute immediately.
+  // TODO(BEAM-8536): Migrate usage of absolute time to requested_time_delay.
   google.protobuf.Timestamp requested_execution_time = 1;
 
   // (Required) The application that should be scheduled.
   BundleApplication application = 2;
+
+  // Recommended time delay at which the application should be scheduled to
+  // execute by the runner. Time delay that equals 0 may be scheduled to execute
+  // immediately. The unit of time delay should be microsecond.
+  google.protobuf.Duration requested_time_delay = 3;
 }
 
 // A request to process a given bundle.
diff --git a/runners/spark/job-server/build.gradle b/runners/spark/job-server/build.gradle
index 2b27a88..6fb7581 100644
--- a/runners/spark/job-server/build.gradle
+++ b/runners/spark/job-server/build.gradle
@@ -114,6 +114,7 @@
       excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
       excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
       excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
+      excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering'
     },
   )
 }
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
index ab03de0..0f0b5b3 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptionsTest.java
@@ -159,21 +159,6 @@
       options.getGcpTempLocation();
     }
 
-    @Test
-    public void testDefaultGcpTempLocationDoesNotExist() {
-      GcpOptions options = PipelineOptionsFactory.as(GcpOptions.class);
-      String tempLocation = "gs://does/not/exist";
-      options.setTempLocation(tempLocation);
-      thrown.expect(IllegalArgumentException.class);
-      thrown.expectMessage(
-          "Error constructing default value for gcpTempLocation: tempLocation is not"
-              + " a valid GCS path");
-      thrown.expectCause(
-          hasMessage(containsString("Output path does not exist or is not writeable")));
-
-      options.getGcpTempLocation();
-    }
-
     private static void makePropertiesFileWithProject(File path, String projectId)
         throws IOException {
       String properties =
@@ -221,6 +206,21 @@
     }
 
     @Test
+    public void testDefaultGcpTempLocationDoesNotExist() throws IOException {
+      String tempLocation = "gs://does/not/exist";
+      options.setTempLocation(tempLocation);
+      when(mockGcsUtil.bucketAccessible(any(GcsPath.class))).thenReturn(false);
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage(
+          "Error constructing default value for gcpTempLocation: tempLocation is not"
+              + " a valid GCS path");
+      thrown.expectCause(
+          hasMessage(containsString("Output path does not exist or is not writeable")));
+
+      options.as(GcpOptions.class).getGcpTempLocation();
+    }
+
+    @Test
     public void testCreateBucket() throws Exception {
       doReturn(fakeProject).when(mockGet).execute();
       when(mockGcsUtil.bucketOwner(any(GcsPath.class))).thenReturn(1L);
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 5b66730..e21052f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -35,6 +35,7 @@
 import logging
 import math
 import random
+import threading
 import uuid
 from builtins import object
 from builtins import range
@@ -1104,13 +1105,17 @@
 class RestrictionTracker(object):
   """Manages concurrent access to a restriction.
 
-  Experimental; no backwards-compatibility guarantees.
-
   Keeps track of the restrictions claimed part for a Splittable DoFn.
 
+  The restriction may be modified by different threads, however the system will
+  ensure sufficient locking such that no methods on the restriction tracker
+  will be called concurrently.
+
   See following documents for more details.
   * https://s.apache.org/splittable-do-fn
   * https://s.apache.org/splittable-do-fn-python-sdk
+
+  Experimental; no backwards-compatibility guarantees.
   """
 
   def current_restriction(self):
@@ -1121,54 +1126,22 @@
 
     The current restriction returned by method may be updated dynamically due
     to due to concurrent invocation of other methods of the
-    ``RestrictionTracker``, For example, ``checkpoint()``.
+    ``RestrictionTracker``, For example, ``split()``.
 
-    ** Thread safety **
+    This API is required to be implemented.
 
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    Returns: a restriction object.
     """
     raise NotImplementedError
 
   def current_progress(self):
     """Returns a RestrictionProgress object representing the current progress.
+
+    This API is recommended to be implemented. The runner can do a better job
+    at parallel processing with better progress signals.
     """
     raise NotImplementedError
 
-  def current_watermark(self):
-    """Returns current watermark. By default, not report watermark.
-
-    TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
-    """
-    return None
-
-  def checkpoint(self):
-    """Performs a checkpoint of the current restriction.
-
-    Signals that the current ``DoFn.process()`` call should terminate as soon as
-    possible. After this method returns, the tracker MUST refuse all future
-    claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.
-
-    This invocation modifies the value returned by ``current_restriction()``
-    invocation and returns a restriction representing the rest of the work. The
-    old value of ``current_restriction()`` is equivalent to the new value of
-    ``current_restriction()`` and the return value of this method invocation
-    combined.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
-    """
-
-    raise NotImplementedError
-
   def check_done(self):
     """Checks whether the restriction has been fully processed.
 
@@ -1179,13 +1152,8 @@
     remaining in the restriction when this method is invoked. Exception raised
     must have an informative error message.
 
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    This API is required to be implemented in order to make sure no data loss
+    during SDK processing.
 
     Returns: ``True`` if current restriction has been fully processed.
     Raises:
@@ -1215,8 +1183,12 @@
     restrictions returned would be [100, 179), [179, 200) (note: current_offset
     + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
 
-    It is very important for pipeline scaling and end to end pipeline execution
-    that try_split is implemented well.
+    ``fraction_of_remainder`` = 0 means a checkpoint is required.
+
+    The API is recommended to be implemented for batch pipeline given that it is
+    very important for pipeline scaling and end to end pipeline execution.
+
+    The API is required to be implemented for a streaming pipeline.
 
     Args:
       fraction_of_remainder: A hint as to the fraction of work the primary
@@ -1226,19 +1198,11 @@
     Returns:
       (primary_restriction, residual_restriction) if a split was possible,
       otherwise returns ``None``.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
     """
     raise NotImplementedError
 
   def try_claim(self, position):
-    """ Attempts to claim the block of work in the current restriction
+    """Attempts to claim the block of work in the current restriction
     identified by the given position.
 
     If this succeeds, the DoFn MUST execute the entire block of work. If it
@@ -1247,40 +1211,137 @@
     work from ``DoFn.process()`` is also not allowed before the first call of
     this method).
 
+    The API is required to be implemented.
+
     Args:
       position: current position that wants to be claimed.
 
     Returns: ``True`` if the position can be claimed as current_position.
     Otherwise, returns ``False``.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
     """
     raise NotImplementedError
 
-  def defer_remainder(self, watermark=None):
-    """ Invokes checkpoint() in an SDF.process().
 
-    TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
-    ``ProcessContinuation``.
+class ThreadsafeRestrictionTracker(object):
+  """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+  This wrapper guarantees synchronization of modifying restrictions across
+  multi-thread.
+  """
+
+  def __init__(self, restriction_tracker):
+    if not isinstance(restriction_tracker, RestrictionTracker):
+      raise ValueError(
+          'Initialize ThreadsafeRestrictionTracker requires'
+          'RestrictionTracker.')
+    self._restriction_tracker = restriction_tracker
+    # Records an absolute timestamp when defer_remainder is called.
+    self._deferred_timestamp = None
+    self._lock = threading.RLock()
+    self._deferred_residual = None
+    self._deferred_watermark = None
+
+  def current_restriction(self):
+    with self._lock:
+      return self._restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    with self._lock:
+      return self._restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    """Performs self-checkpoint on current processing restriction with an
+    expected resuming time.
+
+    Self-checkpoint could happen during processing elements. When executing an
+    DoFn.process(), you may want to stop processing an element and resuming
+    later if current element has been processed quit a long time or you also
+    want to have some outputs from other elements. ``defer_remainder()`` can be
+    called on per element if needed.
 
     Args:
-      watermark
+      deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+      time gap between now and resuming, or an absolute ``timestamp.Timestamp``
+      for resuming execution time. If the time_delay is None, the deferred work
+      will be executed as soon as possible.
     """
-    raise NotImplementedError
+
+    # Record current time for calculating deferred_time later.
+    self._deferred_timestamp = timestamp.Timestamp.now()
+    if (deferred_time and
+        not isinstance(deferred_time, timestamp.Duration) and
+        not isinstance(deferred_time, timestamp.Timestamp)):
+      raise ValueError('The timestamp of deter_remainder() should be a '
+                       'Duration or a Timestamp, or None.')
+    self._deferred_watermark = deferred_time
+    checkpoint = self.try_split(0)
+    if checkpoint:
+      _, self._deferred_residual = checkpoint
+
+  def check_done(self):
+    with self._lock:
+      return self._restriction_tracker.check_done()
+
+  def current_progress(self):
+    with self._lock:
+      return self._restriction_tracker.current_progress()
+
+  def try_split(self, fraction_of_remainder):
+    with self._lock:
+      return self._restriction_tracker.try_split(fraction_of_remainder)
 
   def deferred_status(self):
-    """ Returns deferred_residual with deferred_watermark.
+    """Returns deferred work which is produced by ``defer_remainder()``.
 
-    TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
-    ``ProcessContinuation``.
+    When there is a self-checkpoint performed, the system needs to fulfill the
+    DelayedBundleApplication with deferred_work for a  ProcessBundleResponse.
+    The system calls this API to get deferred_residual with watermark together
+    to help the runner to schedule a future work.
+
+    Returns: (deferred_residual, time_delay) if having any residual, else None.
     """
-    raise NotImplementedError
+    if self._deferred_residual:
+      # If _deferred_watermark is None, create Duration(0).
+      if not self._deferred_watermark:
+        self._deferred_watermark = timestamp.Duration()
+      # If an absolute timestamp is provided, calculate the delta between
+      # the absoluted time and the time deferred_status() is called.
+      elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+        self._deferred_watermark = (self._deferred_watermark -
+                                    timestamp.Timestamp.now())
+      # If a Duration is provided, the deferred time should be:
+      # provided duration - the spent time since the defer_remainder() is
+      # called.
+      elif isinstance(self._deferred_watermark, timestamp.Duration):
+        self._deferred_watermark -= (timestamp.Timestamp.now() -
+                                     self._deferred_timestamp)
+      return self._deferred_residual, self._deferred_watermark
+
+
+class RestrictionTrackerView(object):
+  """A DoFn view of thread-safe RestrictionTracker.
+
+  The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+  exposes APIs that will be called by a ``DoFn.process()``. During execution
+  time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+  restriction_tracker.
+  """
+
+  def __init__(self, threadsafe_restriction_tracker):
+    if not isinstance(threadsafe_restriction_tracker,
+                      ThreadsafeRestrictionTracker):
+      raise ValueError('Initialize RestrictionTrackerView requires '
+                       'ThreadsafeRestrictionTracker.')
+    self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+  def current_restriction(self):
+    return self._threadsafe_restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    return self._threadsafe_restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
 
 
 class RestrictionProgress(object):
@@ -1400,17 +1461,8 @@
                 SourceBundle(residual_weight, self._source, split_pos,
                              stop_pos))
 
-    def deferred_status(self):
-      return None
-
-    def current_watermark(self):
-      return None
-
-    def get_delegate_range_tracker(self):
-      return self._delegate_range_tracker
-
-    def get_tracking_source(self):
-      return self._source
+    def check_done(self):
+      return self._delegate_range_tracker.fraction_consumed() >= 1.0
 
   class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
     """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1463,8 +1515,13 @@
           restriction_tracker=core.DoFn.RestrictionParam(
               _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
                   source, chunk_size))):
-        return restriction_tracker.get_tracking_source().read(
-            restriction_tracker.get_delegate_range_tracker())
+        current_restriction = restriction_tracker.current_restriction()
+        assert isinstance(current_restriction, SourceBundle)
+        tracking_source = current_restriction.source
+        start = current_restriction.start_position
+        stop = current_restriction.stop_position
+        return tracking_source.read(tracking_source.get_range_tracker(start,
+                                                                      stop))
 
     return SDFBoundedSourceDoFn(self.source)
 
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 7adb764..0a6afae 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -19,6 +19,7 @@
 
 from __future__ import absolute_import
 
+import time
 import unittest
 
 import mock
@@ -28,6 +29,9 @@
 from apache_beam.io.concat_source_test import RangeSource
 from apache_beam.io import iobase
 from apache_beam.io.iobase import SourceBundle
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.utils import timestamp
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -191,5 +195,87 @@
     self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
 
 
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+  def test_defer_remainder_with_wrong_time_type(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    with self.assertRaises(ValueError):
+      threadsafe_tracker.defer_remainder(10)
+
+  def test_self_checkpoint_immediately(self):
+    restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        restriction_tracker)
+    threadsafe_tracker.defer_remainder()
+    deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+    expected_residual = OffsetRange(0, 10)
+    self.assertEqual(deferred_residual, expected_residual)
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    self.assertEqual(deferred_time, 0)
+
+  def test_self_checkpoint_with_relative_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation = 100 - 2 - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+  def test_self_checkpoint_with_absolute_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    now = timestamp.Timestamp.now()
+    schedule_time = now + timestamp.Duration(100)
+    self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+    threadsafe_tracker.defer_remainder(schedule_time)
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation =
+    # schedule_time - the time when deferred_status is called - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.RestrictionTrackerView(
+          OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+  def test_api_expose(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    current_restriction = tracker_view.current_restriction()
+    self.assertEqual(current_restriction, OffsetRange(0, 10))
+    self.assertTrue(tracker_view.try_claim(0))
+    tracker_view.defer_remainder()
+    deferred_remainder, deferred_watermark = (
+        threadsafe_tracker.deferred_status())
+    self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+    self.assertEqual(deferred_watermark, timestamp.Duration())
+
+  def test_non_expose_apis(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    with self.assertRaises(AttributeError):
+      tracker_view.check_done()
+    with self.assertRaises(AttributeError):
+      tracker_view.current_progress()
+    with self.assertRaises(AttributeError):
+      tracker_view.try_split()
+    with self.assertRaises(AttributeError):
+      tracker_view.deferred_status()
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 0ba5b23..20bb5c1 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -19,7 +19,6 @@
 from __future__ import absolute_import
 from __future__ import division
 
-import threading
 from builtins import object
 
 from apache_beam.io.iobase import RestrictionProgress
@@ -86,104 +85,69 @@
     assert isinstance(offset_range, OffsetRange)
     self._range = offset_range
     self._current_position = None
-    self._current_watermark = None
     self._last_claim_attempt = None
-    self._deferred_residual = None
     self._checkpointed = False
-    self._lock = threading.RLock()
 
   def check_done(self):
-    with self._lock:
-      if self._last_claim_attempt < self._range.stop - 1:
-        raise ValueError(
-            'OffsetRestrictionTracker is not done since work in range [%s, %s) '
-            'has not been claimed.'
-            % (self._last_claim_attempt if self._last_claim_attempt is not None
-               else self._range.start,
-               self._range.stop))
+    if self._last_claim_attempt < self._range.stop - 1:
+      raise ValueError(
+          'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+          'has not been claimed.'
+          % (self._last_claim_attempt if self._last_claim_attempt is not None
+             else self._range.start,
+             self._range.stop))
 
   def current_restriction(self):
-    with self._lock:
-      return self._range
-
-  def current_watermark(self):
-    return self._current_watermark
+    return self._range
 
   def current_progress(self):
-    with self._lock:
-      if self._current_position is None:
-        fraction = 0.0
-      elif self._range.stop == self._range.start:
-        # If self._current_position is not None, we must be done.
-        fraction = 1.0
-      else:
-        fraction = (
-            float(self._current_position - self._range.start)
-            / (self._range.stop - self._range.start))
+    if self._current_position is None:
+      fraction = 0.0
+    elif self._range.stop == self._range.start:
+      # If self._current_position is not None, we must be done.
+      fraction = 1.0
+    else:
+      fraction = (
+          float(self._current_position - self._range.start)
+          / (self._range.stop - self._range.start))
     return RestrictionProgress(fraction=fraction)
 
   def start_position(self):
-    with self._lock:
-      return self._range.start
+    return self._range.start
 
   def stop_position(self):
-    with self._lock:
-      return self._range.stop
-
-  def default_size(self):
-    return self._range.size()
+    return self._range.stop
 
   def try_claim(self, position):
-    with self._lock:
-      if self._last_claim_attempt and position <= self._last_claim_attempt:
-        raise ValueError(
-            'Positions claimed should strictly increase. Trying to claim '
-            'position %d while last claim attempt was %d.'
-            % (position, self._last_claim_attempt))
+    if self._last_claim_attempt and position <= self._last_claim_attempt:
+      raise ValueError(
+          'Positions claimed should strictly increase. Trying to claim '
+          'position %d while last claim attempt was %d.'
+          % (position, self._last_claim_attempt))
 
-      self._last_claim_attempt = position
-      if position < self._range.start:
-        raise ValueError(
-            'Position to be claimed cannot be smaller than the start position '
-            'of the range. Tried to claim position %r for the range [%r, %r)'
-            % (position, self._range.start, self._range.stop))
+    self._last_claim_attempt = position
+    if position < self._range.start:
+      raise ValueError(
+          'Position to be claimed cannot be smaller than the start position '
+          'of the range. Tried to claim position %r for the range [%r, %r)'
+          % (position, self._range.start, self._range.stop))
 
-      if position >= self._range.start and position < self._range.stop:
-        self._current_position = position
-        return True
+    if position >= self._range.start and position < self._range.stop:
+      self._current_position = position
+      return True
 
-      return False
+    return False
 
   def try_split(self, fraction_of_remainder):
-    with self._lock:
-      if not self._checkpointed:
-        if self._current_position is None:
-          cur = self._range.start - 1
-        else:
-          cur = self._current_position
-        split_point = (
-            cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
-        if split_point < self._range.stop:
-          self._range, residual_range = self._range.split_at(split_point)
-          return self._range, residual_range
-
-  # TODO(SDF): Replace all calls with try_claim(0).
-  def checkpoint(self):
-    with self._lock:
-      # If self._current_position is 'None' no records have been claimed so
-      # residual should start from self._range.start.
+    if not self._checkpointed:
       if self._current_position is None:
-        end_position = self._range.start
+        cur = self._range.start - 1
       else:
-        end_position = self._current_position + 1
-      self._range, residual_range = self._range.split_at(end_position)
-      return residual_range
-
-  def defer_remainder(self, watermark=None):
-    with self._lock:
-      self._deferred_watermark = watermark or self._current_watermark
-      self._deferred_residual = self.checkpoint()
-
-  def deferred_status(self):
-    if self._deferred_residual:
-      return (self._deferred_residual, self._deferred_watermark)
+        cur = self._current_position
+      split_point = (
+          cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
+      if split_point < self._range.stop:
+        if fraction_of_remainder == 0:
+          self._checkpointed = True
+        self._range, residual_range = self._range.split_at(split_point)
+        return self._range, residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
index 459b039..4a57d98 100644
--- a/sdks/python/apache_beam/io/restriction_trackers_test.py
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -81,14 +81,14 @@
 
   def test_checkpoint_unstarted(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 100), tracker.current_restriction())
     self.assertEqual(OffsetRange(100, 200), checkpoint)
 
   def test_checkpoint_just_started(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(100))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 101), tracker.current_restriction())
     self.assertEqual(OffsetRange(101, 200), checkpoint)
 
@@ -96,7 +96,7 @@
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 111), tracker.current_restriction())
     self.assertEqual(OffsetRange(111, 200), checkpoint)
 
@@ -105,9 +105,9 @@
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
     self.assertTrue(tracker.try_claim(199))
-    checkpoint = tracker.checkpoint()
+    checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 200), tracker.current_restriction())
-    self.assertEqual(OffsetRange(200, 200), checkpoint)
+    self.assertEqual(None, checkpoint)
 
   def test_checkpoint_after_failed_claim(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
@@ -116,7 +116,7 @@
     self.assertTrue(tracker.try_claim(160))
     self.assertFalse(tracker.try_claim(240))
 
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertTrue(OffsetRange(100, 161), tracker.current_restriction())
     self.assertTrue(OffsetRange(161, 200), checkpoint)
 
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 2ffe432..37e05bf 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -42,6 +42,8 @@
   cdef object key_arg_name
   cdef object restriction_provider
   cdef object restriction_provider_arg_name
+  cdef object watermark_estimator
+  cdef object watermark_estimator_arg_name
 
 
 cdef class DoFnSignature(object):
@@ -91,7 +93,9 @@
   cdef bint cache_globally_windowed_args
   cdef object process_method
   cdef bint is_splittable
-  cdef object restriction_tracker
+  cdef object threadsafe_restriction_tracker
+  cdef object watermark_estimator
+  cdef object watermark_estimator_param
   cdef WindowedValue current_windowed_value
   cdef bint is_key_param_required
 
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3e14f3b..8632cfd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 # cython: profile=True
 
 """Worker operations executor.
@@ -167,6 +166,8 @@
     self.key_arg_name = None
     self.restriction_provider = None
     self.restriction_provider_arg_name = None
+    self.watermark_estimator = None
+    self.watermark_estimator_arg_name = None
 
     for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
       if isinstance(v, core.DoFn.StateParam):
@@ -184,6 +185,9 @@
       elif isinstance(v, core.DoFn.RestrictionParam):
         self.restriction_provider = v.restriction_provider
         self.restriction_provider_arg_name = kw
+      elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
+        self.watermark_estimator = v.watermark_estimator
+        self.watermark_estimator_arg_name = kw
 
   def invoke_timer_callback(self,
                             user_state_context,
@@ -264,6 +268,9 @@
   def get_restriction_provider(self):
     return self.process_method.restriction_provider
 
+  def get_watermark_estimator(self):
+    return self.process_method.watermark_estimator
+
   def _validate(self):
     self._validate_process()
     self._validate_bundle_method(self.start_bundle_method)
@@ -458,7 +465,11 @@
         signature.is_stateful_dofn())
     self.user_state_context = user_state_context
     self.is_splittable = signature.is_splittable_dofn()
-    self.restriction_tracker = None
+    self.watermark_estimator = self.signature.get_watermark_estimator()
+    self.watermark_estimator_param = (
+        self.signature.process_method.watermark_estimator_arg_name
+        if self.watermark_estimator else None)
+    self.threadsafe_restriction_tracker = None
     self.current_windowed_value = None
     self.bundle_finalizer_param = bundle_finalizer_param
     self.is_key_param_required = False
@@ -569,15 +580,24 @@
         raise ValueError(
             'A RestrictionTracker %r was provided but DoFn does not have a '
             'RestrictionTrackerParam defined' % restriction_tracker)
-      additional_kwargs[restriction_tracker_param] = restriction_tracker
+      from apache_beam.io import iobase
+      self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker(
+          restriction_tracker)
+      additional_kwargs[restriction_tracker_param] = (
+          iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+
+      if self.watermark_estimator:
+        # The watermark estimator needs to be reset for every element.
+        self.watermark_estimator.reset()
+        additional_kwargs[self.watermark_estimator_param] = (
+            self.watermark_estimator)
       try:
         self.current_windowed_value = windowed_value
-        self.restriction_tracker = restriction_tracker
         return self._invoke_process_per_window(
             windowed_value, additional_args, additional_kwargs,
             output_processor)
       finally:
-        self.restriction_tracker = None
+        self.threadsafe_restriction_tracker = None
         self.current_windowed_value = windowed_value
 
     elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
@@ -664,24 +684,34 @@
           windowed_value, self.process_method(*args_for_process))
 
     if self.is_splittable:
-      deferred_status = self.restriction_tracker.deferred_status()
+      # TODO: Consider calling check_done right after SDF.Process() finishing.
+      # In order to do this, we need to know that current invoking dofn is
+      # ProcessSizedElementAndRestriction.
+      self.threadsafe_restriction_tracker.check_done()
+      deferred_status = self.threadsafe_restriction_tracker.deferred_status()
+      output_watermark = None
+      if self.watermark_estimator:
+        output_watermark = self.watermark_estimator.current_watermark()
       if deferred_status:
         deferred_restriction, deferred_watermark = deferred_status
         element = windowed_value.value
         size = self.signature.get_restriction_provider().restriction_size(
             element, deferred_restriction)
-        return (
+        return ((
             windowed_value.with_value(((element, deferred_restriction), size)),
-            deferred_watermark)
+            output_watermark), deferred_watermark)
 
   def try_split(self, fraction):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     current_windowed_value = self.current_windowed_value
     if restriction_tracker and current_windowed_value:
       # Temporary workaround for [BEAM-7473]: get current_watermark before
       # split, in case watermark gets advanced before getting split results.
       # In worst case, current_watermark is always stale, which is ok.
-      current_watermark = restriction_tracker.current_watermark()
+      if self.watermark_estimator:
+        current_watermark = self.watermark_estimator.current_watermark()
+      else:
+        current_watermark = None
       split = restriction_tracker.try_split(fraction)
       if split:
         primary, residual = split
@@ -690,15 +720,13 @@
         primary_size = restriction_provider.restriction_size(element, primary)
         residual_size = restriction_provider.restriction_size(element, residual)
         return (
-            (self.current_windowed_value.with_value(
-                ((element, primary), primary_size)),
-             None),
-            (self.current_windowed_value.with_value(
-                ((element, residual), residual_size)),
-             current_watermark))
+            ((self.current_windowed_value.with_value((
+                (element, primary), primary_size)), None), None),
+            ((self.current_windowed_value.with_value((
+                (element, residual), residual_size)), current_watermark), None))
 
   def current_element_progress(self):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     if restriction_tracker:
       return restriction_tracker.current_progress()
 
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index 946ef34..fd04d4c 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -51,6 +51,9 @@
   def create_tracker(self, restriction):
     return OffsetRestrictionTracker(restriction)
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ReadFiles(DoFn):
 
@@ -63,12 +66,11 @@
       restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
       *args, **kwargs):
     file_name = element
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
 
     with open(file_name, 'rb') as file:
-      pos = restriction_tracker.start_position()
-      if restriction_tracker.start_position() > 0:
-        file.seek(restriction_tracker.start_position() - 1)
+      pos = restriction_tracker.current_restriction().start
+      if restriction_tracker.current_restriction().start > 0:
+        file.seek(restriction_tracker.current_restriction().start - 1)
         line = file.readline()
         pos = pos - 1 + len(line)
 
@@ -104,6 +106,9 @@
   def split(self, element, restriction):
     return [restriction,]
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ExpandStrings(DoFn):
 
@@ -118,10 +123,9 @@
     side.extend(side1)
     side.extend(side2)
     side.extend(side3)
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
     side = list(side)
-    for i in range(restriction_tracker.start_position(),
-                   restriction_tracker.stop_position()):
+    for i in range(restriction_tracker.current_restriction().start,
+                   restriction_tracker.current_restriction().stop):
       if restriction_tracker.try_claim(i):
         if not side:
           yield (
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index e97d65e..377ceb7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -319,6 +319,9 @@
           line = f.readline()
       self.assertSetEqual(lines_actual, lines_expected)
 
+    def test_sdf_with_watermark_tracking(self):
+      raise unittest.SkipTest("BEAM-2939")
+
     def test_sdf_with_sdf_initiated_checkpointing(self):
       raise unittest.SkipTest("BEAM-2939")
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 2204a24..b7929cb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -41,6 +41,7 @@
 from tenacity import stop_after_attempt
 
 import apache_beam as beam
+from apache_beam.io import iobase
 from apache_beam.io import restriction_trackers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricKey
@@ -56,9 +57,11 @@
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.tools import utils
+from apache_beam.transforms import core
 from apache_beam.transforms import environments
 from apache_beam.transforms import userstate
 from apache_beam.transforms import window
+from apache_beam.utils import timestamp
 
 if statesampler.FAST_SAMPLER:
   DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
@@ -423,17 +426,14 @@
       assert_that(actual, is_buffered_correctly)
 
   def test_sdf(self):
-
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
           self,
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           yield element[cur]
           cur += 1
@@ -446,6 +446,56 @@
           | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
+  def test_sdf_with_check_done_failed(self):
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider())):
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield element[cur]
+          cur += 1
+          return
+    with self.assertRaises(Exception):
+      with self.create_pipeline() as p:
+        data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+        _ = (
+            p
+            | beam.Create(data)
+            | beam.ParDo(ExpandingStringsDoFn()))
+
+  def test_sdf_with_watermark_tracking(self):
+
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider()),
+          watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
+              core.WatermarkEstimator())):
+        cur = restriction_tracker.current_restriction().start
+        start = cur
+        while restriction_tracker.try_claim(cur):
+          watermark_estimator.set_watermark(timestamp.Timestamp(micros=cur))
+          assert watermark_estimator.current_watermark().micros == start
+          yield element[cur]
+          if cur % 2 == 1:
+            restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
+            return
+          cur += 1
+
+    with self.create_pipeline() as p:
+      data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+      actual = (
+          p
+          | beam.Create(data)
+          | beam.ParDo(ExpandingStringsDoFn()))
+      assert_that(actual, equal_to(list(''.join(data))))
+
   def test_sdf_with_sdf_initiated_checkpointing(self):
 
     counter = beam.metrics.Metrics.counter('ns', 'my_counter')
@@ -456,10 +506,8 @@
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           counter.inc()
           yield element[cur]
@@ -1123,6 +1171,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
 
@@ -1142,6 +1193,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
 
@@ -1172,6 +1226,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerSplitTest(unittest.TestCase):
 
@@ -1340,7 +1397,7 @@
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
         to_emit = []
-        cur = restriction_tracker.start_position()
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           to_emit.append((element, cur))
           element_counter.increment()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 8439c8f..b3440df 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -32,6 +32,7 @@
 from builtins import object
 
 from future.utils import itervalues
+from google.protobuf import duration_pb2
 from google.protobuf import timestamp_pb2
 
 import apache_beam as beam
@@ -704,8 +705,7 @@
               ) = split
               if element_primary:
                 split_response.primary_roots.add().CopyFrom(
-                    self.delayed_bundle_application(
-                        *element_primary).application)
+                    self.bundle_application(*element_primary))
               if element_residual:
                 split_response.residual_roots.add().CopyFrom(
                     self.delayed_bundle_application(*element_residual))
@@ -718,22 +718,39 @@
     return split_response
 
   def delayed_bundle_application(self, op, deferred_remainder):
-    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
     # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
-    element_and_restriction, watermark = deferred_remainder
-    if watermark:
-      proto_watermark = timestamp_pb2.Timestamp()
-      proto_watermark.FromMicroseconds(watermark.micros)
-      output_watermarks = {output: proto_watermark for output in outputs}
+    ((element_and_restriction, output_watermark),
+     deferred_watermark) = deferred_remainder
+    if deferred_watermark:
+      assert isinstance(deferred_watermark, timestamp.Duration)
+      proto_deferred_watermark = duration_pb2.Duration()
+      proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+    else:
+      proto_deferred_watermark = None
+    return beam_fn_api_pb2.DelayedBundleApplication(
+        requested_time_delay=proto_deferred_watermark,
+        application=self.construct_bundle_application(
+            op, output_watermark, element_and_restriction))
+
+  def bundle_application(self, op, primary):
+    ((element_and_restriction, output_watermark),
+     _) = primary
+    return self.construct_bundle_application(
+        op, output_watermark, element_and_restriction)
+
+  def construct_bundle_application(self, op, output_watermark, element):
+    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+    if output_watermark:
+      proto_output_watermark = timestamp_pb2.Timestamp()
+      proto_output_watermark.FromMicroseconds(output_watermark.micros)
+      output_watermarks = {output: proto_output_watermark for output in outputs}
     else:
       output_watermarks = None
-    return beam_fn_api_pb2.DelayedBundleApplication(
-        application=beam_fn_api_pb2.BundleApplication(
-            transform_id=transform_id,
-            input_id=main_input_tag,
-            output_watermarks=output_watermarks,
-            element=main_input_coder.get_impl().encode_nested(
-                element_and_restriction)))
+    return beam_fn_api_pb2.BundleApplication(
+        transform_id=transform_id,
+        input_id=main_input_tag,
+        output_watermarks=output_watermarks,
+        element=main_input_coder.get_impl().encode_nested(element))
 
   def metrics(self):
     # DEPRECATED
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index 50740ba..fbef112 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -523,7 +523,7 @@
       element,
       restriction_tracker=beam.DoFn.RestrictionParam(
           SyntheticSDFSourceRestrictionProvider())):
-    cur = restriction_tracker.start_position()
+    cur = restriction_tracker.current_restriction().start
     while restriction_tracker.try_claim(cur):
       r = np.random.RandomState(cur)
       time.sleep(element['sleep_per_input_record_sec'])
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 148caae..06fd201 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -63,6 +63,7 @@
 from apache_beam.typehints.decorators import get_type_hints
 from apache_beam.typehints.trivial_inference import element_type
 from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import timestamp
 from apache_beam.utils import urns
 
 try:
@@ -91,7 +92,8 @@
     'Flatten',
     'Create',
     'Impulse',
-    'RestrictionProvider'
+    'RestrictionProvider',
+    'WatermarkEstimator'
     ]
 
 # Type variables
@@ -242,6 +244,8 @@
   def create_tracker(self, restriction):
     """Produces a new ``RestrictionTracker`` for the given restriction.
 
+    This API is required to be implemented.
+
     Args:
       restriction: an object that defines a restriction as identified by a
         Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
@@ -252,7 +256,10 @@
     raise NotImplementedError
 
   def initial_restriction(self, element):
-    """Produces an initial restriction for the given element."""
+    """Produces an initial restriction for the given element.
+
+    This API is required to be implemented.
+    """
     raise NotImplementedError
 
   def split(self, element, restriction):
@@ -262,6 +269,9 @@
     reading input element for each of the returned restrictions should be the
     same as the total set of elements produced by reading the input element for
     the input restriction.
+
+    This API is optional if ``split_and_size`` has been implemented.
+
     """
     yield restriction
 
@@ -281,11 +291,16 @@
 
     By default, asks a newly-created restriction tracker for the default size
     of the restriction.
+
+    This API is required to be implemented.
     """
-    return self.create_tracker(restriction).default_size()
+    raise NotImplementedError
 
   def split_and_size(self, element, restriction):
     """Like split, but also does sizing, returning (restriction, size) pairs.
+
+    This API is optional if ``split`` and ``restriction_size`` have been
+    implemented.
     """
     for part in self.split(element, restriction):
       yield part, self.restriction_size(element, part)
@@ -379,6 +394,43 @@
     return None
 
 
+class WatermarkEstimator(object):
+  """A WatermarkEstimator which is used for tracking output_watermark in a
+  DoFn.process(), typically tracking per <element, restriction> pair in SDF in
+  streaming.
+
+  There are 3 APIs in this class: set_watermark, current_watermark and reset
+  with default implementations.
+
+  TODO(BEAM-8537): Create WatermarkEstimatorProvider to support different types.
+  """
+  def __init__(self):
+    self._watermark = None
+
+  def set_watermark(self, watermark):
+    """Update tracking output_watermark with latest output_watermark.
+    This function is called inside an SDF.Process() to track the watermark of
+    output element.
+
+    Args:
+      watermark: the `timestamp.Timestamp` of current output element.
+    """
+    if not isinstance(watermark, timestamp.Timestamp):
+      raise ValueError('watermark should be a object of timestamp.Timestamp')
+    if self._watermark is None:
+      self._watermark = watermark
+    else:
+      self._watermark = min(self._watermark, watermark)
+
+  def current_watermark(self):
+    """Get current output_watermark. This function is called by system."""
+    return self._watermark
+
+  def reset(self):
+    """ Reset current tracking watermark to None."""
+    self._watermark = None
+
+
 class _DoFnParam(object):
   """DoFn parameter."""
 
@@ -459,6 +511,17 @@
     del self._callbacks[:]
 
 
+class _WatermarkEstimatorParam(_DoFnParam):
+  """WatermarkEstomator DoFn parameter."""
+
+  def __init__(self, watermark_estimator):
+    if not isinstance(watermark_estimator, WatermarkEstimator):
+      raise ValueError('DoFn.WatermarkEstimatorParam expected'
+                       'WatermarkEstimator object.')
+    self.watermark_estimator = watermark_estimator
+    self.param_id = 'WatermarkEstimator'
+
+
 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   """A function object used by a transform with custom processing.
 
@@ -477,7 +540,7 @@
   TimestampParam = _DoFnParam('TimestampParam')
   WindowParam = _DoFnParam('WindowParam')
   PaneInfoParam = _DoFnParam('PaneInfoParam')
-  WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+  WatermarkEstimatorParam = _WatermarkEstimatorParam
   BundleFinalizerParam = _BundleFinalizerParam
   KeyParam = _DoFnParam('KeyParam')
 
@@ -489,7 +552,7 @@
   TimerParam = _TimerDoFnParam
 
   DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
-                       WindowParam, WatermarkReporterParam, PaneInfoParam,
+                       WindowParam, WatermarkEstimatorParam, PaneInfoParam,
                        BundleFinalizerParam, KeyParam, StateParam, TimerParam]
 
   RestrictionParam = _RestrictionDoFnParam
@@ -522,7 +585,7 @@
     ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
     provided here to allow treatment as a Splittable ``DoFn``. The restriction
     tracker will be derived from the restriction provider in the parameter.
-    ``DoFn.WatermarkReporterParam``: a function that can be used to report
+    ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
     output watermark of Splittable ``DoFn`` implementations.
 
     Args:
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
new file mode 100644
index 0000000..1a27bd2
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for core module."""
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.transforms.core import WatermarkEstimator
+from apache_beam.utils.timestamp import Timestamp
+
+
+class WatermarkEstimatorTest(unittest.TestCase):
+
+  def test_set_watermark(self):
+    watermark_estimator = WatermarkEstimator()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+    # set_watermark should only accept timestamp.Timestamp.
+    with self.assertRaises(ValueError):
+      watermark_estimator.set_watermark(0)
+
+    # watermark_estimator should always keep minimal timestamp.
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(150))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(50))
+    self.assertEqual(watermark_estimator.current_watermark(), 50)
+
+  def test_reset(self):
+    watermark_estimator = WatermarkEstimator()
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.reset()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index bb7e522..7a87e60 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -241,7 +241,8 @@
                target_batch_overhead=.1,
                target_batch_duration_secs=1,
                variance=0.25,
-               clock=time.time):
+               clock=time.time,
+               ignore_first_n_seen_per_batch_size=0):
     if min_batch_size > max_batch_size:
       raise ValueError("Minimum (%s) must not be greater than maximum (%s)" % (
           min_batch_size, max_batch_size))
@@ -254,6 +255,9 @@
     if not (target_batch_overhead or target_batch_duration_secs):
       raise ValueError("At least one of target_batch_overhead or "
                        "target_batch_duration_secs must be positive.")
+    if ignore_first_n_seen_per_batch_size < 0:
+      raise ValueError('ignore_first_n_seen_per_batch_size (%s) must be non '
+                       'negative' % (ignore_first_n_seen_per_batch_size))
     self._min_batch_size = min_batch_size
     self._max_batch_size = max_batch_size
     self._target_batch_overhead = target_batch_overhead
@@ -262,6 +266,10 @@
     self._clock = clock
     self._data = []
     self._ignore_next_timing = False
+    self._ignore_first_n_seen_per_batch_size = (
+        ignore_first_n_seen_per_batch_size)
+    self._batch_size_num_seen = {}
+    self._replay_last_batch_size = None
 
     self._size_distribution = Metrics.distribution(
         'BatchElements', 'batch_size')
@@ -279,7 +287,7 @@
     For example, the first emit of a ParDo operation is known to be anomalous
     due to setup that may occur.
     """
-    self._ignore_next_timing = False
+    self._ignore_next_timing = True
 
   @contextlib.contextmanager
   def record_time(self, batch_size):
@@ -290,8 +298,11 @@
     self._size_distribution.update(batch_size)
     self._time_distribution.update(int(elapsed_msec))
     self._remainder_msecs = elapsed_msec - int(elapsed_msec)
+    # If we ignore the next timing, replay the batch size to get accurate
+    # timing.
     if self._ignore_next_timing:
       self._ignore_next_timing = False
+      self._replay_last_batch_size = batch_size
     else:
       self._data.append((batch_size, elapsed))
       if len(self._data) >= self._MAX_DATA_POINTS:
@@ -364,7 +375,7 @@
   except ImportError:
     linear_regression = linear_regression_no_numpy
 
-  def next_batch_size(self):
+  def _calculate_next_batch_size(self):
     if self._min_batch_size == self._max_batch_size:
       return self._min_batch_size
     elif len(self._data) < 1:
@@ -414,6 +425,21 @@
 
     return int(max(self._min_batch_size + jitter, min(target, cap)))
 
+  def next_batch_size(self):
+    # Check if we should replay a previous batch size due to it not being
+    # recorded.
+    if self._replay_last_batch_size:
+      result = self._replay_last_batch_size
+      self._replay_last_batch_size = None
+    else:
+      result = self._calculate_next_batch_size()
+
+    seen_count = self._batch_size_num_seen.get(result, 0) + 1
+    if seen_count <= self._ignore_first_n_seen_per_batch_size:
+      self.ignore_next_timing()
+    self._batch_size_num_seen[result] = seen_count
+    return result
+
 
 class _GlobalWindowsBatchingDoFn(DoFn):
   def __init__(self, batch_size_estimator):
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 4588c32..6ac05d0 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -157,6 +157,60 @@
     self.assertLess(
         max(stable_set), expected_target + expected_target * variance)
 
+  def test_ignore_first_n_batch_size(self):
+    clock = FakeClock()
+    batch_estimator = util._BatchSizeEstimator(
+        clock=clock, ignore_first_n_seen_per_batch_size=2)
+
+    expected_sizes = [
+        1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 64, 64, 64
+    ]
+    actual_sizes = []
+    for i in range(len(expected_sizes)):
+      actual_sizes.append(batch_estimator.next_batch_size())
+      with batch_estimator.record_time(actual_sizes[-1]):
+        if i % 3 == 2:
+          clock.sleep(0.01)
+        else:
+          clock.sleep(1)
+
+    self.assertEqual(expected_sizes, actual_sizes)
+
+    # Check we only record the third timing.
+    expected_data_batch_sizes = [1, 2, 4, 8, 16, 32, 64]
+    actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
+    self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
+    expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
+    for i in range(len(expected_data_timing)):
+      self.assertAlmostEqual(
+          expected_data_timing[i], batch_estimator._data[i][1])
+
+  def test_ignore_next_timing(self):
+    clock = FakeClock()
+    batch_estimator = util._BatchSizeEstimator(clock=clock)
+    batch_estimator.ignore_next_timing()
+
+    expected_sizes = [1, 1, 2, 4, 8, 16]
+    actual_sizes = []
+    for i in range(len(expected_sizes)):
+      actual_sizes.append(batch_estimator.next_batch_size())
+      with batch_estimator.record_time(actual_sizes[-1]):
+        if i == 0:
+          clock.sleep(1)
+        else:
+          clock.sleep(0.01)
+
+    self.assertEqual(expected_sizes, actual_sizes)
+
+    # Check the first record_time was skipped.
+    expected_data_batch_sizes = [1, 2, 4, 8, 16]
+    actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
+    self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
+    expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01]
+    for i in range(len(expected_data_timing)):
+      self.assertAlmostEqual(
+          expected_data_timing[i], batch_estimator._data[i][1])
+
   def _run_regression_test(self, linear_regression_fn, test_outliers):
     xs = [random.random() for _ in range(10)]
     ys = [2*x + 1 for x in xs]
diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py
index dda651d..0d5e14f 100644
--- a/sdks/python/apache_beam/transforms/window_test.py
+++ b/sdks/python/apache_beam/transforms/window_test.py
@@ -22,6 +22,8 @@
 import unittest
 from builtins import range
 
+from nose.plugins.attrib import attr
+
 import apache_beam as beam
 from apache_beam.runners import pipeline_context
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -281,6 +283,35 @@
       assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]),
                   label='assert:mean')
 
+  @attr('ValidatesRunner')
+  def test_window_assignment_idempotency(self):
+    with TestPipeline() as p:
+      pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+      result = (pcoll
+                | 'window' >> WindowInto(FixedWindows(2))
+                | 'same window' >> WindowInto(FixedWindows(2))
+                | 'same window again' >> WindowInto(FixedWindows(2))
+                | GroupByKey())
+
+      assert_that(result, equal_to([('key', [0, 1]),
+                                    ('key', [2, 3]),
+                                    ('key', [4])]))
+
+  @attr('ValidatesRunner')
+  def test_window_assignment_through_multiple_gbk_idempotency(self):
+    with TestPipeline() as p:
+      pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+      result = (pcoll
+                | 'window' >> WindowInto(FixedWindows(2))
+                | 'gbk' >> GroupByKey()
+                | 'same window' >> WindowInto(FixedWindows(2))
+                | 'another gbk' >> GroupByKey()
+                | 'same window again' >> WindowInto(FixedWindows(2))
+                | 'gbk again' >> GroupByKey())
+
+      assert_that(result, equal_to([('key', [[[0, 1]]]),
+                                    ('key', [[[2, 3]]]),
+                                    ('key', [[[4]]])]))
 
 class RunnerApiTest(unittest.TestCase):
 
diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py
index 9bccdfd..a3f3abf 100644
--- a/sdks/python/apache_beam/utils/timestamp.py
+++ b/sdks/python/apache_beam/utils/timestamp.py
@@ -25,6 +25,7 @@
 
 import datetime
 import functools
+import time
 from builtins import object
 
 import dateutil.parser
@@ -76,6 +77,10 @@
     return Timestamp(seconds)
 
   @staticmethod
+  def now():
+    return Timestamp(seconds=time.time())
+
+  @staticmethod
   def _epoch_datetime_utc():
     return datetime.datetime.fromtimestamp(0, pytz.utc)
 
@@ -173,6 +178,8 @@
     return self + other
 
   def __sub__(self, other):
+    if isinstance(other, Timestamp):
+      return Duration(micros=self.micros - other.micros)
     other = Duration.of(other)
     return Timestamp(micros=self.micros - other.micros)
 
diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py
index d26d561..2a4d454 100644
--- a/sdks/python/apache_beam/utils/timestamp_test.py
+++ b/sdks/python/apache_beam/utils/timestamp_test.py
@@ -100,6 +100,7 @@
     self.assertEqual(Timestamp(123) - Duration(456), -333)
     self.assertEqual(Timestamp(1230) % 456, 318)
     self.assertEqual(Timestamp(1230) % Duration(456), 318)
+    self.assertEqual(Timestamp(123) - Timestamp(100), 23)
 
     # Check that direct comparison of Timestamp and Duration is allowed.
     self.assertTrue(Duration(123) == Timestamp(123))
@@ -116,6 +117,7 @@
     self.assertEqual((Timestamp(123) - Duration(456)).__class__, Timestamp)
     self.assertEqual((Timestamp(1230) % 456).__class__, Duration)
     self.assertEqual((Timestamp(1230) % Duration(456)).__class__, Duration)
+    self.assertEqual((Timestamp(123) - Timestamp(100)).__class__, Duration)
 
     # Unsupported operations.
     with self.assertRaises(TypeError):
@@ -159,6 +161,10 @@
     self.assertEqual('Timestamp(-999999999)',
                      str(Timestamp(-999999999)))
 
+  def test_now(self):
+    now = Timestamp.now()
+    self.assertTrue(isinstance(now, Timestamp))
+
 
 class DurationTest(unittest.TestCase):
 
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index e3794ba..bc77fd8 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -183,6 +183,7 @@
   '_TimerDoFnParam',
   '_BundleFinalizerParam',
   '_RestrictionDoFnParam',
+  '_WatermarkEstimatorParam',
 
   # Sphinx cannot find this py:class reference target
   'typing.Generic',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 7eea64c..9f1a9f3 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -228,7 +228,10 @@
     python_requires=python_requires,
     test_suite='nose.collector',
     setup_requires=['pytest_runner'],
-    tests_require=REQUIRED_TEST_PACKAGES,
+    tests_require= [
+        REQUIRED_TEST_PACKAGES,
+        INTERACTIVE_BEAM,
+    ],
     extras_require={
         'docs': ['Sphinx>=1.5.2,<2.0'],
         'test': REQUIRED_TEST_PACKAGES,
diff --git a/sdks/python/test-suites/tox/py37/build.gradle b/sdks/python/test-suites/tox/py37/build.gradle
index 2a57ca9..c9c99e6 100644
--- a/sdks/python/test-suites/tox/py37/build.gradle
+++ b/sdks/python/test-suites/tox/py37/build.gradle
@@ -41,9 +41,6 @@
 toxTask "testPy37Cython", "py37-cython"
 test.dependsOn testPy37Cython
 
-toxTask "testPy37Interactive", "py37-interactive"
-test.dependsOn testPy37Interactive
-
 // Ensure that testPy37Cython runs exclusively to other tests. This line is not
 // actually required, since gradle doesn't do parallel execution within a
 // project.
@@ -60,7 +57,6 @@
 task preCommitPy37() {
     dependsOn "testPy37Gcp"
     dependsOn "testPy37Cython"
-    dependsOn "testPy37Interactive"
 }
 
 task preCommitPy37Pytest {
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index fe3f65a..d227519 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -307,10 +307,3 @@
   coverage report --skip-covered
   # Generate report in xml format
   coverage xml
-
-[testenv:py37-interactive]
-setenv =
-  RUN_SKIPPED_PY3_TESTS=0
-extras = test,interactive
-commands =
-  python setup.py nosetests --ignore-files '.*py3[8-9]\.py$' {posargs}