[SPARK-47365][PYTHON] Add toArrow() DataFrame method to PySpark
### What changes were proposed in this pull request?
- Add a PySpark DataFrame method `toArrow()` which returns the contents of the DataFrame as a [PyArrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html), for both local Spark and Spark Connect.
- Add a new entry to the **Apache Arrow in PySpark** user guide page describing usage of the `toArrow()` method.
- Add a new option to the method `_collect_as_arrow()` to provide more useful output when there are zero records returned. (This keeps the implementation of `toArrow()` simpler.)
### Why are the changes needed?
In the Apache Arrow community, we hear from a lot of users who want to return the contents of a PySpark DataFrame as a PyArrow Table. Currently the only documented way to do this is to return the contents as a pandas DataFrame, then use PyArrow (`pa`) to convert that to a PyArrow Table.
```py
pa.Table.from_pandas(df.toPandas())
```
But going through pandas adds significant overhead which is easily avoided since internally `toPandas()` already converts the contents of Spark DataFrame to Arrow format as an intermediate step when `spark.sql.execution.arrow.pyspark.enabled` is `true`.
Currently it is also possible to use the experimental `_collect_as_arrow()` method to return the contents of a PySpark DataFrame as a list of PyArrow RecordBatches. This PR adds a new non-experimental method `toArrow()` which returns the more user-friendly PyArrow Table object.
This PR also adds a new argument `empty_list_if_zero_records` to the experimental method `_collect_as_arrow()` to control what the method returns in the case when the result data has zero rows. If set to `True` (the default), the existing behavior is preserved, and the method returns an empty Python list. If set to `False`, the method returns returns a length-one list containing an empty Arrow RecordBatch which includes the schema. This is used by `toArrow()` which requires the schema even if the data has zero rows.
For Spark Connect, there is already a `SparkSession.client.to_table()` method that returns a PyArrow table. This PR uses that to expose `toArrow()` for Spark Connect.
### Does this PR introduce _any_ user-facing change?
- It adds a DataFrame method `toArrow()` to the PySpark SQL DataFrame API.
- It adds a new argument `empty_list_if_zero_records` to the experimental DataFrame method `_collect_as_arrow()` with a default value which preserves the method's existing behavior.
- It exposes `toArrow()` for Spark Connect, via the existing `SparkSession.client.to_table()` method.
- It does not introduce any other user-facing changes.
### How was this patch tested?
This adds a new test and a new helper function for the test in `pyspark/sql/tests/test_arrow.py`.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45481 from ianmcook/SPARK-47365.
Lead-authored-by: Ian Cook <ianmcook@gmail.com>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py
index 03daf18..48aee48 100644
--- a/examples/src/main/python/sql/arrow.py
+++ b/examples/src/main/python/sql/arrow.py
@@ -33,6 +33,22 @@
require_minimum_pyarrow_version()
+def dataframe_to_arrow_table_example(spark: SparkSession) -> None:
+ import pyarrow as pa # noqa: F401
+ from pyspark.sql.functions import rand
+
+ # Create a Spark DataFrame
+ df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()})
+
+ # Convert the Spark DataFrame to a PyArrow Table
+ table = df.select("*").toArrow()
+
+ print(table.schema)
+ # 0: double not null
+ # 1: double not null
+ # 2: double not null
+
+
def dataframe_with_arrow_example(spark: SparkSession) -> None:
import numpy as np
import pandas as pd
@@ -302,6 +318,8 @@
.appName("Python Arrow-in-Spark example") \
.getOrCreate()
+ print("Running Arrow conversion example: DataFrame to Table")
+ dataframe_to_arrow_table_example(spark)
print("Running Pandas to/from conversion example")
dataframe_with_arrow_example(spark)
print("Running pandas_udf example: Series to Frame")
diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst
index b69a277..ec39b64 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -109,6 +109,7 @@
DataFrame.tail
DataFrame.take
DataFrame.to
+ DataFrame.toArrow
DataFrame.toDF
DataFrame.toJSON
DataFrame.toLocalIterator
diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst
index 1d6a4df..0a527d8 100644
--- a/python/docs/source/user_guide/sql/arrow_pandas.rst
+++ b/python/docs/source/user_guide/sql/arrow_pandas.rst
@@ -39,6 +39,20 @@
You can install it using pip or conda from the conda-forge channel. See PyArrow
`installation <https://arrow.apache.org/docs/python/install.html>`_ for details.
+Conversion to Arrow Table
+-------------------------
+
+You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table.
+
+.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
+ :language: python
+ :lines: 37-49
+ :dedent: 4
+
+Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to
+the driver program and should be done on a small subset of the data. Not all Spark data types are
+currently supported and an error can be raised if a column has an unsupported type.
+
Enabling for Conversion to/from Pandas
--------------------------------------
@@ -53,7 +67,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 37-52
+ :lines: 53-68
:dedent: 4
Using the above optimizations with Arrow will produce the same results as when Arrow is not
@@ -90,7 +104,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 56-80
+ :lines: 72-96
:dedent: 4
In the following sections, it describes the combinations of the supported type hints. For simplicity,
@@ -113,7 +127,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 84-114
+ :lines: 100-130
:dedent: 4
For detailed usage, please see :func:`pandas_udf`.
@@ -152,7 +166,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 118-140
+ :lines: 134-156
:dedent: 4
For detailed usage, please see :func:`pandas_udf`.
@@ -174,7 +188,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 144-167
+ :lines: 160-183
:dedent: 4
For detailed usage, please see :func:`pandas_udf`.
@@ -205,7 +219,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 171-212
+ :lines: 187-228
:dedent: 4
.. currentmodule:: pyspark.sql.functions
@@ -270,7 +284,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 216-234
+ :lines: 232-250
:dedent: 4
For detailed usage, please see please see :meth:`GroupedData.applyInPandas`
@@ -288,7 +302,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 238-249
+ :lines: 254-265
:dedent: 4
For detailed usage, please see :meth:`DataFrame.mapInPandas`.
@@ -327,7 +341,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 253-275
+ :lines: 269-291
:dedent: 4
@@ -349,7 +363,7 @@
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
- :lines: 279-297
+ :lines: 295-313
:dedent: 4
Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF
@@ -421,9 +435,12 @@
Setting Arrow ``self_destruct`` for memory savings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas DataFrame.
-This option is experimental, and some operations may fail on the resulting Pandas DataFrame due to immutable backing arrays.
-Typically, you would see the error ``ValueError: buffer source array is read-only``.
-Newer versions of Pandas may fix these errors by improving support for such cases.
-You can work around this error by copying the column(s) beforehand.
-Additionally, this conversion may be slower because it is single-threaded.
+Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled``
+can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a
+Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas
+DataFrame. This option can also save memory when creating a PyArrow Table via ``toArrow``.
+This option is experimental. When used with ``toPandas``, some operations may fail on the resulting
+Pandas DataFrame due to immutable backing arrays. Typically, you would see the error
+``ValueError: buffer source array is read-only``. Newer versions of Pandas may fix these errors by
+improving support for such cases. You can work around this error by copying the column(s)
+beforehand. Additionally, this conversion may be slower because it is single-threaded.
diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py
index db9f225..9b6790d 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -74,6 +74,7 @@
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
+ import pyarrow as pa
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext
from pyspark._typing import PrimitiveType
@@ -1825,6 +1826,9 @@
) -> ParentDataFrame:
return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile)
+ def toArrow(self) -> "pa.Table":
+ return PandasConversionMixin.toArrow(self)
+
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 843c92a..3c9415a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1768,6 +1768,10 @@
assert table is not None
return (table, schema)
+ def toArrow(self) -> "pa.Table":
+ table, _ = self._to_table()
+ return table
+
def toPandas(self) -> "PandasDataFrameLike":
query = self._plan.to_proto(self._session.client)
return self._session.client.to_pandas(query, self._plan.observations)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e3d52c4..886f72c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -44,6 +44,7 @@
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
+ import pyarrow as pa
from pyspark.core.context import SparkContext
from pyspark.core.rdd import RDD
from pyspark._typing import PrimitiveType
@@ -1200,6 +1201,7 @@
DataFrame.take : Returns the first `n` rows.
DataFrame.head : Returns the first `n` rows.
DataFrame.toPandas : Returns the data as a pandas DataFrame.
+ DataFrame.toArrow : Returns the data as a PyArrow Table.
Notes
-----
@@ -6213,6 +6215,34 @@
"""
...
+ @dispatch_df_method
+ def toArrow(self) -> "pa.Table":
+ """
+ Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``.
+
+ This is only available if PyArrow is installed and available.
+
+ .. versionadded:: 4.0.0
+
+ Notes
+ -----
+ This method should only be used if the resulting PyArrow ``pyarrow.Table`` is
+ expected to be small, as all the data is loaded into the driver's memory.
+
+ This API is a developer API.
+
+ Examples
+ --------
+ >>> df.toArrow() # doctest: +SKIP
+ pyarrow.Table
+ age: int64
+ name: string
+ ----
+ age: [[2,5]]
+ name: [["Alice","Bob"]]
+ """
+ ...
+
def toPandas(self) -> "PandasDataFrameLike":
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py
index ec4e21d..3446083 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -225,15 +225,48 @@
else:
return pdf
- def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]:
+ def toArrow(self) -> "pa.Table":
+ from pyspark.sql.dataframe import DataFrame
+
+ assert isinstance(self, DataFrame)
+
+ jconf = self.sparkSession._jconf
+
+ from pyspark.sql.pandas.types import to_arrow_schema
+ from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
+
+ require_minimum_pyarrow_version()
+ to_arrow_schema(self.schema)
+
+ import pyarrow as pa
+
+ self_destruct = jconf.arrowPySparkSelfDestructEnabled()
+ batches = self._collect_as_arrow(
+ split_batches=self_destruct, empty_list_if_zero_records=False
+ )
+ table = pa.Table.from_batches(batches)
+ # Ensure only the table has a reference to the batches, so that
+ # self_destruct (if enabled) is effective
+ del batches
+ return table
+
+ def _collect_as_arrow(
+ self,
+ split_batches: bool = False,
+ empty_list_if_zero_records: bool = True,
+ ) -> List["pa.RecordBatch"]:
"""
- Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
+ Returns all records as a list of Arrow RecordBatches. PyArrow must be installed
and available on driver and worker Python environments.
This is an experimental feature.
:param split_batches: split batches such that each column is in its own allocation, so
that the selfDestruct optimization is effective; default False.
+ :param empty_list_if_zero_records: If True (the default), returns an empty list if the
+ result has 0 records. Otherwise, returns a list of length 1 containing an empty
+ Arrow RecordBatch which includes the schema.
+
.. note:: Experimental.
"""
from pyspark.sql.dataframe import DataFrame
@@ -282,8 +315,15 @@
batches = results[:-1]
batch_order = results[-1]
- # Re-order the batch list using the correct order
- return [batches[i] for i in batch_order]
+ if len(batches) or empty_list_if_zero_records:
+ # Re-order the batch list using the correct order
+ return [batches[i] for i in batch_order]
+ else:
+ from pyspark.sql.pandas.types import to_arrow_schema
+
+ schema = to_arrow_schema(self.schema)
+ empty_arrays = [pa.array([], type=field.type) for field in schema]
+ return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]
class SparkConversionMixin:
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 8636e95..71d3c46 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -179,6 +179,35 @@
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)
+ def create_arrow_table(self):
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ data_dict = {}
+ for j, name in enumerate(self.schema.names):
+ data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+ t = pa.Table.from_pydict(data_dict)
+ # convert these to Arrow types
+ new_schema = t.schema.set(
+ t.schema.get_field_index("2_int_t"), pa.field("2_int_t", pa.int32())
+ )
+ new_schema = new_schema.set(
+ new_schema.get_field_index("4_float_t"), pa.field("4_float_t", pa.float32())
+ )
+ new_schema = new_schema.set(
+ new_schema.get_field_index("6_decimal_t"),
+ pa.field("6_decimal_t", pa.decimal128(38, 18)),
+ )
+ t = t.cast(new_schema)
+ # convert timestamp to local timezone
+ timezone = self.spark.conf.get("spark.sql.session.timeZone")
+ t = t.set_column(
+ t.schema.get_field_index("8_timestamp_t"),
+ "8_timestamp_t",
+ pc.assume_timezone(t["8_timestamp_t"], timezone),
+ )
+ return t
+
@property
def create_np_arrs(self):
import numpy as np
@@ -339,6 +368,12 @@
pdf_arrow = df.toPandas()
assert_frame_equal(pdf_arrow, pdf)
+ def test_arrow_round_trip(self):
+ t_in = self.create_arrow_table()
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ t_out = df.toArrow()
+ self.assertTrue(t_out.equals(t_in))
+
def test_pandas_self_destruct(self):
import pyarrow as pa