blob: 86c1b3a45de546baf78971c54c644c353c273d20 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::planner::{make_execution_graph, PyExecutionGraph};
use crate::shuffle::ShuffleCodec;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::{displayable, ExecutionPlan};
use datafusion::prelude::*;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf;
use futures::StreamExt;
use prost::Message;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyTuple};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
type PyResultSet = Vec<PyObject>;
#[pyclass(name = "Context", module = "datafusion_ray", subclass)]
pub struct PyContext {
pub(crate) py_ctx: PyObject,
}
pub(crate) fn execution_plan_from_pyany(
py_plan: &Bound<PyAny>,
) -> PyResult<Arc<dyn ExecutionPlan>> {
let py_proto = py_plan.call_method0("to_proto")?;
let plan_bytes: &[u8] = py_proto.extract()?;
let plan_node = protobuf::PhysicalPlanNode::try_decode(plan_bytes).map_err(|e| {
PyRuntimeError::new_err(format!(
"Unable to decode physical plan protobuf message: {}",
e
))
})?;
let codec = ShuffleCodec {};
let runtime = RuntimeEnv::default();
let registry = SessionContext::new();
plan_node
.try_into_physical_plan(&registry, &runtime, &codec)
.map_err(|e| e.into())
}
#[pymethods]
impl PyContext {
#[new]
pub fn new(session_ctx: PyObject) -> Result<Self> {
Ok(Self {
py_ctx: session_ctx,
})
}
/// Execute SQL directly against the DataFusion context. Useful for statements
/// such as "create view" or "drop view"
pub fn sql(&self, query: &str, py: Python) -> PyResult<()> {
println!("Executing {}", query);
// let _df = wait_for_future(py, self.ctx.sql(sql))?;
let _df = self.run_sql(query, py);
Ok(())
}
fn run_sql(&self, query: &str, py: Python) -> PyResult<Py<PyAny>> {
let args = PyTuple::new_bound(py, [query]);
self.py_ctx.call_method1(py, "sql", args)
}
/// Plan a distributed SELECT query for executing against the Ray workers
pub fn plan(&self, plan: &Bound<PyAny>) -> PyResult<PyExecutionGraph> {
// println!("Planning {}", sql);
// let df = wait_for_future(py, self.ctx.sql(sql))?;
// let py_df = self.run_sql(sql, py)?;
// let py_plan = py_df.call_method0(py, "execution_plan")?;
// let py_plan = py_plan.bind(py);
let plan = execution_plan_from_pyany(plan)?;
let graph = make_execution_graph(plan.clone())?;
// debug logging
let mut stages = graph.query_stages.values().collect::<Vec<_>>();
stages.sort_by_key(|s| s.id);
for stage in stages {
println!(
"Query stage #{}:\n{}",
stage.id,
displayable(stage.plan.as_ref()).indent(false)
);
}
Ok(PyExecutionGraph::new(graph))
}
/// Execute a partition of a query plan. This will typically be executing a shuffle write and write the results to disk
pub fn execute_partition(
&self,
plan: &Bound<'_, PyBytes>,
part: usize,
py: Python,
) -> PyResult<PyResultSet> {
execute_partition(plan, part, py)
}
}
#[pyfunction]
pub fn execute_partition(
plan_bytes: &Bound<'_, PyBytes>,
part: usize,
py: Python,
) -> PyResult<PyResultSet> {
let plan = deserialize_execution_plan(plan_bytes)?;
_execute_partition(plan, part)
.unwrap()
.into_iter()
.map(|batch| batch.to_pyarrow(py))
.collect()
}
pub fn serialize_execution_plan(
plan: Arc<dyn ExecutionPlan>,
py: Python,
) -> PyResult<Bound<'_, PyBytes>> {
let codec = ShuffleCodec {};
let proto =
datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(plan.clone(), &codec)?;
let bytes = proto.encode_to_vec();
Ok(PyBytes::new_bound(py, &bytes))
}
pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dyn ExecutionPlan>> {
let bytes: &[u8] = proto_msg.extract()?;
let proto_plan =
datafusion_proto::protobuf::PhysicalPlanNode::try_decode(bytes).map_err(|e| {
PyRuntimeError::new_err(format!(
"Unable to decode logical node from serialized bytes: {}",
e
))
})?;
let ctx = SessionContext::new();
let codec = ShuffleCodec {};
let plan = proto_plan
.try_into_physical_plan(&ctx, &ctx.runtime_env(), &codec)
.map_err(DataFusionError::from)?;
Ok(plan)
}
/// Execute a partition of a query plan. This will typically be executing a shuffle write and
/// write the results to disk, except for the final query stage, which will return the data.
/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
fn _execute_partition(plan: Arc<dyn ExecutionPlan>, part: usize) -> Result<Vec<RecordBatch>> {
let ctx = Arc::new(TaskContext::new(
Some("task_id".to_string()),
"session_id".to_string(),
SessionConfig::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Arc::new(RuntimeEnv::default()),
));
// create a Tokio runtime to run the async code
let rt = Runtime::new().unwrap();
let fut: JoinHandle<Result<Vec<RecordBatch>>> = rt.spawn(async move {
let mut stream = plan.execute(part, ctx)?;
let mut results = vec![];
while let Some(result) = stream.next().await {
results.push(result?);
}
Ok(results)
});
// block and wait on future
let results = rt.block_on(fut).unwrap()?;
Ok(results)
}