blob: 9f9c59b550c66b0771935d92b217bdc3beff03e5 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::any::Any;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::io::Read;
use std::io::Seek;
use std::io::SeekFrom;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::Result as ArrowResult;
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::metrics::BaselineMetrics;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::metrics::MetricsSet;
use datafusion::physical_plan::DisplayFormatType;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use datafusion::physical_plan::Partitioning::UnknownPartitioning;
use datafusion::physical_plan::RecordBatchStream;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::physical_plan::Statistics;
use futures::Stream;
use jni::errors::Result as JniResult;
use jni::objects::{GlobalRef, JObject};
use crate::jni_bridge::JavaClasses;
use crate::jni_bridge_call_method;
use crate::jni_bridge_call_static_method;
use crate::util::Util;
#[derive(Debug, Clone)]
pub struct ShuffleReaderExec {
pub num_partitions: usize,
pub native_shuffle_id: String,
pub schema: SchemaRef,
pub metrics: ExecutionPlanMetricsSet,
}
impl ShuffleReaderExec {
pub fn new(
num_partitions: usize,
native_shuffle_id: String,
schema: SchemaRef,
) -> ShuffleReaderExec {
ShuffleReaderExec {
num_partitions,
native_shuffle_id,
schema,
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
#[async_trait]
impl ExecutionPlan for ShuffleReaderExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
UnknownPartitioning(self.num_partitions)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Plan(
"Blaze ShuffleReaderExec does not support with_new_children()".to_owned(),
))
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let baseline_metrics = BaselineMetrics::new(&self.metrics, 0);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
let segments = Util::to_datafusion_external_result(Ok(()).and_then(|_| {
let env = JavaClasses::get_thread_jnienv();
let segments_provider = jni_bridge_call_static_method!(
env,
JniBridge.getResource,
env.new_string(&self.native_shuffle_id)?
)?
.l()?;
let segments =
jni_bridge_call_method!(env, ScalaFunction0.apply, segments_provider)?
.l()?;
JniResult::Ok(env.new_global_ref(segments)?)
}))?;
let schema = self.schema.clone();
Ok(Box::pin(ShuffleReaderStream::new(
schema,
segments,
baseline_metrics,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
}
struct ShuffleReaderStream {
schema: SchemaRef,
segments: GlobalRef,
current_segment: Option<GlobalRef>,
arrow_file_reader: Option<FileReader<SeekableByteChannelReader>>,
baseline_metrics: BaselineMetrics,
}
unsafe impl Sync for ShuffleReaderStream {} // safety: segments is safe to be shared
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl Send for ShuffleReaderStream {}
impl ShuffleReaderStream {
pub fn new(
schema: SchemaRef,
segments: GlobalRef,
baseline_metrics: BaselineMetrics,
) -> ShuffleReaderStream {
ShuffleReaderStream {
schema,
segments,
current_segment: None,
arrow_file_reader: None,
baseline_metrics,
}
}
fn next_segment(&mut self) -> Result<bool> {
Util::to_datafusion_external_result(Ok(()).and_then(|_| {
let env = JavaClasses::get_thread_jnienv();
let has_next = jni_bridge_call_method!(
env,
ScalaIterator.hasNext,
self.segments.as_obj()
)?
.z()?;
if !has_next {
self.current_segment = None;
self.arrow_file_reader = None;
return JniResult::Ok(false);
}
let next_segment =
jni_bridge_call_method!(env, ScalaIterator.next, self.segments.as_obj())?
.l()?;
self.current_segment = Some(env.new_global_ref(next_segment)?);
JniResult::Ok(true)
}))?;
if let Some(current_segment) = &self.current_segment {
self.arrow_file_reader = Some(FileReader::try_new(
// safety:
// the lifetime of SeekableByteChannelReader is exactly the same as self.arrow_file_reader
SeekableByteChannelReader(unsafe {
std::mem::transmute::<_, JObject<'static>>(current_segment.as_obj())
}),
None,
)?);
return Ok(true);
}
Ok(false)
}
}
impl Stream for ShuffleReaderStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
if let Some(arrow_file_reader) = &mut self.arrow_file_reader {
if let Some(record_batch) = arrow_file_reader.next() {
return self
.baseline_metrics
.record_poll(Poll::Ready(Some(record_batch)));
}
}
// current arrow file reader reaches EOF, try next ipc
if self.next_segment().unwrap() {
return self.poll_next(cx);
}
Poll::Ready(None)
}
}
impl RecordBatchStream for ShuffleReaderStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
struct SeekableByteChannelReader(JObject<'static>);
impl Read for SeekableByteChannelReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
Ok(())
.and_then(|_| {
let env = JavaClasses::get_thread_jnienv();
return JniResult::Ok(
jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.read,
self.0,
env.new_direct_byte_buffer(buf)?
)?
.i()? as usize,
);
})
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"JNI error: SeekableByteChannelReader.jni_read",
)
})
}
}
impl Seek for SeekableByteChannelReader {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
Ok(())
.and_then(|_| {
let env = JavaClasses::get_thread_jnienv();
match pos {
SeekFrom::Start(position) => {
jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.setPosition,
self.0,
position as i64
)?;
JniResult::Ok(position)
}
SeekFrom::End(offset) => {
let size = jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.size,
self.0
)?
.j()? as u64;
let position = size + offset as u64;
jni_bridge_call_method!(
env,
JavaNioSeekableByteChannel.setPosition,
self.0,
position as i64
)?;
JniResult::Ok(position)
}
SeekFrom::Current(_) => unimplemented!(),
}
})
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"JNI error: SeekableByteChannelReader.jni_seek",
)
})
}
}