blob: cd3757ba945e1e792f5684f8e668150f344d916e [file] [log] [blame]
use std::alloc::Layout;
use std::error::Error;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::sync::Arc;
use std::time::Instant;
use datafusion::arrow::array::{export_array_into_raw, StructArray};
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::{displayable, ExecutionPlan};
use futures::StreamExt;
use jni::objects::JObject;
use jni::objects::{JClass, JString};
use jni::JNIEnv;
use log::{debug, error, info};
use prost::Message;
use datafusion_ext::jni_bridge::JavaClasses;
use datafusion_ext::*;
use plan_serde::protobuf::TaskDefinition;
use crate::{init_logging, init_session_ctx, BlazeIter};
#[allow(non_snake_case)]
#[no_mangle]
pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_callNative(
env: JNIEnv,
_: JClass,
taskDefinition: JObject,
poolSize: i64,
batch_size: i64,
nativeMemory: i64,
memoryFraction: f64,
tmpDirs: JString,
) -> i64 {
let start_time = Instant::now();
init_logging();
// save backtrace when panics
let result = match std::panic::catch_unwind(|| {
info!("Blaze native computing started");
let iter_ptr = blaze_call_native(
&env,
taskDefinition,
poolSize,
batch_size,
nativeMemory,
memoryFraction,
tmpDirs,
)
.unwrap();
info!("Blaze native computing finished");
iter_ptr
}) {
Err(e) => {
let recover = || {
if is_jvm_interrupted(&env)? {
env.exception_clear()?;
info!("Blaze native computing interrupted by JVM");
return Ok(());
}
let panic_message = panic_message::panic_message(&e);
// throw jvm runtime exception
let cause = if env.exception_check()? {
let throwable = env.exception_occurred()?.into();
env.exception_clear()?;
throwable
} else {
JObject::null()
};
let msg = env.new_string(&panic_message)?;
let _throw = jni_bridge_call_static_method_no_check_java_exception!(
env,
JniBridge.raiseThrowable,
jni_bridge_new_object!(env, JavaRuntimeException, msg, cause)?
);
Ok(())
};
recover().unwrap_or_else(|err: Box<dyn Error>| {
error!("Error recovering from panic, cannot resume: {:?}", err);
std::process::abort();
});
-1
}
Ok(ptr) => ptr,
};
let time_cost_sec = Instant::now().duration_since(start_time).as_secs_f64();
info!("blaze_call_native() time cost: {} sec", time_cost_sec);
result
}
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_loadNext(
_: JNIEnv,
_: JClass,
iter_ptr: i64,
schema_ptr: i64,
array_ptr: i64,
) -> i64 {
// loadNext is always called after callNative, therefore a tokio runtime already
tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
.block_on(async {
let blaze_iter = &mut *(iter_ptr as *mut BlazeIter);
loop {
return match blaze_iter.stream.next().await {
Some(Ok(batch)) => {
let num_rows = batch.num_rows();
if num_rows == 0 {
continue;
}
let renamed_batch = RecordBatch::try_new(
blaze_iter.renamed_schema.clone(),
batch.columns().to_vec(),
)
.unwrap();
let array: StructArray = renamed_batch.into();
let out_schema = schema_ptr as *mut FFI_ArrowSchema;
let out_array = array_ptr as *mut FFI_ArrowArray;
export_array_into_raw(Arc::new(array), out_array, out_schema)
.unwrap();
num_rows as i64
}
_ => -1,
};
}
})
}
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_deallocIter(
_: JNIEnv,
_: JClass,
iter_ptr: i64,
) {
std::alloc::dealloc(iter_ptr as *mut u8, Layout::new::<BlazeIter>());
}
#[allow(clippy::too_many_arguments)]
fn blaze_call_native(
env: &JNIEnv,
task_definition: JObject,
_pool_size: i64,
batch_size: i64,
native_memory: i64,
memory_fraction: f64,
tmp_dirs: JString,
) -> Result<i64, Box<dyn Error>> {
debug!("Initializing JavaClasses");
JavaClasses::init(env)?;
let env = JavaClasses::get_thread_jnienv();
debug!("Initializing JavaClasses succeeded");
debug!("Decoding task definition");
let task_definition_raw = env.convert_byte_array(task_definition.into_inner())?;
let task_definition: TaskDefinition = TaskDefinition::decode(&*task_definition_raw)?;
debug!("Decoding task definition succeeded");
debug!("Creating native execution plan");
let task_id = task_definition
.task_id
.ok_or_else(|| IoError::new(IoErrorKind::InvalidInput, "task id is empty"))?;
let plan = &task_definition
.plan
.ok_or_else(|| IoError::new(IoErrorKind::InvalidInput, "task plan is empty"))?;
let execution_plan: Arc<dyn ExecutionPlan> = plan.try_into()?;
let execution_plan_displayable =
displayable(execution_plan.as_ref()).indent().to_string();
info!("Creating native execution plan succeeded");
info!(" task_id={:?}", task_id);
info!(" execution plan:\n{}", execution_plan_displayable);
let dirs = env.get_string(tmp_dirs)?.into();
let batch_size = batch_size as usize;
assert!(batch_size > 0);
let session_ctx =
init_session_ctx(native_memory as usize, memory_fraction, batch_size, dirs);
let task_ctx = session_ctx.task_ctx();
// execute
let result_stream =
execution_plan.execute(task_id.partition_id as usize, task_ctx)?;
// rename all fields to avoid fields with duplicated names
// we are safe to do this because the cunsumer does not rely on these names
let mut num_fields = 0;
let renamed_schema: Arc<Schema> = Arc::new(Schema::new(
execution_plan
.schema()
.fields()
.iter()
.map(|field| {
let unnamed_field = Field::new(
&format!("_c{}", num_fields),
field.data_type().clone(),
field.is_nullable(),
);
num_fields += 1;
unnamed_field
})
.collect(),
));
// safety - manually allocated memory will be released when stream is exhausted
unsafe {
let blaze_iter_ptr: *mut BlazeIter =
std::alloc::alloc(Layout::new::<BlazeIter>()) as *mut BlazeIter;
std::ptr::write(
blaze_iter_ptr,
BlazeIter {
stream: result_stream,
execution_plan,
renamed_schema,
},
);
Ok(blaze_iter_ptr as i64)
}
}
fn is_jvm_interrupted(env: &JNIEnv) -> jni::errors::Result<bool> {
let interrupted_exception_class = "java.lang.InterruptedException";
if env.exception_check()? {
let e = env.exception_occurred()?;
let class = env.get_object_class(e)?;
let classname = jni_bridge_call_method!(env, Class.getName, class)?.l()?;
let classname = env.get_string(classname.into())?;
if classname.to_string_lossy().as_ref() == interrupted_exception_class {
return Ok(true);
}
}
Ok(false)
}