Merge pull request #11749 from boyuanzz/kafka

[BEAM-9977] Implement ReadFromKafkaViaSDF 
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 08847f6..e1d0d2c 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -47,20 +47,28 @@
 import org.apache.beam.sdk.io.Read.Unbounded;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
+import org.apache.beam.sdk.options.ExperimentalOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.schemas.transforms.Convert;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
+import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.WallTime;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.Row;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -72,6 +80,7 @@
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.Deserializer;
@@ -87,6 +96,8 @@
 /**
  * An unbounded source and a sink for <a href="http://kafka.apache.org/">Kafka</a> topics.
  *
+ * <h2>Read from Kafka as {@link UnboundedSource}</h2>
+ *
  * <h3>Reading from Kafka topics</h3>
  *
  * <p>KafkaIO source returns unbounded collection of Kafka records as {@code
@@ -153,7 +164,7 @@
  *
  * <p>When the pipeline starts for the first time, or without any checkpoint, the source starts
  * consuming from the <em>latest</em> offsets. You can override this behavior to consume from the
- * beginning by setting appropriate appropriate properties in {@link ConsumerConfig}, through {@link
+ * beginning by setting properties appropriately in {@link ConsumerConfig}, through {@link
  * Read#withConsumerConfigUpdates(Map)}. You can also enable offset auto_commit in Kafka to resume
  * from last committed.
  *
@@ -198,6 +209,107 @@
  *    ...
  * }</pre>
  *
+ * <h2>Read from Kafka as a {@link DoFn}</h2>
+ *
+ * {@link ReadSourceDescriptors} is the {@link PTransform} that takes a PCollection of {@link
+ * KafkaSourceDescriptor} as input and outputs a PCollection of {@link KafkaRecord}. The core
+ * implementation is based on {@code SplittableDoFn}. For more details about the concept of {@code
+ * SplittableDoFn}, please refer to the <a
+ * href="https://beam.apache.org/blog/splittable-do-fn/">blog post</a> and <a
+ * href="https://s.apache.org/beam-fn-api">design doc</a>. The major difference from {@link
+ * KafkaIO.Read} is, {@link ReadSourceDescriptors} doesn't require source descriptions(e.g., {@link
+ * KafkaIO.Read#getTopicPartitions()}, {@link KafkaIO.Read#getTopics()}, {@link
+ * KafkaIO.Read#getStartReadTime()}, etc.) during the pipeline construction time. Instead, the
+ * pipeline can populate these source descriptions during runtime. For example, the pipeline can
+ * query Kafka topics from a BigQuery table and read these topics via {@link ReadSourceDescriptors}.
+ *
+ * <h3>Common Kafka Consumer Configurations</h3>
+ *
+ * <p>Most Kafka consumer configurations are similar to {@link KafkaIO.Read}:
+ *
+ * <ul>
+ *   <li>{@link ReadSourceDescriptors#getConsumerConfig()} is the same as {@link
+ *       KafkaIO.Read#getConsumerConfig()}.
+ *   <li>{@link ReadSourceDescriptors#getConsumerFactoryFn()} is the same as {@link
+ *       KafkaIO.Read#getConsumerFactoryFn()}.
+ *   <li>{@link ReadSourceDescriptors#getOffsetConsumerConfig()} is the same as {@link
+ *       KafkaIO.Read#getOffsetConsumerConfig()}.
+ *   <li>{@link ReadSourceDescriptors#getKeyCoder()} is the same as {@link
+ *       KafkaIO.Read#getKeyCoder()}.
+ *   <li>{@link ReadSourceDescriptors#getValueCoder()} is the same as {@link
+ *       KafkaIO.Read#getValueCoder()}.
+ *   <li>{@link ReadSourceDescriptors#getKeyDeserializerProvider()} is the same as {@link
+ *       KafkaIO.Read#getKeyDeserializerProvider()}.
+ *   <li>{@link ReadSourceDescriptors#getValueDeserializerProvider()} is the same as {@link
+ *       KafkaIO.Read#getValueDeserializerProvider()}.
+ *   <li>{@link ReadSourceDescriptors#isCommitOffsetEnabled()} has the same meaning as {@link
+ *       KafkaIO.Read#isCommitOffsetsInFinalizeEnabled()}.
+ * </ul>
+ *
+ * <p>For example, to create a basic {@link ReadSourceDescriptors} transform:
+ *
+ * <pre>{@code
+ * pipeline
+ *  .apply(Create.of(KafkaSourceDescriptor.of(new TopicPartition("topic", 1)))
+ *  .apply(KafkaIO.readAll()
+ *          .withBootstrapServers("broker_1:9092,broker_2:9092")
+ *          .withKeyDeserializer(LongDeserializer.class).
+ *          .withValueDeserializer(StringDeserializer.class));
+ * }</pre>
+ *
+ * Note that the {@code bootstrapServers} can also be populated from the {@link
+ * KafkaSourceDescriptor}:
+ *
+ * <pre>{@code
+ * pipeline
+ *  .apply(Create.of(
+ *    KafkaSourceDescriptor.of(
+ *      new TopicPartition("topic", 1),
+ *      null,
+ *      null,
+ *      ImmutableList.of("broker_1:9092", "broker_2:9092"))
+ *  .apply(KafkaIO.readAll()
+ *         .withKeyDeserializer(LongDeserializer.class).
+ *         .withValueDeserializer(StringDeserializer.class));
+ * }</pre>
+ *
+ * <h3>Configurations of {@link ReadSourceDescriptors}</h3>
+ *
+ * <p>Except configurations of Kafka Consumer, there are some other configurations which are related
+ * to processing records.
+ *
+ * <p>{@link ReadSourceDescriptors#commitOffsets()} enables committing offset after processing the
+ * record. Note that if the {@code isolation.level} is set to "read_committed" or {@link
+ * ConsumerConfig#ENABLE_AUTO_COMMIT_CONFIG} is set in the consumer config, the {@link
+ * ReadSourceDescriptors#commitOffsets()} will be ignored.
+ *
+ * <p>{@link ReadSourceDescriptors#withExtractOutputTimestampFn(SerializableFunction)} is used to
+ * compute the {@code output timestamp} for a given {@link KafkaRecord} and controls the watermark
+ * advancement. There are three built-in types:
+ *
+ * <ul>
+ *   <li>{@link ReadSourceDescriptors#withProcessingTime()}
+ *   <li>{@link ReadSourceDescriptors#withCreateTime()}
+ *   <li>{@link ReadSourceDescriptors#withLogAppendTime()}
+ * </ul>
+ *
+ * <p>For example, to create a {@link ReadSourceDescriptors} with this additional configuration:
+ *
+ * <pre>{@code
+ * pipeline
+ * .apply(Create.of(
+ *    KafkaSourceDescriptor.of(
+ *      new TopicPartition("topic", 1),
+ *      null,
+ *      null,
+ *      ImmutableList.of("broker_1:9092", "broker_2:9092"))
+ * .apply(KafkaIO.readAll()
+ *          .withKeyDeserializer(LongDeserializer.class).
+ *          .withValueDeserializer(StringDeserializer.class)
+ *          .withProcessingTime()
+ *          .commitOffsets());
+ * }</pre>
+ *
  * <h3>Writing to Kafka</h3>
  *
  * <p>KafkaIO sink supports writing key-value pairs to a Kafka topic. Users can also write just the
@@ -295,15 +407,15 @@
   /**
    * Creates an uninitialized {@link Read} {@link PTransform}. Before use, basic Kafka configuration
    * should set with {@link Read#withBootstrapServers(String)} and {@link Read#withTopics(List)}.
-   * Other optional settings include key and value {@link Deserializer}s, custom timestamp and
+   * Other optional settings include key and value {@link Deserializer}s, custom timestamp,
    * watermark functions.
    */
   public static <K, V> Read<K, V> read() {
     return new AutoValue_KafkaIO_Read.Builder<K, V>()
         .setTopics(new ArrayList<>())
         .setTopicPartitions(new ArrayList<>())
-        .setConsumerFactoryFn(Read.KAFKA_CONSUMER_FACTORY_FN)
-        .setConsumerConfig(Read.DEFAULT_CONSUMER_PROPERTIES)
+        .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
+        .setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES)
         .setMaxNumRecords(Long.MAX_VALUE)
         .setCommitOffsetsInFinalizeEnabled(false)
         .setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
