Merge pull request #1599 from lakshmi-manasa-g/elasticity-checkpoint-readWrite

SAMZA-2734: [Elasticity] Update last processed offset after an envelope is finished processing when elasticity is enabled
diff --git a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
index 7a12625..1bd5561 100644
--- a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
@@ -216,9 +216,19 @@
    * Set the last processed offset for a given SystemStreamPartition.
    */
   def update(taskName: TaskName, systemStreamPartition: SystemStreamPartition, offset: String) {
+    // without elasticity enabled, there is exactly one entry of an ssp in the systemStreamPartitions map for a taskName
+    // with elasticity enabled, there is exactly one of the keyBuckets of an ssp that a task consumes
+    // and hence exactly one entry of an ssp with keyBucket in in the systemStreamPartitions map for a taskName
+    // hence from the given ssp, find its sspWithKeybucket for the task and use that for updating lastProcessedOffsets
+    val sspWithKeyBucket = systemStreamPartitions.getOrElse(taskName,
+      throw new SamzaException("No SSPs registered for task: " + taskName))
+      .filter(ssp => ssp.getSystemStream.equals(systemStreamPartition.getSystemStream)
+        && ssp.getPartition.equals(systemStreamPartition.getPartition))
+      .toIterator.next()
+
     lastProcessedOffsets.putIfAbsent(taskName, new ConcurrentHashMap[SystemStreamPartition, String]())
     if (offset != null && !offset.equals(IncomingMessageEnvelope.END_OF_STREAM_OFFSET)) {
-      lastProcessedOffsets.get(taskName).put(systemStreamPartition, offset)
+      lastProcessedOffsets.get(taskName).put(sspWithKeyBucket, offset)
     }
   }
 
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index 9367cd7..4a24303 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -188,6 +188,8 @@
       // but the actual systemConsumer which consumes from the input does not know about KeyBucket.
       // hence, use an SSP without KeyBucket
       consumer.register(removeKeyBucket(systemStreamPartition), offset)
+      chooser.register(removeKeyBucket(systemStreamPartition), offset)
+      debug("consumer.register and chooser.register for ssp: %s with offset %s" format (systemStreamPartition, offset))
     }
 
     debug("Starting consumers.")
@@ -244,8 +246,6 @@
     metrics.registerSystemStreamPartition(systemStreamPartition)
     unprocessedMessagesBySSP.put(systemStreamPartition, new ArrayDeque[IncomingMessageEnvelope]())
 
-    chooser.register(systemStreamPartition, offset)
-
     try {
       val consumer = consumers(systemStreamPartition.getSystem)
       val existingOffset = sspToRegisteredOffsets.get(systemStreamPartition)
diff --git a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
index 3949ecf..3c226c4 100644
--- a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
+++ b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
@@ -579,6 +579,47 @@
     assertEquals("60", modifiedOffsets.get(ssp6))
   }
 
+  @Test
+  def testElasticityUpdateWithoutKeyBucket: Unit = {
+    // When elasticity is enabled, task consumes a part of the full SSP represented by SSP With KeyBucket.
+    // OffsetManager tracks the set of SSP with KeyBucket consumed by a task.
+    // However, after an IME processing is complete, OffsetManager.update is called without KeyBuket.
+    // OffsetManager has to find and udpate the last processed offset for the task correctly for its SSP with KeyBucket.
+    val taskName = new TaskName("c")
+    val systemStream = new SystemStream("test-system", "test-stream")
+    val partition = new Partition(0)
+    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
+    val systemStreamPartitionWithKeyBucket = new SystemStreamPartition(systemStreamPartition, 0);
+    val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
+    val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
+    val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
+    val startpointManagerUtil = getStartpointManagerUtil()
+    val systemAdmins = mock(classOf[SystemAdmins])
+    when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
+    val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
+    // register task and its input SSP with KeyBucket
+    offsetManager.register(taskName, Set(systemStreamPartitionWithKeyBucket))
+
+    offsetManager.start
+
+    // update is called with only the full SSP and no keyBucket information.
+    offsetManager.update(taskName, systemStreamPartition, "46")
+    // Get checkpoint snapshot like we do at the beginning of TaskInstance.commit()
+    val checkpoint46 = offsetManager.getLastProcessedOffsets(taskName)
+    offsetManager.update(taskName, systemStreamPartition, "47") // Offset updated before checkpoint
+    offsetManager.writeCheckpoint(taskName, new CheckpointV1(checkpoint46))
+    // OffsetManager correctly updates the lastProcessedOffset and checkpoint for the task and input SSP with KeyBucket.
+    assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartitionWithKeyBucket))
+    assertEquals("46", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartitionWithKeyBucket).getValue)
+
+    // Now write the checkpoint for the latest offset
+    val checkpoint47 = offsetManager.getLastProcessedOffsets(taskName)
+    offsetManager.writeCheckpoint(taskName, new CheckpointV1(checkpoint47))
+
+    assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartitionWithKeyBucket))
+    assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartitionWithKeyBucket).getValue)
+  }
+
   // Utility method to create and write checkpoint in one statement
   def checkpoint(offsetManager: OffsetManager, taskName: TaskName): Unit = {
     offsetManager.writeCheckpoint(taskName, new CheckpointV1(offsetManager.getLastProcessedOffsets(taskName)))