[SPARK-47793][SS][PYTHON] Implement SimpleDataSourceStreamReader for python streaming data source

### What changes were proposed in this pull request?
SimpleDataSourceStreamReader is a simplified version of the DataSourceStreamReader interface.

There are 3 functions that needs to be defined

1. Read data and return the end offset.
_def read(self, start: Offset) -> (Iterator[Tuple], Offset)_

2. Read data between start and end offset, this is required for exactly once read.
_def readBetweenOffset(self, start: Offset, end: Offset) -> Iterator[Tuple]_

3. initial start offset of the streaming query.
_def initialOffset() -> dict_

The implementation wrap the SimpleDataSourceStreamReader instance in a DataSourceStreamReader that prefetch and cache data in latestOffset. The record prefetched in python process will be sent to JVM as arrow record batches in planInputPartitions() and cached by block manager and read by partition reader from executor later..

### Why are the changes needed?
Compared to DataSourceStreamReader interface, the simplified interface has some advantages.
It doesn’t require developers to reason about data partitioning.
It doesn’t require getting the latest offset before reading data.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Add unit test and integration test.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #45977 from chaoqin-li1123/simple_reader_impl.

Lead-authored-by: Chaoqin Li <chaoqin.li@databricks.com>
Co-authored-by: chaoqin-li1123 <55518381+chaoqin-li1123@users.noreply.github.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 585d9a8..6eb015d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -170,6 +170,11 @@
   override def name: String = "input-" + streamId + "-" + uniqueId
 }
 
+@DeveloperApi
+case class PythonStreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+  override def name: String = "python-stream-" + streamId + "-" + uniqueId
+}
+
 /** Id associated with temporary local data managed as blocks. Not serializable. */
 private[spark] case class TempLocalBlockId(id: UUID) extends BlockId {
   override def name: String = "temp_local_" + id
@@ -213,6 +218,7 @@
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
+  val PYTHON_STREAM = "python-stream-([0-9]+)-([0-9]+)".r
   val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r
   val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
   val TEST = "test_(.*)".r
@@ -250,6 +256,8 @@
       TaskResultBlockId(taskId.toLong)
     case STREAM(streamId, uniqueId) =>
       StreamBlockId(streamId.toInt, uniqueId.toLong)
+    case PYTHON_STREAM(streamId, uniqueId) =>
+      PythonStreamBlockId(streamId.toInt, uniqueId.toLong)
     case TEMP_LOCAL(uuid) =>
       TempLocalBlockId(UUID.fromString(uuid))
     case TEMP_SHUFFLE(uuid) =>
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index c08b5b7..6cac7e3 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -183,11 +183,36 @@
             message_parameters={"feature": "streamWriter"},
         )
 
+    def simpleStreamReader(self, schema: StructType) -> "SimpleDataSourceStreamReader":
+        """
+        Returns a :class:`SimpleDataSourceStreamReader` instance for reading data.
+
+        One of simpleStreamReader() and streamReader() must be implemented for readable streaming
+        data source. Spark will check whether streamReader() is implemented, if yes, create a
+        DataSourceStreamReader to read data. simpleStreamReader() will only be invoked when
+        streamReader() is not implemented.
+
+        Parameters
+        ----------
+        schema : :class:`StructType`
+            The schema of the data to be read.
+
+        Returns
+        -------
+        reader : :class:`SimpleDataSourceStreamReader`
+            A reader instance for this data source.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "simpleStreamReader"},
+        )
+
     def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
         """
         Returns a :class:`DataSourceStreamReader` instance for reading streaming data.
 
-        The implementation is required for readable streaming data sources.
+        One of simpleStreamReader() and streamReader() must be implemented for readable streaming
+        data source.
 
         Parameters
         ----------
