KAFKA-17057: Add RETRY option to ProductionExceptionHanlder (#17163)

Implements KIP-1065

Reviewers: Alieh Saeedi <asaeedi@confluent.io>, Bill Bejeck <bill@confluent.io>
diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java
index 0896114..d6cc8e9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java
+++ b/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.errors;
 
 import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.errors.RetriableException;
 
 import java.util.Map;
 
@@ -29,18 +30,22 @@
     @Override
     public ProductionExceptionHandlerResponse handle(final ProducerRecord<byte[], byte[]> record,
                                                      final Exception exception) {
-        return ProductionExceptionHandlerResponse.FAIL;
+        return exception instanceof RetriableException ?
+            ProductionExceptionHandlerResponse.RETRY :
+            ProductionExceptionHandlerResponse.FAIL;
     }
 
     @Override
     public ProductionExceptionHandlerResponse handle(final ErrorHandlerContext context,
                                                      final ProducerRecord<byte[], byte[]> record,
                                                      final Exception exception) {
-        return ProductionExceptionHandlerResponse.FAIL;
+        return exception instanceof RetriableException ?
+            ProductionExceptionHandlerResponse.RETRY :
+            ProductionExceptionHandlerResponse.FAIL;
     }
 
     @Override
     public void configure(final Map<String, ?> configs) {
         // ignore
     }
-}
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
index 02837b9..9512788 100644
--- a/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
+++ b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
@@ -83,18 +83,37 @@
     }
 
     enum ProductionExceptionHandlerResponse {
-        /* continue processing */
+        /** Continue processing.
+         *
+         * <p> For this case, output records which could not be written successfully are lost.
+         * Use this option only if you can tolerate data loss.
+         */
         CONTINUE(0, "CONTINUE"),
-        /* fail processing */
-        FAIL(1, "FAIL");
+        /** Fail processing.
+         *
+         * <p> Kafka Streams will raise an exception and the {@code StreamsThread} will fail.
+         * No offsets (for {@link org.apache.kafka.streams.StreamsConfig#AT_LEAST_ONCE at-least-once}) or transactions
+         * (for {@link org.apache.kafka.streams.StreamsConfig#EXACTLY_ONCE_V2 exactly-once}) will be committed.
+         */
+        FAIL(1, "FAIL"),
+        /** Retry the failed operation.
+         *
+         * <p> Retrying might imply that a {@link TaskCorruptedException} exception is thrown, and that the retry
+         * is started from the last committed offset.
+         *
+         * <p> <b>NOTE:</b> {@code RETRY} is only a valid return value for
+         * {@link org.apache.kafka.common.errors.RetriableException retriable exceptions}.
+         * If {@code RETRY} is returned for a non-retriable exception it will be interpreted as {@link #FAIL}.
+         */
+        RETRY(2, "RETRY");
 
         /**
-         * an english description of the api--this is for debugging and can change
+         * An english description for the used option. This is for debugging only and may change.
          */
         public final String name;
 
         /**
-         * the permanent and immutable id of an API--this can't change ever
+         * The permanent and immutable id for the used option. This can't change ever.
          */
         public final int id;
 
@@ -106,9 +125,9 @@
     }
 
     enum SerializationExceptionOrigin {
-        /* serialization exception occurred during serialization of the key */
+        /** Serialization exception occurred during serialization of the key. */
         KEY,
-        /* serialization exception occurred during serialization of the value */
+        /** Serialization exception occurred during serialization of the value. */
         VALUE
     }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index bd589a8..a79e407 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -34,7 +34,6 @@
 import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.UnknownServerException;
-import org.apache.kafka.common.errors.UnknownTopicOrPartitionException;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Sensor;
@@ -160,7 +159,7 @@
                     fatal
                 );
             }
