blob: fb474a6b8b3631749aabea394f8c1b7542c9cf36 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/util/bit_run_reader.h"
#include "arrow/util/tdigest.h"
namespace arrow {
namespace compute {
namespace internal {
namespace {
using arrow::internal::TDigest;
using arrow::internal::VisitSetBitRunsVoid;
template <typename ArrowType>
struct TDigestImpl : public ScalarAggregator {
using ThisType = TDigestImpl<ArrowType>;
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using CType = typename ArrowType::c_type;
explicit TDigestImpl(const TDigestOptions& options)
: q{options.q}, tdigest{options.delta, options.buffer_size} {}
Status Consume(KernelContext*, const ExecBatch& batch) override {
const ArrayData& data = *batch[0].array();
const CType* values = data.GetValues<CType>(1);
if (data.length > data.GetNullCount()) {
VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
[&](int64_t pos, int64_t len) {
for (int64_t i = 0; i < len; ++i) {
this->tdigest.NanAdd(values[pos + i]);
}
});
}
return Status::OK();
}
Status MergeFrom(KernelContext*, KernelState&& src) override {
auto& other = checked_cast<ThisType&>(src);
std::vector<TDigest> other_tdigest;
other_tdigest.push_back(std::move(other.tdigest));
this->tdigest.Merge(&other_tdigest);
return Status::OK();
}
Status Finalize(KernelContext* ctx, Datum* out) override {
const int64_t out_length = this->tdigest.is_empty() ? 0 : this->q.size();
auto out_data = ArrayData::Make(float64(), out_length, 0);
out_data->buffers.resize(2, nullptr);
if (out_length > 0) {
ARROW_ASSIGN_OR_RAISE(out_data->buffers[1],
ctx->Allocate(out_length * sizeof(double)));
double* out_buffer = out_data->template GetMutableValues<double>(1);
for (int64_t i = 0; i < out_length; ++i) {
out_buffer[i] = this->tdigest.Quantile(this->q[i]);
}
}
*out = Datum(std::move(out_data));
return Status::OK();
}
const std::vector<double>& q;
TDigest tdigest;
};
struct TDigestInitState {
std::unique_ptr<KernelState> state;
KernelContext* ctx;
const DataType& in_type;
const TDigestOptions& options;
TDigestInitState(KernelContext* ctx, const DataType& in_type,
const TDigestOptions& options)
: ctx(ctx), in_type(in_type), options(options) {}
Status Visit(const DataType&) {
return Status::NotImplemented("No tdigest implemented");
}
Status Visit(const HalfFloatType&) {
return Status::NotImplemented("No tdigest implemented");
}
template <typename Type>
enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) {
state.reset(new TDigestImpl<Type>(options));
return Status::OK();
}
Result<std::unique_ptr<KernelState>> Create() {
RETURN_NOT_OK(VisitTypeInline(in_type, this));
return std::move(state);
}
};
Result<std::unique_ptr<KernelState>> TDigestInit(KernelContext* ctx,
const KernelInitArgs& args) {
TDigestInitState visitor(ctx, *args.inputs[0].type,
static_cast<const TDigestOptions&>(*args.options));
return visitor.Create();
}
void AddTDigestKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
ScalarAggregateFunction* func) {
for (const auto& ty : types) {
auto sig = KernelSignature::Make({InputType::Array(ty)}, float64());
AddAggKernel(std::move(sig), init, func);
}
}
const FunctionDoc tdigest_doc{
"Approximate quantiles of a numeric array with T-Digest algorithm",
("By default, 0.5 quantile (median) is returned.\n"
"Nulls and NaNs are ignored.\n"
"An empty array is returned if there is no valid data point."),
{"array"},
"TDigestOptions"};
std::shared_ptr<ScalarAggregateFunction> AddTDigestAggKernels() {
static auto default_tdigest_options = TDigestOptions::Defaults();
auto func = std::make_shared<ScalarAggregateFunction>(
"tdigest", Arity::Unary(), &tdigest_doc, &default_tdigest_options);
AddTDigestKernels(TDigestInit, NumericTypes(), func.get());
return func;
}
} // namespace
void RegisterScalarAggregateTDigest(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(AddTDigestAggKernels()));
}
} // namespace internal
} // namespace compute
} // namespace arrow