SAMZA-2550: Move side input processing to use RunLoop (#1385)

diff --git a/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
new file mode 100644
index 0000000..6274c15
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.storage;
+
+import java.util.Collections;
+import java.util.Set;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.container.RunLoopTask;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallback;
+import org.apache.samza.task.TaskCallbackFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * This class encapsulates the processing logic for side input streams. It is executed by {@link org.apache.samza.container.RunLoop}
+ */
+public class SideInputTask implements RunLoopTask {
+  private static final Logger LOG = LoggerFactory.getLogger(SideInputTask.class);
+
+  private final TaskName taskName;
+  private final Set<SystemStreamPartition> taskSSPs;
+  private final TaskSideInputHandler taskSideInputHandler;
+  private final TaskInstanceMetrics metrics;
+
+  public SideInputTask(
+      TaskName taskName,
+      Set<SystemStreamPartition> taskSSPs,
+      TaskSideInputHandler taskSideInputHandler,
+      TaskInstanceMetrics metrics) {
+    this.taskName = taskName;
+    this.taskSSPs = taskSSPs;
+    this.taskSideInputHandler = taskSideInputHandler;
+    this.metrics = metrics;
+  }
+
+  @Override
+  public TaskName taskName() {
+    return this.taskName;
+  }
+
+  @Override
+  synchronized public void process(IncomingMessageEnvelope envelope, ReadableCoordinator coordinator,
+      TaskCallbackFactory callbackFactory) {
+    TaskCallback callback = callbackFactory.createCallback();
+    this.metrics.processes().inc();
+    try {
+      this.taskSideInputHandler.process(envelope);
+      this.metrics.messagesActuallyProcessed().inc();
+      callback.complete();
+    } catch (Exception e) {
+      callback.failure(e);
+    }
+  }
+
+  @Override
+  public void window(ReadableCoordinator coordinator) {
+    throw new UnsupportedOperationException("Windowing is not applicable for side input tasks.");
+  }
+
+  @Override
+  public void scheduler(ReadableCoordinator coordinator) {
+    throw new UnsupportedOperationException("Scheduling is not applicable for side input tasks.");
+  }
+
+  @Override
+  synchronized public void commit() {
+    this.taskSideInputHandler.flush();
+    this.metrics.commits().inc();
+  }
+
+  @Override
+  public void endOfStream(ReadableCoordinator coordinator) {
+    LOG.info("Task {} has reached end of stream", this.taskName);
+  }
+
+  @Override
+  public boolean isWindowableTask() {
+    return false;
+  }
+
+  @Override
+  public Set<String> intermediateStreams() {
+    return Collections.emptySet();
+  }
+
+  @Override
+  public Set<SystemStreamPartition> systemStreamPartitions() {
+    return this.taskSSPs;
+  }
+
+  @Override
+  public OffsetManager offsetManager() {
+    return null;
+  }
+
+  @Override
+  public TaskInstanceMetrics metrics() {
+    return this.metrics;
+  }
+
+  @Override
+  public EpochTimeScheduler epochTimeScheduler() {
+    return null;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
index e292df9..f352bd0 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
@@ -248,7 +248,7 @@
               this.getSerdes(),
               jobConfig,
               new HashMap<>(),
-              new SamzaContainerMetrics(containerModel.getId(), new MetricsRegistryMap()),
+              new SamzaContainerMetrics(containerModel.getId(), new MetricsRegistryMap(), ""),
               JobContextImpl.fromConfigWithDefaults(jobConfig),
               containerContext,
               new HashMap<>(),
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputHandler.java b/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputHandler.java
index 7ab4036..767b9ce 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputHandler.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputHandler.java
@@ -29,6 +29,7 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
 import java.util.stream.Collectors;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
@@ -38,6 +39,7 @@
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
@@ -63,19 +65,23 @@
   private final Map<String, SideInputsProcessor> storeToProcessor;
   private final SystemAdmins systemAdmins;
   private final StreamMetadataCache streamMetadataCache;
+  // indicates to ContainerStorageManager that all side input ssps in this task are caught up
+  private final CountDownLatch taskCaughtUpLatch;
 
+  private Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> initialSideInputSSPMetadata;
   private Map<SystemStreamPartition, String> startingOffsets;
 
   public TaskSideInputHandler(TaskName taskName, TaskMode taskMode, File storeBaseDir,
       Map<String, StorageEngine> storeToStorageEngines, Map<String, Set<SystemStreamPartition>> storeToSSPs,
       Map<String, SideInputsProcessor> storeToProcessor, SystemAdmins systemAdmins,
-      StreamMetadataCache streamMetadataCache, Clock clock) {
+      StreamMetadataCache streamMetadataCache, CountDownLatch taskCaughtUpLatch, Clock clock) {
     validateProcessorConfiguration(storeToSSPs.keySet(), storeToProcessor);
 
     this.taskName = taskName;
     this.systemAdmins = systemAdmins;
     this.streamMetadataCache = streamMetadataCache;
     this.storeToProcessor = storeToProcessor;
+    this.taskCaughtUpLatch = taskCaughtUpLatch;
 
     this.sspToStores = new HashMap<>();
     storeToSSPs.forEach((store, ssps) -> {
@@ -119,6 +125,30 @@
 
     this.startingOffsets = getStartingOffsets(fileOffsets, getOldestOffsets());
     LOG.info("Starting offsets for the task {}: {}", taskName, startingOffsets);
+
+    this.initialSideInputSSPMetadata = getInitialSideInputSSPMetadata();
+    LOG.info("Task {} will catch up to offsets {}", this.taskName, this.initialSideInputSSPMetadata);
+
+    this.startingOffsets.forEach((ssp, offset) -> checkCaughtUp(ssp, offset, SystemStreamMetadata.OffsetType.UPCOMING));
+  }
+
+  /**
+   * Retrieves the newest offset for each SSP
+   *
+   * @return a map of SSP to metadata
+   */
+  private Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> getInitialSideInputSSPMetadata() {
+    Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> initialSideInputSSPMetadata = new HashMap<>();
+    for (SystemStreamPartition ssp : this.sspToStores.keySet()) {
+      boolean partitionsMetadataOnly = false;
+      SystemStreamMetadata systemStreamMetadata = this.streamMetadataCache.getSystemStreamMetadata(ssp.getSystemStream(), partitionsMetadataOnly);
+      if (systemStreamMetadata != null) {
+        SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata =
+            systemStreamMetadata.getSystemStreamPartitionMetadata().get(ssp.getPartition());
+        initialSideInputSSPMetadata.put(ssp, sspMetadata);
+      }
+    }
+    return initialSideInputSSPMetadata;
   }
 
   /**
@@ -150,6 +180,7 @@
     }
 
     this.lastProcessedOffsets.put(envelopeSSP, envelopeOffset);
+    checkCaughtUp(envelopeSSP, envelopeOffset, SystemStreamMetadata.OffsetType.NEWEST);
   }
 
   /**
@@ -193,9 +224,9 @@
   }
 
   /**
-   * Gets the starting offsets for the {@link SystemStreamPartition}s belonging to all the side input stores.
-   * If the local file offset is available and is greater than the oldest available offset from source, uses it,
-   * else falls back to oldest offset in the source.
+   * Gets the starting offsets for the {@link SystemStreamPartition}s belonging to all the side input stores. See doc
+   * of {@link StorageManagerUtil#getStartingOffset} for how file offsets and oldest offsets for each SSP are
+   * reconciled.
    *
    * @param fileOffsets offsets from the local offset file
    * @param oldestOffsets oldest offsets from the source
@@ -260,6 +291,43 @@
   }
 
   /**
+   * An SSP is considered caught up once the offset indicated for it in {@link #initialSideInputSSPMetadata} has been
+   * processed. Once the set of SSPs to catch up becomes empty, the latch for the task will count down, notifying
+   * {@link ContainerStorageManager} that it is caught up.
+   *
+   * @param ssp The SSP to be checked
+   * @param currentOffset The offset to be checked
+   * @param offsetTypeToCheck The type offset to compare {@code currentOffset} to.
+   */
+  private void checkCaughtUp(SystemStreamPartition ssp, String currentOffset, SystemStreamMetadata.OffsetType offsetTypeToCheck) {
+    SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata = this.initialSideInputSSPMetadata.get(ssp);
+    String offsetToCheck = sspMetadata == null ? null : sspMetadata.getOffset(offsetTypeToCheck);
+
+    LOG.trace("Checking offset {} against {} offset {} for {}.", currentOffset, offsetToCheck, offsetTypeToCheck, ssp);
+
+    Integer comparatorResult;
+    if (currentOffset == null || offsetToCheck == null) {
+      comparatorResult = -1;
+    } else {
+      SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(ssp.getSystem());
+      comparatorResult = systemAdmin.offsetComparator(currentOffset, offsetToCheck);
+    }
+
+    // The SSP is no longer lagging if the envelope's offset is greater than or equal to the
+    // latest offset.
+    if (comparatorResult != null && comparatorResult.intValue() >= 0) {
+      LOG.info("Side input ssp {} has caught up to offset {}.", ssp, offsetToCheck);
+      // if its caught up, we remove the ssp from the map
+      this.initialSideInputSSPMetadata.remove(ssp);
+      if (this.initialSideInputSSPMetadata.isEmpty()) {
+        // if the metadata list is now empty, all SSPs in the task are caught up so count down the latch
+        // this will only happen once, when the last ssp catches up
+        this.taskCaughtUpLatch.countDown();
+      }
+    }
+  }
+
+  /**
    * Validates that each store has an associated {@link SideInputsProcessor}
    */
   private void validateProcessorConfiguration(Set<String> stores, Map<String, SideInputsProcessor> storeToProcessor) {
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
index e0b2fdc..73dcd20 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
@@ -26,7 +26,8 @@
 
 class SamzaContainerMetrics(
   val source: String = "unknown",
-  val registry: ReadableMetricsRegistry = new MetricsRegistryMap) extends MetricsHelper {
+  val registry: ReadableMetricsRegistry = new MetricsRegistryMap,
+  val prefix: String = "") extends MetricsHelper {
 
   val commits = newCounter("commit-calls")
   val windows = newCounter("window-calls")
@@ -54,4 +55,5 @@
     taskStoreRestorationMetrics.put(taskName, newGauge("%s-restore-time" format(taskName.toString), -1L))
   }
 
+  override def getPrefix: String = prefix
 }
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index 94cfbdc..bdd773c 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -26,7 +26,8 @@
 
 class TaskInstanceMetrics(
   val source: String = "unknown",
-  val registry: ReadableMetricsRegistry = new MetricsRegistryMap) extends MetricsHelper {
+  val registry: ReadableMetricsRegistry = new MetricsRegistryMap,
+  val prefix: String = "") extends MetricsHelper {
 
   val commits = newCounter("commit-calls")
   val windows = newCounter("window-calls")
@@ -41,4 +42,6 @@
   def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
     newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)
   }
+
+  override def getPrefix: String = prefix
 }
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index 19411b4..412a3c1 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -23,7 +23,6 @@
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import java.io.File;
 import java.nio.file.Path;
-import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
@@ -34,14 +33,12 @@
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.Callable;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.samza.SamzaException;
@@ -50,6 +47,8 @@
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.RunLoop;
+import org.apache.samza.container.RunLoopTask;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
@@ -68,7 +67,6 @@
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
-import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemConsumers;
@@ -110,14 +108,13 @@
 public class ContainerStorageManager {
   private static final Logger LOG = LoggerFactory.getLogger(ContainerStorageManager.class);
   private static final String RESTORE_THREAD_NAME = "Samza Restore Thread-%d";
-  private static final String SIDEINPUTS_READ_THREAD_NAME = "SideInputs Read Thread";
-  private static final String SIDEINPUTS_FLUSH_THREAD_NAME = "SideInputs Flush Thread";
+  private static final String SIDEINPUTS_THREAD_NAME = "SideInputs Thread";
   private static final String SIDEINPUTS_METRICS_PREFIX = "side-inputs-";
   // We use a prefix to differentiate the SystemConsumersMetrics for sideInputs from the ones in SamzaContainer
 
-  private static final int SIDE_INPUT_READ_THREAD_TIMEOUT_SECONDS = 10; // Timeout with which sideinput read thread checks for exceptions
-  private static final Duration SIDE_INPUT_FLUSH_TIMEOUT = Duration.ofMinutes(1); // Period with which sideinputs are flushed
-
+  // Timeout with which sideinput thread checks for exceptions and for whether SSPs as caught up
+  private static final int SIDE_INPUT_CHECK_TIMEOUT_SECONDS = 10;
+  private static final int SIDE_INPUT_SHUTDOWN_TIMEOUT_SECONDS = 60;
 
   /** Maps containing relevant per-task objects */
   private final Map<TaskName, Map<String, StorageEngine>> taskStores;
@@ -154,17 +151,13 @@
   private final Map<TaskName, Map<String, Set<SystemStreamPartition>>> taskSideInputStoreSSPs;
   private final Map<SystemStreamPartition, TaskSideInputHandler> sspSideInputHandlers;
   private SystemConsumers sideInputSystemConsumers;
-  private final Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> initialSideInputSSPMetadata
-      = new ConcurrentHashMap<>(); // Recorded sspMetadata of the taskSideInputSSPs recorded at start, used to determine when sideInputs are caughtup and container init can proceed
-  private volatile CountDownLatch sideInputsCaughtUp; // Used by the sideInput-read thread to signal to the main thread
+  private volatile Map<TaskName, CountDownLatch> sideInputTaskLatches; // Used by the sideInput-read thread to signal to the main thread
   private volatile boolean shouldShutdown = false;
+  private RunLoop sideInputRunLoop;
 
-  private final ExecutorService sideInputsReadExecutor = Executors.newSingleThreadExecutor(
-      new ThreadFactoryBuilder().setDaemon(true).setNameFormat(SIDEINPUTS_READ_THREAD_NAME).build());
+  private final ExecutorService sideInputsExecutor = Executors.newSingleThreadExecutor(
+      new ThreadFactoryBuilder().setDaemon(true).setNameFormat(SIDEINPUTS_THREAD_NAME).build());
 
-  private final ScheduledExecutorService sideInputsFlushExecutor = Executors.newSingleThreadScheduledExecutor(
-      new ThreadFactoryBuilder().setDaemon(true).setNameFormat(SIDEINPUTS_FLUSH_THREAD_NAME).build());
-  private ScheduledFuture sideInputsFlushFuture;
   private volatile Throwable sideInputException = null;
 
   private final Config config;
@@ -195,6 +188,7 @@
     this.checkpointManager = checkpointManager;
     this.containerModel = containerModel;
     this.taskSideInputStoreSSPs = getTaskSideInputSSPs(containerModel, sideInputSystemStreams);
+    this.sideInputTaskLatches = new HashMap<>();
     this.hasSideInputs = this.taskSideInputStoreSSPs.values().stream()
         .flatMap(m -> m.values().stream())
         .flatMap(Collection::stream)
@@ -603,28 +597,35 @@
 
         Map<String, StorageEngine> sideInputStores = getSideInputStores(taskName);
         Map<String, Set<SystemStreamPartition>> sideInputStoresToSSPs = new HashMap<>();
-
+        boolean taskHasSideInputs = false;
         for (String storeName : sideInputStores.keySet()) {
           Set<SystemStreamPartition> storeSSPs = this.taskSideInputStoreSSPs.get(taskName).get(storeName);
+          taskHasSideInputs = taskHasSideInputs || !storeSSPs.isEmpty();
           sideInputStoresToSSPs.put(storeName, storeSSPs);
         }
 
-        TaskSideInputHandler taskSideInputHandler = new TaskSideInputHandler(taskName,
-            taskModel.getTaskMode(),
-            loggedStoreBaseDirectory,
-            sideInputStores,
-            sideInputStoresToSSPs,
-            taskSideInputProcessors.get(taskName),
-            this.systemAdmins,
-            this.streamMetadataCache,
-            clock);
+        if (taskHasSideInputs) {
+          CountDownLatch taskCountDownLatch = new CountDownLatch(1);
+          this.sideInputTaskLatches.put(taskName, taskCountDownLatch);
 
-        sideInputStoresToSSPs.values().stream().flatMap(Set::stream).forEach(ssp -> {
-          handlers.put(ssp, taskSideInputHandler);
-        });
+          TaskSideInputHandler taskSideInputHandler = new TaskSideInputHandler(taskName,
+              taskModel.getTaskMode(),
+              loggedStoreBaseDirectory,
+              sideInputStores,
+              sideInputStoresToSSPs,
+              taskSideInputProcessors.get(taskName),
+              this.systemAdmins,
+              this.streamMetadataCache,
+              taskCountDownLatch,
+              clock);
 
-        LOG.info("Created TaskSideInputHandler for task {}, sideInputStores {} and loggedStoreBaseDirectory {}",
-            taskName, sideInputStores, loggedStoreBaseDirectory);
+          sideInputStoresToSSPs.values().stream().flatMap(Set::stream).forEach(ssp -> {
+            handlers.put(ssp, taskSideInputHandler);
+          });
+
+          LOG.info("Created TaskSideInputHandler for task {}, sideInputStores {} and loggedStoreBaseDirectory {}",
+              taskName, sideInputStores, loggedStoreBaseDirectory);
+        }
       });
     }
     return handlers;
@@ -728,22 +729,26 @@
     // initialize the sideInputStorageManagers
     getSideInputHandlers().forEach(TaskSideInputHandler::init);
 
-    // start the checkpointing thread at the commit-ms frequency
-    TaskConfig taskConfig = new TaskConfig(config);
-    sideInputsFlushFuture = sideInputsFlushExecutor.scheduleWithFixedDelay(new Runnable() {
-      @Override
-      public void run() {
-        try {
-          getSideInputHandlers().forEach(TaskSideInputHandler::flush);
-        } catch (Exception e) {
-          LOG.error("Exception during flushing sideInputs", e);
-          sideInputException = e;
-        }
-      }
-    }, 0, taskConfig.getCommitMs(), TimeUnit.MILLISECONDS);
+    Map<TaskName, TaskSideInputHandler> taskSideInputHandlers = this.sspSideInputHandlers.values().stream()
+        .distinct()
+        .collect(Collectors.toMap(TaskSideInputHandler::getTaskName, Function.identity()));
 
-    // set the latch to the number of sideInput SSPs
-    this.sideInputsCaughtUp = new CountDownLatch(this.sspSideInputHandlers.keySet().size());
+    Map<TaskName, TaskInstanceMetrics> sideInputTaskMetrics = new HashMap<>();
+    Map<TaskName, RunLoopTask> sideInputTasks = new HashMap<>();
+    this.taskSideInputStoreSSPs.forEach((taskName, storesToSSPs) -> {
+      Set<SystemStreamPartition> taskSSPs = this.taskSideInputStoreSSPs.get(taskName).values().stream()
+          .flatMap(Set::stream)
+          .collect(Collectors.toSet());
+
+      if (!taskSSPs.isEmpty()) {
+        String sideInputSource = SIDEINPUTS_METRICS_PREFIX + this.taskInstanceMetrics.get(taskName).source();
+        TaskInstanceMetrics sideInputMetrics = new TaskInstanceMetrics(sideInputSource, this.taskInstanceMetrics.get(taskName).registry(), SIDEINPUTS_METRICS_PREFIX);
+        sideInputTaskMetrics.put(taskName, sideInputMetrics);
+
+        RunLoopTask sideInputTask = new SideInputTask(taskName, taskSSPs, taskSideInputHandlers.get(taskName), sideInputTaskMetrics.get(taskName));
+        sideInputTasks.put(taskName, sideInputTask);
+      }
+    });
 
     // register all sideInput SSPs with the consumers
     for (SystemStreamPartition ssp : this.sspSideInputHandlers.keySet()) {
@@ -758,41 +763,36 @@
       sideInputSystemConsumers.register(ssp, startingOffset);
       taskInstanceMetrics.get(this.sspSideInputHandlers.get(ssp).getTaskName()).addOffsetGauge(
           ssp, ScalaJavaUtil.toScalaFunction(() -> this.sspSideInputHandlers.get(ssp).getLastProcessedOffset(ssp)));
-
-      SystemStreamMetadata systemStreamMetadata = streamMetadataCache.getSystemStreamMetadata(ssp.getSystemStream(), false);
-      SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata =
-          (systemStreamMetadata == null) ? null : systemStreamMetadata.getSystemStreamPartitionMetadata().get(ssp.getPartition());
-
-      // record a copy of the sspMetadata, to later check if its caught up
-      initialSideInputSSPMetadata.put(ssp, sspMetadata);
-
-      // check if the ssp is caught to upcoming, even at start
-      checkSideInputCaughtUp(ssp, startingOffset, SystemStreamMetadata.OffsetType.UPCOMING, false);
+      sideInputTaskMetrics.get(this.sspSideInputHandlers.get(ssp).getTaskName()).addOffsetGauge(
+          ssp, ScalaJavaUtil.toScalaFunction(() -> this.sspSideInputHandlers.get(ssp).getLastProcessedOffset(ssp)));
     }
 
     // start the systemConsumers for consuming input
     this.sideInputSystemConsumers.start();
 
+    TaskConfig taskConfig = new TaskConfig(this.config);
+    SamzaContainerMetrics sideInputContainerMetrics =
+        new SamzaContainerMetrics(SIDEINPUTS_METRICS_PREFIX + this.samzaContainerMetrics.source(),
+            this.samzaContainerMetrics.registry(), SIDEINPUTS_METRICS_PREFIX);
+
+    this.sideInputRunLoop = new RunLoop(sideInputTasks,
+        null, // all operations are executed in the main runloop thread
+        this.sideInputSystemConsumers,
+        1, // single message in flight per task
+        -1, // no windowing
+        taskConfig.getCommitMs(),
+        taskConfig.getCallbackTimeoutMs(),
+        // TODO consolidate these container configs SAMZA-2275
+        this.config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1)),
+        taskConfig.getMaxIdleMs(),
+        sideInputContainerMetrics,
+        System::nanoTime,
+        false); // commit must be synchronous to ensure integrity of state flush
 
     try {
-
-    // submit the sideInput read runnable
-      sideInputsReadExecutor.submit(() -> {
+      sideInputsExecutor.submit(() -> {
         try {
-          while (!shouldShutdown) {
-            IncomingMessageEnvelope envelope = sideInputSystemConsumers.choose(true);
-
-            if (envelope != null) {
-              if (!envelope.isEndOfStream()) {
-                this.sspSideInputHandlers.get(envelope.getSystemStreamPartition()).process(envelope);
-              }
-
-              checkSideInputCaughtUp(envelope.getSystemStreamPartition(), envelope.getOffset(),
-                  SystemStreamMetadata.OffsetType.NEWEST, envelope.isEndOfStream());
-            } else {
-              LOG.trace("No incoming message was available");
-            }
-          }
+          sideInputRunLoop.run();
         } catch (Exception e) {
           LOG.error("Exception in reading sideInputs", e);
           sideInputException = e;
@@ -801,7 +801,7 @@
 
       // Make the main thread wait until all sideInputs have been caughtup or an exception was thrown
       while (!shouldShutdown && sideInputException == null &&
-          !this.sideInputsCaughtUp.await(SIDE_INPUT_READ_THREAD_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
+          !awaitSideInputTasks()) {
         LOG.debug("Waiting for SideInput bootstrap to complete");
       }
 
@@ -824,39 +824,22 @@
     LOG.info("SideInput Restore complete");
   }
 
-  // Method to check if the given offset means the stream is caught up for reads
-  private void checkSideInputCaughtUp(SystemStreamPartition ssp, String offset, SystemStreamMetadata.OffsetType offsetType, boolean isEndOfStream) {
-
-    if (isEndOfStream) {
-      this.initialSideInputSSPMetadata.remove(ssp);
-      this.sideInputsCaughtUp.countDown();
-      LOG.info("Side input ssp {} has caught up to offset {} ({}).", ssp, offset, offsetType);
-      return;
+  /**
+   * Waits for all side input tasks to catch up until a timeout.
+   *
+   * @return False if waiting on any latch timed out, true otherwise
+   *
+   * @throws InterruptedException if waiting any of the latches is interrupted
+   */
+  private boolean awaitSideInputTasks() throws InterruptedException {
+    long endTime = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(SIDE_INPUT_CHECK_TIMEOUT_SECONDS);
+    for (CountDownLatch latch : this.sideInputTaskLatches.values()) {
+      long remainingMillisToWait = endTime - System.currentTimeMillis();
+      if (remainingMillisToWait <= 0 || !latch.await(remainingMillisToWait, TimeUnit.MILLISECONDS)) {
+        return false;
+      }
     }
-
-    SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata = this.initialSideInputSSPMetadata.get(ssp);
-    String offsetToCheck = sspMetadata == null ? null : sspMetadata.getOffset(offsetType);
-    LOG.trace("Checking {} offset {} against {} for {}.", offsetType, offset, offsetToCheck, ssp);
-
-    // Let's compare offset of the chosen message with offsetToCheck.
-    Integer comparatorResult;
-    if (offset == null || offsetToCheck == null) {
-      comparatorResult = -1;
-    } else {
-      SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(ssp.getSystem());
-      comparatorResult = systemAdmin.offsetComparator(offset, offsetToCheck);
-    }
-
-    // The SSP is no longer lagging if the envelope's offset is greater than or equal to the
-    // latest offset.
-    if (comparatorResult != null && comparatorResult.intValue() >= 0) {
-
-      LOG.info("Side input ssp {} has caught up to offset {} ({}).", ssp, offset, offsetType);
-      // if its caught up, we remove the ssp from the map, and countDown the latch
-      this.initialSideInputSSPMetadata.remove(ssp);
-      this.sideInputsCaughtUp.countDown();
-      return;
-    }
+    return true;
   }
 
   /**
@@ -901,19 +884,16 @@
 
     // stop all sideinput consumers and stores
     if (this.hasSideInputs) {
-      sideInputsReadExecutor.shutdownNow();
-
-      this.sideInputSystemConsumers.stop();
-
-      // cancel all future sideInput flushes, shutdown the executor, and await for finish
-      sideInputsFlushFuture.cancel(false);
-      sideInputsFlushExecutor.shutdown();
+      this.sideInputRunLoop.shutdown();
+      this.sideInputsExecutor.shutdown();
       try {
-        sideInputsFlushExecutor.awaitTermination(SIDE_INPUT_FLUSH_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
+        this.sideInputsExecutor.awaitTermination(SIDE_INPUT_SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS);
       } catch (InterruptedException e) {
         throw new SamzaException("Exception while shutting down sideInputs", e);
       }
 
+      this.sideInputSystemConsumers.stop();
+
       // stop all sideInputStores -- this will perform one last flush on the KV stores, and write the offset file
       this.getSideInputHandlers().forEach(TaskSideInputHandler::stop);
     }
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index da855a1..90d4c33 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -50,7 +50,7 @@
 public class TestRunLoop {
   // Immutable objects shared by all test methods.
   private final ExecutorService executor = null;
-  private final SamzaContainerMetrics containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap());
+  private final SamzaContainerMetrics containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap(), "");
   private final long windowMs = -1;
   private final long commitMs = -1;
   private final long callbackTimeoutMs = 0;
@@ -522,7 +522,7 @@
   private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ssp0) {
     RunLoopTask task0 = mock(RunLoopTask.class);
     when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
-    when(task0.metrics()).thenReturn(new TaskInstanceMetrics("test", new MetricsRegistryMap()));
+    when(task0.metrics()).thenReturn(new TaskInstanceMetrics("test", new MetricsRegistryMap(), ""));
     when(task0.taskName()).thenReturn(taskName);
     return task0;
   }
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestTaskSideInputHandler.java b/samza-core/src/test/java/org/apache/samza/storage/TestTaskSideInputHandler.java
index 656b2ef..2ef206e 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/TestTaskSideInputHandler.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestTaskSideInputHandler.java
@@ -23,6 +23,7 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CountDownLatch;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
@@ -191,6 +192,7 @@
           storeToProcessor,
           systemAdmins,
           streamMetadataCache,
+          new CountDownLatch(1),
           clock));
     }
   }
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
index ab8a29e..c7b3f47 100644
--- a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
@@ -38,7 +38,7 @@
 
   @Before
   public void setup() {
-    TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap());
+    TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap(), "");
     listener = new TaskCallbackListener() {
       @Override
       public void onComplete(TaskCallback callback) {
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
index a9c0c09..fc6cb53 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
@@ -24,7 +24,7 @@
 import org.apache.samza.util.Logging
 import org.apache.samza.storage.{StorageEngine, StoreProperties}
 import org.apache.samza.system.{ChangelogSSPIterator, OutgoingMessageEnvelope, SystemStreamPartition}
-import org.apache.samza.task.MessageCollector
+import org.apache.samza.task.{MessageCollector, TaskInstanceCollector}
 import org.apache.samza.util.TimerUtil
 import java.nio.file.Path
 import java.util.Optional