SAMZA-2751: RunLoop and High-Level API changes for Drain (#1616)

diff --git a/gradle.properties b/gradle.properties
index 1a6ee97..cab76e0 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -16,7 +16,7 @@
 # under the License.
 group=org.apache.samza
 version=1.7.0-SNAPSHOT
-scalaSuffix=2.11
+scalaSuffix=2.12
 
 # after changing this value, run `$ ./gradlew wrapper` and commit the resulting changed files
 gradleVersion=5.2.1
diff --git a/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java b/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java
new file mode 100644
index 0000000..4373d51
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.system;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+/**
+ * The DrainMessage is a control message that is sent out to next stage
+ * once the task has consumed to the end of a bounded stream.
+ */
+public class DrainMessage extends ControlMessage {
+  /**
+   * Id used to invalidate DrainMessages between runs. Ties to app.run.id from config.
+   */
+  private final String runId;
+
+  public DrainMessage(String runId) {
+    this(null, runId);
+  }
+
+  public DrainMessage(@JsonProperty("task-name") String taskName, @JsonProperty("run-id") String runId) {
+    super(taskName);
+    this.runId = runId;
+  }
+
+  public String getRunId() {
+    return runId;
+  }
+
+  @Override
+  public int hashCode() {
+    final int prime = 31;
+    final int result = prime * super.hashCode() + (this.runId != null ? this.runId.hashCode() : 0);
+    return result;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj)
+      return true;
+    if (obj == null)
+      return false;
+    if (getClass() != obj.getClass())
+      return false;
+
+    final DrainMessage other = (DrainMessage) obj;
+    if (!super.equals(other)) {
+      return false;
+    }
+    if (!runId.equals(other.runId)) {
+      return false;
+    }
+    return true;
+  }
+}
diff --git a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
index 2112463..1f4a740 100644
--- a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
+++ b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
@@ -132,7 +132,6 @@
     if (envelopeKeyorOffset == null) {
       return new SystemStreamPartition(systemStreamPartition, 0);
     }
-
     // modulo 31 first to best spread out the hashcode and then modulo elasticityFactor for actual keyBucket
     // Note: elasticityFactor <= 16 so modulo 31 is safe to do.
     int keyBucket = (Math.abs(envelopeKeyorOffset.hashCode()) % 31) % elasticityFactor;
@@ -162,6 +161,10 @@
     return END_OF_STREAM_OFFSET.equals(offset);
   }
 
+  public boolean isDrain() {
+    return message != null && DrainMessage.class.isAssignableFrom(message.getClass());
+  }
+
   /**
    * This method is deprecated in favor of WatermarkManager.buildEndOfStreamEnvelope(SystemStreamPartition ssp).
    *
@@ -172,6 +175,10 @@
     return new IncomingMessageEnvelope(ssp, END_OF_STREAM_OFFSET, null, new EndOfStreamMessage(null));
   }
 
+  public static IncomingMessageEnvelope buildDrainMessage(SystemStreamPartition ssp, String runId) {
+    return new IncomingMessageEnvelope(ssp, null, null, new DrainMessage(runId));
+  }
+
   public static IncomingMessageEnvelope buildWatermarkEnvelope(SystemStreamPartition ssp, long watermark) {
     return new IncomingMessageEnvelope(ssp, null, null, new WatermarkMessage(watermark, null));
   }
diff --git a/samza-api/src/main/java/org/apache/samza/system/MessageType.java b/samza-api/src/main/java/org/apache/samza/system/MessageType.java
index 7129d00..9f58621 100644
--- a/samza-api/src/main/java/org/apache/samza/system/MessageType.java
+++ b/samza-api/src/main/java/org/apache/samza/system/MessageType.java
@@ -26,7 +26,8 @@
 public enum MessageType {
   USER_MESSAGE,
   WATERMARK,
-  END_OF_STREAM;
+  END_OF_STREAM,
+  DRAIN;
 
   /**
    * Returns the {@link MessageType} of a particular intermediate stream message.
@@ -38,6 +39,8 @@
       return WATERMARK;
     } else if (message instanceof EndOfStreamMessage) {
       return END_OF_STREAM;
+    } else if (message instanceof DrainMessage) {
+      return DRAIN;
     } else {
       return USER_MESSAGE;
     }
diff --git a/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java b/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java
new file mode 100644
index 0000000..c5ca7d5
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * The DrainListenerTask augments {@link StreamTask} allowing the method implementor to specify code to be
+ * executed when the 'drain' is reached for a task.
+ */
+public interface DrainListenerTask {
+  /**
+   * Guaranteed to be invoked when all SSPs processed by this task have drained.
+   *
+   * @param collector Contains the means of sending message envelopes to an output stream.*
+   * @param coordinator Manages execution of tasks.
+   *
+   * @throws Exception Any exception types encountered during the execution of the processing task.
+   */
+  void onDrain(MessageCollector collector, TaskCoordinator coordinator) throws Exception;
+}
diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
index 145f4cd..d36ac66 100644
--- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
@@ -166,9 +166,12 @@
 
   // Enable DrainMonitor in Samza Containers
   // Default is false for now. Will be turned on after testing
-  public static final String DRAIN_MONITOR_ENABLED = "samza.drain-monitor.enabled";
+  public static final String DRAIN_MONITOR_ENABLED = "job.drain-monitor.enabled";
   public static final boolean DRAIN_MONITOR_ENABLED_DEFAULT = false;
 
+  public static final String DRAIN_MONITOR_POLL_INTERVAL_MILLIS = "job.drain-monitor.poll.interval.ms";
+  public static final long DRAIN_MONITOR_POLL_INTERVAL_MILLIS_DEFAULT = 60_000;
+
   // Enable ClusterBasedJobCoordinator aka ApplicationMaster High Availability (AM-HA).
   // High availability allows new AM to establish connection with already running containers
   public static final String YARN_AM_HIGH_AVAILABILITY_ENABLED = "yarn.am.high-availability.enabled";
@@ -479,6 +482,10 @@
     return getBoolean(DRAIN_MONITOR_ENABLED, DRAIN_MONITOR_ENABLED_DEFAULT);
   }
 
+  public long getDrainMonitorPollIntervalMillis() {
+    return getLong(DRAIN_MONITOR_POLL_INTERVAL_MILLIS, DRAIN_MONITOR_POLL_INTERVAL_MILLIS_DEFAULT);
+  }
+
   public long getContainerHeartbeatRetryCount() {
     return getLong(YARN_CONTAINER_HEARTBEAT_RETRY_COUNT, YARN_CONTAINER_HEARTBEAT_RETRY_COUNT_DEFAULT);
   }
diff --git a/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java b/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
index 9f774b8..9fc74a7 100644
--- a/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
@@ -167,7 +167,7 @@
    * @param systemStream system stream to map to stream id
    * @return             stream id corresponding to the system stream
    */
-  private String systemStreamToStreamId(SystemStream systemStream) {
+  public String systemStreamToStreamId(SystemStream systemStream) {
     List<String> streamIds = getStreamIdsForSystem(systemStream.getSystem()).stream()
       .filter(streamId -> systemStream.getStream().equals(getPhysicalName(streamId))).collect(Collectors.toList());
     if (streamIds.size() > 1) {
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
index b7a80a3..5bbcbae 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
@@ -24,6 +24,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -35,6 +36,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.MessageType;
 import org.apache.samza.system.SystemConsumers;
@@ -70,7 +72,6 @@
   private final List<AsyncTaskWorker> taskWorkers;
   private final SystemConsumers consumerMultiplexer;
   private final Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToTaskWorkerMapping;
-
   private final ExecutorService threadPool;
   private final CoordinatorRequests coordinatorRequests;
   private final Object latch;
@@ -86,9 +87,11 @@
   private volatile boolean shutdownNow = false;
   private volatile Throwable throwable = null;
   private final HighResolutionClock clock;
-  private final boolean isAsyncCommitEnabled;
+  private boolean isAsyncCommitEnabled;
   private volatile boolean runLoopResumedSinceLastChecked;
   private final int elasticityFactor;
+  private final String runId;
+  private boolean isDraining = false;
 
   public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
       ExecutorService threadPool,
@@ -103,7 +106,7 @@
       HighResolutionClock clock,
       boolean isAsyncCommitEnabled) {
     this(runLoopTasks, threadPool, consumerMultiplexer, maxConcurrency, windowMs, commitMs, callbackTimeoutMs,
-        maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1);
+        maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null);
   }
 
   public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
@@ -118,7 +121,8 @@
       SamzaContainerMetrics containerMetrics,
       HighResolutionClock clock,
       boolean isAsyncCommitEnabled,
-      int elasticityFactor) {
+      int elasticityFactor,
+      String runId) {
 
     this.threadPool = threadPool;
     this.consumerMultiplexer = consumerMultiplexer;
@@ -134,18 +138,30 @@
     this.latch = new Object();
     this.workerTimer = Executors.newSingleThreadScheduledExecutor();
     this.clock = clock;
-    Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
+    // assign runId before creating workers. As the inner AsyncTaskWorker class is not static, it relies on
+    // the outer class fields to be init first
+    this.runId = runId;
+    Map<TaskName, AsyncTaskWorker>  workers = new HashMap<>();
     for (RunLoopTask task : runLoopTasks.values()) {
       workers.put(task.taskName(), new AsyncTaskWorker(task));
     }
     // Partions and tasks assigned to the container will not change during the run loop life time
     this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(runLoopTasks, workers));
+
     this.taskWorkers = Collections.unmodifiableList(new ArrayList<>(workers.values()));
     this.isAsyncCommitEnabled = isAsyncCommitEnabled;
     this.elasticityFactor = elasticityFactor;
   }
 
   /**
+   * Sets the RunLoop to drain mode.
+   * */
+  private void drain() {
+    isDraining = true;
+    isAsyncCommitEnabled = false;
+  }
+
+  /**
    * Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to efficiently route the envelopes
    */
   private static Map<SystemStreamPartition, List<AsyncTaskWorker>> getSspToAsyncTaskWorkerMap(
@@ -297,7 +313,8 @@
    * when elasticity is enabled,
    *       sspToTaskWorkerMapping has workers for a SSP which has keyBucket
    *       hence need to use envelop.getSSP(elasticityFactor)
-   *       Additionally, when envelope is EnofStream or Watermark, it needs to be sent to all works for the ssp irrespective of keyBucket
+   *       Additionally, when envelope is EndOfStream or Watermark or Drain, it needs to be sent to all works for the ssp
+   *       irrespective of keyBucket
    * @param envelope
    * @return list of workers for the envelope
    */
@@ -309,9 +326,14 @@
     final SystemStreamPartition sspOfEnvelope = envelope.getSystemStreamPartition(elasticityFactor);
     List<AsyncTaskWorker> listOfWorkersForEnvelope = null;
 
-    // if envelope is end of stream or watermark, it needs to be routed to all tasks consuming the ssp irresp of keybucket
+    // if envelope is end of stream or watermark or drain, it needs to be routed to all tasks consuming the ssp irresp
+    // of keybucket
     MessageType messageType = MessageType.of(envelope.getMessage());
-    if (envelope.isEndOfStream() || MessageType.END_OF_STREAM == messageType || MessageType.WATERMARK == messageType) {
+    if (envelope.isEndOfStream()
+        || envelope.isDrain()
+        || messageType == MessageType.END_OF_STREAM
+        || messageType == MessageType.DRAIN
+        || messageType == MessageType.WATERMARK) {
 
       //sspToTaskWorkerMapping has ssps with keybucket so extract and check only system, stream and partition and ignore the keybucket
       listOfWorkersForEnvelope = sspToTaskWorkerMapping.entrySet()
@@ -391,6 +413,20 @@
   }
 
   /**
+   * Resume the runloop thread. This method is triggered after a task has completed drain.
+   */
+  private void resumeAfterDrain() {
+    log.trace("Resume loop thread");
+    if (coordinatorRequests.shouldShutdownNow()) {
+      shutdownNow = true;
+    }
+    synchronized (latch) {
+      latch.notifyAll();
+      runLoopResumedSinceLastChecked = true;
+    }
+  }
+
+  /**
    * Set the throwable and abort run loop. The throwable will be thrown from the run loop thread
    * @param t throwable
    */
@@ -426,6 +462,7 @@
     COMMIT,
     PROCESS,
     END_OF_STREAM,
+    DRAIN,
     SCHEDULER,
     NO_OP
   }
@@ -443,7 +480,8 @@
       this.task = task;
       this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock);
       Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
-      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty());
+      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty(),
+          runId);
     }
 
     private void init() {
@@ -514,18 +552,46 @@
         case END_OF_STREAM:
           endOfStream();
           break;
+        case DRAIN:
+          drain();
+          break;
         default:
           //no op
           break;
       }
     }
 
