SAMZA-2493: Keep checkpoint manager consumer open for repeated polling (#1327)

diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index b02f6c9..468d9c9 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -105,6 +105,8 @@
   private static final String BROADCAST_STREAM_PATTERN = "^[\\d]+$";
   private static final String BROADCAST_STREAM_RANGE_PATTERN = "^\\[[\\d]+\\-[\\d]+\\]$";
   public static final String CHECKPOINT_MANAGER_FACTORY = "task.checkpoint.factory";
+  // standby containers use this flag to indicate that checkpoints will be polled continually, rather than only once at startup like in an active container
+  public static final String INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ = "samza.internal.task.checkpoint.consumer.stop.after.first.read";
 
   public static final String TRANSACTIONAL_STATE_CHECKPOINT_ENABLED = "task.transactional.state.checkpoint.enabled";
   private static final boolean DEFAULT_TRANSACTIONAL_STATE_CHECKPOINT_ENABLED = true;
@@ -214,6 +216,14 @@
   }
 
   /**
+   * Internal config to indicate whether the SystemConsumer underlying a CheckpointManager should be stopped after
+   * initial read of checkpoints.
+   */
+  public boolean getCheckpointManagerConsumerStopAfterFirstRead() {
+    return getBoolean(INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, true);
+  }
+
+  /**
    * Get the systemStreamPartitions of the broadcast stream. Specifying
    * one partition for one stream or a range of the partitions for one
    * stream is allowed.
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index c78e841..6fab351 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -50,6 +50,7 @@
 import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
 import org.apache.samza.util.{Util, _}
 import org.apache.samza.SamzaException
+import org.apache.samza.clustermanager.StandbyTaskUtil
 
 import scala.collection.JavaConverters._
 
@@ -132,7 +133,14 @@
     localityManager: LocalityManager = null,
     startpointManager: StartpointManager = null,
     diagnosticsManager: Option[DiagnosticsManager] = Option.empty) = {
-    val config = jobContext.getConfig
+    val config = if (StandbyTaskUtil.isStandbyContainer(containerId)) {
+      // standby containers will need to continually poll checkpoint messages
+      val newConfig = new util.HashMap[String, String](jobContext.getConfig)
+      newConfig.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, java.lang.Boolean.FALSE.toString)
+      new MapConfig(newConfig)
+    } else {
+      jobContext.getConfig
+    }
     val jobConfig = new JobConfig(config)
     val systemConfig = new SystemConfig(config)
     val containerModel = jobModel.getContainers.get(containerId)
diff --git a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
index 87c84aa..1c3531f 100644
--- a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
+++ b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
@@ -26,7 +26,7 @@
 import com.google.common.annotations.VisibleForTesting
 import com.google.common.base.Preconditions
 import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
-import org.apache.samza.config.{Config, JobConfig}
+import org.apache.samza.config.{Config, JobConfig, TaskConfig}
 import org.apache.samza.container.TaskName
 import org.apache.samza.serializers.Serde
 import org.apache.samza.metrics.MetricsRegistry
@@ -76,6 +76,11 @@
   val producerRef: AtomicReference[SystemProducer] = new AtomicReference[SystemProducer](getSystemProducer())
   val producerCreationLock: Object = new Object
 
+  // if true, systemConsumer can be safely closed after the first call to readLastCheckpoint.
+  // if false, it must be left open until KafkaCheckpointManager::stop is called.
+  // for active containers, this will be set to true, while false for standby containers.
+  val stopConsumerAfterFirstRead: Boolean = new TaskConfig(config).getCheckpointManagerConsumerStopAfterFirstRead
+
   /**
     * Create checkpoint stream prior to start.
     */
@@ -107,7 +112,6 @@
     info(s"Starting the checkpoint SystemConsumer from oldest offset $oldestOffset")
     systemConsumer.register(checkpointSsp, oldestOffset)
     systemConsumer.start()
