[FLINK-31049] [flink-connector-kafka] Add support for Kafka record headers to KafkaSink

Co-Authored-By: Tzu-Li (Gordon) Tai <tzulitai@apache.org>

This closes #18.
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/HeaderProvider.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/HeaderProvider.java
new file mode 100644
index 0000000..2c0c080
--- /dev/null
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/HeaderProvider.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kafka.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.Headers;
+
+import java.io.Serializable;
+
+/** Creates an {@link Iterable} of {@link Header}s from the input element. */
+@PublicEvolving
+public interface HeaderProvider<IN> extends Serializable {
+    Headers getHeaders(IN input);
+}
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilder.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilder.java
index 59864a3..1cc9220 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilder.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilder.java
@@ -84,6 +84,7 @@
     @Nullable private SerializationSchema<? super IN> valueSerializationSchema;
     @Nullable private FlinkKafkaPartitioner<? super IN> partitioner;
     @Nullable private SerializationSchema<? super IN> keySerializationSchema;
+    @Nullable private HeaderProvider<? super IN> headerProvider;
 
     /**
      * Sets a custom partitioner determining the target partition of the target topic.
@@ -190,6 +191,20 @@
         return self;
     }
 
+    /**
+     * Sets a {@link HeaderProvider} which is used to add headers to the {@link ProducerRecord} for
+     * the current element.
+     *
+     * @param headerProvider
+     * @return {@code this}
+     */
+    public <T extends IN> KafkaRecordSerializationSchemaBuilder<T> setHeaderProvider(
+            HeaderProvider<? super T> headerProvider) {
+        KafkaRecordSerializationSchemaBuilder<T> self = self();
+        self.headerProvider = checkNotNull(headerProvider);
+        return self;
+    }
+
     @SuppressWarnings("unchecked")
     private <T extends IN> KafkaRecordSerializationSchemaBuilder<T> self() {
         return (KafkaRecordSerializationSchemaBuilder<T>) this;
@@ -239,7 +254,11 @@
         checkState(valueSerializationSchema != null, "No value serializer is configured.");
         checkState(topicSelector != null, "No topic selector is configured.");
         return new KafkaRecordSerializationSchemaWrapper<>(
-                topicSelector, valueSerializationSchema, keySerializationSchema, partitioner);
+                topicSelector,
+                valueSerializationSchema,
+                keySerializationSchema,
+                partitioner,
+                headerProvider);
     }
 
     private void checkValueSerializerNotSet() {
@@ -278,16 +297,19 @@
         private final Function<? super IN, String> topicSelector;
         private final FlinkKafkaPartitioner<? super IN> partitioner;
         private final SerializationSchema<? super IN> keySerializationSchema;
+        private final HeaderProvider<? super IN> headerProvider;
 
         KafkaRecordSerializationSchemaWrapper(
                 Function<? super IN, String> topicSelector,
                 SerializationSchema<? super IN> valueSerializationSchema,
                 @Nullable SerializationSchema<? super IN> keySerializationSchema,
-                @Nullable FlinkKafkaPartitioner<? super IN> partitioner) {
+                @Nullable FlinkKafkaPartitioner<? super IN> partitioner,
+                @Nullable HeaderProvider<? super IN> headerProvider) {
             this.topicSelector = checkNotNull(topicSelector);
             this.valueSerializationSchema = checkNotNull(valueSerializationSchema);
             this.partitioner = partitioner;
             this.keySerializationSchema = keySerializationSchema;
+            this.headerProvider = headerProvider;
         }
 
         @Override
@@ -325,12 +347,22 @@
                                             context.getPartitionsForTopic(targetTopic)))
                             : OptionalInt.empty();
 
-            return new ProducerRecord<>(
-                    targetTopic,
-                    partition.isPresent() ? partition.getAsInt() : null,
-                    timestamp == null || timestamp < 0L ? null : timestamp,
-                    key,
-                    value);
+            if (headerProvider != null) {
+                return new ProducerRecord<>(
+                        targetTopic,
+                        partition.isPresent() ? partition.getAsInt() : null,
+                        timestamp == null || timestamp < 0L ? null : timestamp,
+                        key,
+                        value,
+                        headerProvider.getHeaders(element));
+            } else {
+                return new ProducerRecord<>(
+                        targetTopic,
+                        partition.isPresent() ? partition.getAsInt() : null,
+                        timestamp == null || timestamp < 0L ? null : timestamp,
+                        key,
+                        value);
+            }
         }
     }
 }
diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilderTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilderTest.java
index 614624e..6dd5bae 100644
--- a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilderTest.java
+++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaRecordSerializationSchemaBuilderTest.java
@@ -28,12 +28,16 @@
 
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.Configurable;
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -146,6 +150,26 @@
     }
 
     @Test
+    public void testSerializeRecordWithHeaderProvider() throws Exception {
+        final HeaderProvider<String> headerProvider =
+                (ignored) ->
+                        new RecordHeaders(ImmutableList.of(new RecordHeader("a", "a".getBytes())));
+
+        final KafkaRecordSerializationSchema<String> schema =
+                KafkaRecordSerializationSchema.builder()
+                        .setTopic(DEFAULT_TOPIC)
+                        .setValueSerializationSchema(new SimpleStringSchema())
+                        .setHeaderProvider(headerProvider)
+                        .build();
+        final ProducerRecord<byte[], byte[]> record = schema.serialize("a", null, null);
+        assertThat(record).isNotNull();
+        assertThat(record.headers())
+                .singleElement()
+                .extracting(Header::key, Header::value)
+                .containsExactly("a", "a".getBytes(StandardCharsets.UTF_8));
+    }
+
+    @Test
     public void testSerializeRecordWithKey() {
         final SerializationSchema<String> serializationSchema = new SimpleStringSchema();
         final KafkaRecordSerializationSchema<String> schema =