SAMZA-2749: Startpoint bug fix (#1615)

Symptom:
Using startpoints to trigger full bootstrapping is not reliable in the current implementation, we observed that the bootstrapping only happened on the part of expected partitions.

Cause:
Within Samza (the main class to pay attention to is OffsetManager.scala), there is a bug in which a startpoint can be deleted before the startpoint actually gets used for message consumption. If a container gets into this situation, then the result is that the startpoint is ignored and consumption will continue from the previous processed message from before the startpoint was applied.

Load last processed offsets and startpoints
Use startpoints to register starting offsets for consumers
Message processing starts, but messages for only some of the partitions are received
Write checkpoint using last processed offsets
If a partition did not get messages, then the last processed offset is still the offset from before the standpoint.
Delete startpoints
Container dies (e.g. due to running out of memory)
On restart, load last processed offsets (startpoints have been deleted)
The partitions that did have messages in the previous deployment will have the correct checkpoint.
The partitions that did not have messages will have the checkpoint set to the offset from before the startpoint was applied. This is unexpected, and it means that bootstrapping is not happening for this partition.
Changes:

Keep track of the partitions which have updated processed offsets, and only delete the startpoint for those partitions after checkpointing.

API Changes:

Added a new API removeFanOutForTaskSSPs in StartpointManager to allow clean up the fan outs on partition granularity
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
index 083b483..2c04ea1 100644
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
@@ -317,15 +317,58 @@
    * @return fanned out Startpoints
    */
   public Map<SystemStreamPartition, Startpoint> getFanOutForTask(TaskName taskName) throws IOException {
+    return getStartpointFanOutPerTask(taskName)
+        .map(startpointFanOutPerTask -> ImmutableMap.copyOf(startpointFanOutPerTask.getFanOuts())).orElse(ImmutableMap.of());
+  }
+
+  private Optional<StartpointFanOutPerTask> getStartpointFanOutPerTask(TaskName taskName) throws IOException {
     Preconditions.checkState(!stopped, "Underlying metadata store not available");
     Preconditions.checkNotNull(taskName, "TaskName cannot be null");
 
     byte[] fanOutBytes = fanOutStore.get(toFanOutStoreKey(taskName));
     if (ArrayUtils.isEmpty(fanOutBytes)) {
-      return ImmutableMap.of();
+      return Optional.empty();
     }
-    StartpointFanOutPerTask startpointFanOutPerTask = objectMapper.readValue(fanOutBytes, StartpointFanOutPerTask.class);
-    return ImmutableMap.copyOf(startpointFanOutPerTask.getFanOuts());
+    return Optional.of(objectMapper.readValue(fanOutBytes, StartpointFanOutPerTask.class));
+  }
+
+  /**
+   * Remove the fanned out startpoints for the specified the system stream partitions of the given task. This method
+   * allows to partially remove the fanned out startpoints for the given task.
+   *
+   * Remove the whole task fan out from the store if the fan outs of all system stream partitions of the task are
+   * removed. No action takes if not any specify system stream partition
+   *
+   * @param taskName to (partially) remove the fanned out startpoints for
+   * @param ssps to remove the fanned out startpoints for
+   */
+  public void removeFanOutForTaskSSPs(TaskName taskName, Set<SystemStreamPartition> ssps) {
+    Preconditions.checkState(!stopped, "Underlying metadata store not available");
+    Preconditions.checkNotNull(taskName, "TaskName cannot be null");
+    if (ssps == null || ssps.isEmpty()) {
+      return;
+    }
+    try {
+      getStartpointFanOutPerTask(taskName).ifPresent(fanOutPerTask -> {
+        Map<SystemStreamPartition, Startpoint> fanOuts = fanOutPerTask.getFanOuts();
+        fanOuts.entrySet().removeIf(e -> ssps.contains(e.getKey()));
+        if (fanOuts.isEmpty()) {
+          removeFanOutForTask(taskName);
+          LOG.debug("Deleted the fanned out startpoints for the task {}", taskName);
+        } else {
+          try {
+            fanOutStore.put(toFanOutStoreKey(taskName), objectMapper.writeValueAsBytes(fanOutPerTask));
+          } catch (IOException e) {
+            LOG.error("Can't update the fanned out startpoints for task {}", taskName, e);
+            throw new SamzaException(e);
+          }
+          fanOutStore.flush();
+        }
+      });
+    } catch (IOException e) {
+      LOG.error("Can't get the fanned out startpoints for task {}", taskName, e);
+      throw new SamzaException(e);
+    }
   }
 
   /**
diff --git a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
index bcb59e0..311dc6a 100644
--- a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
@@ -171,6 +171,11 @@
   val lastProcessedOffsets = new ConcurrentHashMap[TaskName, ConcurrentHashMap[SystemStreamPartition, String]]()
 
   /**
+   * The task's SSPs have received new messages and been updated the offsets
+   */
+  val taskSSPsWithProcessedOffsetUpdated = new ConcurrentHashMap[TaskName, ConcurrentHashMap[SystemStreamPartition, Boolean]]()
+
+  /**
    * Offsets to start reading from for each SystemStreamPartition. This
    * variable is populated after all checkpoints have been restored.
    */
@@ -221,8 +226,15 @@
       .toIterator.next()
 
     lastProcessedOffsets.putIfAbsent(taskName, new ConcurrentHashMap[SystemStreamPartition, String]())
-    if (offset != null && !offset.equals(IncomingMessageEnvelope.END_OF_STREAM_OFFSET)) {
-      lastProcessedOffsets.get(taskName).put(sspWithKeyBucket, offset)
+    taskSSPsWithProcessedOffsetUpdated.putIfAbsent(taskName, new ConcurrentHashMap[SystemStreamPartition, Boolean]())
+
+    if (offset != null) {
+      if (!offset.equals(IncomingMessageEnvelope.END_OF_STREAM_OFFSET)) {
+        lastProcessedOffsets.get(taskName).put(sspWithKeyBucket, offset)
+      }
+      // Record the spp that have received the new messages. The startpoint for each ssp should only be deleted when the
+      // ssp has received the new messages. More details in SAMZA-2749.
+      taskSSPsWithProcessedOffsetUpdated.get(taskName).putIfAbsent(sspWithKeyBucket, true)
     }
   }
 
@@ -394,10 +406,24 @@
     }
 
     // delete corresponding startpoints after checkpoint is supposed to be committed