-    // the consumer will be closed after first time reading the checkpoint
   }
 
   /**
@@ -132,9 +136,12 @@
     if (taskNamesToCheckpoints == null) {
       info("Reading checkpoints for the first time")
       taskNamesToCheckpoints = readCheckpoints()
-      // Stop the system consumer since we only need to read checkpoints once
-      info("Stopping system consumer.")
-      systemConsumer.stop()
+      if (stopConsumerAfterFirstRead) {
+        info("Stopping system consumer")
+        systemConsumer.stop()
+      }
+    } else if (!stopConsumerAfterFirstRead) {
+      taskNamesToCheckpoints ++= readCheckpoints()
     }
 
     val checkpoint: Checkpoint = taskNamesToCheckpoints.getOrElse(taskName, null)
@@ -220,6 +227,11 @@
     info ("Stopping system producer.")
     producerRef.get().stop()
 
+    if (!stopConsumerAfterFirstRead) {
+      info("Stopping system consumer")
+      systemConsumer.stop()
+    }
+
     info("CheckpointManager stopped.")
   }
 
diff --git a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
index 9766ce8..2e7a7e4 100644
--- a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
+++ b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
@@ -38,6 +38,7 @@
 import org.junit.Assert._
 import org.junit._
 import org.mockito.Mockito
+import org.mockito.Matchers
 
 class TestKafkaCheckpointManager extends KafkaServerTestHarness {
 
@@ -129,18 +130,13 @@
   def testWriteCheckpointShouldRetryFiniteTimesOnFailure(): Unit = {
     val checkpointTopic = "checkpoint-topic-2"
     val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
-
-    class MockSystemFactory extends KafkaSystemFactory {
-      override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
-        mockKafkaProducer
-      }
-    }
+    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
 
     Mockito.doThrow(new RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
 
     val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
     val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
-    val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory, false, config, new NoOpMetricsRegistry)
+    val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer, mockKafkaProducer), false, config, new NoOpMetricsRegistry)
     checkPointManager.MaxRetryDurationInMillis = 1
 
     try {
@@ -186,6 +182,55 @@
     kcm.stop()
   }
 
+  @Test
+  def testConsumerStopsAfterInitialReadIfConfigSetTrue(): Unit = {
+    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
+
+    val checkpointTopic = "checkpoint-topic-test"
+    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
+    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
+
+    val configMapWithOverride = new java.util.HashMap[String, String](config)
+    configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "true")
+    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
+
+    kafkaCheckpointManager.register(taskName)
+    kafkaCheckpointManager.start()
+    kafkaCheckpointManager.readLastCheckpoint(taskName)
+
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).register(Matchers.any(), Matchers.any())
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).start()
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).poll(Matchers.any(), Matchers.any())
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
+
+    kafkaCheckpointManager.stop()
+
+    Mockito.verifyNoMoreInteractions(mockKafkaSystemConsumer)
+  }
+
+  @Test
+  def testConsumerDoesNotStopAfterInitialReadIfConfigSetFalse(): Unit = {
+    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
+
+    val checkpointTopic = "checkpoint-topic-test"
+    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
+    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
+
+    val configMapWithOverride = new java.util.HashMap[String, String](config)
+    configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "false")
+    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
+
+    kafkaCheckpointManager.register(taskName)
+    kafkaCheckpointManager.start()
+    kafkaCheckpointManager.readLastCheckpoint(taskName)
+
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(0)).stop()
+
+    kafkaCheckpointManager.stop()
+
+    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
+  }
+
   @After
   override def tearDown(): Unit = {
     if (servers != null) {
@@ -243,4 +288,16 @@
     adminZkClient.createTopic(cpTopic, partNum, 1, props)
   }
 
+  class MockSystemFactory(
+    mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer]),
+    mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])) extends KafkaSystemFactory {
+    override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
+      mockKafkaProducer
+    }
+
+    override def getConsumer(systemName: String, config: Config, registry: MetricsRegistry): SystemConsumer = {
+      mockKafkaSystemConsumer
+    }
+  }
+
 }