blob: 73461c1d71bb5770d3ab9d9078e35aad3e006dd6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use ballista_core::serde::{
BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec,
};
use datafusion::logical_expr::ScalarUDF;
use datafusion_proto::logical_plan::LogicalExtensionCodec;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use pyo3::types::{PyAnyMethods, PyBytes, PyBytesMethods};
use pyo3::{PyObject, PyResult, Python};
use std::fmt::Debug;
use std::sync::Arc;
static MODULE: &str = "cloudpickle";
static FUN_LOADS: &str = "loads";
static FUN_DUMPS: &str = "dumps";
/// Serde protocol for UD(a)F
#[derive(Debug)]
struct CloudPickle {
loads: PyObject,
dumps: PyObject,
}
impl CloudPickle {
pub fn try_new(py: Python<'_>) -> PyResult<Self> {
let module = py.import(MODULE)?;
let loads = module.getattr(FUN_LOADS)?.unbind();
let dumps = module.getattr(FUN_DUMPS)?.unbind();
Ok(Self { loads, dumps })
}
pub fn pickle(&self, py: Python<'_>, py_any: &PyObject) -> PyResult<Vec<u8>> {
let b: PyObject = self.dumps.call1(py, (py_any,))?.extract(py)?;
let blob = b.downcast_bound::<PyBytes>(py)?.clone();
Ok(blob.as_bytes().to_owned())
}
pub fn unpickle(&self, py: Python<'_>, blob: &[u8]) -> PyResult<PyObject> {
let t: PyObject = self.loads.call1(py, (blob,))?.extract(py)?;
Ok(t)
}
}
pub struct PyLogicalCodec {
inner: BallistaLogicalExtensionCodec,
cloudpickle: CloudPickle,
}
impl PyLogicalCodec {
pub fn try_new(py: Python<'_>) -> PyResult<Self> {
Ok(Self {
inner: BallistaLogicalExtensionCodec::default(),
cloudpickle: CloudPickle::try_new(py)?,
})
}
}
impl Debug for PyLogicalCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyLogicalCodec").finish()
}
}
impl LogicalExtensionCodec for PyLogicalCodec {
fn try_decode(
&self,
buf: &[u8],
inputs: &[datafusion::logical_expr::LogicalPlan],
ctx: &datafusion::prelude::SessionContext,
) -> datafusion::error::Result<datafusion::logical_expr::Extension> {
self.inner.try_decode(buf, inputs, ctx)
}
fn try_encode(
&self,
node: &datafusion::logical_expr::Extension,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
self.inner.try_encode(node, buf)
}
fn try_decode_table_provider(
&self,
buf: &[u8],
table_ref: &datafusion::sql::TableReference,
schema: datafusion::arrow::datatypes::SchemaRef,
ctx: &datafusion::prelude::SessionContext,
) -> datafusion::error::Result<std::sync::Arc<dyn datafusion::catalog::TableProvider>>
{
self.inner
.try_decode_table_provider(buf, table_ref, schema, ctx)
}
fn try_encode_table_provider(
&self,
table_ref: &datafusion::sql::TableReference,
node: std::sync::Arc<dyn datafusion::catalog::TableProvider>,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
self.inner.try_encode_table_provider(table_ref, node, buf)
}
fn try_decode_file_format(
&self,
buf: &[u8],
ctx: &datafusion::prelude::SessionContext,
) -> datafusion::error::Result<
std::sync::Arc<dyn datafusion::datasource::file_format::FileFormatFactory>,
> {
self.inner.try_decode_file_format(buf, ctx)
}
fn try_encode_file_format(
&self,
buf: &mut Vec<u8>,
node: std::sync::Arc<dyn datafusion::datasource::file_format::FileFormatFactory>,
) -> datafusion::error::Result<()> {
self.inner.try_encode_file_format(buf, node)
}
fn try_decode_udf(
&self,
name: &str,
buf: &[u8],
) -> datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::ScalarUDF>>
{
// use cloud pickle to decode udf
self.inner.try_decode_udf(name, buf)
}
fn try_encode_udf(
&self,
node: &datafusion::logical_expr::ScalarUDF,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
// use cloud pickle to decode udf
self.inner.try_encode_udf(node, buf)
}
fn try_decode_udaf(
&self,
name: &str,
buf: &[u8],
) -> datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::AggregateUDF>>
{
self.inner.try_decode_udaf(name, buf)
}
fn try_encode_udaf(
&self,
node: &datafusion::logical_expr::AggregateUDF,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
self.inner.try_encode_udaf(node, buf)
}
fn try_decode_udwf(
&self,
name: &str,
buf: &[u8],
) -> datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::WindowUDF>>
{
self.inner.try_decode_udwf(name, buf)
}
fn try_encode_udwf(
&self,
node: &datafusion::logical_expr::WindowUDF,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
self.inner.try_encode_udwf(node, buf)
}
}
pub struct PyPhysicalCodec {
inner: BallistaPhysicalExtensionCodec,
cloudpickle: CloudPickle,
}
impl Debug for PyPhysicalCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyPhysicalCodec").finish()
}
}
impl PyPhysicalCodec {
pub fn try_new(py: Python<'_>) -> PyResult<Self> {
Ok(Self {
inner: BallistaPhysicalExtensionCodec::default(),
cloudpickle: CloudPickle::try_new(py)?,
})
}
}
impl PhysicalExtensionCodec for PyPhysicalCodec {
fn try_decode(
&self,
buf: &[u8],
inputs: &[std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>],
registry: &dyn datafusion::execution::FunctionRegistry,
) -> datafusion::error::Result<
std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
> {
self.inner.try_decode(buf, inputs, registry)
}
fn try_encode(
&self,
node: std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
self.inner.try_encode(node, buf)
}
fn try_decode_udf(
&self,
name: &str,
_buf: &[u8],
) -> datafusion::common::Result<Arc<ScalarUDF>> {
// use cloudpickle here
datafusion::common::not_impl_err!(
"PhysicalExtensionCodec is not provided for scalar function {name}"
)
}
fn try_encode_udf(
&self,
_node: &ScalarUDF,
_buf: &mut Vec<u8>,
) -> datafusion::common::Result<()> {
// use cloudpickle here
Ok(())
}
}