@@ -396,8 +421,10 @@
 
     def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
         """
-        Returns a list of InputPartition  given the start and end offsets. Each InputPartition
-        represents a data split that can be processed by one Spark task.
+        Returns a list of InputPartition given the start and end offsets. Each InputPartition
+        represents a data split that can be processed by one Spark task. This may be called with
+        an empty offset range when start == end, in that case the method should return
+        an empty sequence of InputPartition.
 
         Parameters
         ----------
@@ -469,6 +496,102 @@
         ...
 
 
+class SimpleDataSourceStreamReader(ABC):
+    """
+    A base class for simplified streaming data source readers.
+    Compared to :class:`DataSourceStreamReader`, :class:`SimpleDataSourceStreamReader` doesn't
+    require planning data partition. Also, the read api of :class:`SimpleDataSourceStreamReader`
+    allows reading data and planning the latest offset at the same time.
+
+    Because  :class:`SimpleDataSourceStreamReader` read records in Spark driver node to determine
+    end offset of each batch without partitioning, it is only supposed to be used in
+    lightweight use cases where input rate and batch size is small.
+    Use :class:`DataSourceStreamReader` when read throughput is high and can't be handled
+    by a single process.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, "partition-2": {"index": 5}}
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
+        """
+        Read all available data from start offset and return the offset that next read attempt
+        starts from.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        Returns
+        -------
+        A :class:`Tuple` of an iterator of :class:`Tuple` and a dict\\s
+            The iterator contains all the available records after start offset.
+            The dict is the end offset of this read attempt and the start of next read attempt.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read"},
+        )
+
+    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
+        """
+        Read all available data from specific start offset and end offset.
+        This is invoked during failure recovery to re-read a batch deterministically.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        end : dict
+            The offset where the reading stop.
+
+        Returns
+        -------
+        iterator of :class:`Tuple`\\s
+            All the records between start offset and end offset.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "readBetweenOffsets"},
+        )
+
+    def commit(self, end: dict) -> None:
+        """
+        Informs the source that Spark has completed processing all data for offsets less than or
+        equal to `end` and will only request offsets greater than `end` in the future.
+
+        Parameters
+        ----------
+        end : dict
+            The latest offset that the streaming query has processed for this source.
+        """
+        ...
+
+
 class DataSourceWriter(ABC):
     """
     A base class for data source writers. Data source writers are responsible for saving
diff --git a/python/pyspark/sql/datasource_internal.py b/python/pyspark/sql/datasource_internal.py
new file mode 100644
index 0000000..6df0be4
--- /dev/null
+++ b/python/pyspark/sql/datasource_internal.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> "DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return _SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is initialized to be
+    initialOffset() when query start for the first time or initialized to be the end offset of
+    the last planned batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that starts from
+    current_offset, prefetches and cache the data, then updates the current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been committed will be
+    replayed by reading data between start and end offset through readBetweenOffsets(start, end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start offset.
+        if self.current_offset is None:
+            self.current_offset = self.initialOffset()
+        (iter, end) = self.simple_reader.read(self.current_offset)
+        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+        self.current_offset = end
+        return end
+
+    def commit(self, end: dict) -> None:
+        if self.current_offset is None:
+            self.current_offset = end
+
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if end_idx > 0:
+            # Drop prefetched data for batch that has been committed.
+            self.cache = self.cache[end_idx:]
+        self.simple_reader.commit(end)
+
+    def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+        # when query restart from checkpoint, use the last committed offset as the start offset.
+        # This depends on the streaming engine calling planInputPartitions() of the last batch
+        # in offset log when query restart.
+        if self.current_offset is None:
+            self.current_offset = end
+        if len(self.cache) > 0:
+            assert self.cache[-1].end == end
+        return [SimpleInputPartition(start, end)]
+
+    def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+        start_idx = -1
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            # There is no convenient way to compare 2 offsets.
+            # Serialize into json string before comparison.
+            if json.dumps(entry.start) == json.dumps(start):
+                start_idx = idx
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if start_idx == -1 or end_idx == -1:
+            return None  # type: ignore[return-value]
+        # Chain all the data iterator between start offset and end offset
+        # need to copy here to avoid exhausting the original data iterator.
+        entries = [copy.copy(entry.iterator) for entry in self.cache[start_idx : end_idx + 1]]
+        it = chain(*entries)
+        return it
+
+    def read(
+        self, input_partition: SimpleInputPartition  # type: ignore[override]
+    ) -> Iterator[Tuple]:
+        return self.simple_reader.readBetweenOffsets(input_partition.start, input_partition.end)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 8109403..946344f 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -18,9 +18,10 @@
 import os
 import sys
 import json
-from typing import IO
+from typing import IO, Iterator, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError
 from pyspark.serializers import (
     read_int,
@@ -29,11 +30,14 @@
     SpecialLengths,
 )
 from pyspark.sql.datasource import DataSource, DataSourceStreamReader
+from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, _streamReader
+from pyspark.sql.pandas.serializers import ArrowStreamSerializer
 from pyspark.sql.types import (
     _parse_datatype_json_string,
     StructType,
 )
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches
+from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
     check_python_version,
     read_command,
@@ -49,6 +53,10 @@
 PARTITIONS_FUNC_ID = 886
 COMMIT_FUNC_ID = 887
 
+PREFETCHED_RECORDS_NOT_FOUND = 0
+NON_EMPTY_PYARROW_RECORD_BATCHES = 1
+EMPTY_PYARROW_RECORD_BATCHES = 2
+
 
 def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
     offset = reader.initialOffset()
@@ -60,7 +68,14 @@
     write_with_length(json.dumps(offset).encode("utf-8"), outfile)
 
 
-def partitions_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> None:
+def partitions_func(
+    reader: DataSourceStreamReader,
+    data_source: DataSource,
+    schema: StructType,
+    max_arrow_batch_size: int,
+    infile: IO,
+    outfile: IO,
+) -> None:
     start_offset = json.loads(utf8_deserializer.loads(infile))
     end_offset = json.loads(utf8_deserializer.loads(infile))
     partitions = reader.partitions(start_offset, end_offset)
@@ -68,6 +83,14 @@
     write_int(len(partitions), outfile)
     for partition in partitions:
         pickleSer._write_with_length(partition, outfile)
+    if isinstance(reader, _SimpleStreamReaderWrapper):
+        it = reader.getCache(start_offset, end_offset)
+        if it is None:
+            write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)
+        else:
+            send_batch_func(it, outfile, schema, max_arrow_batch_size, data_source)
+    else:
+        write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)
 
 
 def commit_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> None:
@@ -76,6 +99,23 @@
     write_int(0, outfile)
 
 
+def send_batch_func(
+    rows: Iterator[Tuple],
+    outfile: IO,
+    schema: StructType,
+    max_arrow_batch_size: int,
+    data_source: DataSource,
+) -> None:
+    batches = list(records_to_arrow_batches(rows, max_arrow_batch_size, schema, data_source))
+    if len(batches) != 0:
+        write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile)
+        write_int(SpecialLengths.START_ARROW_STREAM, outfile)
+        serializer = ArrowStreamSerializer()
+        serializer.dump_stream(batches, outfile)
+    else:
+        write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile)
+
+
 def main(infile: IO, outfile: IO) -> None:
     try:
         check_python_version(infile)
@@ -110,9 +150,15 @@
                 },
             )
 
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+
         # Instantiate data source reader.
         try:
-            reader = data_source.streamReader(schema=schema)
+            reader = _streamReader(data_source, schema)
             # Initialization succeed.
             write_int(0, outfile)
             outfile.flush()
@@ -125,7 +171,9 @@
                 elif func_id == LATEST_OFFSET_FUNC_ID:
                     latest_offset_func(reader, outfile)
                 elif func_id == PARTITIONS_FUNC_ID:
-                    partitions_func(reader, infile, outfile)
+                    partitions_func(
+                        reader, data_source, schema, max_arrow_batch_size, infile, outfile
+                    )
                 elif func_id == COMMIT_FUNC_ID:
                     commit_func(reader, infile, outfile)
                 else:
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index 8a8b2ca..be7ebd2 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -18,8 +18,9 @@
 import os
 import sys
 import functools
+import pyarrow as pa
 from itertools import islice
-from typing import IO, List, Iterator, Iterable
+from typing import IO, List, Iterator, Iterable, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -32,6 +33,7 @@
 from pyspark.sql import Row
 from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
 from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.datasource_internal import _streamReader
 from pyspark.sql.pandas.types import to_arrow_schema
 from pyspark.sql.types import (
     _parse_datatype_json_string,
@@ -51,6 +53,78 @@
 )
 
 
+def records_to_arrow_batches(
+    output_iter: Iterator[Tuple],
+    max_arrow_batch_size: int,
+    return_type: StructType,
+    data_source: DataSource,
+) -> Iterable[pa.RecordBatch]:
+    """
+    Convert an iterator of Python tuples to an iterator of pyarrow record batches.
+
+    For each python tuple, check the types of each field and append it to the records batch.
+
+    """
+
+    def batched(iterator: Iterator, n: int) -> Iterator:
+        return iter(functools.partial(lambda it: list(islice(it, n)), iterator), [])
+
+    pa_schema = to_arrow_schema(return_type)
+    column_names = return_type.fieldNames()
+    column_converters = [
+        LocalDataToArrowConversion._create_converter(field.dataType) for field in return_type.fields
+    ]
+    # Convert the results from the `reader.read` method to an iterator of arrow batches.
+    num_cols = len(column_names)
+    col_mapping = {name: i for i, name in enumerate(column_names)}
+    col_name_set = set(column_names)
+    for batch in batched(output_iter, max_arrow_batch_size):
+        pylist: List[List] = [[] for _ in range(num_cols)]
+        for result in batch:
+            # Validate the output row schema.
+            if hasattr(result, "__len__") and len(result) != num_cols:
+                raise PySparkRuntimeError(
+                    error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
+                    message_parameters={
+                        "expected": str(num_cols),
+                        "actual": str(len(result)),
+                    },
+                )
+
+            # Validate the output row type.
+            if not isinstance(result, (list, tuple)):
+                raise PySparkRuntimeError(
+                    error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
+                    message_parameters={
+                        "type": type(result).__name__,
+                        "name": data_source.name(),
+                        "supported_types": "tuple, list, `pyspark.sql.types.Row`",
+                    },
+                )
+
+            # Assign output values by name of the field, not position, if the result is a
+            # named `Row` object.
+            if isinstance(result, Row) and hasattr(result, "__fields__"):
+                # Check if the names are the same as the schema.
+                if set(result.__fields__) != col_name_set:
+                    raise PySparkRuntimeError(
+                        error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+                        message_parameters={
+                            "expected": str(column_names),
+                            "actual": str(result.__fields__),
+                        },
+                    )
+                # Assign the values by name.
+                for name in column_names:
+                    idx = col_mapping[name]
+                    pylist[idx].append(column_converters[idx](result[name]))
+            else:
+                for col in range(num_cols):
+                    pylist[col].append(column_converters[col](result[col]))
+        batch = pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
+        yield batch
+
+
 def main(infile: IO, outfile: IO) -> None:
     """
     Main method for planning a data source read.
