blob: 507f1716110b17ee98b7b23c5efc91409b2d0609 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <algorithm>
#include <limits>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <gtest/gtest.h>
#include "arrow/array.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/compute/registry.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/int_util_internal.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
using testing::HasSubstr;
namespace arrow {
using internal::BitmapReader;
using internal::checked_cast;
using internal::checked_pointer_cast;
namespace compute {
namespace {
Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum> keys,
const std::vector<internal::Aggregate>& aggregates) {
ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys)));
ARROW_ASSIGN_OR_RAISE(auto grouper,
internal::Grouper::Make(key_batch.GetDescriptors()));
ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
ARROW_ASSIGN_OR_RAISE(
auto groupings, internal::Grouper::MakeGroupings(*id_batch.array_as<UInt32Array>(),
grouper->num_groups()));
ArrayVector out_columns;
std::vector<std::string> out_names;
for (size_t i = 0; i < arguments.size(); ++i) {
out_names.push_back(aggregates[i].function);
// trim "hash_" prefix
auto scalar_agg_function = aggregates[i].function.substr(5);
ARROW_ASSIGN_OR_RAISE(
auto grouped_argument,
internal::Grouper::ApplyGroupings(*groupings, *arguments[i].make_array()));
ScalarVector aggregated_scalars;
for (int64_t i_group = 0; i_group < grouper->num_groups(); ++i_group) {
auto slice = grouped_argument->value_slice(i_group);
if (slice->length() == 0) continue;
ARROW_ASSIGN_OR_RAISE(
Datum d, CallFunction(scalar_agg_function, {slice}, aggregates[i].options));
aggregated_scalars.push_back(d.scalar());
}
ARROW_ASSIGN_OR_RAISE(Datum aggregated_column,
ScalarVectorToArray(aggregated_scalars));
out_columns.push_back(aggregated_column.make_array());
}
int i = 0;
ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques());
for (const Datum& key : uniques.values) {
out_columns.push_back(key.make_array());
out_names.push_back("key_" + std::to_string(i++));
}
return StructArray::Make(std::move(out_columns), std::move(out_names));
}
void ValidateGroupBy(const std::vector<internal::Aggregate>& aggregates,
std::vector<Datum> arguments, std::vector<Datum> keys) {
ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys, aggregates));
ASSERT_OK_AND_ASSIGN(Datum actual, GroupBy(arguments, keys, aggregates));
ASSERT_OK(expected.make_array()->ValidateFull());
ASSERT_OK(actual.make_array()->ValidateFull());
AssertDatumsEqual(expected, actual, /*verbose=*/true);
}
} // namespace
TEST(Grouper, SupportedKeys) {
ASSERT_OK(internal::Grouper::Make({boolean()}));
ASSERT_OK(internal::Grouper::Make({int8(), uint16(), int32(), uint64()}));
ASSERT_OK(internal::Grouper::Make({dictionary(int64(), utf8())}));
ASSERT_OK(internal::Grouper::Make({float16(), float32(), float64()}));
ASSERT_OK(internal::Grouper::Make({utf8(), binary(), large_utf8(), large_binary()}));
ASSERT_OK(internal::Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)}));
ASSERT_OK(internal::Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
ASSERT_OK(internal::Grouper::Make({date32(), date64()}));
for (auto unit : {
TimeUnit::SECOND,
TimeUnit::MILLI,
TimeUnit::MICRO,
TimeUnit::NANO,
}) {
ASSERT_OK(internal::Grouper::Make({timestamp(unit), duration(unit)}));
}
ASSERT_OK(internal::Grouper::Make({day_time_interval(), month_interval()}));
ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({field("", int64())})}));
ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({})}));
ASSERT_RAISES(NotImplemented, internal::Grouper::Make({list(int32())}));
ASSERT_RAISES(NotImplemented, internal::Grouper::Make({fixed_size_list(int32(), 5)}));
ASSERT_RAISES(NotImplemented,
internal::Grouper::Make({dense_union({field("", int32())})}));
}
struct TestGrouper {
explicit TestGrouper(std::vector<ValueDescr> descrs) : descrs_(std::move(descrs)) {
grouper_ = internal::Grouper::Make(descrs_).ValueOrDie();
FieldVector fields;
for (const auto& descr : descrs_) {
fields.push_back(field("", descr.type));
}
key_schema_ = schema(std::move(fields));
}
void ExpectConsume(const std::string& key_json, const std::string& expected) {
ExpectConsume(ExecBatch(*RecordBatchFromJSON(key_schema_, key_json)),
ArrayFromJSON(uint32(), expected));
}
void ExpectConsume(const std::vector<Datum>& key_batch, Datum expected) {
ExpectConsume(*ExecBatch::Make(key_batch), expected);
}
void ExpectConsume(const ExecBatch& key_batch, Datum expected) {
Datum ids;
ConsumeAndValidate(key_batch, &ids);
AssertDatumsEqual(expected, ids, /*verbose=*/true);
}
void ConsumeAndValidate(const ExecBatch& key_batch, Datum* ids = nullptr) {
ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Consume(key_batch));
ValidateConsume(key_batch, id_batch);
if (ids) {
*ids = std::move(id_batch);
}
}
void ValidateConsume(const ExecBatch& key_batch, const Datum& id_batch) {
if (uniques_.length == -1) {
ASSERT_OK_AND_ASSIGN(uniques_, grouper_->GetUniques());
} else if (static_cast<int64_t>(grouper_->num_groups()) > uniques_.length) {
ASSERT_OK_AND_ASSIGN(ExecBatch new_uniques, grouper_->GetUniques());
// check that uniques_ are prefixes of new_uniques
for (int i = 0; i < uniques_.num_values(); ++i) {
auto new_unique = new_uniques[i].make_array();
ASSERT_OK(new_unique->ValidateFull());
AssertDatumsEqual(uniques_[i], new_unique->Slice(0, uniques_.length),
/*verbose=*/true);
}
uniques_ = std::move(new_uniques);
}
// check that the ids encode an equivalent key sequence
auto ids = id_batch.make_array();
ASSERT_OK(ids->ValidateFull());
for (int i = 0; i < key_batch.num_values(); ++i) {
SCOPED_TRACE(std::to_string(i) + "th key array");
auto original = key_batch[i].make_array();
ASSERT_OK_AND_ASSIGN(auto encoded, Take(*uniques_[i].make_array(), *ids));
AssertArraysEqual(*original, *encoded, /*verbose=*/true,
EqualOptions().nans_equal(true));
}
}
std::vector<ValueDescr> descrs_;
std::shared_ptr<Schema> key_schema_;
std::unique_ptr<internal::Grouper> grouper_;
ExecBatch uniques_ = ExecBatch({}, -1);
};
TEST(Grouper, BooleanKey) {
TestGrouper g({boolean()});
g.ExpectConsume("[[true], [true]]", "[0, 0]");
g.ExpectConsume("[[true], [true]]", "[0, 0]");
g.ExpectConsume("[[false], [null]]", "[1, 2]");
g.ExpectConsume("[[true], [false], [true], [false], [null], [false], [null]]",
"[0, 1, 0, 1, 2, 1, 2]");
}
TEST(Grouper, NumericKey) {
for (auto ty : {
uint8(),
int8(),
uint16(),
int16(),
uint32(),
int32(),
uint64(),
int64(),
float16(),
float32(),
float64(),
}) {
SCOPED_TRACE("key type: " + ty->ToString());
TestGrouper g({ty});
g.ExpectConsume("[[3], [3]]", "[0, 0]");
g.ExpectConsume("[[3], [3]]", "[0, 0]");
g.ExpectConsume("[[27], [81]]", "[1, 2]");
g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]",
"[0, 1, 0, 1, 3, 2, 1, 2]");
}
}
TEST(Grouper, FloatingPointKey) {
TestGrouper g({float32()});
// -0.0 hashes differently from 0.0
g.ExpectConsume("[[0.0], [-0.0]]", "[0, 1]");
g.ExpectConsume("[[Inf], [-Inf]]", "[2, 3]");
// assert(!(NaN == NaN)) does not cause spurious new groups
g.ExpectConsume("[[NaN], [NaN]]", "[4, 4]");
// TODO(bkietz) test denormal numbers, more NaNs
}
TEST(Grouper, StringKey) {
for (auto ty : {utf8(), large_utf8(), fixed_size_binary(2)}) {
SCOPED_TRACE("key type: " + ty->ToString());
TestGrouper g({ty});
g.ExpectConsume(R"([["eh"], ["eh"]])", "[0, 0]");
g.ExpectConsume(R"([["eh"], ["eh"]])", "[0, 0]");
g.ExpectConsume(R"([["be"], [null]])", "[1, 2]");
}
}
TEST(Grouper, DictKey) {
TestGrouper g({dictionary(int32(), utf8())});
// For dictionary keys, all batches must share a single dictionary.
// Eventually, differing dictionaries will be unified and indices transposed
// during encoding to relieve this restriction.
const auto dict = ArrayFromJSON(utf8(), R"(["ex", "why", "zee", null])");
auto WithIndices = [&](const std::string& indices) {
return Datum(*DictionaryArray::FromArrays(ArrayFromJSON(int32(), indices), dict));
};
// NB: null index is not considered equivalent to index=3 (which encodes null in dict)
g.ExpectConsume({WithIndices(" [3, 1, null, 0, 2]")},
ArrayFromJSON(uint32(), "[0, 1, 2, 3, 4]"));
g = TestGrouper({dictionary(int32(), utf8())});
g.ExpectConsume({WithIndices(" [0, 1, 2, 3, null]")},
ArrayFromJSON(uint32(), "[0, 1, 2, 3, 4]"));
g.ExpectConsume({WithIndices(" [3, 1, null, 0, 2]")},
ArrayFromJSON(uint32(), "[3, 1, 4, 0, 2]"));
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, HasSubstr("Unifying differing dictionaries"),
g.grouper_->Consume(*ExecBatch::Make({*DictionaryArray::FromArrays(
ArrayFromJSON(int32(), "[0, 1]"),
ArrayFromJSON(utf8(), R"(["different", "dictionary"])"))})));
}
TEST(Grouper, StringInt64Key) {
TestGrouper g({utf8(), int64()});
g.ExpectConsume(R"([["eh", 0], ["eh", 0]])", "[0, 0]");
g.ExpectConsume(R"([["eh", 0], ["eh", null]])", "[0, 1]");
g.ExpectConsume(R"([["eh", 1], ["bee", 1]])", "[2, 3]");
g.ExpectConsume(R"([["eh", null], ["bee", 1]])", "[1, 3]");
g = TestGrouper({utf8(), int64()});
g.ExpectConsume(R"([
["ex", 0],
["ex", 0],
["why", 0],
["ex", 1],
["why", 0],
["ex", 1],
["ex", 0],
["why", 1]
])",
"[0, 0, 1, 2, 1, 2, 0, 3]");
g.ExpectConsume(R"([
["ex", 0],
[null, 0],
[null, 0],
["ex", 1],
[null, null],
["ex", 1],
["ex", 0],
["why", null]
])",
"[0, 4, 4, 2, 5, 2, 0, 6]");
}
TEST(Grouper, DoubleStringInt64Key) {
TestGrouper g({float64(), utf8(), int64()});
g.ExpectConsume(R"([[1.5, "eh", 0], [1.5, "eh", 0]])", "[0, 0]");
g.ExpectConsume(R"([[1.5, "eh", 0], [1.5, "eh", 0]])", "[0, 0]");
g.ExpectConsume(R"([[1.0, "eh", 0], [1.0, "be", null]])", "[1, 2]");
// note: -0 and +0 hash differently
g.ExpectConsume(R"([[-0.0, "be", 7], [0.0, "be", 7]])", "[3, 4]");
}
TEST(Grouper, RandomInt64Keys) {
TestGrouper g({int64()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
g.ConsumeAndValidate(key_batch);
}
}
TEST(Grouper, RandomStringInt64Keys) {
TestGrouper g({utf8(), int64()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
g.ConsumeAndValidate(key_batch);
}
}
TEST(Grouper, RandomStringInt64DoubleInt32Keys) {
TestGrouper g({utf8(), int64(), float64(), int32()});
for (int i = 0; i < 4; ++i) {
SCOPED_TRACE(std::to_string(i) + "th key batch");
ExecBatch key_batch{
*random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
g.ConsumeAndValidate(key_batch);
}
}
TEST(Grouper, MakeGroupings) {
auto ExpectGroupings = [](std::string ids_json, std::string expected_json) {
auto ids = checked_pointer_cast<UInt32Array>(ArrayFromJSON(uint32(), ids_json));
auto expected = ArrayFromJSON(list(int32()), expected_json);
auto num_groups = static_cast<uint32_t>(expected->length());
ASSERT_OK_AND_ASSIGN(auto actual, internal::Grouper::MakeGroupings(*ids, num_groups));
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
// validate ApplyGroupings
ASSERT_OK_AND_ASSIGN(auto grouped_ids,
internal::Grouper::ApplyGroupings(*actual, *ids));
for (uint32_t group = 0; group < num_groups; ++group) {
auto ids_slice = checked_pointer_cast<UInt32Array>(grouped_ids->value_slice(group));
for (auto slot : *ids_slice) {
EXPECT_EQ(slot, group);
}
}
};
ExpectGroupings("[]", "[[]]");
ExpectGroupings("[0, 0, 0]", "[[0, 1, 2]]");
ExpectGroupings("[0, 0, 0, 1, 1, 2]", "[[0, 1, 2], [3, 4], [5], []]");
ExpectGroupings("[2, 1, 2, 1, 1, 2]", "[[], [1, 3, 4], [0, 2, 5], [], []]");
ExpectGroupings("[2, 2, 5, 5, 2, 3]", "[[], [], [0, 1, 4], [5], [], [2, 3], [], []]");
auto ids = checked_pointer_cast<UInt32Array>(ArrayFromJSON(uint32(), "[0, null, 1]"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("MakeGroupings with null ids"),
internal::Grouper::MakeGroupings(*ids, 5));
}
TEST(GroupBy, Errors) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("group_id", uint32())}), R"([
[1.0, 1],
[null, 1],
[0.0, 2],
[null, 3],
[4.0, 0],
[3.25, 1],
[0.125, 2],
[-0.25, 2],
[0.75, 0],
[null, 3]
])");
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, HasSubstr("Direct execution of HASH_AGGREGATE functions"),
CallFunction("hash_sum", {batch->GetColumnByName("argument"),
batch->GetColumnByName("group_id"), Datum(uint32_t(4))}));
}
TEST(GroupBy, SumOnly) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
[null, 1],
[0.0, 2],
[null, 3],
[4.0, null],
[3.25, 1],
[0.125, 2],
[-0.25, 2],
[0.75, null],
[null, 3]
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
{
{"hash_sum", nullptr},
}));
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_sum", float64()),
field("key_0", int64()),
}),
R"([
[4.25, 1],
[-0.125, 2],
[null, 3],
[4.75, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
TEST(GroupBy, MinMaxOnly) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
[null, 1],
[0.0, 2],
[null, 3],
[4.0, null],
[3.25, 1],
[0.125, 2],
[-0.25, 2],
[0.75, null],
[null, 3]
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
{
{"hash_min_max", nullptr},
}));
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_min_max", struct_({
field("min", float64()),
field("max", float64()),
})),
field("key_0", int64()),
}),
R"([
[{"min": 1.0, "max": 3.25}, 1],
[{"min": -0.25, "max": 0.125}, 2],
[{"min": null, "max": null}, 3],
[{"min": 0.75, "max": 4.0}, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
TEST(GroupBy, CountAndSum) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
[null, 1],
[0.0, 2],
[null, 3],
[4.0, null],
[3.25, 1],
[0.125, 2],
[-0.25, 2],
[0.75, null],
[null, 3]
])");
CountOptions count_options;
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
internal::GroupBy(
{
// NB: passing an argument twice or also using it as a key is legal
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
batch->GetColumnByName("key"),
},
{
batch->GetColumnByName("key"),
},
{
{"hash_count", &count_options},
{"hash_sum", nullptr},
{"hash_sum", nullptr},
}));
AssertDatumsEqual(
ArrayFromJSON(struct_({
field("hash_count", int64()),
// NB: summing a float32 array results in float64 sums
field("hash_sum", float64()),
field("hash_sum", int64()),
field("key_0", int64()),
}),
R"([
[2, 4.25, 3, 1],
[3, -0.125, 6, 2],
[0, null, 6, 3],
[2, 4.75, null, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
TEST(GroupBy, SumOnlyStringAndDictKeys) {
for (auto key_type : {utf8(), dictionary(int32(), utf8())}) {
SCOPED_TRACE("key type: " + key_type->ToString());
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", key_type)}), R"([
[1.0, "alfa"],
[null, "alfa"],
[0.0, "beta"],
[null, "gama"],
[4.0, null ],
[3.25, "alfa"],
[0.125, "beta"],
[-0.25, "beta"],
[0.75, null ],
[null, "gama"]
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
{
{"hash_sum", nullptr},
}));
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_sum", float64()),
field("key_0", key_type),
}),
R"([
[4.25, "alfa"],
[-0.125, "beta"],
[null, "gama"],
[4.75, null ]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
}
TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", utf8())}), R"([
[1.0, "alfa"],
[null, "alfa"],
[0.0, "beta"],
[null, "gama"],
[4.0, null ],
[3.25, "alfa"],
[0.125, "beta"],
[-0.25, "beta"],
[0.75, null ],
[null, "gama"]
])");
CountOptions count_non_null{CountOptions::COUNT_NON_NULL},
count_null{CountOptions::COUNT_NULL};
MinMaxOptions emit_null{MinMaxOptions::EMIT_NULL};
using internal::Aggregate;
for (auto agg : {
Aggregate{"hash_sum", nullptr},
Aggregate{"hash_count", &count_non_null},
Aggregate{"hash_count", &count_null},
Aggregate{"hash_min_max", nullptr},
Aggregate{"hash_min_max", &emit_null},
}) {
SCOPED_TRACE(agg.function);
ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")});
}
}
// Count nulls/non_nulls from record batch with no nulls
TEST(GroupBy, CountNull) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", utf8())}), R"([
[1.0, "alfa"],
[2.0, "beta"],
[3.0, "gama"]
])");
CountOptions count_non_null{CountOptions::COUNT_NON_NULL},
count_null{CountOptions::COUNT_NULL};
using internal::Aggregate;
for (auto agg : {
Aggregate{"hash_count", &count_non_null},
Aggregate{"hash_count", &count_null},
}) {
SCOPED_TRACE(agg.function);
ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")});
}
}
TEST(GroupBy, RandomArraySum) {
for (int64_t length : {1 << 10, 1 << 12, 1 << 15}) {
for (auto null_probability : {0.0, 0.01, 0.5, 1.0}) {
auto batch = random::GenerateBatch(
{
field("argument", float32(),
key_value_metadata(
{{"null_probability", std::to_string(null_probability)}})),
field("key", int64(), key_value_metadata({{"min", "0"}, {"max", "100"}})),
},
length, 0xDEADBEEF);
ValidateGroupBy(
{
{"hash_sum", nullptr},
},
{batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")});
}
}
}
} // namespace compute
} // namespace arrow