fix: mangled errors (#1377)
closes #1226
diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py
index dd4c824..71c08da 100644
--- a/python/tests/test_catalog.py
+++ b/python/tests/test_catalog.py
@@ -81,6 +81,12 @@
return name in self.tables
+class CustomErrorSchemaProvider(CustomSchemaProvider):
+ def table(self, name: str) -> Table | None:
+ message = f"{name} is not an acceptable name"
+ raise ValueError(message)
+
+
class CustomCatalogProvider(dfn.catalog.CatalogProvider):
def __init__(self):
self.schemas = {"my_schema": CustomSchemaProvider()}
@@ -219,6 +225,33 @@
schema.deregister_table(table_name)
+def test_exception_not_mangled(ctx: SessionContext):
+ """Test registering all python providers and running a query against them."""
+
+ catalog_name = "custom_catalog"
+ schema_name = "custom_schema"
+
+ ctx.register_catalog_provider(catalog_name, CustomCatalogProvider())
+
+ catalog = ctx.catalog(catalog_name)
+
+ # Clean out previous schemas if they exist so we can start clean
+ for schema_name in catalog.schema_names():
+ catalog.deregister_schema(schema_name, cascade=False)
+
+ catalog.register_schema(schema_name, CustomErrorSchemaProvider())
+
+ schema = catalog.schema(schema_name)
+
+ for table_name in schema.table_names():
+ schema.deregister_table(table_name)
+
+ schema.register_table("test_table", create_dataset())
+
+ with pytest.raises(ValueError, match="^test_table is not an acceptable name$"):
+ ctx.sql(f"select * from {catalog_name}.{schema_name}.test_table")
+
+
def test_in_end_to_end_python_providers(ctx: SessionContext):
"""Test registering all python providers and running a query against them."""
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 48c3746..12710cf 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -29,7 +29,10 @@
def test_no_table(ctx):
- with pytest.raises(Exception, match="DataFusion error"):
+ with pytest.raises(
+ ValueError,
+ match="^Error during planning: table 'datafusion.public.b' not found$",
+ ):
ctx.sql("SELECT a FROM b").collect()
diff --git a/src/catalog.rs b/src/catalog.rs
index b5b9839..d10d5b8 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -364,7 +364,8 @@
&self,
name: &str,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
- self.table_inner(name).map_err(to_datafusion_err)
+ self.table_inner(name)
+ .map_err(|e| DataFusionError::External(Box::new(e)))
}
fn register_table(
diff --git a/src/context.rs b/src/context.rs
index 89bbe93..fc3d595 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -65,7 +65,9 @@
use crate::common::data_type::PyScalarValue;
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
+use crate::errors::{
+ from_datafusion_error, py_datafusion_err, PyDataFusionError, PyDataFusionResult,
+};
use crate::expr::sort_expr::PySortExpr;
use crate::options::PyCsvReadOptions;
use crate::physical_plan::PyExecutionPlan;
@@ -465,7 +467,8 @@
let mut df = wait_for_future(py, async {
self.ctx.sql_with_options(&query, options).await
- })??;
+ })?
+ .map_err(from_datafusion_error)?;
if !param_values.is_empty() {
df = df.with_param_values(param_values)?;
diff --git a/src/errors.rs b/src/errors.rs
index d1b5180..1080721 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -22,7 +22,7 @@
use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError as InnerDataFusionError;
use prost::EncodeError;
-use pyo3::exceptions::PyException;
+use pyo3::exceptions::{PyException, PyValueError};
use pyo3::PyErr;
pub type PyDataFusionResult<T> = std::result::Result<T, PyDataFusionError>;
@@ -96,3 +96,13 @@
pub fn to_datafusion_err(e: impl Debug) -> InnerDataFusionError {
InnerDataFusionError::Execution(format!("{e:?}"))
}
+
+pub fn from_datafusion_error(err: InnerDataFusionError) -> PyErr {
+ match err {
+ InnerDataFusionError::External(boxed) => match boxed.downcast::<PyErr>() {
+ Ok(py_err) => *py_err,
+ Err(original_boxed) => PyValueError::new_err(format!("{original_boxed}")),
+ },
+ _ => PyValueError::new_err(format!("{err}")),
+ }
+}