blob: 7c42cf5b499a2a4eadfc1740cd288acb744c1c2d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "operators/hashjoin/HashTableBuilder.h"
#include "velox/exec/OperatorUtils.h"
namespace gluten {
namespace {
facebook::velox::RowTypePtr hashJoinTableType(
const std::vector<facebook::velox::core::FieldAccessTypedExprPtr>& joinKeys,
const facebook::velox::RowTypePtr& inputType) {
const auto numKeys = joinKeys.size();
std::vector<std::string> names;
names.reserve(inputType->size());
std::vector<facebook::velox::TypePtr> types;
types.reserve(inputType->size());
std::unordered_set<uint32_t> keyChannelSet;
keyChannelSet.reserve(inputType->size());
for (int i = 0; i < numKeys; ++i) {
auto& key = joinKeys[i];
auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType);
keyChannelSet.insert(channel);
names.emplace_back(inputType->nameOf(channel));
types.emplace_back(inputType->childAt(channel));
}
for (auto i = 0; i < inputType->size(); ++i) {
if (keyChannelSet.find(i) == keyChannelSet.end()) {
names.emplace_back(inputType->nameOf(i));
types.emplace_back(inputType->childAt(i));
}
}
return ROW(std::move(names), std::move(types));
}
bool isLeftNullAwareJoinWithFilter(facebook::velox::core::JoinType joinType, bool nullAware, bool withFilter) {
return (isAntiJoin(joinType) || isLeftSemiProjectJoin(joinType) || isLeftSemiFilterJoin(joinType)) && nullAware &&
withFilter;
}
} // namespace
HashTableBuilder::HashTableBuilder(
facebook::velox::core::JoinType joinType,
bool nullAware,
bool withFilter,
int64_t bloomFilterPushdownSize,
const std::vector<facebook::velox::core::FieldAccessTypedExprPtr>& joinKeys,
const facebook::velox::RowTypePtr& inputType,
facebook::velox::memory::MemoryPool* pool)
: joinType_{joinType},
nullAware_{nullAware},
withFilter_(withFilter),
keyChannelMap_(joinKeys.size()),
inputType_(inputType),
bloomFilterPushdownSize_(bloomFilterPushdownSize),
pool_(pool) {
const auto numKeys = joinKeys.size();
keyChannels_.reserve(numKeys);
for (int i = 0; i < numKeys; ++i) {
auto& key = joinKeys[i];
auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType_);
keyChannelMap_[channel] = i;
keyChannels_.emplace_back(channel);
}
// Identify the non-key build side columns and make a decoder for each.
const int32_t numDependents = inputType_->size() - numKeys;
if (numDependents > 0) {
// Number of join keys (numKeys) may be less then number of input columns
// (inputType->size()). In this case numDependents is negative and cannot be
// used to call 'reserve'. This happens when we join different probe side
// keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 =
// u.k AND t.k2 = u.k.
dependentChannels_.reserve(numDependents);
decoders_.reserve(numDependents);
}
for (auto i = 0; i < inputType->size(); ++i) {
if (keyChannelMap_.find(i) == keyChannelMap_.end()) {
dependentChannels_.emplace_back(i);
decoders_.emplace_back(std::make_unique<facebook::velox::DecodedVector>());
}
}
tableType_ = hashJoinTableType(joinKeys, inputType);
setupTable();
}
// Invoked to set up hash table to build.
void HashTableBuilder::setupTable() {
VELOX_CHECK_NULL(uniqueTable_);
const auto numKeys = keyChannels_.size();
std::vector<std::unique_ptr<facebook::velox::exec::VectorHasher>> keyHashers;
keyHashers.reserve(numKeys);
for (vector_size_t i = 0; i < numKeys; ++i) {
keyHashers.emplace_back(facebook::velox::exec::VectorHasher::create(tableType_->childAt(i), keyChannels_[i]));
}
const auto numDependents = tableType_->size() - numKeys;
std::vector<facebook::velox::TypePtr> dependentTypes;
dependentTypes.reserve(numDependents);
for (int i = numKeys; i < tableType_->size(); ++i) {
dependentTypes.emplace_back(tableType_->childAt(i));
}
if (isRightJoin(joinType_) || isFullJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
// Do not ignore null keys.
uniqueTable_ = facebook::velox::exec::HashTable<false>::createForJoin(
std::move(keyHashers),
dependentTypes,
true, // allowDuplicates
true, // hasProbedFlag
1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
pool_,
true);
} else {
// (Left) semi and anti join with no extra filter only needs to know whether
// there is a match. Hence, no need to store entries with duplicate keys.
dropDuplicates_ =
!withFilter_ && (isLeftSemiFilterJoin(joinType_) || isLeftSemiProjectJoin(joinType_) || isAntiJoin(joinType_));
// Right semi join needs to tag build rows that were probed.
const bool needProbedFlag = isRightSemiFilterJoin(joinType_);
if (isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
// We need to check null key rows in build side in case of null-aware anti
// or left semi project join with filter set.
uniqueTable_ = facebook::velox::exec::HashTable<false>::createForJoin(
std::move(keyHashers),
dependentTypes,
!dropDuplicates_, // allowDuplicates
needProbedFlag, // hasProbedFlag
1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
pool_,
true);
} else {
// Ignore null keys
uniqueTable_ = facebook::velox::exec::HashTable<true>::createForJoin(
std::move(keyHashers),
dependentTypes,
!dropDuplicates_, // allowDuplicates
needProbedFlag, // hasProbedFlag
1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
pool_,
bloomFilterPushdownSize_);
}
}
analyzeKeys_ = uniqueTable_->hashMode() != facebook::velox::exec::BaseHashTable::HashMode::kHash;
}
void HashTableBuilder::addInput(facebook::velox::RowVectorPtr input) {
activeRows_.resize(input->size());
activeRows_.setAll();
auto& hashers = uniqueTable_->hashers();
for (auto i = 0; i < hashers.size(); ++i) {
auto key = input->childAt(hashers[i]->channel())->loadedVector();
hashers[i]->decode(*key, activeRows_);
}
deselectRowsWithNulls(hashers, activeRows_);
activeRows_.setAll();
if (!isRightJoin(joinType_) && !isFullJoin(joinType_) && !isRightSemiProjectJoin(joinType_) &&
!isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
deselectRowsWithNulls(hashers, activeRows_);
if (nullAware_ && !joinHasNullKeys_ && activeRows_.countSelected() < input->size()) {
joinHasNullKeys_ = true;
}
} else if (nullAware_ && !joinHasNullKeys_) {
for (auto& hasher : hashers) {
auto& decoded = hasher->decodedVector();
if (decoded.mayHaveNulls()) {
auto* nulls = decoded.nulls(&activeRows_);
if (nulls && facebook::velox::bits::countNulls(nulls, 0, activeRows_.end()) > 0) {
joinHasNullKeys_ = true;
break;
}
}
}
}
for (auto i = 0; i < dependentChannels_.size(); ++i) {
decoders_[i]->decode(*input->childAt(dependentChannels_[i])->loadedVector(), activeRows_);
}
if (!activeRows_.hasSelections()) {
return;
}
if (analyzeKeys_ && hashes_.size() < activeRows_.end()) {
hashes_.resize(activeRows_.end());
}
// As long as analyzeKeys is true, we keep running the keys through
// the Vectorhashers so that we get a possible mapping of the keys
// to small ints for array or normalized key. When mayUseValueIds is
// false for the first time we stop. We do not retain the value ids
// since the final ones will only be known after all data is
// received.
for (auto& hasher : hashers) {
// TODO: Load only for active rows, except if right/full outer join.
if (analyzeKeys_) {
hasher->computeValueIds(activeRows_, hashes_);
analyzeKeys_ = hasher->mayUseValueIds();
}
}
auto rows = uniqueTable_->rows();
auto nextOffset = rows->nextOffset();
activeRows_.applyToSelected([&](auto rowIndex) {
char* newRow = rows->newRow();
if (nextOffset) {
*reinterpret_cast<char**>(newRow + nextOffset) = nullptr;
}
// Store the columns for each row in sequence. At probe time
// strings of the row will probably be in consecutive places, so
// reading one will prime the cache for the next.
for (auto i = 0; i < hashers.size(); ++i) {
rows->store(hashers[i]->decodedVector(), rowIndex, newRow, i);
}
for (auto i = 0; i < dependentChannels_.size(); ++i) {
rows->store(*decoders_[i], rowIndex, newRow, i + hashers.size());
}
});
}
} // namespace gluten