| // Copyright 2022 The Blaze Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| use crate::{ |
| jni_call, jni_call_static, jni_delete_local_ref, jni_new_direct_byte_buffer, |
| jni_new_global_ref, jni_new_string, |
| }; |
| use async_trait::async_trait; |
| use datafusion::arrow::datatypes::SchemaRef; |
| use datafusion::arrow::error::ArrowError; |
| use datafusion::arrow::record_batch::RecordBatch; |
| use datafusion::error::DataFusionError; |
| use datafusion::error::Result; |
| use datafusion::execution::context::TaskContext; |
| use datafusion::physical_expr::PhysicalSortExpr; |
| use datafusion::physical_plan::memory::MemoryStream; |
| use datafusion::physical_plan::metrics::{ |
| BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, |
| }; |
| use datafusion::physical_plan::stream::RecordBatchStreamAdapter; |
| use datafusion::physical_plan::{ |
| DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, |
| }; |
| use futures::StreamExt; |
| use futures::TryFutureExt; |
| use futures::TryStreamExt; |
| use jni::objects::{GlobalRef, JObject}; |
| use std::any::Any; |
| use std::fmt::Formatter; |
| use std::io::Cursor; |
| use std::sync::Arc; |
| |
| #[derive(Debug)] |
| pub struct IpcWriterExec { |
| input: Arc<dyn ExecutionPlan>, |
| ipc_consumer_resource_id: String, |
| metrics: ExecutionPlanMetricsSet, |
| } |
| |
| impl IpcWriterExec { |
| pub fn new(input: Arc<dyn ExecutionPlan>, ipc_consumer_resource_id: String) -> Self { |
| Self { |
| input, |
| ipc_consumer_resource_id, |
| metrics: ExecutionPlanMetricsSet::new(), |
| } |
| } |
| } |
| |
| #[async_trait] |
| impl ExecutionPlan for IpcWriterExec { |
| fn as_any(&self) -> &dyn Any { |
| self |
| } |
| |
| fn schema(&self) -> SchemaRef { |
| self.input.schema() |
| } |
| |
| fn output_partitioning(&self) -> Partitioning { |
| self.input.output_partitioning() |
| } |
| |
| fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { |
| self.input.output_ordering() |
| } |
| |
| fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { |
| vec![self.input.clone()] |
| } |
| |
| fn with_new_children( |
| self: Arc<Self>, |
| children: Vec<Arc<dyn ExecutionPlan>>, |
| ) -> Result<Arc<dyn ExecutionPlan>> { |
| if children.len() != 1 { |
| return Err(DataFusionError::Plan( |
| "IpcWriterExec expects one children".to_string(), |
| )); |
| } |
| Ok(Arc::new(IpcWriterExec::new( |
| self.input.clone(), |
| self.ipc_consumer_resource_id.clone(), |
| ))) |
| } |
| |
| fn execute( |
| &self, |
| partition: usize, |
| context: Arc<TaskContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| let baseline_metrics = BaselineMetrics::new(&self.metrics, 0); |
| let ipc_consumer = jni_new_global_ref!(jni_call_static!( |
| JniBridge.getResource( |
| jni_new_string!(&self.ipc_consumer_resource_id)? |
| ) -> JObject |
| )?)?; |
| let input = self.input.execute(partition, context.clone())?; |
| |
| Ok(Box::pin(RecordBatchStreamAdapter::new( |
| self.schema(), |
| futures::stream::once( |
| write_ipc( |
| input, |
| context.session_config().batch_size(), |
| ipc_consumer, |
| baseline_metrics, |
| ) |
| .map_err(|e| ArrowError::ExternalError(Box::new(e))), |
| ) |
| .try_flatten(), |
| ))) |
| } |
| |
| fn metrics(&self) -> Option<MetricsSet> { |
| Some(self.metrics.clone_inner()) |
| } |
| |
| fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { |
| match t { |
| DisplayFormatType::Default => { |
| write!(f, "IpcWriterExec") |
| } |
| } |
| } |
| |
| fn statistics(&self) -> Statistics { |
| todo!() |
| } |
| } |
| |
| pub async fn write_ipc( |
| mut input: SendableRecordBatchStream, |
| batch_size: usize, |
| ipc_consumer: GlobalRef, |
| metrics: BaselineMetrics, |
| ) -> Result<SendableRecordBatchStream> { |
| let schema = input.schema(); |
| let mut batches = vec![]; |
| let mut num_rows = 0; |
| |
| macro_rules! flush_batches { |
| () => {{ |
| let timer = metrics.elapsed_compute().timer(); |
| let batch = RecordBatch::concat(&schema, &batches)?; |
| metrics.record_output(num_rows); |
| batches.clear(); |
| num_rows = 0; |
| |
| let mut buffer = vec![]; |
| crate::util::ipc::write_ipc_compressed( |
| &batch, |
| &mut Cursor::new(&mut buffer), |
| )?; |
| std::mem::drop(timer); |
| |
| let jbuf = jni_new_direct_byte_buffer!(&mut buffer)?; |
| let consumed = jni_call!( |
| ScalaFunction1(ipc_consumer.as_obj()).apply(jbuf) -> JObject |
| )?; |
| jni_delete_local_ref!(consumed)?; |
| jni_delete_local_ref!(jbuf.into())?; |
| }} |
| } |
| |
| while let Some(batch) = input.next().await { |
| let batch = batch?; |
| |
| if batch.num_rows() == 0 { |
| continue; |
| } |
| if num_rows + batch.num_rows() > batch_size { |
| flush_batches!(); |
| } |
| num_rows += batch.num_rows(); |
| batches.push(batch); |
| } |
| if num_rows > 0 { |
| flush_batches!(); |
| } |
| assert_eq!(num_rows, 0); |
| |
| // ipc writer always has empty output |
| Ok(Box::pin(MemoryStream::try_new(vec![], schema, None)?)) |
| } |