SAMZA-2493: Keep checkpoint manager consumer open for repeated polling (#1327)
diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index b02f6c9..468d9c9 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -105,6 +105,8 @@
private static final String BROADCAST_STREAM_PATTERN = "^[\\d]+$";
private static final String BROADCAST_STREAM_RANGE_PATTERN = "^\\[[\\d]+\\-[\\d]+\\]$";
public static final String CHECKPOINT_MANAGER_FACTORY = "task.checkpoint.factory";
+ // standby containers use this flag to indicate that checkpoints will be polled continually, rather than only once at startup like in an active container
+ public static final String INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ = "samza.internal.task.checkpoint.consumer.stop.after.first.read";
public static final String TRANSACTIONAL_STATE_CHECKPOINT_ENABLED = "task.transactional.state.checkpoint.enabled";
private static final boolean DEFAULT_TRANSACTIONAL_STATE_CHECKPOINT_ENABLED = true;
@@ -214,6 +216,14 @@
}
/**
+ * Internal config to indicate whether the SystemConsumer underlying a CheckpointManager should be stopped after
+ * initial read of checkpoints.
+ */
+ public boolean getCheckpointManagerConsumerStopAfterFirstRead() {
+ return getBoolean(INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, true);
+ }
+
+ /**
* Get the systemStreamPartitions of the broadcast stream. Specifying
* one partition for one stream or a range of the partitions for one
* stream is allowed.
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index c78e841..6fab351 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -50,6 +50,7 @@
import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
import org.apache.samza.util.{Util, _}
import org.apache.samza.SamzaException
+import org.apache.samza.clustermanager.StandbyTaskUtil
import scala.collection.JavaConverters._
@@ -132,7 +133,14 @@
localityManager: LocalityManager = null,
startpointManager: StartpointManager = null,
diagnosticsManager: Option[DiagnosticsManager] = Option.empty) = {
- val config = jobContext.getConfig
+ val config = if (StandbyTaskUtil.isStandbyContainer(containerId)) {
+ // standby containers will need to continually poll checkpoint messages
+ val newConfig = new util.HashMap[String, String](jobContext.getConfig)
+ newConfig.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, java.lang.Boolean.FALSE.toString)
+ new MapConfig(newConfig)
+ } else {
+ jobContext.getConfig
+ }
val jobConfig = new JobConfig(config)
val systemConfig = new SystemConfig(config)
val containerModel = jobModel.getContainers.get(containerId)
diff --git a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
index 87c84aa..1c3531f 100644
--- a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
+++ b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
@@ -26,7 +26,7 @@
import com.google.common.annotations.VisibleForTesting
import com.google.common.base.Preconditions
import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
-import org.apache.samza.config.{Config, JobConfig}
+import org.apache.samza.config.{Config, JobConfig, TaskConfig}
import org.apache.samza.container.TaskName
import org.apache.samza.serializers.Serde
import org.apache.samza.metrics.MetricsRegistry
@@ -76,6 +76,11 @@
val producerRef: AtomicReference[SystemProducer] = new AtomicReference[SystemProducer](getSystemProducer())
val producerCreationLock: Object = new Object
+ // if true, systemConsumer can be safely closed after the first call to readLastCheckpoint.
+ // if false, it must be left open until KafkaCheckpointManager::stop is called.
+ // for active containers, this will be set to true, while false for standby containers.
+ val stopConsumerAfterFirstRead: Boolean = new TaskConfig(config).getCheckpointManagerConsumerStopAfterFirstRead
+
/**
* Create checkpoint stream prior to start.
*/
@@ -107,7 +112,6 @@
info(s"Starting the checkpoint SystemConsumer from oldest offset $oldestOffset")
systemConsumer.register(checkpointSsp, oldestOffset)
systemConsumer.start()
- // the consumer will be closed after first time reading the checkpoint
}
/**
@@ -132,9 +136,12 @@
if (taskNamesToCheckpoints == null) {
info("Reading checkpoints for the first time")
taskNamesToCheckpoints = readCheckpoints()
- // Stop the system consumer since we only need to read checkpoints once
- info("Stopping system consumer.")
- systemConsumer.stop()
+ if (stopConsumerAfterFirstRead) {
+ info("Stopping system consumer")
+ systemConsumer.stop()
+ }
+ } else if (!stopConsumerAfterFirstRead) {
+ taskNamesToCheckpoints ++= readCheckpoints()
}
val checkpoint: Checkpoint = taskNamesToCheckpoints.getOrElse(taskName, null)
@@ -220,6 +227,11 @@
info ("Stopping system producer.")
producerRef.get().stop()
+ if (!stopConsumerAfterFirstRead) {
+ info("Stopping system consumer")
+ systemConsumer.stop()
+ }
+
info("CheckpointManager stopped.")
}
diff --git a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
index 9766ce8..2e7a7e4 100644
--- a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
+++ b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
@@ -38,6 +38,7 @@
import org.junit.Assert._
import org.junit._
import org.mockito.Mockito
+import org.mockito.Matchers
class TestKafkaCheckpointManager extends KafkaServerTestHarness {
@@ -129,18 +130,13 @@
def testWriteCheckpointShouldRetryFiniteTimesOnFailure(): Unit = {
val checkpointTopic = "checkpoint-topic-2"
val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
-
- class MockSystemFactory extends KafkaSystemFactory {
- override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
- mockKafkaProducer
- }
- }
+ val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
Mockito.doThrow(new RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
- val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory, false, config, new NoOpMetricsRegistry)
+ val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer, mockKafkaProducer), false, config, new NoOpMetricsRegistry)
checkPointManager.MaxRetryDurationInMillis = 1
try {
@@ -186,6 +182,55 @@
kcm.stop()
}
+ @Test
+ def testConsumerStopsAfterInitialReadIfConfigSetTrue(): Unit = {
+ val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
+
+ val checkpointTopic = "checkpoint-topic-test"
+ val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
+ val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
+
+ val configMapWithOverride = new java.util.HashMap[String, String](config)
+ configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "true")
+ val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
+
+ kafkaCheckpointManager.register(taskName)
+ kafkaCheckpointManager.start()
+ kafkaCheckpointManager.readLastCheckpoint(taskName)
+
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).register(Matchers.any(), Matchers.any())
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).start()
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).poll(Matchers.any(), Matchers.any())
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
+
+ kafkaCheckpointManager.stop()
+
+ Mockito.verifyNoMoreInteractions(mockKafkaSystemConsumer)
+ }
+
+ @Test
+ def testConsumerDoesNotStopAfterInitialReadIfConfigSetFalse(): Unit = {
+ val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
+
+ val checkpointTopic = "checkpoint-topic-test"
+ val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
+ val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
+
+ val configMapWithOverride = new java.util.HashMap[String, String](config)
+ configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "false")
+ val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
+
+ kafkaCheckpointManager.register(taskName)
+ kafkaCheckpointManager.start()
+ kafkaCheckpointManager.readLastCheckpoint(taskName)
+
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(0)).stop()
+
+ kafkaCheckpointManager.stop()
+
+ Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
+ }
+
@After
override def tearDown(): Unit = {
if (servers != null) {
@@ -243,4 +288,16 @@
adminZkClient.createTopic(cpTopic, partNum, 1, props)
}
+ class MockSystemFactory(
+ mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer]),
+ mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])) extends KafkaSystemFactory {
+ override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
+ mockKafkaProducer
+ }
+
+ override def getConsumer(systemName: String, config: Config, registry: MetricsRegistry): SystemConsumer = {
+ mockKafkaSystemConsumer
+ }
+ }
+
}