KAFKA-10199: Accept only one task per element in output queue for failed tasks (#15849)

Currently, the state updater writes multiple tasks per exception in the output
queue for failed tasks. To add the functionality to remove tasks synchronously
from the state updater, it is simpler that each element of the output queue for
failed tasks holds one single task.

This commit refactors the class that holds exceptions and failed tasks
in the state updater -- i.e., ExceptionAndTasks -- to just hold one single
task.

Reviewers: Lucas Brutschy <lbrutschy@confluent.io>
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index 9629a3b..07d28e2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -305,7 +305,7 @@
 
         private void handleRuntimeException(final RuntimeException runtimeException) {
             log.error("An unexpected error occurred within the state updater thread: " + runtimeException);
-            addToExceptionsAndFailedTasksThenClearUpdatingTasks(new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException));
+            addToExceptionsAndFailedTasksThenClearUpdatingAndPausedTasks(runtimeException);
             isRunning.set(false);
         }
 
@@ -324,7 +324,9 @@
                 changelogsOfCorruptedTasks.addAll(corruptedTask.changelogPartitions());
             }
             changelogReader.unregister(changelogsOfCorruptedTasks);
-            addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new ExceptionAndTasks(corruptedTasks, taskCorruptedException));
+            corruptedTasks.forEach(
+                task -> addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new ExceptionAndTask(taskCorruptedException, task))
+            );
         }
 
         // TODO: we can let the exception encode the actual corrupted changelog partitions and only
@@ -347,33 +349,66 @@
 
         private void handleStreamsExceptionWithTask(final StreamsException streamsException) {
             final TaskId failedTaskId = streamsException.taskId().get();
-            if (!updatingTasks.containsKey(failedTaskId)) {
-                throw new IllegalStateException("Task " + failedTaskId + " failed but is not updating. " + BUG_ERROR_MESSAGE);
+            if (updatingTasks.containsKey(failedTaskId)) {
+                addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(
+                    new ExceptionAndTask(streamsException, updatingTasks.get(failedTaskId))
+                );
+            } else if (pausedTasks.containsKey(failedTaskId)) {
+                addToExceptionsAndFailedTasksThenRemoveFromPausedTasks(
+                    new ExceptionAndTask(streamsException, pausedTasks.get(failedTaskId))
+                );
+            } else {
+                throw new IllegalStateException("Task " + failedTaskId + " failed but is not updating or paused. " + BUG_ERROR_MESSAGE);
             }
-            final Set<Task> failedTask = new HashSet<>();
-            failedTask.add(updatingTasks.get(failedTaskId));
-            addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new ExceptionAndTasks(failedTask, streamsException));
         }
 
         private void handleStreamsExceptionWithoutTask(final StreamsException streamsException) {
-            addToExceptionsAndFailedTasksThenClearUpdatingTasks(
-                new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), streamsException));
+            addToExceptionsAndFailedTasksThenClearUpdatingAndPausedTasks(streamsException);
         }
 
         // It is important to remove the corrupted tasks from the updating tasks after they were added to the
         // failed tasks.
         // This ensures that all tasks are found in DefaultStateUpdater#getTasks().
