blob: 29dc194b7f84c8412fc353d842c9dbadc13913c8 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, Tuple, Union, TYPE_CHECKING
from pyspark.sql import Row
from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
__all__ = ["DataSource", "DataSourceReader"]
class DataSource(ABC):
"""
A base class for data sources.
This class represents a custom data source that allows for reading from and/or
writing to it. The data source provides methods to create readers and writers
for reading and writing data, respectively. At least one of the methods ``reader``
or ``writer`` must be implemented by any subclass to make the data source either
readable or writable (or both).
After implementing this interface, you can start to load your data source using
``spark.read.format(...).load()`` and save data using ``df.write.format(...).save()``.
"""
def __init__(self, options: Dict[str, "OptionalPrimitiveType"]):
"""
Initializes the data source with user-provided options.
Parameters
----------
options : dict
A dictionary representing the options for this data source.
Notes
-----
This method should not contain any non-serializable objects.
"""
self.options = options
@property
def name(self) -> str:
"""
Returns a string represents the format name of this data source.
By default, it is the class name of the data source. It can be overridden to
provide a customized short name for the data source.
Examples
--------
>>> def name(self):
... return "my_data_source"
"""
return self.__class__.__name__
def schema(self) -> Union[StructType, str]:
"""
Returns the schema of the data source.
It can reference the ``options`` field to infer the data source's schema when
users do not explicitly specify it. This method is invoked once when calling
``spark.read.format(...).load()`` to get the schema for a data source read
operation. If this method is not implemented, and a user does not provide a
schema when reading the data source, an exception will be thrown.
Returns
-------
schema : StructType or str
The schema of this data source or a DDL string represents the schema
Examples
--------
Returns a DDL string:
>>> def schema(self):
... return "a INT, b STRING"
Returns a StructType:
>>> def schema(self):
... return StructType().add("a", "int").add("b", "string")
"""
raise NotImplementedError
def reader(self, schema: StructType) -> "DataSourceReader":
"""
Returns a ``DataSourceReader`` instance for reading data.
The implementation is required for readable data sources.
Parameters
----------
schema : StructType
The schema of the data to be read.
Returns
-------
reader : DataSourceReader
A reader instance for this data source.
"""
raise NotImplementedError
class DataSourceReader(ABC):
"""
A base class for data source readers. Data source readers are responsible for
outputting data from a data source.
"""
def partitions(self) -> Iterator[Any]:
"""
Returns an iterator of partitions for this data source.
Partitions are used to split data reading operations into parallel tasks.
If this method returns N partitions, the query planner will create N tasks.
Each task will execute ``read(partition)`` in parallel, using the respective
partition value to read the data.
This method is called once during query planning. By default, it returns a
single partition with the value ``None``. Subclasses can override this method
to return multiple partitions.
It's recommended to override this method for better performance when reading
large datasets.
Returns
-------
Iterator[Any]
An iterator of partitions for this data source. The partition value can be
any serializable objects.
Notes
-----
This method should not return any un-picklable objects.
Examples
--------
Returns a list of integers:
>>> def partitions(self):
... return [1, 2, 3]
Returns a list of string:
>>> def partitions(self):
... return ["a", "b", "c"]
Returns a list of tuples:
>>> def partitions(self):
... return [("a", 1), ("b", 2), ("c", 3)]
Returns a list of dictionaries:
>>> def partitions(self):
... return [{"a": 1}, {"b": 2}, {"c": 3}]
"""
yield None
@abstractmethod
def read(self, partition: Any) -> Iterator[Union[Tuple, Row]]:
"""
Generates data for a given partition and returns an iterator of tuples or rows.
This method is invoked once per partition to read the data. Implementing
this method is required for readable data sources. You can initialize any
non-serializable resources required for reading data from the data source
within this method.
Parameters
----------
partition : object
The partition to read. It must be one of the partition values returned by
``partitions()``.
Returns
-------
Iterator[Tuple] or Iterator[Row]
An iterator of tuples or rows. Each tuple or row will be converted to a row
in the final DataFrame.
Examples
--------
Yields a list of tuples:
>>> def read(self, partition):
... yield (partition, 0)
... yield (partition, 1)
Yields a list of rows:
>>> def read(self, partition):
... yield Row(partition=partition, value=0)
... yield Row(partition=partition, value=1)
"""
...