diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index a603d16..4def350 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -1148,6 +1148,15 @@
     }
   }
 
+  private Windmill.WorkItemCommitRequest.Builder initializeOutputBuilder(
+      final ByteString key, final Windmill.WorkItem workItem) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(key)
+        .setShardingKey(workItem.getShardingKey())
+        .setWorkToken(workItem.getWorkToken())
+        .setCacheToken(workItem.getCacheToken());
+  }
+
   private void process(
       final SdkWorkerHarness worker,
       final ComputationState computationState,
@@ -1164,12 +1173,7 @@
     DataflowWorkerLoggingMDC.setStageName(computationId);
     LOG.debug("Starting processing for {}:\n{}", computationId, work);
 
-    Windmill.WorkItemCommitRequest.Builder outputBuilder =
-        Windmill.WorkItemCommitRequest.newBuilder()
-            .setKey(key)
-            .setShardingKey(workItem.getShardingKey())
-            .setWorkToken(workItem.getWorkToken())
-            .setCacheToken(workItem.getCacheToken());
+    Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
 
     // Before any processing starts, call any pending OnCommit callbacks.  Nothing that requires
     // cleanup should be done before this, since we might exit early here.
@@ -1334,20 +1338,22 @@
       WorkItemCommitRequest commitRequest = outputBuilder.build();
       int byteLimit = maxWorkItemCommitBytes;
       int commitSize = commitRequest.getSerializedSize();
+      int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize;
+
       // Detect overflow of integer serialized size or if the byte limit was exceeded.
-      windmillMaxObservedWorkItemCommitBytes.addValue(
-          commitSize < 0 ? Integer.MAX_VALUE : commitSize);
-      if (commitSize < 0) {
-        throw KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
-      } else if (commitSize > byteLimit) {
-        // Once supported, we should communicate the desired truncation for the commit to the
-        // streaming engine. For now we report the error but attempt the commit so that it will be
-        // truncated by the streaming engine backend.
+      windmillMaxObservedWorkItemCommitBytes.addValue(estimatedCommitSize);
+      if (estimatedCommitSize > byteLimit) {
         KeyCommitTooLargeException e =
             KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
         reportFailure(computationId, workItem, e);
         LOG.error(e.toString());
+
+        // Drop the current request in favor of a new, minimal one requesting truncation.
+        // Messages, timers, counters, and other commit content will not be used by the service
+        // so we're purposefully dropping them here
+        commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize);
       }
+
       commitQueue.put(new Commit(commitRequest, computationState, work));
 
       // Compute shuffle and state byte statistics these will be flushed asynchronously.
@@ -1442,6 +1448,14 @@
     }
   }
 
+  private WorkItemCommitRequest buildWorkItemTruncationRequest(
+      final ByteString key, final Windmill.WorkItem workItem, final int estimatedCommitSize) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
+    outputBuilder.setExceedsMaxWorkItemCommitBytes(true);
+    outputBuilder.setEstimatedWorkItemCommitBytes(estimatedCommitSize);
+    return outputBuilder.build();
+  }
+
   private void commitLoop() {
     Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> computationRequestMap =
         new HashMap<>();
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 356dddb..2cb5ada 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -562,37 +562,70 @@
 
   private WorkItemCommitRequest.Builder makeExpectedOutput(
       int index, long timestamp, String key, String outKey) throws Exception {
+    StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+    appendCommitOutputMessages(expectedCommitRequestBuilder, index, timestamp, outKey);
+
     return setMessagesMetadata(
         PaneInfo.NO_FIRING,
         intervalWindowBytes(DEFAULT_WINDOW),
-        parseCommitRequest(
-            "key: \""
-                + key
-                + "\" "
-                + "sharding_key: 17 "
-                + "work_token: "
-                + index
-                + " "
-                + "cache_token: 3 "
-                + "output_messages {"
-                + "  destination_stream_id: \""
-                + DEFAULT_DESTINATION_STREAM_ID
-                + "\""
-                + "  bundles {"
-                + "    key: \""
-                + outKey
-                + "\""
-                + "    messages {"
-                + "      timestamp: "
-                + timestamp
-                + "      data: \""
-                + dataStringForIndex(index)
-                + "\""
-                + "      metadata: \"\""
-                + "    }"
-                + "    messages_ids: \"\""
-                + "  }"
-                + "}"));
+        parseCommitRequest(expectedCommitRequestBuilder.toString()));
+  }
+
+  private WorkItemCommitRequest.Builder makeExpectedTruncationRequestOutput(
+      int index, String key, long estimatedSize) throws Exception {
+    StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+    appendCommitTruncationFields(expectedCommitRequestBuilder, estimatedSize);
+
+    return parseCommitRequest(expectedCommitRequestBuilder.toString());
+  }
+
+  private StringBuilder initializeExpectedCommitRequest(String key, int index) {
+    StringBuilder requestBuilder = new StringBuilder();
+
+    requestBuilder.append("key: \"");
+    requestBuilder.append(key);
+    requestBuilder.append("\" ");
+    requestBuilder.append("sharding_key: 17 ");
+    requestBuilder.append("work_token: ");
+    requestBuilder.append(index);
+    requestBuilder.append(" ");
+    requestBuilder.append("cache_token: 3 ");
+
+    return requestBuilder;
+  }
+
+  private StringBuilder appendCommitOutputMessages(
+      StringBuilder requestBuilder, int index, long timestamp, String outKey) {
+    requestBuilder.append("output_messages {");
+    requestBuilder.append("  destination_stream_id: \"");
+    requestBuilder.append(DEFAULT_DESTINATION_STREAM_ID);
+    requestBuilder.append("\"");
+    requestBuilder.append("  bundles {");
+    requestBuilder.append("    key: \"");
+    requestBuilder.append(outKey);
+    requestBuilder.append("\"");
+    requestBuilder.append("    messages {");
+    requestBuilder.append("      timestamp: ");
+    requestBuilder.append(timestamp);
+    requestBuilder.append("      data: \"");
+    requestBuilder.append(dataStringForIndex(index));
+    requestBuilder.append("\"");
+    requestBuilder.append("      metadata: \"\"");
+    requestBuilder.append("    }");
+    requestBuilder.append("    messages_ids: \"\"");
+    requestBuilder.append("  }");
+    requestBuilder.append("}");
+
+    return requestBuilder;
+  }
+
+  private StringBuilder appendCommitTruncationFields(
+      StringBuilder requestBuilder, long estimatedSize) {
+    requestBuilder.append("exceeds_max_work_item_commit_bytes: true ");
+    requestBuilder.append("estimated_work_item_commit_bytes: ");
+    requestBuilder.append(estimatedSize);
+
+    return requestBuilder;
   }
 
   private StreamingComputationConfig makeDefaultStreamingComputationConfig(
@@ -948,64 +981,19 @@
 
     assertEquals(2, result.size());
     assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
+
     assertTrue(result.containsKey(1L));
-    assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
-    assertTrue(result.get(1L).getSerializedSize() > 1000);
+    WorkItemCommitRequest largeCommit = result.get(1L);
+    assertEquals("large_key", largeCommit.getKey().toStringUtf8());
+    assertEquals(
+        makeExpectedTruncationRequestOutput(
+                1, "large_key", largeCommit.getEstimatedWorkItemCommitBytes())
+            .build(),
+        largeCommit);
 
-    // Spam worker updates a few times.
-    int maxTries = 10;
-    while (--maxTries > 0) {
-      worker.reportPeriodicWorkerUpdates();
-      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
-    }
-
-    // We should see an exception reported for the large commit but not the small one.
-    ArgumentCaptor<WorkItemStatus> workItemStatusCaptor =
-        ArgumentCaptor.forClass(WorkItemStatus.class);
-    verify(mockWorkUnitClient, atLeast(2)).reportWorkItemStatus(workItemStatusCaptor.capture());
-    List<WorkItemStatus> capturedStatuses = workItemStatusCaptor.getAllValues();
-    boolean foundErrors = false;
-    for (WorkItemStatus status : capturedStatuses) {
-      if (!status.getErrors().isEmpty()) {
-        assertFalse(foundErrors);
-        foundErrors = true;
-        String errorMessage = status.getErrors().get(0).getMessage();
-        assertThat(errorMessage, Matchers.containsString("KeyCommitTooLargeException"));
-      }
-    }
-    assertTrue(foundErrors);
-  }
-
-  @Test
-  public void testKeyCommitTooLargeException_StreamingEngine() throws Exception {
-    KvCoder<String, String> kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of());
-
-    List<ParallelInstruction> instructions =
-        Arrays.asList(
-            makeSourceInstruction(kvCoder),
-            makeDoFnInstruction(new LargeCommitFn(), 0, kvCoder),
-            makeSinkInstruction(kvCoder, 1));
-
-    FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-    server.setExpectedExceptionCount(1);
-
-    StreamingDataflowWorkerOptions options =
-        createTestingPipelineOptions(server, "--experiments=enable_streaming_engine");
-    StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */);
-    worker.setMaxWorkItemCommitBytes(1000);
-    worker.start();
-
-    server.addWorkToOffer(makeInput(1, 0, "large_key"));
-    server.addWorkToOffer(makeInput(2, 0, "key"));
-    server.waitForEmptyWorkQueue();
-
-    Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(1);
-
-    assertEquals(2, result.size());
-    assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
-    assertTrue(result.containsKey(1L));
-    assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
-    assertTrue(result.get(1L).getSerializedSize() > 1000);
+    // Check this explicitly since the estimated commit bytes weren't actually
+    // checked against an expected value in the previous step
+    assertTrue(largeCommit.getEstimatedWorkItemCommitBytes() > 1000);
 
     // Spam worker updates a few times.
     int maxTries = 10;
diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 5310902..598164e 100644
--- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -272,13 +272,16 @@
   optional string state_family = 3;
 }
 
