blob: 8af60a5534e7fcc9209ebf46435970f9ec50b863 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <arrow/c/abi.h>
#include <jni/JniCommon.h>
#include "JniHashTable.h"
#include "folly/String.h"
#include "memory/ColumnarBatch.h"
#include "memory/VeloxColumnarBatch.h"
#include "substrait/algebra.pb.h"
#include "substrait/type.pb.h"
#include "velox/core/PlanNode.h"
#include "velox/type/Type.h"
namespace gluten {
void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) {
vm_ = javaVm;
const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, classSig);
jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", "(Ljava/lang/String;)J");
}
void JniHashTableContext::finalize(JNIEnv* env) {
if (jniVeloxBroadcastBuildSideCache_ != nullptr) {
env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_);
jniVeloxBroadcastBuildSideCache_ = nullptr;
}
}
jlong JniHashTableContext::callJavaGet(const std::string& id) const {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
}
const jstring s = env->NewStringUTF(id.c_str());
auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, jniGet_, s);
return result;
}
// Return the velox's hash table.
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
const std::vector<std::string>& joinKeys,
std::vector<std::string> names,
std::vector<facebook::velox::TypePtr> veloxTypeList,
int joinType,
bool hasMixedJoinCondition,
bool isExistenceJoin,
bool isNullAwareAntiJoin,
int64_t bloomFilterPushdownSize,
std::vector<std::shared_ptr<ColumnarBatch>>& batches,
std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool) {
auto rowType = std::make_shared<facebook::velox::RowType>(std::move(names), std::move(veloxTypeList));
auto sJoin = static_cast<substrait::JoinRel_JoinType>(joinType);
facebook::velox::core::JoinType vJoin;
switch (sJoin) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
vJoin = facebook::velox::core::JoinType::kInner;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
vJoin = facebook::velox::core::JoinType::kFull;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
vJoin = facebook::velox::core::JoinType::kLeft;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
vJoin = facebook::velox::core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
// Determine the semi join type based on extracted information.
if (isExistenceJoin) {
vJoin = facebook::velox::core::JoinType::kLeftSemiProject;
} else {
vJoin = facebook::velox::core::JoinType::kLeftSemiFilter;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
// Determine the semi join type based on extracted information.
if (isExistenceJoin) {
vJoin = facebook::velox::core::JoinType::kRightSemiProject;
} else {
vJoin = facebook::velox::core::JoinType::kRightSemiFilter;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI: {
// Determine the anti join type based on extracted information.
vJoin = facebook::velox::core::JoinType::kAnti;
break;
}
default:
VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
}
std::vector<std::shared_ptr<const facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
joinKeyTypes.reserve(joinKeys.size());
for (const auto& name : joinKeys) {
joinKeyTypes.emplace_back(
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name), name));
}
auto hashTableBuilder = std::make_shared<HashTableBuilder>(
vJoin,
isNullAwareAntiJoin,
hasMixedJoinCondition,
bloomFilterPushdownSize,
joinKeyTypes,
rowType,
memoryPool.get());
for (auto i = 0; i < batches.size(); i++) {
auto rowVector = VeloxColumnarBatch::from(memoryPool.get(), batches[i])->getRowVector();
hashTableBuilder->addInput(rowVector);
}
return hashTableBuilder;
}
long getJoin(const std::string& hashTableId) {
return JniHashTableContext::getInstance().callJavaGet(hashTableId);
}
} // namespace gluten