-            if (partitions.size() > 0) {
+            if (!partitions.isEmpty()) {
                 final Optional<Set<Integer>> maybeMulticastPartitions = partitioner.partitions(topic, key, value, partitions.size());
                 if (!maybeMulticastPartitions.isPresent()) {
                     // A null//empty partition indicates we should use the default partitioner
@@ -342,7 +341,7 @@
             throw new FailedProcessingException("Fatal user code error in production error callback", fatalUserException);
         }
 
-        if (response == ProductionExceptionHandlerResponse.FAIL) {
+        if (maybeFailResponse(response) == ProductionExceptionHandlerResponse.FAIL) {
             throw new StreamsException(
                 String.format(
                     "Unable to serialize record. ProducerRecord(topic=[%s], partition=[%d], timestamp=[%d]",
@@ -430,55 +429,53 @@
                 "indicating the task may be migrated out";
             sendException.set(new TaskMigratedException(errorMessage, productionException));
         } else {
-            if (isRetriable(productionException)) {
+            final ProductionExceptionHandlerResponse response;
+            try {
+                response = Objects.requireNonNull(
+                    productionExceptionHandler.handle(
+                        errorHandlerContext(context, processorNodeId),
+                        serializedRecord,
+                        productionException
+                    ),
+                    "Invalid ProductionExceptionHandler response."
+                );
+            } catch (final RuntimeException fatalUserException) {
+                log.error(
+                    "Production error callback failed after production error for record {}",
+                    serializedRecord,
+                    productionException
+                );
+                sendException.set(new FailedProcessingException("Fatal user code error in production error callback", fatalUserException));
+                return;
+            }
+
+            if (productionException instanceof RetriableException && response == ProductionExceptionHandlerResponse.RETRY) {
                 errorMessage += "\nThe broker is either slow or in bad state (like not having enough replicas) in responding the request, " +
                     "or the connection to broker was interrupted sending the request or receiving the response. " +
                     "\nConsider overwriting `max.block.ms` and /or " +
                     "`delivery.timeout.ms` to a larger value to wait longer for such scenarios and avoid timeout errors";
                 sendException.set(new TaskCorruptedException(Collections.singleton(taskId)));
             } else {
-                final ProductionExceptionHandlerResponse response;
-                try {
-                    response = Objects.requireNonNull(
-                        productionExceptionHandler.handle(
-                            errorHandlerContext(context, processorNodeId),
-                            serializedRecord,
-                            productionException
-                        ),
-                        "Invalid ProductionExceptionHandler response."
-                    );
-                } catch (final RuntimeException fatalUserException) {
-                    log.error(
-                        "Production error callback failed after production error for record {}",
-                        serializedRecord,
-                        productionException
-                    );
-                    sendException.set(new FailedProcessingException("Fatal user code error in production error callback", fatalUserException));
-                    return;
-                }
-
-                if (response == ProductionExceptionHandlerResponse.FAIL) {
+                if (maybeFailResponse(response) == ProductionExceptionHandlerResponse.FAIL) {
                     errorMessage += "\nException handler choose to FAIL the processing, no more records would be sent.";
                     sendException.set(new StreamsException(errorMessage, productionException));
                 } else {
                     errorMessage += "\nException handler choose to CONTINUE processing in spite of this error but written offsets would not be recorded.";
                     droppedRecordsSensor.record();
                 }
-
             }
         }
 
         log.error(errorMessage, productionException);
     }
 
-    /**
-     * The `TimeoutException` with root cause `UnknownTopicOrPartitionException` is considered as non-retriable
-     * (despite `TimeoutException` being a subclass of `RetriableException`, this particular case is explicitly excluded).
-    */
-    private boolean isRetriable(final Exception exception) {
-        return exception instanceof RetriableException &&
-                (!(exception instanceof TimeoutException) || exception.getCause() == null
-                        || !(exception.getCause() instanceof UnknownTopicOrPartitionException));
+    private ProductionExceptionHandlerResponse maybeFailResponse(final ProductionExceptionHandlerResponse response) {
+        if (response == ProductionExceptionHandlerResponse.RETRY) {
+            log.warn("ProductionExceptionHandler returned RETRY for a non-retriable exception. Will treat it as FAIL.");
+            return ProductionExceptionHandlerResponse.FAIL;
+        } else {
+            return response;
+        }
     }
 
     private boolean isFatalException(final Exception exception) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/CustomHandlerIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SwallowUnknownTopicErrorIntegrationTest.java
similarity index 65%
rename from streams/src/test/java/org/apache/kafka/streams/integration/CustomHandlerIntegrationTest.java
rename to streams/src/test/java/org/apache/kafka/streams/integration/SwallowUnknownTopicErrorIntegrationTest.java
index 873b2eb..25f9afd 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/CustomHandlerIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/SwallowUnknownTopicErrorIntegrationTest.java
@@ -16,24 +16,28 @@
  */
 package org.apache.kafka.streams.integration;
 
-import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.UnknownTopicOrPartitionException;
+import org.apache.kafka.common.serialization.IntegerDeserializer;
 import org.apache.kafka.common.serialization.IntegerSerializer;
 import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KafkaStreams.State;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.KeyValueTimestamp;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.Topology;
-import org.apache.kafka.streams.errors.StreamsException;
-import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler;
+import org.apache.kafka.streams.errors.ErrorHandlerContext;
+import org.apache.kafka.streams.errors.ProductionExceptionHandler;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
+import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.Produced;
 import org.apache.kafka.test.TestUtils;
 
@@ -48,16 +52,14 @@
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.Map;
 import java.util.Properties;
-import java.util.concurrent.atomic.AtomicReference;
 
 import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
-import static org.junit.jupiter.api.Assertions.assertInstanceOf;
-
 
 @Timeout(600)
 @Tag("integration")
-public class CustomHandlerIntegrationTest {
+public class SwallowUnknownTopicErrorIntegrationTest {
     private static final int NUM_BROKERS = 1;
     public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS,
             Utils.mkProperties(Collections.singletonMap("auto.create.topics.enable", "false")));
@@ -77,8 +79,7 @@
     // topic name
     private static final String STREAM_INPUT = "STREAM_INPUT";
     private static final String NON_EXISTING_TOPIC = "non_existing_topic";
-
-    private final AtomicReference<Throwable> caughtException = new AtomicReference<>();
+    private static final String STREAM_OUTPUT = "STREAM_OUTPUT";
 
     private KafkaStreams kafkaStreams;
     private Topology topology;
@@ -87,19 +88,19 @@
     @BeforeEach
     public void before(final TestInfo testInfo) throws InterruptedException {
         final StreamsBuilder builder = new StreamsBuilder();
-        CLUSTER.createTopics(STREAM_INPUT);
+        CLUSTER.createTopics(STREAM_INPUT, STREAM_OUTPUT);
         final String safeTestName = safeUniqueTestName(testInfo);
         appId = "app-" + safeTestName;
 
-        builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String()))
-            .to(NON_EXISTING_TOPIC, Produced.with(Serdes.Integer(), Serdes.String()));
-        produceRecords();
+        final KStream<Integer, String> stream = builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String()));
+        stream.to(NON_EXISTING_TOPIC, Produced.with(Serdes.Integer(), Serdes.String()));
+        stream.to(STREAM_OUTPUT, Produced.with(Serdes.Integer(), Serdes.String()));
         topology = builder.build();
     }
 
     @AfterEach
     public void after() throws InterruptedException {
-        CLUSTER.deleteTopics(STREAM_INPUT);
+        CLUSTER.deleteTopics(STREAM_INPUT, STREAM_OUTPUT);
         if (kafkaStreams != null) {
             kafkaStreams.close();
             kafkaStreams.cleanUp();
@@ -108,15 +109,30 @@
 
     private void produceRecords() {
         final Properties props = TestUtils.producerConfig(
-                CLUSTER.bootstrapServers(),
-                IntegerSerializer.class,
-                StringSerializer.class,
-                new Properties());
+            CLUSTER.bootstrapServers(),
+            IntegerSerializer.class,
+            StringSerializer.class
+        );
         IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp(
-                STREAM_INPUT,
-                Collections.singletonList(new KeyValue<>(1, "A")),
-                props,
-                CLUSTER.time.milliseconds() + 2
+            STREAM_INPUT,
+            Collections.singletonList(new KeyValue<>(1, "A")),
+            props,
+            CLUSTER.time.milliseconds() + 2
+        );
+    }
+
+    private void verifyResult() {
+        final Properties props = TestUtils.consumerConfig(
+            CLUSTER.bootstrapServers(),
+            "consumer",
+            IntegerDeserializer.class,
+            StringDeserializer.class
+        );
+
+        IntegrationTestUtils.verifyKeyValueTimestamps(
+            props,
+            STREAM_OUTPUT,
+            Collections.singletonList(new KeyValueTimestamp<>(1, "A", CLUSTER.time.milliseconds() + 2))
         );
     }
 
@@ -124,12 +140,33 @@
         final Properties streamsConfiguration = new Properties();
         streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId);
         streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1);
         streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class);
         streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class);
-        streamsConfiguration.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10_000);
+        streamsConfiguration.put(StreamsConfig.PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG, TestHandler.class);
         return streamsConfiguration;
     }
 
