blob: c41b14b396ec41814811ce78ed4a45f8c3de22f1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SourceFromJavaIter.h"
#include <Interpreters/castColumn.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <jni/jni_common.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/JNIUtils.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}
namespace local_engine
{
jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr;
jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext = nullptr;
jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr;
static DB::Block getRealHeader(const DB::Block & header, const std::optional<DB::Block> & first_block)
{
if (header.empty())
return BlockUtil::buildRowCountHeader();
if (!first_block.has_value())
return header;
if (header.columns() != first_block.value().columns())
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Header first block have different number of columns, header:{} first_block:{}",
header.dumpStructure(),
first_block.value().dumpStructure());
DB::Block result;
const size_t column_size = header.columns();
for (size_t i = 0; i < column_size; ++i)
{
const auto & header_column = header.getByPosition(i);
const auto & input_column = first_block.value().getByPosition(i);
chassert(header_column.name == input_column.name);
DB::WhichDataType input_which(input_column.type);
/// Some AggregateFunctions may have parameters, so we need to use the exact type from the first block.
/// e.g. spark approx_percentile -> CH quantilesGK(accuracy, level1, level2, ...), the intermediate result type
/// parsed from substrait plan is always AggregateFunction(10000, 1)(quantilesGK, arg_type), which maybe different
/// from the actual intermediate result type from input block. So we need to use the exact type from the input block.
auto type = input_which.isAggregateFunction() ? input_column.type : header_column.type;
result.insert(DB::ColumnWithTypeAndName(type, header_column.name));
}
return result;
}
std::optional<DB::Block> SourceFromJavaIter::peekBlock(JNIEnv * env, jobject java_iter)
{
jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext);
if (!has_next)
return std::nullopt;
jbyteArray block_addr = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next));
auto * block = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block_addr));
if (block->columns())
return std::optional(DB::Block(block->getColumnsWithTypeAndName()));
else
return std::nullopt;
}
SourceFromJavaIter::SourceFromJavaIter(
DB::ContextPtr context_, const DB::Block& header, jobject java_iter_, bool materialize_input_, std::optional<DB::Block> && first_block_)
: DB::ISource(toShared(getRealHeader(header, first_block_)))
, context(context_)
, original_header(header)
, java_iter(java_iter_)
, materialize_input(materialize_input_)
, first_block(first_block_)
{
}
DB::Chunk SourceFromJavaIter::generate()
{
if (isCancelled())
return {};
GET_JNIENV(env)
SCOPE_EXIT({CLEAN_JNIENV});
DB::Block * input_block = nullptr;
if (first_block.has_value()) [[unlikely]]
{
input_block = &first_block.value();
}
else if (jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext))
{
jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next));
input_block = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block));
}
else
return {};
DB::Chunk result;
if (!original_header.empty())
{
const auto & header = getPort().getHeader();
chassert(header.columns() == input_block->columns());
/// Cast all input columns in data to expected data types in header
for (size_t i = 0; i < header.columns(); ++i)
{
auto & input_column = input_block->getByPosition(i);
const auto & expected_type = header.getByPosition(i).type;
auto column = DB::castColumn(input_column, expected_type);
input_column.column = column;
input_column.type = expected_type;
}
/// Do materializing after casting is faster than materializing before casting
if (materialize_input)
materializeBlockInplace(*input_block);
auto info = std::make_shared<DB::AggregatedChunkInfo>();
info->is_overflows = input_block->info.is_overflows;
info->bucket_num = input_block->info.bucket_num;
result.getChunkInfos().add(std::move(info));
result.setColumns(input_block->getColumns(), input_block->rows());
}
else
{
result = BlockUtil::buildRowCountChunk(input_block->rows());
auto info = std::make_shared<DB::AggregatedChunkInfo>();
result.getChunkInfos().add(std::move(info));
}
first_block = std::nullopt;
return result;
}
SourceFromJavaIter::~SourceFromJavaIter()
{
GET_JNIENV(env)
env->DeleteGlobalRef(java_iter);
CLEAN_JNIENV
}
Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr)
{
jsize len = env->GetArrayLength(arr);
assert(len == sizeof(Int64));
char * c_arr = new char[len];
env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte *>(c_arr));
std::reverse(c_arr, c_arr + 8);
Int64 result = reinterpret_cast<Int64 *>(c_arr)[0];
delete[] c_arr;
return result;
}
}