Add Python bindings for accessing ExecutionMetrics (#1381)
* feat: add Python bindings for accessing ExecutionMetrics
* test: imporve tests
* first round of reviews
* plan caching
* address some concerns
* merge and address comments
* fix Ci issues
* attempt to fix lint
* fix build
* fix docstring
* address some more comments
---------
Co-authored-by: ShreyeshArangath <shryeyesh.arangath@gmail.com>
diff --git a/Cargo.lock b/Cargo.lock
index 1cbb0ac..4efca3e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1667,6 +1667,7 @@
"arrow",
"arrow-select",
"async-trait",
+ "chrono",
"cstr",
"datafusion",
"datafusion-ffi",
diff --git a/Cargo.toml b/Cargo.toml
index 14408d2..d0e87a9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -35,6 +35,7 @@
pyo3 = { version = "0.28" }
pyo3-async-runtimes = { version = "0.28" }
pyo3-log = "0.13.3"
+chrono = { version = "0.4", default-features = false }
arrow = { version = "58" }
arrow-array = { version = "58" }
arrow-schema = { version = "58" }
diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml
index 3e2b01c..d714dc9 100644
--- a/crates/core/Cargo.toml
+++ b/crates/core/Cargo.toml
@@ -47,6 +47,7 @@
] }
pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"] }
pyo3-log = { workspace = true }
+chrono = { workspace = true }
arrow = { workspace = true, features = ["pyarrow"] }
arrow-select = { workspace = true }
datafusion = { workspace = true, features = ["avro", "unicode_expressions"] }
diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs
index c067eac..2d815ec 100644
--- a/crates/core/src/dataframe.rs
+++ b/crates/core/src/dataframe.rs
@@ -37,9 +37,15 @@
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
+use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::SortExpr;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
+use datafusion::physical_plan::{
+ ExecutionPlan as DFExecutionPlan, collect as df_collect,
+ collect_partitioned as df_collect_partitioned, execute_stream as df_execute_stream,
+ execute_stream_partitioned as df_execute_stream_partitioned,
+};
use datafusion::prelude::*;
use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future};
use futures::{StreamExt, TryStreamExt};
@@ -308,6 +314,9 @@
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
batches: SharedCachedBatches,
+
+ // Cache the last physical plan so that metrics are available after execution.
+ last_plan: Arc<Mutex<Option<Arc<dyn DFExecutionPlan>>>>,
}
impl PyDataFrame {
@@ -316,6 +325,7 @@
Self {
df: Arc::new(df),
batches: Arc::new(Mutex::new(None)),
+ last_plan: Arc::new(Mutex::new(None)),
}
}
@@ -387,6 +397,20 @@
Ok(html_str)
}
+ /// Create the physical plan, cache it in `last_plan`, and return the plan together
+ /// with a task context. Centralises the repeated three-line pattern that appears in
+ /// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`.
+ fn create_and_cache_plan(
+ &self,
+ py: Python,
+ ) -> PyDataFusionResult<(Arc<dyn DFExecutionPlan>, Arc<TaskContext>)> {
+ let df = self.df.as_ref().clone();
+ let new_plan = wait_for_future(py, df.create_physical_plan())??;
+ *self.last_plan.lock() = Some(Arc::clone(&new_plan));
+ let task_ctx = Arc::new(self.df.as_ref().task_ctx());
+ Ok((new_plan, task_ctx))
+ }
+
async fn collect_column_inner(&self, column: &str) -> Result<ArrayRef, DataFusionError> {
let batches = self
.df
@@ -646,8 +670,9 @@
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
- let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
- .map_err(PyDataFusionError::from)?;
+ let (plan, task_ctx) = self.create_and_cache_plan(py)?;
+ let batches =
+ wait_for_future(py, df_collect(plan, task_ctx))?.map_err(PyDataFusionError::from)?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
@@ -662,7 +687,8 @@
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
/// maintaining the input partitioning.
fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
- let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
+ let (plan, task_ctx) = self.create_and_cache_plan(py)?;
+ let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))?
.map_err(PyDataFusionError::from)?;
batches
@@ -840,7 +866,13 @@
}
/// Get the execution plan for this `DataFrame`
+ ///
+ /// If the DataFrame has already been executed (e.g. via `collect()`),
+ /// returns the cached plan which includes populated metrics.
fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
+ if let Some(plan) = self.last_plan.lock().as_ref() {
+ return Ok(PyExecutionPlan::new(Arc::clone(plan)));
+ }
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
Ok(plan.into())
}
@@ -1198,14 +1230,17 @@
}
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
- let df = self.df.as_ref().clone();
- let stream = spawn_future(py, async move { df.execute_stream().await })?;
+ let (plan, task_ctx) = self.create_and_cache_plan(py)?;
+ let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?;
Ok(PyRecordBatchStream::new(stream))
}
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
- let df = self.df.as_ref().clone();
- let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
+ let (plan, task_ctx) = self.create_and_cache_plan(py)?;
+ let streams = spawn_future(
+ py,
+ async move { df_execute_stream_partitioned(plan, task_ctx) },
+ )?;
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
}
diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs
index fc2d006..77d6991 100644
--- a/crates/core/src/lib.rs
+++ b/crates/core/src/lib.rs
@@ -43,6 +43,7 @@
pub mod expr;
#[allow(clippy::borrow_deref_ref)]
mod functions;
+pub mod metrics;
mod options;
pub mod physical_plan;
mod pyarrow_filter_expression;
@@ -92,6 +93,8 @@
m.add_class::<udtf::PyTableFunction>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
+ m.add_class::<metrics::PyMetricsSet>()?;
+ m.add_class::<metrics::PyMetric>()?;
m.add_class::<physical_plan::PyExecutionPlan>()?;
m.add_class::<record_batch::PyRecordBatch>()?;
m.add_class::<record_batch::PyRecordBatchStream>()?;
diff --git a/crates/core/src/metrics.rs b/crates/core/src/metrics.rs
new file mode 100644
index 0000000..ee0937e
--- /dev/null
+++ b/crates/core/src/metrics.rs
@@ -0,0 +1,169 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use chrono::{Datelike, Timelike};
+use datafusion::physical_plan::metrics::{Metric, MetricValue, MetricsSet, Timestamp};
+use pyo3::prelude::*;
+
+#[pyclass(from_py_object, frozen, name = "MetricsSet", module = "datafusion")]
+#[derive(Debug, Clone)]
+pub struct PyMetricsSet {
+ metrics: MetricsSet,
+}
+
+impl PyMetricsSet {
+ pub fn new(metrics: MetricsSet) -> Self {
+ Self { metrics }
+ }
+}
+
+#[pymethods]
+impl PyMetricsSet {
+ fn metrics(&self) -> Vec<PyMetric> {
+ self.metrics
+ .iter()
+ .map(|m| PyMetric::new(Arc::clone(m)))
+ .collect()
+ }
+
+ fn output_rows(&self) -> Option<usize> {
+ self.metrics.output_rows()
+ }
+
+ fn elapsed_compute(&self) -> Option<usize> {
+ self.metrics.elapsed_compute()
+ }
+
+ fn spill_count(&self) -> Option<usize> {
+ self.metrics.spill_count()
+ }
+
+ fn spilled_bytes(&self) -> Option<usize> {
+ self.metrics.spilled_bytes()
+ }
+
+ fn spilled_rows(&self) -> Option<usize> {
+ self.metrics.spilled_rows()
+ }
+
+ fn sum_by_name(&self, name: &str) -> Option<usize> {
+ self.metrics.sum_by_name(name).map(|v| v.as_usize())
+ }
+
+ fn __repr__(&self) -> String {
+ format!("{}", self.metrics)
+ }
+}
+
+#[pyclass(from_py_object, frozen, name = "Metric", module = "datafusion")]
+#[derive(Debug, Clone)]
+pub struct PyMetric {
+ metric: Arc<Metric>,
+}
+
+impl PyMetric {
+ pub fn new(metric: Arc<Metric>) -> Self {
+ Self { metric }
+ }
+
+ fn timestamp_to_pyobject<'py>(
+ py: Python<'py>,
+ ts: &Timestamp,
+ ) -> PyResult<Option<Bound<'py, PyAny>>> {
+ match ts.value() {
+ Some(dt) => {
+ let datetime_mod = py.import("datetime")?;
+ let datetime_cls = datetime_mod.getattr("datetime")?;
+ let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?;
+ let result = datetime_cls.call1((
+ dt.year(),
+ dt.month(),
+ dt.day(),
+ dt.hour(),
+ dt.minute(),
+ dt.second(),
+ dt.timestamp_subsec_micros(),
+ tz_utc,
+ ))?;
+ Ok(Some(result))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+#[pymethods]
+impl PyMetric {
+ #[getter]
+ fn name(&self) -> String {
+ self.metric.value().name().to_string()
+ }
+
+ #[getter]
+ fn value<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
+ match self.metric.value() {
+ MetricValue::OutputRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
+ MetricValue::OutputBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
+ MetricValue::ElapsedCompute(t) => Ok(Some(t.value().into_pyobject(py)?.into_any())),
+ MetricValue::SpillCount(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
+ MetricValue::SpilledBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
+ MetricValue::SpilledRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
+ MetricValue::CurrentMemoryUsage(g) => Ok(Some(g.value().into_pyobject(py)?.into_any())),
+ MetricValue::Count { count, .. } => {
+ Ok(Some(count.value().into_pyobject(py)?.into_any()))
+ }
+ MetricValue::Gauge { gauge, .. } => {
+ Ok(Some(gauge.value().into_pyobject(py)?.into_any()))
+ }
+ MetricValue::Time { time, .. } => Ok(Some(time.value().into_pyobject(py)?.into_any())),
+ MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => {
+ Self::timestamp_to_pyobject(py, ts)
+ }
+ _ => Ok(None),
+ }
+ }
+
+ #[getter]
+ fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
+ match self.metric.value() {
+ MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => {
+ Self::timestamp_to_pyobject(py, ts)
+ }
+ _ => Ok(None),
+ }
+ }
+
+ #[getter]
+ fn partition(&self) -> Option<usize> {
+ self.metric.partition()
+ }
+
+ fn labels(&self) -> HashMap<String, String> {
+ self.metric
+ .labels()
+ .iter()
+ .map(|l| (l.name().to_string(), l.value().to_string()))
+ .collect()
+ }
+
+ fn __repr__(&self) -> String {
+ format!("{}", self.metric.value())
+ }
+}
diff --git a/crates/core/src/physical_plan.rs b/crates/core/src/physical_plan.rs
index 8674a8b..fac9738 100644
--- a/crates/core/src/physical_plan.rs
+++ b/crates/core/src/physical_plan.rs
@@ -26,6 +26,7 @@
use crate::context::PySessionContext;
use crate::errors::PyDataFusionResult;
+use crate::metrics::PyMetricsSet;
#[pyclass(
from_py_object,
@@ -96,6 +97,10 @@
Ok(Self::new(plan))
}
+ pub fn metrics(&self) -> Option<PyMetricsSet> {
+ self.plan.metrics().map(PyMetricsSet::new)
+ }
+
fn __repr__(&self) -> String {
self.display_indent()
}
diff --git a/docs/source/user-guide/dataframe/execution-metrics.rst b/docs/source/user-guide/dataframe/execution-metrics.rst
new file mode 100644
index 0000000..764fa76
--- /dev/null
+++ b/docs/source/user-guide/dataframe/execution-metrics.rst
@@ -0,0 +1,215 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+
+.. _execution_metrics:
+
+Execution Metrics
+=================
+
+Overview
+--------
+
+When DataFusion executes a query it compiles the logical plan into a tree of
+*physical plan operators* (e.g. ``FilterExec``, ``ProjectionExec``,
+``HashAggregateExec``). Each operator can record runtime statistics while it
+runs. These statistics are called **execution metrics**.
+
+Typical metrics include:
+
+- **output_rows** – number of rows produced by the operator
+- **elapsed_compute** – total CPU time (nanoseconds) spent inside the operator
+- **spill_count** – number of times the operator spilled data to disk
+- **spilled_bytes** – total bytes written to disk during spills
+- **spilled_rows** – total rows written to disk during spills
+
+Metrics are collected *per-partition*: DataFusion may execute each operator
+in parallel across several partitions. The convenience properties on
+:py:class:`~datafusion.MetricsSet` (e.g. ``output_rows``, ``elapsed_compute``)
+automatically sum the named metric across **all** partitions, giving a single
+aggregate value for the operator as a whole. You can also access the raw
+per-partition :py:class:`~datafusion.Metric` objects via
+:py:meth:`~datafusion.MetricsSet.metrics`.
+
+When Are Metrics Available?
+---------------------------
+
+Some operators (for example ``DataSourceExec``) eagerly create a
+:py:class:`~datafusion.MetricsSet` when the physical plan is built, so
+:py:meth:`~datafusion.ExecutionPlan.metrics` may return a set even before any
+rows have been processed. However, metric **values** such as ``output_rows``
+are only meaningful **after** the DataFrame has been executed via one of the
+terminal operations:
+
+- :py:meth:`~datafusion.DataFrame.collect`
+- :py:meth:`~datafusion.DataFrame.collect_partitioned`
+- :py:meth:`~datafusion.DataFrame.execute_stream`
+ (metrics are available once the stream has been fully consumed)
+- :py:meth:`~datafusion.DataFrame.execute_stream_partitioned`
+ (metrics are available once all partition streams have been fully consumed)
+
+Before execution, metric values will be ``0`` or ``None``.
+
+.. note::
+
+ **display() does not populate metrics.**
+ When a DataFrame is displayed in a notebook (e.g. via ``display(df)`` or
+ automatic ``repr`` output), DataFusion runs a *limited* internal execution
+ to fetch preview rows. This internal execution does **not** cache the
+ physical plan used, so :py:meth:`~datafusion.ExecutionPlan.collect_metrics`
+ will not reflect the display execution. To access metrics you must call
+ one of the terminal operations listed above.
+
+If you call :py:meth:`~datafusion.DataFrame.collect` (or another terminal
+operation) multiple times on the same DataFrame, each call creates a fresh
+physical plan. Metrics from :py:meth:`~datafusion.DataFrame.execution_plan`
+always reflect the **most recent** execution.
+
+Reading the Physical Plan Tree
+--------------------------------
+
+:py:meth:`~datafusion.DataFrame.execution_plan` returns the root
+:py:class:`~datafusion.ExecutionPlan` node of the physical plan tree. The tree
+mirrors the operator pipeline: the root is typically a projection or
+coalescing node; its children are filters, aggregates, scans, etc.
+
+The ``operator_name`` string returned by
+:py:meth:`~datafusion.ExecutionPlan.collect_metrics` is the *display* name of
+the node, for example ``"FilterExec: column1@0 > 1"``. This is the same string
+you would see when calling ``plan.display()``.
+
+Aggregated vs Per-Partition Metrics
+------------------------------------
+
+DataFusion executes each operator across one or more **partitions** in
+parallel. The :py:class:`~datafusion.MetricsSet` convenience properties
+(``output_rows``, ``elapsed_compute``, etc.) automatically **sum** the named
+metric across all partitions, giving a single aggregate value.
+
+To inspect individual partitions — for example to detect data skew where one
+partition processes far more rows than others — iterate over the raw
+:py:class:`~datafusion.Metric` objects:
+
+.. code-block:: python
+
+ for metric in metrics_set.metrics():
+ print(f" partition={metric.partition} {metric.name}={metric.value}")
+
+The ``partition`` property is a 0-based index (``0``, ``1``, …) identifying
+which parallel slot processed this metric. It is ``None`` for metrics that
+apply globally (not tied to a specific partition).
+
+Available Metrics
+-----------------
+
+The following metrics are directly accessible as properties on
+:py:class:`~datafusion.MetricsSet`:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 75
+
+ * - Property
+ - Description
+ * - ``output_rows``
+ - Number of rows emitted by the operator (summed across partitions).
+ * - ``elapsed_compute``
+ - Wall-clock CPU time **in nanoseconds** spent inside the operator's
+ compute loop, excluding I/O wait. Useful for identifying which
+ operators are most expensive (summed across partitions).
+ * - ``spill_count``
+ - Number of spill-to-disk events triggered by memory pressure. This is
+ a unitless count of events, not a measure of data volume (summed across
+ partitions).
+ * - ``spilled_bytes``
+ - Total bytes written to disk during spill events (summed across
+ partitions).
+ * - ``spilled_rows``
+ - Total rows written to disk during spill events (summed across
+ partitions).
+
+Any metric not listed above can be accessed via
+:py:meth:`~datafusion.MetricsSet.sum_by_name`, or by iterating over the raw
+:py:class:`~datafusion.Metric` objects returned by
+:py:meth:`~datafusion.MetricsSet.metrics`.
+
+Labels
+------
+
+A :py:class:`~datafusion.Metric` may carry *labels*: key/value pairs that
+provide additional context. Labels are operator-specific; most metrics have
+an empty label dict.
+
+Some operators tag their metrics with labels to distinguish variants. For
+example, a ``HashAggregateExec`` may record separate ``output_rows`` metrics
+for intermediate and final output:
+
+.. code-block:: python
+
+ for metric in metrics_set.metrics():
+ print(metric.name, metric.labels())
+ # output_rows {'output_type': 'final'}
+ # output_rows {'output_type': 'intermediate'}
+
+When summing by name (via :py:attr:`~datafusion.MetricsSet.output_rows` or
+:py:meth:`~datafusion.MetricsSet.sum_by_name`), **all** metrics with that
+name are summed regardless of labels. To filter by label, iterate over the
+raw :py:class:`~datafusion.Metric` objects directly.
+
+End-to-End Example
+------------------
+
+.. code-block:: python
+
+ from datafusion import SessionContext
+
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE sales AS VALUES (1, 100), (2, 200), (3, 50)")
+
+ df = ctx.sql("SELECT * FROM sales WHERE column1 > 1")
+
+ # Execute the query — this populates the metrics
+ results = df.collect()
+
+ # Retrieve the physical plan with metrics
+ plan = df.execution_plan()
+
+ # Walk every operator and print its metrics
+ for operator_name, ms in plan.collect_metrics():
+ if ms.output_rows is not None:
+ print(f"{operator_name}")
+ print(f" output_rows = {ms.output_rows}")
+ print(f" elapsed_compute = {ms.elapsed_compute} ns")
+
+ # Access raw per-partition metrics
+ for operator_name, ms in plan.collect_metrics():
+ for metric in ms.metrics():
+ print(
+ f" partition={metric.partition} "
+ f"{metric.name}={metric.value} "
+ f"labels={metric.labels()}"
+ )
+
+API Reference
+-------------
+
+- :py:class:`datafusion.ExecutionPlan` — physical plan node
+- :py:meth:`datafusion.ExecutionPlan.collect_metrics` — walk the tree and
+ return ``(operator_name, MetricsSet)`` pairs
+- :py:meth:`datafusion.ExecutionPlan.metrics` — return the
+ :py:class:`~datafusion.MetricsSet` for a single node
+- :py:class:`datafusion.MetricsSet` — aggregated metrics for one operator
+- :py:class:`datafusion.Metric` — a single per-partition metric value
diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst
index 510bcbc..8475a7b 100644
--- a/docs/source/user-guide/dataframe/index.rst
+++ b/docs/source/user-guide/dataframe/index.rst
@@ -365,7 +365,16 @@
For a complete list of available functions, see the :py:mod:`datafusion.functions` module documentation.
+Execution Metrics
+-----------------
+
+After executing a DataFrame (via ``collect()``, ``execute_stream()``, etc.),
+DataFusion populates per-operator runtime statistics such as row counts and
+compute time. See :doc:`execution-metrics` for a full explanation and
+worked example.
+
.. toctree::
:maxdepth: 1
rendering
+ execution-metrics
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index ee02c92..80dfa2f 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -56,7 +56,7 @@
from .expr import Expr, WindowFrame
from .io import read_avro, read_csv, read_json, read_parquet
from .options import CsvReadOptions
-from .plan import ExecutionPlan, LogicalPlan
+from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet
from .record_batch import RecordBatch, RecordBatchStream
from .user_defined import (
Accumulator,
@@ -86,6 +86,8 @@
"Expr",
"InsertOp",
"LogicalPlan",
+ "Metric",
+ "MetricsSet",
"ParquetColumnOptions",
"ParquetWriterOptions",
"RecordBatch",
diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py
index 9c96a18..c0cfd52 100644
--- a/python/datafusion/plan.py
+++ b/python/datafusion/plan.py
@@ -24,11 +24,15 @@
import datafusion._internal as df_internal
if TYPE_CHECKING:
+ import datetime
+
from datafusion.context import SessionContext
__all__ = [
"ExecutionPlan",
"LogicalPlan",
+ "Metric",
+ "MetricsSet",
]
@@ -151,3 +155,176 @@
Tables created in memory from record batches are currently not supported.
"""
return self._raw_plan.to_proto()
+
+ def metrics(self) -> MetricsSet | None:
+ """Return metrics for this plan node, or None if this plan has no MetricsSet.
+
+ Some operators (e.g. DataSourceExec) eagerly initialize a MetricsSet
+ when the plan is created, so this may return a set even before
+ execution. Metric *values* (such as ``output_rows``) are only
+ meaningful after the DataFrame has been executed.
+ """
+ raw = self._raw_plan.metrics()
+ if raw is None:
+ return None
+ return MetricsSet(raw)
+
+ def collect_metrics(self) -> list[tuple[str, MetricsSet]]:
+ """Return runtime statistics for each step of the query execution.
+
+ DataFusion executes a query as a pipeline of operators — for example a
+ data source scan, followed by a filter, followed by a projection. After
+ the DataFrame has been executed (via
+ :py:meth:`~datafusion.DataFrame.collect`,
+ :py:meth:`~datafusion.DataFrame.execute_stream`, etc.), each operator
+ records statistics such as how many rows it produced and how much CPU
+ time it consumed.
+
+ Each entry in the returned list corresponds to one operator that
+ recorded metrics. The first element of the tuple is the operator's
+ description string — the same text shown by
+ :py:meth:`display_indent` — which identifies both the operator type
+ and its key parameters, for example ``"FilterExec: column1@0 > 1"``
+ or ``"DataSourceExec: partitions=1"``.
+
+ Returns:
+ A list of ``(description, MetricsSet)`` tuples ordered from the
+ outermost operator (top of the execution tree) down to the
+ data-source leaves. Only operators that recorded at least one
+ metric are included. Returns an empty list if called before the
+ DataFrame has been executed.
+ """
+ result: list[tuple[str, MetricsSet]] = []
+
+ def _walk(node: ExecutionPlan) -> None:
+ ms = node.metrics()
+ if ms is not None:
+ result.append((node.display(), ms))
+ for child in node.children():
+ _walk(child)
+
+ _walk(self)
+ return result
+
+
+class MetricsSet:
+ """A set of metrics for a single execution plan operator.
+
+ A physical plan operator runs independently across one or more partitions.
+ :py:meth:`metrics` returns the raw per-partition :py:class:`Metric` objects.
+ The convenience properties (:py:attr:`output_rows`, :py:attr:`elapsed_compute`,
+ etc.) automatically sum the named metric across *all* partitions, giving a
+ single aggregate value for the operator as a whole.
+ """
+
+ def __init__(self, raw: df_internal.MetricsSet) -> None:
+ """This constructor should not be called by the end user."""
+ self._raw = raw
+
+ def metrics(self) -> list[Metric]:
+ """Return all individual metrics in this set."""
+ return [Metric(m) for m in self._raw.metrics()]
+
+ @property
+ def output_rows(self) -> int | None:
+ """Sum of output_rows across all partitions."""
+ return self._raw.output_rows()
+
+ @property
+ def elapsed_compute(self) -> int | None:
+ """Total CPU time (in nanoseconds) spent inside this operator's execute loop.
+
+ Summed across all partitions. Returns ``None`` if no ``elapsed_compute``
+ metric was recorded.
+ """
+ return self._raw.elapsed_compute()
+
+ @property
+ def spill_count(self) -> int | None:
+ """Number of times this operator spilled data to disk due to memory pressure.
+
+ This is a count of spill events, not a byte count. Summed across all
+ partitions. Returns ``None`` if no ``spill_count`` metric was recorded.
+ """
+ return self._raw.spill_count()
+
+ @property
+ def spilled_bytes(self) -> int | None:
+ """Sum of spilled_bytes across all partitions."""
+ return self._raw.spilled_bytes()
+
+ @property
+ def spilled_rows(self) -> int | None:
+ """Sum of spilled_rows across all partitions."""
+ return self._raw.spilled_rows()
+
+ def sum_by_name(self, name: str) -> int | None:
+ """Sum the named metric across all partitions.
+
+ Useful for accessing any metric not exposed as a first-class property.
+ Returns ``None`` if no metric with the given name was recorded.
+
+ Args:
+ name: The metric name, e.g. ``"output_rows"`` or ``"elapsed_compute"``.
+ """
+ return self._raw.sum_by_name(name)
+
+ def __repr__(self) -> str:
+ """Return a string representation of the metrics set."""
+ return repr(self._raw)
+
+
+class Metric:
+ """A single execution metric with name, value, partition, and labels."""
+
+ def __init__(self, raw: df_internal.Metric) -> None:
+ """This constructor should not be called by the end user."""
+ self._raw = raw
+
+ @property
+ def name(self) -> str:
+ """The name of this metric (e.g. ``output_rows``)."""
+ return self._raw.name
+
+ @property
+ def value(self) -> int | datetime.datetime | None:
+ """The value of this metric.
+
+ Returns an ``int`` for counters, gauges, and time-based metrics
+ (nanoseconds), a :py:class:`~datetime.datetime` (UTC) for
+ ``start_timestamp`` / ``end_timestamp`` metrics, or ``None``
+ when the value has not been set or is not representable.
+ """
+ return self._raw.value
+
+ @property
+ def value_as_datetime(self) -> datetime.datetime | None:
+ """The value as a UTC :py:class:`~datetime.datetime` for timestamp metrics.
+
+ Returns ``None`` for all non-timestamp metrics and for timestamp
+ metrics whose value has not been set (e.g. before execution).
+ """
+ return self._raw.value_as_datetime
+
+ @property
+ def partition(self) -> int | None:
+ """The 0-based partition index this metric applies to.
+
+ Returns ``None`` for metrics that are not partition-specific (i.e. they
+ apply globally across all partitions of the operator).
+ """
+ return self._raw.partition
+
+ def labels(self) -> dict[str, str]:
+ """Return the labels associated with this metric.
+
+ Labels provide additional context for a metric. For example::
+
+ metric.labels()
+ # {'output_type': 'final'}
+ """
+ return self._raw.labels()
+
+ def __repr__(self) -> str:
+ """Return a string representation of the metric."""
+ return repr(self._raw)
diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py
index 396acbe..3705fc7 100644
--- a/python/tests/test_plans.py
+++ b/python/tests/test_plans.py
@@ -15,8 +15,16 @@
# specific language governing permissions and limitations
# under the License.
+import datetime
+
import pytest
-from datafusion import ExecutionPlan, LogicalPlan, SessionContext
+from datafusion import (
+ ExecutionPlan,
+ LogicalPlan,
+ Metric,
+ MetricsSet,
+ SessionContext,
+)
# Note: We must use CSV because memory tables are currently not supported for
@@ -40,3 +48,185 @@
execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes)
assert str(original_execution_plan) == str(execution_plan)
+
+
+def test_metrics_tree_walk() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+ df.collect()
+ plan = df.execution_plan()
+
+ results = plan.collect_metrics()
+ assert len(results) >= 1
+ output_rows_by_op: dict[str, int] = {}
+ for name, ms in results:
+ assert isinstance(name, str)
+ assert isinstance(ms, MetricsSet)
+ if ms.output_rows is not None:
+ output_rows_by_op[name] = ms.output_rows
+
+ # The filter passes rows where column1 > 1, so exactly
+ # 2 rows from (1,'a'),(2,'b'),(3,'c').
+ # At least one operator must report exactly 2 output rows (the filter).
+ assert 2 in output_rows_by_op.values(), (
+ f"Expected an operator with output_rows=2, got {output_rows_by_op}"
+ )
+
+
+def test_metric_properties() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+ df.collect()
+ plan = df.execution_plan()
+
+ found_any_metric = False
+ for _, ms in plan.collect_metrics():
+ r = repr(ms)
+ assert isinstance(r, str)
+ for metric in ms.metrics():
+ found_any_metric = True
+ assert isinstance(metric, Metric)
+ assert isinstance(metric.name, str)
+ assert len(metric.name) > 0
+ assert metric.partition is None or isinstance(metric.partition, int)
+ assert metric.value is None or isinstance(
+ metric.value, int | datetime.datetime
+ )
+ assert isinstance(metric.labels(), dict)
+ mr = repr(metric)
+ assert isinstance(mr, str)
+ assert len(mr) > 0
+ assert found_any_metric, "Expected at least one metric after execution"
+
+
+def test_no_meaningful_metrics_before_execution() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+ plan_before = df.execution_plan()
+
+ # Some plan nodes (e.g. DataSourceExec) eagerly initialize a MetricsSet,
+ # so metrics() may return a set even before execution. However, no rows
+ # should have been processed yet — output_rows must be absent or zero.
+ for _, ms in plan_before.collect_metrics():
+ rows = ms.output_rows
+ assert rows is None or rows == 0, (
+ f"Expected 0 output_rows before execution, got {rows}"
+ )
+
+ # After execution, at least one operator must report rows processed.
+ df.collect()
+ plan_after = df.execution_plan()
+ output_rows_after = [
+ ms.output_rows
+ for _, ms in plan_after.collect_metrics()
+ if ms.output_rows is not None and ms.output_rows > 0
+ ]
+ assert len(output_rows_after) > 0, "Expected output_rows > 0 after execution"
+
+
+def test_collect_partitioned_metrics() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+
+ df.collect_partitioned()
+ plan = df.execution_plan()
+
+ output_rows_values = [
+ ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
+ ]
+ assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
+
+
+def test_execute_stream_metrics() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+
+ for _ in df.execute_stream():
+ pass
+
+ plan = df.execution_plan()
+ output_rows_values = [
+ ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
+ ]
+ assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
+
+
+def test_execute_stream_partitioned_metrics() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+
+ for stream in df.execute_stream_partitioned():
+ for _ in stream:
+ pass
+
+ plan = df.execution_plan()
+ output_rows_values = [
+ ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
+ ]
+ assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
+
+
+def test_value_as_datetime() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+ df.collect()
+ plan = df.execution_plan()
+
+ for _, ms in plan.collect_metrics():
+ for metric in ms.metrics():
+ if metric.name in ("start_timestamp", "end_timestamp"):
+ dt = metric.value_as_datetime
+ assert dt is None or isinstance(dt, datetime.datetime)
+ if dt is not None:
+ assert dt.tzinfo is not None
+ else:
+ assert metric.value_as_datetime is None
+
+
+def test_metric_names_and_labels() -> None:
+ """Verify that known metric names appear and labels are well-formed."""
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+ df.collect()
+ plan = df.execution_plan()
+
+ all_metric_names: set[str] = set()
+ for _, ms in plan.collect_metrics():
+ for metric in ms.metrics():
+ all_metric_names.add(metric.name)
+ # Labels must be a dict of str->str
+ labels = metric.labels()
+ for k, v in labels.items():
+ assert isinstance(k, str)
+ assert isinstance(v, str)
+
+ # After a filter query, we expect at minimum these standard metric names.
+ assert "output_rows" in all_metric_names, (
+ f"Expected 'output_rows' in {all_metric_names}"
+ )
+ assert "elapsed_compute" in all_metric_names, (
+ f"Expected 'elapsed_compute' in {all_metric_names}"
+ )
+
+
+def test_collect_twice_has_metrics() -> None:
+ ctx = SessionContext()
+ ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
+
+ df.collect()
+ df.collect()
+
+ plan = df.execution_plan()
+ output_rows_values = [
+ ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
+ ]
+ assert len(output_rows_values) > 0