Feat/parameterized sql queries (#964)
* Intermediate work on parameterizing queries
* Reworking to do token parsing of sql query instead of string manipulation
* Switching to explicit param_values or named parameters that will perform string replacement via parsed tokens
* Add additional unit tests for parameterized queries
* merge conflict
* license text
* Add documentation
* cargo clippy and fmt
* Need at least pyarrow 16 now
* add type hints
* Minor docstring update
diff --git a/docs/source/user-guide/configuration.rst b/docs/source/user-guide/configuration.rst
index 5425a04..f8e613c 100644
--- a/docs/source/user-guide/configuration.rst
+++ b/docs/source/user-guide/configuration.rst
@@ -15,6 +15,8 @@
.. specific language governing permissions and limitations
.. under the License.
+.. _configuration:
+
Configuration
=============
diff --git a/docs/source/user-guide/sql.rst b/docs/source/user-guide/sql.rst
index 6fa7f0c..b4bfb96 100644
--- a/docs/source/user-guide/sql.rst
+++ b/docs/source/user-guide/sql.rst
@@ -23,17 +23,100 @@
.. ipython:: python
import datafusion
- from datafusion import col
- import pyarrow
+ from datafusion import DataFrame, SessionContext
# create a context
ctx = datafusion.SessionContext()
# register a CSV
- ctx.register_csv('pokemon', 'pokemon.csv')
+ ctx.register_csv("pokemon", "pokemon.csv")
# create a new statement via SQL
df = ctx.sql('SELECT "Attack"+"Defense", "Attack"-"Defense" FROM pokemon')
# collect and convert to pandas DataFrame
- df.to_pandas()
\ No newline at end of file
+ df.to_pandas()
+
+Parameterized queries
+---------------------
+
+In DataFusion-Python 51.0.0 we introduced the ability to pass parameters
+in a SQL query. These are similar in concept to
+`prepared statements <https://datafusion.apache.org/user-guide/sql/prepared_statements.html>`_,
+but allow passing named parameters into a SQL query. Consider this simple
+example.
+
+.. ipython:: python
+
+ def show_attacks(ctx: SessionContext, threshold: int) -> None:
+ ctx.sql(
+ 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', val=threshold
+ ).show(num=5)
+ show_attacks(ctx, 75)
+
+When passing parameters like the example above we convert the Python objects
+into their string representation. We also have special case handling
+for :py:class:`~datafusion.dataframe.DataFrame` objects, since they cannot simply
+be turned into string representations for an SQL query. In these cases we
+will register a temporary view in the :py:class:`~datafusion.context.SessionContext`
+using a generated table name.
+
+The formatting for passing string replacement objects is to precede the
+variable name with a single ``$``. This works for all dialects in
+the SQL parser except ``hive`` and ``mysql``. Since these dialects do not
+support named placeholders, we are unable to do this type of replacement.
+We recommend either switching to another dialect or using Python
+f-string style replacement.
+
+.. warning::
+
+ To support DataFrame parameterized queries, your session must support
+ registration of temporary views. The default
+ :py:class:`~datafusion.catalog.CatalogProvider` and
+ :py:class:`~datafusion.catalog.SchemaProvider` do have this capability.
+ If you have implemented custom providers, it is important that temporary
+ views do not persist across :py:class:`~datafusion.context.SessionContext`
+ or you may get unintended consequences.
+
+The following example shows passing in both a :py:class:`~datafusion.dataframe.DataFrame`
+object as well as a Python object to be used in parameterized replacement.
+
+.. ipython:: python
+
+ def show_column(
+ ctx: SessionContext, column: str, df: DataFrame, threshold: int
+ ) -> None:
+ ctx.sql(
+ 'SELECT "Name", $col FROM $df WHERE $col > $val',
+ col=column,
+ df=df,
+ val=threshold,
+ ).show(num=5)
+ df = ctx.table("pokemon")
+ show_column(ctx, '"Defense"', df, 75)
+
+The approach implemented for conversion of variables into a SQL query
+relies on string conversion. This has the potential for data loss,
+specifically for cases like floating point numbers. If you need to pass
+variables into a parameterized query and it is important to maintain the
+original value without conversion to a string, then you can use the
+optional parameter ``param_values`` to specify these. This parameter
+expects a dictionary mapping from the parameter name to a Python
+object. Those objects will be cast into a
+`PyArrow Scalar Value <https://arrow.apache.org/docs/python/generated/pyarrow.Scalar.html>`_.
+
+Using ``param_values`` will rely on the SQL dialect you have configured
+for your session. This can be set using the :ref:`configuration options <configuration>`
+of your :py:class:`~datafusion.context.SessionContext`. Similar to how
+`prepared statements <https://datafusion.apache.org/user-guide/sql/prepared_statements.html>`_
+work, these parameters are limited to places where you would pass in a
+scalar value, such as a comparison.
+
+.. ipython:: python
+
+ def param_attacks(ctx: SessionContext, threshold: int) -> None:
+ ctx.sql(
+ 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val',
+ param_values={"val": threshold},
+ ).show(num=5)
+ param_attacks(ctx, 75)
diff --git a/pyproject.toml b/pyproject.toml
index 25f30b8..9ad7dab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,7 +44,7 @@
"Programming Language :: Rust",
]
dependencies = [
- "pyarrow>=11.0.0;python_version<'3.14'",
+ "pyarrow>=16.0.0;python_version<'3.14'",
"pyarrow>=22.0.0;python_version>='3.14'",
"typing-extensions;python_version<'3.13'"
]
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 0aa2f27..7dc06eb 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import uuid
import warnings
from typing import TYPE_CHECKING, Any, Protocol
@@ -27,6 +28,7 @@
except ImportError:
from typing_extensions import deprecated # Python 3.12
+
import pyarrow as pa
from datafusion.catalog import Catalog
@@ -592,9 +594,19 @@
self._convert_file_sort_order(file_sort_order),
)
- def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
+ def sql(
+ self,
+ query: str,
+ options: SQLOptions | None = None,
+ param_values: dict[str, Any] | None = None,
+ **named_params: Any,
+ ) -> DataFrame:
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.
+ See the online documentation for a description of how to perform
+ parameterized substitution via either the ``param_values`` option
+ or passing in ``named_params``.
+
Note: This API implements DDL statements such as ``CREATE TABLE`` and
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
default implementation.See
@@ -603,15 +615,57 @@
Args:
query: SQL query text.
options: If provided, the query will be validated against these options.
+ param_values: Provides substitution of scalar values in the query
+ after parsing.
+ named_params: Provides string or DataFrame substitution in the query string.
Returns:
DataFrame representation of the SQL query.
"""
- if options is None:
- return DataFrame(self.ctx.sql(query))
- return DataFrame(self.ctx.sql_with_options(query, options.options_internal))
- def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
+ def value_to_scalar(value: Any) -> pa.Scalar:
+ if isinstance(value, pa.Scalar):
+ return value
+ return pa.scalar(value)
+
+ def value_to_string(value: Any) -> str:
+ if isinstance(value, DataFrame):
+ view_name = str(uuid.uuid4()).replace("-", "_")
+ view_name = f"view_{view_name}"
+ view = value.df.into_view(temporary=True)
+ self.ctx.register_table(view_name, view)
+ return view_name
+ return str(value)
+
+ param_values = (
+ {name: value_to_scalar(value) for (name, value) in param_values.items()}
+ if param_values is not None
+ else {}
+ )
+ param_strings = (
+ {name: value_to_string(value) for (name, value) in named_params.items()}
+ if named_params is not None
+ else {}
+ )
+
+ options_raw = options.options_internal if options is not None else None
+
+ return DataFrame(
+ self.ctx.sql_with_options(
+ query,
+ options=options_raw,
+ param_values=param_values,
+ param_strings=param_strings,
+ )
+ )
+
+ def sql_with_options(
+ self,
+ query: str,
+ options: SQLOptions,
+ param_values: dict[str, Any] | None = None,
+ **named_params: Any,
+ ) -> DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
This function will first validate that the query is allowed by the
@@ -620,11 +674,16 @@
Args:
query: SQL query text.
options: SQL options.
+ param_values: Provides substitution of scalar values in the query
+ after parsing.
+ named_params: Provides string or DataFrame substitution in the query string.
Returns:
DataFrame representation of the SQL query.
"""
- return self.sql(query, options)
+ return self.sql(
+ query, options=options, param_values=param_values, **named_params
+ )
def create_dataframe(
self,
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 8d1f300..85afd02 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -21,7 +21,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
-from datafusion import col, udf
+from datafusion import SessionContext, col, udf
from datafusion.object_store import Http
from pyarrow.csv import write_csv
@@ -550,3 +550,46 @@
rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["count"], strict=False)) == {"a": 3, "b": 2}
+
+
+def test_parameterized_named_params(ctx, tmp_path) -> None:
+ path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
+
+ df = ctx.read_parquet(path)
+ result = ctx.sql(
+ "SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df",
+ lit_val=3,
+ replaced_df=df,
+ ).collect()
+ result = pa.Table.from_batches(result)
+ assert result.to_pydict() == {"cnt": [100], "lit_val": [3]}
+
+
+def test_parameterized_param_values(ctx: SessionContext) -> None:
+ # Test the parameters that should be handled by the parser rather
+ # than our manipulation of the query string by searching for tokens
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3, 4])],
+ names=["a"],
+ )
+
+ ctx.register_record_batches("t", [[batch]])
+ result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3})
+ assert result.to_pydict() == {"a": [1, 2]}
+
+
+def test_parameterized_mixed_query(ctx: SessionContext) -> None:
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3, 4])],
+ names=["a"],
+ )
+ ctx.register_record_batches("t", [[batch]])
+ registered_df = ctx.table("t")
+
+ result = ctx.sql(
+ "SELECT $col_name FROM $df WHERE a < $val",
+ param_values={"val": 3},
+ df=registered_df,
+ col_name="a",
+ )
+ assert result.to_pydict() == {"a": [1, 2]}
diff --git a/src/context.rs b/src/context.rs
index f64cc16..adc6f15 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -55,14 +55,16 @@
use uuid::Uuid;
use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
+use crate::common::data_type::PyScalarValue;
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, PyDataFusionResult};
+use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::exceptions::py_value_err;
use crate::sql::logical::PyLogicalPlan;
+use crate::sql::util::replace_placeholders_with_strings;
use crate::store::StorageContexts;
use crate::table::PyTable;
use crate::udaf::PyAggregateUDF;
@@ -422,27 +424,41 @@
self.ctx.register_udtf(&name, func);
}
- /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
- pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
- let result = self.ctx.sql(query);
- let df = wait_for_future(py, result)??;
- Ok(PyDataFrame::new(df))
- }
-
- #[pyo3(signature = (query, options=None))]
+ #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
pub fn sql_with_options(
&self,
- query: &str,
- options: Option<PySQLOptions>,
py: Python,
+ mut query: String,
+ options: Option<PySQLOptions>,
+ param_values: HashMap<String, PyScalarValue>,
+ param_strings: HashMap<String, String>,
) -> PyDataFusionResult<PyDataFrame> {
let options = if let Some(options) = options {
options.options
} else {
SQLOptions::new()
};
- let result = self.ctx.sql_with_options(query, options);
- let df = wait_for_future(py, result)??;
+
+ let param_values = param_values
+ .into_iter()
+ .map(|(name, value)| (name, ScalarValue::from(value)))
+ .collect::<HashMap<_, _>>();
+
+ let state = self.ctx.state();
+ let dialect = state.config().options().sql_parser.dialect.as_str();
+
+ if !param_strings.is_empty() {
+ query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
+ }
+
+ let mut df = wait_for_future(py, async {
+ self.ctx.sql_with_options(&query, options).await
+ })??;
+
+ if !param_values.is_empty() {
+ df = df.with_param_values(param_values)?;
+ }
+
Ok(PyDataFrame::new(df))
}
@@ -550,7 +566,7 @@
(array.schema().as_ref().to_owned(), vec![array])
} else {
- return Err(crate::errors::PyDataFusionError::Common(
+ return Err(PyDataFusionError::Common(
"Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
));
};
@@ -714,7 +730,7 @@
) -> PyDataFusionResult<()> {
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
- return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
+ return Err(PyDataFusionError::PythonError(py_value_err(
"Delimiter must be a single character",
)));
}
@@ -968,7 +984,7 @@
) -> PyDataFusionResult<PyDataFrame> {
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
- return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
+ return Err(PyDataFusionError::PythonError(py_value_err(
"Delimiter must be a single character",
)));
};
diff --git a/src/sql.rs b/src/sql.rs
index 9f1fe81..dea9b56 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -17,3 +17,4 @@
pub mod exceptions;
pub mod logical;
+pub(crate) mod util;
diff --git a/src/sql/util.rs b/src/sql/util.rs
new file mode 100644
index 0000000..5edff00
--- /dev/null
+++ b/src/sql/util.rs
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError};
+use datafusion::logical_expr::sqlparser::dialect::dialect_from_str;
+use datafusion::sql::sqlparser::dialect::Dialect;
+use datafusion::sql::sqlparser::parser::Parser;
+use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer};
+
+fn tokens_from_replacements(
+ placeholder: &str,
+ replacements: &HashMap<String, Vec<Token>>,
+) -> Option<Vec<Token>> {
+ if let Some(pattern) = placeholder.strip_prefix("$") {
+ replacements.get(pattern).cloned()
+ } else {
+ None
+ }
+}
+
+fn get_tokens_for_string_replacement(
+ dialect: &dyn Dialect,
+ replacements: HashMap<String, String>,
+) -> Result<HashMap<String, Vec<Token>>, DataFusionError> {
+ replacements
+ .into_iter()
+ .map(|(name, value)| {
+ let tokens = Tokenizer::new(dialect, &value)
+ .tokenize()
+ .map_err(|err| DataFusionError::External(err.into()))?;
+ Ok((name, tokens))
+ })
+ .collect()
+}
+
+pub(crate) fn replace_placeholders_with_strings(
+ query: &str,
+ dialect: &str,
+ replacements: HashMap<String, String>,
+) -> Result<String, DataFusionError> {
+ let dialect = dialect_from_str(dialect)
+ .ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?;
+
+ let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?;
+
+ let tokens = Tokenizer::new(dialect.as_ref(), query)
+ .tokenize()
+ .map_err(|err| DataFusionError::External(err.into()))?;
+
+ let replaced_tokens = tokens
+ .into_iter()
+ .flat_map(|token| {
+ if let Token::Placeholder(placeholder) = &token {
+ tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token])
+ } else {
+ vec![token]
+ }
+ })
+ .collect::<Vec<Token>>();
+
+ let statement = Parser::new(dialect.as_ref())
+ .with_tokens(replaced_tokens)
+ .parse_statements()
+ .map_err(|err| DataFusionError::External(Box::new(err)))?;
+
+ if statement.len() != 1 {
+ return exec_err!("placeholder replacement should return exactly one statement");
+ }
+
+ Ok(statement[0].to_string())
+}