+    /**
+     * Called when a task has drained i.e all SSPs for the task have received a drain message.
+     * */
+    private void drain() {
+      state.complete = true;
+      state.startDrain();
+      try {
+        ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
+
+        task.drain(coordinator);
+
+        // issue a shutdown request for the task
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinatorRequests.update(coordinator);
+
+        // issue a commit explicitly before we shutdown the task
+        // Adding commit to coordinator will not work as the state is marked complete and NO_OP will always be the
+        // next operation for this task
+        commit();
+      } finally {
+        resumeAfterDrain();
+      }
+    }
+
     private void endOfStream() {
       state.complete = true;
       try {
         ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
 
         task.endOfStream(coordinator);
+
         // issue a request for shutdown of the task
         coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
         coordinatorRequests.update(coordinator);
@@ -538,7 +604,6 @@
       } finally {
         resume();
       }
-
     }
 
     /**
@@ -638,9 +703,12 @@
         }
       };
 
-      if (threadPool != null) {
+      if (threadPool != null && !isDraining) {
         log.trace("Task {} commits on the thread pool", task.taskName());
         threadPool.submit(commitWorker);
+      } else if (isDraining) {
+        log.trace("Task {} commits on the run loop thread as task is draining", task.taskName());
+        commitWorker.run();
       } else {
         log.trace("Task {} commits on the run loop thread", task.taskName());
         commitWorker.run();
@@ -759,24 +827,31 @@
     private volatile boolean needScheduler = false;
     private volatile boolean complete = false;
     private volatile boolean endOfStream = false;
+    private volatile boolean shouldDrain = false;
     private volatile boolean windowInFlight = false;
     private volatile boolean commitInFlight = false;
     private volatile boolean schedulerInFlight = false;
     private final AtomicInteger messagesInFlight = new AtomicInteger(0);
-    private final ArrayDeque<PendingEnvelope> pendingEnvelopeQueue;
+    private final ArrayDeque<PendingEnvelope>   pendingEnvelopeQueue;
 
     //Set of SSPs that we are currently processing for this task instance
     private final Set<SystemStreamPartition> processingSspSet;
+    //Set of SSPs that we are currently processing for this task instance
+    private final Set<SystemStreamPartition> processingSspSetToDrain;
     private final TaskName taskName;
     private final TaskInstanceMetrics taskMetrics;
     private final boolean hasIntermediateStreams;
+    private final String runId;
 
-    AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet, boolean hasIntermediateStreams) {
+    AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet,
+        boolean hasIntermediateStreams, String runId) {
       this.taskName = taskName;
       this.taskMetrics = taskMetrics;
       this.pendingEnvelopeQueue = new ArrayDeque<>();
       this.processingSspSet = sspSet;
+      this.processingSspSetToDrain = new HashSet<>(sspSet);
       this.hasIntermediateStreams = hasIntermediateStreams;
+      this.runId = runId;
     }
 
     private boolean checkEndOfStream() {
@@ -796,6 +871,7 @@
                     && sspInSet.getPartition().equals(sspOfEnvelope.getPartition()))
                 .findFirst();
             ssp.ifPresent(processingSspSet::remove);
+            ssp.ifPresent(processingSspSetToDrain::remove);
           }
           if (!hasIntermediateStreams) {
             pendingEnvelopeQueue.remove();
@@ -805,6 +881,57 @@
       return processingSspSet.isEmpty();
     }
 
+    private boolean shouldDrain() {
+      if (endOfStream) {
+        return false;
+      }
+
+      if (!pendingEnvelopeQueue.isEmpty()) {
+        PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.peek();
+        IncomingMessageEnvelope envelope = pendingEnvelope.envelope;
+
+        if (envelope.isDrain()) {
+          final DrainMessage message = (DrainMessage) envelope.getMessage();
+          if (!message.getRunId().equals(runId)) {
+            // Removing the drain message from the pending queue as it doesn't match with the current runId
+            // Removing it will ensure that it is not picked up by process()
+            pendingEnvelopeQueue.remove();
+          } else {
+            // set the RunLoop to drain mode
+            if (!isDraining) {
+              drain();
+            }
+
+            if (elasticityFactor <= 1) {
+              SystemStreamPartition ssp = envelope.getSystemStreamPartition();
+              processingSspSetToDrain.remove(ssp);
+            } else {
+              // SystemConsumers will write only one envelope (enclosing DrainMessage) per SSP in its buffer.
+              // This envelope doesn't have keybucket info it's SSP. With elasticity, the same SSP can be processed by
+              // multiple tasks. Therefore, if envelope contains drain message, the ssp of envelope should be removed
+              // from task's processing set irrespective of keyBucket.
+              SystemStreamPartition sspOfEnvelope = envelope.getSystemStreamPartition();
+              Optional<SystemStreamPartition> ssp = processingSspSetToDrain.stream()
+                  .filter(sspInSet -> sspInSet.getSystemStream().equals(sspOfEnvelope.getSystemStream())
+                      && sspInSet.getPartition().equals(sspOfEnvelope.getPartition()))
+                  .findFirst();
+              ssp.ifPresent(processingSspSetToDrain::remove);
+            }
+
+            if (!hasIntermediateStreams) {
+              // Don't remove from the pending queue as we want the DAG to pick up Drain message and propagate it to
+              // intermediate streams
+              pendingEnvelopeQueue.remove();
+            }
+          }
+        }
+        return processingSspSetToDrain.isEmpty();
+      }
+      // if no messages are in the queue, the task has probably already drained or there are no messages from
+      // the chooser
+      return false;
+    }
+
     /**
      * Returns whether the task is ready to do process/window/commit.
      *
@@ -813,6 +940,11 @@
       if (checkEndOfStream()) {
         endOfStream = true;
       }
+
+      if (shouldDrain()) {
+        shouldDrain = true;
+      }
+
       if (coordinatorRequests.commitRequests().remove(taskName)) {
         needCommit = true;
       }
@@ -826,9 +958,9 @@
        */
       if (needCommit) {
         return (messagesInFlight.get() == 0 || isAsyncCommitEnabled) && !opInFlight;
-      } else if (needWindow || needScheduler || endOfStream) {
+      } else if (needWindow || needScheduler || endOfStream || shouldDrain) {
         /*
-         * A task is ready for window, scheduler or end-of-stream operation.
+         * A task is ready for window, scheduler, drain or end-of-stream operation.
          */
         return messagesInFlight.get() == 0 && !opInFlight;
       } else {
@@ -847,14 +979,24 @@
      */
     private WorkerOp nextOp() {
 
-      if (complete) return WorkerOp.NO_OP;
+      if (complete) {
+        return WorkerOp.NO_OP;
+      }
 
       if (isReady()) {
-        if (needCommit) return WorkerOp.COMMIT;
-        else if (needWindow) return WorkerOp.WINDOW;
-        else if (needScheduler) return WorkerOp.SCHEDULER;
-        else if (endOfStream && pendingEnvelopeQueue.isEmpty()) return WorkerOp.END_OF_STREAM;
-        else if (!pendingEnvelopeQueue.isEmpty()) return WorkerOp.PROCESS;
+        if (needCommit) {
+          return WorkerOp.COMMIT;
+        } else if (needWindow) {
+          return WorkerOp.WINDOW;
+        } else if (needScheduler) {
+          return WorkerOp.SCHEDULER;
+        } else if (endOfStream && pendingEnvelopeQueue.isEmpty()) {
+          return WorkerOp.END_OF_STREAM;
+        } else if (shouldDrain && pendingEnvelopeQueue.isEmpty()) {
+          return WorkerOp.DRAIN;
+        } else if (!pendingEnvelopeQueue.isEmpty()) {
+          return WorkerOp.PROCESS;
+        }
       }
       return WorkerOp.NO_OP;
     }
@@ -876,6 +1018,10 @@
       windowInFlight = true;
     }
 
+    private void startDrain() {
+      shouldDrain = false;
+    }
+
     private void startCommit() {
       needCommit = false;
       commitInFlight = true;
@@ -918,6 +1064,7 @@
       int queueSize = pendingEnvelopeQueue.size();
       taskMetrics.pendingMessages().set(queueSize);
       log.trace("Insert envelope to task {} queue.", taskName);
+      log.trace("Insert envelope to task {} queue.", taskName);
       log.debug("Task {} pending envelope count is {} after insertion.", taskName, queueSize);
     }
 
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index 9d20950..cb021d0 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -41,7 +41,8 @@
       SamzaContainerMetrics containerMetrics,
       TaskConfig taskConfig,
       HighResolutionClock clock,
-      int elasticityFactor) {
+      int elasticityFactor,
+      String runId) {
 
     long taskWindowMs = taskConfig.getWindowMs();
 
@@ -65,6 +66,8 @@
 
     log.info("Got elasticity factor: {}.", elasticityFactor);
 
+    log.info("Got current run Id: {}.", runId);
+
     log.info("Run loop in asynchronous mode.");
 
     return new RunLoop(
@@ -80,6 +83,7 @@
       containerMetrics,
       clock,
       isAsyncCommitEnabled,
-      elasticityFactor);
+      elasticityFactor,
+      runId);
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
index 551da88..af63494 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
@@ -99,6 +99,15 @@
   void endOfStream(ReadableCoordinator coordinator);
 
   /**
+   * Called when all {@link SystemStreamPartition} processed by a task have drained. This is called only
+   * once per task. {@link RunLoop} will issue a shutdown request to the coordinator immediately following the
+   * invocation of this method.
+   *
+   * @param coordinator manages execution of tasks.
+   */
+  void drain(ReadableCoordinator coordinator);
+
+  /**
    * Indicates whether {@link #window} should be invoked on this task. If true, {@link RunLoop}
    * will schedule window to execute periodically according to its windowMs.
    *
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
index 6b5c98e..b46fc3f 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
@@ -33,6 +33,7 @@
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
 import org.apache.samza.metadatastore.MetadataStore;
 import org.slf4j.Logger;
@@ -65,7 +66,6 @@
     STOPPED
   }
 
-  private static final int POLLING_INTERVAL_MILLIS = 60_000;
   private static final int INITIAL_POLL_DELAY_MILLIS = 0;
 
   private final ScheduledExecutorService schedulerService =
@@ -75,7 +75,6 @@
               .setDaemon(true)
               .build());
   private final String appRunId;
-  private final long pollingIntervalMillis;
   private final NamespaceAwareCoordinatorStreamStore drainMetadataStore;
   // Used to guard write access to state.
   private final Object lock = new Object();
@@ -83,20 +82,23 @@
   @GuardedBy("lock")
   private State state = State.INIT;
   private DrainCallback callback;
+  private long pollingIntervalMillis;
 
   public DrainMonitor(MetadataStore metadataStore, Config config) {
-    this(metadataStore, config, POLLING_INTERVAL_MILLIS);
-  }
-
-  public DrainMonitor(MetadataStore metadataStore, Config config, long pollingIntervalMillis) {
     Preconditions.checkNotNull(metadataStore, "MetadataStore parameter cannot be null.");
     Preconditions.checkNotNull(config, "Config parameter cannot be null.");
-    Preconditions.checkArgument(pollingIntervalMillis > 0,
-        String.format("Polling interval specified is %d ms. It should be greater than 0.", pollingIntervalMillis));
     this.drainMetadataStore =
         new NamespaceAwareCoordinatorStreamStore(metadataStore, DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
     ApplicationConfig applicationConfig = new ApplicationConfig(config);
     this.appRunId = applicationConfig.getRunId();
+    JobConfig jobConfig = new JobConfig(config);
+    this.pollingIntervalMillis = jobConfig.getDrainMonitorPollIntervalMillis();
+  }
+
+  public DrainMonitor(MetadataStore metadataStore, Config config, long pollingIntervalMillis) {
+    this(metadataStore, config);
+    Preconditions.checkArgument(pollingIntervalMillis > 0,
+        String.format("Polling interval specified is %d ms. It should be greater than 0.", pollingIntervalMillis));
     this.pollingIntervalMillis = pollingIntervalMillis;
   }
 
@@ -207,12 +209,12 @@
    * One time check check to see if there are any DrainNotification messages available in the
    * metadata store for the current deployment.
    * */
