Drain partial offset commit batches on upstream failure (#1058)

diff --git a/core/src/main/mima-filters/2.0.2.backwards.excludes b/core/src/main/mima-filters/2.0.2.backwards.excludes
new file mode 100644
index 0000000..54c1f6a
--- /dev/null
+++ b/core/src/main/mima-filters/2.0.2.backwards.excludes
@@ -0,0 +1,3 @@
+# Internal API
+ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.kafka.ConsumerMessage#CommittableOffsetBatch.isEmpty")
+ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.kafka.ConsumerMessage#CommittableOffsetBatch.tellCommitEmergency")
\ No newline at end of file
diff --git a/core/src/main/scala/akka/kafka/ConsumerMessage.scala b/core/src/main/scala/akka/kafka/ConsumerMessage.scala
index 457b2de..590efbd 100644
--- a/core/src/main/scala/akka/kafka/ConsumerMessage.scala
+++ b/core/src/main/scala/akka/kafka/ConsumerMessage.scala
@@ -215,6 +215,14 @@
      */
     @InternalApi
     private[kafka] def tellCommit(): CommittableOffsetBatch
+
+    @InternalApi
+    private[kafka] def tellCommitEmergency(): CommittableOffsetBatch
+
+    /**
+     * @return true if the batch contains no commits.
+     */
+    def isEmpty: Boolean
   }
 
 }