-    if (startpointManager != null && startpoints.contains(taskName)) {
-      info("%d startpoint(s) for taskName: %s have been committed to the checkpoint." format (startpoints.get(taskName).size, taskName.getTaskName))
-      startpointManager.removeFanOutForTask(taskName)
-      startpoints -= taskName
+    if (startpointManager != null && startpoints.contains(taskName) && taskSSPsWithProcessedOffsetUpdated.containsKey(taskName)) {
+      val sspsWithProcessedOffsetUpdated = taskSSPsWithProcessedOffsetUpdated.get(taskName).keySet()
+      startpointManager.removeFanOutForTaskSSPs(taskName, sspsWithProcessedOffsetUpdated)
+      // Remove the startpoints for the ssps that have received the updates of processed offsets. if all ssps of the
+      // task have received the updates of processed offsets, remove the whole task's startpoints.
+      startpoints.get(taskName) match {
+        case Some(sspToStartpoint) => {
+          val newSspToStartpoint = sspToStartpoint.filterKeys(ssp => !sspsWithProcessedOffsetUpdated.contains(ssp)).toMap
+          if (newSspToStartpoint.isEmpty) {
+            startpoints -= taskName
+            info("All startpoints for the taskName: %s have been committed to the checkpoint." format(taskName))
+          } else {
+            startpoints += taskName -> newSspToStartpoint
+            debug("Updated the startpoints and the latest startpoints for the task %s: %s" format(taskName, newSspToStartpoint))
+          }
+        }
+        case None => {}
+      }
 
       if (startpoints.isEmpty) {
         info("All outstanding startpoints have been committed to the checkpoint.")
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
index d5274a8..e9f3ad6 100644
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
@@ -142,6 +142,11 @@
       startpointManager.removeFanOutForTask(new TaskName("t0"));
       Assert.fail("Expected precondition exception.");
     } catch (IllegalStateException ex) { }
+
+    try {
+      startpointManager.removeFanOutForTaskSSPs(new TaskName("t0"), ImmutableSet.of(ssp));
+      Assert.fail("Expected precondition exception.");
+    } catch (IllegalStateException ex) { }
   }
 
   @Test
@@ -350,6 +355,33 @@
   }
 
   @Test