@@ -131,25 +205,16 @@
 
         # Instantiate data source reader.
         reader = (
-            data_source.streamReader(schema=schema)
+            _streamReader(data_source, schema)
             if is_streaming
             else data_source.reader(schema=schema)
         )
 
-        # Wrap the data source read logic in an mapInArrow UDF.
-        import pyarrow as pa
-
         # Create input converter.
         converter = ArrowTableToRowsConversion._create_converter(BinaryType())
 
         # Create output converter.
         return_type = schema
-        pa_schema = to_arrow_schema(return_type)
-        column_names = return_type.fieldNames()
-        column_converters = [
-            LocalDataToArrowConversion._create_converter(field.dataType)
-            for field in return_type.fields
-        ]
 
         def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
             partition_bytes = None
@@ -189,58 +254,9 @@
                     },
                 )
 
-            def batched(iterator: Iterator, n: int) -> Iterator:
-                return iter(functools.partial(lambda it: list(islice(it, n)), iterator), [])
-
-            # Convert the results from the `reader.read` method to an iterator of arrow batches.
-            num_cols = len(column_names)
-            col_mapping = {name: i for i, name in enumerate(column_names)}
-            col_name_set = set(column_names)
-            for batch in batched(output_iter, max_arrow_batch_size):
-                pylist: List[List] = [[] for _ in range(num_cols)]
-                for result in batch:
-                    # Validate the output row schema.
-                    if hasattr(result, "__len__") and len(result) != num_cols:
-                        raise PySparkRuntimeError(
-                            error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
-                            message_parameters={
-                                "expected": str(num_cols),
-                                "actual": str(len(result)),
-                            },
-                        )
-
-                    # Validate the output row type.
-                    if not isinstance(result, (list, tuple)):
-                        raise PySparkRuntimeError(
-                            error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
-                            message_parameters={
-                                "type": type(result).__name__,
-                                "name": data_source.name(),
-                                "supported_types": "tuple, list, `pyspark.sql.types.Row`",
-                            },
-                        )
-
-                    # Assign output values by name of the field, not position, if the result is a
-                    # named `Row` object.
-                    if isinstance(result, Row) and hasattr(result, "__fields__"):
-                        # Check if the names are the same as the schema.
-                        if set(result.__fields__) != col_name_set:
-                            raise PySparkRuntimeError(
-                                error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
-                                message_parameters={
-                                    "expected": str(column_names),
-                                    "actual": str(result.__fields__),
-                                },
-                            )
-                        # Assign the values by name.
-                        for name in column_names:
-                            idx = col_mapping[name]
-                            pylist[idx].append(column_converters[idx](result[name]))
-                    else:
-                        for col in range(num_cols):
-                            pylist[col].append(column_converters[col](result[col]))
-
-                yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
+            return records_to_arrow_batches(
+                output_iter, max_arrow_batch_size, return_type, data_source
+            )
 
         command = (data_source_read_func, return_type)
         pickleSer._write_with_length(command, outfile)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 71e6c29..0fc1df4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -16,12 +16,15 @@
  */
 package org.apache.spark.sql.execution.datasources.v2.python
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
-import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset}
+import org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
 import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.storage.{PythonStreamBlockId, StorageLevel}
 
 case class PythonStreamingSourceOffset(json: String) extends Offset
 
