blob: 44f282c4a9091c2ba8fcb4b74bf461445a63c064 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::any::Any;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::io::Read;
use std::io::Seek;
use std::io::SeekFrom;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::Result as ArrowResult;
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::metrics::BaselineMetrics;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::metrics::MetricsSet;
use datafusion::physical_plan::DisplayFormatType;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use datafusion::physical_plan::Partitioning::UnknownPartitioning;
use datafusion::physical_plan::RecordBatchStream;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::physical_plan::Statistics;
use futures::Stream;
use jni::objects::{GlobalRef, JObject};
use jni::sys::{jboolean, jint, jlong, JNI_TRUE};
use crate::jni_bridge::JavaClasses;
use crate::jni_bridge_call_static_method;
use crate::jni_global_ref;
use crate::jni_map_error;
use crate::{jni_bridge_call_method, ResultExt};
#[derive(Debug, Clone)]
pub struct ShuffleReaderExec {
pub num_partitions: usize,
pub native_shuffle_id: String,
pub schema: SchemaRef,
pub metrics: ExecutionPlanMetricsSet,
}
impl ShuffleReaderExec {
pub fn new(
num_partitions: usize,
native_shuffle_id: String,
schema: SchemaRef,
) -> ShuffleReaderExec {
ShuffleReaderExec {
num_partitions,
native_shuffle_id,
schema,
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
#[async_trait]
impl ExecutionPlan for ShuffleReaderExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
UnknownPartitioning(self.num_partitions)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Plan(
"Blaze ShuffleReaderExec does not support with_new_children()".to_owned(),
))
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let baseline_metrics = BaselineMetrics::new(&self.metrics, 0);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
let env = JavaClasses::get_thread_jnienv();
let segments_provider = jni_bridge_call_static_method!(
env,
JniBridge.getResource -> JObject,
jni_map_error!(env.new_string(&self.native_shuffle_id))?
)?;
let segments = jni_global_ref!(
env,
jni_bridge_call_method!(
env,
ScalaFunction0.apply -> JObject,
segments_provider
)?
)?;
let schema = self.schema.clone();
Ok(Box::pin(ShuffleReaderStream::new(
schema,
segments,
baseline_metrics,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
}
struct ShuffleReaderStream {
schema: SchemaRef,
segments: GlobalRef,
arrow_file_reader: Option<FileReader<SeekableByteChannelReader>>,
baseline_metrics: BaselineMetrics,
}
unsafe impl Sync for ShuffleReaderStream {} // safety: segments is safe to be shared
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl Send for ShuffleReaderStream {}
impl ShuffleReaderStream {
pub fn new(
schema: SchemaRef,
segments: GlobalRef,
baseline_metrics: BaselineMetrics,
) -> ShuffleReaderStream {
ShuffleReaderStream {
schema,
segments,
arrow_file_reader: None,
baseline_metrics,
}
}
fn next_segment(&mut self) -> Result<bool> {
let env = JavaClasses::get_thread_jnienv();
if jni_bridge_call_method!(
env,
ScalaIterator.hasNext -> jboolean,
self.segments.as_obj()
)? != JNI_TRUE
{
self.arrow_file_reader = None;
return Ok(false);
}
let channel = jni_bridge_call_method!(
env,
ScalaIterator.next -> JObject,
self.segments.as_obj()
)?;
self.arrow_file_reader = Some(FileReader::try_new(
SeekableByteChannelReader(jni_global_ref!(env, channel)?),
None,
)?);
// channel ref must be explicitly deleted to avoid OOM
jni_map_error!(env.delete_local_ref(channel))?;
Ok(true)
}
}
impl Stream for ShuffleReaderStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
if let Some(arrow_file_reader) = &mut self.arrow_file_reader {
if let Some(record_batch) = arrow_file_reader.next() {
return self
.baseline_metrics
.record_poll(Poll::Ready(Some(record_batch)));
}
}
// current arrow file reader reaches EOF, try next ipc
if self.next_segment()? {
return self.poll_next(cx);
}
Poll::Ready(None)
}
}
impl RecordBatchStream for ShuffleReaderStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
struct SeekableByteChannelReader(GlobalRef);
impl Read for SeekableByteChannelReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let env = JavaClasses::get_thread_jnienv();
Ok(jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.read -> jint,
self.0.as_obj(),
jni_map_error!(env.new_direct_byte_buffer(buf)).to_io_result()?
)
.to_io_result()? as usize)
}
}
impl Seek for SeekableByteChannelReader {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let env = JavaClasses::get_thread_jnienv();
match pos {
SeekFrom::Start(position) => {
jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.setPosition -> JObject,
self.0.as_obj(),
position as i64
)
.and_then(|channel| {
// channel ref must be explicitly deleted to avoid OOM
jni_map_error!(env.delete_local_ref(channel))
})
.to_io_result()?;
Ok(position)
}
SeekFrom::End(offset) => {
let size = jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.size -> jlong,
self.0.as_obj()
)
.to_io_result()? as u64;
let position = size + offset as u64;
jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.setPosition -> JObject,
self.0.as_obj(),
position as i64
)
.and_then(|channel| {
// channel ref must be explicitly deleted to avoid OOM
jni_map_error!(env.delete_local_ref(channel))
})
.to_io_result()?;
Ok(position)
}
SeekFrom::Current(_) => unimplemented!(),
}
}
}