Merge pull request #12325: [BEAM-10543] Upgrade Kafka cross-language python tests

diff --git a/build.gradle b/build.gradle
index 93ca237..7f93fe2 100644
--- a/build.gradle
+++ b/build.gradle
@@ -247,7 +247,6 @@
 }
 
 task python2PostCommit() {
-  dependsOn ":sdks:python:test-suites:portable:py2:crossLanguagePythonJavaKafkaIOFlink"
   dependsOn ":sdks:python:test-suites:portable:py2:crossLanguageTests"
   dependsOn ":sdks:python:test-suites:dataflow:py2:postCommitIT"
   dependsOn ":sdks:python:test-suites:direct:py2:directRunnerIT"
@@ -275,7 +274,6 @@
 }
 
 task python38PostCommit() {
-  dependsOn ":sdks:python:test-suites:portable:py38:crossLanguagePythonJavaKafkaIOFlink"
   dependsOn ":sdks:python:test-suites:dataflow:py38:postCommitIT"
   dependsOn ":sdks:python:test-suites:direct:py38:postCommitIT"
   dependsOn ":sdks:python:test-suites:direct:py38:hdfsIntegrationTest"
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 13aabc8..08847f6 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -452,9 +452,15 @@
         // Set required defaults
         setTopicPartitions(Collections.emptyList());
         setConsumerFactoryFn(Read.KAFKA_CONSUMER_FACTORY_FN);
-        setMaxNumRecords(Long.MAX_VALUE);
+        if (config.maxReadTime != null) {
+          setMaxReadTime(Duration.standardSeconds(config.maxReadTime));
+        }
+        setMaxNumRecords(config.maxNumRecords == null ? Long.MAX_VALUE : config.maxNumRecords);
         setCommitOffsetsInFinalizeEnabled(false);
         setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime());
+        if (config.startReadTime != null) {
+          setStartReadTime(Instant.ofEpochMilli(config.startReadTime));
+        }
         // We do not include Metadata until we can encode KafkaRecords cross-language
         return build().withoutMetadata();
       }
@@ -507,6 +513,9 @@
         private Iterable<String> topics;
         private String keyDeserializer;
         private String valueDeserializer;
+        private Long startReadTime;
+        private Long maxNumRecords;
+        private Long maxReadTime;
 
         public void setConsumerConfig(Iterable<KV<String, String>> consumerConfig) {
           this.consumerConfig = consumerConfig;
@@ -523,6 +532,18 @@
         public void setValueDeserializer(String valueDeserializer) {
           this.valueDeserializer = valueDeserializer;
         }
+
+        public void setStartReadTime(Long startReadTime) {
+          this.startReadTime = startReadTime;
+        }
+
+        public void setMaxNumRecords(Long maxNumRecords) {
+          this.maxNumRecords = maxNumRecords;
+        }
+
+        public void setMaxReadTime(Long maxReadTime) {
+          this.maxReadTime = maxReadTime;
+        }
       }
     }
 
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
index d157c16..3e44c17 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
@@ -34,6 +34,7 @@
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.expansion.service.ExpansionService;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Impulse;
@@ -67,6 +68,7 @@
             .put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer)
             .put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer)
             .build();
+    Long startReadTime = 100L;
 
     ExternalTransforms.ExternalConfigurationPayload payload =
         ExternalTransforms.ExternalConfigurationPayload.newBuilder()
@@ -98,6 +100,12 @@
                     .addCoderUrn("beam:coder:string_utf8:v1")
                     .setPayload(ByteString.copyFrom(encodeString(valueDeserializer)))
                     .build())
+            .putConfiguration(
+                "start_read_time",
+                ExternalTransforms.ConfigValue.newBuilder()
+                    .addCoderUrn("beam:coder:varint:v1")
+                    .setPayload(ByteString.copyFrom(encodeLong(startReadTime)))
+                    .build())
             .build();
 
     RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
@@ -280,6 +288,12 @@
     return baos.toByteArray();
   }
 