diff --git a/core/src/main/scala/akka/kafka/internal/CommitCollectorStage.scala b/core/src/main/scala/akka/kafka/internal/CommitCollectorStage.scala
new file mode 100644
index 0000000..3833ab8
--- /dev/null
+++ b/core/src/main/scala/akka/kafka/internal/CommitCollectorStage.scala
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2014 - 2016 Softwaremill <https://softwaremill.com>
+ * Copyright (C) 2016 - 2020 Lightbend Inc. <https://www.lightbend.com>
+ */
+
+package akka.kafka.internal
+
+import akka.annotation.InternalApi
+import akka.kafka.CommitterSettings
+import akka.kafka.ConsumerMessage.{Committable, CommittableOffsetBatch}
+import akka.stream._
+import akka.stream.stage._
+
+/**
+ * INTERNAL API.
+ *
+ * Combined stage for committing incoming offsets in batches. Capable of emitting dynamic (reduced) size batch in case of
+ * upstream failures. Support flushing on failure (for downstreams).
+ */
+@InternalApi
+private[kafka] final class CommitCollectorStage(val committerSettings: CommitterSettings)
+    extends GraphStage[FlowShape[Committable, CommittableOffsetBatch]] {
+
+  val in: Inlet[Committable] = Inlet[Committable]("FlowIn")
+  val out: Outlet[CommittableOffsetBatch] = Outlet[CommittableOffsetBatch]("FlowOut")
+  val shape: FlowShape[Committable, CommittableOffsetBatch] = FlowShape(in, out)
+
+  override def createLogic(
+      inheritedAttributes: Attributes
+  ): GraphStageLogic = {
+    new CommitCollectorStageLogic(this, inheritedAttributes)
+  }
+}
+
+private final class CommitCollectorStageLogic(
+    stage: CommitCollectorStage,
+    inheritedAttributes: Attributes
+) extends TimerGraphStageLogic(stage.shape)
+    with StageIdLogging {
+
+  import CommitCollectorStage._
+  import CommitTrigger._
+
+  override protected def logSource: Class[_] = classOf[CommitCollectorStageLogic]
+
+  // ---- initialization
+  override def preStart(): Unit = {
+    super.preStart()
+    scheduleCommit()
+    log.debug("CommitCollectorStage initialized")
+  }
+
+  /** Batches offsets until a commit is triggered. */
+  private var offsetBatch: CommittableOffsetBatch = CommittableOffsetBatch.empty
+
+  // ---- Consuming
+  private def consume(offset: Committable): Unit = {
+    log.debug("Consuming offset {}", offset)
+    offsetBatch = offsetBatch.updated(offset)
+    if (offsetBatch.batchSize >= stage.committerSettings.maxBatch) pushDownStream(BatchSize)(push)
+    else tryPull(stage.in) // accumulating the batch
+  }
+
+  private def scheduleCommit(): Unit =
+    scheduleOnce(CommitNow, stage.committerSettings.maxInterval)
+
+  override protected def onTimer(timerKey: Any): Unit = timerKey match {
+    case CommitCollectorStage.CommitNow => pushDownStream(Interval)(push)
+  }
+
+  private def pushDownStream(triggeredBy: TriggerdBy)(
+      emission: (Outlet[CommittableOffsetBatch], CommittableOffsetBatch) => Unit
+  ): Unit = {
+    if (activeBatchInProgress) {
+      log.debug("pushDownStream triggered by {}, outstanding batch {}", triggeredBy, offsetBatch)
+      emission(stage.out, offsetBatch)
+      offsetBatch = CommittableOffsetBatch.empty
+    }
+    scheduleCommit()
+  }
+
+  setHandler(
+    stage.in,
+    new InHandler {
+      def onPush(): Unit = {
+        consume(grab(stage.in))
+      }
+
+      override def onUpstreamFinish(): Unit = {
+        if (noActiveBatchInProgress) {
+          completeStage()
+        } else {
+          pushDownStream(UpstreamFinish)(emit[CommittableOffsetBatch])
+          completeStage()
+        }
+      }
+
+      override def onUpstreamFailure(ex: Throwable): Unit = {
+        if (noActiveBatchInProgress) {
+          log.debug("onUpstreamFailure with exception {} with empty offset batch", ex)
+          failStage(ex)
+        } else {
+          log.debug("onUpstreamFailure with exception {} with {}", ex, offsetBatch)
+          offsetBatch.tellCommitEmergency()
+          offsetBatch = CommittableOffsetBatch.empty
+          failStage(ex)
+        }
+      }
+    }
+  )
+
+  setHandler(
+    stage.out,
+    new OutHandler {
+      def onPull(): Unit = if (!hasBeenPulled(stage.in)) {
+        tryPull(stage.in)
+      }
+    }
+  )
+
+  override def postStop(): Unit = {
+    log.debug("CommitCollectorStage stopped")
+    super.postStop()
+  }
+
+  private def noActiveBatchInProgress: Boolean = offsetBatch.isEmpty
+  private def activeBatchInProgress: Boolean = !offsetBatch.isEmpty
+}
+
+private[akka] object CommitCollectorStage {
+  val CommitNow = "flowStageCommit"
+}
diff --git a/core/src/main/scala/akka/kafka/internal/CommitTrigger.scala b/core/src/main/scala/akka/kafka/internal/CommitTrigger.scala
new file mode 100644
index 0000000..9489247
--- /dev/null
+++ b/core/src/main/scala/akka/kafka/internal/CommitTrigger.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) 2014 - 2016 Softwaremill <https://softwaremill.com>
+ * Copyright (C) 2016 - 2020 Lightbend Inc. <https://www.lightbend.com>
+ */
+
+package akka.kafka.internal
+
+private[akka] object CommitTrigger {
+  sealed trait TriggerdBy
+  case object BatchSize extends TriggerdBy {
+    override def toString: String = "batch size"
+  }
+  case object Interval extends TriggerdBy {
+    override def toString: String = "interval"
+  }
+  case object UpstreamClosed extends TriggerdBy {
+    override def toString: String = "upstream closed"
+  }
+  case object UpstreamFinish extends TriggerdBy {
+    override def toString: String = "upstream finish"
+  }
+  case object UpstreamFailure extends TriggerdBy {
+    override def toString: String = "upstream failure"
+  }
+}
diff --git a/core/src/main/scala/akka/kafka/internal/CommittableSources.scala b/core/src/main/scala/akka/kafka/internal/CommittableSources.scala
index 2060587..72cc1b7 100644
--- a/core/src/main/scala/akka/kafka/internal/CommittableSources.scala
+++ b/core/src/main/scala/akka/kafka/internal/CommittableSources.scala
@@ -166,13 +166,13 @@
     case _ => failForUnexpectedImplementation(batch)
   }
 