-  static boolean shouldDrain(NamespaceAwareCoordinatorStreamStore drainMetadataStore, String deploymentId) {
+  static boolean shouldDrain(NamespaceAwareCoordinatorStreamStore drainMetadataStore, String runId) {
     final Optional<List<DrainNotification>> drainNotifications = readDrainNotificationMessages(drainMetadataStore);
     if (drainNotifications.isPresent()) {
       final ImmutableList<DrainNotification> filteredDrainNotifications = drainNotifications.get()
           .stream()
-          .filter(notification -> deploymentId.equals(notification.getDeploymentId()))
+          .filter(notification -> runId.equals(notification.getRunId()))
           .collect(ImmutableList.toImmutableList());
       return !filteredDrainNotifications.isEmpty();
     }
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
index a16595e..dd97cd7 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
@@ -33,26 +33,26 @@
   /**
    * Unique identifier for a deployment so drain notifications messages can be invalidated across a job restarts.
    */
-  private final String deploymentId;
+  private final String runId;
 
-  public DrainNotification(UUID uuid, String deploymentId) {
+  public DrainNotification(UUID uuid, String runId) {
     this.uuid = uuid;
-    this.deploymentId = deploymentId;
+    this.runId = runId;
   }
 
   public UUID getUuid() {
     return this.uuid;
   }
 
-  public String getDeploymentId() {
-    return deploymentId;
+  public String getRunId() {
+    return runId;
   }
 
   @Override
   public String toString() {
     final StringBuilder sb = new StringBuilder("DrainMessage{");
     sb.append(" UUID: ").append(uuid);
-    sb.append(", deploymentId: '").append(deploymentId).append('\'');
+    sb.append(", runId: '").append(runId).append('\'');
     sb.append('}');
     return sb.toString();
   }
@@ -67,11 +67,11 @@
     }
     DrainNotification that = (DrainNotification) o;
     return Objects.equal(uuid, that.uuid)
-        && Objects.equal(deploymentId, that.deploymentId);
+        && Objects.equal(runId, that.runId);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(uuid, deploymentId);
+    return Objects.hashCode(uuid, runId);
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
index 48e6db1..72cdfbb 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
@@ -67,7 +67,7 @@
         throws IOException {
       Map<String, Object> drainMessageMap = new HashMap<>();
       drainMessageMap.put("uuid", value.getUuid().toString());
-      drainMessageMap.put("deploymentId", value.getDeploymentId());
+      drainMessageMap.put("runId", value.getRunId());
       jsonGenerator.writeObject(drainMessageMap);
     }
   }
@@ -79,8 +79,8 @@
       ObjectCodec oc = jsonParser.getCodec();
       JsonNode node = oc.readTree(jsonParser);
       UUID uuid = UUID.fromString(node.get("uuid").textValue());
-      String deploymentId = node.get("deploymentId").textValue();
-      return new DrainNotification(uuid, deploymentId);
+      String runId = node.get("runId").textValue();
+      return new DrainNotification(uuid, runId);
     }
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
index c100a47..f712381 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
@@ -47,23 +47,23 @@
    * Writes a {@link DrainNotification} to the underlying metastore. This method should be used by external controllers
    * to issue a DrainNotification to the JobCoordinator and Samza Containers.
    * @param metadataStore Metadata store to write drain notification to.
-   * @param deploymentId deploymentId for the DrainNotification
+   * @param runId runId for the DrainNotification
    *
    * @return generated uuid for the DrainNotification
    */
-  public static UUID writeDrainNotification(MetadataStore metadataStore, String deploymentId) {
+  public static UUID writeDrainNotification(MetadataStore metadataStore, String runId) {
     Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
-    Preconditions.checkArgument(!Strings.isNullOrEmpty(deploymentId), "deploymentId should be non-null.");
-    LOG.info("Attempting to write DrainNotification to metadata-store for the deployment ID {}", deploymentId);
+    Preconditions.checkArgument(!Strings.isNullOrEmpty(runId), "runId should be non-null.");
+    LOG.info("Attempting to write DrainNotification to metadata-store for the deployment ID {}", runId);
     final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
         new NamespaceAwareCoordinatorStreamStore(metadataStore, DRAIN_METADATA_STORE_NAMESPACE);
     final ObjectMapper objectMapper = DrainNotificationObjectMapper.getObjectMapper();
     final UUID uuid = UUID.randomUUID();
-    final DrainNotification message = new DrainNotification(uuid, deploymentId);
+    final DrainNotification message = new DrainNotification(uuid, runId);
     try {
       drainMetadataStore.put(message.getUuid().toString(), objectMapper.writeValueAsBytes(message));
       drainMetadataStore.flush();
-      LOG.info("DrainNotification with id {} written to metadata-store for the deployment ID {}", uuid, deploymentId);
+      LOG.info("DrainNotification with id {} written to metadata-store for the deployment ID {}", uuid, runId);
     } catch (Exception ex) {
       throw new SamzaException(
           String.format("DrainNotification might have been not written to metastore %s", message), ex);
@@ -73,23 +73,23 @@
 
   /**
    * Cleans up DrainNotifications for the current deployment from the underlying metadata store.
-   * The current deploymentId is extracted from the config.
+   * The current runId is extracted from the config.
    *
    * @param metadataStore underlying metadata store
-   * @param config Config for the job. Used to extract the deploymentId of the job.
+   * @param config Config for the job. Used to extract the runId of the job.
    * */
   public static void cleanup(MetadataStore metadataStore, Config config) {
     Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
     Preconditions.checkNotNull(config, "Config parameter cannot be null.");
 
     final ApplicationConfig applicationConfig = new ApplicationConfig(config);
-    final String deploymentId = applicationConfig.getRunId();
+    final String runId = applicationConfig.getRunId();
     final ObjectMapper objectMapper = DrainNotificationObjectMapper.getObjectMapper();
     final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
         new NamespaceAwareCoordinatorStreamStore(metadataStore, DRAIN_METADATA_STORE_NAMESPACE);
 
-    if (DrainMonitor.shouldDrain(drainMetadataStore, deploymentId)) {
-      LOG.info("Attempting to clean up DrainNotifications from the metadata-store for the current deployment {}", deploymentId);
+    if (DrainMonitor.shouldDrain(drainMetadataStore, runId)) {
+      LOG.info("Attempting to clean up DrainNotifications from the metadata-store for the current deployment {}", runId);
       drainMetadataStore.all()
           .values()
           .stream()
@@ -101,19 +101,19 @@
               throw new SamzaException(e);
             }
           })
-          .filter(notification -> deploymentId.equals(notification.getDeploymentId()))
+          .filter(notification -> runId.equals(notification.getRunId()))
           .forEach(notification -> drainMetadataStore.delete(notification.getUuid().toString()));
 
       drainMetadataStore.flush();
-      LOG.info("Successfully cleaned up DrainNotifications from the metadata-store for the current deployment {}", deploymentId);
+      LOG.info("Successfully cleaned up DrainNotifications from the metadata-store for the current deployment {}", runId);
     } else {
       LOG.info("No DrainNotification found in the metadata-store for the current deployment {}. No need to cleanup.",
-          deploymentId);
+          runId);
     }
   }
 
   /**
-   * Cleans up all DrainNotifications irrespective of the deploymentId.
+   * Cleans up all DrainNotifications irrespective of the runId.
    * */
   public static void cleanupAll(MetadataStore metadataStore) {
     Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
index 4f93f2c..7dd5d19 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
@@ -21,10 +21,12 @@
 
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.context.Context;
 import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.system.ControlMessage;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
@@ -40,11 +42,13 @@
   private final BroadcastOperatorSpec<M> broadcastOpSpec;
   private final SystemStream systemStream;
   private final String taskName;
+  private final String runId;
 
   BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, SystemStream systemStream, Context context) {
     this.broadcastOpSpec = broadcastOpSpec;
     this.systemStream = systemStream;
     this.taskName = context.getTaskContext().getTaskModel().getTaskName().getTaskName();
+    this.runId = new ApplicationConfig(context.getJobContext().getConfig()).getRunId();
   }
 
   @Override
@@ -74,6 +78,12 @@
   }
 
   @Override
+  protected Collection<Void> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+    sendControlMessage(new DrainMessage(taskName, runId), collector);
+    return Collections.emptyList();
+  }
+
+  @Override
   protected Collection<Void> handleWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) {
     sendControlMessage(new WatermarkMessage(watermark, taskName), collector);
     return Collections.emptyList();
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java b/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java
new file mode 100644
index 0000000..30a341f
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.operators.impl;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.system.DrainMessage;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+
+
+/**
+ * This class tracks the drain state for streams in a task. Internally it keeps track of Drain messages received
+ * from upstream tasks for each system stream partition (ssp). If messages have been received from all tasks,
+ * it will mark the ssp as drained. For a stream to be drained, all its partitions assigned to
+ * the task need to be drained.
+ *
+ * This class is thread-safe.
+ */
+public class DrainStates {
+  private static final class DrainState {
+    // set of upstream tasks
+    private final Set<String> tasks = new HashSet<>();
+    private final int expectedTotal;
+    private volatile boolean drained = false;
+
+    DrainState(int expectedTotal) {
+      this.expectedTotal = expectedTotal;
+    }
+
+    synchronized void update(String taskName) {
+      if (taskName != null) {
+        // aggregate the eos messages
+        tasks.add(taskName);
+        drained = tasks.size() == expectedTotal;
+      } else {
+        // eos is coming from either source or aggregator task
+        drained = true;
+      }
+    }
+
+    boolean isDrained() {
+      return drained;
+    }
+
+    @Override
+    public String toString() {
+      return "DrainState: [Tasks : "
+          + tasks
+          + ", isDrained : "
+          + drained
+          + "]";
+    }
+  }
+
+  private final Map<SystemStreamPartition, DrainState> drainStates;
+
+  /**
+   * Constructing the drain states for a task.
+   * @param ssps all the ssps assigned to this task
+   * @param producerTaskCounts mapping from a stream to the number of upstream tasks that produce to it
+   */
+  DrainStates(Set<SystemStreamPartition> ssps, Map<SystemStream, Integer> producerTaskCounts) {
+    this.drainStates = ssps.stream()
+        .collect(Collectors.toMap(
+          ssp -> ssp,
+          ssp -> new DrainState(producerTaskCounts.getOrDefault(ssp.getSystemStream(), 0))));
+  }
+
+  /**
+   * Update the state upon receiving a drain message.
+   * @param eos message of {@link DrainMessage}
+   * @param ssp system stream partition
+   */
+  void update(DrainMessage eos, SystemStreamPartition ssp) {
+    DrainState state = drainStates.get(ssp);
+    state.update(eos.getTaskName());
+  }
+
+  /**
+   * Checks if the system-stream is drained.
+   * */
+  boolean isDrained(SystemStream systemStream) {
+    return drainStates.entrySet().stream()
+        .filter(entry -> entry.getKey().getSystemStream().equals(systemStream))
+        .allMatch(entry -> entry.getValue().isDrained());
+  }
+
+  /**
+   * Checks if all streams (input SSPs) for the task has drained.
+   * */
+  boolean areAllStreamsDrained() {
+    return drainStates.values().stream().allMatch(DrainState::isDrained);
+  }
+
+  @Override
+  public String toString() {
+    return drainStates.toString();
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index c404475..5723d91 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -1,4 +1,4 @@
-/*
+  /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -39,6 +39,7 @@
 import org.apache.samza.operators.functions.WatermarkFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.scheduler.CallbackScheduler;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
@@ -86,6 +87,8 @@
   private TaskModel taskModel;
   // end-of-stream states
   private EndOfStreamStates eosStates;
+  // drain states
+  private DrainStates drainStates;
   // watermark states
   private WatermarkStates watermarkStates;
   private CallbackScheduler callbackScheduler;
@@ -125,6 +128,7 @@
     this.taskName = taskContext.getTaskModel().getTaskName();
     this.eosStates = (EndOfStreamStates) internalTaskContext.fetchObject(EndOfStreamStates.class.getName());
     this.watermarkStates = (WatermarkStates) internalTaskContext.fetchObject(WatermarkStates.class.getName());
+    this.drainStates = (DrainStates) internalTaskContext.fetchObject(DrainStates.class.getName());
     this.controlMessageSender = new ControlMessageSender(internalTaskContext.getStreamMetadataCache());
     this.taskModel = taskContext.getTaskModel();
     this.callbackScheduler = taskContext.getCallbackScheduler();
@@ -362,6 +366,88 @@
   }
 
   /**
+   * This method is implemented when all input stream to this operation have encountered drain-and-stop control message.
+   * Inherited class should handle drain-and-stop by overriding this function.
+   * By default noop implementation is for in-memory operator to handle the drain-and-stop. Output operator need to
+   * override this to actually propagate drain-and-stop over the wire.
+   * @param collector message collector
+   * @param coordinator task coordinator
+   * @return results to be emitted when this operator reaches drain-and-stop
+   */
+  protected Collection<RM> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+    return Collections.emptyList();
+  }
+
+  /**
+   * Aggregate {@link DrainMessage} from each ssp of the stream.
+   * Invoke {@link #onDrainOfStream(MessageCollector, TaskCoordinator)} if the stream reaches the end.
+   * @param drainMessage {@link DrainMessage} object
+   * @param ssp system stream partition
+   * @param collector message collector
+   * @param coordinator task coordinator
+   */
+  public final CompletionStage<Void> aggregateDrainMessages(DrainMessage drainMessage, SystemStreamPartition ssp,
+      MessageCollector collector, TaskCoordinator coordinator) {
+    LOG.info("Received drain message from task {} in {}", drainMessage.getTaskName(), ssp);
+    drainStates.update(drainMessage, ssp);
+
+    SystemStream stream = ssp.getSystemStream();
+    CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
+
+    if (drainStates.isDrained(stream)) {
+      LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName());
+      if (drainMessage.getTaskName() != null && shouldTaskBroadcastToOtherPartitions(ssp)) {
+        // This is the aggregation task, which already received all the eos messages from upstream
+        // broadcast the end-of-stream to all the peer partitions
+        // additionally if elasiticty is enabled
+        // then only one of the elastic tasks of the ssp will broadcast
+        controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
+      }
+
+      drainFuture = onDrainOfStream(collector, coordinator)
+          .thenAccept(result -> {
+            if (drainStates.areAllStreamsDrained()) {
+              // all input streams have been drained, shut down the task
+              LOG.info("All input streams have been drained for task {}", taskName.getTaskName());
+              coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            }
+          });
+    }
+
+    return drainFuture;
+  }
+
+
+  /**
+   * Invoke {@link #handleDrain(MessageCollector, TaskCoordinator)} if all the input streams to the current operator
+   * have encountered drain message.
+   * Propagate the drain to downstream operators.
+   * @param collector message collector
+   * @param coordinator task coordinator
+   */
+  private CompletionStage<Void> onDrainOfStream(MessageCollector collector, TaskCoordinator coordinator) {
+    CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
+
+    if (inputStreams.stream().allMatch(input -> drainStates.isDrained(input))) {
+      Collection<RM> results = handleDrain(collector, coordinator);
+
+      CompletionStage<Void> resultFuture = CompletableFuture.allOf(
+          results.stream()
+              .flatMap(r -> this.registeredOperators.stream()
+                  .map(op -> op.onMessageAsync(r, collector, coordinator)))
+              .toArray(CompletableFuture[]::new));
+
+      // propagate DrainMessage to downstream operators
+      drainFuture = resultFuture.thenCompose(x ->
+          CompletableFuture.allOf(this.registeredOperators.stream()
+              .map(op -> op.onDrainOfStream(collector, coordinator))
+              .toArray(CompletableFuture[]::new)));
+    }
+
+    return drainFuture;
+  }
+
+  /**
    * Aggregate the {@link WatermarkMessage} from each ssp into a watermark. Then call onWatermark() if
    * a new watermark exits.
    * @param watermarkMessage a {@link WatermarkMessage} object
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
index 1aa67d8..c62b0b2 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
@@ -115,11 +115,16 @@
     // set states for end-of-stream; don't include side inputs (see SAMZA-2303)
     internalTaskContext.registerObject(EndOfStreamStates.class.getName(),
         new EndOfStreamStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts));
+
     // set states for watermark; don't include side inputs (see SAMZA-2303)
     internalTaskContext.registerObject(WatermarkStates.class.getName(),
         new WatermarkStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts,
             context.getContainerContext().getContainerMetricsRegistry()));
 
+    // set states for drain; don't include side inputs (see SAMZA-2303)
+    internalTaskContext.registerObject(DrainStates.class.getName(),
+        new DrainStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts));
+
     specGraph.getInputOperators().forEach((streamId, inputOpSpec) -> {
       SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId);
       InputOperatorImpl inputOperatorImpl =
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
index 47ad4f6..643abdf 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
@@ -20,12 +20,14 @@
 
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.context.Context;
 import org.apache.samza.context.InternalTaskContext;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
 import org.apache.samza.system.ControlMessage;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
@@ -48,6 +50,7 @@
   private final MapFunction<? super M, ? extends K> keyFunction;
   private final MapFunction<? super M, ? extends V> valueFunction;
   private final String taskName;
+  private final String runId;
   private final ControlMessageSender controlMessageSender;
 
   PartitionByOperatorImpl(PartitionByOperatorSpec<M, K, V> partitionByOpSpec,
@@ -57,6 +60,7 @@
     this.keyFunction = partitionByOpSpec.getKeyFunction();
     this.valueFunction = partitionByOpSpec.getValueFunction();
     this.taskName = internalTaskContext.getContext().getTaskContext().getTaskModel().getTaskName().getTaskName();
+    this.runId = new ApplicationConfig(internalTaskContext.getContext().getJobContext().getConfig()).getRunId();
     StreamMetadataCache streamMetadataCache = internalTaskContext.getStreamMetadataCache();
     this.controlMessageSender = new ControlMessageSender(streamMetadataCache);
   }
@@ -95,6 +99,12 @@
   }
 
   @Override
+  protected Collection<Void> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+    sendControlMessage(new DrainMessage(taskName, runId), collector);
+    return Collections.emptyList();
+  }
+
+  @Override
   protected Collection<Void> handleWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) {
     sendControlMessage(new WatermarkMessage(watermark, taskName), collector);
     return Collections.emptyList();
diff --git a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
index 654d9a0..cda04377 100644
--- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
+++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
@@ -48,6 +48,7 @@
 import org.apache.samza.coordinator.JobCoordinatorFactory;
 import org.apache.samza.coordinator.JobCoordinatorListener;
 import org.apache.samza.diagnostics.DiagnosticsManager;
+import org.apache.samza.drain.DrainMonitor;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metadatastore.MetadataStore;
 import org.apache.samza.metrics.MetricsRegistry;
@@ -402,12 +403,18 @@
      */
     MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
 
+    DrainMonitor drainMonitor = null;
+    JobConfig jobConfig = new JobConfig(config);
+    if (metadataStore != null && jobConfig.getDrainMonitorEnabled()) {
+      drainMonitor = new DrainMonitor(metadataStore, config);
+    }
+
     return SamzaContainer.apply(processorId, jobModel, ScalaJavaUtil.toScalaMap(this.customMetricsReporter),
         metricsRegistryMap, this.taskFactory, JobContextImpl.fromConfigWithDefaults(this.config, jobModel),
         Option.apply(this.applicationDefinedContainerContextFactoryOptional.orElse(null)),
         Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)),
         Option.apply(this.externalContextOptional.orElse(null)), null, startpointManager,