@@ -311,6 +423,17 @@
   }
 
   /**
+   * Creates an uninitialized {@link ReadSourceDescriptors} {@link PTransform}. Different from
+   * {@link Read}, setting up {@code topics} and {@code bootstrapServers} is not required during
+   * construction time. But the {@code bootstrapServers} still can be configured {@link
+   * ReadSourceDescriptors#withBootstrapServers(String)}. Please refer to {@link
+   * ReadSourceDescriptors} for more details.
+   */
+  public static <K, V> ReadSourceDescriptors<K, V> readSourceDescriptors() {
+    return ReadSourceDescriptors.<K, V>read();
+  }
+
+  /**
    * Creates an uninitialized {@link Write} {@link PTransform}. Before use, Kafka configuration
    * should be set with {@link Write#withBootstrapServers(String)} and {@link Write#withTopic} along
    * with {@link Deserializer}s for (optional) key and values.
@@ -322,7 +445,7 @@
                 .setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
                 .setEOS(false)
                 .setNumShards(0)
-                .setConsumerFactoryFn(Read.KAFKA_CONSUMER_FACTORY_FN)
+                .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
                 .build())
         .build();
   }
@@ -337,7 +460,7 @@
         .setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
         .setEOS(false)
         .setNumShards(0)
-        .setConsumerFactoryFn(Read.KAFKA_CONSUMER_FACTORY_FN)
+        .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
         .build();
   }
 
@@ -451,7 +574,9 @@
 
         // Set required defaults
         setTopicPartitions(Collections.emptyList());
-        setConsumerFactoryFn(Read.KAFKA_CONSUMER_FACTORY_FN);
+        setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN);
+        setMaxNumRecords(Long.MAX_VALUE);
+        setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN);
         if (config.maxReadTime != null) {
           setMaxReadTime(Duration.standardSeconds(config.maxReadTime));
         }
@@ -660,7 +785,7 @@
     @Deprecated
     public Read<K, V> updateConsumerProperties(Map<String, Object> configUpdates) {
       Map<String, Object> config =
-          updateKafkaProperties(getConsumerConfig(), IGNORED_CONSUMER_PROPERTIES, configUpdates);
+          KafkaIOUtils.updateKafkaProperties(getConsumerConfig(), configUpdates);
       return toBuilder().setConsumerConfig(config).build();
     }
 
@@ -861,11 +986,11 @@
      * offset;<br>
      *
      * <p>By default, main consumer uses the configuration from {@link
-     * #DEFAULT_CONSUMER_PROPERTIES}.
+     * KafkaIOUtils#DEFAULT_CONSUMER_PROPERTIES}.
      */
     public Read<K, V> withConsumerConfigUpdates(Map<String, Object> configUpdates) {
       Map<String, Object> config =
-          updateKafkaProperties(getConsumerConfig(), IGNORED_CONSUMER_PROPERTIES, configUpdates);
+          KafkaIOUtils.updateKafkaProperties(getConsumerConfig(), configUpdates);
       return toBuilder().setConsumerConfig(config).build();
     }
 
@@ -922,19 +1047,86 @@
       Coder<K> keyCoder = getKeyCoder(coderRegistry);
       Coder<V> valueCoder = getValueCoder(coderRegistry);
 
-      // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set.
-      Unbounded<KafkaRecord<K, V>> unbounded =
-          org.apache.beam.sdk.io.Read.from(
-              toBuilder().setKeyCoder(keyCoder).setValueCoder(valueCoder).build().makeSource());
+      // The Read will be expanded into SDF transform when "beam_fn_api" is enabled and
+      // "beam_fn_api_use_deprecated_read" is not enabled.
+      if (!ExperimentalOptions.hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+          || ExperimentalOptions.hasExperiment(
+              input.getPipeline().getOptions(), "beam_fn_api_use_deprecated_read")) {
+        // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set.
+        Unbounded<KafkaRecord<K, V>> unbounded =
+            org.apache.beam.sdk.io.Read.from(
+                toBuilder().setKeyCoder(keyCoder).setValueCoder(valueCoder).build().makeSource());
 
-      PTransform<PBegin, PCollection<KafkaRecord<K, V>>> transform = unbounded;
+        PTransform<PBegin, PCollection<KafkaRecord<K, V>>> transform = unbounded;
 
-      if (getMaxNumRecords() < Long.MAX_VALUE || getMaxReadTime() != null) {
-        transform =
-            unbounded.withMaxReadTime(getMaxReadTime()).withMaxNumRecords(getMaxNumRecords());
+        if (getMaxNumRecords() < Long.MAX_VALUE || getMaxReadTime() != null) {
+          transform =
+              unbounded.withMaxReadTime(getMaxReadTime()).withMaxNumRecords(getMaxNumRecords());
+        }
+
+        return input.getPipeline().apply(transform);
+      }
+      ReadSourceDescriptors<K, V> readTransform =
+          ReadSourceDescriptors.<K, V>read()
+              .withConsumerConfigOverrides(getConsumerConfig())
+              .withOffsetConsumerConfigOverrides(getOffsetConsumerConfig())
+              .withConsumerFactoryFn(getConsumerFactoryFn())
+              .withKeyDeserializerProvider(getKeyDeserializerProvider())
+              .withValueDeserializerProvider(getValueDeserializerProvider())
+              .withManualWatermarkEstimator()
+              .withTimestampPolicyFactory(getTimestampPolicyFactory());
+      if (isCommitOffsetsInFinalizeEnabled()) {
+        readTransform = readTransform.commitOffsets();
+      }
+      PCollection<KafkaSourceDescriptor> output =
+          input
+              .getPipeline()
+              .apply(Impulse.create())
+              .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this)));
+      return output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
+    }
+
+    /**
+     * A DoFn which generates {@link KafkaSourceDescriptor} based on the configuration of {@link
+     * Read}.
+     */
+    @VisibleForTesting
+    static class GenerateKafkaSourceDescriptor extends DoFn<byte[], KafkaSourceDescriptor> {
+      GenerateKafkaSourceDescriptor(Read read) {
+        this.consumerConfig = read.getConsumerConfig();
+        this.consumerFactoryFn = read.getConsumerFactoryFn();
+        this.topics = read.getTopics();
+        this.topicPartitions = read.getTopicPartitions();
+        this.startReadTime = read.getStartReadTime();
       }
 
-      return input.getPipeline().apply(transform);
+      private final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
+          consumerFactoryFn;
+
+      private final List<TopicPartition> topicPartitions;
+
+      private final Instant startReadTime;
+
+      @VisibleForTesting final Map<String, Object> consumerConfig;
+
+      @VisibleForTesting final List<String> topics;
+
+      @ProcessElement
+      public void processElement(OutputReceiver<KafkaSourceDescriptor> receiver) {
+        List<TopicPartition> partitions = new ArrayList<>(topicPartitions);
+        if (partitions.isEmpty()) {
+          try (Consumer<?, ?> consumer = consumerFactoryFn.apply(consumerConfig)) {
+            for (String topic : topics) {
+              for (PartitionInfo p : consumer.partitionsFor(topic)) {
+                partitions.add(new TopicPartition(p.topic(), p.partition()));
+              }
+            }
+          }
+        }
+        for (TopicPartition topicPartition : partitions) {
+          receiver.output(KafkaSourceDescriptor.of(topicPartition, null, startReadTime, null));
+        }
+      }
     }
 
     private Coder<K> getKeyCoder(CoderRegistry coderRegistry) {
@@ -965,45 +1157,6 @@
             final SerializableFunction<KV<KeyT, ValueT>, OutT> fn) {
       return record -> fn.apply(record.getKV());
     }
-    ///////////////////////////////////////////////////////////////////////////////////////
-
-    /** A set of properties that are not required or don't make sense for our consumer. */
-    private static final Map<String, String> IGNORED_CONSUMER_PROPERTIES =
-        ImmutableMap.of(
-            ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyDeserializer instead",
-            ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueDeserializer instead"
-            // "group.id", "enable.auto.commit", "auto.commit.interval.ms" :
-            //     lets allow these, applications can have better resume point for restarts.
-            );
-
-    // set config defaults
-    private static final Map<String, Object> DEFAULT_CONSUMER_PROPERTIES =
-        ImmutableMap.of(
-            ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
-            ByteArrayDeserializer.class.getName(),
-            ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
-            ByteArrayDeserializer.class.getName(),
-
-            // Use large receive buffer. Once KAFKA-3135 is fixed, this _may_ not be required.
-            // with default value of of 32K, It takes multiple seconds between successful polls.
-            // All the consumer work is done inside poll(), with smaller send buffer size, it
-            // takes many polls before a 1MB chunk from the server is fully read. In my testing
-            // about half of the time select() inside kafka consumer waited for 20-30ms, though
-            // the server had lots of data in tcp send buffers on its side. Compared to default,
-            // this setting increased throughput by many fold (3-4x).
-            ConsumerConfig.RECEIVE_BUFFER_CONFIG,
-            512 * 1024,
-
-            // default to latest offset when we are not resuming.
-            ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
-            "latest",
-            // disable auto commit of offsets. we don't require group_id. could be enabled by user.
-            ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
-            false);
-
-    // default Kafka 0.9 Consumer supplier.
-    private static final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
-        KAFKA_CONSUMER_FACTORY_FN = KafkaConsumer::new;
 
     @SuppressWarnings("unchecked")
     @Override
@@ -1018,10 +1171,11 @@
             DisplayData.item("topicPartitions", Joiner.on(",").join(topicPartitions))
                 .withLabel("Topic Partition/s"));
       }