+  public void testRemoveFanOutForTaskSSPs() throws Exception {
+    SystemStreamPartition ssp0 = new SystemStreamPartition("mockSystem", "mockStream", new Partition(0));
+    SystemStreamPartition ssp1 = new SystemStreamPartition("mockSystem", "mockStream", new Partition(1));
+    TaskName taskName = new TaskName("mockTask");
+    Map<TaskName, Set<SystemStreamPartition>> taskToSSPs = ImmutableMap.of(taskName, ImmutableSet.of(ssp0, ssp1));
+    StartpointSpecific startpoint42 = new StartpointSpecific("42");
+    startpointManager.writeStartpoint(ssp0, startpoint42);
+    startpointManager.writeStartpoint(ssp1, startpoint42);
+    Map<TaskName, Map<SystemStreamPartition, Startpoint>> tasksFannedOutTo = startpointManager.fanOut(taskToSSPs);
+    Assert.assertEquals(ImmutableSet.of(taskName), tasksFannedOutTo.keySet());
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(ssp0).isPresent());
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(ssp1).isPresent());
+
+    // no action takes if not specify any system stream partition
+    startpointManager.removeFanOutForTaskSSPs(taskName, ImmutableSet.of());
+    Assert.assertEquals(ImmutableMap.of(ssp0, startpoint42, ssp1, startpoint42), startpointManager.getFanOutForTask(taskName));
+
+    // partially removal: remove the fanned out startpoint for the specified system stream partition only
+    startpointManager.removeFanOutForTaskSSPs(taskName, ImmutableSet.of(ssp0));
+    Assert.assertEquals(ImmutableMap.of(ssp1, startpoint42), startpointManager.getFanOutForTask(taskName));
+
+    // remove the whole task's startpoints if all the task's partitions' are removed
+    startpointManager.removeFanOutForTaskSSPs(taskName, ImmutableSet.of(ssp1));
+    Assert.assertEquals(ImmutableMap.of(), startpointManager.getFanOutForTask(taskName));
+  }
+
+  @Test
   public void testDeleteAllStartpoints() throws IOException {
     SystemStreamPartition sspBroadcast = new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
     SystemStreamPartition sspSingle = new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
diff --git a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
index 3c226c4..d0039a9 100644
--- a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
+++ b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
@@ -233,17 +233,18 @@
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
     assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should not delete if the partition's processed offset is not updated
+    assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+
+    offsetManager.update(taskName, systemStreamPartition, "46")
+    offsetManager.update(taskName, systemStreamPartition, "47")
+    checkpoint(offsetManager, taskName)
     intercept[IllegalStateException] {
       // StartpointManager should stop after last fan out is removed
       startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
     }
     startpointManagerUtil.getStartpointManager.start
-    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete after checkpoint commit
-    assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
-    offsetManager.update(taskName, systemStreamPartition, "46")
-
-    offsetManager.update(taskName, systemStreamPartition, "47")
-    checkpoint(offsetManager, taskName)
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should not delete after checkpoint commit
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
     offsetManager.update(taskName, systemStreamPartition, "48")
@@ -426,12 +427,7 @@
     offsetsToCheckpoint.put(unregisteredSystemStreamPartition, "50")
     offsetManager.writeCheckpoint(taskName, new CheckpointV1(offsetsToCheckpoint))
 
-    intercept[IllegalStateException] {
-      // StartpointManager should stop after last fan out is removed
-      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
-    }
-    startpointManagerUtil.getStartpointManager.start
-    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should not delete if the partition's processed offset is not updated
 
     assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
@@ -445,6 +441,13 @@
     offsetManager.update(taskName, systemStreamPartition, "46")
     offsetManager.update(taskName, systemStreamPartition, "47")
     checkpoint(offsetManager, taskName)
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete if the partition's processed offset is updated
+
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("47", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -488,12 +491,8 @@
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
     assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
-    intercept[IllegalStateException] {
-      // StartpointManager should stop after last fan out is removed
-      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
-    }
-    startpointManagerUtil.getStartpointManager.start
-    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
+
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should not delete if the partition's processed offset is not updated
     assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
     offsetManager.update(taskName, systemStreamPartition, "46")
@@ -501,6 +500,12 @@
     val checkpoint46 = offsetManager.getLastProcessedOffsets(taskName)
     offsetManager.update(taskName, systemStreamPartition, "47") // Offset updated before checkpoint
     offsetManager.writeCheckpoint(taskName, new CheckpointV1(checkpoint46))
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete if the partition's processed offset is updated
     assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartition))
     assertEquals("46", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
@@ -620,6 +625,85 @@
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartitionWithKeyBucket).getValue)
   }
 