-        Option.apply(diagnosticsManager.orElse(null)), null);
+        Option.apply(diagnosticsManager.orElse(null)), drainMonitor);
   }
 
   private static JobCoordinator createJobCoordinator(Config config, String processorId, MetricsRegistry metricsRegistry, MetadataStore metadataStore) {
diff --git a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
index f499eb3..cd153b5 100644
--- a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
@@ -143,7 +143,8 @@
       MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
 
       DrainMonitor drainMonitor = null;
-      if (new JobConfig(config).getDrainMonitorEnabled()) {
+      JobConfig jobConfig = new JobConfig(config);
+      if (jobConfig.getDrainMonitorEnabled()) {
         drainMonitor = new DrainMonitor(coordinatorStreamStore, config);
       }
 
diff --git a/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java b/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
index 83a0a35..a7afecd 100644
--- a/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
+++ b/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
@@ -22,6 +22,7 @@
 import java.util.Arrays;
 
 import org.apache.samza.SamzaException;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.MessageType;
 import org.apache.samza.system.WatermarkMessage;
@@ -56,11 +57,13 @@
   private final Serde userMessageSerde;
   private final Serde<WatermarkMessage> watermarkSerde;
   private final Serde<EndOfStreamMessage> eosSerde;
+  private final Serde<DrainMessage> drainMessageSerde;
 
   public IntermediateMessageSerde(Serde userMessageSerde) {
     this.userMessageSerde = userMessageSerde;
     this.watermarkSerde = new JsonSerdeV2<>(WatermarkMessage.class);
     this.eosSerde = new JsonSerdeV2<>(EndOfStreamMessage.class);
+    this.drainMessageSerde = new JsonSerdeV2<>(DrainMessage.class);
   }
 
   @Override
@@ -93,6 +96,9 @@
         case END_OF_STREAM:
           object = eosSerde.fromBytes(data);
           break;
+        case DRAIN:
+          object = drainMessageSerde.fromBytes(data);
+          break;
         default:
           throw new UnsupportedOperationException(String.format("Message type %s is not supported", type.name()));
       }
@@ -118,6 +124,9 @@
       case END_OF_STREAM:
         data = eosSerde.toBytes((EndOfStreamMessage) object);
         break;
+      case DRAIN:
+        data = drainMessageSerde.toBytes((DrainMessage) object);
+        break;
       default:
         throw new SamzaException("Unknown message type: " + type.name());
     }
diff --git a/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
index 6274c15..981e847 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
@@ -97,6 +97,11 @@
   }
 
   @Override
+  public void drain(ReadableCoordinator coordinator) {
+    LOG.info("Task {} has drained", this.taskName);
+  }
+
+  @Override
   public boolean isWindowableTask() {
     return false;
   }
diff --git a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
index 0079fab..0db960a 100644
--- a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
+++ b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
@@ -31,6 +31,7 @@
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.MessageType;
+import org.apache.samza.system.DrainMessage;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.util.Clock;
@@ -129,6 +130,12 @@
                   inputOpImpl.aggregateEndOfStream(eosMessage, ime.getSystemStreamPartition(), collector, coordinator);
               break;
 
+            case DRAIN:
+              DrainMessage drainMessage = (DrainMessage) ime.getMessage();
+              processFuture =
+                  inputOpImpl.aggregateDrainMessages(drainMessage, ime.getSystemStreamPartition(), collector, coordinator);
+              break;
+
             case WATERMARK:
               WatermarkMessage watermarkMessage = (WatermarkMessage) ime.getMessage();
               processFuture = inputOpImpl.aggregateWatermark(watermarkMessage, ime.getSystemStreamPartition(), collector,
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index c157b87..f0aa20b 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -27,7 +27,6 @@
 import java.util.concurrent._
 import java.util.function.Consumer
 import java.util.{Base64, Optional}
-import com.google.common.annotations.VisibleForTesting
 import com.google.common.util.concurrent.ThreadFactoryBuilder
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.{CheckpointListener, OffsetManager, OffsetManagerMetrics}
@@ -426,6 +425,8 @@
 
     val pollIntervalMs = taskConfig.getPollIntervalMs
 
+    val appConfig = new ApplicationConfig(config)
+
     val consumerMultiplexer = new SystemConsumers(
       chooser = chooser,
       consumers = consumers,
@@ -435,7 +436,8 @@
       dropDeserializationError = dropDeserializationError,
       pollIntervalMs = pollIntervalMs,
       clock = () => clock.nanoTime(),
-      elasticityFactor = jobConfig.getElasticityFactor)
+      elasticityFactor = jobConfig.getElasticityFactor,
+      runId = appConfig.getRunId)
 
     val producerMultiplexer = new SystemProducers(
       producers = producers,
@@ -622,7 +624,7 @@
 
     val maxThrottlingDelayMs = config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1))
 
-    val runLoop = RunLoopFactory.createRunLoop(
+    val runLoop: Runnable = RunLoopFactory.createRunLoop(
       taskInstances,
       consumerMultiplexer,
       taskThreadPool,
@@ -630,7 +632,8 @@
       samzaContainerMetrics,
       taskConfig,
       clock,
-      jobConfig.getElasticityFactor)
+      jobConfig.getElasticityFactor,
+      appConfig.getRunId)
 
     val containerMemoryMb : Int = new ClusterManagerConfig(config).getContainerMemoryMb
 
@@ -747,6 +750,7 @@
   def getStatus(): SamzaContainerStatus = status
 
   def drain() {
+    // set the SystemConsumers multiplexer and RunLoop in drain mode
     consumerMultiplexer.drain
   }
 
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index d75d911..606f32a 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -72,6 +72,7 @@
 
   val taskName: TaskName = taskModel.getTaskName
   val isInitableTask = task.isInstanceOf[InitableTask]
+  val isDrainTask = task.isInstanceOf[DrainListenerTask]
   val isEndOfStreamListenerTask = task.isInstanceOf[EndOfStreamListenerTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
 
@@ -111,6 +112,12 @@
   val streamConfig: StreamConfig = new StreamConfig(config)
   override val intermediateStreams: java.util.Set[String] = JavaConverters.setAsJavaSetConverter(streamConfig.getStreamIds.filter(streamConfig.getIsIntermediateStream)).asJava
 
+  val intermediateSSPs: Set[SystemStreamPartition] = systemStreamPartitions.filter(ssp => {
+    val systemStream = ssp.getSystemStream
+    val streamId = streamConfig.systemStreamToStreamId(systemStream)
+    intermediateStreams.contains(streamId)
+  }).toSet
+
   val streamsToDeleteCommittedMessages: Set[String] = streamConfig.getStreamIds.filter(streamConfig.getDeleteCommittedMessages).map(streamConfig.getPhysicalName).toSet
 
   val checkpointWriteVersions = new TaskConfig(config).getCheckpointWriteVersions
@@ -189,20 +196,27 @@
   }
 
   /**
-    * Computes the starting offset for the partitions assigned to the task and registers them with the underlying {@see SystemConsumers}.
-    *
-    * Starting offset for a partition of the task is computed in the following manner:
-    *
-    * 1. If a startpoint exists for a task, system stream partition and it resolves to a offset, then the resolved offset is used as the starting offset.
-    * 2. Else, the checkpointed offset for the system stream partition is used as the starting offset.
+   * This method registers the following with the underlying {@see SystemConsumers}:
+   * a) starting offsets for all SSPs assigned to the task
+   * b) intermediate SSPs assigned to the task
+   *
+   * Starting offset for a partition of the task is computed in the following manner:
+   *
+   * 1. If a startpoint exists for a task, system stream partition and it resolves to a offset, then the resolved offset is used as the starting offset.
+   * 2. Else, the checkpointed offset for the system stream partition is used as the starting offset.
     */
   def registerConsumers() {
     debug("Registering consumers for taskName: %s" format taskName)
+
     systemStreamPartitions.foreach(systemStreamPartition => {
       val startingOffset: String = getStartingOffset(systemStreamPartition)
       consumerMultiplexer.register(systemStreamPartition, startingOffset)
       metrics.addOffsetGauge(systemStreamPartition, () => offsetManager.getLastProcessedOffset(taskName, systemStreamPartition).orNull)
     })
+
+    intermediateSSPs.foreach(ssp => {
+      consumerMultiplexer.registerIntermediateSSP(ssp)
+    })
   }
 
   def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator,