@@ -30,11 +33,22 @@
     shortName: String,
     outputSchema: StructType,
     options: CaseInsensitiveStringMap
-  ) extends MicroBatchStream with Logging {
+  )
+  extends MicroBatchStream
+  with Logging
+  with AcceptsLatestSeenOffset {
   private def createDataSourceFunc =
     ds.source.createPythonFunction(
       ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)).dataSource)
 
+  private val streamId = nextStreamId
+  private var nextBlockId = 0L
+
+  // planInputPartitions() maybe be called multiple times for the current microbatch.
+  // Cache the result of planInputPartitions() because it may involve sending data
+  // from python to JVM.
+  private var cachedInputPartition: Option[(String, String, PythonStreamingInputPartition)] = None
+
   private val runner: PythonStreamingSourceRunner =
     new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
   runner.init()
@@ -44,9 +58,35 @@
   override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset())
 
   override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = {
-    runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
-      end.asInstanceOf[PythonStreamingSourceOffset].json)
-      .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+    val startOffsetJson = start.asInstanceOf[PythonStreamingSourceOffset].json
+    val endOffsetJson = end.asInstanceOf[PythonStreamingSourceOffset].json
+
+    if (cachedInputPartition.exists(p => p._1 == startOffsetJson && p._2 == endOffsetJson)) {
+      return Array(cachedInputPartition.get._3)
+    }
+
+    val (partitions, rows) = runner.partitions(startOffsetJson, endOffsetJson)
+    if (rows.isDefined) {
+      // Only SimpleStreamReader without partitioning prefetch data.
+      assert(partitions.length == 1)
+      nextBlockId = nextBlockId + 1
+      val blockId = PythonStreamBlockId(streamId, nextBlockId)
+      SparkEnv.get.blockManager.putIterator(
+        blockId, rows.get, StorageLevel.MEMORY_AND_DISK_SER, true)
+      val partition = PythonStreamingInputPartition(0, partitions.head, Some(blockId))
+      cachedInputPartition.foreach(_._3.dropCache())
+      cachedInputPartition = Some((startOffsetJson, endOffsetJson, partition))
+      Array(partition)
+    } else {
+      partitions.zipWithIndex
+        .map(p => PythonStreamingInputPartition(p._2, p._1, None))
+    }
+  }
+
+  override def setLatestSeenOffset(offset: Offset): Unit = {
+    // Call planPartition on python with an empty offset range to initialize the start offset
+    // for the prefetching of simple reader.
+    runner.partitions(offset.json(), offset.json())
   }
 
   private lazy val readInfo: PythonDataSourceReadInfo = {
@@ -57,7 +97,7 @@
   }
 
   override def createReaderFactory(): PartitionReaderFactory = {
-    new PythonPartitionReaderFactory(
+    new PythonStreamingPartitionReaderFactory(
       ds.source, readInfo.func, outputSchema, None)
   }
 
@@ -66,9 +106,18 @@
   }
 
   override def stop(): Unit = {
+    cachedInputPartition.foreach(_._3.dropCache())
     runner.stop()
   }
 
   override def deserializeOffset(json: String): Offset = PythonStreamingSourceOffset(json)
 }
 
