| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include "SourceFromJavaIter.h" |
| #include <Interpreters/castColumn.h> |
| #include <Processors/Transforms/AggregatingTransform.h> |
| #include <jni/jni_common.h> |
| #include <Common/BlockTypeUtils.h> |
| #include <Common/CHUtil.h> |
| #include <Common/Exception.h> |
| #include <Common/JNIUtils.h> |
| |
| namespace DB |
| { |
| namespace ErrorCodes |
| { |
| extern const int LOGICAL_ERROR; |
| } |
| } |
| |
| namespace local_engine |
| { |
| jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; |
| jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext = nullptr; |
| jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr; |
| |
| static DB::Block getRealHeader(const DB::Block & header, const std::optional<DB::Block> & first_block) |
| { |
| if (header.empty()) |
| return BlockUtil::buildRowCountHeader(); |
| if (!first_block.has_value()) |
| return header; |
| if (header.columns() != first_block.value().columns()) |
| throw DB::Exception( |
| DB::ErrorCodes::LOGICAL_ERROR, |
| "Header first block have different number of columns, header:{} first_block:{}", |
| header.dumpStructure(), |
| first_block.value().dumpStructure()); |
| |
| DB::Block result; |
| const size_t column_size = header.columns(); |
| for (size_t i = 0; i < column_size; ++i) |
| { |
| const auto & header_column = header.getByPosition(i); |
| const auto & input_column = first_block.value().getByPosition(i); |
| chassert(header_column.name == input_column.name); |
| |
| DB::WhichDataType input_which(input_column.type); |
| /// Some AggregateFunctions may have parameters, so we need to use the exact type from the first block. |
| /// e.g. spark approx_percentile -> CH quantilesGK(accuracy, level1, level2, ...), the intermediate result type |
| /// parsed from substrait plan is always AggregateFunction(10000, 1)(quantilesGK, arg_type), which maybe different |
| /// from the actual intermediate result type from input block. So we need to use the exact type from the input block. |
| auto type = input_which.isAggregateFunction() ? input_column.type : header_column.type; |
| result.insert(DB::ColumnWithTypeAndName(type, header_column.name)); |
| } |
| return result; |
| } |
| |
| |
| std::optional<DB::Block> SourceFromJavaIter::peekBlock(JNIEnv * env, jobject java_iter) |
| { |
| jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext); |
| if (!has_next) |
| return std::nullopt; |
| |
| jbyteArray block_addr = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); |
| auto * block = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block_addr)); |
| if (block->columns()) |
| return std::optional(DB::Block(block->getColumnsWithTypeAndName())); |
| else |
| return std::nullopt; |
| } |
| |
| |
| SourceFromJavaIter::SourceFromJavaIter( |
| DB::ContextPtr context_, const DB::Block& header, jobject java_iter_, bool materialize_input_, std::optional<DB::Block> && first_block_) |
| : DB::ISource(toShared(getRealHeader(header, first_block_))) |
| , context(context_) |
| , original_header(header) |
| , java_iter(java_iter_) |
| , materialize_input(materialize_input_) |
| , first_block(first_block_) |
| { |
| } |
| |
| DB::Chunk SourceFromJavaIter::generate() |
| { |
| if (isCancelled()) |
| return {}; |
| |
| GET_JNIENV(env) |
| SCOPE_EXIT({CLEAN_JNIENV}); |
| |
| DB::Block * input_block = nullptr; |
| if (first_block.has_value()) [[unlikely]] |
| { |
| input_block = &first_block.value(); |
| } |
| else if (jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext)) |
| { |
| jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); |
| input_block = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block)); |
| } |
| else |
| return {}; |
| |
| DB::Chunk result; |
| if (!original_header.empty()) |
| { |
| const auto & header = getPort().getHeader(); |
| chassert(header.columns() == input_block->columns()); |
| /// Cast all input columns in data to expected data types in header |
| for (size_t i = 0; i < header.columns(); ++i) |
| { |
| auto & input_column = input_block->getByPosition(i); |
| const auto & expected_type = header.getByPosition(i).type; |
| auto column = DB::castColumn(input_column, expected_type); |
| input_column.column = column; |
| input_column.type = expected_type; |
| } |
| |
| /// Do materializing after casting is faster than materializing before casting |
| if (materialize_input) |
| materializeBlockInplace(*input_block); |
| |
| auto info = std::make_shared<DB::AggregatedChunkInfo>(); |
| info->is_overflows = input_block->info.is_overflows; |
| info->bucket_num = input_block->info.bucket_num; |
| result.getChunkInfos().add(std::move(info)); |
| result.setColumns(input_block->getColumns(), input_block->rows()); |
| } |
| else |
| { |
| result = BlockUtil::buildRowCountChunk(input_block->rows()); |
| auto info = std::make_shared<DB::AggregatedChunkInfo>(); |
| result.getChunkInfos().add(std::move(info)); |
| } |
| first_block = std::nullopt; |
| return result; |
| } |
| |
| SourceFromJavaIter::~SourceFromJavaIter() |
| { |
| GET_JNIENV(env) |
| env->DeleteGlobalRef(java_iter); |
| CLEAN_JNIENV |
| } |
| |
| Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr) |
| { |
| jsize len = env->GetArrayLength(arr); |
| assert(len == sizeof(Int64)); |
| char * c_arr = new char[len]; |
| env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte *>(c_arr)); |
| std::reverse(c_arr, c_arr + 8); |
| Int64 result = reinterpret_cast<Int64 *>(c_arr)[0]; |
| delete[] c_arr; |
| return result; |
| } |
| |
| } |