+    public static class TestHandler implements ProductionExceptionHandler {
+
+        public TestHandler() { }
+
+        @Override
+        public void configure(final Map<String, ?> configs) { }
+
+        @Override
+        public ProductionExceptionHandlerResponse handle(final ErrorHandlerContext context,
+                                                         final ProducerRecord<byte[], byte[]> record,
+                                                         final Exception exception) {
+            if (exception instanceof TimeoutException &&
+                exception.getCause() != null &&
+                exception.getCause() instanceof UnknownTopicOrPartitionException) {
+                return ProductionExceptionHandlerResponse.CONTINUE;
+            }
+            return ProductionExceptionHandler.super.handle(context, record, exception);
+        }
+    }
+
     private void closeApplication(final Properties streamsConfiguration) throws Exception {
         kafkaStreams.close();
         kafkaStreams.cleanUp();
@@ -140,10 +177,6 @@
     public void shouldThrowStreamsExceptionWithMissingTopicAndDefaultExceptionHandler() throws Exception {
         final Properties streamsConfiguration = getCommonProperties();
         kafkaStreams = new KafkaStreams(topology, streamsConfiguration);
-        kafkaStreams.setUncaughtExceptionHandler(e -> {
-            caughtException.set(e);
-            return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT;
-        });
         kafkaStreams.start();
         TestUtils.waitForCondition(
             () -> kafkaStreams.state() == State.RUNNING,
@@ -151,29 +184,9 @@
             () -> "Kafka Streams application did not reach state RUNNING in " + timeoutMs + " ms"
         );
 
-        TestUtils.waitForCondition(
-            this::receivedUnknownTopicOrPartitionException,
-            timeoutMs,
-            () -> "Did not receive UnknownTopicOrPartitionException"
-        );
+        produceRecords();
+        verifyResult();
 
-        TestUtils.waitForCondition(
-            () -> kafkaStreams.state() == State.ERROR,
-            timeoutMs,
-            () -> "Kafka Streams application did not reach state ERROR in " + timeoutMs + " ms"
-        );
         closeApplication(streamsConfiguration);
     }
-
-    private boolean receivedUnknownTopicOrPartitionException() {
-        if (caughtException.get() == null) {
-            return false;
-        }
-
-        assertInstanceOf(StreamsException.class, caughtException.get());
-        assertInstanceOf(TimeoutException.class, caughtException.get().getCause());
-        assertInstanceOf(UnknownTopicOrPartitionException.class, caughtException.get().getCause().getCause());
-
-        return true;
-    }
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index dc8e668..353289f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -54,6 +54,7 @@
 import org.apache.kafka.streams.errors.ProductionExceptionHandler.ProductionExceptionHandlerResponse;
 import org.apache.kafka.streams.errors.ProductionExceptionHandler.SerializationExceptionOrigin;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
@@ -1346,15 +1347,40 @@
 
         collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, sinkNodeName, context, streamPartitioner);
 