@@ -241,6 +255,16 @@
     }
   }
 
+  def drain(coordinator: ReadableCoordinator): Unit = {
+    task match {
+      case _: DrainListenerTask =>
+        exceptionHandler.maybeHandle {
+          task.asInstanceOf[DrainListenerTask].onDrain(collector, coordinator)
+        }
+      case _ =>
+    }
+  }
+
   def window(coordinator: ReadableCoordinator) {
     if (isWindowableTask) {
       trace("Windowing for taskName: %s" format taskName)
@@ -613,14 +637,19 @@
     // if elasticityFactor > 1, find the SSP with keyBucket
     var incomingMessageSsp = envelope.getSystemStreamPartition(elasticityFactor)
 
-    // if envelope is end of stream or watermark, it needs to be routed to all tasks consuming the ssp irresp of keyBucket
+    // if envelope is end of stream or watermark or drain,
+    // it needs to be routed to all tasks consuming the ssp irresp of keyBucket
     val messageType = MessageType.of(envelope.getMessage)
-    if (envelope.isEndOfStream() || MessageType.END_OF_STREAM == messageType || MessageType.WATERMARK == messageType) {
+    if (envelope.isEndOfStream()
+      || envelope.isDrain()
+      || messageType == MessageType.END_OF_STREAM
+      || messageType == MessageType.WATERMARK) {
+
       incomingMessageSsp = systemStreamPartitions
         .filter(ssp => ssp.getSystemStream.equals(incomingMessageSsp.getSystemStream)
           && ssp.getPartition.equals(incomingMessageSsp.getPartition))
         .toIterator.next()
-      debug("for watermark or end of stream envelope, found incoming ssp as {}" format incomingMessageSsp)
+      debug("for watermark or end-of-stream or drain envelope, found incoming ssp as {}".format(incomingMessageSsp))
     }
     incomingMessageSsp
   }
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index f1c476a..ac1a558 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -89,7 +89,7 @@
     }
     var drainMonitor: DrainMonitor = null
     if (jobConfig.getDrainMonitorEnabled()) {
-      drainMonitor = new DrainMonitor(coordinatorStreamStore, config)
+      drainMonitor = new DrainMonitor(coordinatorStreamStore, config, jobConfig.getDrainMonitorPollIntervalMillis)
     }
 
     val containerId = "0"
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index c2e02a4..40b2999 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -45,6 +45,7 @@
 import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.StorageConfig;
@@ -120,6 +121,8 @@
 
   private static final int RESTORE_THREAD_POOL_SHUTDOWN_TIMEOUT_SECONDS = 60;
 
+  private static final int DEFAULT_SIDE_INPUT_ELASTICITY_FACTOR = 1;
+
   /** Maps containing relevant per-task objects */
   private final Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
   private final Map<TaskName, TaskInstanceCollector> taskInstanceCollectors;
@@ -297,11 +300,13 @@
       MessageChooser chooser = DefaultChooser.apply(inputStreamMetadata, new RoundRobinChooserFactory(), config,
           sideInputSystemConsumersMetrics.registry(), systemAdmins);
 
+      ApplicationConfig applicationConfig = new ApplicationConfig(config);
+
       sideInputSystemConsumers =
           new SystemConsumers(chooser, ScalaJavaUtil.toScalaMap(sideInputConsumers), systemAdmins, serdeManager,
               sideInputSystemConsumersMetrics, SystemConsumers.DEFAULT_NO_NEW_MESSAGES_TIMEOUT(), SystemConsumers.DEFAULT_DROP_SERIALIZATION_ERROR(),
               TaskConfig.DEFAULT_POLL_INTERVAL_MS, ScalaJavaUtil.toScalaFunction(() -> System.nanoTime()),
-              JobConfig.DEFAULT_JOB_ELASTICITY_FACTOR);
+              JobConfig.DEFAULT_JOB_ELASTICITY_FACTOR, applicationConfig.getRunId());
     }
 
   }
@@ -922,6 +927,8 @@
         new SamzaContainerMetrics(SIDEINPUTS_METRICS_PREFIX + this.samzaContainerMetrics.source(),
             this.samzaContainerMetrics.registry(), SIDEINPUTS_METRICS_PREFIX);
 
