blob: 67ad8dcc1d44658c1523824f65fb1b116edf2bb6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{collections::HashMap, sync::Arc};
use datafusion_expr::ScalarUDFImpl;
use pyo3::prelude::*;
use sedona::context::SedonaContext;
use tokio::runtime::Runtime;
use crate::{
dataframe::InternalDataFrame,
datasource::PyExternalFormat,
error::PySedonaError,
import_from::{import_ffi_scalar_udf, import_table_provider_from_any},
runtime::wait_for_future,
udf::PySedonaScalarUdf,
};
#[pyclass]
pub struct InternalContext {
pub inner: SedonaContext,
pub runtime: Arc<Runtime>,
}
#[pymethods]
impl InternalContext {
#[new]
fn new(py: Python) -> Result<Self, PySedonaError> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| {
PySedonaError::SedonaPython(format!("Failed to build multithreaded runtime: {e}"))
})?;
let inner = wait_for_future(py, &runtime, SedonaContext::new_local_interactive())??;
Ok(Self {
inner,
runtime: Arc::new(runtime),
})
}
pub fn view<'py>(
&self,
py: Python<'py>,
name: &str,
) -> Result<InternalDataFrame, PySedonaError> {
let df = wait_for_future(py, &self.runtime, self.inner.ctx.table(name))??;
Ok(InternalDataFrame::new(df, self.runtime.clone()))
}
pub fn create_data_frame<'py>(
&self,
py: Python<'py>,
obj: &Bound<PyAny>,
requested_schema: Option<&Bound<PyAny>>,
) -> Result<InternalDataFrame, PySedonaError> {
let provider = import_table_provider_from_any(py, obj, requested_schema)?;
let df = self.inner.ctx.read_table(provider)?;
Ok(InternalDataFrame::new(df, self.runtime.clone()))
}
pub fn read_parquet<'py>(
&self,
py: Python<'py>,
table_paths: Vec<String>,
options: HashMap<String, PyObject>,
) -> Result<InternalDataFrame, PySedonaError> {
// Convert Python options to strings, filtering out None values
let rust_options: HashMap<String, String> = options
.into_iter()
.filter_map(|(k, v)| {
if v.is_none(py) {
None
} else {
v.bind(py)
.str()
.and_then(|s| s.extract())
.map(|s: String| (k, s))
.ok()
}
})
.collect();
let geo_options =
sedona_geoparquet::provider::GeoParquetReadOptions::from_table_options(rust_options)
.map_err(|e| PySedonaError::SedonaPython(format!("Invalid table options: {e}")))?;
let df = wait_for_future(
py,
&self.runtime,
self.inner.read_parquet(table_paths, geo_options),
)??;
Ok(InternalDataFrame::new(df, self.runtime.clone()))
}
pub fn read_external_format<'py>(
&self,
py: Python<'py>,
format_spec: Bound<PyAny>,
table_paths: Vec<String>,
check_extension: bool,
) -> Result<InternalDataFrame, PySedonaError> {
let spec = format_spec
.call_method0("__sedona_external_format__")?
.extract::<PyExternalFormat>()?;
let df = wait_for_future(
py,
&self.runtime,
self.inner
.read_external_format(Arc::new(spec), table_paths, None, check_extension),
)??;
Ok(InternalDataFrame::new(df, self.runtime.clone()))
}
pub fn sql<'py>(
&self,
py: Python<'py>,
query: &str,
) -> Result<InternalDataFrame, PySedonaError> {
let df = wait_for_future(py, &self.runtime, self.inner.sql(query))??;
Ok(InternalDataFrame::new(df, self.runtime.clone()))
}
pub fn drop_view(&self, table_ref: &str) -> Result<(), PySedonaError> {
self.inner.ctx.deregister_table(table_ref)?;
Ok(())
}
pub fn scalar_udf(&self, name: &str) -> Result<PySedonaScalarUdf, PySedonaError> {
if let Some(sedona_scalar_udf) = self.inner.functions.scalar_udf(name) {
Ok(PySedonaScalarUdf {
inner: sedona_scalar_udf.clone(),
})
} else {
Err(PySedonaError::SedonaPython(format!(
"Sedona scalar UDF with name {name} was not found"
)))
}
}
pub fn register_udf(&mut self, udf: Bound<PyAny>) -> Result<(), PySedonaError> {
if udf.hasattr("__sedona_internal_udf__")? {
let py_scalar_udf = udf
.getattr("__sedona_internal_udf__")?
.call0()?
.extract::<PySedonaScalarUdf>()?;
let name = py_scalar_udf.inner.name();
self.inner
.functions
.insert_scalar_udf(py_scalar_udf.inner.clone());
self.inner.ctx.register_udf(
self.inner
.functions
.scalar_udf(name)
.unwrap()
.clone()
.into(),
);
return Ok(());
} else if udf.hasattr("__datafusion_scalar_udf__")? {
let scalar_udf = import_ffi_scalar_udf(&udf)?;
self.inner.ctx.register_udf(scalar_udf);
return Ok(());
}
Err(PySedonaError::SedonaPython(
"Expected an object implementing __sedona_internal_udf__ or __datafusion_scalar_udf__"
.to_string(),
))
}
}