+object PythonMicroBatchStream {
+  private var currentId = 0
+  def nextStreamId: Int = synchronized {
+    currentId = currentId + 1
+    currentId
+  }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 8fefc8b..8ebb91c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -41,6 +41,9 @@
 
   override def supportedCustomMetrics(): Array[CustomMetric] =
     ds.source.createPythonMetrics()
+
+  override def columnarSupportMode(): Scan.ColumnarSupportMode =
+    Scan.ColumnarSupportMode.UNSUPPORTED
 }
 
 class PythonBatch(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
new file mode 100644
index 0000000..75a38b8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.PythonStreamBlockId
+
+
+case class PythonStreamingInputPartition(
+    index: Int,
+    pickedPartition: Array[Byte],
+    blockId: Option[PythonStreamBlockId]) extends InputPartition {
+  def dropCache(): Unit = {
+    blockId.foreach(SparkEnv.get.blockManager.master.removeBlock(_))
+  }
+}
+
+class PythonStreamingPartitionReaderFactory(
+    source: UserDefinedPythonDataSource,
+    pickledReadFunc: Array[Byte],
+    outputSchema: StructType,
+    jobArtifactUUID: Option[String])
+  extends PartitionReaderFactory with Logging {
+
+  override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+    val part = partition.asInstanceOf[PythonStreamingInputPartition]
+
+    // Maybe read from cached block prefetched by SimpleStreamReader
+    lazy val cachedBlock = if (part.blockId.isDefined) {
+      val block = SparkEnv.get.blockManager.get[InternalRow](part.blockId.get)
+        .map(_.data.asInstanceOf[Iterator[InternalRow]])
+      if (block.isEmpty) {
+        logWarning(s"Prefetched block ${part.blockId} for Python data source not found.")
+      }
+      block
+    } else None
+
+    new PartitionReader[InternalRow] {
+
+      private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics
+
+      private val outputIter = if (cachedBlock.isEmpty) {
+        // Evaluate the python read UDF if the partition is not cached as block.
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledReadFunc,
+          "read_from_data_source",
+          UserDefinedPythonDataSource.readInputSchema,
+          outputSchema,
+          metrics,
+          jobArtifactUUID)
+
+        evaluatorFactory.createEvaluator().eval(
+          part.index, Iterator.single(InternalRow(part.pickedPartition)))
+      } else cachedBlock.get
+
+      override def next(): Boolean = outputIter.hasNext
+
+      override def get(): InternalRow = outputIter.next()
+
+      override def close(): Unit = {}
+
+      override def currentMetricsValues(): Array[CustomTaskMetric] = {
+        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value })
+      }
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
index 2ef046f..a512b34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
@@ -23,14 +23,20 @@
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+
 import org.apache.spark.SparkEnv
 import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.PYTHON_EXEC
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
 
 object PythonStreamingSourceRunner {
   // When the python process for python_streaming_source_runner receives one of the
@@ -39,6 +45,11 @@
   val LATEST_OFFSET_FUNC_ID = 885
   val PARTITIONS_FUNC_ID = 886
   val COMMIT_FUNC_ID = 887
+  // Status code for JVM to decide how to receive prefetched record batches
+  // for simple stream reader.
+  val PREFETCHED_RECORDS_NOT_FOUND = 0
+  val NON_EMPTY_PYARROW_RECORD_BATCHES = 1
+  val EMPTY_PYARROW_RECORD_BATCHES = 2
 }
 
 /**
@@ -102,6 +113,8 @@
     // Send output schema
     PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
 
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+
     dataOut.flush()
 
     dataIn = new DataInputStream(
@@ -148,7 +161,8 @@
   /**
    * Invokes partitions(start, end) function of the stream reader and receive the return value.
    */