+    ApplicationConfig applicationConfig = new ApplicationConfig(config);
+
     this.sideInputRunLoop = new RunLoop(sideInputTasks,
         null, // all operations are executed in the main runloop thread
         this.sideInputSystemConsumers,
@@ -934,7 +941,9 @@
         taskConfig.getMaxIdleMs(),
         sideInputContainerMetrics,
         System::nanoTime,
-        false); // commit must be synchronous to ensure integrity of state flush
+        false,
+        DEFAULT_SIDE_INPUT_ELASTICITY_FACTOR,
+        applicationConfig.getRunId()); // commit must be synchronous to ensure integrity of state flush
 
     try {
       sideInputsExecutor.submit(() -> {
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index bcecebb..d1b639c 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -28,9 +28,8 @@
 import java.util.HashSet
 import java.util.Queue
 import java.util.Set
-import java.util.function.Predicate
+import java.util.function.{Consumer}
 import java.util.stream.Collectors
-
 import scala.collection.JavaConverters._
 import org.apache.samza.serializers.SerdeManager
 import org.apache.samza.util.{Logging, TimerUtil}
@@ -38,6 +37,7 @@
 import org.apache.samza.SamzaException
 import org.apache.samza.config.TaskConfig
 
+
 object SystemConsumers {
   val DEFAULT_NO_NEW_MESSAGES_TIMEOUT = 10
   val DEFAULT_DROP_SERIALIZATION_ERROR = false
@@ -116,7 +116,12 @@
    */
   val clock: () => Long = () => System.nanoTime(),
 
-  val elasticityFactor: Int = 1) extends Logging with TimerUtil {
+  val elasticityFactor: Int = 1,
+
+  /**
+   * Identifier of the current deployment.
+   * */
+  val runId: String = null) extends Logging with TimerUtil {
 
   /**
    * Mapping from the {@see SystemStreamPartition} to the registered offsets.
@@ -130,6 +135,10 @@
    */
   private val sspKeyBucketsRegistered = new HashSet[SystemStreamPartition] ()
 
+  private val intermediateSSPs = new HashSet[SystemStreamPartition]()
+
+  private val intermediateSystems = new HashSet[String]()
+
   /**
    * A buffer of incoming messages grouped by SystemStreamPartition. These
    * messages are handed out to the MessageChooser as it needs them.
@@ -155,10 +164,8 @@
     */
   private var started = false
 
-  /**
-   * Denotes if the SystemConsumers is in drain mode.
-   * */
-  private var draining = false
+  @volatile
+  private var isDraining = false
 
   /**
    * Default timeout to noNewMessagesTimeout. Every time SystemConsumers
@@ -193,8 +200,6 @@
       // but the actual systemConsumer which consumes from the input does not know about KeyBucket.
       // hence, use an SSP without KeyBucket
       consumer.register(removeKeyBucket(systemStreamPartition), offset)
-      chooser.register(removeKeyBucket(systemStreamPartition), offset)
-      debug("consumer.register and chooser.register for ssp: %s with offset %s" format (systemStreamPartition, offset))
     }
 
     debug("Starting consumers.")
@@ -214,15 +219,12 @@
 
     chooser.start
 
+
     started = true
 
     refresh
   }
 
-  def drain: Unit = {
-    draining = true
-  }
-
   def stop {
     if (started) {
       debug("Stopping consumers.")
@@ -237,6 +239,14 @@
     }
   }
 
+  def drain(): Unit = {
+    if (!isDraining) {
+      isDraining = true;
+      info("SystemConsumers is set to drain mode.")
+      consumers.values.foreach(_.stop)
+      writeDrainControlMessageToSspQueue()
+    }
+  }
 
   def register(ssp: SystemStreamPartition, offset: String) {
     // If elasticity is enabled then the RunLoop gives SSP with keybucket
@@ -255,6 +265,8 @@
     metrics.registerSystemStreamPartition(systemStreamPartition)
     unprocessedMessagesBySSP.put(systemStreamPartition, new ArrayDeque[IncomingMessageEnvelope]())
 
+    chooser.register(systemStreamPartition, offset)
+
     try {
       val consumer = consumers(systemStreamPartition.getSystem)
       val existingOffset = sspToRegisteredOffsets.get(systemStreamPartition)
@@ -268,12 +280,17 @@
     }
   }
 
+  def registerIntermediateSSP(ssp: SystemStreamPartition): Unit = {
+    debug("Registering intermediate stream: %s" format ssp)
+    intermediateSSPs.add(ssp)
+    intermediateSystems.add(ssp.getSystem)
+  }
 
   def isEndOfStream(systemStreamPartition: SystemStreamPartition) = {
     endOfStreamSSPs.contains(removeKeyBucket(systemStreamPartition))
   }
 
-  def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = {
+  def choose(updateChooser: Boolean = true): IncomingMessageEnvelope = {
     val envelopeFromChooser = chooser.choose
 
     updateTimer(metrics.deserializationNs) {
@@ -398,16 +415,54 @@
   }
 
   private def refresh {
-    if (draining) {
-      trace("Skipping refresh of chooser as the multiplexer is in drain mode.")
-      return
-    }
-    trace("Refreshing chooser with new messages.")
-
     // Update last poll time so we don't poll too frequently.
     lastPollNs = clock()
-    // Poll every system for new messages.
-    consumers.keys.map(poll(_))
+
+    if (isDraining) {
+      trace("Refreshing chooser with new messages from intermediate systems.")
+
+      // scala 2.11 doesn't allow using syntactical sugar: intermediateSystems.foreach(poll(_)) over java collections
+      intermediateSystems.forEach(new Consumer[String] {
+        override def accept(system: String): Unit = poll(system)
+      })
+    } else {
+      trace("Refreshing chooser with new messages.")
+      consumers.keys.foreach(poll(_))
+    }
+  }
+
+  private def writeDrainControlMessageToSspQueue() {
+    val sspsToDrain = new HashSet(sspKeyBucketsRegistered)
+
+    // only write Drain ControlMessages to source SSPs
+    // sspsToDrain = allSSPs - intermediateSSPs - eosSSPs
+    sspsToDrain.removeAll(intermediateSSPs)
+    sspsToDrain.removeAll(endOfStreamSSPs)
+
+    sspsToDrain.forEach(new Consumer[SystemStreamPartition] {
+      override def accept(ssp: SystemStreamPartition): Unit = {
+        val envelopes: Queue[IncomingMessageEnvelope] =
+          if (unprocessedMessagesBySSP.containsKey(ssp)) {
+            unprocessedMessagesBySSP.get(ssp)
+          } else {
+            new util.ArrayDeque[IncomingMessageEnvelope]()
+          }
+
+        // Add watermark ControlMessage only if there are intermediate SSPs as low-level API task doesn't process
+        // WatermarkMessages
+        if (!intermediateSSPs.isEmpty) {
+          envelopes.add(IncomingMessageEnvelope.buildWatermarkEnvelope(ssp, Long.MaxValue))
+          totalUnprocessedMessages += 1
+        }
+        // Add Drain ControlMessage
+        envelopes.add(IncomingMessageEnvelope.buildDrainMessage(ssp, runId))
+        totalUnprocessedMessages += 1
+        unprocessedMessagesBySSP.put(ssp, envelopes)
+
+        // update the chooser with the messages
+        tryUpdate(ssp)
+      }
+    })
   }
 
   /**
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index 7bd5f9a..e584974 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -20,8 +20,9 @@
 package org.apache.samza.container;
 
 import com.google.common.collect.ImmutableMap;
-import java.util.Collections;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
@@ -49,6 +50,7 @@
 
 public class TestRunLoop {
   // Immutable objects shared by all test methods.
+  private final String runId = "foo";
   private final ExecutorService executor = null;
   private final SamzaContainerMetrics containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap(), "");
   private final long windowMs = -1;
@@ -60,13 +62,23 @@
   private final Partition p1 = new Partition(1);
   private final TaskName taskName0 = new TaskName(p0.toString());
   private final TaskName taskName1 = new TaskName(p1.toString());
-  private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
-  private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope11 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
-  private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
-  private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
+  private final SystemStreamPartition sspA0 = new SystemStreamPartition("testSystem", "testStreamA", p0);
+  private final SystemStreamPartition sspA1 = new SystemStreamPartition("testSystem", "testStreamA", p1);
+  private final SystemStreamPartition sspB0 = new SystemStreamPartition("testSystem", "testStreamB", p0);
+  private final SystemStreamPartition sspB1 = new SystemStreamPartition("testSystem", "testStreamB", p1);
+  private final IncomingMessageEnvelope envelopeA00 = new IncomingMessageEnvelope(sspA0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelopeA11 = new IncomingMessageEnvelope(sspA1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelopeA01 = new IncomingMessageEnvelope(sspA0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelopeB00 = new IncomingMessageEnvelope(sspB0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelopeB11 = new IncomingMessageEnvelope(sspB1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope sspA0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspA0);
+  private final IncomingMessageEnvelope sspA1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspA1);
+  private final IncomingMessageEnvelope sspB0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspB0);
+  private final IncomingMessageEnvelope sspB1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspB1);
+  private final IncomingMessageEnvelope sspA0Drain = IncomingMessageEnvelope.buildDrainMessage(sspA0, runId);
+  private final IncomingMessageEnvelope sspA1Drain = IncomingMessageEnvelope.buildDrainMessage(sspA1, runId);
+  private final IncomingMessageEnvelope sspB0Drain = IncomingMessageEnvelope.buildDrainMessage(sspB0, runId);
+  private final IncomingMessageEnvelope sspB1Drain = IncomingMessageEnvelope.buildDrainMessage(sspB1, runId);
 
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
@@ -75,8 +87,8 @@
   public void testProcessMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
     tasks.put(taskName0, task0);
@@ -84,12 +96,13 @@
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo");
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(sspA0EndOfStream).thenReturn(
+        sspA1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    verify(task0).process(eq(envelope00), any(), any());
-    verify(task1).process(eq(envelope11), any(), any());
+    verify(task0).process(eq(envelopeA00), any(), any());
+    verify(task1).process(eq(envelopeA11), any(), any());
 
     assertEquals(4L, containerMetrics.envelopes().getCount());
   }
@@ -97,9 +110,9 @@
   @Test
   public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0EndOfStream).thenReturn(null);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
 
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
@@ -108,8 +121,8 @@
     runLoop.run();
 
     InOrder inOrder = inOrder(task0);
-    inOrder.verify(task0).process(eq(envelope00), any(), any());
-    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA01), any(), any());
   }
 
   @Test
@@ -119,7 +132,7 @@
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     when(task0.offsetManager()).thenReturn(offsetManager);
     CountDownLatch firstMessageBarrier = new CountDownLatch(1);
     doAnswer(invocation -> {
@@ -134,7 +147,7 @@
         return null;
       });
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
     doAnswer(invocation -> {
       assertEquals(1, task0.metrics().messagesInFlight().getValue());
@@ -145,21 +158,21 @@
       callback.complete();
       firstMessageBarrier.countDown();
       return null;
-    }).when(task0).process(eq(envelope01), any(), any());
+    }).when(task0).process(eq(envelopeA01), any(), any());
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
     tasks.put(taskName0, task0);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
     runLoop.run();
 
     InOrder inOrder = inOrder(task0);
-    inOrder.verify(task0).process(eq(envelope00), any(), any());
-    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA01), any(), any());
 
-    verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    verify(offsetManager).update(eq(taskName0), eq(sspA0), eq(envelopeA00.getOffset()));
 
     assertEquals(2L, containerMetrics.processes().getCount());
   }
@@ -168,9 +181,9 @@
   public void testProcessElasticityEnabled() {
 
     TaskName taskName0 = new TaskName(p0.toString() + " 0");
-    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStream", p0);
-    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0, 0);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p0, 1);
+    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
 
     // create two IME such that one of their ssp keybucket maps to ssp0 and the other one maps to ssp1
     // task in the runloop should process only the first ime (aka the one whose ssp keybucket is ssp0)
@@ -197,12 +210,12 @@
     }).when(task0).process(eq(envelope01), any(), any());
 
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(sspA0EndOfStream).thenReturn(null);
 
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2);
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null);
     runLoop.run();
 
     verify(task0).process(eq(envelope00), any(), any());
@@ -216,9 +229,9 @@
 
     TaskName taskName0 = new TaskName(p0.toString() + " 0");
     TaskName taskName1 = new TaskName(p0.toString() + " 1");
-    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStream", p0);
-    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0, 0);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p0, 1);
+    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
 
     // create EOS IME such that its ssp keybucket maps to ssp0 and not to ssp1
     // task in the runloop should give this ime to both it tasks
@@ -237,7 +250,7 @@
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2);
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null);
     runLoop.run();
 
     verify(task0).endOfStream(any());
@@ -245,12 +258,104 @@
   }
 
   @Test
+  public void testDrainWithElasticityEnabled() {
+    TaskName taskName0 = new TaskName(p0.toString() + " 0");
+    TaskName taskName1 = new TaskName(p0.toString() + " 1");
+    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
+
+    // create EOS IME such that its ssp keybucket maps to ssp0 and not to ssp1
+    // task in the runloop should give this ime to both it tasks
+    IncomingMessageEnvelope envelopeDrain = spy(IncomingMessageEnvelope.buildDrainMessage(ssp, runId));
+    when(envelopeDrain.getSystemStreamPartition(2)).thenReturn(ssp0);
+
+    // two task in the run loop that processes ssp0 -> 0th keybucket of ssp and ssp1 -> 1st keybucket of ssp
+    // Drain ime should be given to both the tasks irrespective of the keybucket
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeDrain).thenReturn(null);
+
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+    int maxMessagesInFlight = 1;
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, runId);
+    runLoop.run();
+
+    verify(task0).drain(any());
+    verify(task1).drain(any());
+  }
+
+
+  @Test
+  public void testDrainForTasksWithSingleSSP() {
+    TaskName taskName0 = new TaskName(p0.toString() + " 0");
+    TaskName taskName1 = new TaskName(p1.toString() + " 1");
+
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
+
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+    // insert all envelopes followed by drain messages
+    when(consumerMultiplexer.choose(false))
+        .thenReturn(envelopeA00).thenReturn(envelopeA11)
+        .thenReturn(sspA0Drain).thenReturn(sspA1Drain);
+
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+    int maxMessagesInFlight = 1;
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+    runLoop.run();
+
+    // check if process was called once for each task
+    verify(task0, times(1)).process(any(), any(), any());
+    verify(task1, times(1)).process(any(), any(), any());
+    // check if drain was called once for each task followed by commit
+    verify(task0, times(1)).drain(any());
+    verify(task1, times(1)).drain(any());
+    verify(task0, times(1)).commit();
+    verify(task1, times(1)).commit();
+  }
+
+  @Test
+  public void testDrainForTasksWithMultipleSSP() {
+    TaskName taskName0 = new TaskName(p0.toString() + " 0");
+    TaskName taskName1 = new TaskName(p1.toString() + " 1");
+
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0, sspB0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1, sspB1);
+
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+    // insert all envelopes followed by drain messages
+    when(consumerMultiplexer.choose(false))
+        .thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(envelopeB00).thenReturn(envelopeB11)
+        .thenReturn(sspA0Drain).thenReturn(sspA1Drain).thenReturn(sspB0Drain).thenReturn(sspB1Drain);
+
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+    int maxMessagesInFlight = 1;
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+    runLoop.run();
+
+    // check if process was called twice for each task
+    verify(task0, times(2)).process(any(), any(), any());
+    verify(task1, times(2)).process(any(), any(), any());
+    // check if drain was called once for each task followed by commit
+    verify(task0, times(1)).drain(any());
+    verify(task1, times(1)).drain(any());
+    verify(task0, times(1)).commit();
+    verify(task1, times(1)).commit();
+  }
+
+  @Test
   public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
     int maxMessagesInFlight = 1;
     long windowMs = 1;
-    RunLoopTask task = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task = getMockRunLoopTask(taskName0, sspA0);
     when(task.isWindowableTask()).thenReturn(true);
 
     final AtomicInteger windowCount = new AtomicInteger(0);
@@ -277,7 +382,7 @@
   public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     doAnswer(invocation -> {
       ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
       TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -288,9 +393,9 @@
 
       callback.complete();
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
     tasks.put(this.taskName0, task0);
@@ -300,7 +405,7 @@
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
 
     runLoop.run();
 
@@ -315,7 +420,7 @@
   public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     doAnswer(invocation -> {
       ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
       TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -326,9 +431,9 @@
 
       callback.complete();
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
     tasks.put(this.taskName0, task0);
@@ -338,7 +443,7 @@
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
 
     runLoop.run();
 
@@ -354,7 +459,7 @@
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
     int maxMessagesInFlight = 1;
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     doAnswer(invocation -> {
       ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
       TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -363,9 +468,9 @@
       coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
       callback.complete();
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
     doAnswer(invocation -> {
       ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
       TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -374,7 +479,7 @@
       coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
       callback.complete();
       return null;
-    }).when(task1).process(eq(envelope11), any(), any());
+    }).when(task1).process(eq(envelopeA11), any(), any());
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
     tasks.put(taskName0, task0);
@@ -383,7 +488,7 @@
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
     runLoop.run();
 
     verify(task0).process(any(), any(), any());
@@ -397,8 +502,8 @@
   public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0, sspB0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1, sspB1);
 
     Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
@@ -409,21 +514,27 @@
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope00)
-      .thenReturn(envelope11)
-      .thenReturn(ssp0EndOfStream)
-      .thenReturn(ssp1EndOfStream)
+      .thenReturn(envelopeA00)
+      .thenReturn(envelopeA11)
+      .thenReturn(envelopeB00)
+      .thenReturn(envelopeB11)
+      .thenReturn(sspA0EndOfStream)
+      .thenReturn(sspB0EndOfStream)
+      .thenReturn(sspB1EndOfStream)
+      .thenReturn(sspA1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).process(eq(envelopeA00), any(), any());
+    verify(task0).process(eq(envelopeB00), any(), any());
     verify(task0).endOfStream(any());
 
-    verify(task1).process(eq(envelope11), any(), any());
+    verify(task1).process(eq(envelopeA11), any(), any());
+    verify(task1).process(eq(envelopeB11), any(), any());
     verify(task1).endOfStream(any());
 
-    assertEquals(4L, containerMetrics.envelopes().getCount());
+    assertEquals(8L, containerMetrics.envelopes().getCount());
   }
 
   @Test
@@ -433,7 +544,7 @@
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     when(task0.offsetManager()).thenReturn(offsetManager);
     CountDownLatch firstMessageBarrier = new CountDownLatch(2);
     doAnswer(invocation -> {
@@ -445,7 +556,7 @@
         return null;
       });
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
     doAnswer(invocation -> {
       assertEquals(1, task0.metrics().messagesInFlight().getValue());
@@ -455,7 +566,7 @@
       callback.complete();
       firstMessageBarrier.countDown();
       return null;
-    }).when(task0).process(eq(envelope01), any(), any());
+    }).when(task0).process(eq(envelopeA01), any(), any());
 
     doAnswer(invocation -> {
       assertEquals(0, task0.metrics().messagesInFlight().getValue());
@@ -469,7 +580,7 @@
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0EndOfStream)
         .thenAnswer(invocation -> {
           // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
           firstMessageBarrier.countDown();
@@ -482,10 +593,67 @@
   }
 
   @Test
+  public void testDrainWaitsForInFlightMessages() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+    OffsetManager offsetManager = mock(OffsetManager.class);
+
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+      TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+      TaskCallback callback = callbackFactory.createCallback();
+      taskExecutor.submit(() -> {
+        firstMessageBarrier.await();
+        callback.complete();
+        return null;
+      });
+      return null;
+    }).when(task0).process(eq(envelopeA00), any(), any());
+
+    doAnswer(invocation -> {
+      assertEquals(1, task0.metrics().messagesInFlight().getValue());
+
+      TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+      TaskCallback callback = callbackFactory.createCallback();
+      callback.complete();
+      firstMessageBarrier.countDown();
+      return null;
+    }).when(task0).process(eq(envelopeA01), any(), any());
+
+    doAnswer(invocation -> {
+      assertEquals(0, task0.metrics().messagesInFlight().getValue());
+      assertEquals(2, task0.metrics().asyncCallbackCompleted().getCount());
+
+      return null;
+    }).when(task0).drain(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+        1, runId);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0Drain)
+        .thenAnswer(invocation -> {
+          // this ensures that the drain message has passed through run loop BEFORE the flight message
+          // completes
+          firstMessageBarrier.countDown();
+          return null;
+        });
+
+    runLoop.run();
+
+    verify(task0).drain(any());
+  }
+
+  @Test
   public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     doAnswer(invocation -> {
       ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
 
@@ -500,7 +668,7 @@
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
@@ -511,13 +679,36 @@
   }
 
   @Test
+  public void testDrainCommitBehavior() {
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+
+    tasks.put(taskName0, task0);
+
+    int maxMessagesInFlight = 1;
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+        1, runId);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0Drain).thenReturn(null);
+
+    runLoop.run();
+
+    InOrder inOrder = inOrder(task0);
+
+    inOrder.verify(task0).drain(any());
+    inOrder.verify(task0).commit();
+  }
+
+  @Test
   public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
     int maxMessagesInFlight = 2;
     ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     when(task0.offsetManager()).thenReturn(offsetManager);
     CountDownLatch firstMessageBarrier = new CountDownLatch(1);
     doAnswer(invocation -> {
@@ -532,7 +723,7 @@
         return null;
       });
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
     CountDownLatch secondMessageBarrier = new CountDownLatch(1);
     doAnswer(invocation -> {
@@ -550,7 +741,7 @@
         return null;
       });
       return null;
-    }).when(task0).process(eq(envelope01), any(), any());
+    }).when(task0).process(eq(envelopeA01), any(), any());
 
     doAnswer(invocation -> {
       assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
@@ -565,12 +756,12 @@
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
     runLoop.run();
 
     InOrder inOrder = inOrder(task0);
-    inOrder.verify(task0).process(eq(envelope00), any(), any());
-    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+    inOrder.verify(task0).process(eq(envelopeA01), any(), any());
     inOrder.verify(task0).commit();
   }
 
@@ -578,12 +769,12 @@
   public void testExceptionIsPropagated() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
     doAnswer(invocation -> {
       TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
       callbackFactory.createCallback().failure(new Exception("Intentional failure"));
       return null;
-    }).when(task0).process(eq(envelope00), any(), any());
+    }).when(task0).process(eq(envelopeA00), any(), any());
 
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
 
@@ -592,16 +783,16 @@
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
 
     when(consumerMultiplexer.choose(false))
-        .thenReturn(envelope00)
-        .thenReturn(ssp0EndOfStream)
+        .thenReturn(envelopeA00)
+        .thenReturn(sspA0EndOfStream)
         .thenReturn(null);
 
     runLoop.run();
   }
 
-  private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ssp0) {
+  private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ... ssps) {
     RunLoopTask task0 = mock(RunLoopTask.class);
-    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.systemStreamPartitions()).thenReturn(new HashSet<>(Arrays.asList(ssps)));
     when(task0.metrics()).thenReturn(new TaskInstanceMetrics("test", new MetricsRegistryMap(), ""));
     when(task0.taskName()).thenReturn(taskName);
     return task0;
diff --git a/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
index e7666aa..80061b3 100644
--- a/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
@@ -41,13 +41,13 @@
  * Tests for {@link DrainMonitor}
  * */
 public class DrainMonitorTests {
-  private static final String TEST_DEPLOYMENT_ID = "foo";
+  private static final String TEST_RUN_ID = "foo";
 
   private static final Config
       CONFIG = new MapConfig(ImmutableMap.of(
           "job.name", "test-job",
       "job.coordinator.system", "test-kafka",
-      ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+      ApplicationConfig.APP_RUN_ID, TEST_RUN_ID));
 
   private CoordinatorStreamStore coordinatorStreamStore;
 
@@ -114,7 +114,7 @@
     final AtomicInteger numCallbacks = new AtomicInteger(0);
     final CountDownLatch latch = new CountDownLatch(1);
     // write drain before monitor start
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
     DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, CONFIG);
     drainMonitor.registerDrainCallback(() -> {
       numCallbacks.incrementAndGet();
@@ -141,7 +141,7 @@
       latch.countDown();
     });
     drainMonitor.start();
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
     if (!latch.await(2, TimeUnit.SECONDS)) {
       Assert.fail("Timed out waiting for drain callback to complete");
     }
@@ -150,8 +150,8 @@
   }
 
   @Test
-  public void testCallbackNotCalledDueToMismatchedDeploymentId() throws InterruptedException {
-    // The test fails due to timeout as the published DrainNotification's deploymentId doesn't match deploymentId
+  public void testCallbackNotCalledDueToMismatchedRunId() throws InterruptedException {
+    // The test fails due to timeout as the published DrainNotification's runId doesn't match runId
     // in the config
     exceptionRule.expect(AssertionError.class);
     exceptionRule.expectMessage("Timed out waiting for drain callback to complete.");
@@ -166,8 +166,8 @@
     });
 
     drainMonitor.start();
-    final String mismatchedDeploymentId = "bar";
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedDeploymentId);
+    final String mismatchedRunId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedRunId);
     if (!latch.await(2, TimeUnit.SECONDS)) {
       Assert.fail("Timed out waiting for drain callback to complete.");
     }
@@ -184,16 +184,16 @@
 
   @Test
   public void testShouldDrain() {
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
     NamespaceAwareCoordinatorStreamStore drainStore =
         new NamespaceAwareCoordinatorStreamStore(coordinatorStreamStore, DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
-    Assert.assertTrue(DrainMonitor.shouldDrain(drainStore, TEST_DEPLOYMENT_ID));
+    Assert.assertTrue(DrainMonitor.shouldDrain(drainStore, TEST_RUN_ID));
 
     // Cleanup old drain message
     DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
 
-    final String mismatchedDeploymentId = "bar";
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedDeploymentId);
-    Assert.assertFalse(DrainMonitor.shouldDrain(drainStore, TEST_DEPLOYMENT_ID));
+    final String mismatchedRunId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedRunId);
+    Assert.assertFalse(DrainMonitor.shouldDrain(drainStore, TEST_RUN_ID));
   }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
