| // Copyright 2022 The Blaze Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| use std::{ |
| future::Future, |
| panic::AssertUnwindSafe, |
| pin::Pin, |
| sync::{Arc, Weak}, |
| task::{ready, Context, Poll}, |
| time::Instant, |
| }; |
| |
| use arrow::{array::RecordBatch, datatypes::SchemaRef}; |
| use blaze_cudf_bridge::plans::{split_table::CudfSplitTablePlan, CudfPlan}; |
| use blaze_jni_bridge::{ |
| conf::{self, BooleanConf}, |
| is_task_running, |
| }; |
| use datafusion::{ |
| common::Result, |
| execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}, |
| physical_plan::{ |
| metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time}, |
| stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}, |
| ExecutionPlan, |
| }, |
| }; |
| use datafusion_ext_commons::{ |
| arrow::{array_size::BatchSize, coalesce::coalesce_batches_unchecked}, |
| batch_size, df_execution_err, suggested_batch_mem_size, |
| }; |
| use futures::{Stream, StreamExt}; |
| use futures_util::FutureExt; |
| use once_cell::sync::OnceCell; |
| use parking_lot::Mutex; |
| use tokio::sync::mpsc::Sender; |
| |
| use crate::{ |
| common::{column_pruning::ExecuteWithColumnPruning, timer_helper::TimerHelper}, |
| cudf::plan::convert_datafusion_plan_to_cudf, |
| memmgr::metrics::SpillMetrics, |
| }; |
| |
| pub struct ExecutionContext { |
| task_ctx: Arc<TaskContext>, |
| partition_id: usize, |
| output_schema: SchemaRef, |
| metrics: ExecutionPlanMetricsSet, |
| baseline_metrics: BaselineMetrics, |
| spill_metrics: Arc<OnceCell<SpillMetrics>>, |
| input_stat_metrics: Arc<OnceCell<Option<InputBatchStatistics>>>, |
| } |
| |
| impl ExecutionContext { |
| pub fn new( |
| task_ctx: Arc<TaskContext>, |
| partition_id: usize, |
| output_schema: SchemaRef, |
| metrics: &ExecutionPlanMetricsSet, |
| ) -> Arc<Self> { |
| Arc::new(Self { |
| task_ctx, |
| partition_id, |
| output_schema, |
| baseline_metrics: BaselineMetrics::new(&metrics, partition_id), |
| metrics: metrics.clone(), |
| spill_metrics: Arc::default(), |
| input_stat_metrics: Arc::default(), |
| }) |
| } |
| |
| pub fn with_new_output_schema(&self, output_schema: SchemaRef) -> Arc<Self> { |
| Arc::new(Self { |
| task_ctx: self.task_ctx.clone(), |
| partition_id: self.partition_id, |
| output_schema, |
| metrics: self.metrics.clone(), |
| baseline_metrics: self.baseline_metrics.clone(), |
| spill_metrics: self.spill_metrics.clone(), |
| input_stat_metrics: self.input_stat_metrics.clone(), |
| }) |
| } |
| |
| pub fn task_ctx(&self) -> Arc<TaskContext> { |
| self.task_ctx.clone() |
| } |
| |
| pub fn partition_id(&self) -> usize { |
| self.partition_id |
| } |
| |
| pub fn output_schema(&self) -> SchemaRef { |
| self.output_schema.clone() |
| } |
| |
| pub fn execution_plan_metrics(&self) -> &ExecutionPlanMetricsSet { |
| &self.metrics |
| } |
| |
| pub fn baseline_metrics(&self) -> &BaselineMetrics { |
| &self.baseline_metrics |
| } |
| |
| pub fn spill_metrics(&self) -> &SpillMetrics { |
| self.spill_metrics |
| .get_or_init(|| SpillMetrics::new(&self.metrics, self.partition_id)) |
| } |
| |
| pub fn register_timer_metric(&self, name: &str) -> Time { |
| MetricBuilder::new(self.execution_plan_metrics()) |
| .subset_time(name.to_owned(), self.partition_id) |
| } |
| |
| pub fn register_counter_metric(&self, name: &str) -> Count { |
| MetricBuilder::new(self.execution_plan_metrics()) |
| .counter(name.to_owned(), self.partition_id) |
| } |
| |
| pub fn coalesce_with_default_batch_size( |
| self: &Arc<Self>, |
| input: SendableRecordBatchStream, |
| ) -> SendableRecordBatchStream { |
| pub struct CoalesceStream { |
| input: SendableRecordBatchStream, |
| staging_batches: Vec<RecordBatch>, |
| staging_rows: usize, |
| staging_batches_mem_size: usize, |
| batch_size: usize, |
| elapsed_compute: Time, |
| } |
| |
| impl CoalesceStream { |
| fn coalesce(&mut self) -> Result<RecordBatch> { |
| // better concat_batches() implementation that releases old batch columns asap. |
| let schema = self.input.schema(); |
| let coalesced_batch = coalesce_batches_unchecked(schema, &self.staging_batches); |
| self.staging_batches.clear(); |
| self.staging_rows = 0; |
| self.staging_batches_mem_size = 0; |
| Ok(coalesced_batch) |
| } |
| |
| fn should_flush(&self) -> bool { |
| let size_limit = suggested_batch_mem_size(); |
| let (batch_size_limit, mem_size_limit) = if self.staging_batches.len() > 1 { |
| (self.batch_size, size_limit) |
| } else { |
| (self.batch_size / 2, size_limit / 2) |
| }; |
| self.staging_rows >= batch_size_limit |
| || self.staging_batches_mem_size > mem_size_limit |
| } |
| } |
| |
| impl RecordBatchStream for CoalesceStream { |
| fn schema(&self) -> SchemaRef { |
| self.input.schema() |
| } |
| } |
| |
| impl Stream for CoalesceStream { |
| type Item = Result<RecordBatch>; |
| |
| fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { |
| let elapsed_time = self.elapsed_compute.clone(); |
| loop { |
| match ready!(self.input.poll_next_unpin(cx)).transpose()? { |
| Some(batch) => { |
| let _timer = elapsed_time.timer(); |
| let num_rows = batch.num_rows(); |
| if num_rows > 0 { |
| self.staging_rows += batch.num_rows(); |
| self.staging_batches_mem_size += batch.get_batch_mem_size(); |
| self.staging_batches.push(batch); |
| if self.should_flush() { |
| let coalesced = self.coalesce()?; |
| return Poll::Ready(Some(Ok(coalesced))); |
| } |
| continue; |
| } |
| } |
| None if !self.staging_batches.is_empty() => { |
| let _timer = elapsed_time.timer(); |
| let coalesced = self.coalesce()?; |
| return Poll::Ready(Some(Ok(coalesced))); |
| } |
| None => { |
| return Poll::Ready(None); |
| } |
| } |
| } |
| } |
| } |
| |
| Box::pin(CoalesceStream { |
| input, |
| staging_batches: vec![], |
| staging_rows: 0, |
| staging_batches_mem_size: 0, |
| batch_size: batch_size(), |
| elapsed_compute: self.baseline_metrics().elapsed_compute().clone(), |
| }) |
| } |
| |
| pub fn execute_with_input_stats( |
| self: &Arc<Self>, |
| input: &Arc<dyn ExecutionPlan>, |
| ) -> Result<SendableRecordBatchStream> { |
| let executed = self.execute(input)?; |
| Ok(self.stat_input(executed)) |
| } |
| |
| pub fn execute_projected_with_input_stats( |
| self: &Arc<Self>, |
| input: &Arc<dyn ExecutionPlan>, |
| projection: &[usize], |
| ) -> Result<SendableRecordBatchStream> { |
| let executed = self.execute_projected(input, projection)?; |
| Ok(self.stat_input(executed)) |
| } |
| |
| pub fn execute( |
| self: &Arc<Self>, |
| input: &Arc<dyn ExecutionPlan>, |
| ) -> Result<SendableRecordBatchStream> { |
| input.execute(self.partition_id, self.task_ctx.clone()) |
| } |
| |
| pub fn execute_projected( |
| self: &Arc<Self>, |
| input: &Arc<dyn ExecutionPlan>, |
| projection: &[usize], |
| ) -> Result<SendableRecordBatchStream> { |
| input.execute_projected(self.partition_id, self.task_ctx.clone(), projection) |
| } |
| |
| pub fn execute_with_cudf( |
| self: &Arc<Self>, |
| plan: &dyn ExecutionPlan, |
| ) -> Result<SendableRecordBatchStream> { |
| if !conf::ENABLE_CUDA.value().unwrap_or(false) { |
| return df_execution_err!("blaze CUDA support is not enabled"); |
| } |
| let exec_ctx = self.clone(); |
| let cudf_elapsed_compute = exec_ctx.register_timer_metric("cudf_elapsed_compute"); |
| let cudf_plan = convert_datafusion_plan_to_cudf(plan).inspect_err(|err| { |
| log::info!("convert plan to cudf-bridge error: {err}"); |
| })?; |
| |
| Ok(exec_ctx |
| .clone() |
| .output_with_sender("CudfStream", move |sender| async move { |
| let _cudf_timer = cudf_elapsed_compute.timer(); |
| sender.exclude_time(&cudf_elapsed_compute); |
| |
| log::info!("****** executing with Blaze + CUDA (libcudf) ******"); |
| |
| let cudf_plan: Arc<dyn CudfPlan> = |
| Arc::new(CudfSplitTablePlan::new(cudf_plan, batch_size())); |
| let mut cudf_table_stream = Box::pin(cudf_plan.execute().inspect_err(|err| { |
| log::info!("executing cudf-bridge plan error: {err}"); |
| })?); |
| |
| while let Some(batch) = { |
| let output_schema = exec_ctx.output_schema(); |
| tokio::task::block_in_place(|| -> Result<Option<RecordBatch>> { |
| if let Some(cudf_table) = cudf_table_stream.as_mut().next().transpose()? { |
| return Ok(Some(cudf_table.to_arrow_record_batch(output_schema)?)); |
| } |
| Ok(None) |
| })? |
| } { |
| exec_ctx.baseline_metrics().record_output(batch.num_rows()); |
| sender.send(batch).await; |
| } |
| Ok(()) |
| })) |
| } |
| |
| pub fn stat_input( |
| self: &Arc<Self>, |
| input: SendableRecordBatchStream, |
| ) -> SendableRecordBatchStream { |
| let input_batch_statistics = self.input_stat_metrics.get_or_init(|| { |
| InputBatchStatistics::from_metrics_set_and_blaze_conf( |
| self.execution_plan_metrics(), |
| self.partition_id, |
| ) |
| .expect("error creating input batch statistics") |
| }); |
| |
| if let Some(input_batch_statistics) = input_batch_statistics.clone() { |
| let stat_input: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( |
| input.schema(), |
| input.inspect(move |batch_result| { |
| if let Ok(batch) = &batch_result { |
| input_batch_statistics.record_input_batch(batch); |
| } |
| }), |
| )); |
| return stat_input; |
| } |
| input |
| } |
| |
| pub fn stream_on_completion( |
| self: &Arc<Self>, |
| input: SendableRecordBatchStream, |
| on_completion: Box<dyn FnOnce() -> Result<()> + Send + 'static>, |
| ) -> SendableRecordBatchStream { |
| struct CompletionStream { |
| input: SendableRecordBatchStream, |
| on_completion: Option<Box<dyn FnOnce() -> Result<()> + Send + 'static>>, |
| } |
| |
| impl RecordBatchStream for CompletionStream { |
| fn schema(&self) -> SchemaRef { |
| self.input.schema() |
| } |
| } |
| |
| impl Stream for CompletionStream { |
| type Item = Result<RecordBatch>; |
| |
| fn poll_next( |
| mut self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| ) -> Poll<Option<Self::Item>> { |
| match ready!(self.as_mut().input.poll_next_unpin(cx)) { |
| Some(r) => Poll::Ready(Some(r)), |
| None => { |
| if let Some(on_completion) = self.as_mut().on_completion.take() { |
| if let Err(e) = on_completion() { |
| return Poll::Ready(Some(Err(e))); |
| } |
| } |
| Poll::Ready(None) |
| } |
| } |
| } |
| } |
| |
| Box::pin(CompletionStream { |
| input, |
| on_completion: Some(on_completion), |
| }) |
| } |
| |
| pub fn output_with_sender<Fut: Future<Output = Result<()>> + Send>( |
| self: &Arc<Self>, |
| desc: &'static str, |
| output: impl FnOnce(Arc<WrappedRecordBatchSender>) -> Fut + Send + 'static, |
| ) -> SendableRecordBatchStream { |
| let mut stream_builder = RecordBatchReceiverStream::builder(self.output_schema(), 1); |
| let err_sender = stream_builder.tx().clone(); |
| let wrapped_sender = |
| WrappedRecordBatchSender::new(self.clone(), stream_builder.tx().clone()); |
| |
| stream_builder.spawn(async move { |
| let result = AssertUnwindSafe(async move { |
| if let Err(err) = output(wrapped_sender).await { |
| panic!("output_with_sender[{desc}]: output() returns error: {err}"); |
| } |
| }) |
| .catch_unwind() |
| .await |
| .map(|_| Ok(())) |
| .unwrap_or_else(|err| { |
| let panic_message = |
| panic_message::get_panic_message(&err).unwrap_or("unknown error"); |
| df_execution_err!("{panic_message}") |
| }); |
| |
| if let Err(err) = result { |
| err_sender |
| .send(df_execution_err!("{err}")) |
| .await |
| .unwrap_or_default(); |
| |
| // panic current spawn |
| let task_running = is_task_running(); |
| if !task_running { |
| panic!("output_with_sender[{desc}] canceled due to task finished/killed"); |
| } else { |
| panic!("output_with_sender[{desc}] error: {}", err.to_string()); |
| } |
| } |
| Ok(()) |
| }); |
| stream_builder.build() |
| } |
| } |
| |
| #[derive(Clone)] |
| pub struct InputBatchStatistics { |
| input_batch_count: Count, |
| input_batch_mem_size: Count, |
| input_row_count: Count, |
| } |
| |
| impl InputBatchStatistics { |
| pub fn from_metrics_set_and_blaze_conf( |
| metrics_set: &ExecutionPlanMetricsSet, |
| partition: usize, |
| ) -> Result<Option<Self>> { |
| let enabled = conf::INPUT_BATCH_STATISTICS_ENABLE.value().unwrap_or(false); |
| Ok(enabled.then_some(Self::from_metrics_set(metrics_set, partition))) |
| } |
| |
| pub fn from_metrics_set(metrics_set: &ExecutionPlanMetricsSet, partition: usize) -> Self { |
| Self { |
| input_batch_count: MetricBuilder::new(metrics_set) |
| .counter("input_batch_count", partition), |
| input_batch_mem_size: MetricBuilder::new(metrics_set) |
| .counter("input_batch_mem_size", partition), |
| input_row_count: MetricBuilder::new(metrics_set).counter("input_row_count", partition), |
| } |
| } |
| |
| pub fn record_input_batch(&self, input_batch: &RecordBatch) { |
| let mem_size = input_batch.get_batch_mem_size(); |
| let num_rows = input_batch.num_rows(); |
| self.input_batch_count.add(1); |
| self.input_batch_mem_size.add(mem_size); |
| self.input_row_count.add(num_rows); |
| } |
| } |
| |
| fn working_senders() -> &'static Mutex<Vec<Weak<WrappedRecordBatchSender>>> { |
| static WORKING_SENDERS: OnceCell<Mutex<Vec<Weak<WrappedRecordBatchSender>>>> = OnceCell::new(); |
| WORKING_SENDERS.get_or_init(|| Mutex::default()) |
| } |
| |
| pub struct WrappedRecordBatchSender { |
| exec_ctx: Arc<ExecutionContext>, |
| sender: Sender<Result<RecordBatch>>, |
| exclude_time: OnceCell<Time>, |
| } |
| |
| impl WrappedRecordBatchSender { |
| pub fn new(exec_ctx: Arc<ExecutionContext>, sender: Sender<Result<RecordBatch>>) -> Arc<Self> { |
| let wrapped = Arc::new(Self { |
| exec_ctx, |
| sender, |
| exclude_time: OnceCell::new(), |
| }); |
| let mut working_senders = working_senders().lock(); |
| working_senders.push(Arc::downgrade(&wrapped)); |
| wrapped |
| } |
| |
| pub fn exclude_time(&self, exclude_time: &Time) { |
| assert!( |
| self.exclude_time.get().is_none(), |
| "already used a exclude_time" |
| ); |
| self.exclude_time.get_or_init(|| exclude_time.clone()); |
| } |
| |
| pub async fn send(&self, batch: RecordBatch) { |
| if batch.num_rows() == 0 { |
| return; |
| } |
| let exclude_time = self.exclude_time.get().cloned(); |
| let send_time = exclude_time.as_ref().map(|_| Instant::now()); |
| self.sender |
| .send(Ok(batch)) |
| .await |
| .unwrap_or_else(|err| panic!("output_with_sender: send error: {err}")); |
| |
| send_time.inspect(|send_time| { |
| exclude_time |
| .as_ref() |
| .unwrap() |
| .sub_duration(send_time.elapsed()); |
| }); |
| } |
| } |
| |
| pub fn cancel_all_tasks(task_ctx: &Arc<TaskContext>) { |
| let mut working_senders = working_senders().lock(); |
| *working_senders = std::mem::take(&mut *working_senders) |
| .into_iter() |
| .filter(|wrapped| match wrapped.upgrade() { |
| Some(wrapped) if Arc::ptr_eq(&wrapped.exec_ctx.task_ctx, task_ctx) => { |
| wrapped |
| .sender |
| .try_send(df_execution_err!("task completed/cancelled")) |
| .unwrap_or_default(); |
| false |
| } |
| Some(_) => true, // do not modify senders from other tasks |
| None => false, // already released |
| }) |
| .collect(); |
| } |