| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #pragma once |
| |
| #include <jni.h> |
| #include <unistd.h> |
| |
| #include <cstdint> |
| #include <memory> |
| |
| #include "absl/strings/substitute.h" |
| #include "common/cast_set.h" |
| #include "common/compiler_util.h" |
| #include "common/exception.h" |
| #include "common/logging.h" |
| #include "common/status.h" |
| #include "runtime/user_function_cache.h" |
| #include "util/jni-util.h" |
| #include "vec/aggregate_functions/aggregate_function.h" |
| #include "vec/columns/column_array.h" |
| #include "vec/columns/column_map.h" |
| #include "vec/columns/column_string.h" |
| #include "vec/common/string_ref.h" |
| #include "vec/core/field.h" |
| #include "vec/core/types.h" |
| #include "vec/exec/jni_connector.h" |
| #include "vec/io/io_helper.h" |
| |
| namespace doris::vectorized { |
| #include "common/compile_check_begin.h" |
| |
| const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor"; |
| const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V"; |
| const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V"; |
| const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V"; |
| const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V"; |
| const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B"; |
| const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V"; |
| const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J"; |
| const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V"; |
| // Calling Java method about those signature means: "(argument-types)return-type" |
| // https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html |
| |
| struct AggregateJavaUdafData { |
| public: |
| AggregateJavaUdafData() = default; |
| AggregateJavaUdafData(int64_t num_args) { cast_set(argument_size, num_args); } |
| |
| ~AggregateJavaUdafData() = default; |
| |
| Status close_and_delete_object() { |
| JNIEnv* env = nullptr; |
| Defer defer {[&]() { |
| if (env != nullptr) { |
| env->DeleteGlobalRef(executor_cl); |
| env->DeleteGlobalRef(executor_obj); |
| } |
| }}; |
| Status st = JniUtil::GetJNIEnv(&env); |
| if (!st.ok()) { |
| LOG(WARNING) << "Failed to get JNIEnv"; |
| return st; |
| } |
| env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_close_id); |
| st = JniUtil::GetJniExceptionMsg(env); |
| if (!st.ok()) { |
| LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string(); |
| return st; |
| } |
| return Status::OK(); |
| } |
| |
| Status init_udaf(const TFunction& fn, const std::string& local_location) { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf init_udaf function"); |
| RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, UDAF_EXECUTOR_CLASS, &executor_cl)); |
| RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env), |
| "Java-Udaf register_func_id function"); |
| |
| // Add a scoped cleanup jni reference object. This cleans up local refs made below. |
| JniLocalFrame jni_frame; |
| { |
| TJavaUdfExecutorCtorParams ctor_params; |
| ctor_params.__set_fn(fn); |
| if (!fn.hdfs_location.empty() && !fn.checksum.empty()) { |
| ctor_params.__set_location(local_location); |
| } |
| jbyteArray ctor_params_bytes; |
| |
| // Pushed frame will be popped when jni_frame goes out-of-scope. |
| RETURN_IF_ERROR(jni_frame.push(env)); |
| RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes)); |
| executor_obj = env->NewObject(executor_cl, executor_ctor_id, ctor_params_bytes); |
| |
| jbyte* pBytes = env->GetByteArrayElements(ctor_params_bytes, nullptr); |
| env->ReleaseByteArrayElements(ctor_params_bytes, pBytes, JNI_ABORT); |
| env->DeleteLocalRef(ctor_params_bytes); |
| } |
| RETURN_ERROR_IF_EXC(env); |
| RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, executor_obj, &executor_obj)); |
| return Status::OK(); |
| } |
| |
| Status add(int64_t places_address, bool is_single_place, const IColumn** columns, |
| int64_t row_num_start, int64_t row_num_end, const DataTypes& argument_types, |
| int64_t place_offset) { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function"); |
| |
| Block input_block; |
| for (size_t i = 0; i < argument_size; ++i) { |
| input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], |
| std::to_string(i))); |
| } |
| std::unique_ptr<long[]> input_table; |
| RETURN_IF_ERROR(JniConnector::to_java_table(&input_block, input_table)); |
| auto input_table_schema = JniConnector::parse_table_schema(&input_block); |
| std::map<String, String> input_params = { |
| {"meta_address", std::to_string((long)input_table.get())}, |
| {"required_fields", input_table_schema.first}, |
| {"columns_types", input_table_schema.second}}; |
| jobject input_map = nullptr; |
| RETURN_IF_ERROR(JniUtil::convert_to_java_map(env, input_params, &input_map)); |
| // invoke add batch |
| // Keep consistent with the function signature of executor_add_batch_id. |
| env->CallObjectMethod(executor_obj, executor_add_batch_id, is_single_place, |
| cast_set<int>(row_num_start), cast_set<int>(row_num_end), |
| places_address, cast_set<int>(place_offset), input_map); |
| RETURN_ERROR_IF_EXC(env); |
| env->DeleteGlobalRef(input_map); |
| return JniUtil::GetJniExceptionMsg(env); |
| } |
| |
| Status merge(const AggregateJavaUdafData& rhs, int64_t place) { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf merge function"); |
| serialize_data = rhs.serialize_data; |
| jsize len = cast_set<jsize>(serialize_data.length()); // jsize needs to be used. |
| jbyteArray arr = env->NewByteArray(len); |
| env->SetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data())); |
| env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, place, arr); |
| RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); |
| jbyte* pBytes = env->GetByteArrayElements(arr, nullptr); |
| env->ReleaseByteArrayElements(arr, pBytes, JNI_ABORT); |
| env->DeleteLocalRef(arr); |
| return JniUtil::GetJniExceptionMsg(env); |
| } |
| |
| Status write(BufferWritable& buf, int64_t place) { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf write function"); |
| // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to |
| // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again. |
| jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod( |
| executor_obj, executor_cl, executor_serialize_id, place)); |
| RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); |
| int len = env->GetArrayLength(arr); |
| serialize_data.resize(len); |
| env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data())); |
| buf.write_binary(serialize_data); |
| jbyte* pBytes = env->GetByteArrayElements(arr, nullptr); |
| env->ReleaseByteArrayElements(arr, pBytes, JNI_ABORT); |
| env->DeleteLocalRef(arr); |
| return JniUtil::GetJniExceptionMsg(env); |
| } |
| |
| Status reset(int64_t place) { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf reset function"); |
| env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_reset_id, place); |
| return JniUtil::GetJniExceptionMsg(env); |
| } |
| |
| void read(BufferReadable& buf) { buf.read_binary(serialize_data); } |
| |
| Status destroy() { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf destroy function"); |
| env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_destroy_id); |
| return JniUtil::GetJniExceptionMsg(env); |
| } |
| |
| Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const { |
| JNIEnv* env = nullptr; |
| RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function"); |
| |
| Block output_block; |
| output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_")); |
| auto output_table_schema = JniConnector::parse_table_schema(&output_block); |
| std::string output_nullable = result_type->is_nullable() ? "true" : "false"; |
| std::map<String, String> output_params = {{"is_nullable", output_nullable}, |
| {"required_fields", output_table_schema.first}, |
| {"columns_types", output_table_schema.second}}; |
| jobject output_map = nullptr; |
| RETURN_IF_ERROR(JniUtil::convert_to_java_map(env, output_params, &output_map)); |
| long output_address = |
| env->CallLongMethod(executor_obj, executor_get_value_id, place, output_map); |
| RETURN_ERROR_IF_EXC(env); |
| env->DeleteGlobalRef(output_map); |
| RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); |
| return JniConnector::fill_block(&output_block, {0}, output_address); |
| } |
| |
| private: |
| Status register_func_id(JNIEnv* env) { |
| auto register_id = [&](const char* func_name, const char* func_sign, jmethodID& func_id) { |
| func_id = env->GetMethodID(executor_cl, func_name, func_sign); |
| Status s = JniUtil::GetJniExceptionMsg(env); |
| if (!s.ok()) { |
| LOG(WARNING) << "Failed to register function " << func_name << ": " |
| << s.to_string(); |
| return Status::InternalError(absl::Substitute( |
| "Java-Udaf register_func_id meet error and error is $0", s.to_string())); |
| } |
| return s; |
| }; |
| RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); |
| RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); |
| RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id)); |
| RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id)); |
| RETURN_IF_ERROR( |
| register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id)); |
| RETURN_IF_ERROR( |
| register_id("getValue", UDAF_EXECUTOR_GET_SIGNATURE, executor_get_value_id)); |
| RETURN_IF_ERROR( |
| register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id)); |
| RETURN_IF_ERROR( |
| register_id("addBatch", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_batch_id)); |
| return Status::OK(); |
| } |
| |
| private: |
| // TODO: too many variables are hold, it's causing a lot of memory waste |
| // it's time to refactor it. |
| jclass executor_cl; |
| jobject executor_obj; |
| jmethodID executor_ctor_id; |
| |
| jmethodID executor_add_batch_id; |
| jmethodID executor_merge_id; |
| jmethodID executor_serialize_id; |
| jmethodID executor_get_value_id; |
| jmethodID executor_reset_id; |
| jmethodID executor_close_id; |
| jmethodID executor_destroy_id; |
| int argument_size = 0; |
| std::string serialize_data; |
| }; |
| |
| class AggregateJavaUdaf final |
| : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf>, |
| VarargsExpression, |
| NullableAggregateFunction { |
| public: |
| ENABLE_FACTORY_CREATOR(AggregateJavaUdaf); |
| AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types_, |
| const DataTypePtr& return_type) |
| : IAggregateFunctionDataHelper(argument_types_), |
| _fn(fn), |
| _return_type(return_type), |
| _first_created(true), |
| _exec_place(nullptr) {} |
| ~AggregateJavaUdaf() override = default; |
| |
| static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_, |
| const DataTypePtr& return_type) { |
| return std::make_shared<AggregateJavaUdaf>(fn, argument_types_, return_type); |
| } |
| //Note: The condition is added because maybe the BE can't find java-udaf impl jar |
| //So need to check as soon as possible, before call Data function |
| Status check_udaf(const TFunction& fn) { |
| auto function_cache = UserFunctionCache::instance(); |
| // get jar path if both file path location and checksum are null |
| if (!fn.hdfs_location.empty() && !fn.checksum.empty()) { |
| return function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum, |
| &_local_location); |
| } else { |
| return Status::OK(); |
| } |
| } |
| |
| void create(AggregateDataPtr __restrict place) const override { |
| new (place) Data(argument_types.size()); |
| if (_first_created) { |
| Status status = this->data(place).init_udaf(_fn, _local_location); |
| _first_created = false; |
| _exec_place = place; |
| if (UNLIKELY(!status.ok())) { |
| static_cast<void>(this->data(place).destroy()); |
| this->data(place).~Data(); |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string()); |
| } |
| } |
| } |
| |
| // To avoid multiple times JNI call, Here will destroy all data at once |
| void destroy(AggregateDataPtr __restrict place) const noexcept override { |
| if (place == _exec_place) { |
| Status status = Status::OK(); |
| status = this->data(_exec_place).destroy(); |
| status = this->data(_exec_place).close_and_delete_object(); |
| _first_created = true; |
| if (UNLIKELY(!status.ok())) { |
| LOG(WARNING) << "Failed to destroy function: " << status.to_string(); |
| } |
| } |
| this->data(place).~Data(); |
| } |
| |
| String get_name() const override { return _fn.name.function_name; } |
| |
| DataTypePtr get_return_type() const override { return _return_type; } |
| |
| void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
| Arena&) const override { |
| int64_t places_address = reinterpret_cast<int64_t>(place); |
| Status st = this->data(_exec_place) |
| .add(places_address, true, columns, row_num, row_num + 1, |
| argument_types, 0); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, |
| const IColumn** columns, Arena&, bool /*agg_many*/) const override { |
| int64_t places_address = reinterpret_cast<int64_t>(places); |
| Status st = this->data(_exec_place) |
| .add(places_address, false, columns, 0, batch_size, argument_types, |
| place_offset); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, |
| Arena&) const override { |
| int64_t places_address = reinterpret_cast<int64_t>(place); |
| Status st = this->data(_exec_place) |
| .add(places_address, true, columns, 0, batch_size, argument_types, 0); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, |
| int64_t frame_end, AggregateDataPtr place, const IColumn** columns, |
| Arena&, UInt8* current_window_empty, |
| UInt8* current_window_has_inited) const override { |
| frame_start = std::max<int64_t>(frame_start, partition_start); |
| frame_end = std::min<int64_t>(frame_end, partition_end); |
| int64_t places_address = reinterpret_cast<int64_t>(place); |
| Status st = this->data(_exec_place) |
| .add(places_address, true, columns, frame_start, frame_end, |
| argument_types, 0); |
| if (frame_start >= frame_end) { |
| if (!*current_window_has_inited) { |
| *current_window_empty = true; |
| } |
| } else { |
| *current_window_empty = false; |
| *current_window_has_inited = true; |
| } |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void reset(AggregateDataPtr place) const override { |
| Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place)); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
| Arena&) const override { |
| Status st = |
| this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place)); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
| Status st = this->data(_exec_place).write(buf, reinterpret_cast<int64_t>(place)); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| // during merge-finalized phase, for deserialize and merge firstly, |
| // will call create --- deserialize --- merge --- destory for each rows , |
| // so need doing new (place), to create Data and read to buf, then call merge , |
| // and during destory about deserialize, because haven't done init_udaf, |
| // so it's can't call ~Data, only to change _destory_deserialize flag. |
| void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
| Arena&) const override { |
| this->data(place).read(buf); |
| } |
| |
| void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
| Status st = this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place)); |
| if (UNLIKELY(!st.ok())) { |
| throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
| } |
| } |
| |
| private: |
| TFunction _fn; |
| DataTypePtr _return_type; |
| mutable bool _first_created; |
| mutable AggregateDataPtr _exec_place; |
| std::string _local_location; |
| }; |
| |
| } // namespace doris::vectorized |
| |
| #include "common/compile_check_end.h" |