-      Set<String> ignoredConsumerPropertiesKeys = IGNORED_CONSUMER_PROPERTIES.keySet();
+      Set<String> disallowedConsumerPropertiesKeys =
+          KafkaIOUtils.DISALLOWED_CONSUMER_PROPERTIES.keySet();
       for (Map.Entry<String, Object> conf : getConsumerConfig().entrySet()) {
         String key = conf.getKey();
-        if (!ignoredConsumerPropertiesKeys.contains(key)) {
+        if (!disallowedConsumerPropertiesKeys.contains(key)) {
           Object value =
               DisplayData.inferType(conf.getValue()) != null
                   ? conf.getValue()
@@ -1067,33 +1221,475 @@
     }
   }
 
+  /**
+   * A {@link PTransform} to read from {@link KafkaSourceDescriptor}. See {@link KafkaIO} for more
+   * information on usage and configuration. See {@link ReadFromKafkaDoFn} for more implementation
+   * details.
+   */
+  @Experimental(Kind.PORTABILITY)
+  @AutoValue
+  public abstract static class ReadSourceDescriptors<K, V>
+      extends PTransform<PCollection<KafkaSourceDescriptor>, PCollection<KafkaRecord<K, V>>> {
+
+    private static final Logger LOG = LoggerFactory.getLogger(ReadSourceDescriptors.class);
+
+    abstract Map<String, Object> getConsumerConfig();
+
+    @Nullable
+    abstract Map<String, Object> getOffsetConsumerConfig();
+
+    @Nullable
+    abstract DeserializerProvider getKeyDeserializerProvider();
+
+    @Nullable
+    abstract DeserializerProvider getValueDeserializerProvider();
+
+    @Nullable
+    abstract Coder<K> getKeyCoder();
+
+    @Nullable
+    abstract Coder<V> getValueCoder();
+
+    abstract SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
+        getConsumerFactoryFn();
+
+    @Nullable
+    abstract SerializableFunction<KafkaRecord<K, V>, Instant> getExtractOutputTimestampFn();
+
+    @Nullable
+    abstract SerializableFunction<Instant, WatermarkEstimator<Instant>>
+        getCreateWatermarkEstimatorFn();
+
+    abstract boolean isCommitOffsetEnabled();
+
+    @Nullable
+    abstract TimestampPolicyFactory<K, V> getTimestampPolicyFactory();
+
+    abstract ReadSourceDescriptors.Builder<K, V> toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder<K, V> {
+      abstract ReadSourceDescriptors.Builder<K, V> setConsumerConfig(Map<String, Object> config);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setOffsetConsumerConfig(
+          Map<String, Object> offsetConsumerConfig);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setConsumerFactoryFn(
+          SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setKeyDeserializerProvider(
+          DeserializerProvider deserializerProvider);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setValueDeserializerProvider(
+          DeserializerProvider deserializerProvider);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setKeyCoder(Coder<K> keyCoder);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setValueCoder(Coder<V> valueCoder);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setExtractOutputTimestampFn(
+          SerializableFunction<KafkaRecord<K, V>, Instant> fn);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setCreateWatermarkEstimatorFn(
+          SerializableFunction<Instant, WatermarkEstimator<Instant>> fn);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setCommitOffsetEnabled(
+          boolean commitOffsetEnabled);
+
+      abstract ReadSourceDescriptors.Builder<K, V> setTimestampPolicyFactory(
+          TimestampPolicyFactory<K, V> policy);
+
+      abstract ReadSourceDescriptors<K, V> build();
+    }
+
+    public static <K, V> ReadSourceDescriptors<K, V> read() {
+      return new AutoValue_KafkaIO_ReadSourceDescriptors.Builder<K, V>()
+          .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
+          .setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES)
+          .setCommitOffsetEnabled(false)
+          .build()
+          .withProcessingTime()
+          .withMonotonicallyIncreasingWatermarkEstimator();
+    }
+
+    /**
+     * Sets the bootstrap servers to use for the Kafka consumer if unspecified via
+     * KafkaSourceDescriptor#getBootStrapServers()}.
+     */
+    public ReadSourceDescriptors<K, V> withBootstrapServers(String bootstrapServers) {
+      return withConsumerConfigUpdates(
+          ImmutableMap.of(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers));
+    }
+
+    public ReadSourceDescriptors<K, V> withKeyDeserializerProvider(
+        DeserializerProvider<K> deserializerProvider) {
+      return toBuilder().setKeyDeserializerProvider(deserializerProvider).build();
+    }
+
+    public ReadSourceDescriptors<K, V> withValueDeserializerProvider(
+        DeserializerProvider<V> deserializerProvider) {
+      return toBuilder().setValueDeserializerProvider(deserializerProvider).build();
+    }
+
+    /**
+     * Sets a Kafka {@link Deserializer} to interpret key bytes read from Kafka.
+     *
+     * <p>In addition, Beam also needs a {@link Coder} to serialize and deserialize key objects at
+     * runtime. KafkaIO tries to infer a coder for the key based on the {@link Deserializer} class,
+     * however in case that fails, you can use {@link #withKeyDeserializerAndCoder(Class, Coder)} to
+     * provide the key coder explicitly.
+     */
+    public ReadSourceDescriptors<K, V> withKeyDeserializer(
+        Class<? extends Deserializer<K>> keyDeserializer) {
+      return withKeyDeserializerProvider(LocalDeserializerProvider.of(keyDeserializer));
+    }
+
+    /**
+     * Sets a Kafka {@link Deserializer} to interpret value bytes read from Kafka.
+     *
+     * <p>In addition, Beam also needs a {@link Coder} to serialize and deserialize value objects at
+     * runtime. KafkaIO tries to infer a coder for the value based on the {@link Deserializer}
+     * class, however in case that fails, you can use {@link #withValueDeserializerAndCoder(Class,
+     * Coder)} to provide the value coder explicitly.
+     */
+    public ReadSourceDescriptors<K, V> withValueDeserializer(
+        Class<? extends Deserializer<V>> valueDeserializer) {
+      return withValueDeserializerProvider(LocalDeserializerProvider.of(valueDeserializer));
+    }
+
+    /**
+     * Sets a Kafka {@link Deserializer} for interpreting key bytes read from Kafka along with a
+     * {@link Coder} for helping the Beam runner materialize key objects at runtime if necessary.
+     *
+     * <p>Use this method to override the coder inference performed within {@link
+     * #withKeyDeserializer(Class)}.
+     */
+    public ReadSourceDescriptors<K, V> withKeyDeserializerAndCoder(
+        Class<? extends Deserializer<K>> keyDeserializer, Coder<K> keyCoder) {
+      return withKeyDeserializer(keyDeserializer).toBuilder().setKeyCoder(keyCoder).build();
+    }
+
+    /**
+     * Sets a Kafka {@link Deserializer} for interpreting value bytes read from Kafka along with a
+     * {@link Coder} for helping the Beam runner materialize value objects at runtime if necessary.
+     *
+     * <p>Use this method to override the coder inference performed within {@link
+     * #withValueDeserializer(Class)}.
+     */
+    public ReadSourceDescriptors<K, V> withValueDeserializerAndCoder(
+        Class<? extends Deserializer<V>> valueDeserializer, Coder<V> valueCoder) {
+      return withValueDeserializer(valueDeserializer).toBuilder().setValueCoder(valueCoder).build();
+    }
+
+    /**
+     * A factory to create Kafka {@link Consumer} from consumer configuration. This is useful for
+     * supporting another version of Kafka consumer. Default is {@link KafkaConsumer}.
+     */
+    public ReadSourceDescriptors<K, V> withConsumerFactoryFn(
+        SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn) {
+      return toBuilder().setConsumerFactoryFn(consumerFactoryFn).build();
+    }
+
+    /**
+     * Updates configuration for the main consumer. This method merges updates from the provided map
+     * with with any prior updates using {@link KafkaIOUtils#DEFAULT_CONSUMER_PROPERTIES} as the
+     * starting configuration.
+     *
+     * <p>In {@link ReadFromKafkaDoFn}, there're two consumers running in the backend:
+     *
+     * <ol>
+     *   <li>the main consumer which reads data from kafka.
+     *   <li>the secondary offset consumer which is used to estimate the backlog by fetching the
+     *       latest offset.
+     * </ol>
+     *
+     * <p>See {@link #withConsumerConfigOverrides} for overriding the configuration instead of
+     * updating it.
+     *
+     * <p>See {@link #withOffsetConsumerConfigOverrides} for configuring the secondary offset
+     * consumer.
+     */
+    public ReadSourceDescriptors<K, V> withConsumerConfigUpdates(
+        Map<String, Object> configUpdates) {
+      Map<String, Object> config =
+          KafkaIOUtils.updateKafkaProperties(getConsumerConfig(), configUpdates);
+      return toBuilder().setConsumerConfig(config).build();
+    }
+
+    /**
+     * A function to calculate output timestamp for a given {@link KafkaRecord}. The default value
+     * is {@link #withProcessingTime()}.
+     */
+    public ReadSourceDescriptors<K, V> withExtractOutputTimestampFn(
+        SerializableFunction<KafkaRecord<K, V>, Instant> fn) {
+      return toBuilder().setExtractOutputTimestampFn(fn).build();
+    }
+
+    /**
+     * A function to create a {@link WatermarkEstimator}. The default value is {@link
+     * MonotonicallyIncreasing}.
+     */
+    public ReadSourceDescriptors<K, V> withCreatWatermarkEstimatorFn(
+        SerializableFunction<Instant, WatermarkEstimator<Instant>> fn) {
+      return toBuilder().setCreateWatermarkEstimatorFn(fn).build();
+    }
+
+    /** Use the log append time as the output timestamp. */
+    public ReadSourceDescriptors<K, V> withLogAppendTime() {
+      return withExtractOutputTimestampFn(
+          ReadSourceDescriptors.ExtractOutputTimestampFns.useLogAppendTime());
+    }
+
+    /** Use the processing time as the output timestamp. */
+    public ReadSourceDescriptors<K, V> withProcessingTime() {
+      return withExtractOutputTimestampFn(
+          ReadSourceDescriptors.ExtractOutputTimestampFns.useProcessingTime());
+    }
+
+    /** Use the creation time of {@link KafkaRecord} as the output timestamp. */
+    public ReadSourceDescriptors<K, V> withCreateTime() {
+      return withExtractOutputTimestampFn(
+          ReadSourceDescriptors.ExtractOutputTimestampFns.useCreateTime());
+    }
+
+    /** Use the {@link WallTime} as the watermark estimator. */
+    public ReadSourceDescriptors<K, V> withWallTimeWatermarkEstimator() {
+      return withCreatWatermarkEstimatorFn(
+          state -> {
+            return new WallTime(state);
+          });
+    }
+
+    /** Use the {@link MonotonicallyIncreasing} as the watermark estimator. */
+    public ReadSourceDescriptors<K, V> withMonotonicallyIncreasingWatermarkEstimator() {
+      return withCreatWatermarkEstimatorFn(
+          state -> {
+            return new MonotonicallyIncreasing(state);
+          });
+    }
+
+    /** Use the {@link Manual} as the watermark estimator. */
+    public ReadSourceDescriptors<K, V> withManualWatermarkEstimator() {
+      return withCreatWatermarkEstimatorFn(
+          state -> {
+            return new Manual(state);
+          });
+    }
+
+    /**
+     * Sets "isolation_level" to "read_committed" in Kafka consumer configuration. This ensures that
+     * the consumer does not read uncommitted messages. Kafka version 0.11 introduced transactional
+     * writes. Applications requiring end-to-end exactly-once semantics should only read committed
+     * messages. See JavaDoc for {@link KafkaConsumer} for more description.
+     */
+    public ReadSourceDescriptors<K, V> withReadCommitted() {
+      return withConsumerConfigUpdates(ImmutableMap.of("isolation.level", "read_committed"));
+    }
+
+    /**
+     * Enable committing record offset. If {@link #withReadCommitted()} or {@link
+     * ConsumerConfig#ENABLE_AUTO_COMMIT_CONFIG} is set together with {@link #commitOffsets()},
+     * {@link #commitOffsets()} will be ignored.
+     */
+    public ReadSourceDescriptors<K, V> commitOffsets() {
+      return toBuilder().setCommitOffsetEnabled(true).build();
+    }
+
+    /**
+     * Set additional configuration for the offset consumer. It may be required for a secured Kafka
+     * cluster, especially when you see similar WARN log message {@code exception while fetching
+     * latest offset for partition {}. will be retried}.
+     *
+     * <p>In {@link ReadFromKafkaDoFn}, there are two consumers running in the backend:
+     *
+     * <ol>
+     *   <li>the main consumer which reads data from kafka.
+     *   <li>the secondary offset consumer which is used to estimate the backlog by fetching the
+     *       latest offset.
+     * </ol>
+     *
+     * <p>By default, offset consumer inherits the configuration from main consumer, with an
+     * auto-generated {@link ConsumerConfig#GROUP_ID_CONFIG}. This may not work in a secured Kafka
+     * which requires additional configuration.
+     *
+     * <p>See {@link #withConsumerConfigUpdates} for configuring the main consumer.
+     */
+    public ReadSourceDescriptors<K, V> withOffsetConsumerConfigOverrides(
+        Map<String, Object> offsetConsumerConfig) {
+      return toBuilder().setOffsetConsumerConfig(offsetConsumerConfig).build();
+    }
+
+    /**
+     * Replaces the configuration for the main consumer.
+     *
+     * <p>In {@link ReadFromKafkaDoFn}, there are two consumers running in the backend:
+     *
+     * <ol>
+     *   <li>the main consumer which reads data from kafka.
+     *   <li>the secondary offset consumer which is used to estimate the backlog by fetching the
+     *       latest offset.
+     * </ol>
+     *
+     * <p>By default, main consumer uses the configuration from {@link
+     * KafkaIOUtils#DEFAULT_CONSUMER_PROPERTIES}.
+     *
+     * <p>See {@link #withConsumerConfigUpdates} for updating the configuration instead of
+     * overriding it.
+     */
+    public ReadSourceDescriptors<K, V> withConsumerConfigOverrides(
+        Map<String, Object> consumerConfig) {
+      return toBuilder().setConsumerConfig(consumerConfig).build();
+    }
+
+    ReadAllFromRow forExternalBuild() {
+      return new ReadAllFromRow(this);
+    }
+
+    /**
+     * A transform that is used in cross-language case. The input {@link Row} should be encoded with
+     * an equivalent schema as {@link KafkaSourceDescriptor}.
+     */
+    private static class ReadAllFromRow<K, V>
+        extends PTransform<PCollection<Row>, PCollection<KV<K, V>>> {
+
+      private final ReadSourceDescriptors<K, V> readViaSDF;
+
+      ReadAllFromRow(ReadSourceDescriptors read) {
+        readViaSDF = read;
+      }
+
+      @Override
+      public PCollection<KV<K, V>> expand(PCollection<Row> input) {
+        return input
+            .apply(Convert.fromRows(KafkaSourceDescriptor.class))
+            .apply(readViaSDF)
+            .apply(
+                ParDo.of(
+                    new DoFn<KafkaRecord<K, V>, KV<K, V>>() {
+                      @ProcessElement
+                      public void processElement(
+                          @Element KafkaRecord element, OutputReceiver<KV<K, V>> outputReceiver) {
+                        outputReceiver.output(element.getKV());
+                      }
+                    }))
+            .setCoder(KvCoder.<K, V>of(readViaSDF.getKeyCoder(), readViaSDF.getValueCoder()));
+      }
+    }
+
+    /**
+     * Set the {@link TimestampPolicyFactory}. If the {@link TimestampPolicyFactory} is given, the
+     * output timestamp will be computed by the {@link
+     * TimestampPolicyFactory#createTimestampPolicy(TopicPartition, Optional)} and {@link Manual} is
+     * used as the watermark estimator.
+     */
+    ReadSourceDescriptors<K, V> withTimestampPolicyFactory(
+        TimestampPolicyFactory<K, V> timestampPolicyFactory) {
+      return toBuilder()
+          .setTimestampPolicyFactory(timestampPolicyFactory)
+          .build()
+          .withManualWatermarkEstimator();
+    }
+
+    @Override
+    public PCollection<KafkaRecord<K, V>> expand(PCollection<KafkaSourceDescriptor> input) {
+      checkArgument(
+          ExperimentalOptions.hasExperiment(input.getPipeline().getOptions(), "beam_fn_api"),
+          "The ReadSourceDescriptors can only used when beam_fn_api is enabled.");
+
+      checkArgument(getKeyDeserializerProvider() != null, "withKeyDeserializer() is required");
+      checkArgument(getValueDeserializerProvider() != null, "withValueDeserializer() is required");
+
+      ConsumerSpEL consumerSpEL = new ConsumerSpEL();
+      if (!consumerSpEL.hasOffsetsForTimes()) {
+        LOG.warn(
+            "Kafka client version {} is too old. Versions before 0.10.1.0 are deprecated and "
+                + "may not be supported in next release of Apache Beam. "
+                + "Please upgrade your Kafka client version.",
+            AppInfoParser.getVersion());
+      }
+
+      if (isCommitOffsetEnabled()) {
+        if (configuredKafkaCommit()) {
+          LOG.info(
+              "Either read_committed or auto_commit is set together with commitOffsetEnabled but you "
+                  + "only need one of them. The commitOffsetEnabled is going to be ignored");
+        }
+      }
+
+      if (getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) == null) {
+        LOG.warn(
+            "The bootstrapServers is not set. It must be populated through the KafkaSourceDescriptor during runtime otherwise the pipeline will fail.");
+      }
+
+      CoderRegistry coderRegistry = input.getPipeline().getCoderRegistry();
+      Coder<K> keyCoder = getKeyCoder(coderRegistry);
+      Coder<V> valueCoder = getValueCoder(coderRegistry);
+      Coder<KafkaRecord<K, V>> outputCoder = KafkaRecordCoder.of(keyCoder, valueCoder);
+      PCollection<KafkaRecord<K, V>> output =
+          input.apply(ParDo.of(new ReadFromKafkaDoFn<K, V>(this))).setCoder(outputCoder);
+      // TODO(BEAM-10123): Add CommitOffsetTransform to expansion.
+      if (isCommitOffsetEnabled() && !configuredKafkaCommit()) {
+        throw new IllegalStateException("Offset committed is not supported yet");
+      }
+      return output;
+    }
+
+    private Coder<K> getKeyCoder(CoderRegistry coderRegistry) {
+      return (getKeyCoder() != null)
+          ? getKeyCoder()
+          : getKeyDeserializerProvider().getCoder(coderRegistry);
+    }
+
+    private Coder<V> getValueCoder(CoderRegistry coderRegistry) {
+      return (getValueCoder() != null)
+          ? getValueCoder()
+          : getValueDeserializerProvider().getCoder(coderRegistry);
+    }
+
+    private boolean configuredKafkaCommit() {
+      return getConsumerConfig().get("isolation.level") == "read_committed"
+          || Boolean.TRUE.equals(getConsumerConfig().get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
+    }
+
+    static class ExtractOutputTimestampFns<K, V> {
+      public static <K, V> SerializableFunction<KafkaRecord<K, V>, Instant> useProcessingTime() {
+        return record -> Instant.now();
+      }
+
+      public static <K, V> SerializableFunction<KafkaRecord<K, V>, Instant> useCreateTime() {
+        return record -> {
+          checkArgument(
+              record.getTimestampType() == KafkaTimestampType.CREATE_TIME,
+              "Kafka record's timestamp is not 'CREATE_TIME' "
+                  + "(topic: %s, partition %s, offset %s, timestamp type '%s')",
+              record.getTopic(),
+              record.getPartition(),
+              record.getOffset(),
+              record.getTimestampType());
+          return new Instant(record.getTimestamp());
+        };
+      }
+
+      public static <K, V> SerializableFunction<KafkaRecord<K, V>, Instant> useLogAppendTime() {
+        return record -> {
+          checkArgument(
+              record.getTimestampType() == KafkaTimestampType.LOG_APPEND_TIME,
+              "Kafka record's timestamp is not 'LOG_APPEND_TIME' "
+                  + "(topic: %s, partition %s, offset %s, timestamp type '%s')",
+              record.getTopic(),
+              record.getPartition(),
+              record.getOffset(),
+              record.getTimestampType());
+          return new Instant(record.getTimestamp());
+        };
+      }
+    }
+  }
+
   ////////////////////////////////////////////////////////////////////////////////////////////////
 
   private static final Logger LOG = LoggerFactory.getLogger(KafkaIO.class);
 
-  /**
-   * Returns a new config map which is merge of current config and updates. Verifies the updates do
-   * not includes ignored properties.
-   */
-  private static Map<String, Object> updateKafkaProperties(
-      Map<String, Object> currentConfig,
-      Map<String, String> ignoredProperties,
-      Map<String, Object> updates) {
-
-    for (String key : updates.keySet()) {
-      checkArgument(
-          !ignoredProperties.containsKey(key),
-          "No need to configure '%s'. %s",
-          key,
-          ignoredProperties.get(key));
-    }
-
-    Map<String, Object> config = new HashMap<>(currentConfig);
-    config.putAll(updates);
-
-    return config;
-  }
-
   /** Static class, prevent instantiation. */
   private KafkaIO() {}
 
@@ -1205,7 +1801,7 @@
     @Deprecated
     public WriteRecords<K, V> updateProducerProperties(Map<String, Object> configUpdates) {
       Map<String, Object> config =
-          updateKafkaProperties(getProducerConfig(), IGNORED_PRODUCER_PROPERTIES, configUpdates);
+          KafkaIOUtils.updateKafkaProperties(getProducerConfig(), configUpdates);
       return toBuilder().setProducerConfig(config).build();
     }
 
@@ -1217,7 +1813,7 @@
      */
     public WriteRecords<K, V> withProducerConfigUpdates(Map<String, Object> configUpdates) {
       Map<String, Object> config =
-          updateKafkaProperties(getProducerConfig(), IGNORED_PRODUCER_PROPERTIES, configUpdates);
+          KafkaIOUtils.updateKafkaProperties(getProducerConfig(), configUpdates);
       return toBuilder().setProducerConfig(config).build();
     }
 
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java
new file mode 100644
index 0000000..0589a05
--- /dev/null
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.common.serialization.ByteArrayDeserializer;
+
+/**
+ * Common utility functions and default configurations for {@link KafkaIO.Read} and {@link
+ * KafkaIO.ReadSourceDescriptors}.
+ */
+final class KafkaIOUtils {
+  // A set of config defaults.
+  static final Map<String, Object> DEFAULT_CONSUMER_PROPERTIES =
+      ImmutableMap.of(
+          ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
+          ByteArrayDeserializer.class.getName(),
+          ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
+          ByteArrayDeserializer.class.getName(),
+
+          // Use large receive buffer. Once KAFKA-3135 is fixed, this _may_ not be required.
+          // with default value of of 32K, It takes multiple seconds between successful polls.
+          // All the consumer work is done inside poll(), with smaller send buffer size, it
+          // takes many polls before a 1MB chunk from the server is fully read. In my testing
+          // about half of the time select() inside kafka consumer waited for 20-30ms, though
+          // the server had lots of data in tcp send buffers on its side. Compared to default,
+          // this setting increased throughput by many fold (3-4x).
+          ConsumerConfig.RECEIVE_BUFFER_CONFIG,
+          512 * 1024,
+
+          // default to latest offset when we are not resuming.
+          ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
+          "latest",
+          // disable auto commit of offsets. we don't require group_id. could be enabled by user.
+          ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
+          false);
+
+  // A set of properties that are not required or don't make sense for our consumer.
+  static final Map<String, String> DISALLOWED_CONSUMER_PROPERTIES =
+      ImmutableMap.of(
+          ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyDeserializer instead",
+          ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueDeserializer instead"
+          // "group.id", "enable.auto.commit", "auto.commit.interval.ms" :
+          //     lets allow these, applications can have better resume point for restarts.
+          );
+
+  // default Kafka 0.9 Consumer supplier.
+  static final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
+      KAFKA_CONSUMER_FACTORY_FN = KafkaConsumer::new;
+
+  /**
+   * Returns a new config map which is merge of current config and updates. Verifies the updates do
+   * not includes ignored properties.
+   */
+  static Map<String, Object> updateKafkaProperties(
+      Map<String, Object> currentConfig, Map<String, Object> updates) {
+
+    for (String key : updates.keySet()) {
+      checkArgument(
+          !DISALLOWED_CONSUMER_PROPERTIES.containsKey(key),
+          "No need to configure '%s'. %s",
+          key,
+          DISALLOWED_CONSUMER_PROPERTIES.get(key));
+    }
+
+    Map<String, Object> config = new HashMap<>(currentConfig);
+    config.putAll(updates);
+
+    return config;
+  }
+
+  static Map<String, Object> getOffsetConsumerConfig(
+      String name, Map<String, Object> offsetConfig, Map<String, Object> consumerConfig) {
+    Map<String, Object> offsetConsumerConfig = new HashMap<>(consumerConfig);
+    offsetConsumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
+
+    Object groupId = consumerConfig.get(ConsumerConfig.GROUP_ID_CONFIG);
+    // override group_id and disable auto_commit so that it does not interfere with main consumer
+    String offsetGroupId =
+        String.format(
+            "%s_offset_consumer_%d_%s",
+            name, (new Random()).nextInt(Integer.MAX_VALUE), (groupId == null ? "none" : groupId));
+    offsetConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId);
+
+    if (offsetConfig != null) {
+      offsetConsumerConfig.putAll(offsetConfig);
+    }
+
+    // Force read isolation level to 'read_uncommitted' for offset consumer. This consumer
+    // fetches latest offset for two reasons : (a) to calculate backlog (number of records
+    // yet to be consumed) (b) to advance watermark if the backlog is zero. The right thing to do
+    // for (a) is to leave this config unchanged from the main config (i.e. if there are records
+    // that can't be read because of uncommitted records before them, they shouldn't
+    // ideally count towards backlog when "read_committed" is enabled. But (b)
+    // requires finding out if there are any records left to be read (committed or uncommitted).
+    // Rather than using two separate consumers we will go with better support for (b). If we do
+    // hit a case where a lot of records are not readable (due to some stuck transactions), the
+    // pipeline would report more backlog, but would not be able to consume it. It might be ok
+    // since CPU consumed on the workers would be low and will likely avoid unnecessary upscale.
+    offsetConsumerConfig.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_uncommitted");
+
+    return offsetConsumerConfig;
+  }
+
+  // Maintains approximate average over last 1000 elements
+  static class MovingAvg {
+    private static final int MOVING_AVG_WINDOW = 1000;
+    private double avg = 0;
+    private long numUpdates = 0;
+
+    void update(double quantity) {
+      numUpdates++;
+      avg += (quantity - avg) / Math.min(MOVING_AVG_WINDOW, numUpdates);
+    }
+
+    double get() {
+      return avg;
+    }
+  }
+}
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSourceDescriptor.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSourceDescriptor.java
new file mode 100644
index 0000000..d2027eb
--- /dev/null
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSourceDescriptor.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
+import org.apache.beam.sdk.schemas.annotations.SchemaIgnore;
+import org.apache.kafka.common.TopicPartition;
+import org.joda.time.Instant;
+
+/** Represents a Kafka source description. */
+@DefaultSchema(AutoValueSchema.class)
+@AutoValue
+public abstract class KafkaSourceDescriptor implements Serializable {
+  @SchemaFieldName("topic")
+  abstract String getTopic();
+
+  @SchemaFieldName("partition")
+  abstract Integer getPartition();
+
+  @SchemaFieldName("start_read_offset")
+  @Nullable
+  abstract Long getStartReadOffset();
+
+  @SchemaFieldName("start_read_time")
+  @Nullable
+  abstract Instant getStartReadTime();
+
+  @SchemaFieldName("bootstrap_servers")
+  @Nullable
+  abstract List<String> getBootStrapServers();
+
+  private TopicPartition topicPartition = null;
+
+  @SchemaIgnore
+  public TopicPartition getTopicPartition() {
+    if (topicPartition == null) {
+      topicPartition = new TopicPartition(getTopic(), getPartition());
+    }
+    return topicPartition;
+  }
+
+  public static KafkaSourceDescriptor of(
+      TopicPartition topicPartition,
+      Long startReadOffset,
+      Instant startReadTime,
+      List<String> bootstrapServers) {
+    return new AutoValue_KafkaSourceDescriptor(
+        topicPartition.topic(),
+        topicPartition.partition(),
+        startReadOffset,
+        startReadTime,
+        bootstrapServers);
+  }
+}
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
index 2a89549..d21e967 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
@@ -23,13 +23,11 @@
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Optional;
-import java.util.Random;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -123,7 +121,9 @@
     consumerPollThread.submit(this::consumerPollLoop);
 
     // offsetConsumer setup :
