blob: 9a83f08b5438c61fa38517888a489e04e2c10e60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "BroadCastJoinBuilder.h"
#include <Compression/CompressedReadBuffer.h>
#include <Interpreters/TableJoin.h>
#include <Join/StorageJoinFromReadBuffer.h>
#include <Parser/RelParsers/JoinRelParser.h>
#include <Parser/TypeParser.h>
#include <QueryPipeline/ProfileInfo.h>
#include <Shuffle/ShuffleReader.h>
#include <jni/SharedPointerWrapper.h>
#include <jni/jni_common.h>
#include <Poco/StringTokenizer.h>
#include <Common/CHUtil.h>
#include <Common/JNIUtils.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_TYPE;
}
}
namespace local_engine
{
namespace BroadCastJoinBuilder
{
using namespace DB;
static jclass Java_CHBroadcastBuildSideCache = nullptr;
static jmethodID Java_get = nullptr;
jlong callJavaGet(const std::string & id)
{
GET_JNIENV(env)
const jstring s = charTojstring(env, id.c_str());
const auto result = safeCallStaticLongMethod(env, Java_CHBroadcastBuildSideCache, Java_get, s);
CLEAN_JNIENV
return result;
}
DB::Block resetBuildTableBlockName(Block & block, bool only_one = false)
{
DB::ColumnsWithTypeAndName new_cols;
std::set<std::string> names;
int32_t seq = 0;
for (const auto & col : block)
{
// Add a prefix to avoid column name conflicts with left table.
std::stringstream new_name;
// add a sequence to avoid duplicate name in some rare cases
if (names.find(col.name) == names.end())
{
new_name << BlockUtil::RIHGT_COLUMN_PREFIX << col.name;
names.insert(col.name);
}
else
{
new_name << BlockUtil::RIHGT_COLUMN_PREFIX << (seq++) << "_" << col.name;
}
new_cols.emplace_back(col.column, col.type, new_name.str());
if (only_one)
break;
}
return DB::Block(new_cols);
}
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance)
{
auto clean_join = [&]
{
SharedPointerWrapper<StorageJoinFromReadBuffer>::dispose(instance);
};
/// Record memory usage in Total Memory Tracker
ThreadFromGlobalPoolNoTracingContextPropagation thread(clean_join);
thread.join();
LOG_DEBUG(&Poco::Logger::get("BroadCastJoinBuilder"), "Broadcast hash table {} is cleaned", hash_table_id);
}
std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & key)
{
const jlong result = callJavaGet(key);
if (unlikely(result == 0))
throw Exception(ErrorCodes::LOGICAL_ERROR, "broadcast table {} not found in cache.", key);
auto wrapper = SharedPointerWrapper<StorageJoinFromReadBuffer>::sharedPtr(result);
if (unlikely(!wrapper))
throw Exception(ErrorCodes::LOGICAL_ERROR, "broadcast table {} not found, cache value is invalidated.", key);
return wrapper;
}
std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
const std::string & key,
DB::ReadBuffer & input,
jlong row_count,
const std::string & join_keys,
jint join_type,
bool has_mixed_join_condition,
bool is_existence_join,
const std::string & named_struct,
bool is_null_aware_anti_join,
bool has_null_key_values)
{
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
Names key_names;
for (const auto & key_name : join_key_list)
key_names.emplace_back(BlockUtil::RIHGT_COLUMN_PREFIX + key_name);
DB::JoinKind kind;
DB::JoinStrictness strictness;
if (key.starts_with("BuiltBNLJBroadcastTable-"))
std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast<substrait::CrossRel_JoinType>(join_type));
else
std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type), is_existence_join);
substrait::NamedStruct substrait_struct;
substrait_struct.ParseFromString(named_struct);
Block header = TypeParser::buildBlockFromNamedStruct(substrait_struct);
header = resetBuildTableBlockName(header);
Blocks data;
auto collect_data = [&]
{
bool only_one_column = header.getNamesAndTypesList().empty();
if (only_one_column)
header = BlockUtil::buildRowCountBlock(0).getColumnsWithTypeAndName();
NativeReader block_stream(input);
ProfileInfo info;
Block block = block_stream.read();
while (!block.empty())
{
DB::ColumnsWithTypeAndName columns;
for (size_t i = 0; i < block.columns(); ++i)
{
const auto & column = block.getByPosition(i);
if (only_one_column)
{
auto virtual_block = BlockUtil::buildRowCountBlock(column.column->size()).getColumnsWithTypeAndName();
header = virtual_block;
columns.emplace_back(virtual_block.back());
break;
}
columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i)));
}
DB::Block final_block(columns);
info.update(final_block);
data.emplace_back(std::move(final_block));
block = block_stream.read();
}
};
/// Record memory usage in Total Memory Tracker
ThreadFromGlobalPoolNoTracingContextPropagation thread(collect_data);
thread.join();
ColumnsDescription columns_description(header.getNamesAndTypesList());
return make_shared<StorageJoinFromReadBuffer>(
data,
row_count,
key_names,
true,
kind,
strictness,
has_mixed_join_condition,
columns_description,
ConstraintsDescription(),
key,
true,
is_null_aware_anti_join,
has_null_key_values);
}
void init(JNIEnv * env)
{
/**
* Scala object will be compiled into two classes, one is with '$' suffix which is normal class,
* and one is utility class which only has static method.
*
* Here, we use utility class.
*/
const char * classSig = "Lorg/apache/gluten/execution/CHBroadcastBuildSideCache;";
Java_CHBroadcastBuildSideCache = CreateGlobalClassReference(env, classSig);
Java_get = GetStaticMethodID(env, Java_CHBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J");
}
void destroy(JNIEnv * env)
{
env->DeleteGlobalRef(Java_CHBroadcastBuildSideCache);
}
}
}