-  def tellCommit(batch: CommittableOffsetBatch): Unit = batch match {
+  def tellCommit(batch: CommittableOffsetBatch, emergency: Boolean): Unit = batch match {
     case b: CommittableOffsetBatchImpl =>
       b.offsetsAndMetadata.foreach {
         case (groupTopicPartition, offset) =>
           // sends one message per partition, they are aggregated in the KafkaConsumerActor
           b.committerFor(groupTopicPartition)
-            .tellCommit(CommitWithoutReply(groupTopicPartition.topicPartition, offset))
+            .tellCommit(CommitWithoutReply(groupTopicPartition.topicPartition, offset, emergency))
       }
 
     case _ => failForUnexpectedImplementation(batch)
diff --git a/core/src/main/scala/akka/kafka/internal/CommittingProducerSinkStage.scala b/core/src/main/scala/akka/kafka/internal/CommittingProducerSinkStage.scala
index 5ee4713..1b2552a 100644
--- a/core/src/main/scala/akka/kafka/internal/CommittingProducerSinkStage.scala
+++ b/core/src/main/scala/akka/kafka/internal/CommittingProducerSinkStage.scala
@@ -51,6 +51,7 @@
     with DeferredProducer[K, V] {
 
   import CommittingProducerSinkStage._
+  import CommitTrigger._
 
   /** The promise behind the materialized future. */
   final val streamCompletion = Promise[Done]
@@ -201,6 +202,18 @@
       checkForCompletion()
   }
 
+  private def emergencyShutdown(ex: Throwable): Unit = {
+    log.debug("Emergency shutdown triggered by {} (awaitingProduceResult={} awaitingCommitResult={})",
+              ex,
+              awaitingProduceResult,
+              awaitingCommitResult)
+
+    offsetBatch.tellCommitEmergency()
+    upstreamCompletionState = Some(Failure(ex))
+    offsetBatch = CommittableOffsetBatch.empty
+    closeAndFailStage(ex)
+  }
+
   // ---- handler and completion
   /** Keeps track of upstream completion signals until this stage shuts down. */
   private var upstreamCompletionState: Option[Try[Done]] = None
@@ -227,9 +240,7 @@
         if (awaitingCommitResult == 0) {
           closeAndFailStage(ex)
         } else {
-          commit(UpstreamFailure)
-          setKeepGoing(true)
-          upstreamCompletionState = Some(Failure(ex))
+          emergencyShutdown(ex)
         }
     }
   )
@@ -261,21 +272,4 @@
 
 private object CommittingProducerSinkStage {
   val CommitNow = "commit"
-
-  sealed trait TriggerdBy
-  case object BatchSize extends TriggerdBy {
-    override def toString: String = "batch size"
-  }
-  case object Interval extends TriggerdBy {
-    override def toString: String = "interval"
-  }
-  case object UpstreamClosed extends TriggerdBy {
-    override def toString: String = "upstream closed"
-  }
-  case object UpstreamFinish extends TriggerdBy {
-    override def toString: String = "upstream finish"
-  }
-  case object UpstreamFailure extends TriggerdBy {
-    override def toString: String = "upstream failure"
-  }
 }
diff --git a/core/src/main/scala/akka/kafka/internal/KafkaConsumerActor.scala b/core/src/main/scala/akka/kafka/internal/KafkaConsumerActor.scala
index 57a257d..399aaad 100644
--- a/core/src/main/scala/akka/kafka/internal/KafkaConsumerActor.scala
+++ b/core/src/main/scala/akka/kafka/internal/KafkaConsumerActor.scala
@@ -65,7 +65,7 @@
     final case class StopFromStage(stageId: String) extends StopLike
     final case class Commit(tp: TopicPartition, offsetAndMetadata: OffsetAndMetadata)
         extends NoSerializationVerificationNeeded
-    final case class CommitWithoutReply(tp: TopicPartition, offsetAndMetadata: OffsetAndMetadata)
+    final case class CommitWithoutReply(tp: TopicPartition, offsetAndMetadata: OffsetAndMetadata, emergency: Boolean)
         extends NoSerializationVerificationNeeded
 
     /** Special case commit for non-batched committing. */
@@ -259,9 +259,12 @@
       commitMaps = tp -> offset :: commitMaps
       commitSenders = commitSenders :+ sender()
 
-    case CommitWithoutReply(tp, offset) =>
+    case CommitWithoutReply(tp, offset, emergency) =>
       // prepending, as later received offsets most likely are higher
       commitMaps = tp -> offset :: commitMaps
+      if (emergency) {
+        emergencyPoll()
+      }
 
     case CommitSingle(tp, offset) =>
       commitMaps = tp -> offset :: commitMaps
@@ -486,14 +489,14 @@
       self ! delayedPollMsg
     }
 
+  private def emergencyPoll(): Unit = {
+    log.debug("Performing emergency poll")
+    commitAndPoll()
+  }
+
   private def receivePoll(p: Poll[_, _]): Unit =
     if (p.target == this) {
-      val refreshOffsets = commitRefreshing.refreshOffsets
-      if (refreshOffsets.nonEmpty) {
-        log.debug("Refreshing committed offsets: {}", refreshOffsets)
-        commit(refreshOffsets, _ => ())
-      }
-      poll()
+      commitAndPoll()
       if (p.periodic)
         schedulePollTask()
       else
@@ -503,6 +506,15 @@
       log.debug("Ignoring Poll message with stale target ref")
     }
 