-    Map<String, Object> offsetConsumerConfig = getOffsetConsumerConfig();
+    Map<String, Object> offsetConsumerConfig =
+        KafkaIOUtils.getOffsetConsumerConfig(
+            name, spec.getOffsetConsumerConfig(), spec.getConsumerConfig());
 
     offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig);
     consumerSpEL.evaluateAssign(offsetConsumer, spec.getTopicPartitions());
@@ -364,23 +364,7 @@
     return name;
   }
 
-  // Maintains approximate average over last 1000 elements
-  private static class MovingAvg {
-    private static final int MOVING_AVG_WINDOW = 1000;
-    private double avg = 0;
-    private long numUpdates = 0;
-
-    void update(double quantity) {
-      numUpdates++;
-      avg += (quantity - avg) / Math.min(MOVING_AVG_WINDOW, numUpdates);
-    }
-
-    double get() {
-      return avg;
-    }
-  }
-
-  private static class TimestampPolicyContext extends TimestampPolicy.PartitionContext {
+  static class TimestampPolicyContext extends TimestampPolicy.PartitionContext {
 
     private final long messageBacklog;
     private final Instant backlogCheckTime;
@@ -412,8 +396,9 @@
 
     private Iterator<ConsumerRecord<byte[], byte[]>> recordIter = Collections.emptyIterator();
 
-    private MovingAvg avgRecordSize = new MovingAvg();
-    private MovingAvg avgOffsetGap = new MovingAvg(); // > 0 only when log compaction is enabled.
+    private KafkaIOUtils.MovingAvg avgRecordSize = new KafkaIOUtils.MovingAvg();
+    private KafkaIOUtils.MovingAvg avgOffsetGap =
+        new KafkaIOUtils.MovingAvg(); // > 0 only when log compaction is enabled.
 
     PartitionState(
         TopicPartition partition, long nextOffset, TimestampPolicy<K, V> timestampPolicy) {
@@ -687,39 +672,6 @@
     return backlogCount;
   }
 
-  @VisibleForTesting
-  Map<String, Object> getOffsetConsumerConfig() {
-    Map<String, Object> offsetConsumerConfig = new HashMap<>(source.getSpec().getConsumerConfig());
-    offsetConsumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
-
-    Object groupId = source.getSpec().getConsumerConfig().get(ConsumerConfig.GROUP_ID_CONFIG);
-    // override group_id and disable auto_commit so that it does not interfere with main consumer
-    String offsetGroupId =
-        String.format(
-            "%s_offset_consumer_%d_%s",
-            name, (new Random()).nextInt(Integer.MAX_VALUE), (groupId == null ? "none" : groupId));
-    offsetConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId);
-
-    if (source.getSpec().getOffsetConsumerConfig() != null) {
-      offsetConsumerConfig.putAll(source.getSpec().getOffsetConsumerConfig());
-    }
-
-    // Force read isolation level to 'read_uncommitted' for offset consumer. This consumer
-    // fetches latest offset for two reasons : (a) to calculate backlog (number of records
-    // yet to be consumed) (b) to advance watermark if the backlog is zero. The right thing to do
-    // for (a) is to leave this config unchanged from the main config (i.e. if there are records
-    // that can't be read because of uncommitted records before them, they shouldn't
-    // ideally count towards backlog when "read_committed" is enabled. But (b)
-    // requires finding out if there are any records left to be read (committed or uncommitted).
-    // Rather than using two separate consumers we will go with better support for (b). If we do
-    // hit a case where a lot of records are not readable (due to some stuck transactions), the
-    // pipeline would report more backlog, but would not be able to consume it. It might be ok
-    // since CPU consumed on the workers would be low and will likely avoid unnecessary upscale.
-    offsetConsumerConfig.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_uncommitted");
-
-    return offsetConsumerConfig;
-  }
-
   @Override
   public void close() throws IOException {
     closed.set(true);
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
new file mode 100644
index 0000000..f42dd23
--- /dev/null
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors;
+import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg;
+import org.apache.beam.sdk.io.kafka.KafkaUnboundedReader.TimestampPolicyContext;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.splittabledofn.GrowableOffsetRangeTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Closeables;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.serialization.Deserializer;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SplittableDoFn which reads from {@link KafkaSourceDescriptor} and outputs {@link KafkaRecord}.
+ * By default, a {@link MonotonicallyIncreasing} watermark estimator is used to track watermark.
+ *
+ * <p>{@link ReadFromKafkaDoFn} implements the logic of reading from Kafka. The element is a {@link
+ * KafkaSourceDescriptor}, and the restriction is an {@link OffsetRange} which represents record
+ * offset. A {@link GrowableOffsetRangeTracker} is used to track an {@link OffsetRange} ended with
+ * {@code Long.MAX_VALUE}. For a finite range, a {@link OffsetRangeTracker} is created.
+ *
+ * <h4>Initial Restriction</h4>
+ *
+ * <p>The initial range for a {@link KafkaSourceDescriptor} is defined by {@code [startOffset,
+ * Long.MAX_VALUE)} where {@code startOffset} is defined as:
+ *
+ * <ul>
+ *   <li>the {@code startReadOffset} if {@link KafkaSourceDescriptor#getStartReadOffset} is set.
+ *   <li>the first offset with a greater or equivalent timestamp if {@link
+ *       KafkaSourceDescriptor#getStartReadTime()} is set.
+ *   <li>the {@code last committed offset + 1} for the {@link Consumer#position(TopicPartition)
+ *       topic partition}.
+ * </ul>
+ *
+ * <h4>Splitting</h4>
+ *
+ * <p>TODO(BEAM-10319): Add support for initial splitting.
+ *
+ * <h4>Checkpoint and Resume Processing</h4>
+ *
+ * <p>There are 2 types of checkpoint here: self-checkpoint which invokes by the DoFn and
+ * system-checkpoint which is issued by the runner via {@link
+ * org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest}. Every time the
+ * consumer gets empty response from {@link Consumer#poll(long)}, {@link ReadFromKafkaDoFn} will
+ * checkpoint the current {@link KafkaSourceDescriptor} and move to process the next element. These
+ * deferred elements will be resumed by the runner as soon as possible.
+ *
+ * <h4>Progress and Size</h4>
+ *
+ * <p>The progress is provided by {@link GrowableOffsetRangeTracker} or per {@link
+ * KafkaSourceDescriptor}. For an infinite {@link OffsetRange}, a Kafka {@link Consumer} is used in
+ * the {@link GrowableOffsetRangeTracker} as the {@link
+ * GrowableOffsetRangeTracker.RangeEndEstimator} to poll the latest offset. Please refer to {@link
+ * ReadFromKafkaDoFn#restrictionTracker(KafkaSourceDescriptor, OffsetRange)} for details.
+ *
+ * <p>The size is computed by {@link ReadFromKafkaDoFn#getSize(KafkaSourceDescriptor, OffsetRange)}.
+ * A {@link KafkaIOUtils.MovingAvg} is used to track the average size of kafka records.
+ *
+ * <h4>Track Watermark</h4>
+ *
+ * <p>The {@link WatermarkEstimator} is created by {@link
+ * ReadSourceDescriptors#getCreateWatermarkEstimatorFn()}. The estimated watermark is computed by
+ * this {@link WatermarkEstimator} based on output timestamps computed by {@link
+ * ReadSourceDescriptors#getExtractOutputTimestampFn()} (SerializableFunction)}. The default
+ * configuration is using {@link ReadSourceDescriptors#withProcessingTime()} as the {@code
+ * extractTimestampFn} and {@link
+ * ReadSourceDescriptors#withMonotonicallyIncreasingWatermarkEstimator()} as the {@link
+ * WatermarkEstimator}.
+ */
+@UnboundedPerElement
+class ReadFromKafkaDoFn<K, V> extends DoFn<KafkaSourceDescriptor, KafkaRecord<K, V>> {
+
+  ReadFromKafkaDoFn(ReadSourceDescriptors transform) {
+    this.consumerConfig = transform.getConsumerConfig();
+    this.offsetConsumerConfig = transform.getOffsetConsumerConfig();
+    this.keyDeserializerProvider = transform.getKeyDeserializerProvider();
+    this.valueDeserializerProvider = transform.getValueDeserializerProvider();
+    this.consumerFactoryFn = transform.getConsumerFactoryFn();
+    this.extractOutputTimestampFn = transform.getExtractOutputTimestampFn();
+    this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn();
+    this.timestampPolicyFactory = transform.getTimestampPolicyFactory();
+  }
+
+  private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
+
+  private final Map<String, Object> offsetConsumerConfig;
+
+  private final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
+      consumerFactoryFn;
+  private final SerializableFunction<KafkaRecord<K, V>, Instant> extractOutputTimestampFn;
+  private final SerializableFunction<Instant, WatermarkEstimator<Instant>>
+      createWatermarkEstimatorFn;
+  private final TimestampPolicyFactory<K, V> timestampPolicyFactory;
+
+  // Valid between bundle start and bundle finish.
+  private transient ConsumerSpEL consumerSpEL = null;
+  private transient Deserializer<K> keyDeserializerInstance = null;
+  private transient Deserializer<V> valueDeserializerInstance = null;
+
+  private transient LoadingCache<TopicPartition, AverageRecordSize> avgRecordSize;
+
+  private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000);
+
+  @VisibleForTesting final DeserializerProvider keyDeserializerProvider;
+  @VisibleForTesting final DeserializerProvider valueDeserializerProvider;
+  @VisibleForTesting final Map<String, Object> consumerConfig;
+
+  /**
+   * A {@link GrowableOffsetRangeTracker.RangeEndEstimator} which uses a Kafka {@link Consumer} to
+   * fetch backlog.
+   */
+  private static class KafkaLatestOffsetEstimator
+      implements GrowableOffsetRangeTracker.RangeEndEstimator {
+
+    private final Consumer<byte[], byte[]> offsetConsumer;
+    private final TopicPartition topicPartition;
+    private final ConsumerSpEL consumerSpEL;
+    private final Supplier<Long> memoizedBacklog;
+
+    KafkaLatestOffsetEstimator(
+        Consumer<byte[], byte[]> offsetConsumer, TopicPartition topicPartition) {
+      this.offsetConsumer = offsetConsumer;
+      this.topicPartition = topicPartition;
+      this.consumerSpEL = new ConsumerSpEL();
+      this.consumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition));
+      memoizedBacklog =
+          Suppliers.memoizeWithExpiration(
+              () -> {
+                consumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition);
+                return offsetConsumer.position(topicPartition);
+              },
+              5,
+              TimeUnit.SECONDS);
+    }
+
+    @Override
+    protected void finalize() {
+      try {
+        Closeables.close(offsetConsumer, true);
+      } catch (Exception anyException) {
+        LOG.warn("Failed to close offset consumer for {}", topicPartition);
+      }
+    }
+
+    @Override
+    public long estimate() {
+      return memoizedBacklog.get();
+    }
+  }
+
+  @GetInitialRestriction
+  public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSourceDescriptor) {
+    Map<String, Object> updatedConsumerConfig =
+        overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor);
+    try (Consumer<byte[], byte[]> offsetConsumer =
+        consumerFactoryFn.apply(
+            KafkaIOUtils.getOffsetConsumerConfig(
+                "initialOffset", offsetConsumerConfig, updatedConsumerConfig))) {
+      consumerSpEL.evaluateAssign(
+          offsetConsumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
+      long startOffset;
+      if (kafkaSourceDescriptor.getStartReadOffset() != null) {
+        startOffset = kafkaSourceDescriptor.getStartReadOffset();
+      } else if (kafkaSourceDescriptor.getStartReadTime() != null) {
+        startOffset =
+            consumerSpEL.offsetForTime(
+                offsetConsumer,
+                kafkaSourceDescriptor.getTopicPartition(),
+                kafkaSourceDescriptor.getStartReadTime());
+      } else {
+        startOffset = offsetConsumer.position(kafkaSourceDescriptor.getTopicPartition());
+      }
+      return new OffsetRange(startOffset, Long.MAX_VALUE);
+    }
+  }
+
+  @GetInitialWatermarkEstimatorState
+  public Instant getInitialWatermarkEstimatorState(@Timestamp Instant currentElementTimestamp) {
+    return currentElementTimestamp;
+  }
+
+  @NewWatermarkEstimator
+  public WatermarkEstimator<Instant> newWatermarkEstimator(
+      @WatermarkEstimatorState Instant watermarkEstimatorState) {
+    return createWatermarkEstimatorFn.apply(watermarkEstimatorState);
+  }
+
+  @GetSize
+  public double getSize(
+      @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange)
+      throws Exception {
+    double numRecords =
+        restrictionTracker(kafkaSourceDescriptor, offsetRange).getProgress().getWorkRemaining();
+    // Before processing elements, we don't have a good estimated size of records and offset gap.
+    if (!avgRecordSize.asMap().containsKey(kafkaSourceDescriptor.getTopicPartition())) {
+      return numRecords;
+    }
+    return avgRecordSize.get(kafkaSourceDescriptor.getTopicPartition()).getTotalSize(numRecords);
+  }
+
+  @NewTracker
+  public GrowableOffsetRangeTracker restrictionTracker(
+      @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) {
+    Map<String, Object> updatedConsumerConfig =
+        overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor);
+    KafkaLatestOffsetEstimator offsetPoller =
+        new KafkaLatestOffsetEstimator(
+            consumerFactoryFn.apply(
+                KafkaIOUtils.getOffsetConsumerConfig(
+                    "tracker-" + kafkaSourceDescriptor.getTopicPartition(),
+                    offsetConsumerConfig,
+                    updatedConsumerConfig)),
+            kafkaSourceDescriptor.getTopicPartition());
+    return new GrowableOffsetRangeTracker(restriction.getFrom(), offsetPoller);
+  }
+
+  @ProcessElement
+  public ProcessContinuation processElement(
+      @Element KafkaSourceDescriptor kafkaSourceDescriptor,
+      RestrictionTracker<OffsetRange, Long> tracker,
+      WatermarkEstimator watermarkEstimator,
+      OutputReceiver<KafkaRecord<K, V>> receiver) {
+    // If there is no future work, resume with max timeout and move to the next element.
+    Map<String, Object> updatedConsumerConfig =
+        overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor);
+    // If there is a timestampPolicyFactory, create the TimestampPolicy for current
+    // TopicPartition.
+    TimestampPolicy timestampPolicy = null;
+    if (timestampPolicyFactory != null) {
+      timestampPolicy =
+          timestampPolicyFactory.createTimestampPolicy(
+              kafkaSourceDescriptor.getTopicPartition(),
+              Optional.ofNullable(watermarkEstimator.currentWatermark()));
+    }
+    try (Consumer<byte[], byte[]> consumer = consumerFactoryFn.apply(updatedConsumerConfig)) {
+      consumerSpEL.evaluateAssign(
+          consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
+      long startOffset = tracker.currentRestriction().getFrom();
+      long expectedOffset = startOffset;
+      consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset);
+      ConsumerRecords<byte[], byte[]> rawRecords = ConsumerRecords.empty();
+
+      while (true) {
+        rawRecords = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis());
+        // When there are no records available for the current TopicPartition, self-checkpoint
+        // and move to process the next element.
+        if (rawRecords.isEmpty()) {
+          return ProcessContinuation.resume();
+        }
+        for (ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
+          if (!tracker.tryClaim(rawRecord.offset())) {
+            return ProcessContinuation.stop();
+          }
+          KafkaRecord<K, V> kafkaRecord =
+              new KafkaRecord<>(
+                  rawRecord.topic(),
+                  rawRecord.partition(),
+                  rawRecord.offset(),
+                  consumerSpEL.getRecordTimestamp(rawRecord),
+                  consumerSpEL.getRecordTimestampType(rawRecord),
+                  ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null,
+                  keyDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.key()),
+                  valueDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.value()));
+          int recordSize =
+              (rawRecord.key() == null ? 0 : rawRecord.key().length)
+                  + (rawRecord.value() == null ? 0 : rawRecord.value().length);
+          avgRecordSize
+              .getUnchecked(kafkaSourceDescriptor.getTopicPartition())
+              .update(recordSize, rawRecord.offset() - expectedOffset);
+          expectedOffset = rawRecord.offset() + 1;
+          Instant outputTimestamp;
+          // The outputTimestamp and watermark will be computed by timestampPolicy, where the
+          // WatermarkEstimator should be a manual one.
+          if (timestampPolicy != null) {
+            checkState(watermarkEstimator instanceof ManualWatermarkEstimator);
+            TimestampPolicyContext context =
+                new TimestampPolicyContext(
+                    (long) ((HasProgress) tracker).getProgress().getWorkRemaining(), Instant.now());
+            outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord);
+            ((ManualWatermarkEstimator) watermarkEstimator)
+                .setWatermark(timestampPolicy.getWatermark(context));
+          } else {
+            outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord);
+          }
+          receiver.outputWithTimestamp(kafkaRecord, outputTimestamp);
+        }
+      }
+    }
+  }
+
+  @GetRestrictionCoder
+  public Coder<OffsetRange> restrictionCoder() {
+    return new OffsetRange.Coder();
+  }
+
+  @Setup
+  public void setup() throws Exception {
+    // Start to track record size and offset gap per bundle.
+    avgRecordSize =
+        CacheBuilder.newBuilder()
+            .maximumSize(1000L)
+            .build(
+                new CacheLoader<TopicPartition, AverageRecordSize>() {
+                  @Override
+                  public AverageRecordSize load(TopicPartition topicPartition) throws Exception {
+                    return new AverageRecordSize();
+                  }
+                });
+    consumerSpEL = new ConsumerSpEL();
+    keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true);
+    valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false);
+  }
+
+  @Teardown
+  public void teardown() throws Exception {
+    try {
+      Closeables.close(keyDeserializerInstance, true);
+      Closeables.close(valueDeserializerInstance, true);
+    } catch (Exception anyException) {
+      LOG.warn("Fail to close resource during finishing bundle.", anyException);
+    }
+  }
+
+  private Map<String, Object> overrideBootstrapServersConfig(
+      Map<String, Object> currentConfig, KafkaSourceDescriptor description) {
+    checkState(
+        currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)
+            || description.getBootStrapServers() != null);
+    Map<String, Object> config = new HashMap<>(currentConfig);
+    if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) {
+      config.put(
+          ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG,
+          String.join(",", description.getBootStrapServers()));
+    }
+    return config;
+  }
+
+  private static class AverageRecordSize {
+    private MovingAvg avgRecordSize;
+    private MovingAvg avgRecordGap;
+
+    public AverageRecordSize() {
+      this.avgRecordSize = new MovingAvg();
+      this.avgRecordGap = new MovingAvg();
+    }
+
+    public void update(int recordSize, long gap) {
+      avgRecordSize.update(recordSize);
+      avgRecordGap.update(gap);
+    }
+
+    public double getTotalSize(double numRecords) {
+      return avgRecordSize.get() * numRecords / (1 + avgRecordGap.get());
+    }
+  }
+}
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
index 3e44c17..aa9fa4b 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
@@ -29,7 +29,6 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
-import org.apache.beam.runners.core.construction.ReadTranslation;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -137,25 +136,60 @@
 
     RunnerApi.PTransform kafkaComposite =
         result.getComponents().getTransformsOrThrow(transform.getSubtransforms(0));