+  private static byte[] encodeLong(Long str) throws IOException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    VarLongCoder.of().encode(str, baos);
+    return baos.toByteArray();
+  }
+
   private static class TestStreamObserver<T> implements StreamObserver<T> {
 
     private T result;
diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
index 0dad234..cec1d9b 100644
--- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
@@ -28,12 +28,17 @@
 import time
 import typing
 import unittest
+import uuid
 
 import apache_beam as beam
 from apache_beam.io.kafka import ReadFromKafka
 from apache_beam.io.kafka import WriteToKafka
 from apache_beam.metrics import Metrics
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+NUM_RECORDS = 1000
 
 
 class CrossLanguageKafkaIO(object):
@@ -47,7 +52,7 @@
     _ = (
         pipeline
         | 'Impulse' >> beam.Impulse()
-        | 'Generate' >> beam.FlatMap(lambda x: range(1000))  # pylint: disable=range-builtin-not-iterating
+        | 'Generate' >> beam.FlatMap(lambda x: range(NUM_RECORDS))  # pylint: disable=range-builtin-not-iterating
         | 'Reshuffle' >> beam.Reshuffle()
         | 'MakeKV' >> beam.Map(lambda x:
                                (b'', str(x).encode())).with_output_types(
@@ -57,8 +62,8 @@
             topic=self.topic,
             expansion_service=self.expansion_service))
 
-  def build_read_pipeline(self, pipeline):
-    _ = (
+  def build_read_pipeline(self, pipeline, max_num_records=None):
+    kafka_records = (
         pipeline
         | 'ReadFromKafka' >> ReadFromKafka(
             consumer_config={
@@ -66,7 +71,14 @@
                 'auto.offset.reset': 'earliest'
             },
             topics=[self.topic],
-            expansion_service=self.expansion_service)
+            max_num_records=max_num_records,
+            expansion_service=self.expansion_service))
+
+    if max_num_records:
+      return kafka_records
+
+    return (
+        kafka_records
         | 'Windowing' >> beam.WindowInto(
             beam.window.FixedWindows(300),
             trigger=beam.transforms.trigger.AfterProcessingTime(60),
@@ -86,6 +98,30 @@
     os.environ.get('LOCAL_KAFKA_JAR'),
     "LOCAL_KAFKA_JAR environment var is not provided.")
 class CrossLanguageKafkaIOTest(unittest.TestCase):
+  def test_kafkaio(self):
+    kafka_topic = 'xlang_kafkaio_test_{}'.format(uuid.uuid4())
+    local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
+    with self.local_kafka_service(local_kafka_jar) as kafka_port:
+      bootstrap_servers = '{}:{}'.format(
+          self.get_platform_localhost(), kafka_port)
+      pipeline_creator = CrossLanguageKafkaIO(bootstrap_servers, kafka_topic)
+
+      self.run_kafka_write(pipeline_creator)
+      self.run_kafka_read(pipeline_creator)
+
+  def run_kafka_write(self, pipeline_creator):
+    with TestPipeline() as pipeline:
+      pipeline.not_use_test_runner_api = True
+      pipeline_creator.build_write_pipeline(pipeline)
+
+  def run_kafka_read(self, pipeline_creator):
+    with TestPipeline() as pipeline:
+      pipeline.not_use_test_runner_api = True
+      result = pipeline_creator.build_read_pipeline(pipeline, NUM_RECORDS)
+      assert_that(
+          result,
+          equal_to([(b'', str(i).encode()) for i in range(NUM_RECORDS)]))
+
   def get_platform_localhost(self):
     if sys.platform == 'darwin':
       return 'host.docker.internal'
@@ -119,18 +155,6 @@
       if kafka_server:
         kafka_server.kill()
 
-  def test_kafkaio_write(self):
-    local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
-    with self.local_kafka_service(local_kafka_jar) as kafka_port:
-      p = TestPipeline()
-      p.not_use_test_runner_api = True
-      xlang_kafkaio = CrossLanguageKafkaIO(
-          '%s:%s' % (self.get_platform_localhost(), kafka_port),
-          'xlang_kafkaio_test')
-      xlang_kafkaio.build_write_pipeline(p)
-      job = p.run()
-      job.wait_until_finish()
-
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/io/kafka.py b/sdks/python/apache_beam/io/kafka.py
index 4336bec..dc75b73 100644
--- a/sdks/python/apache_beam/io/kafka.py
+++ b/sdks/python/apache_beam/io/kafka.py
@@ -97,6 +97,9 @@
         ('topics', typing.List[unicode]),
         ('key_deserializer', unicode),
         ('value_deserializer', unicode),
+        ('start_read_time', typing.Optional[int]),
+        ('max_num_records', typing.Optional[int]),
+        ('max_read_time', typing.Optional[int]),
     ])
 
 
@@ -125,24 +128,30 @@
       topics,
       key_deserializer=byte_array_deserializer,
       value_deserializer=byte_array_deserializer,
-      expansion_service=None):
+      start_read_time=None,
+      max_num_records=None,
+      max_read_time=None,
+      expansion_service=None,
+  ):
     """
     Initializes a read operation from Kafka.
 
     :param consumer_config: A dictionary containing the consumer configuration.
     :param topics: A list of topic strings.
     :param key_deserializer: A fully-qualified Java class name of a Kafka
-                             Deserializer for the topic's key, e.g.
-                             'org.apache.kafka.common.
-                             serialization.LongDeserializer'.
-                             Default: 'org.apache.kafka.common.
-                             serialization.ByteArrayDeserializer'.
+        Deserializer for the topic's key, e.g.
+        'org.apache.kafka.common.serialization.LongDeserializer'.
+        Default: 'org.apache.kafka.common.serialization.ByteArrayDeserializer'.
     :param value_deserializer: A fully-qualified Java class name of a Kafka
-                               Deserializer for the topic's value, e.g.
-                               'org.apache.kafka.common.
-                               serialization.LongDeserializer'.
-                               Default: 'org.apache.kafka.common.
-                               serialization.ByteArrayDeserializer'.
+        Deserializer for the topic's value, e.g.
+        'org.apache.kafka.common.serialization.LongDeserializer'.
+        Default: 'org.apache.kafka.common.serialization.ByteArrayDeserializer'.
+    :param start_read_time: Use timestamp to set up start offset in milliseconds
+        epoch.
+    :param max_num_records: Maximum amount of records to be read. Mainly used
+        for tests and demo applications.
+    :param max_read_time: Maximum amount of time in seconds the transform
+        executes. Mainly used for tests and demo applications.
     :param expansion_service: The address (host:port) of the ExpansionService.
     """
     super(ReadFromKafka, self).__init__(
@@ -153,6 +162,9 @@
                 topics=topics,
                 key_deserializer=key_deserializer,
                 value_deserializer=value_deserializer,
+                max_num_records=max_num_records,
+                max_read_time=max_read_time,
+                start_read_time=start_read_time,
             )),
         expansion_service or default_io_expansion_service())
 
