blob: 251c5b1c72a8eb5bc774390fd241e5734efb252f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/compute/kernels/hash.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "arrow/array.h"
#include "arrow/array/dict_internal.h"
#include "arrow/buffer.h"
#include "arrow/builder.h"
#include "arrow/compute/context.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/hashing.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/string_view.h"
#include "arrow/visitor_inline.h"
namespace arrow {
class MemoryPool;
using internal::checked_cast;
using internal::DictionaryTraits;
using internal::HashTraits;
namespace compute {
namespace {
#define CHECK_IMPLEMENTED(KERNEL, FUNCNAME, TYPE) \
if (!KERNEL) { \
return Status::NotImplemented(FUNCNAME, " not implemented for ", type->ToString()); \
}
// ----------------------------------------------------------------------
// Unique implementation
class ActionBase {
public:
ActionBase(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: type_(type), pool_(pool) {}
protected:
std::shared_ptr<DataType> type_;
MemoryPool* pool_;
};
class UniqueAction final : public ActionBase {
public:
using ActionBase::ActionBase;
Status Reset() { return Status::OK(); }
Status Reserve(const int64_t length) { return Status::OK(); }
template <class Index>
void ObserveNullFound(Index index) {}
template <class Index>
void ObserveNullNotFound(Index index) {}
template <class Index>
void ObserveFound(Index index) {}
template <class Index>
void ObserveNotFound(Index index) {}
Status Flush(Datum* out) { return Status::OK(); }
std::shared_ptr<DataType> out_type() const { return type_; }
Status FlushFinal(Datum* out) { return Status::OK(); }
};
// ----------------------------------------------------------------------
// Count values implementation (see HashKernel for description of methods)
class ValueCountsAction final : ActionBase {
public:
using ActionBase::ActionBase;
ValueCountsAction(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: ActionBase(type, pool), count_builder_(pool) {}
Status Reserve(const int64_t length) {
// builder size is independent of input array size.
return Status::OK();
}
Status Reset() {
count_builder_.Reset();
return Status::OK();
}
// Don't do anything on flush because we don't want to finalize the builder
// or incur the cost of memory copies.
Status Flush(Datum* out) { return Status::OK(); }
std::shared_ptr<DataType> out_type() const { return type_; }
// Return the counts corresponding the MemoTable keys.
Status FlushFinal(Datum* out) {
std::shared_ptr<ArrayData> result;
RETURN_NOT_OK(count_builder_.FinishInternal(&result));
out->value = std::move(result);
return Status::OK();
}
template <class Index>
void ObserveNullFound(Index index) {
count_builder_[index]++;
}
template <class Index>
void ObserveNullNotFound(Index index) {
ARROW_LOG(FATAL) << "ObserveNullNotFound without err_status should not be called";
}
template <class Index>
void ObserveNullNotFound(Index index, Status* status) {
Status s = count_builder_.Append(1);
if (ARROW_PREDICT_FALSE(!s.ok())) {
*status = s;
}
}
template <class Index>
void ObserveFound(Index slot) {
count_builder_[slot]++;
}
template <class Index>
void ObserveNotFound(Index slot, Status* status) {
Status s = count_builder_.Append(1);
if (ARROW_PREDICT_FALSE(!s.ok())) {
*status = s;
}
}
private:
Int64Builder count_builder_;
};
// ----------------------------------------------------------------------
// Dictionary encode implementation (see HashKernel for description of methods)
class DictEncodeAction final : public ActionBase {
public:
using ActionBase::ActionBase;
DictEncodeAction(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: ActionBase(type, pool), indices_builder_(pool) {}
Status Reset() {
indices_builder_.Reset();
return Status::OK();
}
Status Reserve(const int64_t length) { return indices_builder_.Reserve(length); }
template <class Index>
void ObserveNullFound(Index index) {
indices_builder_.UnsafeAppendNull();
}
template <class Index>
void ObserveNullNotFound(Index index) {
indices_builder_.UnsafeAppendNull();
}
template <class Index>
void ObserveFound(Index index) {
indices_builder_.UnsafeAppend(index);
}
template <class Index>
void ObserveNotFound(Index index) {
ObserveFound(index);
}
Status Flush(Datum* out) {
std::shared_ptr<ArrayData> result;
RETURN_NOT_OK(indices_builder_.FinishInternal(&result));
out->value = std::move(result);
return Status::OK();
}
std::shared_ptr<DataType> out_type() const { return int32(); }
Status FlushFinal(Datum* out) { return Status::OK(); }
private:
Int32Builder indices_builder_;
};
/// \brief Invoke hash table kernel on input array, returning any output
/// values. Implementations should be thread-safe
///
/// This interface is implemented below using visitor pattern on "Action"
/// implementations. It is not consolidate to keep the contract clearer.
class HashKernel : public UnaryKernel {
public:
// Reset for another run.
virtual Status Reset() = 0;
// Prepare the Action for the given input (e.g. reserve appropriately sized
// data structures) and visit the given input with Action.
virtual Status Append(FunctionContext* ctx, const ArrayData& input) = 0;
// Flush out accumulated results from the last invocation of Call.
virtual Status Flush(Datum* out) = 0;
// Flush out accumulated results across all invocations of Call. The kernel
// should not be used until after Reset() is called.
virtual Status FlushFinal(Datum* out) = 0;
// Get the values (keys) acummulated in the dictionary so far.
virtual Status GetDictionary(std::shared_ptr<ArrayData>* out) = 0;
};
// ----------------------------------------------------------------------
// Base class for all hash kernel implementations
class HashKernelImpl : public HashKernel {
public:
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(Datum::ARRAY, input.kind());
RETURN_NOT_OK(Append(ctx, *input.array()));
return Flush(out);
}
Status Append(FunctionContext* ctx, const ArrayData& input) override {
std::lock_guard<std::mutex> guard(lock_);
return Append(input);
}
virtual Status Append(const ArrayData& arr) = 0;
protected:
std::mutex lock_;
};
// ----------------------------------------------------------------------
// Base class for all "regular" hash kernel implementations
// (NullType has a separate implementation)
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
template <typename Type, typename Scalar, typename Action, bool with_error_status = false,
bool with_memo_visit_null = true>
class RegularHashKernelImpl : public HashKernelImpl {
public:
RegularHashKernelImpl(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: pool_(pool), type_(type), action_(type, pool) {}
Status Reset() override {
memo_table_.reset(new MemoTable(pool_, 0));
return action_.Reset();
}
Status Append(const ArrayData& arr) override {
RETURN_NOT_OK(action_.Reserve(arr.length));
return ArrayDataVisitor<Type>::Visit(arr, this);
}
Status Flush(Datum* out) override { return action_.Flush(out); }
Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); }
Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
return DictionaryTraits<Type>::GetDictionaryArrayData(pool_, type_, *memo_table_,
0 /* start_offset */, out);
}
template <typename Enable = Status>
auto VisitNull() -> enable_if_t<!with_error_status, Enable> {
auto on_found = [this](int32_t memo_index) { action_.ObserveNullFound(memo_index); };
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index);
};
if (with_memo_visit_null) {
memo_table_->GetOrInsertNull(on_found, on_not_found);
} else {
action_.ObserveNullNotFound(-1);
}
return Status::OK();
}
template <typename Enable = Status>
auto VisitNull() -> enable_if_t<with_error_status, Enable> {
Status s = Status::OK();
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNotFound(memo_index, &s);
};
if (with_memo_visit_null) {
memo_table_->GetOrInsertNull(on_found, on_not_found);
} else {
action_.ObserveNullNotFound(-1);
}
return s;
}
template <typename Enable = Status>
auto VisitValue(const Scalar& value) ->
typename std::enable_if<!with_error_status, Enable>::type {
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNotFound(memo_index);
};
memo_table_->GetOrInsert(value, on_found, on_not_found);
return Status::OK();
}
template <typename Enable = Status>
auto VisitValue(const Scalar& value) ->
typename std::enable_if<with_error_status, Enable>::type {
Status s = Status::OK();
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNotFound(memo_index, &s);
};
memo_table_->GetOrInsert(value, on_found, on_not_found);
return s;
}
std::shared_ptr<DataType> out_type() const override { return action_.out_type(); }
protected:
using MemoTable = typename HashTraits<Type>::MemoTableType;
MemoryPool* pool_;
std::shared_ptr<DataType> type_;
Action action_;
std::unique_ptr<MemoTable> memo_table_;
};
// ----------------------------------------------------------------------
// Hash kernel implementation for nulls
template <typename Action>
class NullHashKernelImpl : public HashKernelImpl {
public:
NullHashKernelImpl(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: pool_(pool), type_(type), action_(type, pool) {}
Status Reset() override { return action_.Reset(); }
Status Append(const ArrayData& arr) override {
RETURN_NOT_OK(action_.Reserve(arr.length));
for (int64_t i = 0; i < arr.length; ++i) {
if (i == 0) {
action_.ObserveNullNotFound(0);
} else {
action_.ObserveNullFound(0);
}
}
return Status::OK();
}
Status Flush(Datum* out) override { return action_.Flush(out); }
Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); }
Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
// TODO(wesm): handle null being a valid dictionary value
auto null_array = std::make_shared<NullArray>(0);
*out = null_array->data();
return Status::OK();
}
std::shared_ptr<DataType> out_type() const override { return null(); }
protected:
MemoryPool* pool_;
std::shared_ptr<DataType> type_;
Action action_;
};
// ----------------------------------------------------------------------
// Kernel wrapper for generic hash table kernels
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null, typename Enable = void>
struct HashKernelTraits {};
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null>
struct HashKernelTraits<Type, Action, with_error_status, with_memo_visit_null,
enable_if_null<Type>> {
using HashKernelImpl = NullHashKernelImpl<Action>;
};
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null>
struct HashKernelTraits<Type, Action, with_error_status, with_memo_visit_null,
enable_if_has_c_type<Type>> {
using HashKernelImpl = RegularHashKernelImpl<Type, typename Type::c_type, Action,
with_error_status, with_memo_visit_null>;
};
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null>
struct HashKernelTraits<Type, Action, with_error_status, with_memo_visit_null,
enable_if_boolean<Type>> {
using HashKernelImpl =
RegularHashKernelImpl<Type, bool, Action, with_error_status, with_memo_visit_null>;
};
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null>
struct HashKernelTraits<Type, Action, with_error_status, with_memo_visit_null,
enable_if_binary<Type>> {
using HashKernelImpl = RegularHashKernelImpl<Type, util::string_view, Action,
with_error_status, with_memo_visit_null>;
};
template <typename Type, typename Action, bool with_error_status,
bool with_memo_visit_null>
struct HashKernelTraits<Type, Action, with_error_status, with_memo_visit_null,
enable_if_fixed_size_binary<Type>> {
using HashKernelImpl = RegularHashKernelImpl<Type, util::string_view, Action,
with_error_status, with_memo_visit_null>;
};
} // namespace
#define PROCESS_SUPPORTED_HASH_TYPES(PROCESS) \
PROCESS(NullType) \
PROCESS(BooleanType) \
PROCESS(UInt8Type) \
PROCESS(Int8Type) \
PROCESS(UInt16Type) \
PROCESS(Int16Type) \
PROCESS(UInt32Type) \
PROCESS(Int32Type) \
PROCESS(UInt64Type) \
PROCESS(Int64Type) \
PROCESS(FloatType) \
PROCESS(DoubleType) \
PROCESS(Date32Type) \
PROCESS(Date64Type) \
PROCESS(Time32Type) \
PROCESS(Time64Type) \
PROCESS(TimestampType) \
PROCESS(BinaryType) \
PROCESS(StringType) \
PROCESS(FixedSizeBinaryType) \
PROCESS(Decimal128Type)
Status GetUniqueKernel(FunctionContext* ctx, const std::shared_ptr<DataType>& type,
std::unique_ptr<HashKernel>* out) {
std::unique_ptr<HashKernel> kernel;
switch (type->id()) {
#define PROCESS(InType) \
case InType::type_id: \
kernel.reset( \
new \
typename HashKernelTraits<InType, UniqueAction, false, true>::HashKernelImpl( \
type, ctx->memory_pool())); \
break;
PROCESS_SUPPORTED_HASH_TYPES(PROCESS)
#undef PROCESS
default:
break;
}
CHECK_IMPLEMENTED(kernel, "unique", type);
RETURN_NOT_OK(kernel->Reset());
*out = std::move(kernel);
return Status::OK();
}
Status GetDictionaryEncodeKernel(FunctionContext* ctx,
const std::shared_ptr<DataType>& type,
std::unique_ptr<HashKernel>* out) {
std::unique_ptr<HashKernel> kernel;
switch (type->id()) {
#define PROCESS(InType) \
case InType::type_id: \
kernel.reset( \
new typename HashKernelTraits<InType, DictEncodeAction, false, \
false>::HashKernelImpl(type, ctx->memory_pool())); \
break;
PROCESS_SUPPORTED_HASH_TYPES(PROCESS)
#undef PROCESS
default:
break;
}
#undef DICTIONARY_ENCODE_CASE
CHECK_IMPLEMENTED(kernel, "dictionary-encode", type);
RETURN_NOT_OK(kernel->Reset());
*out = std::move(kernel);
return Status::OK();
}
Status GetValueCountsKernel(FunctionContext* ctx, const std::shared_ptr<DataType>& type,
std::unique_ptr<HashKernel>* out) {
std::unique_ptr<HashKernel> kernel;
switch (type->id()) {
#define PROCESS(InType) \
case InType::type_id: \
kernel.reset( \
new typename HashKernelTraits<InType, ValueCountsAction, true, \
true>::HashKernelImpl(type, ctx->memory_pool())); \
break;
PROCESS_SUPPORTED_HASH_TYPES(PROCESS)
#undef PROCESS
default:
break;
}
CHECK_IMPLEMENTED(kernel, "count-values", type);
RETURN_NOT_OK(kernel->Reset());
*out = std::move(kernel);
return Status::OK();
}
namespace {
Status InvokeHash(FunctionContext* ctx, HashKernel* func, const Datum& value,
std::vector<Datum>* kernel_outputs,
std::shared_ptr<Array>* dictionary) {
RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, func, value, kernel_outputs));
std::shared_ptr<ArrayData> dict_data;
RETURN_NOT_OK(func->GetDictionary(&dict_data));
*dictionary = MakeArray(dict_data);
return Status::OK();
}
} // namespace
Status Unique(FunctionContext* ctx, const Datum& value, std::shared_ptr<Array>* out) {
std::unique_ptr<HashKernel> func;
RETURN_NOT_OK(GetUniqueKernel(ctx, value.type(), &func));
std::vector<Datum> dummy_outputs;
return InvokeHash(ctx, func.get(), value, &dummy_outputs, out);
}
Status DictionaryEncode(FunctionContext* ctx, const Datum& value, Datum* out) {
std::unique_ptr<HashKernel> func;
RETURN_NOT_OK(GetDictionaryEncodeKernel(ctx, value.type(), &func));
std::shared_ptr<Array> dictionary;
std::vector<Datum> indices_outputs;
RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &indices_outputs, &dictionary));
// Wrap indices in dictionary arrays for result
std::vector<std::shared_ptr<Array>> dict_chunks;
std::shared_ptr<DataType> dict_type;
if (indices_outputs.size() == 0) {
// Special case: empty was an empty chunked array
DCHECK_EQ(value.kind(), Datum::CHUNKED_ARRAY);
dict_type = ::arrow::dictionary(int32(), dictionary->type());
*out = std::make_shared<ChunkedArray>(dict_chunks, dict_type);
} else {
// Create the dictionary type
DCHECK_EQ(indices_outputs[0].kind(), Datum::ARRAY);
dict_type = ::arrow::dictionary(indices_outputs[0].array()->type, dictionary->type());
// Create DictionaryArray for each piece yielded by the kernel invocations
for (const Datum& datum : indices_outputs) {
dict_chunks.emplace_back(std::make_shared<DictionaryArray>(
dict_type, MakeArray(datum.array()), dictionary));
}
*out = detail::WrapArraysLike(value, dict_chunks);
}
return Status::OK();
}
const char kValuesFieldName[] = "values";
const char kCountsFieldName[] = "counts";
const int32_t kValuesFieldIndex = 0;
const int32_t kCountsFieldIndex = 1;
Status ValueCounts(FunctionContext* ctx, const Datum& value,
std::shared_ptr<Array>* counts) {
std::unique_ptr<HashKernel> func;
RETURN_NOT_OK(GetValueCountsKernel(ctx, value.type(), &func));
// Calls return nothing for counts.
std::vector<Datum> unused_output;
std::shared_ptr<Array> uniques;
RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &unused_output, &uniques));
Datum value_counts;
RETURN_NOT_OK(func->FlushFinal(&value_counts));
auto data_type = std::make_shared<StructType>(std::vector<std::shared_ptr<Field>>{
std::make_shared<Field>(kValuesFieldName, uniques->type()),
std::make_shared<Field>(kCountsFieldName, int64())});
*counts = std::make_shared<StructArray>(
data_type, uniques->length(),
std::vector<std::shared_ptr<Array>>{uniques, MakeArray(value_counts.array())});
return Status::OK();
}
#undef PROCESS_SUPPORTED_HASH_TYPES
} // namespace compute
} // namespace arrow