-        // With default handler which returns FAIL, flush() throws StreamsException with TimeoutException cause,
-        // otherwise it would throw a TaskCorruptedException with null cause
+        final TaskCorruptedException thrown = assertThrows(TaskCorruptedException.class, collector::flush);
+        assertThat(
+            thrown.getMessage(),
+            equalTo("Tasks [0_0] are corrupted and hence need to be re-initialized")
+        );
+    }
+
+    @Test
+    public void shouldThrowStreamsExceptionOnUnknownTopicOrPartitionExceptionWhenExceptionHandlerReturnsFail() {
+        final KafkaException exception = new TimeoutException("KABOOM!", new UnknownTopicOrPartitionException());
+        final RecordCollector collector = new RecordCollectorImpl(
+            logContext,
+            taskId,
+            getExceptionalStreamsProducerOnSend(exception),
+            new ProductionExceptionHandlerMock(
+                Optional.of(ProductionExceptionHandlerResponse.FAIL),
+                context,
+                sinkNodeName,
+                taskId
+            ),
+            streamsMetrics,
+            topology
+        );
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, sinkNodeName, context, streamPartitioner);
+
+        // With custom handler which returns FAIL, flush() throws StreamsException with TimeoutException cause
         final StreamsException thrown = assertThrows(StreamsException.class, collector::flush);
         assertEquals(exception, thrown.getCause());
         assertThat(
             thrown.getMessage(),
             equalTo("Error encountered sending record to topic topic for task 0_0 due to:" +
-                    "\norg.apache.kafka.common.errors.TimeoutException: KABOOM!" +
-                    "\nException handler choose to FAIL the processing, no more records would be sent.")
+                "\norg.apache.kafka.common.errors.TimeoutException: KABOOM!" +
+                "\nException handler choose to FAIL the processing, no more records would be sent.")
         );
     }
 