-    RunnerApi.PTransform kafkaRead =
-        result.getComponents().getTransformsOrThrow(kafkaComposite.getSubtransforms(0));
-    RunnerApi.ReadPayload readPayload =
-        RunnerApi.ReadPayload.parseFrom(kafkaRead.getSpec().getPayload());
-    KafkaUnboundedSource source =
-        (KafkaUnboundedSource) ReadTranslation.unboundedSourceFromProto(readPayload);
-    KafkaIO.Read spec = source.getSpec();
 
-    assertThat(spec.getConsumerConfig(), Matchers.is(consumerConfig));
-    assertThat(spec.getTopics(), Matchers.is(topics));
+    // KafkaIO.Read should be expanded into SDF transform.
     assertThat(
-        spec.getKeyDeserializerProvider()
-            .getDeserializer(spec.getConsumerConfig(), true)
+        kafkaComposite.getSubtransformsList(),
+        Matchers.contains(
+            "test_namespacetest/KafkaIO.Read/Impulse",
+            "test_namespacetest/KafkaIO.Read/ParDo(GenerateKafkaSourceDescriptor)",
+            "test_namespacetest/KafkaIO.Read/KafkaIO.ReadSourceDescriptors"));
+
+    // Verify the consumerConfig and topics are populated correctly to
+    // GenerateKafkaSourceDescriptor.
+    RunnerApi.PTransform generateParDo =
+        result.getComponents().getTransformsOrThrow(kafkaComposite.getSubtransforms(1));
+    KafkaIO.Read.GenerateKafkaSourceDescriptor generateDoFn =
+        (KafkaIO.Read.GenerateKafkaSourceDescriptor)
+            ParDoTranslation.getDoFn(
+                RunnerApi.ParDoPayload.parseFrom(
+                    result
+                        .getComponents()
+                        .getTransformsOrThrow(generateParDo.getSubtransforms(0))
+                        .getSpec()
+                        .getPayload()));
+    assertThat(generateDoFn.consumerConfig, Matchers.is(consumerConfig));
+    assertThat(generateDoFn.topics, Matchers.is(topics));
+
+    // Verify that the consumerConfig, keyDeserializerProvider, valueDeserializerProvider are
+    // populated correctly to the SDF.
+    RunnerApi.PTransform readViaSDF =
+        result.getComponents().getTransformsOrThrow(kafkaComposite.getSubtransforms(2));
+    RunnerApi.PTransform subTransform =
+        result.getComponents().getTransformsOrThrow(readViaSDF.getSubtransforms(0));
+
+    ReadFromKafkaDoFn readSDF =
+        (ReadFromKafkaDoFn)
+            ParDoTranslation.getDoFn(
+                RunnerApi.ParDoPayload.parseFrom(
+                    result
+                        .getComponents()
+                        .getTransformsOrThrow(subTransform.getSubtransforms(0))
+                        .getSpec()
+                        .getPayload()));
+
+    assertThat(readSDF.consumerConfig, Matchers.is(consumerConfig));
+    assertThat(
+        readSDF
+            .keyDeserializerProvider
+            .getDeserializer(readSDF.consumerConfig, true)
             .getClass()
             .getName(),
         Matchers.is(keyDeserializer));
     assertThat(
-        spec.getValueDeserializerProvider()
-            .getDeserializer(spec.getConsumerConfig(), false)
+        readSDF
+            .valueDeserializerProvider
+            .getDeserializer(readSDF.consumerConfig, false)
             .getClass()
             .getName(),
         Matchers.is(valueDeserializer));
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index d5f9e08..bb2ea92 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -1594,49 +1594,6 @@
     }
   }
 
