blob: a42cb3be302ba94db08cd0b2bcebe4270e900a4e [file] [log] [blame]
use std::alloc::Layout;
use std::any::Any;
use std::error::Error;
use std::panic::AssertUnwindSafe;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use datafusion::arrow::array::{export_array_into_raw, StructArray};
use datafusion::arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::memory_manager::MemoryManagerConfig;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::physical_plan::{displayable, ExecutionPlan};
use datafusion::prelude::{SessionConfig, SessionContext};
use futures::{FutureExt, StreamExt};
use jni::objects::{JClass, JString};
use jni::objects::{JObject, JThrowable};
use jni::sys::{jbyteArray, jlong, JNI_FALSE, JNI_TRUE};
use jni::JNIEnv;
use log::LevelFilter;
use once_cell::sync::OnceCell;
use prost::Message;
use simplelog::{ColorChoice, ConfigBuilder, TermLogger, TerminalMode, ThreadLogMode};
use tokio::runtime::Runtime;
use datafusion_ext::jni_bridge::JavaClasses;
use datafusion_ext::*;
use plan_serde::protobuf::TaskDefinition;
use crate::BlazeIter;
static LOGGING_INIT: OnceCell<()> = OnceCell::new();
static SESSIONCTX: OnceCell<SessionContext> = OnceCell::new();
#[allow(non_snake_case)]
#[allow(clippy::single_match)]
#[no_mangle]
pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_initNative(
env: JNIEnv,
_: JClass,
batch_size: i64,
native_memory: i64,
memory_fraction: f64,
tmp_dirs: JString,
) {
match std::panic::catch_unwind(|| {
// init logging
LOGGING_INIT.get_or_init(|| {
TermLogger::init(
LevelFilter::Info,
ConfigBuilder::new()
.set_thread_mode(ThreadLogMode::Both)
.build(),
TerminalMode::Stderr,
ColorChoice::Never,
)
.unwrap();
});
// init jni java classes
JavaClasses::init(&env);
// init datafusion session context
SESSIONCTX.get_or_init(|| {
let env = JavaClasses::get_thread_jnienv();
let dirs = jni_map_error!(env.get_string(tmp_dirs))
.unwrap()
.to_string_lossy()
.split(',')
.map(PathBuf::from)
.collect::<Vec<_>>();
let max_memory = native_memory as usize;
let batch_size = batch_size as usize;
let runtime_config = RuntimeConfig::new()
.with_memory_manager(MemoryManagerConfig::New {
max_memory,
memory_fraction,
})
.with_disk_manager(DiskManagerConfig::NewSpecified(dirs));
let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap());
let config = SessionConfig::new().with_batch_size(batch_size);
SessionContext::with_config_rt(config, runtime)
});
}) {
Err(err) => {
handle_unwinded(err);
}
Ok(()) => {}
}
}
#[allow(non_snake_case)]
#[no_mangle]
pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_callNative(
env: JNIEnv,
_: JClass,
task_definition: jbyteArray,
) -> i64 {
match std::panic::catch_unwind(|| {
log::info!("Entering blaze callNative()");
let task_definition = TaskDefinition::decode(
env.convert_byte_array(task_definition).unwrap().as_slice(),
)
.unwrap();
let task_id = &task_definition.task_id.expect("task_id is empty");
let plan = &task_definition.plan.expect("plan is empty");
let execution_plan: Arc<dyn ExecutionPlan> = plan.try_into().unwrap();
let execution_plan_displayable =
displayable(execution_plan.as_ref()).indent().to_string();
log::info!("Creating native execution plan succeeded");
log::info!(" task_id={:?}", task_id);
log::info!(" execution plan:\n{}", execution_plan_displayable);
// execute
let session_ctx = SESSIONCTX.get().unwrap();
let task_ctx = session_ctx.task_ctx();
let stream = execution_plan
.execute(task_id.partition_id as usize, task_ctx)
.unwrap();
log::info!("Got stream");
// create tokio runtime used for loadNext()
let runtime = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
.block_on(async move {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.thread_keep_alive(Duration::MAX) // always use same thread
.build()
.unwrap(),
);
// propagate task context to spawned children threads
let env = JavaClasses::get_thread_jnienv();
let task_context = jni_global_ref!(
env,
jni_bridge_call_static_method!(
env,
JniBridge.getTaskContext -> JObject
)
.unwrap()
)
.unwrap();
runtime.spawn(async move {
AssertUnwindSafe(async move {
let env = JavaClasses::get_thread_jnienv();
jni_bridge_call_static_method!(
env,
JniBridge.setTaskContext -> (),
task_context.as_obj(),
)
.unwrap();
})
.catch_unwind()
.await
.unwrap_or_else(|err| {
let panic_message = panic_message::panic_message(&err);
throw_runtime_exception(panic_message, JObject::null())
.unwrap_or_fatal();
});
});
runtime
});
// safety - manually allocated memory will be released when stream is exhausted
log::info!("Got blaze iter");
unsafe {
let blaze_iter_ptr: *mut BlazeIter =
std::alloc::alloc(Layout::new::<BlazeIter>()) as *mut BlazeIter;
std::ptr::write(
blaze_iter_ptr,
BlazeIter {
stream,
execution_plan,
runtime,
},
);
blaze_iter_ptr as i64
}
}) {
Err(err) => {
handle_unwinded(err);
-1
}
Ok(ptr) => ptr,
}
}
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_loadBatches(
_: JNIEnv,
_: JClass,
iter_ptr: i64,
ret_queue: JObject,
err_queue: JObject,
) {
if let Err(err) = std::panic::catch_unwind(|| {
log::info!("Entering blaze loadBatches()");
let blaze_iter = &mut *(iter_ptr as *mut BlazeIter);
let env = JavaClasses::get_thread_jnienv();
let ret_queue = jni_global_ref!(env, ret_queue).unwrap();
let err_queue = jni_global_ref!(env, err_queue).unwrap();
// spawn a thread to poll next batch
blaze_iter.runtime.clone().spawn(async move {
AssertUnwindSafe(async move {
while let Some(r) = blaze_iter.stream.next().await {
match r {
Ok(batch) => {
let env = JavaClasses::get_thread_jnienv();
let num_rows = batch.num_rows();
if num_rows == 0 {
continue;
}
// ret_queue -> (schema_ptr, array_ptr)
let input = jni_bridge_call_method!(
env,
JavaSynchronousQueue.take -> JObject,
ret_queue.as_obj()
).unwrap();
let schema_ptr = jni_bridge_call_method!(env, ScalaTuple2._1 -> JObject, input).unwrap();
let schema_ptr = jni_bridge_call_method!(env, JavaLong.longValue -> jlong, schema_ptr).unwrap();
let array_ptr = jni_bridge_call_method!(env, ScalaTuple2._2 -> JObject, input).unwrap();
let array_ptr = jni_bridge_call_method!(env, JavaLong.longValue -> jlong, array_ptr).unwrap();
// exit by jvm side
if schema_ptr == 0 && array_ptr == 0 {
return;
}
let out_schema = schema_ptr as *mut FFI_ArrowSchema;
let out_array = array_ptr as *mut FFI_ArrowArray;
let batch: Arc<StructArray> = Arc::new(batch.into());
export_array_into_raw(
batch,
out_array,
out_schema,
)
.expect("export_array_into_raw error");
// ret_queue <- hasNext=true
let r = jni_bridge_new_object!(env, JavaBoolean, JNI_TRUE).unwrap();
jni_bridge_call_method!(
env,
JavaSynchronousQueue.put -> (),
ret_queue.as_obj(),
r
)
.unwrap();
}
Err(e) => {
panic!("stream.next() error: {:?}", e);
}
}
}
let env = JavaClasses::get_thread_jnienv();
// unblock queue
jni_bridge_call_method!(env, JavaSynchronousQueue.take -> JObject, ret_queue.as_obj()).unwrap();
// ret_queue <- hasNext=false
let r = jni_bridge_new_object!(env, JavaBoolean, JNI_FALSE).unwrap();
jni_bridge_call_method!(
env,
JavaSynchronousQueue.put -> (),
ret_queue.as_obj(),
r
)
.unwrap();
})
.catch_unwind()
.await
.map_err(|err| {
log::error!("native execution panics");
let env = JavaClasses::get_thread_jnienv();
let panic_message = panic_message::panic_message(&err);
// err_queue <- RuntimeException
jni_bridge_call_method!(
env,
JavaSynchronousQueue.put -> (),
err_queue.as_obj(),
jni_bridge_new_object!(
env,
JavaRuntimeException,
jni_map_error!(env.new_string(&panic_message))?,
JObject::null()
)?
)?;
datafusion::error::Result::Ok(())
})
.unwrap();
});
}) {
handle_unwinded(err)
}
}
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_deallocIter(
_: JNIEnv,
_: JClass,
iter_ptr: i64,
) {
// shutdown any background threads
// safety: safe to copy because Runtime::drop() does not do anything under ThreadPool mode
let runtime: Runtime =
std::mem::transmute_copy((*(iter_ptr as *mut BlazeIter)).runtime.as_ref());
runtime.shutdown_background();
// dealloc memory
std::alloc::dealloc(iter_ptr as *mut u8, Layout::new::<BlazeIter>());
}
fn is_jvm_interrupted(env: &JNIEnv) -> datafusion::error::Result<bool> {
let interrupted_exception_class = "java.lang.InterruptedException";
if env.exception_check().unwrap_or(false) {
let e: JObject = env
.exception_occurred()
.unwrap_or_else(|_| JThrowable::from(JObject::null()))
.into();
let class = jni_map_error!(env.get_object_class(e))?;
let classname = jni_bridge_call_method!(env, Class.getName -> JObject, class)?;
let classname = jni_map_error!(env.get_string(classname.into()))?;
if classname.to_string_lossy().as_ref() == interrupted_exception_class {
return Ok(true);
}
}
Ok(false)
}
fn throw_runtime_exception(msg: &str, cause: JObject) -> datafusion::error::Result<()> {
let env = JavaClasses::get_thread_jnienv();
let msg = jni_map_error!(env.new_string(msg))?;
let e = jni_bridge_new_object!(env, JavaRuntimeException, msg, cause)?;
let _throw = jni_bridge_call_static_method!(
env,
JniBridge.raiseThrowable -> (),
e
);
Ok(())
}
fn handle_unwinded(err: Box<dyn Any + Send>) {
let env = JavaClasses::get_thread_jnienv();
// default handling:
// * caused by InterruptedException: do nothing but just print a message.
// * other reasons: wrap it into a RuntimeException and throw.
// * if another error happens during handling, kill the whole JVM instance.
let recover = || {
if is_jvm_interrupted(&env)? {
env.exception_clear()?;
log::info!("native execution interrupted by JVM");
return Ok(());
}
let panic_message = panic_message::panic_message(&err);
// throw jvm runtime exception
let cause = if env.exception_check()? {
let throwable = env.exception_occurred()?.into();
env.exception_clear()?;
throwable
} else {
JObject::null()
};
throw_runtime_exception(panic_message, cause)?;
Ok(())
};
recover().unwrap_or_else(|err: Box<dyn Error>| {
env.fatal_error(format!(
"Error recovering from panic, cannot resume: {:?}",
err
));
});
}