-  def partitions(start: String, end: String): Array[Array[Byte]] = {
+  def partitions(start: String, end: String): (Array[Array[Byte]], Option[Iterator[InternalRow]]) =
+  {
     dataOut.writeInt(PARTITIONS_FUNC_ID)
     PythonWorkerUtils.writeUTF(start, dataOut)
     PythonWorkerUtils.writeUTF(end, dataOut)
@@ -165,7 +179,20 @@
       val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
       pickledPartitions.append(pickledPartition)
     }
-    pickledPartitions.toArray
+    val prefetchedRecordsStatus = dataIn.readInt()
+    val iter: Option[Iterator[InternalRow]] = prefetchedRecordsStatus match {
+      case NON_EMPTY_PYARROW_RECORD_BATCHES => Some(readArrowRecordBatches())
+      case PREFETCHED_RECORDS_NOT_FOUND => None
+      case EMPTY_PYARROW_RECORD_BATCHES => Some(Iterator.empty)
+      case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+        val msg = PythonWorkerUtils.readUTF(dataIn)
+        throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+          action = "planPartitions", msg)
+      case _ =>
+        throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+          action = "planPartitions", s"unknown status code $prefetchedRecordsStatus")
+    }
+    (pickledPartitions.toArray, iter)
   }
 
   /**
@@ -200,4 +227,30 @@
         logError("Exception when trying to kill worker", e)
     }
   }
+
+  private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+    s"stream reader for $pythonExec", 0, Long.MaxValue)
+
+  def readArrowRecordBatches(): Iterator[InternalRow] = {
+    assert(dataIn.readInt() == SpecialLengths.START_ARROW_STREAM)
+    val reader = new ArrowStreamReader(dataIn, allocator)
+    val root = reader.getVectorSchemaRoot()
+    // When input is empty schema can't be read.
+    val schema = ArrowUtils.fromArrowSchema(root.getSchema())
+    assert(schema == outputSchema)
+
+    val vectors = root.getFieldVectors().asScala.map { vector =>
+      new ArrowColumnVector(vector)
+    }.toArray[ColumnVector]
+    val rows = ArrayBuffer[InternalRow]()
+    while (reader.loadNextBatch()) {
+      val batch = new ColumnarBatch(vectors)
+      batch.setNumRows(root.getRowCount)
+      // Need to copy the row because the ColumnarBatch row iterator use
+      // the same underlying Internal row.
+      rows.appendAll(batch.rowIterator().asScala.map(_.copy()))
+    }
+    reader.close(false)
+    rows.iterator
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
index 6f4bd18..97e6467 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -25,7 +25,7 @@
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, shouldTestPandasUDFs}
 import org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, PythonMicroBatchStream, PythonStreamingSourceOffset}
-import org.apache.spark.sql.execution.streaming.{MemoryStream, ProcessingTimeTrigger}
+import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, OffsetSeqLog, ProcessingTimeTrigger}
 import org.apache.spark.sql.streaming.StreamingQueryException
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -36,11 +36,11 @@
 
   val waitTimeout = 15.seconds
 
-  protected def simpleDataStreamReaderScript: String =
+  protected def testDataStreamReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition
       |
-      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |class TestDataStreamReader(DataSourceStreamReader):
       |    current = 0
       |    def initialOffset(self):
       |        return {"offset": {"partition-1": 0}}
@@ -57,6 +57,43 @@
       |        yield (partition.value,)
       |""".stripMargin
 
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+      |
+      |class SimpleDataStreamReader(SimpleDataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"partition-1": 0}
+      |    def read(self, start: dict):
+      |        start_idx = start["partition-1"]
+      |        it = iter([(i, ) for i in range(start_idx, start_idx + 2)])
+      |        return (it, {"partition-1": start_idx + 2})
+      |    def readBetweenOffsets(self, start: dict, end: dict):
+      |        start_idx = start["partition-1"]
+      |        end_idx = end["partition-1"]
+      |        return iter([(i, ) for i in range(start_idx, end_idx)])
+      |""".stripMargin
+
+  protected def simpleDataStreamReaderWithEmptyBatchScript: String =
+    """
+      |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+      |
+      |class SimpleDataStreamReader(SimpleDataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"partition-1": 0}
+      |    def read(self, start: dict):
+      |        start_idx = start["partition-1"]
+      |        if start_idx % 4 == 0:
+      |            it = iter([(i, ) for i in range(start_idx, start_idx + 2)])
+      |        else:
+      |            it = iter([])
+      |        return (it, {"partition-1": start_idx + 2})
+      |    def readBetweenOffsets(self, start: dict, end: dict):
+      |        start_idx = start["partition-1"]
+      |        end_idx = end["partition-1"]
+      |        return iter([(i, ) for i in range(start_idx, end_idx)])
+      |""".stripMargin
+
   protected def errorDataStreamReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition
@@ -117,11 +154,11 @@
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource
-         |$simpleDataStreamReaderScript
+         |$testDataStreamReaderScript
          |
          |class $dataSourceName(DataSource):
          |    def streamReader(self, schema):
-         |        return SimpleDataStreamReader()
+         |        return TestDataStreamReader()
          |""".stripMargin
     val inputSchema = StructType.fromDDL("input BINARY")
 
@@ -144,7 +181,7 @@
     stream.stop()
   }
 
-  test("Read from simple data stream source") {
+  test("SimpleDataSourceStreamReader run query and restart") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
@@ -154,9 +191,260 @@
          |class $dataSourceName(DataSource):
          |    def schema(self) -> str:
          |        return "id INT"
-         |    def streamReader(self, schema):
+         |    def simpleStreamReader(self, schema):
          |        return SimpleDataStreamReader()
          |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val df = spark.readStream.format(dataSourceName).load()
+
+      val stopSignal1 = new CountDownLatch(1)
+
+      val q1 = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          if (batchId == 10) stopSignal1.countDown()
+        })
+        .start()
+      stopSignal1.await()
+      assert(q1.recentProgress.forall(_.numInputRows == 2))
+      q1.stop()
+      q1.awaitTermination()
+
+      val stopSignal2 = new CountDownLatch(1)
+      val q2 = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          if (batchId == 20) stopSignal2.countDown()
+        }).start()
+      stopSignal2.await()
+      assert(q2.recentProgress.forall(_.numInputRows == 2))
+      q2.stop()
+      q2.awaitTermination()
+    }
+  }
+
+  // Verify prefetch and cache pattern of SimpleDataSourceStreamReader handle empty
+  // data batch correctly.
+  test("SimpleDataSourceStreamReader read empty batch") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderWithEmptyBatchScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def simpleStreamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val df = spark.readStream.format(dataSourceName).load()
+
+      val stopSignal = new CountDownLatch(1)
+
+      val q = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          if (batchId % 2 == 0) {
+            checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          } else {
+            assert(df.isEmpty)
+          }
+          if (batchId == 10) stopSignal.countDown()
+        })
+        .start()
+      stopSignal.await()
+      q.stop()
+      q.awaitTermination()
+    }
+  }
+
+  test("SimpleDataSourceStreamReader read exactly once") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def simpleStreamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val outputDir = new File(path, "output")
+      val df = spark.readStream.format(dataSourceName).load()
+      var lastBatch = 0
+      // Restart streaming query multiple times to verify exactly once guarantee.
+      for (i <- 1 to 5) {
+
+        if (i % 2 == 0) {
+          // Remove the last entry of commit log to test replaying microbatch during restart.
+          val offsetLog = new OffsetSeqLog(
+            spark, new File(checkpointDir, "offsets").getCanonicalPath)
+          val commitLog = new CommitLog(
+            spark, new File(checkpointDir, "commits").getCanonicalPath)
+          commitLog.purgeAfter(offsetLog.getLatest().get._1 - 1)
+        }
+
+        val q = df
+          .writeStream
+          .option("checkpointLocation", checkpointDir.getAbsolutePath)
+          .format("json")
+          .start(outputDir.getAbsolutePath)
+
+        while (q.recentProgress.length < 5) {
+          Thread.sleep(200)
+        }
+        q.stop()
+        q.awaitTermination()
+        lastBatch = q.lastProgress.batchId.toInt
+      }
+      assert(lastBatch > 20)
+      checkAnswer(spark.read.format("json").load(outputDir.getAbsolutePath),
+        (0 to  2 * lastBatch + 1).map(Row(_)))
+    }
+  }
+
+  test("initialOffset() method not implemented in SimpleDataSourceStreamReader") {
+    assume(shouldTestPandasUDFs)
+    val initialOffsetNotImplementedScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+         |class ErrorDataStreamReader(SimpleDataSourceStreamReader):
+         |    ...
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def simpleStreamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource =
+      createUserDefinedPythonDataSource(errorDataSourceName, initialOffsetNotImplementedScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit = {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError(
+      "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") {
+      stream => stream.initialOffset()
+    }
+
+    // User don't need to implement latestOffset for SimpleDataSourceStreamReader.
+    // The latestOffset method of simple stream reader invokes initialOffset() and read()
+    // So the not implemented method is initialOffset.
+    testMicroBatchStreamError(
+      "latestOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") {
+      stream => stream.latestOffset()
+    }
+  }
+
+  test("read() method throw error in SimpleDataSourceStreamReader") {
+    assume(shouldTestPandasUDFs)
+    val initialOffsetNotImplementedScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+         |class ErrorDataStreamReader(SimpleDataSourceStreamReader):
+         |    def initialOffset(self):
+         |        return {"partition": 1}
+         |    def read(self, start):
+         |        raise Exception("error reading available data")
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def simpleStreamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource =
+      createUserDefinedPythonDataSource(errorDataSourceName, initialOffsetNotImplementedScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit = {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError(
+      "latestOffset", "Exception: error reading available data") {
+      stream => stream.latestOffset()
+    }
+  }
+
+  test("Read from test data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$testDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def streamReader(self, schema):
+         |        return TestDataStreamReader()
+         |""".stripMargin
 
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
@@ -188,7 +476,7 @@
          |        self.start = start
          |        self.end = end
          |
-         |class SimpleDataStreamReader(DataSourceStreamReader):
+         |class TestDataStreamReader(DataSourceStreamReader):
          |    current = 0
          |    def initialOffset(self):
          |        return {"offset": 0}
@@ -210,7 +498,7 @@
          |        return "id INT"
          |
          |    def streamReader(self, schema):
-         |        return SimpleDataStreamReader()
+         |        return TestDataStreamReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
@@ -303,7 +591,6 @@
     assert(err.getMessage.contains("error reading data"))
   }
 
-
   test("Method not implemented in stream reader") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =