blob: 7bf9a5262cdd797cce852a0803a6be9f19ac78cc [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use std::error::Error;
use std::panic::AssertUnwindSafe;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use datafusion::arrow::array::{export_array_into_raw, StructArray};
use datafusion::arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::memory_manager::MemoryManagerConfig;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::physical_plan::{displayable, ExecutionPlan};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_ext::jni_bridge::JavaClasses;
use datafusion_ext::*;
use futures::{FutureExt, StreamExt};
use jni::objects::{JClass, JString};
use jni::objects::{JObject, JThrowable};
use jni::sys::{jboolean, jlong, JNI_FALSE, JNI_TRUE};
use jni::JNIEnv;
use log::LevelFilter;
use once_cell::sync::OnceCell;
use plan_serde::protobuf::TaskDefinition;
use prost::Message;
use simplelog::{ColorChoice, ConfigBuilder, TermLogger, TerminalMode, ThreadLogMode};
use tokio::runtime::Runtime;
use crate::metrics::update_spark_metric_node;
static LOGGING_INIT: OnceCell<()> = OnceCell::new();
static SESSIONCTX: OnceCell<SessionContext> = OnceCell::new();
#[allow(non_snake_case)]
#[allow(clippy::single_match)]
#[no_mangle]
pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_initNative(
env: JNIEnv,
_: JClass,
batch_size: i64,
native_memory: i64,
memory_fraction: f64,
tmp_dirs: JString,
) {
match std::panic::catch_unwind(|| {
// init logging
LOGGING_INIT.get_or_init(|| {
TermLogger::init(
LevelFilter::Info,
ConfigBuilder::new()
.set_thread_mode(ThreadLogMode::Both)
.build(),
TerminalMode::Stderr,
ColorChoice::Never,
)
.unwrap();
});
// init jni java classes
JavaClasses::init(&env);
// init datafusion session context
SESSIONCTX.get_or_init(|| {
let dirs = jni_get_string!(tmp_dirs)
.unwrap()
.split(',')
.map(PathBuf::from)
.collect::<Vec<_>>();
let max_memory = native_memory as usize;
let batch_size = batch_size as usize;
let runtime_config = RuntimeConfig::new()
.with_memory_manager(MemoryManagerConfig::New {
max_memory,
memory_fraction,
})
.with_disk_manager(DiskManagerConfig::NewSpecified(dirs));
let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap());
let config = SessionConfig::new().with_batch_size(batch_size);
SessionContext::with_config_rt(config, runtime)
});
}) {
Err(err) => {
handle_unwinded(err);
}
Ok(()) => {}
}
}
#[allow(non_snake_case)]
#[no_mangle]
pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_callNative(
_: JNIEnv,
_: JClass,
wrapper: JObject,
) {
if let Err(err) = std::panic::catch_unwind(|| {
log::info!("Entering blaze callNative()");
let wrapper = Arc::new(jni_new_global_ref!(wrapper).unwrap());
let wrapper_clone = wrapper.clone();
let obj_true =
jni_new_global_ref!(jni_new_object!(JavaBoolean, JNI_TRUE).unwrap()).unwrap();
let obj_false =
jni_new_global_ref!(jni_new_object!(JavaBoolean, JNI_FALSE).unwrap())
.unwrap();
// decode plan
let raw_task_definition: JObject = jni_call!(
BlazeCallNativeWrapper(wrapper.as_obj()).getRawTaskDefinition() -> JObject
)
.unwrap();
let task_definition = TaskDefinition::decode(
jni_convert_byte_array!(raw_task_definition.into_inner())
.unwrap()
.as_slice(),
)
.unwrap();
let task_id = &task_definition.task_id.expect("task_id is empty");
let plan = &task_definition.plan.expect("plan is empty");
// get execution plan
let execution_plan: Arc<dyn ExecutionPlan> = plan.try_into().unwrap();
let execution_plan_displayable =
displayable(execution_plan.as_ref()).indent().to_string();
log::info!("Creating native execution plan succeeded");
log::info!(" task_id={:?}", task_id);
log::info!(" execution plan:\n{}", execution_plan_displayable);
// execute
let session_ctx = SESSIONCTX.get().unwrap();
let task_ctx = session_ctx.task_ctx();
let mut stream = execution_plan
.execute(task_id.partition_id as usize, task_ctx)
.unwrap();
let task_context = jni_new_global_ref!(
jni_call_static!(JniBridge.getTaskContext() -> JObject).unwrap()
)
.unwrap();
// a runtime wrapper that calls shutdown_background on dropping
struct RuntimeWrapper {
runtime: Option<Runtime>,
}
impl Drop for RuntimeWrapper {
fn drop(&mut self) {
if let Some(rt) = self.runtime.take() {
rt.shutdown_background();
}
}
}
// spawn a thread to poll batches
let runtime = Arc::new(RuntimeWrapper {
runtime: Some(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.thread_keep_alive(Duration::MAX) // always use same thread
.build()
.unwrap(),
),
});
let runtime_clone = runtime.clone();
runtime.clone().runtime.as_ref().unwrap().spawn(async move {
AssertUnwindSafe(async move {
let mut total_batches = 0;
let mut total_rows = 0;
// propagate task context to spawned children threads
jni_call_static!(JniBridge.setTaskContext(task_context.as_obj()) -> ()).unwrap();
// load batches
while let Some(r) = stream.next().await {
match r {
Ok(batch) => {
let num_rows = batch.num_rows();
if num_rows == 0 {
continue;
}
total_batches += 1;
total_rows += num_rows;
// value_queue -> (schema_ptr, array_ptr)
let mut input = JObject::null();
while jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).isFinished() -> jboolean).unwrap() != JNI_TRUE {
input = jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).dequeueWithTimeout() -> JObject).unwrap();
if !input.is_null() {
break;
}
}
if input.is_null() { // wrapper.isFinished = true
break;
}
let schema_ptr = jni_call!(ScalaTuple2(input)._1() -> JObject).unwrap();
let schema_ptr = jni_call!(JavaLong(schema_ptr).longValue() -> jlong).unwrap();
let array_ptr = jni_call!(ScalaTuple2(input)._2() -> JObject).unwrap();
let array_ptr = jni_call!(JavaLong(array_ptr).longValue() -> jlong).unwrap();
let out_schema = schema_ptr as *mut FFI_ArrowSchema;
let out_array = array_ptr as *mut FFI_ArrowArray;
let batch: Arc<StructArray> = Arc::new(batch.into());
unsafe {
export_array_into_raw(
batch,
out_array,
out_schema,
)
.expect("export_array_into_raw error");
}
// value_queue <- hasNext=true
while {
jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).isFinished() -> jboolean).unwrap() != JNI_TRUE &&
jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).enqueueWithTimeout(obj_true.as_obj()) -> jboolean).unwrap() != JNI_TRUE
} {}
}
Err(e) => {
panic!("stream.next() error: {:?}", e);
}
}
}
// value_queue -> (discard)
while jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).isFinished() -> jboolean).unwrap() != JNI_TRUE {
let input = jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).dequeueWithTimeout() -> JObject).unwrap();
if !input.is_null() {
break;
}
}
// value_queue <- hasNext=false
while {
jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).isFinished() -> jboolean).unwrap() != JNI_TRUE &&
jni_call!(BlazeCallNativeWrapper(wrapper.as_obj()).enqueueWithTimeout(obj_false.as_obj()) -> jboolean).unwrap() != JNI_TRUE
} {}
log::info!("Updating blaze exec metrics ...");
let metrics = jni_call!(
BlazeCallNativeWrapper(wrapper.as_obj()).getMetrics() -> JObject
).unwrap();
update_spark_metric_node(
metrics,
execution_plan.clone(),
).unwrap();
log::info!("Blaze native executing finished.");
log::info!(" total loaded batches: {}", total_batches);
log::info!(" total loaded rows: {}", total_rows);
std::mem::drop(runtime);
})
.catch_unwind()
.await
.map_err(|err| {
let panic_message = panic_message::panic_message(&err);
let e = if jni_exception_check!()? {
log::error!("native execution panics with an java exception");
log::error!("panic message: {}", panic_message);
jni_exception_occurred!()?.into()
} else {
log::error!("native execution panics");
log::error!("panic message: {}", panic_message);
jni_new_object!(
JavaRuntimeException,
jni_new_string!("blaze native panics")?,
JObject::null()
)?
};
// error_queue <- exception
while jni_call!(
BlazeCallNativeWrapper(wrapper_clone.as_obj()).isFinished() -> jboolean
).unwrap() != JNI_TRUE {
let enqueued = jni_call!(
BlazeCallNativeWrapper(wrapper_clone.as_obj()).enqueueError(e) -> jboolean
)?;
if enqueued == JNI_TRUE {
break;
}
}
log::info!("Blaze native executing exited with error.");
std::mem::drop(runtime_clone);
datafusion::error::Result::Ok(())
})
.unwrap();
});
log::info!("Blaze native thread created");
}) {
handle_unwinded(err);
}
}
fn is_jvm_interrupted() -> datafusion::error::Result<bool> {
let interrupted_exception_class = "java.lang.InterruptedException";
if jni_exception_check!()? {
let e: JObject = jni_exception_occurred!()?.into();
let class = jni_get_object_class!(e)?;
let classname_obj = jni_call!(Class(class).getName() -> JObject)?;
let classname = jni_get_string!(classname_obj.into())?;
if classname == interrupted_exception_class {
return Ok(true);
}
}
Ok(false)
}
fn throw_runtime_exception(msg: &str, cause: JObject) -> datafusion::error::Result<()> {
let msg = jni_new_string!(msg)?;
let e = jni_new_object!(JavaRuntimeException, msg, cause)?;
if let Err(err) = jni_throw!(JThrowable::from(e)) {
jni_fatal_error!(format!(
"Error throwing RuntimeException, cannot result: {:?}",
err
));
}
Ok(())
}
fn handle_unwinded(err: Box<dyn Any + Send>) {
// default handling:
// * caused by InterruptedException: do nothing but just print a message.
// * other reasons: wrap it into a RuntimeException and throw.
// * if another error happens during handling, kill the whole JVM instance.
let recover = || {
if is_jvm_interrupted()? {
jni_exception_clear!()?;
log::info!("native execution interrupted by JVM");
return Ok(());
}
let panic_message = panic_message::panic_message(&err);
// throw jvm runtime exception
let cause = if jni_exception_check!()? {
let throwable = jni_exception_occurred!()?.into();
jni_exception_clear!()?;
throwable
} else {
JObject::null()
};
throw_runtime_exception(panic_message, cause)?;
Ok(())
};
recover().unwrap_or_else(|err: Box<dyn Error>| {
jni_fatal_error!(format!(
"Error recovering from panic, cannot resume: {:?}",
err
));
});
}