-// next id: 19
+// next id: 24
 message WorkItemCommitRequest {
   required bytes key = 1;
   required fixed64 work_token = 2;
   optional fixed64 sharding_key = 15;
   optional fixed64 cache_token = 16;
 
+  optional bool exceeds_max_work_item_commit_bytes = 20;
+  optional int64 estimated_work_item_commit_bytes = 21;
+
   repeated OutputMessageBundle output_messages = 3;
   repeated PubSubMessageBundle pubsub_messages = 7;
   repeated Timer output_timers = 4;
@@ -290,12 +293,14 @@
   optional SourceState source_state_updates = 12;
   optional int64 source_watermark = 13 [default=-0x8000000000000000];
   optional int64 source_backlog_bytes = 17 [default=-1];
+  optional int64 source_bytes_processed = 22;
+
   repeated WatermarkHold watermark_holds = 14;
 
   // DEPRECATED
   repeated GlobalDataId global_data_id_requests = 9;
 
-  reserved 6;
+  reserved 6, 19, 23;
 }
 
 message ComputationCommitWorkRequest {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java
index f56e7f2..9051051 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java
@@ -148,16 +148,18 @@
       LOG.debug("Received InstructionResponse {}", response);
       CompletableFuture<BeamFnApi.InstructionResponse> responseFuture =
           outstandingRequests.remove(response.getInstructionId());
-      if (responseFuture != null) {
-        if (response.getError().isEmpty()) {
-          responseFuture.complete(response);
-        } else {
-          responseFuture.completeExceptionally(
-              new RuntimeException(
-                  String.format(
-                      "Error received from SDK harness for instruction %s: %s",
-                      response.getInstructionId(), response.getError())));
-        }
+      if (responseFuture == null) {
+        LOG.warn("Dropped unknown InstructionResponse {}", response);
+        return;
+      }
+      if (response.getError().isEmpty()) {
+        responseFuture.complete(response);
+      } else {
+        responseFuture.completeExceptionally(
+            new RuntimeException(
+                String.format(
+                    "Error received from SDK harness for instruction %s: %s",
+                    response.getInstructionId(), response.getError())));
       }
     }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
index 1619c22..f85b3c8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
@@ -138,12 +138,14 @@
       public void onNext(StateResponse value) {
         LOG.debug("Received StateResponse {}", value);
         CompletableFuture<StateResponse> responseFuture = outstandingRequests.remove(value.getId());
-        if (responseFuture != null) {
-          if (value.getError().isEmpty()) {
-            responseFuture.complete(value);
-          } else {
-            responseFuture.completeExceptionally(new IllegalStateException(value.getError()));
-          }
+        if (responseFuture == null) {
+          LOG.warn("Dropped unknown StateResponse {}", value);
+          return;
+        }
+        if (value.getError().isEmpty()) {
+          responseFuture.complete(value);
+        } else {
+          responseFuture.completeExceptionally(new IllegalStateException(value.getError()));
         }
       }
 
diff --git a/sdks/python/apache_beam/internal/gcp/auth.py b/sdks/python/apache_beam/internal/gcp/auth.py
index 5142a34..8a94acf 100644
--- a/sdks/python/apache_beam/internal/gcp/auth.py
+++ b/sdks/python/apache_beam/internal/gcp/auth.py
@@ -40,6 +40,10 @@
 # information.
 executing_project = None
 
+
+_LOGGER = logging.getLogger(__name__)
+
+
 if GceAssertionCredentials is not None:
   class _GceAssertionCredentials(GceAssertionCredentials):
     """GceAssertionCredentials with retry wrapper.
@@ -101,9 +105,9 @@
       # apitools use urllib with the global timeout. Set it to 60 seconds
       # to prevent network related stuckness issues.
       if not socket.getdefaulttimeout():
-        logging.info("Setting socket default timeout to 60 seconds.")
+        _LOGGER.info("Setting socket default timeout to 60 seconds.")
         socket.setdefaulttimeout(60)
-      logging.info(
+      _LOGGER.info(
           "socket default timeout is %s seconds.", socket.getdefaulttimeout())
 
       cls._credentials = cls._get_service_credentials()
@@ -131,7 +135,7 @@
                       'Credentials.')
         return credentials
       except Exception as e:
-        logging.warning(
+        _LOGGER.warning(
             'Unable to find default credentials to use: %s\n'
             'Connecting anonymously.', e)
         return None
diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py
index 3ff420f..76143dd 100644
--- a/sdks/python/apache_beam/io/filebasedsink.py
+++ b/sdks/python/apache_beam/io/filebasedsink.py
@@ -45,6 +45,9 @@
 __all__ = ['FileBasedSink']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class FileBasedSink(iobase.Sink):
   """A sink to a GCS or local files.
 
@@ -210,7 +213,7 @@
                       for file_metadata in mr.metadata_list]
 
     if dst_glob_files:
-      logging.warning('Deleting %d existing files in target path matching: %s',
+      _LOGGER.warning('Deleting %d existing files in target path matching: %s',
                       len(dst_glob_files), self.shard_name_glob_format)
       FileSystems.delete(dst_glob_files)
 
@@ -250,12 +253,12 @@
         raise BeamIOError('src and dst files do not exist. src: %s, dst: %s' % (
             src, dst))
       if not src_exists and dst_exists:
-        logging.debug('src: %s -> dst: %s already renamed, skipping', src, dst)
+        _LOGGER.debug('src: %s -> dst: %s already renamed, skipping', src, dst)
         num_skipped += 1
         continue
       if (src_exists and dst_exists and
           FileSystems.checksum(src) == FileSystems.checksum(dst)):
-        logging.debug('src: %s == dst: %s, deleting src', src, dst)
+        _LOGGER.debug('src: %s == dst: %s, deleting src', src, dst)
         delete_files.append(src)
         continue
 
@@ -284,7 +287,7 @@
                               for i in range(0, len(dst_files), chunk_size)]
 
     if num_shards_to_finalize:
-      logging.info(
+      _LOGGER.info(
           'Starting finalize_write threads with num_shards: %d (skipped: %d), '
           'batches: %d, num_threads: %d',
           num_shards_to_finalize, num_skipped, len(source_file_batch),
@@ -304,11 +307,11 @@
             raise
           for (src, dst), exception in iteritems(exp.exception_details):
             if exception:
-              logging.error(('Exception in _rename_batch. src: %s, '
+              _LOGGER.error(('Exception in _rename_batch. src: %s, '
                              'dst: %s, err: %s'), src, dst, exception)
               exceptions.append(exception)
             else:
-              logging.debug('Rename successful: %s -> %s', src, dst)
+              _LOGGER.debug('Rename successful: %s -> %s', src, dst)
           return exceptions
 
       exception_batches = util.run_using_threadpool(
@@ -324,10 +327,10 @@
       for final_name in dst_files:
         yield final_name
 
-      logging.info('Renamed %d shards in %.2f seconds.', num_shards_to_finalize,
+      _LOGGER.info('Renamed %d shards in %.2f seconds.', num_shards_to_finalize,
                    time.time() - start_time)
     else:
-      logging.warning(
+      _LOGGER.warning(
           'No shards found to finalize. num_shards: %d, skipped: %d',
           num_shards, num_skipped)
 
diff --git a/sdks/python/apache_beam/io/filebasedsink_test.py b/sdks/python/apache_beam/io/filebasedsink_test.py
index 07d0e8e..4c5ef6b 100644
--- a/sdks/python/apache_beam/io/filebasedsink_test.py
+++ b/sdks/python/apache_beam/io/filebasedsink_test.py
@@ -43,6 +43,8 @@
 from apache_beam.transforms.display import DisplayData
 from apache_beam.transforms.display_test import DisplayDataItemMatcher
 
+_LOGGER = logging.getLogger(__name__)
+
 
 # TODO: Refactor code so all io tests are using same library
 # TestCaseWithTempDirCleanup class.
@@ -247,7 +249,7 @@
           'gs://aaa/bbb', 'gs://aaa/bbb/', 'gs://aaa', 'gs://aaa/', 'gs://',
           '/')
     except ValueError:
-      logging.debug('Ignoring test since GCP module is not installed')
+      _LOGGER.debug('Ignoring test since GCP module is not installed')
 
   @mock.patch('apache_beam.io.localfilesystem.os')
   def test_temp_dir_local(self, filesystem_os_mock):
diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py
index 1d9fcdd..14c35bc 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -114,6 +114,9 @@
            'ReadMatches']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class EmptyMatchTreatment(object):
   """How to treat empty matches in ``MatchAll`` and ``MatchFiles`` transforms.
 
@@ -479,7 +482,7 @@
           str,
           filesystems.FileSystems.join(temp_location,
                                        '.temp%s' % dir_uid))
-      logging.info('Added temporary directory %s', self._temp_directory.get())
+      _LOGGER.info('Added temporary directory %s', self._temp_directory.get())
 
     output = (pcoll
               | beam.ParDo(_WriteUnshardedRecordsFn(
@@ -557,7 +560,7 @@
                                             '',
                                             destination)
 
-      logging.info('Moving temporary file %s to dir: %s as %s. Res: %s',
+      _LOGGER.info('Moving temporary file %s to dir: %s as %s. Res: %s',
                    r.file_name, self.path.get(), final_file_name, r)
 
       final_full_path = filesystems.FileSystems.join(self.path.get(),
@@ -570,7 +573,7 @@
       except BeamIOError:
         # This error is not serious, because it may happen on a retry of the
         # bundle. We simply log it.
-        logging.debug('File %s failed to be copied. This may be due to a bundle'
+        _LOGGER.debug('File %s failed to be copied. This may be due to a bundle'
                       ' being retried.', r.file_name)
 
       yield FileResult(final_file_name,
@@ -580,7 +583,7 @@
                        r.pane,
                        destination)
 
-    logging.info('Cautiously removing temporary files for'
+    _LOGGER.info('Cautiously removing temporary files for'
                  ' destination %s and window %s', destination, w)
     writer_key = (destination, w)
     self._remove_temporary_files(writer_key)
@@ -592,10 +595,10 @@
       match_result = filesystems.FileSystems.match(['%s*' % prefix])
       orphaned_files = [m.path for m in match_result[0].metadata_list]
 
-      logging.debug('Deleting orphaned files: %s', orphaned_files)
+      _LOGGER.debug('Deleting orphaned files: %s', orphaned_files)
       filesystems.FileSystems.delete(orphaned_files)
     except BeamIOError as e:
-      logging.debug('Exceptions when deleting files: %s', e)
+      _LOGGER.debug('Exceptions when deleting files: %s', e)
 
 
 class _WriteShardedRecordsFn(beam.DoFn):
@@ -625,7 +628,7 @@
     sink.flush()
     writer.close()
 
-    logging.info('Writing file %s for destination %s and shard %s',
+    _LOGGER.info('Writing file %s for destination %s and shard %s',
                  full_file_name, destination, repr(shard))
 
     yield FileResult(full_file_name,
diff --git a/sdks/python/apache_beam/io/filesystemio_test.py b/sdks/python/apache_beam/io/filesystemio_test.py
index 72e7f0d..7797eb8 100644
--- a/sdks/python/apache_beam/io/filesystemio_test.py
+++ b/sdks/python/apache_beam/io/filesystemio_test.py
@@ -28,6 +28,8 @@
 
 from apache_beam.io import filesystemio
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class FakeDownloader(filesystemio.Downloader):
 
@@ -206,7 +208,7 @@
 
     for buffer_size in buffer_sizes:
       for target in [self._read_and_verify, self._read_and_seek]:
-        logging.info('buffer_size=%s, target=%s' % (buffer_size, target))
+        _LOGGER.info('buffer_size=%s, target=%s' % (buffer_size, target))
         parent_conn, child_conn = multiprocessing.Pipe()
         stream = filesystemio.PipeStream(child_conn)
         success = [False]
diff --git a/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py b/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py
index 21de828..d357946 100644
--- a/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py
+++ b/sdks/python/apache_beam/io/gcp/big_query_query_to_table_it_test.py
@@ -45,6 +45,9 @@
 except ImportError:
   pass
 
+
+_LOGGER = logging.getLogger(__name__)
+
 WAIT_UNTIL_FINISH_DURATION_MS = 15 * 60 * 1000
 
 BIG_QUERY_DATASET_ID = 'python_query_to_table_'
@@ -90,7 +93,7 @@
     try:
       self.bigquery_client.client.datasets.Delete(request)
     except HttpError:
-      logging.debug('Failed to clean up dataset %s' % self.dataset_id)
+      _LOGGER.debug('Failed to clean up dataset %s' % self.dataset_id)
 
   def _setup_new_types_env(self):
     table_schema = bigquery.TableSchema()
diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index d3ac5ca..0280c61 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -273,6 +273,9 @@
     ]
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 @deprecated(since='2.11.0', current="bigquery_tools.parse_table_reference")
 def _parse_table_reference(table, dataset=None, project=None):
   return bigquery_tools.parse_table_reference(table, dataset, project)
@@ -787,7 +790,7 @@
       # and avoid the get-or-create step.
       return
 
-    logging.debug('Creating or getting table %s with schema %s.',
+    _LOGGER.debug('Creating or getting table %s with schema %s.',
                   table_reference, schema)
 
     table_schema = self.get_table_schema(schema)
@@ -833,7 +836,7 @@
     return self._flush_all_batches()
 
   def _flush_all_batches(self):
-    logging.debug('Attempting to flush to all destinations. Total buffered: %s',
+    _LOGGER.debug('Attempting to flush to all destinations. Total buffered: %s',
                   self._total_buffered_rows)
 
     return itertools.chain(*[self._flush_batch(destination)
@@ -850,7 +853,7 @@
       table_reference.projectId = vp.RuntimeValueProvider.get_value(
           'project', str, '')
 
-    logging.debug('Flushing data to %s. Total %s rows.',
+    _LOGGER.debug('Flushing data to %s. Total %s rows.',
                   destination, len(rows_and_insert_ids))
 
     rows = [r[0] for r in rows_and_insert_ids]
@@ -865,7 +868,7 @@
           insert_ids=insert_ids,
           skip_invalid_rows=True)
 
-      logging.debug("Passed: %s. Errors are %s", passed, errors)
+      _LOGGER.debug("Passed: %s. Errors are %s", passed, errors)
       failed_rows = [rows[entry.index] for entry in errors]
       should_retry = any(
           bigquery_tools.RetryStrategy.should_retry(
@@ -877,7 +880,7 @@
         break
       else:
         retry_backoff = next(self._backoff_calculator)
-        logging.info('Sleeping %s seconds before retrying insertion.',
+        _LOGGER.info('Sleeping %s seconds before retrying insertion.',
                      retry_backoff)
         time.sleep(retry_backoff)
 
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
index cb285ea..06525cd 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
@@ -46,6 +46,8 @@
 from apache_beam.transforms import trigger
 from apache_beam.transforms.window import GlobalWindows
 
+_LOGGER = logging.getLogger(__name__)
+
 ONE_TERABYTE = (1 << 40)
 
 # The maximum file size for imports is 5TB. We keep our files under that.
@@ -320,7 +322,7 @@
                                copy_to_reference.datasetId,
                                copy_to_reference.tableId)))
 
-    logging.info("Triggering copy job from %s to %s",
+    _LOGGER.info("Triggering copy job from %s to %s",
                  copy_from_reference, copy_to_reference)
     job_reference = self.bq_wrapper._insert_copy_job(
         copy_to_reference.projectId,
@@ -407,7 +409,7 @@
     uid = _bq_uuid()
     job_name = '%s_%s_%s' % (
         load_job_name_prefix, destination_hash, uid)
-    logging.debug('Load job has %s files. Job name is %s.',
+    _LOGGER.debug('Load job has %s files. Job name is %s.',
                   len(files), job_name)
 
     if self.temporary_tables:
@@ -415,7 +417,7 @@
       table_reference.tableId = job_name
       yield pvalue.TaggedOutput(TriggerLoadJobs.TEMP_TABLES, table_reference)
 
-    logging.info('Triggering job %s to load data to BigQuery table %s.'
+    _LOGGER.info('Triggering job %s to load data to BigQuery table %s.'
                  'Schema: %s. Additional parameters: %s',
                  job_name, table_reference,
                  schema, additional_parameters)
@@ -519,9 +521,9 @@
                                     ref.jobId,
                                     ref.location)
 
-      logging.info("Job status: %s", job.status)
+      _LOGGER.info("Job status: %s", job.status)
       if job.status.state == 'DONE' and job.status.errorResult:
-        logging.warning("Job %s seems to have failed. Error Result: %s",
+        _LOGGER.warning("Job %s seems to have failed. Error Result: %s",
                         ref.jobId, job.status.errorResult)
         self._latest_error = job.status
         return WaitForBQJobs.FAILED
@@ -541,7 +543,7 @@
     self.bq_wrapper = bigquery_tools.BigQueryWrapper(client=self.test_client)
 
   def process(self, table_reference):
-    logging.info("Deleting table %s", table_reference)
+    _LOGGER.info("Deleting table %s", table_reference)
     table_reference = bigquery_tools.parse_table_reference(table_reference)
     self.bq_wrapper._delete_table(
         table_reference.projectId,
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
index 035be18..bbf8d3a 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
@@ -58,6 +58,8 @@
   HttpError = None
 
 
+_LOGGER = logging.getLogger(__name__)
+
 _DESTINATION_ELEMENT_PAIRS = [
     # DESTINATION 1
     ('project1:dataset1.table1', '{"name":"beam", "language":"py"}'),
@@ -609,7 +611,7 @@
     self.bigquery_client = bigquery_tools.BigQueryWrapper()
     self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
     self.output_table = "%s.output_table" % (self.dataset_id)
-    logging.info("Created dataset %s in project %s",
+    _LOGGER.info("Created dataset %s in project %s",
                  self.dataset_id, self.project)
 
   @attr('IT')
@@ -794,11 +796,11 @@
         projectId=self.project, datasetId=self.dataset_id,
         deleteContents=True)
     try:
-      logging.info("Deleting dataset %s in project %s",
+      _LOGGER.info("Deleting dataset %s in project %s",
                    self.dataset_id, self.project)
       self.bigquery_client.client.datasets.Delete(request)
     except HttpError:
-      logging.debug('Failed to clean up dataset %s in project %s',
+      _LOGGER.debug('Failed to clean up dataset %s in project %s',
                     self.dataset_id, self.project)
 
 
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
index 246d2ce..ff63eda 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
@@ -46,6 +46,9 @@
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class BigQueryReadIntegrationTests(unittest.TestCase):
   BIG_QUERY_DATASET_ID = 'python_read_table_'
 
@@ -59,7 +62,7 @@
                                   str(int(time.time())),
                                   random.randint(0, 10000))
     self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
-    logging.info("Created dataset %s in project %s",
+    _LOGGER.info("Created dataset %s in project %s",
                  self.dataset_id, self.project)
 
   def tearDown(self):
@@ -67,11 +70,11 @@
         projectId=self.project, datasetId=self.dataset_id,
         deleteContents=True)
     try:
-      logging.info("Deleting dataset %s in project %s",
+      _LOGGER.info("Deleting dataset %s in project %s",
                    self.dataset_id, self.project)
       self.bigquery_client.client.datasets.Delete(request)
     except HttpError:
-      logging.debug('Failed to clean up dataset %s in project %s',
+      _LOGGER.debug('Failed to clean up dataset %s in project %s',
                     self.dataset_id, self.project)
 
   def create_table(self, tablename):
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index b8c8c1c..6cf4529 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -71,6 +71,9 @@
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
 class TestTableRowJsonCoder(unittest.TestCase):
 
@@ -579,7 +582,7 @@
     self.bigquery_client = bigquery_tools.BigQueryWrapper()
     self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
     self.output_table = "%s.output_table" % (self.dataset_id)
-    logging.info("Created dataset %s in project %s",
+    _LOGGER.info("Created dataset %s in project %s",
                  self.dataset_id, self.project)
 
   @attr('IT')
@@ -741,11 +744,11 @@
         projectId=self.project, datasetId=self.dataset_id,
         deleteContents=True)
     try:
-      logging.info("Deleting dataset %s in project %s",
+      _LOGGER.info("Deleting dataset %s in project %s",
                    self.dataset_id, self.project)
       self.bigquery_client.client.datasets.Delete(request)
     except HttpError:
-      logging.debug('Failed to clean up dataset %s in project %s',
+      _LOGGER.debug('Failed to clean up dataset %s in project %s',
                     self.dataset_id, self.project)
 
 
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
index 0649703..f2763ca 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
@@ -62,8 +62,9 @@
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
-MAX_RETRIES = 3
+_LOGGER = logging.getLogger(__name__)
 
+MAX_RETRIES = 3
 
 JSON_COMPLIANCE_ERROR = 'NAN, INF and -INF values are not JSON compliant.'
 
@@ -262,7 +263,7 @@
 
     if response.statistics is None:
       # This behavior is only expected in tests
-      logging.warning(
+      _LOGGER.warning(
           "Unable to get location, missing response.statistics. Query: %s",
           query)
       return None
@@ -274,11 +275,11 @@
           table.projectId,
           table.datasetId,
           table.tableId)
-      logging.info("Using location %r from table %r referenced by query %s",
+      _LOGGER.info("Using location %r from table %r referenced by query %s",
                    location, table, query)
       return location
 
-    logging.debug("Query %s does not reference any tables.", query)
+    _LOGGER.debug("Query %s does not reference any tables.", query)
     return None
 
   @retry.with_exponential_backoff(
@@ -309,9 +310,9 @@
         )
     )
 
-    logging.info("Inserting job request: %s", request)
+    _LOGGER.info("Inserting job request: %s", request)
     response = self.client.jobs.Insert(request)
-    logging.info("Response was %s", response)
+    _LOGGER.info("Response was %s", response)
     return response.jobReference
 
   @retry.with_exponential_backoff(
@@ -442,7 +443,7 @@
     request = bigquery.BigqueryTablesInsertRequest(
         projectId=project_id, datasetId=dataset_id, table=table)
     response = self.client.tables.Insert(request)
-    logging.debug("Created the table with id %s", table_id)
+    _LOGGER.debug("Created the table with id %s", table_id)
     # The response is a bigquery.Table instance.
     return response
 
@@ -491,7 +492,7 @@
       self.client.tables.Delete(request)
     except HttpError as exn:
       if exn.status_code == 404:
-        logging.warning('Table %s:%s.%s does not exist', project_id,
+        _LOGGER.warning('Table %s:%s.%s does not exist', project_id,
                         dataset_id, table_id)
         return
       else:
@@ -508,7 +509,7 @@
       self.client.datasets.Delete(request)
     except HttpError as exn:
       if exn.status_code == 404:
-        logging.warning('Dataset %s:%s does not exist', project_id,
+        _LOGGER.warning('Dataset %s:%s does not exist', project_id,
                         dataset_id)
         return
       else:
@@ -537,7 +538,7 @@
             % (project_id, dataset_id))
     except HttpError as exn:
       if exn.status_code == 404:
-        logging.warning(
+        _LOGGER.warning(
             'Dataset %s:%s does not exist so we will create it as temporary '
             'with location=%s',
             project_id, dataset_id, location)
@@ -555,7 +556,7 @@
           projectId=project_id, datasetId=temp_table.datasetId))
     except HttpError as exn:
       if exn.status_code == 404:
-        logging.warning('Dataset %s:%s does not exist', project_id,
+        _LOGGER.warning('Dataset %s:%s does not exist', project_id,
                         temp_table.datasetId)
         return
       else:
@@ -669,12 +670,12 @@
             additional_parameters=additional_create_parameters)
       except HttpError as exn:
         if exn.status_code == 409:
-          logging.debug('Skipping Creation. Table %s:%s.%s already exists.'
+          _LOGGER.debug('Skipping Creation. Table %s:%s.%s already exists.'
                         % (project_id, dataset_id, table_id))
           created_table = self.get_table(project_id, dataset_id, table_id)
         else:
           raise
-      logging.info('Created table %s.%s.%s with schema %s. '
+      _LOGGER.info('Created table %s.%s.%s with schema %s. '
                    'Result: %s.',
                    project_id, dataset_id, table_id,
                    schema or found_table.schema,
@@ -684,7 +685,7 @@
       if write_disposition == BigQueryDisposition.WRITE_TRUNCATE:
         # BigQuery can route data to the old table for 2 mins max so wait
         # that much time before creating the table and writing it
-        logging.warning('Sleeping for 150 seconds before the write as ' +
+        _LOGGER.warning('Sleeping for 150 seconds before the write as ' +
                         'BigQuery inserts can be routed to deleted table ' +
                         'for 2 mins after the delete and create.')
         # TODO(BEAM-2673): Remove this sleep by migrating to load api
@@ -713,7 +714,7 @@
         # request not for the actual execution of the query in the service.  If
         # the request times out we keep trying. This situation is quite possible
         # if the query will return a large number of rows.
-        logging.info('Waiting on response from query: %s ...', query)
+        _LOGGER.info('Waiting on response from query: %s ...', query)
         time.sleep(1.0)
         continue
       # We got some results. The last page is signalled by a missing pageToken.
@@ -975,7 +976,7 @@
 
   def _flush_rows_buffer(self):
     if self.rows_buffer:
-      logging.info('Writing %d rows to %s:%s.%s table.', len(self.rows_buffer),
+      _LOGGER.info('Writing %d rows to %s:%s.%s table.', len(self.rows_buffer),
                    self.project_id, self.dataset_id, self.table_id)
       passed, errors = self.client.insert_rows(
           project_id=self.project_id, dataset_id=self.dataset_id,
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
index 3658b9c..ae56e35 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
@@ -48,6 +48,9 @@
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class BigQueryWriteIntegrationTests(unittest.TestCase):
   BIG_QUERY_DATASET_ID = 'python_write_to_table_'
 
@@ -61,7 +64,7 @@
                                   str(int(time.time())),
                                   random.randint(0, 10000))
     self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
-    logging.info("Created dataset %s in project %s",
+    _LOGGER.info("Created dataset %s in project %s",
                  self.dataset_id, self.project)
 
   def tearDown(self):
@@ -69,11 +72,11 @@
         projectId=self.project, datasetId=self.dataset_id,
         deleteContents=True)
     try:
-      logging.info("Deleting dataset %s in project %s",
+      _LOGGER.info("Deleting dataset %s in project %s",
                    self.dataset_id, self.project)
       self.bigquery_client.client.datasets.Delete(request)
     except HttpError:
-      logging.debug('Failed to clean up dataset %s in project %s',
+      _LOGGER.debug('Failed to clean up dataset %s in project %s',
                     self.dataset_id, self.project)
 
   def create_table(self, table_name):
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
index 3d32611..9af7674 100644
--- a/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
@@ -45,12 +45,15 @@
 from apache_beam.transforms import PTransform
 from apache_beam.transforms.util import Values
 
+_LOGGER = logging.getLogger(__name__)
+
+
 # Protect against environments where datastore library is not available.
 # pylint: disable=wrong-import-order, wrong-import-position
 try:
   from google.cloud.proto.datastore.v1 import datastore_pb2
   from googledatastore import helper as datastore_helper
-  logging.warning(
+  _LOGGER.warning(
       'Using deprecated Datastore client.\n'
       'This client will be removed in Beam 3.0 (next Beam major release).\n'
       'Please migrate to apache_beam.io.gcp.datastore.v1new.datastoreio.')
@@ -125,7 +128,7 @@
           'Google Cloud IO not available, '
           'please install apache_beam[gcp]')
 
-    logging.warning('datastoreio read transform is experimental.')
+    _LOGGER.warning('datastoreio read transform is experimental.')
     super(ReadFromDatastore, self).__init__()
 
     if not project:
@@ -213,13 +216,13 @@
       else:
         estimated_num_splits = self._num_splits
 
-      logging.info("Splitting the query into %d splits", estimated_num_splits)
+      _LOGGER.info("Splitting the query into %d splits", estimated_num_splits)
       try:
         query_splits = query_splitter.get_splits(
             self._datastore, query, estimated_num_splits,
             helper.make_partition(self._project, self._datastore_namespace))
       except Exception:
-        logging.warning("Unable to parallelize the given query: %s", query,
+        _LOGGER.warning("Unable to parallelize the given query: %s", query,
                         exc_info=True)
         query_splits = [query]
 
@@ -296,7 +299,7 @@
     kind = query.kind[0].name
     latest_timestamp = ReadFromDatastore.query_latest_statistics_timestamp(
         project, namespace, datastore)
-    logging.info('Latest stats timestamp for kind %s is %s',
+    _LOGGER.info('Latest stats timestamp for kind %s is %s',
                  kind, latest_timestamp)
 
     kind_stats_query = (
@@ -316,13 +319,13 @@
     try:
       estimated_size_bytes = ReadFromDatastore.get_estimated_size_bytes(
           project, namespace, query, datastore)
-      logging.info('Estimated size bytes for query: %s', estimated_size_bytes)
+      _LOGGER.info('Estimated size bytes for query: %s', estimated_size_bytes)
       num_splits = int(min(ReadFromDatastore._NUM_QUERY_SPLITS_MAX, round(
           (float(estimated_size_bytes) /
            ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES))))
 
     except Exception as e:
-      logging.warning('Failed to fetch estimated size bytes: %s', e)
+      _LOGGER.warning('Failed to fetch estimated size bytes: %s', e)
       # Fallback in case estimated size is unavailable.
       num_splits = ReadFromDatastore._NUM_QUERY_SPLITS_MIN
 
@@ -346,7 +349,7 @@
      """
     self._project = project
     self._mutation_fn = mutation_fn
-    logging.warning('datastoreio write transform is experimental.')
+    _LOGGER.warning('datastoreio write transform is experimental.')
 
   def expand(self, pcoll):
     return (pcoll
@@ -424,7 +427,7 @@
           self._datastore, self._project, self._mutations,
           self._throttler, self._update_rpc_stats,
           throttle_delay=util.WRITE_BATCH_TARGET_LATENCY_MS//1000)
-      logging.debug("Successfully wrote %d mutations in %dms.",
+      _LOGGER.debug("Successfully wrote %d mutations in %dms.",
                     len(self._mutations), latency_ms)
 
       if not self._fixed_batch_size:
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
index a2bc521..4ea2898 100644
--- a/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
@@ -54,6 +54,9 @@
 # pylint: enable=ungrouped-imports
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 def key_comparator(k1, k2):
   """A comparator for Datastore keys.
 
@@ -216,7 +219,7 @@
   def commit(request):
     # Client-side throttling.
     while throttler.throttle_request(time.time()*1000):
-      logging.info("Delaying request for %ds due to previous failures",
+      _LOGGER.info("Delaying request for %ds due to previous failures",
                    throttle_delay)
       time.sleep(throttle_delay)
       rpc_stats_callback(throttled_secs=throttle_delay)
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py
index f5b1157..efe80c8 100644
--- a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastore_write_it_pipeline.py
@@ -47,6 +47,8 @@
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 
+_LOGGER = logging.getLogger(__name__)
+
 
 def new_pipeline_with_job_name(pipeline_options, job_name, suffix):
   """Create a pipeline with the given job_name and a suffix."""
@@ -108,7 +110,7 @@
   # Pipeline 1: Create and write the specified number of Entities to the
   # Cloud Datastore.
   ancestor_key = Key([kind, str(uuid.uuid4())], project=project)
-  logging.info('Writing %s entities to %s', num_entities, project)
+  _LOGGER.info('Writing %s entities to %s', num_entities, project)
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-write')
   _ = (p
        | 'Input' >> beam.Create(list(range(num_entities)))
@@ -121,7 +123,7 @@
   # Optional Pipeline 2: If a read limit was provided, read it and confirm
   # that the expected entities were read.
   if known_args.limit is not None:
-    logging.info('Querying a limited set of %s entities and verifying count.',
+    _LOGGER.info('Querying a limited set of %s entities and verifying count.',
                  known_args.limit)
     p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify-limit')
     query.limit = known_args.limit
@@ -134,7 +136,7 @@
     query.limit = None
 
   # Pipeline 3: Query the written Entities and verify result.
-  logging.info('Querying entities, asserting they match.')
+  _LOGGER.info('Querying entities, asserting they match.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify')
   entities = p | 'read from datastore' >> ReadFromDatastore(query)
 
@@ -145,7 +147,7 @@
   p.run()
 
   # Pipeline 4: Delete Entities.
-  logging.info('Deleting entities.')
+  _LOGGER.info('Deleting entities.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-delete')
   entities = p | 'read from datastore' >> ReadFromDatastore(query)
   _ = (entities
@@ -155,7 +157,7 @@
   p.run()
 
   # Pipeline 5: Query the written Entities, verify no results.
-  logging.info('Querying for the entities to make sure there are none present.')
+  _LOGGER.info('Querying for the entities to make sure there are none present.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify-deleted')
   entities = p | 'read from datastore' >> ReadFromDatastore(query)
 
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py
index f71a801..a70ea95 100644
--- a/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/datastoreio.py
@@ -51,6 +51,9 @@
 __all__ = ['ReadFromDatastore', 'WriteToDatastore', 'DeleteFromDatastore']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 @typehints.with_output_types(types.Entity)
 class ReadFromDatastore(PTransform):
   """A ``PTransform`` for querying Google Cloud Datastore.
@@ -173,11 +176,11 @@
         else:
           estimated_num_splits = self._num_splits
 
-        logging.info("Splitting the query into %d splits", estimated_num_splits)
+        _LOGGER.info("Splitting the query into %d splits", estimated_num_splits)
         query_splits = query_splitter.get_splits(
             client, query, estimated_num_splits)
       except query_splitter.QuerySplitterError:
-        logging.info("Unable to parallelize the given query: %s", query,
+        _LOGGER.info("Unable to parallelize the given query: %s", query,
                      exc_info=True)
         query_splits = [query]
 
@@ -219,7 +222,7 @@
       latest_timestamp = (
           ReadFromDatastore._SplitQueryFn
           .query_latest_statistics_timestamp(client))
-      logging.info('Latest stats timestamp for kind %s is %s',
+      _LOGGER.info('Latest stats timestamp for kind %s is %s',
                    kind_name, latest_timestamp)
 
       if client.namespace is None:
@@ -243,12 +246,12 @@
         estimated_size_bytes = (
             ReadFromDatastore._SplitQueryFn
             .get_estimated_size_bytes(client, query))
-        logging.info('Estimated size bytes for query: %s', estimated_size_bytes)
+        _LOGGER.info('Estimated size bytes for query: %s', estimated_size_bytes)
         num_splits = int(min(ReadFromDatastore._NUM_QUERY_SPLITS_MAX, round(
             (float(estimated_size_bytes) /
              ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES))))
       except Exception as e:
-        logging.warning('Failed to fetch estimated size bytes: %s', e)
+        _LOGGER.warning('Failed to fetch estimated size bytes: %s', e)
         # Fallback in case estimated size is unavailable.
         num_splits = ReadFromDatastore._NUM_QUERY_SPLITS_MIN
 
@@ -360,7 +363,7 @@
       """
       # Client-side throttling.
       while throttler.throttle_request(time.time() * 1000):
-        logging.info("Delaying request for %ds due to previous failures",
+        _LOGGER.info("Delaying request for %ds due to previous failures",
                      throttle_delay)
         time.sleep(throttle_delay)
         rpc_stats_callback(throttled_secs=throttle_delay)
@@ -412,7 +415,7 @@
           self._throttler,
           rpc_stats_callback=self._update_rpc_stats,
           throttle_delay=util.WRITE_BATCH_TARGET_LATENCY_MS // 1000)
-      logging.debug("Successfully wrote %d mutations in %dms.",
+      _LOGGER.debug("Successfully wrote %d mutations in %dms.",
                     len(self._batch.mutations), latency_ms)
 
       now = time.time() * 1000
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py
index 21633d9..c3bf8ef 100644
--- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py
@@ -40,6 +40,9 @@
   client = None
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 @unittest.skipIf(client is None, 'Datastore dependencies are not installed')
 class TypesTest(unittest.TestCase):
   _PROJECT = 'project'
@@ -168,7 +171,7 @@
     self.assertEqual(order, cq.order)
     self.assertEqual(distinct_on, cq.distinct_on)
 
-    logging.info('query: %s', q)  # Test __repr__()
+    _LOGGER.info('query: %s', q)  # Test __repr__()
 
   def testValueProviderFilters(self):
     self.vp_filters = [
@@ -193,7 +196,7 @@
       cq = q._to_client_query(self._test_client)
       self.assertEqual(exp_filter, cq.filters)
 
-      logging.info('query: %s', q)  # Test __repr__()
+      _LOGGER.info('query: %s', q)  # Test __repr__()
 
   def testQueryEmptyNamespace(self):
     # Test that we can pass a namespace of None.
diff --git a/sdks/python/apache_beam/io/gcp/datastore_write_it_pipeline.py b/sdks/python/apache_beam/io/gcp/datastore_write_it_pipeline.py
index 67e375f..2d0be8f 100644
--- a/sdks/python/apache_beam/io/gcp/datastore_write_it_pipeline.py
+++ b/sdks/python/apache_beam/io/gcp/datastore_write_it_pipeline.py
@@ -57,6 +57,9 @@
 # pylint: enable=ungrouped-imports
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 def new_pipeline_with_job_name(pipeline_options, job_name, suffix):
   """Create a pipeline with the given job_name and a suffix."""
   gcp_options = pipeline_options.view_as(GoogleCloudOptions)
@@ -137,7 +140,7 @@
 
   # Pipeline 1: Create and write the specified number of Entities to the
   # Cloud Datastore.
-  logging.info('Writing %s entities to %s', num_entities, project)
+  _LOGGER.info('Writing %s entities to %s', num_entities, project)
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-write')
 
   # pylint: disable=expression-not-assigned
@@ -152,7 +155,7 @@
   # Optional Pipeline 2: If a read limit was provided, read it and confirm
   # that the expected entities were read.
   if known_args.limit is not None:
-    logging.info('Querying a limited set of %s entities and verifying count.',
+    _LOGGER.info('Querying a limited set of %s entities and verifying count.',
                  known_args.limit)
     p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify-limit')
     query_with_limit = query_pb2.Query()
@@ -167,7 +170,7 @@
     p.run()
 
   # Pipeline 3: Query the written Entities and verify result.
-  logging.info('Querying entities, asserting they match.')
+  _LOGGER.info('Querying entities, asserting they match.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify')
   entities = p | 'read from datastore' >> ReadFromDatastore(project, query)
 
@@ -178,7 +181,7 @@
   p.run()
 
   # Pipeline 4: Delete Entities.
-  logging.info('Deleting entities.')
+  _LOGGER.info('Deleting entities.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-delete')
   entities = p | 'read from datastore' >> ReadFromDatastore(project, query)
   # pylint: disable=expression-not-assigned
@@ -189,7 +192,7 @@
   p.run()
 
   # Pipeline 5: Query the written Entities, verify no results.
-  logging.info('Querying for the entities to make sure there are none present.')
+  _LOGGER.info('Querying for the entities to make sure there are none present.')
   p = new_pipeline_with_job_name(pipeline_options, job_name, '-verify-deleted')
   entities = p | 'read from datastore' >> ReadFromDatastore(project, query)
 
diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py
index dfdc29d..c1e0314 100644
--- a/sdks/python/apache_beam/io/gcp/gcsio.py
+++ b/sdks/python/apache_beam/io/gcp/gcsio.py
@@ -44,6 +44,9 @@
 __all__ = ['GcsIO']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 # Issue a friendlier error message if the storage library is not available.
 # TODO(silviuc): Remove this guard when storage is available everywhere.
 try:
@@ -250,7 +253,7 @@
         maxBytesRewrittenPerCall=max_bytes_rewritten_per_call)
     response = self.client.objects.Rewrite(request)
     while not response.done:
-      logging.debug(
+      _LOGGER.debug(
           'Rewrite progress: %d of %d bytes, %s to %s',
           response.totalBytesRewritten, response.objectSize, src, dest)
       request.rewriteToken = response.rewriteToken
@@ -258,7 +261,7 @@
       if self._rewrite_cb is not None:
         self._rewrite_cb(response)
 
-    logging.debug('Rewrite done: %s to %s', src, dest)
+    _LOGGER.debug('Rewrite done: %s to %s', src, dest)
 
   # We intentionally do not decorate this method with a retry, as retrying is
   # handled in BatchApiRequest.Execute().
@@ -320,12 +323,12 @@
                 GcsIOError(errno.ENOENT, 'Source file not found: %s' % src))
           pair_to_status[pair] = exception
         elif not response.done:
-          logging.debug(
+          _LOGGER.debug(
               'Rewrite progress: %d of %d bytes, %s to %s',
               response.totalBytesRewritten, response.objectSize, src, dest)
           pair_to_request[pair].rewriteToken = response.rewriteToken
         else:
-          logging.debug('Rewrite done: %s to %s', src, dest)
+          _LOGGER.debug('Rewrite done: %s to %s', src, dest)
           pair_to_status[pair] = None
 
     return [(pair[0], pair[1], pair_to_status[pair]) for pair in src_dest_pairs]
@@ -458,7 +461,7 @@
     file_sizes = {}
     counter = 0
     start_time = time.time()
-    logging.info("Starting the size estimation of the input")
+    _LOGGER.info("Starting the size estimation of the input")
     while True:
       response = self.client.objects.List(request)
       for item in response.items:
@@ -466,12 +469,12 @@
         file_sizes[file_name] = item.size
         counter += 1
         if counter % 10000 == 0:
-          logging.info("Finished computing size of: %s files", len(file_sizes))
+          _LOGGER.info("Finished computing size of: %s files", len(file_sizes))
       if response.nextPageToken:
         request.pageToken = response.nextPageToken
       else:
         break
-    logging.info("Finished listing %s files in %s seconds.",
+    _LOGGER.info("Finished listing %s files in %s seconds.",
                  counter, time.time() - start_time)
     return file_sizes
 
@@ -492,7 +495,7 @@
       if http_error.status_code == 404:
         raise IOError(errno.ENOENT, 'Not found: %s' % self._path)
       else:
-        logging.error('HTTP error while requesting file %s: %s', self._path,
+        _LOGGER.error('HTTP error while requesting file %s: %s', self._path,
                       http_error)
         raise
     self._size = metadata.size
@@ -564,7 +567,7 @@
     try:
       self._client.objects.Insert(self._insert_request, upload=self._upload)
     except Exception as e:  # pylint: disable=broad-except
-      logging.error('Error in _start_upload while inserting file %s: %s',
+      _LOGGER.error('Error in _start_upload while inserting file %s: %s',
                     self._path, traceback.format_exc())
       self._upload_thread.last_error = e
     finally:
diff --git a/sdks/python/apache_beam/io/gcp/gcsio_overrides.py b/sdks/python/apache_beam/io/gcp/gcsio_overrides.py
index a5fc749..1be587d 100644
--- a/sdks/python/apache_beam/io/gcp/gcsio_overrides.py
+++ b/sdks/python/apache_beam/io/gcp/gcsio_overrides.py
@@ -26,6 +26,8 @@
 from apitools.base.py import http_wrapper
 from apitools.base.py import util
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class GcsIOOverrides(object):
   """Functions for overriding Google Cloud Storage I/O client."""
@@ -37,13 +39,13 @@
     # handling GCS download throttling errors (BEAM-7424)
     if (isinstance(retry_args.exc, exceptions.BadStatusCodeError) and
         retry_args.exc.status_code == http_wrapper.TOO_MANY_REQUESTS):
-      logging.debug(
+      _LOGGER.debug(
           'Caught GCS quota error (%s), retrying.', retry_args.exc.status_code)
     else:
       return http_wrapper.HandleExceptionsAndRebuildHttpConnections(retry_args)
 
     http_wrapper.RebuildHttpConnections(retry_args.http)
-    logging.debug('Retrying request to url %s after exception %s',
+    _LOGGER.debug('Retrying request to url %s after exception %s',
                   retry_args.http_request.url, retry_args.exc)
     sleep_seconds = util.CalculateWaitForRetry(
         retry_args.num_retries, max_wait=retry_args.max_retry_wait)
diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
index c3394a1..2e50763 100644
--- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
+++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
@@ -45,6 +45,8 @@
 
 MAX_RETRIES = 5
 
+_LOGGER = logging.getLogger(__name__)
+
 
 def retry_on_http_and_value_error(exception):
   """Filter allowing retries on Bigquery errors and value error."""
@@ -83,10 +85,10 @@
   def _matches(self, _):
     if self.checksum is None:
       response = self._query_with_retry()
-      logging.info('Read from given query (%s), total rows %d',
+      _LOGGER.info('Read from given query (%s), total rows %d',
                    self.query, len(response))
       self.checksum = compute_hash(response)
-      logging.info('Generate checksum: %s', self.checksum)
+      _LOGGER.info('Generate checksum: %s', self.checksum)
 
     return self.checksum == self.expected_checksum
 
@@ -95,7 +97,7 @@
       retry_filter=retry_on_http_and_value_error)
   def _query_with_retry(self):
     """Run Bigquery query with retry if got error http response"""
-    logging.info('Attempting to perform query %s to BQ', self.query)
+    _LOGGER.info('Attempting to perform query %s to BQ', self.query)
     # Create client here since it throws an exception if pickled.
     bigquery_client = bigquery.Client(self.project)
     query_job = bigquery_client.query(self.query)
@@ -134,7 +136,7 @@
   def _matches(self, _):
     if self.actual_data is None:
       self.actual_data = self._get_query_result()
-      logging.info('Result of query is: %r', self.actual_data)
+      _LOGGER.info('Result of query is: %r', self.actual_data)
 
     try:
       equal_to(self.expected_data)(self.actual_data)
@@ -179,7 +181,7 @@
       response = self._query_with_retry()
       if len(response) >= len(self.expected_data):
         return response
-      logging.debug('Query result contains %d rows' % len(response))
+      _LOGGER.debug('Query result contains %d rows' % len(response))
       time.sleep(1)
     if sys.version_info >= (3,):
       raise TimeoutError('Timeout exceeded for matcher.') # noqa: F821
@@ -207,13 +209,13 @@
     return bigquery_wrapper.get_table(self.project, self.dataset, self.table)
 
   def _matches(self, _):
-    logging.info('Start verify Bigquery table properties.')
+    _LOGGER.info('Start verify Bigquery table properties.')
     # Run query
     bigquery_wrapper = bigquery_tools.BigQueryWrapper()
 
     self.actual_table = self._get_table_with_retry(bigquery_wrapper)
 
-    logging.info('Table proto is %s', self.actual_table)
+    _LOGGER.info('Table proto is %s', self.actual_table)
 
     return all(
         self._match_property(v, self._get_or_none(self.actual_table, k))
@@ -231,7 +233,7 @@
 
   @staticmethod
   def _match_property(expected, actual):
-    logging.info("Matching %s to %s", expected, actual)
+    _LOGGER.info("Matching %s to %s", expected, actual)
     if isinstance(expected, dict):
       return all(
           BigQueryTableMatcher._match_property(
diff --git a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py
index 7a0b5c8..af94e02 100644
--- a/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py
+++ b/sdks/python/apache_beam/io/gcp/tests/pubsub_matcher.py
@@ -39,6 +39,8 @@
 DEFAULT_TIMEOUT = 5 * 60
 MAX_MESSAGES_IN_ONE_PULL = 50
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class PubSubMessageMatcher(BaseMatcher):
   """Matcher that verifies messages from given subscription.
@@ -123,7 +125,7 @@
       time.sleep(1)
 
     if time.time() - start_time > timeout:
-      logging.error('Timeout after %d sec. Received %d messages from %s.',
+      _LOGGER.error('Timeout after %d sec. Received %d messages from %s.',
                     timeout, len(total_messages), self.sub_name)
     return total_messages
 
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils.py b/sdks/python/apache_beam/io/gcp/tests/utils.py
index 4ed9af3..dbf8ac9 100644
--- a/sdks/python/apache_beam/io/gcp/tests/utils.py
+++ b/sdks/python/apache_beam/io/gcp/tests/utils.py
@@ -37,6 +37,9 @@
   bigquery = None
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class GcpTestIOError(retry.PermanentException):
   """Basic GCP IO error for testing. Function that raises this error should
   not be retried."""
@@ -93,7 +96,7 @@
     dataset_id: Name of the dataset where table is.
     table_id: Name of the table.
   """
-  logging.info('Clean up a BigQuery table with project: %s, dataset: %s, '
+  _LOGGER.info('Clean up a BigQuery table with project: %s, dataset: %s, '
                'table: %s.', project, dataset_id, table_id)
   client = bigquery.Client(project=project)
   table_ref = client.dataset(dataset_id).table(table_id)
diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py
index 71d74e8..0abdbaf 100644
--- a/sdks/python/apache_beam/io/hadoopfilesystem.py
+++ b/sdks/python/apache_beam/io/hadoopfilesystem.py
@@ -55,6 +55,8 @@
 _FILE_STATUS_TYPE_DIRECTORY = 'DIRECTORY'
 _FILE_STATUS_TYPE_FILE = 'FILE'
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class HdfsDownloader(filesystemio.Downloader):
 
@@ -196,7 +198,7 @@
   @staticmethod
   def _add_compression(stream, path, mime_type, compression_type):
     if mime_type != 'application/octet-stream':
-      logging.warning('Mime types are not supported. Got non-default mime_type:'
+      _LOGGER.warning('Mime types are not supported. Got non-default mime_type:'
                       ' %s', mime_type)
     if compression_type == CompressionTypes.AUTO:
       compression_type = CompressionTypes.detect_compression_type(path)
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index e21052f..dd97b7f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -61,6 +61,9 @@
            'Sink', 'Write', 'Writer']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 # Encapsulates information about a bundle of a source generated when method
 # BoundedSource.split() is invoked.
 # This is a named 4-tuple that has following fields.
@@ -1075,7 +1078,7 @@
   write_results = list(write_results)
   extra_shards = []
   if len(write_results) < min_shards:
-    logging.debug(
+    _LOGGER.debug(
         'Creating %s empty shard(s).', min_shards - len(write_results))
     for _ in range(min_shards - len(write_results)):
       writer = sink.open_writer(init_result, str(uuid.uuid4()))
diff --git a/sdks/python/apache_beam/io/mongodbio.py b/sdks/python/apache_beam/io/mongodbio.py
index 6004ca1..a89ccb1 100644
--- a/sdks/python/apache_beam/io/mongodbio.py
+++ b/sdks/python/apache_beam/io/mongodbio.py
@@ -66,6 +66,9 @@
 from apache_beam.transforms import Reshuffle
 from apache_beam.utils.annotations import experimental
 
+_LOGGER = logging.getLogger(__name__)
+
+
 try:
   # Mongodb has its own bundled bson, which is not compatible with bson pakcage.
   # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
@@ -80,7 +83,7 @@
   from pymongo import ReplaceOne
 except ImportError:
   objectid = None
-  logging.warning("Could not find a compatible bson package.")
+  _LOGGER.warning("Could not find a compatible bson package.")
 
 __all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
 
@@ -497,7 +500,7 @@
                      replacement=doc,
                      upsert=True))
     resp = self.client[self.db][self.coll].bulk_write(requests)
-    logging.debug('BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, '
+    _LOGGER.debug('BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, '
                   'nMatched:%d, Errors:%s' %
                   (resp.modified_count, resp.upserted_count, resp.matched_count,
                    resp.bulk_api_result.get('writeErrors')))
diff --git a/sdks/python/apache_beam/io/mongodbio_it_test.py b/sdks/python/apache_beam/io/mongodbio_it_test.py
index bfc6099..b315562 100644
--- a/sdks/python/apache_beam/io/mongodbio_it_test.py
+++ b/sdks/python/apache_beam/io/mongodbio_it_test.py
@@ -27,6 +27,8 @@
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 
+_LOGGER = logging.getLogger(__name__)
+
 
 def run(argv=None):
   default_db = 'beam_mongodbio_it_db'
@@ -54,7 +56,7 @@
   # Test Write to MongoDB
   with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
     start_time = time.time()
-    logging.info('Writing %d documents to mongodb' % known_args.num_documents)
+    _LOGGER.info('Writing %d documents to mongodb' % known_args.num_documents)
     docs = [{
         'number': x,
         'number_mod_2': x % 2,
@@ -67,13 +69,13 @@
                                                        known_args.mongo_coll,
                                                        known_args.batch_size)
   elapsed = time.time() - start_time
-  logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
+  _LOGGER.info('Writing %d documents to mongodb finished in %.3f seconds' %
                (known_args.num_documents, elapsed))
 
   # Test Read from MongoDB
   with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
     start_time = time.time()
-    logging.info('Reading from mongodb %s:%s' %
+    _LOGGER.info('Reading from mongodb %s:%s' %
                  (known_args.mongo_db, known_args.mongo_coll))
     r = p | 'ReadFromMongoDB' >> \
                 beam.io.ReadFromMongoDB(known_args.mongo_uri,
@@ -85,7 +87,7 @@
         r, equal_to([number for number in range(known_args.num_documents)]))
 
   elapsed = time.time() - start_time
-  logging.info('Read %d documents from mongodb finished in %.3f seconds' %
+  _LOGGER.info('Read %d documents from mongodb finished in %.3f seconds' %
                (known_args.num_documents, elapsed))
 
 
diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py
index c46f801..d4845fb 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -34,6 +34,9 @@
            'OrderedPositionRangeTracker', 'UnsplittableRangeTracker']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class OffsetRangeTracker(iobase.RangeTracker):
   """A 'RangeTracker' for non-negative positions of type 'long'."""
 
@@ -137,27 +140,27 @@
     assert isinstance(split_offset, (int, long))
     with self._lock:
       if self._stop_offset == OffsetRangeTracker.OFFSET_INFINITY:
-        logging.debug('refusing to split %r at %d: stop position unspecified',
+        _LOGGER.debug('refusing to split %r at %d: stop position unspecified',
                       self, split_offset)
         return
       if self._last_record_start == -1:
-        logging.debug('Refusing to split %r at %d: unstarted', self,
+        _LOGGER.debug('Refusing to split %r at %d: unstarted', self,
                       split_offset)
         return
 
       if split_offset <= self._last_record_start:
-        logging.debug(
+        _LOGGER.debug(
             'Refusing to split %r at %d: already past proposed stop offset',
             self, split_offset)
         return
       if (split_offset < self.start_position()
           or split_offset >= self.stop_position()):
-        logging.debug(
+        _LOGGER.debug(
             'Refusing to split %r at %d: proposed split position out of range',
             self, split_offset)
         return
 
-      logging.debug('Agreeing to split %r at %d', self, split_offset)
+      _LOGGER.debug('Agreeing to split %r at %d', self, split_offset)
 
       split_fraction = (float(split_offset - self._start_offset) / (
           self._stop_offset - self._start_offset))
diff --git a/sdks/python/apache_beam/io/source_test_utils.py b/sdks/python/apache_beam/io/source_test_utils.py
index d83a62a..7291786 100644
--- a/sdks/python/apache_beam/io/source_test_utils.py
+++ b/sdks/python/apache_beam/io/source_test_utils.py
@@ -68,6 +68,9 @@
            'assert_split_at_fraction_succeeds_and_consistent']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class ExpectedSplitOutcome(object):
   MUST_SUCCEED_AND_BE_CONSISTENT = 1
   MUST_FAIL = 2
@@ -588,7 +591,7 @@
         num_trials += 1
         if (num_trials >
             MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM):
-          logging.warning(
+          _LOGGER.warning(
               'After %d concurrent splitting trials at item #%d, observed '
               'only %s, giving up on this item',
               num_trials,
@@ -604,7 +607,7 @@
           have_failure = True
 
         if have_success and have_failure:
-          logging.info('%d trials to observe both success and failure of '
+          _LOGGER.info('%d trials to observe both success and failure of '
                        'concurrent splitting at item #%d', num_trials, i)
           break
     finally:
@@ -613,11 +616,11 @@
     num_total_trials += num_trials
 
     if num_total_trials > MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL:
-      logging.warning('After %d total concurrent splitting trials, considered '
+      _LOGGER.warning('After %d total concurrent splitting trials, considered '
                       'only %d items, giving up.', num_total_trials, i)
       break
 
-  logging.info('%d total concurrent splitting trials for %d items',
+  _LOGGER.info('%d total concurrent splitting trials for %d items',
                num_total_trials, len(expected_items))
 
 
diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py
index 340449f..3b426cc 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -42,6 +42,9 @@
            'WriteToText']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class _TextSource(filebasedsource.FileBasedSource):
   r"""A source for reading text files.
 
@@ -127,7 +130,7 @@
       raise ValueError('Cannot skip negative number of header lines: %d'
                        % skip_header_lines)
     elif skip_header_lines > 10:
-      logging.warning(
+      _LOGGER.warning(
           'Skipping %d header lines. Skipping large number of header '
           'lines might significantly slow down processing.')
     self._skip_header_lines = skip_header_lines
diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py
index 7b0bd87..ab7d2f5 100644
--- a/sdks/python/apache_beam/io/tfrecordio.py
+++ b/sdks/python/apache_beam/io/tfrecordio.py
@@ -38,6 +38,9 @@
 __all__ = ['ReadFromTFRecord', 'WriteToTFRecord']
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 def _default_crc32c_fn(value):
   """Calculates crc32c of a bytes object using either snappy or crcmod."""
 
@@ -54,7 +57,7 @@
       pass
 
     if not _default_crc32c_fn.fn:
-      logging.warning('Couldn\'t find python-snappy so the implementation of '
+      _LOGGER.warning('Couldn\'t find python-snappy so the implementation of '
                       '_TFRecordUtil._masked_crc32c is not as fast as it could '
                       'be.')
       _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
diff --git a/sdks/python/apache_beam/io/vcfio.py b/sdks/python/apache_beam/io/vcfio.py
index 2a13bf8..aed3579 100644
--- a/sdks/python/apache_beam/io/vcfio.py
+++ b/sdks/python/apache_beam/io/vcfio.py
@@ -51,6 +51,9 @@
 __all__ = ['ReadFromVcf', 'Variant', 'VariantCall', 'VariantInfo',
            'MalformedVcfRecord']
 
+
+_LOGGER = logging.getLogger(__name__)
+
 # Stores data about variant INFO fields. The type of 'data' is specified in the
 # VCF headers. 'field_count' is a string that specifies the number of fields
 # that the data type contains. Its value can either be a number representing a
@@ -346,7 +349,7 @@
                                                self._vcf_reader.formats)
       except (LookupError, ValueError):
         if self._allow_malformed_records:
-          logging.warning(
+          _LOGGER.warning(
               'An exception was raised when reading record from VCF file '
               '%s. Invalid record was %s: %s',
               self._file_name, self._last_record, traceback.format_exc())
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index 41cd306..4de4b51 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -48,6 +48,9 @@
     ]
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 def _static_value_provider_of(value_type):
   """"Helper function to plug a ValueProvider into argparse.
 
@@ -262,7 +265,7 @@
       add_extra_args_fn(parser)
     known_args, unknown_args = parser.parse_known_args(self._flags)
     if unknown_args:
-      logging.warning("Discarding unparseable args: %s", unknown_args)
+      _LOGGER.warning("Discarding unparseable args: %s", unknown_args)
     result = vars(known_args)
 
     # Apply the overrides if any
@@ -510,7 +513,7 @@
     """
     environment_region = os.environ.get('CLOUDSDK_COMPUTE_REGION')
     if environment_region:
-      logging.info('Using default GCP region %s from $CLOUDSDK_COMPUTE_REGION',
+      _LOGGER.info('Using default GCP region %s from $CLOUDSDK_COMPUTE_REGION',
                    environment_region)
       return environment_region
     try:
@@ -523,12 +526,12 @@
       raw_output = processes.check_output(cmd, stderr=DEVNULL)
       formatted_output = raw_output.decode('utf-8').strip()
       if formatted_output:
-        logging.info('Using default GCP region %s from `%s`',
+        _LOGGER.info('Using default GCP region %s from `%s`',
                      formatted_output, ' '.join(cmd))
         return formatted_output
     except RuntimeError:
       pass
-    logging.warning(
+    _LOGGER.warning(
         '--region not set; will default to us-central1. Future releases of '
         'Beam will require the user to set --region explicitly, or else have a '
         'default set via the gcloud tool. '
@@ -872,8 +875,13 @@
   def _add_argparse_args(cls, parser):
     parser.add_argument(
         '--job_endpoint', default=None,
-        help=('Job service endpoint to use. Should be in the form of address '
-              'and port, e.g. localhost:3000'))
+        help=('Job service endpoint to use. Should be in the form of host '
+              'and port, e.g. localhost:8099.'))
+    parser.add_argument(
+        '--artifact_endpoint', default=None,
+        help=('Artifact staging endpoint to use. Should be in the form of host '
+              'and port, e.g. localhost:8098. If none is specified, the '
+              'artifact endpoint sent from the job server is used.'))
     parser.add_argument(
         '--job-server-timeout', default=60, type=int,
         help=('Job service request timeout in seconds. The timeout '
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index 201324a..c1c25d1 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -24,7 +24,6 @@
 import unittest
 from builtins import object
 from builtins import range
-from collections import defaultdict
 
 import mock
 
@@ -40,9 +39,6 @@
 from apache_beam.pvalue import AsSingleton
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.runners.dataflow.native_io.iobase import NativeSource
-from apache_beam.runners.direct.evaluation_context import _ExecutionContext
-from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
-from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -812,84 +808,5 @@
                      p.transforms_stack[0])
 
 
-class DirectRunnerRetryTests(unittest.TestCase):
-
-  def test_retry_fork_graph(self):
-    # TODO(BEAM-3642): The FnApiRunner currently does not currently support
-    # retries.
-    p = beam.Pipeline(runner='BundleBasedDirectRunner')
-
-    # TODO(mariagh): Remove the use of globals from the test.
-    global count_b, count_c # pylint: disable=global-variable-undefined
-    count_b, count_c = 0, 0
-
-    def f_b(x):
-      global count_b  # pylint: disable=global-variable-undefined
-      count_b += 1
-      raise Exception('exception in f_b')
-
-    def f_c(x):
-      global count_c  # pylint: disable=global-variable-undefined
-      count_c += 1
-      raise Exception('exception in f_c')
-
-    names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe'])
-
-    fork_b = names | 'SendToB' >> beam.Map(f_b) # pylint: disable=unused-variable
-    fork_c = names | 'SendToC' >> beam.Map(f_c) # pylint: disable=unused-variable
-
-    with self.assertRaises(Exception):
-      p.run().wait_until_finish()
-    assert count_b == count_c == 4
-
-  def test_no_partial_writeouts(self):
-
-    class TestTransformEvaluator(_TransformEvaluator):
-
-      def __init__(self):
-        self._execution_context = _ExecutionContext(None, {})
-
-      def start_bundle(self):
-        self.step_context = self._execution_context.get_step_context()
-
-      def process_element(self, element):
-        k, v = element
-        state = self.step_context.get_keyed_state(k)
-        state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
-
-    # Create instance and add key/value, key/value2
-    evaluator = TestTransformEvaluator()
-    evaluator.start_bundle()
-    self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
-    self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))
-
-    evaluator.process_element(['key', 'value'])
-    self.assertEqual(
-        evaluator.step_context.existing_keyed_state['key'].state,
-        defaultdict(lambda: defaultdict(list)))
-    self.assertEqual(
-        evaluator.step_context.partial_keyed_state['key'].state,
-        {None: {'elements':['value']}})
-
-    evaluator.process_element(['key', 'value2'])
-    self.assertEqual(
-        evaluator.step_context.existing_keyed_state['key'].state,
-        defaultdict(lambda: defaultdict(list)))
-    self.assertEqual(
-        evaluator.step_context.partial_keyed_state['key'].state,
-        {None: {'elements':['value', 'value2']}})
-
-    # Simulate an exception (redo key/value)
-    evaluator._execution_context.reset()
-    evaluator.start_bundle()
-    evaluator.process_element(['key', 'value'])
-    self.assertEqual(
-        evaluator.step_context.existing_keyed_state['key'].state,
-        defaultdict(lambda: defaultdict(list)))
-    self.assertEqual(
-        evaluator.step_context.partial_keyed_state['key'].state,
-        {None: {'elements':['value']}})
-
-
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py
index 558e925..382cf52 100644
--- a/sdks/python/apache_beam/runners/direct/bundle_factory.py
+++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py
@@ -99,6 +99,10 @@
     def windows(self):
       return self._initial_windowed_value.windows
 
+    @property
+    def pane_info(self):
+      return self._initial_windowed_value.pane_info
+
     def add_value(self, value):
       self._appended_values.append(value)
 
@@ -107,8 +111,7 @@
       # _appended_values to yield WindowedValue on the fly.
       yield self._initial_windowed_value
       for v in self._appended_values:
-        yield WindowedValue(v, self._initial_windowed_value.timestamp,
-                            self._initial_windowed_value.windows)
+        yield self._initial_windowed_value.with_value(v)
 
   def __init__(self, pcollection, stacked=True):
     assert isinstance(pcollection, (pvalue.PBegin, pvalue.PCollection))
@@ -178,7 +181,8 @@
         (isinstance(self._elements[-1], (WindowedValue,
                                          _Bundle._StackedWindowedValues))) and
         self._elements[-1].timestamp == element.timestamp and
-        self._elements[-1].windows == element.windows):
+        self._elements[-1].windows == element.windows and
+        self._elements[-1].pane_info == element.pane_info):
       if isinstance(self._elements[-1], WindowedValue):
         self._elements[-1] = _Bundle._StackedWindowedValues(self._elements[-1])
       self._elements[-1].add_value(element.value)
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 73063a1..3332e39 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -187,7 +187,7 @@
   class CombinePerKeyOverride(PTransformOverride):
     def matches(self, applied_ptransform):
       if isinstance(applied_ptransform.transform, CombinePerKey):
-        return True
+        return applied_ptransform.inputs[0].windowing.is_default()
 
     def get_replacement_transform(self, transform):
       # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
index 66f1845..22a930c 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
@@ -19,6 +19,7 @@
 
 import threading
 import unittest
+from collections import defaultdict
 
 import hamcrest as hc
 
@@ -33,6 +34,9 @@
 from apache_beam.runners import DirectRunner
 from apache_beam.runners import TestDirectRunner
 from apache_beam.runners import create_runner
+from apache_beam.runners.direct.evaluation_context import _ExecutionContext
+from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
+from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
 from apache_beam.testing import test_pipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -129,5 +133,84 @@
            | beam.combiners.Count.Globally())
 
 
+class DirectRunnerRetryTests(unittest.TestCase):
+
+  def test_retry_fork_graph(self):
+    # TODO(BEAM-3642): The FnApiRunner currently does not currently support
+    # retries.
+    p = beam.Pipeline(runner='BundleBasedDirectRunner')
+
+    # TODO(mariagh): Remove the use of globals from the test.
+    global count_b, count_c # pylint: disable=global-variable-undefined
+    count_b, count_c = 0, 0
+
+    def f_b(x):
+      global count_b  # pylint: disable=global-variable-undefined
+      count_b += 1
+      raise Exception('exception in f_b')
+
+    def f_c(x):
+      global count_c  # pylint: disable=global-variable-undefined
+      count_c += 1
+      raise Exception('exception in f_c')
+
+    names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe'])
+
+    fork_b = names | 'SendToB' >> beam.Map(f_b) # pylint: disable=unused-variable
+    fork_c = names | 'SendToC' >> beam.Map(f_c) # pylint: disable=unused-variable
+
+    with self.assertRaises(Exception):
+      p.run().wait_until_finish()
+    assert count_b == count_c == 4
+
+  def test_no_partial_writeouts(self):
+
+    class TestTransformEvaluator(_TransformEvaluator):
+
+      def __init__(self):
+        self._execution_context = _ExecutionContext(None, {})
+
+      def start_bundle(self):
+        self.step_context = self._execution_context.get_step_context()
+
+      def process_element(self, element):
+        k, v = element
+        state = self.step_context.get_keyed_state(k)
+        state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
+
+    # Create instance and add key/value, key/value2
+    evaluator = TestTransformEvaluator()
+    evaluator.start_bundle()
+    self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
+    self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))
+
+    evaluator.process_element(['key', 'value'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value']}})
+
+    evaluator.process_element(['key', 'value2'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value', 'value2']}})
+
+    # Simulate an exception (redo key/value)
+    evaluator._execution_context.reset()
+    evaluator.start_bundle()
+    evaluator.process_element(['key', 'value'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value']}})
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index b7929cb..846fe58 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -656,7 +656,7 @@
       pcoll_b = p | 'b' >> beam.Create(['b'])
       assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a']))
 
-  def test_metrics(self):
+  def test_metrics(self, check_gauge=True):
     p = self.create_pipeline()
 
     counter = beam.metrics.Metrics.counter('ns', 'counter')
@@ -678,14 +678,17 @@
     c2, = res.metrics().query(beam.metrics.MetricsFilter().with_step('count2'))[
         'counters']
     self.assertEqual(c2.committed, 4)
+
     dist, = res.metrics().query(beam.metrics.MetricsFilter().with_step('dist'))[
         'distributions']
-    gaug, = res.metrics().query(
-        beam.metrics.MetricsFilter().with_step('gauge'))['gauges']
     self.assertEqual(
         dist.committed.data, beam.metrics.cells.DistributionData(4, 2, 1, 3))
     self.assertEqual(dist.committed.mean, 2.0)
-    self.assertEqual(gaug.committed.value, 3)
+
+    if check_gauge:
+      gaug, = res.metrics().query(
+          beam.metrics.MetricsFilter().with_step('gauge'))['gauges']
+      self.assertEqual(gaug.committed.value, 3)
 
   def test_callbacks_with_exception(self):
     elements_list = ['1', '2']
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 4f3e2f9..91f106f 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -588,122 +588,175 @@
 
   ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
   """
+  def is_compatible_with_combiner_lifting(trigger):
+    '''Returns whether this trigger is compatible with combiner lifting.
+
+    Certain triggers, such as those that fire after a certain number of
+    elements, need to observe every element, and as such are incompatible
+    with combiner lifting (which may aggregate several elements into one
+    before they reach the triggering code after shuffle).
+    '''
+    if trigger is None:
+      return True
+    elif trigger.WhichOneof('trigger') in (
+        'default', 'always', 'never', 'after_processing_time',
+        'after_synchronized_processing_time'):
+      return True
+    elif trigger.HasField('element_count'):
+      return trigger.element_count.element_count == 1
+    elif trigger.HasField('after_end_of_window'):
+      return is_compatible_with_combiner_lifting(
+          trigger.after_end_of_window.early_firings
+          ) and is_compatible_with_combiner_lifting(
+              trigger.after_end_of_window.late_firings)
+    elif trigger.HasField('after_any'):
+      return all(
+          is_compatible_with_combiner_lifting(t)
+          for t in trigger.after_any.subtriggers)
+    elif trigger.HasField('repeat'):
+      return is_compatible_with_combiner_lifting(trigger.repeat.subtrigger)
+    else:
+      return False
+
+  def can_lift(combine_per_key_transform):
+    windowing = context.components.windowing_strategies[
+        context.components.pcollections[
+            only_element(list(combine_per_key_transform.inputs.values()))
+        ].windowing_strategy_id]
+    if windowing.output_time != beam_runner_api_pb2.OutputTime.END_OF_WINDOW:
+      # This depends on the spec of PartialGroupByKey.
+      return False
+    elif not is_compatible_with_combiner_lifting(windowing.trigger):
+      return False
+    else:
+      return True
+
+  def make_stage(base_stage, transform):
+    return Stage(
+        transform.unique_name,
+        [transform],
+        downstream_side_inputs=base_stage.downstream_side_inputs,
+        must_follow=base_stage.must_follow,
+        parent=base_stage,
+        environment=base_stage.environment)
+
+  def lifted_stages(stage):
+    transform = stage.transforms[0]
+    combine_payload = proto_utils.parse_Bytes(
+        transform.spec.payload, beam_runner_api_pb2.CombinePayload)
+
+    input_pcoll = context.components.pcollections[only_element(
+        list(transform.inputs.values()))]
+    output_pcoll = context.components.pcollections[only_element(
+        list(transform.outputs.values()))]
+
+    element_coder_id = input_pcoll.coder_id
+    element_coder = context.components.coders[element_coder_id]
+    key_coder_id, _ = element_coder.component_coder_ids
+    accumulator_coder_id = combine_payload.accumulator_coder_id
+
+    key_accumulator_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, accumulator_coder_id])
+    key_accumulator_coder_id = context.add_or_get_coder_id(
+        key_accumulator_coder)
+
+    accumulator_iter_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.ITERABLE.urn),
+        component_coder_ids=[accumulator_coder_id])
+    accumulator_iter_coder_id = context.add_or_get_coder_id(
+        accumulator_iter_coder)
+
+    key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
+    key_accumulator_iter_coder_id = context.add_or_get_coder_id(
+        key_accumulator_iter_coder)
+
+    precombined_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[precombined_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Precombine.out',
+            coder_id=key_accumulator_coder_id,
+            windowing_strategy_id=input_pcoll.windowing_strategy_id,
+            is_bounded=input_pcoll.is_bounded))
+
+    grouped_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[grouped_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Group.out',
+            coder_id=key_accumulator_iter_coder_id,
+            windowing_strategy_id=output_pcoll.windowing_strategy_id,
+            is_bounded=output_pcoll.is_bounded))
+
+    merged_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[merged_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Merge.out',
+            coder_id=key_accumulator_coder_id,
+            windowing_strategy_id=output_pcoll.windowing_strategy_id,
+            is_bounded=output_pcoll.is_bounded))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Precombine',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_PRECOMBINE.urn,
+                payload=transform.spec.payload),
+            inputs=transform.inputs,
+            outputs={'out': precombined_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Group',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.primitives.GROUP_BY_KEY.urn),
+            inputs={'in': precombined_pcoll_id},
+            outputs={'out': grouped_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Merge',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
+                payload=transform.spec.payload),
+            inputs={'in': grouped_pcoll_id},
+            outputs={'out': merged_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/ExtractOutputs',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
+                payload=transform.spec.payload),
+            inputs={'in': merged_pcoll_id},
+            outputs=transform.outputs))
+
+  def unlifted_stages(stage):
+    transform = stage.transforms[0]
+    for sub in transform.subtransforms:
+      yield make_stage(stage, context.components.transforms[sub])
+
   for stage in stages:
     assert len(stage.transforms) == 1
     transform = stage.transforms[0]
     if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
-      combine_payload = proto_utils.parse_Bytes(
-          transform.spec.payload, beam_runner_api_pb2.CombinePayload)
-
-      input_pcoll = context.components.pcollections[only_element(
-          list(transform.inputs.values()))]
-      output_pcoll = context.components.pcollections[only_element(
-          list(transform.outputs.values()))]
-
-      element_coder_id = input_pcoll.coder_id
-      element_coder = context.components.coders[element_coder_id]
-      key_coder_id, _ = element_coder.component_coder_ids
-      accumulator_coder_id = combine_payload.accumulator_coder_id
-
-      key_accumulator_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.KV.urn),
-          component_coder_ids=[key_coder_id, accumulator_coder_id])
-      key_accumulator_coder_id = context.add_or_get_coder_id(
-          key_accumulator_coder)
-
-      accumulator_iter_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.ITERABLE.urn),
-          component_coder_ids=[accumulator_coder_id])
-      accumulator_iter_coder_id = context.add_or_get_coder_id(
-          accumulator_iter_coder)
-
-      key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.KV.urn),
-          component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
-      key_accumulator_iter_coder_id = context.add_or_get_coder_id(
-          key_accumulator_iter_coder)
-
-      precombined_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[precombined_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Precombine.out',
-              coder_id=key_accumulator_coder_id,
-              windowing_strategy_id=input_pcoll.windowing_strategy_id,
-              is_bounded=input_pcoll.is_bounded))
-
-      grouped_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[grouped_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Group.out',
-              coder_id=key_accumulator_iter_coder_id,
-              windowing_strategy_id=output_pcoll.windowing_strategy_id,
-              is_bounded=output_pcoll.is_bounded))
-
-      merged_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[merged_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Merge.out',
-              coder_id=key_accumulator_coder_id,
-              windowing_strategy_id=output_pcoll.windowing_strategy_id,
-              is_bounded=output_pcoll.is_bounded))
-
-      def make_stage(base_stage, transform):
-        return Stage(
-            transform.unique_name,
-            [transform],
-            downstream_side_inputs=base_stage.downstream_side_inputs,
-            must_follow=base_stage.must_follow,
-            parent=base_stage,
-            environment=base_stage.environment)
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Precombine',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_PRECOMBINE.urn,
-                  payload=transform.spec.payload),
-              inputs=transform.inputs,
-              outputs={'out': precombined_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Group',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.primitives.GROUP_BY_KEY.urn),
-              inputs={'in': precombined_pcoll_id},
-              outputs={'out': grouped_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Merge',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
-                  payload=transform.spec.payload),
-              inputs={'in': grouped_pcoll_id},
-              outputs={'out': merged_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/ExtractOutputs',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
-                  payload=transform.spec.payload),
-              inputs={'in': merged_pcoll_id},
-              outputs=transform.outputs))
-
+      expansion = lifted_stages if can_lift(transform) else unlifted_stages
+      for substage in expansion(stage):
+        yield substage
     else:
       yield stage
 
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 6de9e5c..8b078a5 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -20,13 +20,11 @@
 import functools
 import itertools
 import logging
-import sys
 import threading
 import time
 
 import grpc
 
-from apache_beam import version as beam_version
 from apache_beam.metrics import metric
 from apache_beam.metrics.execution import MetricResult
 from apache_beam.options.pipeline_options import DebugOptions
@@ -79,22 +77,6 @@
     self._dockerized_job_server = None
 
   @staticmethod
-  def default_docker_image():
-    sdk_version = beam_version.__version__
-    version_suffix = '.'.join([str(i) for i in sys.version_info[0:2]])
-    logging.warning('Make sure that locally built Python SDK docker image '
-                    'has Python %d.%d interpreter.' % (
-                        sys.version_info[0], sys.version_info[1]))
-
-    image = ('apachebeam/python{version_suffix}_sdk:{tag}'.format(
-        version_suffix=version_suffix, tag=sdk_version))
-    logging.info(
-        'Using Python SDK docker image: %s. If the image is not '
-        'available at local, we will try to pull from hub.docker.com'
-        % (image))
-    return image
-
-  @staticmethod
   def _create_environment(options):
     portable_options = options.view_as(PortableOptions)
     # Do not set a Runner. Otherwise this can cause problems in Java's
@@ -280,9 +262,12 @@
     prepare_response = job_service.Prepare(
         prepare_request,
         timeout=portable_options.job_server_timeout)
-    if prepare_response.artifact_staging_endpoint.url:
+    artifact_endpoint = (portable_options.artifact_endpoint
+                         if portable_options.artifact_endpoint
+                         else prepare_response.artifact_staging_endpoint.url)
+    if artifact_endpoint:
       stager = portable_stager.PortableStager(
-          grpc.insecure_channel(prepare_response.artifact_staging_endpoint.url),
+          grpc.insecure_channel(artifact_endpoint),
           prepare_response.staging_session_token)
       retrieval_token, _ = stager.stage_job_resources(
           options,
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index e2de26a..46dbad5 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -296,7 +296,7 @@
 
 class PortableRunnerInternalTest(unittest.TestCase):
   def test__create_default_environment(self):
-    docker_image = PortableRunner.default_docker_image()
+    docker_image = environments.DockerEnvironment.default_docker_image()
     self.assertEqual(
         PortableRunner._create_environment(PipelineOptions.from_dictionary({})),
         environments.DockerEnvironment(container_image=docker_image))
@@ -353,7 +353,7 @@
 
 
 def hasDockerImage():
-  image = PortableRunner.default_docker_image()
+  image = environments.DockerEnvironment.default_docker_image()
   try:
     check_image = subprocess.check_output("docker images -q %s" % image,
                                           shell=True)
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
index 8324e6b..26e4b60 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -173,7 +173,7 @@
 
   def __init__(self):
     self._to_send = queue.Queue()
-    self._received = collections.defaultdict(queue.Queue)
+    self._received = collections.defaultdict(lambda: queue.Queue(maxsize=5))
     self._receive_lock = threading.Lock()
     self._reads_finished = threading.Event()
     self._closed = False
@@ -267,7 +267,6 @@
         yield beam_fn_api_pb2.Elements(data=data)
 
   def _read_inputs(self, elements_iterator):
-    # TODO(robertwb): Pushback/throttling to avoid unbounded buffering.
     try:
       for elements in elements_iterator:
         for data in elements.data:
diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py
index 688d602..916faa4 100644
--- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py
+++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py
@@ -41,6 +41,8 @@
 import apache_beam as beam
 from apache_beam.testing.benchmarks.nexmark.models import nexmark_model
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class Command(object):
   def __init__(self, cmd, args):
@@ -53,7 +55,7 @@
                     timeout, self.cmd.__name__)
 
       self.cmd(*self.args)
-      logging.info('%d seconds elapsed. Thread (%s) finished.',
+      _LOGGER.info('%d seconds elapsed. Thread (%s) finished.',
                    timeout, self.cmd.__name__)
 
     thread = threading.Thread(target=thread_target, name='Thread-timeout')
diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py
index c5c5259..2f22b02 100644
--- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py
+++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py
@@ -72,6 +72,8 @@
     }
 ]
 
+_LOGGER = logging.getLogger(__name__)
+
 
 def parse_step(step_name):
   """Replaces white spaces and removes 'Step:' label
@@ -330,13 +332,13 @@
     if len(results) > 0:
       log = "Load test results for test: %s and timestamp: %s:" \
             % (results[0][ID_LABEL], results[0][SUBMIT_TIMESTAMP_LABEL])
-      logging.info(log)
+      _LOGGER.info(log)
       for result in results:
         log = "Metric: %s Value: %d" \
               % (result[METRICS_TYPE_LABEL], result[VALUE_LABEL])
-        logging.info(log)
+        _LOGGER.info(log)
     else:
-      logging.info("No test results were collected.")
+      _LOGGER.info("No test results were collected.")
 
 
 class BigQueryMetricsPublisher(object):
@@ -351,7 +353,7 @@
       for output in outputs:
         errors = output['errors']
         for err in errors:
-          logging.error(err['message'])
+          _LOGGER.error(err['message'])
           raise ValueError(
               'Unable save rows in BigQuery: {}'.format(err['message']))
 
diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers.py b/sdks/python/apache_beam/testing/pipeline_verifiers.py
index 1178672..cf99541 100644
--- a/sdks/python/apache_beam/testing/pipeline_verifiers.py
+++ b/sdks/python/apache_beam/testing/pipeline_verifiers.py
@@ -48,6 +48,8 @@
 
 MAX_RETRIES = 4
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class PipelineStateMatcher(BaseMatcher):
   """Matcher that verify pipeline job terminated in expected state
@@ -121,7 +123,7 @@
     if not matched_path:
       raise IOError('No such file or directory: %s' % self.file_path)
 
-    logging.info('Find %d files in %s: \n%s',
+    _LOGGER.info('Find %d files in %s: \n%s',
                  len(matched_path), self.file_path, '\n'.join(matched_path))
     for path in matched_path:
       with FileSystems.open(path, 'r') as f:
@@ -132,7 +134,7 @@
   def _matches(self, _):
     if self.sleep_secs:
       # Wait to have output file ready on FS
-      logging.info('Wait %d seconds...', self.sleep_secs)
+      _LOGGER.info('Wait %d seconds...', self.sleep_secs)
       time.sleep(self.sleep_secs)
 
     # Read from given file(s) path
@@ -140,7 +142,7 @@
 
     # Compute checksum
     self.checksum = utils.compute_hash(read_lines)
-    logging.info('Read from given path %s, %d lines, checksum: %s.',
+    _LOGGER.info('Read from given path %s, %d lines, checksum: %s.',
                  self.file_path, len(read_lines), self.checksum)
     return self.checksum == self.expected_checksum
 
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 06fd201..6c48b23 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -101,6 +101,8 @@
 K = typing.TypeVar('K')
 V = typing.TypeVar('V')
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class DoFnContext(object):
   """A context available to all methods of DoFn instance."""
@@ -377,16 +379,16 @@
       else:
         env1 = id_to_proto_map[env_id]
         env2 = context.environments[env_id]
-        assert env1.urn == env2.proto.urn, (
+        assert env1.urn == env2.to_runner_api(context).urn, (
             'Expected environments with the same ID to be equal but received '
             'environments with different URNs '
             '%r and %r',
-            env1.urn, env2.proto.urn)
-        assert env1.payload == env2.proto.payload, (
+            env1.urn, env2.to_runner_api(context).urn)
+        assert env1.payload == env2.to_runner_api(context).payload, (
             'Expected environments with the same ID to be equal but received '
             'environments with different payloads '
             '%r and %r',
-            env1.payload, env2.proto.payload)
+            env1.payload, env2.to_runner_api(context).payload)
     return self._proto
 
   def get_restriction_coder(self):
@@ -502,7 +504,7 @@
       try:
         callback()
       except Exception as e:
-        logging.warning("Got exception from finalization call: %s", e)
+        _LOGGER.warning("Got exception from finalization call: %s", e)
 
   def has_callbacks(self):
     return len(self._callbacks) > 0
@@ -747,7 +749,7 @@
       type_hints = type_hints.strip_iterable()
     except ValueError as e:
       # TODO(BEAM-8466): Raise exception here if using stricter type checking.
-      logging.warning('%s: %s', self.display_data()['fn'].value, e)
+      _LOGGER.warning('%s: %s', self.display_data()['fn'].value, e)
     return type_hints
 
   def infer_output_type(self, input_type):
@@ -1218,7 +1220,7 @@
         key_coder = coders.registry.get_coder(typehints.Any)
 
       if not key_coder.is_deterministic():
-        logging.warning(
+        _LOGGER.warning(
             'Key coder %s for transform %s with stateful DoFn may not '
             'be deterministic. This may cause incorrect behavior for complex '
             'key types. Consider adding an input type hint for this transform.',
diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py
index 8758ab8..999647f 100644
--- a/sdks/python/apache_beam/transforms/environments.py
+++ b/sdks/python/apache_beam/transforms/environments.py
@@ -22,6 +22,8 @@
 from __future__ import absolute_import
 
 import json
+import logging
+import sys
 
 from google.protobuf import message
 
@@ -119,12 +121,10 @@
 class DockerEnvironment(Environment):
 
   def __init__(self, container_image=None):
-    from apache_beam.runners.portability.portable_runner import PortableRunner
-
     if container_image:
       self.container_image = container_image
     else:
-      self.container_image = PortableRunner.default_docker_image()
+      self.container_image = self.default_docker_image()
 
   def __eq__(self, other):
     return self.__class__ == other.__class__ \
@@ -153,6 +153,24 @@
   def from_options(cls, options):
     return cls(container_image=options.environment_config)
 
+  @staticmethod
+  def default_docker_image():
+    from apache_beam import version as beam_version
+
+    sdk_version = beam_version.__version__
+    version_suffix = '.'.join([str(i) for i in sys.version_info[0:2]])
+    logging.warning('Make sure that locally built Python SDK docker image '
+                    'has Python %d.%d interpreter.' % (
+                        sys.version_info[0], sys.version_info[1]))
+
+    image = ('apachebeam/python{version_suffix}_sdk:{tag}'.format(
+        version_suffix=version_suffix, tag=sdk_version))
+    logging.info(
+        'Using Python SDK docker image: %s. If the image is not '
+        'available at local, we will try to pull from hub.docker.com'
+        % (image))
+    return image
+
 
 @Environment.register_urn(common_urns.environments.PROCESS.urn,
                           beam_runner_api_pb2.ProcessPayload)
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index 9761da2..e5bc20d 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -67,6 +67,9 @@
     ]
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class AccumulationMode(object):
   """Controls what to do with data when a trigger fires multiple times."""
   DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING
@@ -1190,7 +1193,7 @@
             window, self.NONSPECULATIVE_INDEX)
         state.add_state(window, self.NONSPECULATIVE_INDEX, 1)
         windowed_value.PaneInfoTiming.LATE
-        logging.warning('Watermark moved backwards in time '
+        _LOGGER.warning('Watermark moved backwards in time '
                         'or late data moved window end forward.')
     else:
       nonspeculative_index = state.get_state(window, self.NONSPECULATIVE_INDEX)
@@ -1320,7 +1323,7 @@
         elif time_domain == TimeDomain.WATERMARK:
           time_marker = watermark
         else:
-          logging.error(
+          _LOGGER.error(
               'TimeDomain error: No timers defined for time domain %s.',
               time_domain)
         if timestamp <= time_marker:
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index 22ecda3..d1e5433 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -41,6 +41,7 @@
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import ptransform
 from apache_beam.transforms import trigger
 from apache_beam.transforms.core import Windowing
 from apache_beam.transforms.trigger import AccumulationMode
@@ -632,6 +633,54 @@
   }
 
 
+def _windowed_value_info_map_fn(
+    k, vs,
+    window=beam.DoFn.WindowParam,
+    t=beam.DoFn.TimestampParam,
+    p=beam.DoFn.PaneInfoParam):
+  return (
+      k,
+      _windowed_value_info(WindowedValue(
+          vs, windows=[window], timestamp=t, pane_info=p)))
+
+
+def _windowed_value_info_check(actual, expected):
+
+  def format(panes):
+    return '\n[%s]\n' % '\n '.join(str(pane) for pane in sorted(
+        panes, key=lambda pane: pane.get('timestamp', None)))
+
+  if len(actual) > len(expected):
+    raise AssertionError(
+        'Unexpected output: expected %s but got %s' % (
+            format(expected), format(actual)))
+  elif len(expected) > len(actual):
+    raise AssertionError(
+        'Unmatched output: expected %s but got %s' % (
+            format(expected), format(actual)))
+  else:
+
+    def diff(actual, expected):
+      for key in sorted(expected.keys(), reverse=True):
+        if key in actual:
+          if actual[key] != expected[key]:
+            return key
+
+    for output in actual:
+      diffs = [diff(output, pane) for pane in expected]
+      if all(diffs):
+        raise AssertionError(
+            'Unmatched output: %s not found in %s (diffs in %s)' % (
+                output, format(expected), diffs))
+
+
+class _ConcatCombineFn(beam.CombineFn):
+  create_accumulator = lambda self: []
+  add_input = lambda self, acc, element: acc.append(element) or acc
+  merge_accumulators = lambda self, accs: sum(accs, [])
+  extract_output = lambda self, acc: acc
+
+
 class TriggerDriverTranscriptTest(TranscriptTest):
 
   def _execute(
@@ -782,42 +831,23 @@
         elif action == 'expect':
           actual = list(seen.read())
           seen.clear()
-
-          if len(actual) > len(data):
-            raise AssertionError(
-                'Unexpected output: expected %s but got %s' % (data, actual))
-          elif len(data) > len(actual):
-            raise AssertionError(
-                'Unmatched output: expected %s but got %s' % (data, actual))
-          else:
-
-            def diff(actual, expected):
-              for key in sorted(expected.keys(), reverse=True):
-                if key in actual:
-                  if actual[key] != expected[key]:
-                    return key
-
-            for output in actual:
-              diffs = [diff(output, expected) for expected in data]
-              if all(diffs):
-                raise AssertionError(
-                    'Unmatched output: %s not found in %s (diffs in %s)' % (
-                        output, data, diffs))
+          _windowed_value_info_check(actual, data)
 
         else:
           raise ValueError('Unexpected action: %s' % action)
 
-    with TestPipeline() as p:
-      # TODO(BEAM-8601): Pass this during pipeline construction.
-      p.options.view_as(StandardOptions).streaming = True
+    @ptransform.ptransform_fn
+    def CheckAggregation(inputs_and_expected, aggregation):
       # Split the test stream into a branch of to-be-processed elements, and
       # a branch of expected results.
       inputs, expected = (
-          p
-          | read_test_stream
-          | beam.MapTuple(
-              lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
-              ).with_outputs('input', 'expect'))
+          inputs_and_expected
+          | beam.FlatMapTuple(
+              lambda tag, value: [
+                  beam.pvalue.TaggedOutput(tag, ('key1', value)),
+                  beam.pvalue.TaggedOutput(tag, ('key2', value)),
+              ]).with_outputs('input', 'expect'))
+
       # Process the inputs with the given windowing to produce actual outputs.
       outputs = (
           inputs
@@ -828,15 +858,8 @@
               trigger=trigger_fn,
               accumulation_mode=accumulation_mode,
               timestamp_combiner=timestamp_combiner)
-          | beam.GroupByKey()
-          | beam.MapTuple(
-              lambda k, vs,
-                     window=beam.DoFn.WindowParam,
-                     t=beam.DoFn.TimestampParam,
-                     p=beam.DoFn.PaneInfoParam: (
-                         k,
-                         _windowed_value_info(WindowedValue(
-                             vs, windows=[window], timestamp=t, pane_info=p))))
+          | aggregation
+          | beam.MapTuple(_windowed_value_info_map_fn)
           # Place outputs back into the global window to allow flattening
           # and share a single state in Check.
           | 'Global' >> beam.WindowInto(beam.transforms.window.GlobalWindows()))
@@ -850,6 +873,16 @@
        | beam.Flatten()
        | beam.ParDo(Check(self.allow_out_of_order)))
 
+    with TestPipeline() as p:
+      # TODO(BEAM-8601): Pass this during pipeline construction.
+      p.options.view_as(StandardOptions).streaming = True
+
+      # We can have at most one test stream per pipeline, so we share it.
+      inputs_and_expected = p | read_test_stream
+      _ = inputs_and_expected | CheckAggregation(beam.GroupByKey())
+      _ = inputs_and_expected | CheckAggregation(beam.CombinePerKey(
+          _ConcatCombineFn()))
+
 
 class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
   allow_out_of_order = False
@@ -859,6 +892,83 @@
   allow_out_of_order = True
 
 
+class BatchTranscriptTest(TranscriptTest):
+
+  def _execute(
+      self, window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
+      transcript, spec):
+    if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED:
+      self.skipTest(
+          'Non-fnapi timestamp combiner: %s' % spec.get('timestamp_combiner'))
+
+    if accumulation_mode != AccumulationMode.ACCUMULATING:
+      self.skipTest('Batch mode only makes sense for accumulating.')
+
+    watermark = MIN_TIMESTAMP
+    for action, params in transcript:
+      if action == 'watermark':
+        watermark = params
+      elif action == 'input':
+        if any(t <= watermark for t in params):
+          self.skipTest('Batch mode never has late data.')
+
+    inputs = sum([vs for action, vs in transcript if action == 'input'], [])
+    final_panes_by_window = {}
+    for action, params in transcript:
+      if action == 'expect':
+        for expected in params:
+          trimmed = {}
+          for field in ('window', 'values', 'timestamp'):
+            if field in expected:
+              trimmed[field] = expected[field]
+          final_panes_by_window[tuple(expected['window'])] = trimmed
+    final_panes = list(final_panes_by_window.values())
+
+    if window_fn.is_merging():
+      merged_away = set()
+      class MergeContext(WindowFn.MergeContext):
+        def merge(_, to_be_merged, merge_result):
+          for window in to_be_merged:
+            if window != merge_result:
+              merged_away.add(window)
+      all_windows = [IntervalWindow(*pane['window']) for pane in final_panes]
+      window_fn.merge(MergeContext(all_windows))
+      final_panes = [
+          pane for pane in final_panes
+          if IntervalWindow(*pane['window']) not in merged_away]
+
+    with TestPipeline() as p:
+      input_pc = (
+          p
+          | beam.Create(inputs)
+          | beam.Map(lambda t: TimestampedValue(('key', t), t))
+          | beam.WindowInto(
+              window_fn,
+              trigger=trigger_fn,
+              accumulation_mode=accumulation_mode,
+              timestamp_combiner=timestamp_combiner))
+
+      grouped = input_pc | 'Grouped' >> (
+          beam.GroupByKey()
+          | beam.MapTuple(_windowed_value_info_map_fn)
+          | beam.MapTuple(lambda _, value: value))
+
+      combined = input_pc | 'Combined' >> (
+          beam.CombinePerKey(_ConcatCombineFn())
+          | beam.MapTuple(_windowed_value_info_map_fn)
+          | beam.MapTuple(lambda _, value: value))
+
+      assert_that(
+          grouped,
+          lambda actual: _windowed_value_info_check(actual, final_panes),
+          label='CheckGrouped')
+
+      assert_that(
+          combined,
+          lambda actual: _windowed_value_info_check(actual, final_panes),
+          label='CheckCombined')
+
+
 TRANSCRIPT_TEST_FILE = os.path.join(
     os.path.dirname(__file__), '..', 'testing', 'data',
     'trigger_transcripts.yaml')
@@ -866,6 +976,7 @@
   TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
   TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
   WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+  BatchTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py
index 0d5e14f..3c45d83 100644
--- a/sdks/python/apache_beam/transforms/window_test.py
+++ b/sdks/python/apache_beam/transforms/window_test.py
@@ -286,21 +286,21 @@
   @attr('ValidatesRunner')
   def test_window_assignment_idempotency(self):
     with TestPipeline() as p:
-      pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+      pcoll = self.timestamped_key_values(p, 'key', 0, 2, 4)
       result = (pcoll
                 | 'window' >> WindowInto(FixedWindows(2))
                 | 'same window' >> WindowInto(FixedWindows(2))
                 | 'same window again' >> WindowInto(FixedWindows(2))
                 | GroupByKey())
 
-      assert_that(result, equal_to([('key', [0, 1]),
-                                    ('key', [2, 3]),
+      assert_that(result, equal_to([('key', [0]),
+                                    ('key', [2]),
                                     ('key', [4])]))
 
   @attr('ValidatesRunner')
   def test_window_assignment_through_multiple_gbk_idempotency(self):
     with TestPipeline() as p:
-      pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4)
+      pcoll = self.timestamped_key_values(p, 'key', 0, 2, 4)
       result = (pcoll
                 | 'window' >> WindowInto(FixedWindows(2))
                 | 'gbk' >> GroupByKey()
@@ -309,8 +309,8 @@
                 | 'same window again' >> WindowInto(FixedWindows(2))
                 | 'gbk again' >> GroupByKey())
 
-      assert_that(result, equal_to([('key', [[[0, 1]]]),
-                                    ('key', [[[2, 3]]]),
+      assert_that(result, equal_to([('key', [[[0]]]),
+                                    ('key', [[[2]]]),
                                     ('key', [[[4]]])]))
 
 class RunnerApiTest(unittest.TestCase):
diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py
index 6062e6f..b64e020 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -96,6 +96,8 @@
 # to templated (upper-case) versions instead.
 DISALLOWED_PRIMITIVE_TYPES = (list, set, tuple, dict)
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class SimpleTypeHintError(TypeError):
   pass
@@ -1086,9 +1088,9 @@
     if isinstance(type_params, tuple) and len(type_params) == 3:
       yield_type, send_type, return_type = type_params
       if send_type is not None:
-        logging.warning('Ignoring send_type hint: %s' % send_type)
+        _LOGGER.warning('Ignoring send_type hint: %s' % send_type)
       if send_type is not None:
-        logging.warning('Ignoring return_type hint: %s' % return_type)
+        _LOGGER.warning('Ignoring return_type hint: %s' % return_type)
     else:
       yield_type = type_params
     return self.IteratorTypeConstraint(yield_type)
diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py
index 0606744..c6f7295 100644
--- a/sdks/python/apache_beam/utils/profiler.py
+++ b/sdks/python/apache_beam/utils/profiler.py
@@ -36,6 +36,8 @@
 
 from apache_beam.io import filesystems
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class Profile(object):
   """cProfile wrapper context for saving and logging profiler results."""
@@ -53,14 +55,14 @@
     self.profile_output = None
 
   def __enter__(self):
-    logging.info('Start profiling: %s', self.profile_id)
+    _LOGGER.info('Start profiling: %s', self.profile_id)
     self.profile = cProfile.Profile()
     self.profile.enable()
     return self
 
   def __exit__(self, *args):
     self.profile.disable()
-    logging.info('Stop profiling: %s', self.profile_id)
+    _LOGGER.info('Stop profiling: %s', self.profile_id)
 
     if self.profile_location:
       dump_location = os.path.join(
@@ -70,7 +72,7 @@
       try:
         os.close(fd)
         self.profile.dump_stats(filename)
-        logging.info('Copying profiler data to: [%s]', dump_location)
+        _LOGGER.info('Copying profiler data to: [%s]', dump_location)
         self.file_copy_fn(filename, dump_location)
       finally:
         os.remove(filename)
@@ -81,7 +83,7 @@
       self.stats = pstats.Stats(
           self.profile, stream=s).sort_stats(Profile.SORTBY)
       self.stats.print_stats()
-      logging.info('Profiler data: [%s]', s.getvalue())
+      _LOGGER.info('Profiler data: [%s]', s.getvalue())
 
   @staticmethod
   def default_file_copy_fn(src, dest):
@@ -176,5 +178,5 @@
       return
     report_start_time = time.time()
     heap_profile = self._hpy().heap()
-    logging.info('*** MemoryReport Heap:\n %s\n MemoryReport took %.1f seconds',
+    _LOGGER.info('*** MemoryReport Heap:\n %s\n MemoryReport took %.1f seconds',
                  heap_profile, time.time() - report_start_time)
diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py
index 59d8dec..e34e364 100644
--- a/sdks/python/apache_beam/utils/retry.py
+++ b/sdks/python/apache_beam/utils/retry.py
@@ -51,6 +51,9 @@
 # pylint: enable=wrong-import-order, wrong-import-position
 
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class PermanentException(Exception):
   """Base class for exceptions that should not be retried."""
   pass
@@ -153,7 +156,7 @@
 
 
 def with_exponential_backoff(
-    num_retries=7, initial_delay_secs=5.0, logger=logging.warning,
+    num_retries=7, initial_delay_secs=5.0, logger=_LOGGER.warning,
     retry_filter=retry_on_server_errors_filter,
     clock=Clock(), fuzz=True, factor=2, max_delay_secs=60 * 60):
   """Decorator with arguments that control the retry logic.
@@ -163,7 +166,7 @@
     initial_delay_secs: The delay before the first retry, in seconds.
     logger: A callable used to report an exception. Must have the same signature
       as functions in the standard logging module. The default is
-      logging.warning.
+      _LOGGER.warning.
     retry_filter: A callable getting the exception raised and returning True
       if the retry should happen. For instance we do not want to retry on
       404 Http errors most of the time. The default value will return true
diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py
index 65dbcae..fd55f18 100644
--- a/sdks/python/apache_beam/utils/subprocess_server.py
+++ b/sdks/python/apache_beam/utils/subprocess_server.py
@@ -33,6 +33,8 @@
 
 from apache_beam.version import __version__ as beam_version
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class SubprocessServer(object):
   """An abstract base class for running GRPC Servers as an external process.
@@ -78,7 +80,7 @@
         port, = pick_port(None)
         cmd = [arg.replace('{{PORT}}', str(port)) for arg in self._cmd]
       endpoint = 'localhost:%s' % port
-      logging.warning("Starting service with %s", str(cmd).replace("',", "'"))
+      _LOGGER.warning("Starting service with %s", str(cmd).replace("',", "'"))
       try:
         self._process = subprocess.Popen(cmd)
         wait_secs = .1
@@ -86,7 +88,7 @@
         channel_ready = grpc.channel_ready_future(channel)
         while True:
           if self._process.poll() is not None:
-            logging.error("Starting job service with %s", cmd)
+            _LOGGER.error("Starting job service with %s", cmd)
             raise RuntimeError(
                 'Service failed to start up with error %s' %
                 self._process.poll())
@@ -100,7 +102,7 @@
                         endpoint)
         return self._stub_class(channel)
       except:  # pylint: disable=bare-except
-        logging.exception("Error bringing up service")
+        _LOGGER.exception("Error bringing up service")
         self.stop()
         raise
 
@@ -169,7 +171,7 @@
             classifier='SNAPSHOT',
             appendix=appendix))
     if os.path.exists(local_path):
-      logging.info('Using pre-built snapshot at %s', local_path)
+      _LOGGER.info('Using pre-built snapshot at %s', local_path)
       return local_path
     elif '.dev' in beam_version:
       # TODO: Attempt to use nightly snapshots?
@@ -187,7 +189,7 @@
     if os.path.exists(url):
       return url
     else:
-      logging.warning('Downloading job server jar from %s' % url)
+      _LOGGER.warning('Downloading job server jar from %s' % url)
       cached_jar = os.path.join(cls.JAR_CACHE, os.path.basename(url))
       if not os.path.exists(cached_jar):
         if not os.path.exists(cls.JAR_CACHE):