@@ -1381,6 +1407,42 @@
     }
 
     @Test
+    public void shouldTreatRetryAsFailForNonRetriableException() {
+        try (final LogCaptureAppender logCaptureAppender = LogCaptureAppender.createAndRegister(RecordCollectorImpl.class)) {
+            final RuntimeException exception = new RuntimeException("KABOOM!");
+            final RecordCollector collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                getExceptionalStreamsProducerOnSend(exception),
+                new ProductionExceptionHandlerMock(
+                    Optional.of(ProductionExceptionHandlerResponse.RETRY),
+                    context,
+                    sinkNodeName,
+                    taskId
+                ),
+                streamsMetrics,
+                topology
+            );
+
+            collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, sinkNodeName, context, streamPartitioner);
+
+            final StreamsException thrown = assertThrows(StreamsException.class, collector::flush);
+            assertEquals(exception, thrown.getCause());
+            assertThat(
+                thrown.getMessage(),
+                equalTo("Error encountered sending record to topic topic for task 0_0 due to:" +
+                    "\njava.lang.RuntimeException: KABOOM!" +
+                    "\nException handler choose to FAIL the processing, no more records would be sent.")
+            );
+
+            assertThat(
+                logCaptureAppender.getMessages().get(0),
+                equalTo("test ProductionExceptionHandler returned RETRY for a non-retriable exception. Will treat it as FAIL.")
+            );
+        }
+    }
+
+    @Test
     public void shouldNotAbortTxnOnEOSCloseDirtyIfNothingSent() {
         final AtomicBoolean functionCalled = new AtomicBoolean(false);
         final RecordCollector collector = new RecordCollectorImpl(
@@ -1986,6 +2048,7 @@
             return response.orElse(null);
         }
 
+        @SuppressWarnings("rawtypes")
         @Override
         public ProductionExceptionHandlerResponse handleSerializationException(final ErrorHandlerContext context,
                                                                                final ProducerRecord record,