+  @Test
+  def testStartpointUpdate: Unit = {
+    val taskName = new TaskName("c")
+    val systemStream = new SystemStream("test-system", "test-stream")
+    val p0 = new Partition(0)
+    val p1 = new Partition(1)
+    val p2 = new Partition(2)
+    val ssp0 = new SystemStreamPartition(systemStream, p0)
+    val ssp1 = new SystemStreamPartition(systemStream, p1)
+    val ssp2 = new SystemStreamPartition(systemStream, p2)
+//    val unregisteredSystemStreamPartition = new SystemStreamPartition(systemStream3, partition)
+    val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(
+      p0 -> new SystemStreamPartitionMetadata("0", "1", "2"),
+      p1 -> new SystemStreamPartitionMetadata("0", "1", "2"),
+      p2 -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
+    val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
+    val config = new MapConfig
+    val checkpointManager = getCheckpointManager1(new CheckpointV1(Map(ssp0 -> "45", ssp1 -> "100", ssp1 -> "200").asJava),
+      taskName)
+    val startpointManagerUtil = getStartpointManagerUtil()
+    val systemAdmins = mock(classOf[SystemAdmins])
+    when(systemAdmins.getSystemAdmin(systemStream.getSystem)).thenReturn(getSystemAdmin)
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
+    offsetManager.register(taskName, Set(ssp0, ssp1, ssp2))
+    val startpoint0 = new StartpointUpcoming
+    val startpoint1 = new StartpointOldest
+    val startpoint2 = new StartpointOldest
+    startpointManagerUtil.getStartpointManager.writeStartpoint(ssp0, taskName, startpoint0)
+    startpointManagerUtil.getStartpointManager.writeStartpoint(ssp1, taskName, startpoint1)
+    startpointManagerUtil.getStartpointManager.writeStartpoint(ssp2, taskName, startpoint2)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, ssp0, ssp1, ssp2)).keySet().contains(taskName))
+    offsetManager.start
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp0))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp1))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp2))
+
+    checkpoint(offsetManager, taskName)
+    // Startpoint should not delete if the partition's processed offset is not updated
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp0))
+    assertEquals(Option(startpoint0), offsetManager.getStartpoint(taskName, ssp0))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp1))
+    assertEquals(Option(startpoint1), offsetManager.getStartpoint(taskName, ssp1))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp2))
+    assertEquals(Option(startpoint2), offsetManager.getStartpoint(taskName, ssp2))
+
+    offsetManager.update(taskName, ssp0, "46")
+    checkpoint(offsetManager, taskName)
+    // Startpoint should delete if the partition's processed offset is updated
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp0))
+    assertTrue(offsetManager.getStartpoint(taskName, ssp0).isEmpty)
+    // Startpoint should not delete if the partition's processed offset is not updated
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp1))
+    assertEquals(Option(startpoint1), offsetManager.getStartpoint(taskName, ssp1))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp2))
+    assertEquals(Option(startpoint2), offsetManager.getStartpoint(taskName, ssp2))
+
+    // update the offset which is same with the last checkpoint offset
+    offsetManager.update(taskName, ssp1, "100")
+    checkpoint(offsetManager, taskName)
+    // Startpoint should delete if the partition's processed offset is updated
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp1))
+    assertTrue(offsetManager.getStartpoint(taskName, ssp1).isEmpty)
+    // Startpoint should not delete if the partition's processed offset is not updated
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(ssp2))
+    assertEquals(Option(startpoint2), offsetManager.getStartpoint(taskName, ssp2))
+
+    offsetManager.update(taskName, ssp2, "201")
+    checkpoint(offsetManager, taskName)
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    // Startpoint should delete if the partition's processed offset is updated
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).isEmpty)
+    assertTrue(offsetManager.getStartpoint(taskName, ssp1).isEmpty)
+    startpointManagerUtil.stop
+  }
+
   // Utility method to create and write checkpoint in one statement
   def checkpoint(offsetManager: OffsetManager, taskName: TaskName): Unit = {
     offsetManager.writeCheckpoint(taskName, new CheckpointV1(offsetManager.getLastProcessedOffsets(taskName)))