blob: fdb6ab925e4e73377ab75a79e2a1218894f4a660 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::ScalarFunctionImplementation;
pub mod debug_exec;
pub mod empty_partitions_exec;
pub mod hdfs_object_store; // note: can be changed to priv once plan transforming is removed
pub mod ipc_reader_exec;
pub mod ipc_writer_exec;
pub mod jni_bridge;
pub mod rename_columns_exec;
pub mod shuffle_writer_exec;
pub mod sort_merge_join_exec;
mod spark_ext_function;
mod spark_hash;
mod util;
pub trait ResultExt<T> {
fn unwrap_or_fatal(self) -> T;
fn to_io_result(self) -> std::io::Result<T>;
}
impl<T, E> ResultExt<T> for Result<T, E>
where
E: std::error::Error,
{
fn unwrap_or_fatal(self) -> T {
match self {
Ok(value) => value,
Err(err) => {
if jni_exception_check!().unwrap_or(false) {
let _ = jni_exception_describe!();
let _ = jni_exception_clear!();
}
let errmsg = format!("uncaught error in native code: {:?}", err);
std::panic::catch_unwind(move || jni_fatal_error!(errmsg))
.unwrap_or_else(|_| std::process::abort())
}
}
}
fn to_io_result(self) -> std::io::Result<T> {
match self {
Ok(value) => Ok(value),
Err(err) => Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("{:?}", err),
)),
}
}
}
pub fn create_spark_ext_function(
name: &str,
) -> datafusion::error::Result<ScalarFunctionImplementation> {
Ok(match name {
"UnscaledValue" => Arc::new(spark_ext_function::spark_unscaled_value),
"MakeDecimal" => Arc::new(spark_ext_function::spark_make_decimal),
_ => Err(DataFusionError::NotImplemented(format!(
"spark ext function not implemented: {}",
name
)))?,
})
}