@@ -195,17 +207,13 @@
     :param producer_config: A dictionary containing the producer configuration.
     :param topic: A Kafka topic name.
     :param key_deserializer: A fully-qualified Java class name of a Kafka
-                             Serializer for the topic's key, e.g.
-                             'org.apache.kafka.common.
-                             serialization.LongSerializer'.
-                             Default: 'org.apache.kafka.common.
-                             serialization.ByteArraySerializer'.
+        Serializer for the topic's key, e.g.
+        'org.apache.kafka.common.serialization.LongSerializer'.
+        Default: 'org.apache.kafka.common.serialization.ByteArraySerializer'.
     :param value_deserializer: A fully-qualified Java class name of a Kafka
-                               Serializer for the topic's value, e.g.
-                               'org.apache.kafka.common.
-                               serialization.LongSerializer'.
-                               Default: 'org.apache.kafka.common.
-                               serialization.ByteArraySerializer'.
+        Serializer for the topic's value, e.g.
+        'org.apache.kafka.common.serialization.LongSerializer'.
+        Default: 'org.apache.kafka.common.serialization.ByteArraySerializer'.
     :param expansion_service: The address (host:port) of the ExpansionService.
     """
     super(WriteToKafka, self).__init__(
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index 2e60afa..48312a6 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -101,38 +101,6 @@
   }
 }
 
-task crossLanguagePythonJavaKafkaIOFlink {
-  dependsOn 'setupVirtualenv'
-  dependsOn ':runners:flink:1.10:job-server:shadowJar'
-  dependsOn ":sdks:python:container:py${pythonVersionSuffix}:docker"
-  dependsOn ':sdks:java:container:docker'
-  dependsOn ':sdks:java:io:expansion-service:shadowJar'
-  dependsOn ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar'
-
-  doLast {
-    def kafkaJar = project(":sdks:java:testing:kafka-service:").buildTestKafkaServiceJar.archivePath
-    def options = [
-            "--runner=FlinkRunner",
-            "--parallelism=2",
-            "--environment_type=DOCKER",
-            "--environment_cache_millis=10000",
-            "--experiment=beam_fn_api",
-    ]
-    exec {
-      environment "LOCAL_KAFKA_JAR", kafkaJar
-      executable 'sh'
-      args '-c', """
-          . ${envdir}/bin/activate \\
-          && cd ${pythonRootDir} \\
-          && pip install -e .[test] \\
-          && python setup.py nosetests \\
-              --tests apache_beam.io.external.xlang_kafkaio_it_test:CrossLanguageKafkaIOTest \\
-              --test-pipeline-options='${options.join(' ')}'
-          """
-    }
-  }
-}
-
 task createProcessWorker {
   dependsOn ':sdks:python:container:build'
   dependsOn 'setupVirtualenv'
@@ -223,12 +191,14 @@
           ':runners:flink:1.10:job-server:shadowJar',
           ':sdks:java:container:docker',
           ':sdks:java:io:expansion-service:shadowJar',
+          ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar'
   ]
 
   doLast {
     def tests = [
             "apache_beam.io.gcp.bigquery_read_it_test",
             "apache_beam.io.external.xlang_jdbcio_it_test",
+            "apache_beam.io.external.xlang_kafkaio_it_test",
     ]
     def testOpts = ["--tests=${tests.join(',')}"]
     def cmdArgs = mapToArgString([
@@ -236,7 +206,9 @@
             "suite": "postCommitIT-flink-py${pythonVersionSuffix}",
             "pipeline_opts": "--runner=FlinkRunner --project=apache-beam-testing --environment_type=LOOPBACK --temp_location=gs://temp-storage-for-end-to-end-tests/temp-it",
     ])
+    def kafkaJar = project(":sdks:java:testing:kafka-service:").buildTestKafkaServiceJar.archivePath
     exec {
+      environment "LOCAL_KAFKA_JAR", kafkaJar
       executable 'sh'
       args '-c', ". ${envdir}/bin/activate && ${pythonRootDir}/scripts/run_integration_test.sh $cmdArgs"
     }