-        private void addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(final ExceptionAndTasks exceptionAndTasks) {
-            exceptionsAndFailedTasks.add(exceptionAndTasks);
-            exceptionAndTasks.getTasks().stream().map(Task::id).forEach(updatingTasks::remove);
-            if (exceptionAndTasks.getTasks().stream().anyMatch(Task::isActive)) {
-                transitToUpdateStandbysIfOnlyStandbysLeft();
+        private void addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(final ExceptionAndTask exceptionAndTask) {
+            exceptionsAndFailedTasksLock.lock();
+            try {
+                exceptionsAndFailedTasks.add(exceptionAndTask);
+                updatingTasks.remove(exceptionAndTask.task().id());
+                if (exceptionAndTask.task().isActive()) {
+                    transitToUpdateStandbysIfOnlyStandbysLeft();
+                }
+            } finally {
+                exceptionsAndFailedTasksLock.unlock();
             }
         }
 
-        private void addToExceptionsAndFailedTasksThenClearUpdatingTasks(final ExceptionAndTasks exceptionAndTasks) {
-            exceptionsAndFailedTasks.add(exceptionAndTasks);
-            updatingTasks.clear();
+        private void addToExceptionsAndFailedTasksThenRemoveFromPausedTasks(final ExceptionAndTask exceptionAndTask) {
+            exceptionsAndFailedTasksLock.lock();
+            try {
+                exceptionsAndFailedTasks.add(exceptionAndTask);
+                pausedTasks.remove(exceptionAndTask.task().id());
+                if (exceptionAndTask.task().isActive()) {
+                    transitToUpdateStandbysIfOnlyStandbysLeft();
+                }
+            } finally {
+                exceptionsAndFailedTasksLock.unlock();
+            }
+        }
+
+        private void addToExceptionsAndFailedTasksThenClearUpdatingAndPausedTasks(final RuntimeException runtimeException) {
+            exceptionsAndFailedTasksLock.lock();
+            try {
+                updatingTasks.values().forEach(
+                    task -> exceptionsAndFailedTasks.add(new ExceptionAndTask(runtimeException, task))
+                );
+                updatingTasks.clear();
+                pausedTasks.values().forEach(
+                    task -> exceptionsAndFailedTasks.add(new ExceptionAndTask(runtimeException, task))
+                );
+                pausedTasks.clear();
+            } finally {
+                exceptionsAndFailedTasksLock.unlock();
+            }
         }
 
         private void waitIfAllChangelogsCompletelyRead() {
@@ -610,7 +645,8 @@
     private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>();
     private final Lock restoredActiveTasksLock = new ReentrantLock();
     private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition();
-    private final BlockingQueue<ExceptionAndTasks> exceptionsAndFailedTasks = new LinkedBlockingQueue<>();
+    private final Lock exceptionsAndFailedTasksLock = new ReentrantLock();
+    private final Queue<ExceptionAndTask> exceptionsAndFailedTasks = new LinkedList<>();
     private final BlockingQueue<Task> removedTasks = new LinkedBlockingQueue<>();
     private final AtomicBoolean isTopologyResumed = new AtomicBoolean(false);
 
@@ -780,15 +816,26 @@
     }
 
     @Override
-    public List<ExceptionAndTasks> drainExceptionsAndFailedTasks() {
-        final List<ExceptionAndTasks> result = new ArrayList<>();
-        exceptionsAndFailedTasks.drainTo(result);
+    public List<ExceptionAndTask> drainExceptionsAndFailedTasks() {
+        final List<ExceptionAndTask> result = new ArrayList<>();
+        exceptionsAndFailedTasksLock.lock();
+        try {
+            result.addAll(exceptionsAndFailedTasks);
+            exceptionsAndFailedTasks.clear();
+        } finally {
+            exceptionsAndFailedTasksLock.unlock();
+        }
         return result;
     }
 
     @Override
     public boolean hasExceptionsAndFailedTasks() {
-        return !exceptionsAndFailedTasks.isEmpty();
+        exceptionsAndFailedTasksLock.lock();
+        try {
+            return !exceptionsAndFailedTasks.isEmpty();
+        } finally {
+            exceptionsAndFailedTasksLock.unlock();
+        }
     }
 
     public Set<StandbyTask> getUpdatingStandbyTasks() {
@@ -813,8 +860,13 @@
         }
     }
 
-    public List<ExceptionAndTasks> getExceptionsAndFailedTasks() {
-        return Collections.unmodifiableList(new ArrayList<>(exceptionsAndFailedTasks));
+    public List<ExceptionAndTask> getExceptionsAndFailedTasks() {
+        exceptionsAndFailedTasksLock.lock();
+        try {
+            return Collections.unmodifiableList(new ArrayList<>(exceptionsAndFailedTasks));
+        } finally {
+            exceptionsAndFailedTasksLock.unlock();
+        }
     }
 
     public Set<Task> getRemovedTasks() {
@@ -868,9 +920,11 @@
     private <T> Set<T> executeWithQueuesLocked(final Supplier<Set<T>> action) {
         tasksAndActionsLock.lock();
         restoredActiveTasksLock.lock();
+        exceptionsAndFailedTasksLock.lock();
         try {
             return action.get();
         } finally {
+            exceptionsAndFailedTasksLock.unlock();
             restoredActiveTasksLock.unlock();
             tasksAndActionsLock.unlock();
         }
@@ -895,7 +949,7 @@
                     Stream.concat(
                         restoredActiveTasks.stream(),
                         Stream.concat(
-                            exceptionsAndFailedTasks.stream().flatMap(exceptionAndTasks -> exceptionAndTasks.getTasks().stream()),
+                            exceptionsAndFailedTasks.stream().map(ExceptionAndTask::task),
                             removedTasks.stream()))));
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index 445d72a..861c1f0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -21,24 +21,23 @@
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.time.Duration;
-import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
 
 public interface StateUpdater {
 
-    class ExceptionAndTasks {
-        private final Set<Task> tasks;
+    class ExceptionAndTask {
+        private final Task task;
         private final RuntimeException exception;
 
-        public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException exception) {
-            this.tasks = Objects.requireNonNull(tasks);
+        public ExceptionAndTask(final RuntimeException exception, final Task task) {
             this.exception = Objects.requireNonNull(exception);
+            this.task = Objects.requireNonNull(task);
         }
 
-        public Set<Task> getTasks() {
-            return Collections.unmodifiableSet(tasks);
+        public Task task() {
+            return task;
         }
 
         public RuntimeException exception() {
@@ -48,14 +47,22 @@
         @Override
         public boolean equals(final Object o) {
             if (this == o) return true;
-            if (!(o instanceof ExceptionAndTasks)) return false;
-            final ExceptionAndTasks that = (ExceptionAndTasks) o;
-            return tasks.equals(that.tasks) && exception.equals(that.exception);
+            if (!(o instanceof ExceptionAndTask)) return false;
+            final ExceptionAndTask that = (ExceptionAndTask) o;
+            return task.id().equals(that.task.id()) && exception.equals(that.exception);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(tasks, exception);
+            return Objects.hash(task, exception);
+        }
+
+        @Override
+        public String toString() {
+            return "ExceptionAndTask{" +
+                "task=" + task.id() +
+                ", exception=" + exception +
+                '}';
         }
     }
 
@@ -142,7 +149,7 @@
      *
      * @return list of failed tasks and the corresponding exceptions
      */
-    List<ExceptionAndTasks> drainExceptionsAndFailedTasks();
+    List<ExceptionAndTask> drainExceptionsAndFailedTasks();
 
     /**
      * Checks if the state updater has any failed tasks that should be returned to the StreamThread
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 4c7f3a8..ff055b7 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -914,15 +914,12 @@
     public void handleExceptionsFromStateUpdater() {
         final Map<TaskId, RuntimeException> taskExceptions = new LinkedHashMap<>();
 
-        for (final StateUpdater.ExceptionAndTasks exceptionAndTasks : stateUpdater.drainExceptionsAndFailedTasks()) {
-            final RuntimeException exception = exceptionAndTasks.exception();
-            final Set<Task> failedTasks = exceptionAndTasks.getTasks();
-
-            for (final Task failedTask : failedTasks) {
-                // need to add task back to the bookkeeping to be handled by the stream thread
-                tasks.addTask(failedTask);
-                taskExceptions.put(failedTask.id(), exception);
-            }
+        for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) {
+            final RuntimeException exception = exceptionAndTask.exception();
+            final Task failedTask = exceptionAndTask.task();
+            // need to add task back to the bookkeeping to be handled by the stream thread
+            tasks.addTask(failedTask);
+            taskExceptions.put(failedTask.id(), exception);
         }
 
         maybeThrowTaskExceptions(taskExceptions);
@@ -1440,7 +1437,7 @@
 
     private void closeFailedTasksFromStateUpdater() {
         final Set<Task> tasksToCloseDirty = stateUpdater.drainExceptionsAndFailedTasks().stream()
-            .flatMap(exAndTasks -> exAndTasks.getTasks().stream()).collect(Collectors.toSet());
+            .map(StateUpdater.ExceptionAndTask::task).collect(Collectors.toSet());
 
         for (final Task task : tasksToCloseDirty) {
             try {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 4f6a659..31460d4 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -26,7 +26,7 @@
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
+import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTask;
 import org.apache.kafka.streams.processor.internals.Task.State;
 import org.hamcrest.Matcher;
 import org.junit.jupiter.api.AfterEach;
@@ -458,7 +458,7 @@
         stateUpdater.add(task);
         verifyRestoredActiveTasks();
         verifyUpdatingTasks();
-        verifyExceptionsAndFailedTasks(new ExceptionAndTasks(mkSet(task), taskCorruptedException));
+        verifyExceptionsAndFailedTasks(new ExceptionAndTask(taskCorruptedException, task));
         verifyRemovedTasks();
         verifyPausedTasks();
 
@@ -666,9 +666,9 @@
         stateUpdater.add(activeTask2);
         stateUpdater.add(standbyTask);
 
-        final ExceptionAndTasks expectedExceptionAndTasks =
-            new ExceptionAndTasks(mkSet(activeTask1, activeTask2), taskCorruptedException);
-        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        final ExceptionAndTask expectedExceptionAndTask1 = new ExceptionAndTask(taskCorruptedException, activeTask1);
+        final ExceptionAndTask expectedExceptionAndTask2 = new ExceptionAndTask(taskCorruptedException, activeTask2);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTask1, expectedExceptionAndTask2);
         final InOrder orderVerifier = inOrder(changelogReader);
         orderVerifier.verify(changelogReader, atLeast(1)).enforceRestoreActive();
         orderVerifier.verify(changelogReader).transitToUpdateStandby();
@@ -683,7 +683,7 @@
                 mkEntry(task2.id(), task2)
         );
         final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(mkSet(task1.id()));
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1), taskCorruptedException);
+        final ExceptionAndTask expectedExceptionAndTasks = new ExceptionAndTask(taskCorruptedException, task1);
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
         doThrow(taskCorruptedException).doReturn(0L).when(changelogReader).restore(updatingTasks);
 
@@ -849,7 +849,7 @@
 
         stateUpdater.add(task);
         stateUpdater.add(controlTask);
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task), streamsException);
+        final ExceptionAndTask expectedExceptionAndTasks = new ExceptionAndTask(streamsException, task);
         verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
 
         stateUpdater.remove(task.id());
@@ -991,7 +991,7 @@
 
         stateUpdater.add(task);
         stateUpdater.add(controlTask);
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task), streamsException);
+        final ExceptionAndTask expectedExceptionAndTasks = new ExceptionAndTask(streamsException, task);
         verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
         verifyUpdatingTasks(controlTask);
 
@@ -1172,7 +1172,7 @@
 
         stateUpdater.add(task);
         stateUpdater.add(controlTask);
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task), streamsException);
+        final ExceptionAndTask expectedExceptionAndTasks = new ExceptionAndTask(streamsException, task);
         verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
         verifyUpdatingTasks(controlTask);
 
@@ -1223,8 +1223,9 @@
         stateUpdater.add(task1);
         stateUpdater.add(task2);
 
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), streamsException);
-        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        final ExceptionAndTask expectedExceptionAndTask1 = new ExceptionAndTask(streamsException, task1);
+        final ExceptionAndTask expectedExceptionAndTask2 = new ExceptionAndTask(streamsException, task2);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTask1, expectedExceptionAndTask2);
         verifyRemovedTasks();
         verifyPausedTasks();
         verifyUpdatingTasks();
@@ -1260,8 +1261,8 @@
         stateUpdater.add(task2);
         stateUpdater.add(task3);
 
-        final ExceptionAndTasks expectedExceptionAndTasks1 = new ExceptionAndTasks(mkSet(task1), streamsException1);
-        final ExceptionAndTasks expectedExceptionAndTasks2 = new ExceptionAndTasks(mkSet(task3), streamsException2);
+        final ExceptionAndTask expectedExceptionAndTasks1 = new ExceptionAndTask(streamsException1, task1);
+        final ExceptionAndTask expectedExceptionAndTasks2 = new ExceptionAndTask(streamsException2, task3);
         verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, expectedExceptionAndTasks2);
         verifyUpdatingTasks(task2);
         verifyRestoredActiveTasks();
@@ -1288,8 +1289,9 @@
         stateUpdater.add(task2);
         stateUpdater.add(task3);
 
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), taskCorruptedException);
-        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        final ExceptionAndTask expectedExceptionAndTask1 = new ExceptionAndTask(taskCorruptedException, task1);
+        final ExceptionAndTask expectedExceptionAndTask2 = new ExceptionAndTask(taskCorruptedException, task2);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTask1, expectedExceptionAndTask2);
         verifyUpdatingTasks(task3);
         verifyRestoredActiveTasks();
         verifyRemovedTasks();
@@ -1313,8 +1315,9 @@
         stateUpdater.add(task1);
         stateUpdater.add(task2);
 
-        final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1, task2), illegalStateException);
-        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
+        final ExceptionAndTask expectedExceptionAndTask1 = new ExceptionAndTask(illegalStateException, task1);
+        final ExceptionAndTask expectedExceptionAndTask2 = new ExceptionAndTask(illegalStateException, task2);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTask1, expectedExceptionAndTask2);
         verifyUpdatingTasks();
         verifyRestoredActiveTasks();
         verifyRemovedTasks();
@@ -1359,16 +1362,16 @@
 
         stateUpdater.add(task1);
 
-        final ExceptionAndTasks expectedExceptionAndTasks1 = new ExceptionAndTasks(mkSet(task1), streamsException1);
+        final ExceptionAndTask expectedExceptionAndTasks1 = new ExceptionAndTask(streamsException1, task1);
         verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks1);
 
         stateUpdater.add(task2);
         stateUpdater.add(task3);
         stateUpdater.add(task4);
 
-        final ExceptionAndTasks expectedExceptionAndTasks2 = new ExceptionAndTasks(mkSet(task2), streamsException2);
-        final ExceptionAndTasks expectedExceptionAndTasks3 = new ExceptionAndTasks(mkSet(task3), streamsException3);
-        final ExceptionAndTasks expectedExceptionAndTasks4 = new ExceptionAndTasks(mkSet(task4), streamsException4);
+        final ExceptionAndTask expectedExceptionAndTasks2 = new ExceptionAndTask(streamsException2, task2);
+        final ExceptionAndTask expectedExceptionAndTasks3 = new ExceptionAndTask(streamsException3, task3);
+        final ExceptionAndTask expectedExceptionAndTasks4 = new ExceptionAndTask(streamsException4, task4);
         verifyDrainingExceptionsAndFailedTasks(expectedExceptionAndTasks2, expectedExceptionAndTasks3, expectedExceptionAndTasks4);
     }
 
@@ -1511,10 +1514,10 @@
         stateUpdater.add(standbyTask1);
         stateUpdater.add(activeTask1);
         stateUpdater.add(standbyTask2);
-        final ExceptionAndTasks expectedExceptionAndTasks1 =
-            new ExceptionAndTasks(mkSet(standbyTask1, standbyTask2), taskCorruptedException);
-        final ExceptionAndTasks expectedExceptionAndTasks2 = new ExceptionAndTasks(mkSet(activeTask1), streamsException);
-        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, expectedExceptionAndTasks2);
+        final ExceptionAndTask expectedExceptionAndTasks1 = new ExceptionAndTask(taskCorruptedException, standbyTask1);
+        final ExceptionAndTask expectedExceptionAndTasks2 = new ExceptionAndTask(taskCorruptedException, standbyTask2);
+        final ExceptionAndTask expectedExceptionAndTasks3 = new ExceptionAndTask(streamsException, activeTask1);
+        verifyExceptionsAndFailedTasks(expectedExceptionAndTasks1, expectedExceptionAndTasks2, expectedExceptionAndTasks3);
 
         verifyGetTasks(mkSet(activeTask1), mkSet(standbyTask1, standbyTask2));
 
@@ -1838,9 +1841,9 @@
         assertTrue(stateUpdater.drainRemovedTasks().isEmpty());
     }
 
-    private void verifyExceptionsAndFailedTasks(final ExceptionAndTasks... exceptionsAndTasks) throws Exception {
-        final List<ExceptionAndTasks> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks);
-        final Set<ExceptionAndTasks> failedTasks = new HashSet<>();
+    private void verifyExceptionsAndFailedTasks(final ExceptionAndTask... exceptionsAndTasks) throws Exception {
+        final List<ExceptionAndTask> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks);
+        final Set<ExceptionAndTask> failedTasks = new HashSet<>();
         waitForCondition(
             () -> {
                 failedTasks.addAll(stateUpdater.getExceptionsAndFailedTasks());
@@ -1856,27 +1859,27 @@
         final List<Task> expectedFailedTasks = Arrays.asList(tasks);
         final Set<Task> failedTasks = new HashSet<>();
         waitForCondition(
-                () -> {
-                    for (final ExceptionAndTasks exceptionsAndTasks : stateUpdater.getExceptionsAndFailedTasks()) {
-                        if (clazz.isInstance(exceptionsAndTasks.exception())) {
-                            failedTasks.addAll(exceptionsAndTasks.getTasks());
-                        }
+            () -> {
+                for (final ExceptionAndTask exceptionAndTask : stateUpdater.getExceptionsAndFailedTasks()) {
+                    if (clazz.isInstance(exceptionAndTask.exception())) {
+                        failedTasks.add(exceptionAndTask.task());
                     }
-                    return failedTasks.containsAll(expectedFailedTasks)
-                            && failedTasks.size() == expectedFailedTasks.size();
-                },
-                VERIFICATION_TIMEOUT,
-                "Did not get all exceptions and failed tasks within the given timeout!"
+                }
+                return failedTasks.containsAll(expectedFailedTasks)
+                    && failedTasks.size() == expectedFailedTasks.size();
+            },
+            VERIFICATION_TIMEOUT,
+            "Did not get all exceptions and failed tasks within the given timeout!"
         );
     }
 
-    private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTasks... exceptionsAndTasks) throws Exception {
-        final List<ExceptionAndTasks> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks);
-        final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
+    private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTask... exceptionsAndTasks) throws Exception {
+        final List<ExceptionAndTask> expectedExceptionAndTasks = Arrays.asList(exceptionsAndTasks);
+        final List<ExceptionAndTask> failedTasks = new ArrayList<>();
         waitForCondition(
             () -> {
                 if (stateUpdater.hasExceptionsAndFailedTasks()) {
-                    final List<ExceptionAndTasks> exceptionAndTasks = stateUpdater.drainExceptionsAndFailedTasks();
+                    final List<ExceptionAndTask> exceptionAndTasks = stateUpdater.drainExceptionsAndFailedTasks();
                     assertFalse(exceptionAndTasks.isEmpty());
                     failedTasks.addAll(exceptionAndTasks);
                 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 64ad0d1..566e838 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -47,7 +47,7 @@
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
-import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
+import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTask;
 import org.apache.kafka.streams.processor.internals.Task.State;
 import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
 import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig;
@@ -1790,10 +1790,7 @@
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
         final StreamsException exception = new StreamsException("boom!");
-        final StateUpdater.ExceptionAndTasks exceptionAndTasks = new StateUpdater.ExceptionAndTasks(
-            Collections.singleton(statefulTask),
-            exception
-        );
+        final ExceptionAndTask exceptionAndTasks = new ExceptionAndTask(exception, statefulTask);
         when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(true);
         when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Collections.singletonList(exceptionAndTasks));
 
@@ -1815,10 +1812,7 @@
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
         final RuntimeException exception = new RuntimeException("boom!");
-        final StateUpdater.ExceptionAndTasks exceptionAndTasks = new StateUpdater.ExceptionAndTasks(
-            Collections.singleton(statefulTask),
-            exception
-        );
+        final ExceptionAndTask exceptionAndTasks = new ExceptionAndTask(exception, statefulTask);
         when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(true);
         when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Collections.singletonList(exceptionAndTasks));
 
@@ -1843,13 +1837,13 @@
         final StreamTask statefulTask1 = statefulTask(taskId01, taskId01ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId01Partitions).build();
-        final StateUpdater.ExceptionAndTasks exceptionAndTasks0 = new StateUpdater.ExceptionAndTasks(
-            Collections.singleton(statefulTask0),
-            new TaskCorruptedException(Collections.singleton(taskId00))
+        final ExceptionAndTask exceptionAndTasks0 = new ExceptionAndTask(
+            new TaskCorruptedException(Collections.singleton(taskId00)),
+            statefulTask0
         );
-        final StateUpdater.ExceptionAndTasks exceptionAndTasks1 = new StateUpdater.ExceptionAndTasks(
-            Collections.singleton(statefulTask1),
-            new TaskCorruptedException(Collections.singleton(taskId01))
+        final ExceptionAndTask exceptionAndTasks1 = new ExceptionAndTask(
+            new TaskCorruptedException(Collections.singleton(taskId01)),
+            statefulTask1
         );
         when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(true);
         when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList(exceptionAndTasks0, exceptionAndTasks1));
@@ -3749,8 +3743,8 @@
             .inState(State.RUNNING).build();
         when(stateUpdater.drainExceptionsAndFailedTasks())
             .thenReturn(Arrays.asList(
-                new ExceptionAndTasks(mkSet(failedStatefulTask), new RuntimeException()),
-                new ExceptionAndTasks(mkSet(failedStandbyTask), new RuntimeException()))
+                new ExceptionAndTask(new RuntimeException(), failedStatefulTask),
+                new ExceptionAndTask(new RuntimeException(), failedStandbyTask))
             );
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);