index 1726535..43b91ad 100644
--- a/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
@@ -45,11 +45,11 @@
  * Tests for {@link DrainUtils}
  * */
 public class DrainUtilsTests {
-  private static final String TEST_DEPLOYMENT_ID = "foo";
+  private static final String TEST_RUN_ID = "foo";
   private static final Config CONFIG = new MapConfig(ImmutableMap.of(
       "job.name", "test-job",
       "job.coordinator.system", "test-kafka",
-      ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+      ApplicationConfig.APP_RUN_ID, TEST_RUN_ID));
 
   private CoordinatorStreamStore coordinatorStreamStore;
 
@@ -68,17 +68,17 @@
 
   @Test
   public void testWrites() {
-    String deploymentId1 = "foo1";
-    String deploymentId2 = "foo2";
-    String deploymentId3 = "foo3";
+    String runId1 = "foo1";
+    String runId2 = "foo2";
+    String runId3 = "foo3";
 
-    UUID uuid1 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId1);
-    UUID uuid2 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId2);
-    UUID uuid3 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId3);
+    UUID uuid1 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId1);
+    UUID uuid2 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId2);
+    UUID uuid3 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId3);
 
-    DrainNotification expectedDrainNotification1 = new DrainNotification(uuid1, deploymentId1);
-    DrainNotification expectedDrainNotification2 = new DrainNotification(uuid2, deploymentId2);
-    DrainNotification expectedDrainNotification3 = new DrainNotification(uuid3, deploymentId3);
+    DrainNotification expectedDrainNotification1 = new DrainNotification(uuid1, runId1);
+    DrainNotification expectedDrainNotification2 = new DrainNotification(uuid2, runId2);
+    DrainNotification expectedDrainNotification3 = new DrainNotification(uuid3, runId3);
     Set<DrainNotification> expectedDrainNotifications = new HashSet<>(Arrays.asList(expectedDrainNotification1,
         expectedDrainNotification2, expectedDrainNotification3));
 
@@ -90,23 +90,23 @@
 
   @Test
   public void testCleanup() {
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
     DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
     final Optional<List<DrainNotification>> drainNotifications1 = readDrainNotificationMessages(coordinatorStreamStore);
     Assert.assertFalse(drainNotifications1.isPresent());
 
-    final String deploymentId = "bar";
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId);
+    final String runId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, runId);
     DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
     final Optional<List<DrainNotification>> drainNotifications2 = readDrainNotificationMessages(coordinatorStreamStore);
     Assert.assertTrue(drainNotifications2.isPresent());
-    Assert.assertEquals(deploymentId, drainNotifications2.get().get(0).getDeploymentId());
+    Assert.assertEquals(runId, drainNotifications2.get().get(0).getRunId());
   }
 
   @Test
   public void testCleanupAll() {
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
-    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
     DrainUtils.writeDrainNotification(coordinatorStreamStore, "bar");
     DrainUtils.cleanupAll(coordinatorStreamStore);
     final Optional<List<DrainNotification>> drainNotifications = readDrainNotificationMessages(coordinatorStreamStore);
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index d9bb916..6b2154e 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -57,7 +57,7 @@
   @Mock
   private var taskInstance: TaskInstance = null
   @Mock
-  private var runLoop: Runnable = null
+  private var runLoop: RunLoop = null
   @Mock
   private var systemAdmins: SystemAdmins = null
   @Mock
diff --git a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
index 943a8ca..60323d6 100644
--- a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
+++ b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
@@ -20,6 +20,7 @@
 package org.apache.samza.test.framework;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
 import java.io.File;
 import java.time.Duration;
 import java.util.HashMap;
@@ -28,6 +29,9 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import org.apache.commons.lang3.RandomStringUtils;
 import org.apache.samza.SamzaException;
@@ -93,10 +97,12 @@
   private static final Logger LOG = LoggerFactory.getLogger(TestRunner.class);
   private static final String JOB_DEFAULT_SYSTEM = "default-samza-system";
   private static final String APP_NAME = "samza-test";
+  private static final long DEFAULT_DELAY_BETWEEN_MESSAGES = 0L;
 
   private Map<String, String> configs;
   private SamzaApplication app;
   private ExternalContext externalContext;
+  private InMemoryMetadataStoreFactory inMemoryMetadataStoreFactory;
   /*
    * inMemoryScope is a unique global key per TestRunner, this key when configured with {@link InMemorySystemDescriptor}
    * provides an isolated state to run with in memory system
@@ -201,6 +207,13 @@
   }
 
   /**
+   * Get configs.
+   * */
+  public Config getConfig() {
+    return new MapConfig(configs);
+  }
+
+  /**
    * Passes the user provided external context to {@link LocalApplicationRunner}
    *
    * @param externalContext external context provided by user
@@ -213,6 +226,18 @@
   }
 
   /**
+   * Set an InMemory MetadataStoreFactory to be used by {@link LocalApplicationRunner}
+   *
+   * @param inMemoryMetadataStoreFactory InMemoryMetadataStoreFactory
+   * @return this {@link TestRunner}
+   */
+  public TestRunner setInMemoryMetadataFactory(InMemoryMetadataStoreFactory inMemoryMetadataStoreFactory) {
+    Preconditions.checkNotNull(inMemoryMetadataStoreFactory);
+    this.inMemoryMetadataStoreFactory = inMemoryMetadataStoreFactory;
+    return this;
+  }
+
+  /**
    * Adds the provided input stream with mock data to the test application.
    *
    * @param descriptor describes the stream that is supposed to be input to Samza application
@@ -224,10 +249,29 @@
    */
   public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
       List<StreamMessageType> messages) {
+    return addInputStream(descriptor, messages, DEFAULT_DELAY_BETWEEN_MESSAGES);
+  }
+
+  /**
+   * Adds the provided input stream with mock data to the test application.The mock messages will be added one at a time
+   * with a delay between messages instead of adding all at once.
+   * Default configs and user added configs have
+   * a higher precedence over system and stream descriptor generated configs.
+   * @param descriptor describes the stream that is supposed to be input to Samza application
+   * @param messages map whose key is partitionId and value is messages in the partition. These message should always
+   *                 be deserialized
+   * @param delayBetweenMessagesInMillis delay between messages
+   * @param <StreamMessageType> message with null key or a KV {@link org.apache.samza.operators.KV}.
+   *                           A key of which represents key of {@link org.apache.samza.system.IncomingMessageEnvelope} or
+   *                           {@link org.apache.samza.system.OutgoingMessageEnvelope} and value is message
+   * @return this {@link TestRunner}
+   */
+  public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
+      List<StreamMessageType> messages, long delayBetweenMessagesInMillis) {
     Preconditions.checkNotNull(descriptor, messages);
-    Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<Integer, Iterable<StreamMessageType>>();
+    Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<>();
     partitionData.put(0, messages);
-    initializeInMemoryInputStream(descriptor, partitionData);
+    initializeInMemoryInputStream(descriptor, partitionData, delayBetweenMessagesInMillis);
     return this;
   }
 
@@ -244,10 +288,28 @@
    */
   public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
       Map<Integer, ? extends Iterable<StreamMessageType>> messages) {
+    return addInputStream(descriptor, messages, DEFAULT_DELAY_BETWEEN_MESSAGES);
+  }
+
+  /**
+   * Adds the provided input stream with mock data to the test application. The mock messages will be added one at a time
+   * with a delay between messages instead of adding all at once. TestRunner will cycle through all partitions once to add
+   * a message each per partition. This will be repeated periodically till all messages are exhausted.
+   * Default configs and user added configs have a higher precedence over system and stream descriptor generated configs.
+   * @param descriptor describes the stream that is supposed to be input to Samza application
+   * @param messages map whose key is partitionId and value is messages in the partition. These message should always
+   *                 be deserialized
+   * @param delayBetweenMessagesInMillis delay bewtween messages.
+   * @param <StreamMessageType> message with null key or a KV {@link org.apache.samza.operators.KV}.
+   *                           A key of which represents key of {@link org.apache.samza.system.IncomingMessageEnvelope} or
+   *                           {@link org.apache.samza.system.OutgoingMessageEnvelope} and value is message
+   * @return this {@link TestRunner}
+   */
+  public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
+      Map<Integer, ? extends Iterable<StreamMessageType>> messages, long delayBetweenMessagesInMillis) {
     Preconditions.checkNotNull(descriptor, messages);
-    Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<Integer, Iterable<StreamMessageType>>();
-    partitionData.putAll(messages);
-    initializeInMemoryInputStream(descriptor, partitionData);
+    Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<>(messages);
+    initializeInMemoryInputStream(descriptor, partitionData, delayBetweenMessagesInMillis);
     return this;
   }
 