-  @Test
-  public void testOffsetConsumerConfigOverrides() throws Exception {
-    KafkaUnboundedReader reader1 =
-        new KafkaUnboundedReader(
-            new KafkaUnboundedSource(
-                KafkaIO.read()
-                    .withBootstrapServers("broker_1:9092,broker_2:9092")
-                    .withTopic("my_topic")
-                    .withOffsetConsumerConfigOverrides(null),
-                0),
-            null);
-    assertTrue(
-        reader1
-            .getOffsetConsumerConfig()
-            .get(ConsumerConfig.GROUP_ID_CONFIG)
-            .toString()
-            .matches(".*_offset_consumer_\\d+_none"));
-    assertEquals(
-        false, reader1.getOffsetConsumerConfig().get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
-    assertEquals(
-        "read_uncommitted",
-        reader1.getOffsetConsumerConfig().get(ConsumerConfig.ISOLATION_LEVEL_CONFIG));
-
-    String offsetGroupId = "group.offsetConsumer";
-    KafkaUnboundedReader reader2 =
-        new KafkaUnboundedReader(
-            new KafkaUnboundedSource(
-                KafkaIO.read()
-                    .withBootstrapServers("broker_1:9092,broker_2:9092")
-                    .withTopic("my_topic")
-                    .withOffsetConsumerConfigOverrides(
-                        ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId)),
-                0),
-            null);
-    assertEquals(
-        offsetGroupId, reader2.getOffsetConsumerConfig().get(ConsumerConfig.GROUP_ID_CONFIG));
-    assertEquals(
-        false, reader2.getOffsetConsumerConfig().get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
-    assertEquals(
-        "read_uncommitted",
-        reader2.getOffsetConsumerConfig().get(ConsumerConfig.ISOLATION_LEVEL_CONFIG));
-  }
-
   private static void verifyProducerRecords(
       MockProducer<Integer, Long> mockProducer,
       String topic,
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOUtilsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOUtilsTest.java
new file mode 100644
index 0000000..c913fa5
--- /dev/null
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOUtilsTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Map;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests of {@link KafkaIOUtils}. */
+@RunWith(JUnit4.class)
+public class KafkaIOUtilsTest {
+
+  @Test
+  public void testOffsetConsumerConfigOverrides() throws Exception {
+    KafkaIO.Read spec =
+        KafkaIO.read()
+            .withBootstrapServers("broker_1:9092,broker_2:9092")
+            .withTopic("my_topic")
+            .withOffsetConsumerConfigOverrides(null);
+    Map<String, Object> offsetConfig =
+        KafkaIOUtils.getOffsetConsumerConfig(
+            "name", spec.getOffsetConsumerConfig(), spec.getConsumerConfig());
+    assertTrue(
+        offsetConfig
+            .get(ConsumerConfig.GROUP_ID_CONFIG)
+            .toString()
+            .matches("name_offset_consumer_\\d+_none"));
+
+    assertEquals(false, offsetConfig.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
+    assertEquals("read_uncommitted", offsetConfig.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG));
+
+    String offsetGroupId = "group.offsetConsumer";
+    KafkaIO.Read spec2 =
+        KafkaIO.read()
+            .withBootstrapServers("broker_1:9092,broker_2:9092")
+            .withTopic("my_topic")
+            .withOffsetConsumerConfigOverrides(
+                ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId));
+    offsetConfig =
+        KafkaIOUtils.getOffsetConsumerConfig(
+            "name2", spec2.getOffsetConsumerConfig(), spec2.getConsumerConfig());
+    assertEquals(offsetGroupId, offsetConfig.get(ConsumerConfig.GROUP_ID_CONFIG));
+    assertEquals(false, offsetConfig.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
+    assertEquals("read_uncommitted", offsetConfig.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG));
+  }
+}