+  private def commitAndPoll(): Unit = {
+    val refreshOffsets = commitRefreshing.refreshOffsets
+    if (refreshOffsets.nonEmpty) {
+      log.debug("Refreshing committed offsets: {}", refreshOffsets)
+      commit(refreshOffsets, _ => ())
+    }
+    poll()
+  }
+
   def poll(): Unit = {
     try {
       val currentAssignmentsJava = consumer.assignment()
diff --git a/core/src/main/scala/akka/kafka/internal/MessageBuilder.scala b/core/src/main/scala/akka/kafka/internal/MessageBuilder.scala
index f5b9b2e..b3bde9e 100644
--- a/core/src/main/scala/akka/kafka/internal/MessageBuilder.scala
+++ b/core/src/main/scala/akka/kafka/internal/MessageBuilder.scala
@@ -229,19 +229,27 @@
   override def commitScaladsl(): Future[Done] = commitInternal()
 
   override def commitInternal(): Future[Done] =
-    if (batchSize == 0L)
+    if (isEmpty)
       Future.successful(Done)
     else {
       committers.head._2.commit(this)
     }
 
-  override def tellCommit(): CommittableOffsetBatch = {
+  override def tellCommit(): CommittableOffsetBatch = tellCommitWithPriority(emergency = false)
+
+  override def tellCommitEmergency(): CommittableOffsetBatch = tellCommitWithPriority(emergency = true)
+
+  private def tellCommitWithPriority(emergency: Boolean): CommittableOffsetBatch = {
     if (batchSize != 0L) {
-      committers.head._2.tellCommit(this)
+      committers.head._2.tellCommit(this, emergency = emergency)
     }
     this
   }
 
   override def commitJavadsl(): CompletionStage[Done] = commitInternal().toJava
 
+  /**
+   * @return true if the batch contains no commits.
+   */
+  def isEmpty: Boolean = batchSize == 0
 }
diff --git a/core/src/main/scala/akka/kafka/scaladsl/Committer.scala b/core/src/main/scala/akka/kafka/scaladsl/Committer.scala
index 630f8d8..d422ccc 100644
--- a/core/src/main/scala/akka/kafka/scaladsl/Committer.scala
+++ b/core/src/main/scala/akka/kafka/scaladsl/Committer.scala
@@ -5,12 +5,13 @@
 
 package akka.kafka.scaladsl
 
-import akka.dispatch.ExecutionContexts
 import akka.annotation.ApiMayChange
-import akka.{Done, NotUsed}
+import akka.dispatch.ExecutionContexts
 import akka.kafka.CommitterSettings
 import akka.kafka.ConsumerMessage.{Committable, CommittableOffsetBatch}
+import akka.kafka.internal.CommitCollectorStage
 import akka.stream.scaladsl.{Flow, FlowWithContext, Keep, Sink}
+import akka.{Done, NotUsed}
 
 import scala.concurrent.Future
 
@@ -26,20 +27,20 @@
    * Batches offsets and commits them to Kafka, emits `CommittableOffsetBatch` for every committed batch.
    */
   def batchFlow(settings: CommitterSettings): Flow[Committable, CommittableOffsetBatch, NotUsed] = {
-    val offsetBatches = Flow[Committable]
-      .groupedWeightedWithin(settings.maxBatch, settings.maxInterval)(_.batchSize)
-      .map(CommittableOffsetBatch.apply)
+    val offsetBatches: Flow[Committable, CommittableOffsetBatch, NotUsed] =
+      Flow
+        .fromGraph(new CommitCollectorStage(settings))
+
     // See https://github.com/akka/alpakka-kafka/issues/882
     import akka.kafka.CommitDelivery._
     settings.delivery match {
       case WaitForAck =>
         offsetBatches
-          .mapAsyncUnordered(settings.parallelism) { b =>
-            b.commitInternal().map(_ => b)(ExecutionContexts.sameThreadExecutionContext)
+          .mapAsyncUnordered(settings.parallelism) { batch =>
+            batch.commitInternal().map(_ => batch)(ExecutionContexts.sameThreadExecutionContext)
           }
       case SendAndForget =>
-        offsetBatches
-          .map(_.tellCommit())
+        offsetBatches.map(_.tellCommit())
     }
   }
 
diff --git a/docs/src/main/paradox/consumer.md b/docs/src/main/paradox/consumer.md
index 713cfe4..88739f1 100644
--- a/docs/src/main/paradox/consumer.md
+++ b/docs/src/main/paradox/consumer.md
@@ -147,7 +147,7 @@
 Java
 : @@ snip [snip](/tests/src/test/java/docs/javadsl/ConsumerExampleTest.java) { #atLeastOnce }
 
-Committing the offset for each message (`withMaxBatch(1)`) as illustrated above is rather slow. It is recommended to batch the commits for better throughput, with the trade-off that more messages may be re-delivered in case of failures.
+Committing the offset for each message (`withMaxBatch(1)`) as illustrated above is rather slow. It is recommended to batch the commits for better throughput, in cases when upstream fails the `Committer` will try to commit the offsets collected before the error.
 
 
 ### Committer sink
diff --git a/testkit/src/main/scala/akka/kafka/testkit/ConsumerResultFactory.scala b/testkit/src/main/scala/akka/kafka/testkit/ConsumerResultFactory.scala
index f673921..f2c2e12 100644
--- a/testkit/src/main/scala/akka/kafka/testkit/ConsumerResultFactory.scala
+++ b/testkit/src/main/scala/akka/kafka/testkit/ConsumerResultFactory.scala
@@ -22,7 +22,8 @@
 
   val fakeCommitter: KafkaAsyncConsumerCommitterRef = new KafkaAsyncConsumerCommitterRef(null, null)(ec = null) {
     override def commitSingle(offset: CommittableOffsetImpl): Future[Done] = Future.successful(Done)
-    override def commit(batch: ConsumerMessage.CommittableOffsetBatch): Future[Done] = Future.successful(Done)
+    override def commit(batch: ConsumerMessage.CommittableOffsetBatch): Future[Done] =
+      Future.successful(Done)
   }
 
   def partitionOffset(groupId: String, topic: String, partition: Int, offset: Long): ConsumerMessage.PartitionOffset =
diff --git a/tests/src/test/scala/akka/kafka/internal/CommitCollectorStageSpec.scala b/tests/src/test/scala/akka/kafka/internal/CommitCollectorStageSpec.scala
new file mode 100644
index 0000000..200f0c2
--- /dev/null
+++ b/tests/src/test/scala/akka/kafka/internal/CommitCollectorStageSpec.scala
@@ -0,0 +1,291 @@
+/*
+ * Copyright (C) 2014 - 2016 Softwaremill <https://softwaremill.com>
+ * Copyright (C) 2016 - 2020 Lightbend Inc. <https://www.lightbend.com>
+ */
+
+package akka.kafka.internal
+
+import java.util.concurrent.atomic.AtomicLong
+
+import akka.Done
+import akka.actor.ActorSystem
+import akka.event.LoggingAdapter
+import akka.kafka.CommitterSettings
+import akka.kafka.ConsumerMessage.{CommittableOffset, CommittableOffsetBatch, PartitionOffset}
+import akka.kafka.scaladsl.{Committer, Consumer}
+import akka.kafka.testkit.ConsumerResultFactory
+import akka.kafka.testkit.scaladsl.{ConsumerControlFactory, Slf4jToAkkaLoggingAdapter}
+import akka.kafka.tests.scaladsl.LogCapturing
+import akka.stream.ActorMaterializer
+import akka.stream.scaladsl.Keep
+import akka.stream.testkit.scaladsl.StreamTestKit.assertAllStagesStopped
+import akka.stream.testkit.scaladsl.{TestSink, TestSource}
+import akka.stream.testkit.{TestPublisher, TestSubscriber}
+import akka.testkit.TestKit
+import org.scalatest.concurrent.{Eventually, IntegrationPatience, ScalaFutures}
+import org.scalatest.{AppendedClues, BeforeAndAfterAll, Matchers, WordSpecLike}
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.concurrent.duration.{FiniteDuration, _}
+import scala.concurrent.{ExecutionContext, Future, Promise}
+
+class CommitCollectorStageSpec(_system: ActorSystem)
+    extends TestKit(_system)
+    with WordSpecLike
+    with Matchers
+    with BeforeAndAfterAll
+    with Eventually
+    with IntegrationPatience
+    with AppendedClues
+    with ScalaFutures
+    with LogCapturing {
+
+  implicit lazy val materializer: ActorMaterializer = ActorMaterializer()
+  implicit lazy val executionContext: ExecutionContext = system.dispatcher
+
+  val DefaultCommitterSettings: CommitterSettings = CommitterSettings(system)
+  val msgAbsenceDuration: FiniteDuration = 2.seconds
+
+  val log: Logger = LoggerFactory.getLogger(getClass)
+  // used by the .log(...) stream operator
+  implicit val adapter: LoggingAdapter = new Slf4jToAkkaLoggingAdapter(log)
+
+  def this() = this(ActorSystem())
+
+  override def afterAll(): Unit = shutdown(system)
+
+  "The CommitCollectorStage" when {
+    "the batch is full" should {
+      val settings = DefaultCommitterSettings.withMaxBatch(2).withMaxInterval(10.hours)
+      "batch commit without errors" in assertAllStagesStopped {
+        val (sourceProbe, control, sinkProbe) = streamProbes(settings)
+        val committer = new TestBatchCommitter(settings)
+        val offsetFactory = TestOffsetFactory(committer)
+        val (msg1, msg2) = (offsetFactory.makeOffset(), offsetFactory.makeOffset())
+
+        sinkProbe.request(100)
+
+        // first message should not be committed but 'batched-up'
+        sourceProbe.sendNext(msg1)
+        sinkProbe.expectNoMessage(msgAbsenceDuration)
+        committer.commits shouldBe empty
+
+        // now message that fills up the batch
+        sourceProbe.sendNext(msg2)
+
+        val committedBatch = sinkProbe.expectNext()
+
+        committedBatch.batchSize shouldBe 2
+        committedBatch.offsets.values should have size 1
+        committedBatch.offsets.values.last shouldBe msg2.partitionOffset.offset
+        committer.commits.size shouldBe 1 withClue "expected only one batch commit"
+
+        control.shutdown().futureValue shouldBe Done
+      }
+    }
+
+    "batch duration has elapsed" should {
+      val settings = DefaultCommitterSettings.withMaxBatch(Integer.MAX_VALUE).withMaxInterval(1.milli)
+      "batch commit without errors" in assertAllStagesStopped {
+        val (sourceProbe, control, sinkProbe, factory) = streamProbesWithOffsetFactory(settings)
+
+        sinkProbe.request(100)
+
+        val msg = factory.makeOffset()
+
+        sourceProbe.sendNext(msg)
+        val committedBatch = sinkProbe.expectNext()
+
+        committedBatch.batchSize shouldBe 1
+        committedBatch.offsets.values should have size 1
+        committedBatch.offsets.values.last shouldBe msg.partitionOffset.offset
+        factory.committer.commits.size shouldBe 1 withClue "expected only one batch commit"
+
+        control.shutdown().futureValue shouldBe Done
+      }
+    }
+
+    "all offsets are in batch that is in flight" should {
+      val settings =
+        DefaultCommitterSettings.withMaxBatch(Integer.MAX_VALUE).withMaxInterval(10.hours).withParallelism(1)
+
+      "batch commit all buffered elements if upstream has suddenly completed" in assertAllStagesStopped {
+        val (sourceProbe, control, sinkProbe, factory) = streamProbesWithOffsetFactory(settings)
+
+        sinkProbe.ensureSubscription()
+        sinkProbe.request(100)
+
+        val msg = factory.makeOffset()
+        sourceProbe.sendNext(msg)
+        sourceProbe.sendComplete()
+
+        val committedBatch = sinkProbe.expectNext()
+
+        committedBatch.batchSize shouldBe 1
+        committedBatch.offsets.values should have size 1
+        committedBatch.offsets.values.last shouldBe msg.partitionOffset.offset
+        factory.committer.commits.size shouldBe 1 withClue "expected only one batch commit"
+
+        control.shutdown().futureValue shouldBe Done
+      }
+
+      "batch commit all buffered elements if upstream has suddenly completed with delayed commits" in assertAllStagesStopped {
+        val (sourceProbe, control, sinkProbe) = streamProbes(settings)
+        val committer = new TestBatchCommitter(settings, () => 50.millis)
+
+        val factory = TestOffsetFactory(committer)
+        sinkProbe.request(100)
+
+        val (msg1, msg2) = (factory.makeOffset(), factory.makeOffset())
+        sourceProbe.sendNext(msg1)
+        sourceProbe.sendNext(msg2)
+        sourceProbe.sendComplete()
+
+        val committedBatch = sinkProbe.expectNext()
+
+        committedBatch.batchSize shouldBe 2
+        committedBatch.offsets.values should have size 1
+        committedBatch.offsets.values.last shouldBe msg2.partitionOffset.offset
+        committer.commits.size shouldBe 1 withClue "expected only one batch commit"
+
+        control.shutdown().futureValue shouldBe Done
+      }
+
+      "batch commit all buffered elements if upstream has suddenly failed" in assertAllStagesStopped {
+        val settings = // special config to have more than one batch before failure
+          DefaultCommitterSettings.withMaxBatch(3).withMaxInterval(10.hours).withParallelism(100)
+
+        val (sourceProbe, control, sinkProbe, factory) = streamProbesWithOffsetFactory(settings)
+
+        sinkProbe.request(100)
+
+        val msgs = (1 to 10).map(_ => factory.makeOffset())
+
+        msgs.foreach(sourceProbe.sendNext)
+
+        val testError = new IllegalStateException("BOOM")
+        sourceProbe.sendError(testError)
+
+        val receivedError = pullTillFailure(sinkProbe, maxEvents = 4)
+
+        receivedError shouldBe testError
+
+        val commits = factory.committer.commits
+
+        commits.last.offset shouldBe 10 withClue "last offset commit should be exactly the one preceeding the error"
+
+        control.shutdown().futureValue shouldBe Done
+      }
+    }
+  }
+
+  @scala.annotation.tailrec
+  private def pullTillFailure(
+      sinkProbe: TestSubscriber.Probe[CommittableOffsetBatch],
+      maxEvents: Int
+  ): Throwable = {
+    val nextOrError = sinkProbe.expectNextOrError()
+    if (maxEvents < 0) {
+      fail("Max number events has been read, no error encountered.")
+    }
+    nextOrError match {
+      case Left(ex) =>
+        log.debug("Received failure")
+        ex
+      case Right(_) =>
+        log.debug("Received batch {}")
+        pullTillFailure(sinkProbe, maxEvents - 1)
+    }
+  }
+
+  private def streamProbes(
+      committerSettings: CommitterSettings
+  ): (TestPublisher.Probe[CommittableOffset], Consumer.Control, TestSubscriber.Probe[CommittableOffsetBatch]) = {
+
+    val flow = Committer.batchFlow(committerSettings)
+
+    val ((source, control), sink) = TestSource
+      .probe[CommittableOffset]
+      .viaMat(ConsumerControlFactory.controlFlow())(Keep.both)
+      .via(flow)
+      .toMat(TestSink.probe)(Keep.both)
+      .run()
+
+    (source, control, sink)
+  }
+
+  private def streamProbesWithOffsetFactory(
+      committerSettings: CommitterSettings
+  ): (TestPublisher.Probe[CommittableOffset],
+      Consumer.Control,
+      TestSubscriber.Probe[CommittableOffsetBatch],
+      TestOffsetFactory) = {
+    val (source, control, sink) = streamProbes(committerSettings)
+    val factory = TestOffsetFactory(new TestBatchCommitter(committerSettings))
+    (source, control, sink, factory)
+  }
+
+  object TestCommittableOffset {
+
+    def apply(offsetCounter: AtomicLong,
+              committer: TestBatchCommitter,
+              failWith: Option[Throwable] = None): CommittableOffset = {
+      CommittableOffsetImpl(
+        ConsumerResultFactory
+          .partitionOffset(groupId = "group1",
+                           topic = "topic1",
+                           partition = 1,
+                           offset = offsetCounter.incrementAndGet()),
+        "metadata1"
+      )(committer.underlying)
+    }
+  }
+
+  class TestOffsetFactory(val committer: TestBatchCommitter) {
+    private val offsetCounter = new AtomicLong(0L)
+
+    def makeOffset(failWith: Option[Throwable] = None): CommittableOffset = {
+      TestCommittableOffset(offsetCounter, committer, failWith)
+    }
+  }
+
+  object TestOffsetFactory {
+
+    def apply(committer: TestBatchCommitter): TestOffsetFactory =
+      new TestOffsetFactory(committer)
+  }
+
+  class TestBatchCommitter(
+      commitSettings: CommitterSettings,
+      commitDelay: () => FiniteDuration = () => Duration.Zero
+  )(
+      implicit system: ActorSystem
+  ) {
+
+    var commits = List.empty[PartitionOffset]
+
+    private def completeCommit(): Future[Done] = {
+      val promisedCommit = Promise[Done]
+      system.scheduler.scheduleOnce(commitDelay()) {
+        promisedCommit.success(Done)
+      }
+      promisedCommit.future
+    }
+
+    private[akka] val underlying =
+      new KafkaAsyncConsumerCommitterRef(consumerActor = null, commitSettings.maxInterval) {
+        override def commitSingle(offset: CommittableOffsetImpl): Future[Done] = {
+          commits = commits :+ offset.partitionOffset
+          completeCommit()
+        }
+
+        override def commit(batch: CommittableOffsetBatch): Future[Done] = {
+          val offsets = batch.offsets.toList.map { case (partition, offset) => PartitionOffset(partition, offset) }
+          commits = commits ++ offsets
+          completeCommit()
+        }
+
+        override def tellCommit(batch: CommittableOffsetBatch, emergency: Boolean): Unit = commit(batch)
+      }
+  }
+}
diff --git a/tests/src/test/scala/akka/kafka/internal/CommittingProducerSinkSpec.scala b/tests/src/test/scala/akka/kafka/internal/CommittingProducerSinkSpec.scala
index 0d8ff3e..698436b 100644
--- a/tests/src/test/scala/akka/kafka/internal/CommittingProducerSinkSpec.scala
+++ b/tests/src/test/scala/akka/kafka/internal/CommittingProducerSinkSpec.scala
@@ -366,7 +366,7 @@
       .mapMaterializedValue(DrainingControl.apply)
       .run()
 
-    val commitMsg = consumer.actor.expectMsgClass(1.second, classOf[Internal.Commit])
+    val commitMsg = consumer.actor.expectMsgClass(1.second, classOf[Internal.CommitWithoutReply])
     commitMsg.tp shouldBe new TopicPartition(topic, partition)
     commitMsg.offsetAndMetadata.offset() shouldBe (consumer.startOffset + 2)
     consumer.actor.reply(Done)
diff --git a/tests/src/test/scala/akka/kafka/internal/CommittingWithMockSpec.scala b/tests/src/test/scala/akka/kafka/internal/CommittingWithMockSpec.scala
index d6e1411..005605a 100644
--- a/tests/src/test/scala/akka/kafka/internal/CommittingWithMockSpec.scala
+++ b/tests/src/test/scala/akka/kafka/internal/CommittingWithMockSpec.scala
@@ -213,7 +213,8 @@
     mock.enqueue(msgs.map(toRecord))
 
     probe.request(count.toLong)
-    val allCommits = Future.sequence(probe.expectNextN(count.toLong).map(_.committableOffset.commitInternal()))
+    val allCommits =
+      Future.sequence(probe.expectNextN(count.toLong).map(_.committableOffset.commitInternal()))
 
     withClue("the commits are aggregated to a low number of calls to commitAsync:") {
       awaitAssert {
diff --git a/tests/src/test/scala/akka/kafka/scaladsl/CommittingSpec.scala b/tests/src/test/scala/akka/kafka/scaladsl/CommittingSpec.scala
index dd4ee38..fc22bf5 100644
--- a/tests/src/test/scala/akka/kafka/scaladsl/CommittingSpec.scala
+++ b/tests/src/test/scala/akka/kafka/scaladsl/CommittingSpec.scala
@@ -359,6 +359,38 @@
       assert(element1.toInt >= failAt - committerSettings.maxBatch, "Should re-process at most maxBatch elements")
     }
 
+    "work with a committer batch flow even with upstream failure" in assertAllStagesStopped {
+      val topic = createTopic()
+      val group = createGroupId()
+
+      awaitProduce(produce(topic, 1 to 100))
+      val consumerSettings = consumerDefaults.withGroupId(group)
+      val committerSettings = committerDefaults.withMaxBatch(5)
+
+      def consumeAndCommitUntil(topic: String, failAt: String) =
+        Consumer
+          .committableSource(
+            consumerSettings,
+            Subscriptions.topics(topic)
+          )
+          .map {
+            case msg if msg.record.value() == failAt => throw new Exception
+            case other => other
+          }
+          .map(_.committableOffset)
+          .via(Committer.batchFlow(committerSettings))
+          .toMat(Sink.ignore)(Keep.right)
+          .run()
+
+      // Consume and fail in the middle of the commit batch
+      val failAt = 32
+      consumeAndCommitUntil(topic, failAt.toString).failed.futureValue shouldBe an[Exception]
+
+      val element1 = consumeFirstElement(topic, consumerSettings)
+      assert(element1.toInt == failAt,
+             "Should re-process exactly the last committed element from batch-in-flight in case of upstream failure")
+    }
+
   }
 
   "Multiple consumers to one committer" must {