@@ -291,7 +353,10 @@
     // Cleaning store directories to ensure current run does not pick up state from previous run
     deleteStoreDirectories();
     Config config = new MapConfig(JobPlanner.generateSingleJobConfig(configs));
-    final LocalApplicationRunner runner = new LocalApplicationRunner(app, config, new InMemoryMetadataStoreFactory());
+    InMemoryMetadataStoreFactory metadataStoreFactory = inMemoryMetadataStoreFactory != null
+        ? inMemoryMetadataStoreFactory
+        : new InMemoryMetadataStoreFactory();
+    final LocalApplicationRunner runner = new LocalApplicationRunner(app, config, metadataStoreFactory);
     runner.run(externalContext);
     if (!runner.waitForFinish(timeout)) {
       throw new SamzaException("Timed out waiting for application to finish");
@@ -378,7 +443,7 @@
    * @param descriptor describes a stream to initialize with the in memory system
    */
   private <StreamMessageType> void initializeInMemoryInputStream(InMemoryInputDescriptor<?> descriptor,
-      Map<Integer, Iterable<StreamMessageType>> partitionData) {
+      Map<Integer, Iterable<StreamMessageType>> partitionData, long delayInMillis) {
     String systemName = descriptor.getSystemName();
     String streamName = (String) descriptor.getPhysicalName().orElse(descriptor.getStreamId());
     if (this.app instanceof LegacyTaskApplication) {
@@ -402,19 +467,64 @@
     factory.getAdmin(systemName, config).createStream(spec);
     InMemorySystemProducer producer = (InMemorySystemProducer) factory.getProducer(systemName, config, null);
     SystemStream sysStream = new SystemStream(systemName, streamName);
-    partitionData.forEach((partitionId, partition) -> {
-      partition.forEach(e -> {
-        Object key = e instanceof KV ? ((KV) e).getKey() : null;
-        Object value = e instanceof KV ? ((KV) e).getValue() : e;
-        if (value instanceof IncomingMessageEnvelope) {
-          producer.send((IncomingMessageEnvelope) value);
-        } else {
-          producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value));
-        }
+    if (delayInMillis > 0) {
+      delayedInitialization(partitionData, producer, sysStream, delayInMillis);
+    } else {
+      partitionData.forEach((partitionId, partition) -> {
+        partition.forEach(e -> {
+          Object key = e instanceof KV ? ((KV) e).getKey() : null;
+          Object value = e instanceof KV ? ((KV) e).getValue() : e;
+          if (value instanceof IncomingMessageEnvelope) {
+            producer.send((IncomingMessageEnvelope) value);
+          } else {
+            producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value));
+          }
+        });
+        producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null,
+            new EndOfStreamMessage(null)));
       });
-      producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null,
-          new EndOfStreamMessage(null)));
-    });
+    }
+  }
+
+  private <StreamMessageType> void delayedInitialization(Map<Integer, Iterable<StreamMessageType>> partitionData,
+      InMemorySystemProducer producer, SystemStream systemStream, long delayInMillis) {
+    final Set<Integer> endOfStreamPartitions = new HashSet<>();
+    final int numPartitions = partitionData.size();
+    Map<Integer, LinkedList<StreamMessageType>> messageQueuesByPartition = partitionData.entrySet()
+        .stream()
+        .collect(Collectors.toMap(
+          Map.Entry::getKey,
+          entry -> Lists.newLinkedList(entry.getValue())));
+
+    final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
+    executor.scheduleAtFixedRate(new Runnable() {
+      @Override
+      public void run() {
+        if (endOfStreamPartitions.size() == numPartitions) {
+          // shutdown scheduled executor as EOS is reached for all partitions
+          executor.shutdownNow();
+        } else {
+          messageQueuesByPartition.forEach((partitionId, messages) -> {
+            if (messages.isEmpty() && !endOfStreamPartitions.contains(partitionId)) {
+              // if end of input is reached, send EOS for the partition and add the partition to the EOS partition set
+              // to ensure we don't send EOS again for the partition
+              producer.send(systemStream.getSystem(),
+                  new OutgoingMessageEnvelope(systemStream, Integer.valueOf(partitionId), null, new EndOfStreamMessage(null)));
+              endOfStreamPartitions.add(partitionId);
+            } else if (!messages.isEmpty()) {
+              final StreamMessageType e = messageQueuesByPartition.get(partitionId).removeFirst();
+              final Object key = e instanceof KV ? ((KV) e).getKey() : null;
+              final Object value = e instanceof KV ? ((KV) e).getValue() : e;
+              if (value instanceof IncomingMessageEnvelope) {
+                producer.send((IncomingMessageEnvelope) value);
+              } else {
+                producer.send(systemStream.getSystem(), new OutgoingMessageEnvelope(systemStream, Integer.valueOf(partitionId), key, value));
+              }
+            }
+          });
+        }
+      }
+    }, 0L, delayInMillis, TimeUnit.MILLISECONDS);
   }
 
   private void deleteStoreDirectories() {
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java b/samza-test/src/test/java/org/apache/samza/test/TestData.java
similarity index 97%
rename from samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java
rename to samza-test/src/test/java/org/apache/samza/test/TestData.java
index 52d8380..5b2f137 100644
--- a/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java
+++ b/samza-test/src/test/java/org/apache/samza/test/TestData.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.samza.test.controlmessages;
+package org.apache.samza.test;
 
 import java.io.Serializable;
 import org.apache.samza.SamzaException;
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
index 849f64c..bbe03ae 100644
--- a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
@@ -67,8 +67,8 @@
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.StreamOperatorTask;
 import org.apache.samza.task.TestStreamOperatorTask;
-import org.apache.samza.test.controlmessages.TestData.PageView;
-import org.apache.samza.test.controlmessages.TestData.PageViewJsonSerdeFactory;
+import org.apache.samza.test.TestData.PageView;
+import org.apache.samza.test.TestData.PageViewJsonSerdeFactory;
 import org.apache.samza.test.harness.IntegrationTestHarness;
 import org.apache.samza.test.util.SimpleSystemAdmin;
 import org.apache.samza.test.util.TestStreamConsumer;
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
new file mode 100644
index 0000000..c1a6f91
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.drain;
+
+import com.google.common.collect.ImmutableMap;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.drain.DrainUtils;
+import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.operators.KV;
+import org.apache.samza.serializers.IntegerSerde;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
+import org.apache.samza.system.descriptors.GenericInputDescriptor;
+import org.apache.samza.test.framework.TestRunner;
+import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
+import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
+import org.apache.samza.test.table.TestTableData;
+import org.apache.samza.test.table.TestTableData.PageView;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+
+/**
+ * End to end integration test to check drain functionality with samza high-level API.
+ */
+public class DrainHighLevelApiIntegrationTest {
+  private static final List<PageView> RECEIVED = new ArrayList<>();
+  private static final String SYSTEM_NAME = "test";
+  private static final String STREAM_ID = "PageView";
+
+  private static class  PageViewEventCountHighLevelApplication implements StreamApplication {
+    @Override
+    public void describe(StreamApplicationDescriptor appDescriptor) {
+      DelegatingSystemDescriptor sd = new DelegatingSystemDescriptor(SYSTEM_NAME);
+      GenericInputDescriptor<KV<String, PageView>> isd =
+          sd.getInputDescriptor(STREAM_ID, KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
+      appDescriptor.getInputStream(isd)
+          .map(KV::getValue)
+          .partitionBy(PageView::getMemberId, pv -> pv,
+              KVSerde.of(new IntegerSerde(), new TestTableData.PageViewJsonSerde()), "p1")
+          .sink((m, collector, coordinator) -> {
+            RECEIVED.add(m.getValue());
+          });
+    }
+  }
+
+  // The test can be occasionally flaky, so we set Ignore annotation
+  // Remove ignore annotation and run the test as follows:
+  // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainHighLevelApiIntegrationTest -PscalaSuffix=2.12
+  @Ignore
+  @Test
+  public void testPipeline() {
+    String runId = "DrainTestId";
+    int numPageViews = 40;
+
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME);
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID, new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
+    // write drain message to it
+    // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
+    // per partition when writing messages to the InputStream
+    TestRunner testRunner = TestRunner.of(new PageViewEventCountHighLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // write drain message after a delay
+    ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
+    executorService.schedule(new Callable<String>() {
+      @Override
+      public String call() throws Exception {
+        UUID uuid = DrainUtils.writeDrainNotification(metadataStore, runId);
+        return uuid.toString();
+      }
+    }, 2000L, TimeUnit.MILLISECONDS);
+
+    testRunner.run(Duration.ofSeconds(20));
+
+    assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
new file mode 100644
index 0000000..7579c6e
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.drain;
+
+import com.google.common.collect.ImmutableMap;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.context.Context;
+import org.apache.samza.drain.DrainUtils;
+import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
+import org.apache.samza.system.descriptors.GenericInputDescriptor;
+import org.apache.samza.task.DrainListenerTask;
+import org.apache.samza.task.EndOfStreamListenerTask;
+import org.apache.samza.task.InitableTask;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.StreamTask;
+import org.apache.samza.task.StreamTaskFactory;
+import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.test.framework.TestRunner;
+import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
+import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
+import org.apache.samza.test.table.TestTableData;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+
+/**
+ * End to end integration test to check drain functionality with samza low-level API.
+ * */
+public class DrainLowLevelApiIntegrationTest {
+  private static final List<TestTableData.PageView> RECEIVED = new ArrayList<>();
+
+  private static Integer drainCounter = 0;
+  private static Integer eosCounter = 0;
+
+  private static final String SYSTEM_NAME = "test";
+  private static final String STREAM_ID = "PageView";
+
+  private static class PageViewEventCountLowLevelApplication implements TaskApplication {
+    @Override
+    public void describe(TaskApplicationDescriptor appDescriptor) {
+      DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor(SYSTEM_NAME);
+      GenericInputDescriptor<TestTableData.PageView> pageViewIsd = ksd.getInputDescriptor(STREAM_ID, new NoOpSerde<>());
+      appDescriptor
+          .withInputStream(pageViewIsd)
+          .withTaskFactory((StreamTaskFactory) PageViewEventCountStreamTask::new);
+    }
+  }
+
+  private static class PageViewEventCountStreamTask implements StreamTask, InitableTask, DrainListenerTask, EndOfStreamListenerTask {
+    public PageViewEventCountStreamTask() {
+    }
+
+    @Override
+    public void init(Context context) {
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope message, MessageCollector collector, TaskCoordinator coordinator) {
+      TestTableData.PageView pv = (TestTableData.PageView) message.getMessage();
+      RECEIVED.add(pv);
+    }
+
+    @Override
+    public void onDrain(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      drainCounter++;
+    }
+
+    @Override
+    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      eosCounter++;
+    }
+  }
+
+  // The test can be occasionally flaky, so we set Ignore annotation
+  // Remove ignore annotation and run the test as follows:
+  // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainHighLevelApiIntegrationTest -PscalaSuffix=2.12
+  @Ignore
+  @Test
+  public void testPipeline() {
+    int numPageViews = 40;
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor("test");
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor("PageView", new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    String runId = "DrainTestId";
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
+    // write drain message to it
+    // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
+    // per partition when writing messages to the InputStream
+    TestRunner testRunner = TestRunner.of(new PageViewEventCountLowLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore
+        metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // write drain message after a delay
+    ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
+    executorService.schedule(new Callable<String>() {
+      @Override
+      public String call() throws Exception {
+        UUID uuid = DrainUtils.writeDrainNotification(metadataStore, runId);
+        return uuid.toString();
+      }
+    }, 2000L, TimeUnit.MILLISECONDS);
+
+    testRunner.run(Duration.ofSeconds(25));
+
+    assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
index 6afc77c..2a6e70c 100644
--- a/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
@@ -42,7 +42,7 @@
 import org.apache.samza.system.kafka.descriptors.KafkaOutputDescriptor;
 import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
 import org.apache.samza.table.Table;
-import org.apache.samza.test.controlmessages.TestData;
+import org.apache.samza.test.TestData;
 import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
 import org.apache.samza.test.framework.system.descriptors.InMemoryOutputDescriptor;
 import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
@@ -53,7 +53,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import static org.apache.samza.test.controlmessages.TestData.PageView;
+import static org.apache.samza.test.TestData.PageView;
 
 public class StreamApplicationIntegrationTest {
   private static final Logger LOG = LoggerFactory.getLogger(StreamApplicationIntegrationTest.class);