| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <algorithm> |
| #include <memory> |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include <gtest/gtest.h> |
| |
| #include "arrow/array.h" |
| #include "arrow/compute/api.h" |
| #include "arrow/compute/kernels/test_util.h" |
| #include "arrow/testing/gtest_common.h" |
| #include "arrow/testing/gtest_util.h" |
| #include "arrow/testing/random.h" |
| #include "arrow/type.h" |
| #include "arrow/type_traits.h" |
| #include "arrow/util/bitmap_reader.h" |
| #include "arrow/util/checked_cast.h" |
| |
| namespace arrow { |
| |
| using internal::BitmapReader; |
| |
| namespace compute { |
| |
| using util::string_view; |
| |
| template <typename ArrowType> |
| static void ValidateCompare(CompareOptions options, const Datum& lhs, const Datum& rhs, |
| const Datum& expected) { |
| ASSERT_OK_AND_ASSIGN(Datum result, Compare(lhs, rhs, options)); |
| AssertArraysEqual(*expected.make_array(), *result.make_array(), |
| /*verbose=*/true); |
| } |
| |
| template <typename ArrowType> |
| static void ValidateCompare(CompareOptions options, const char* lhs_str, const Datum& rhs, |
| const char* expected_str) { |
| auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str); |
| auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str); |
| ValidateCompare<ArrowType>(options, lhs, rhs, expected); |
| } |
| |
| template <typename ArrowType> |
| static void ValidateCompare(CompareOptions options, const Datum& lhs, const char* rhs_str, |
| const char* expected_str) { |
| auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str); |
| auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str); |
| ValidateCompare<ArrowType>(options, lhs, rhs, expected); |
| } |
| |
| template <typename ArrowType> |
| static void ValidateCompare(CompareOptions options, const char* lhs_str, |
| const char* rhs_str, const char* expected_str) { |
| auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str); |
| auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str); |
| auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str); |
| ValidateCompare<ArrowType>(options, lhs, rhs, expected); |
| } |
| |
| template <typename T> |
| static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { |
| switch (op) { |
| case EQUAL: |
| return lhs == rhs; |
| case NOT_EQUAL: |
| return lhs != rhs; |
| case GREATER: |
| return lhs > rhs; |
| case GREATER_EQUAL: |
| return lhs >= rhs; |
| case LESS: |
| return lhs < rhs; |
| case LESS_EQUAL: |
| return lhs <= rhs; |
| default: |
| return false; |
| } |
| } |
| |
| template <typename ArrowType> |
| Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, |
| const Datum& rhs) { |
| using ArrayType = typename TypeTraits<ArrowType>::ArrayType; |
| using ScalarType = typename TypeTraits<ArrowType>::ScalarType; |
| |
| bool swap = lhs.is_array(); |
| auto array = std::static_pointer_cast<ArrayType>((swap ? lhs : rhs).make_array()); |
| auto value = std::static_pointer_cast<ScalarType>((swap ? rhs : lhs).scalar())->value; |
| |
| std::vector<bool> bitmap(array->length()); |
| for (int64_t i = 0; i < array->length(); i++) { |
| bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value) |
| : SlowCompare(options.op, value, array->Value(i)); |
| } |
| |
| std::shared_ptr<Array> result; |
| |
| if (array->null_count() == 0) { |
| ArrayFromVector<BooleanType>(bitmap, &result); |
| } else { |
| std::vector<bool> null_bitmap(array->length()); |
| auto reader = |
| BitmapReader(array->null_bitmap_data(), array->offset(), array->length()); |
| for (int64_t i = 0; i < array->length(); i++, reader.Next()) { |
| null_bitmap[i] = reader.IsSet(); |
| } |
| ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result); |
| } |
| |
| return Datum(result); |
| } |
| |
| template <> |
| Datum SimpleScalarArrayCompare<StringType>(CompareOptions options, const Datum& lhs, |
| const Datum& rhs) { |
| bool swap = lhs.is_array(); |
| auto array = std::static_pointer_cast<StringArray>((swap ? lhs : rhs).make_array()); |
| auto value = util::string_view( |
| *std::static_pointer_cast<StringScalar>((swap ? rhs : lhs).scalar())->value); |
| |
| std::vector<bool> bitmap(array->length()); |
| for (int64_t i = 0; i < array->length(); i++) { |
| bitmap[i] = swap ? SlowCompare(options.op, array->GetView(i), value) |
| : SlowCompare(options.op, value, array->GetView(i)); |
| } |
| |
| std::shared_ptr<Array> result; |
| |
| if (array->null_count() == 0) { |
| ArrayFromVector<BooleanType>(bitmap, &result); |
| } else { |
| std::vector<bool> null_bitmap(array->length()); |
| auto reader = |
| BitmapReader(array->null_bitmap_data(), array->offset(), array->length()); |
| for (int64_t i = 0; i < array->length(); i++, reader.Next()) { |
| null_bitmap[i] = reader.IsSet(); |
| } |
| ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result); |
| } |
| |
| return Datum(result); |
| } |
| |
| template <typename ArrayType> |
| std::vector<bool> NullBitmapFromArrays(const ArrayType& lhs, const ArrayType& rhs) { |
| auto left_lambda = [&lhs](int64_t i) { |
| return lhs.null_count() == 0 ? true : lhs.IsValid(i); |
| }; |
| |
| auto right_lambda = [&rhs](int64_t i) { |
| return rhs.null_count() == 0 ? true : rhs.IsValid(i); |
| }; |
| |
| const int64_t length = lhs.length(); |
| std::vector<bool> null_bitmap(length); |
| |
| for (int64_t i = 0; i < length; i++) { |
| null_bitmap[i] = left_lambda(i) && right_lambda(i); |
| } |
| |
| return null_bitmap; |
| } |
| |
| template <typename ArrowType> |
| Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs, |
| const Datum& rhs) { |
| using ArrayType = typename TypeTraits<ArrowType>::ArrayType; |
| |
| auto l_array = std::static_pointer_cast<ArrayType>(lhs.make_array()); |
| auto r_array = std::static_pointer_cast<ArrayType>(rhs.make_array()); |
| const int64_t length = l_array->length(); |
| |
| std::vector<bool> bitmap(length); |
| for (int64_t i = 0; i < length; i++) { |
| bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i)); |
| } |
| |
| std::shared_ptr<Array> result; |
| |
| if (l_array->null_count() == 0 && r_array->null_count() == 0) { |
| ArrayFromVector<BooleanType>(bitmap, &result); |
| } else { |
| std::vector<bool> null_bitmap = NullBitmapFromArrays(*l_array, *r_array); |
| ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result); |
| } |
| |
| return Datum(result); |
| } |
| |
| template <> |
| Datum SimpleArrayArrayCompare<StringType>(CompareOptions options, const Datum& lhs, |
| const Datum& rhs) { |
| auto l_array = std::static_pointer_cast<StringArray>(lhs.make_array()); |
| auto r_array = std::static_pointer_cast<StringArray>(rhs.make_array()); |
| const int64_t length = l_array->length(); |
| |
| std::vector<bool> bitmap(length); |
| for (int64_t i = 0; i < length; i++) { |
| bitmap[i] = SlowCompare(options.op, l_array->GetView(i), r_array->GetView(i)); |
| } |
| |
| std::shared_ptr<Array> result; |
| |
| if (l_array->null_count() == 0 && r_array->null_count() == 0) { |
| ArrayFromVector<BooleanType>(bitmap, &result); |
| } else { |
| std::vector<bool> null_bitmap = NullBitmapFromArrays(*l_array, *r_array); |
| ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result); |
| } |
| |
| return Datum(result); |
| } |
| |
| template <typename ArrowType> |
| void ValidateCompare(CompareOptions options, const Datum& lhs, const Datum& rhs) { |
| Datum result; |
| |
| bool has_scalar = lhs.is_scalar() || rhs.is_scalar(); |
| Datum expected = has_scalar ? SimpleScalarArrayCompare<ArrowType>(options, lhs, rhs) |
| : SimpleArrayArrayCompare<ArrowType>(options, lhs, rhs); |
| |
| ValidateCompare<ArrowType>(options, lhs, rhs, expected); |
| } |
| |
| template <typename ArrowType> |
| class TestNumericCompareKernel : public ::testing::Test {}; |
| |
| TYPED_TEST_SUITE(TestNumericCompareKernel, NumericArrowTypes); |
| TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { |
| using ScalarType = typename TypeTraits<TypeParam>::ScalarType; |
| using CType = typename TypeTraits<TypeParam>::CType; |
| |
| Datum one(std::make_shared<ScalarType>(CType(1))); |
| |
| CompareOptions eq(CompareOperator::EQUAL); |
| ValidateCompare<TypeParam>(eq, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(eq, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); |
| ValidateCompare<TypeParam>(eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); |
| ValidateCompare<TypeParam>(eq, "[null,0,1,1]", one, "[null,0,1,1]"); |
| |
| CompareOptions neq(CompareOperator::NOT_EQUAL); |
| ValidateCompare<TypeParam>(neq, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(neq, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); |
| ValidateCompare<TypeParam>(neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); |
| ValidateCompare<TypeParam>(neq, "[null,0,1,1]", one, "[null,1,0,0]"); |
| |
| CompareOptions gt(CompareOperator::GREATER); |
| ValidateCompare<TypeParam>(gt, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(gt, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); |
| ValidateCompare<TypeParam>(gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(gt, "[null,0,1,1]", one, "[null,0,0,0]"); |
| |
| CompareOptions gte(CompareOperator::GREATER_EQUAL); |
| ValidateCompare<TypeParam>(gte, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(gte, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(gte, "[null,0,1,1]", one, "[null,0,1,1]"); |
| |
| CompareOptions lt(CompareOperator::LESS); |
| ValidateCompare<TypeParam>(lt, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(lt, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(lt, "[null,0,1,1]", one, "[null,1,0,0]"); |
| |
| CompareOptions lte(CompareOperator::LESS_EQUAL); |
| ValidateCompare<TypeParam>(lte, "[]", one, "[]"); |
| ValidateCompare<TypeParam>(lte, "[null]", one, "[null]"); |
| ValidateCompare<TypeParam>(lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); |
| ValidateCompare<TypeParam>(lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(lte, "[null,0,1,1]", one, "[null,1,1,1]"); |
| } |
| |
| TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) { |
| using ScalarType = typename TypeTraits<TypeParam>::ScalarType; |
| using CType = typename TypeTraits<TypeParam>::CType; |
| |
| Datum one(std::make_shared<ScalarType>(CType(1))); |
| |
| CompareOptions eq(CompareOperator::EQUAL); |
| ValidateCompare<TypeParam>(eq, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(eq, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]"); |
| ValidateCompare<TypeParam>(eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]"); |
| ValidateCompare<TypeParam>(eq, one, "[null,0,1,1]", "[null,0,1,1]"); |
| |
| CompareOptions neq(CompareOperator::NOT_EQUAL); |
| ValidateCompare<TypeParam>(neq, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(neq, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]"); |
| ValidateCompare<TypeParam>(neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]"); |
| ValidateCompare<TypeParam>(neq, one, "[null,0,1,1]", "[null,1,0,0]"); |
| |
| CompareOptions gt(CompareOperator::GREATER); |
| ValidateCompare<TypeParam>(gt, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(gt, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(gt, one, "[null,0,1,1]", "[null,1,0,0]"); |
| |
| CompareOptions gte(CompareOperator::GREATER_EQUAL); |
| ValidateCompare<TypeParam>(gte, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(gte, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]"); |
| ValidateCompare<TypeParam>(gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]"); |
| ValidateCompare<TypeParam>(gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); |
| ValidateCompare<TypeParam>(gte, one, "[null,0,1,1]", "[null,1,1,1]"); |
| |
| CompareOptions lt(CompareOperator::LESS); |
| ValidateCompare<TypeParam>(lt, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(lt, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]"); |
| ValidateCompare<TypeParam>(lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(lt, one, "[null,0,1,1]", "[null,0,0,0]"); |
| |
| CompareOptions lte(CompareOperator::LESS_EQUAL); |
| ValidateCompare<TypeParam>(lte, one, "[]", "[]"); |
| ValidateCompare<TypeParam>(lte, one, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]"); |
| ValidateCompare<TypeParam>(lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); |
| ValidateCompare<TypeParam>(lte, one, "[null,0,1,1]", "[null,0,1,1]"); |
| } |
| |
| TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { |
| /* Ensure that null scalar broadcast to all null results. */ |
| using ScalarType = typename TypeTraits<TypeParam>::ScalarType; |
| |
| Datum null(std::make_shared<ScalarType>()); |
| EXPECT_FALSE(null.scalar()->is_valid); |
| |
| CompareOptions eq(CompareOperator::EQUAL); |
| ValidateCompare<TypeParam>(eq, "[]", null, "[]"); |
| ValidateCompare<TypeParam>(eq, null, "[]", "[]"); |
| ValidateCompare<TypeParam>(eq, "[null]", null, "[null]"); |
| ValidateCompare<TypeParam>(eq, null, "[null]", "[null]"); |
| ValidateCompare<TypeParam>(eq, null, "[1,2,3]", "[null, null, null]"); |
| } |
| |
| TYPED_TEST_SUITE(TestNumericCompareKernel, NumericArrowTypes); |
| |
| template <typename Type> |
| struct CompareRandomNumeric { |
| static void Test(const std::shared_ptr<DataType>& type) { |
| using ScalarType = typename TypeTraits<Type>::ScalarType; |
| using CType = typename TypeTraits<Type>::CType; |
| auto rand = random::RandomArrayGenerator(0x5416447); |
| const int64_t length = 1000; |
| for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { |
| for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { |
| auto data = |
| rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability); |
| |
| auto data1 = |
| rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability); |
| auto data2 = |
| rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability); |
| |
| // Create view of data as the type (e.g. timestamp) |
| auto array1 = Datum(*data1->View(type)); |
| auto array2 = Datum(*data2->View(type)); |
| auto fifty = Datum(std::make_shared<ScalarType>(CType(50), type)); |
| auto options = CompareOptions(op); |
| |
| ValidateCompare<Type>(options, array1, fifty); |
| ValidateCompare<Type>(options, fifty, array1); |
| ValidateCompare<Type>(options, array1, array2); |
| } |
| } |
| } |
| }; |
| |
| TEST(TestCompareKernel, PrimitiveRandomTests) { |
| TestRandomPrimitiveCTypes<CompareRandomNumeric>(); |
| } |
| |
| TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) { |
| /* Ensure that null scalar broadcast to all null results. */ |
| CompareOptions eq(CompareOperator::EQUAL); |
| ValidateCompare<TypeParam>(eq, "[]", "[]", "[]"); |
| ValidateCompare<TypeParam>(eq, "[null]", "[null]", "[null]"); |
| ValidateCompare<TypeParam>(eq, "[1]", "[1]", "[1]"); |
| ValidateCompare<TypeParam>(eq, "[1]", "[2]", "[0]"); |
| ValidateCompare<TypeParam>(eq, "[null]", "[1]", "[null]"); |
| ValidateCompare<TypeParam>(eq, "[1]", "[null]", "[null]"); |
| |
| CompareOptions lte(CompareOperator::LESS_EQUAL); |
| ValidateCompare<TypeParam>(lte, "[1,2,3,4,5]", "[2,3,4,5,6]", "[1,1,1,1,1]"); |
| } |
| |
| TEST(TestCompareTimestamps, Basics) { |
| const char* example1_json = R"(["1970-01-01","2000-02-29","1900-02-28"])"; |
| const char* example2_json = R"(["1970-01-02","2000-02-01","1900-02-28"])"; |
| |
| auto CheckArrayCase = [&](std::shared_ptr<DataType> type, CompareOperator op, |
| const char* expected_json) { |
| auto lhs = ArrayFromJSON(type, example1_json); |
| auto rhs = ArrayFromJSON(type, example2_json); |
| auto expected = ArrayFromJSON(boolean(), expected_json); |
| ASSERT_OK_AND_ASSIGN(Datum result, Compare(lhs, rhs, CompareOptions(op))); |
| AssertArraysEqual(*expected, *result.make_array(), /*verbose=*/true); |
| }; |
| |
| auto seconds = timestamp(TimeUnit::SECOND); |
| auto millis = timestamp(TimeUnit::MILLI); |
| auto micros = timestamp(TimeUnit::MICRO); |
| auto nanos = timestamp(TimeUnit::NANO); |
| |
| CheckArrayCase(seconds, CompareOperator::EQUAL, "[false, false, true]"); |
| CheckArrayCase(seconds, CompareOperator::NOT_EQUAL, "[true, true, false]"); |
| CheckArrayCase(seconds, CompareOperator::LESS, "[true, false, false]"); |
| CheckArrayCase(seconds, CompareOperator::LESS_EQUAL, "[true, false, true]"); |
| CheckArrayCase(seconds, CompareOperator::GREATER, "[false, true, false]"); |
| CheckArrayCase(seconds, CompareOperator::GREATER_EQUAL, "[false, true, true]"); |
| |
| // Check that comparisons with tz-aware timestamps work fine |
| auto seconds_utc = timestamp(TimeUnit::SECOND, "utc"); |
| CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]"); |
| } |
| |
| TEST(TestCompareKernel, DispatchBest) { |
| for (std::string name : |
| {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { |
| CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); |
| CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); |
| |
| CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()}); |
| |
| CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()}); |
| CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()}); |
| CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()}); |
| |
| CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()}); |
| CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()}); |
| |
| CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); |
| CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); |
| CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); |
| |
| CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, |
| {float64(), float64()}); |
| CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, |
| {float64(), float64()}); |
| |
| CheckDispatchBest(name, {timestamp(TimeUnit::MICRO), date64()}, |
| {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); |
| |
| CheckDispatchBest(name, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)}, |
| {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); |
| |
| CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()}); |
| CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()}); |
| } |
| } |
| |
| TEST(TestCompareKernel, GreaterWithImplicitCasts) { |
| CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), |
| ArrayFromJSON(float64(), "[0.5, 1.0, 1.5, 2.0]"), |
| ArrayFromJSON(boolean(), "[false, false, true, null]")); |
| |
| CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), |
| ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), |
| ArrayFromJSON(boolean(), "[false, false, true, null]")); |
| |
| CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), |
| ArrayFromJSON(uint8(), "[255, 254, 1, 0]"), |
| ArrayFromJSON(boolean(), "[false, false, true, null]")); |
| |
| CheckScalarBinary("greater", |
| ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), |
| ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), |
| ArrayFromJSON(boolean(), "[false, false, false, null]")); |
| |
| CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), |
| std::make_shared<NullArray>(4), |
| ArrayFromJSON(boolean(), "[null, null, null, null]")); |
| |
| CheckScalarBinary("greater", |
| ArrayFromJSON(timestamp(TimeUnit::SECOND), |
| R"(["1970-01-01","2000-02-29","1900-02-28"])"), |
| ArrayFromJSON(date64(), "[86400000, 0, 86400000]"), |
| ArrayFromJSON(boolean(), "[false, true, false]")); |
| |
| CheckScalarBinary("greater", |
| ArrayFromJSON(dictionary(int32(), int8()), "[3, -3, -28, null]"), |
| ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), |
| ArrayFromJSON(boolean(), "[false, false, false, null]")); |
| } |
| |
| TEST(TestCompareKernel, GreaterWithImplicitCastsUint64EdgeCase) { |
| // int64 is as wide as we can promote |
| CheckDispatchBest("greater", {int8(), uint64()}, {int64(), int64()}); |
| |
| // this works sometimes |
| CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-1]"), |
| ArrayFromJSON(uint64(), "[0]"), ArrayFromJSON(boolean(), "[false]")); |
| |
| // ... but it can result in impossible implicit casts in the presence of uint64, since |
| // some uint64 values cannot be cast to int64: |
| ASSERT_RAISES( |
| Invalid, |
| CallFunction("greater", {ArrayFromJSON(int64(), "[-1]"), |
| ArrayFromJSON(uint64(), "[18446744073709551615]")})); |
| } |
| |
| class TestStringCompareKernel : public ::testing::Test {}; |
| |
| TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { |
| Datum one(std::make_shared<StringScalar>("one")); |
| |
| CompareOptions eq(CompareOperator::EQUAL); |
| ValidateCompare<StringType>(eq, "[]", one, "[]"); |
| ValidateCompare<StringType>(eq, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(eq, R"(["zero","zero","one","one","two","two"])", one, |
| "[0,0,1,1,0,0]"); |
| ValidateCompare<StringType>(eq, R"(["zero","one","two","three","four","five"])", one, |
| "[0,1,0,0,0,0]"); |
| ValidateCompare<StringType>(eq, R"(["five","four","three","two","one","zero"])", one, |
| "[0,0,0,0,1,0]"); |
| ValidateCompare<StringType>(eq, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); |
| |
| Datum na(std::make_shared<StringScalar>()); |
| ValidateCompare<StringType>(eq, R"([null,"zero","one","one"])", na, |
| "[null,null,null,null]"); |
| ValidateCompare<StringType>(eq, na, R"([null,"zero","one","one"])", |
| "[null,null,null,null]"); |
| |
| CompareOptions neq(CompareOperator::NOT_EQUAL); |
| ValidateCompare<StringType>(neq, "[]", one, "[]"); |
| ValidateCompare<StringType>(neq, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(neq, R"(["zero","zero","one","one","two","two"])", one, |
| "[1,1,0,0,1,1]"); |
| ValidateCompare<StringType>(neq, R"(["zero","one","two","three","four","five"])", one, |
| "[1,0,1,1,1,1]"); |
| ValidateCompare<StringType>(neq, R"(["five","four","three","two","one","zero"])", one, |
| "[1,1,1,1,0,1]"); |
| ValidateCompare<StringType>(neq, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); |
| |
| CompareOptions gt(CompareOperator::GREATER); |
| ValidateCompare<StringType>(gt, "[]", one, "[]"); |
| ValidateCompare<StringType>(gt, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(gt, R"(["zero","zero","one","one","two","two"])", one, |
| "[1,1,0,0,1,1]"); |
| ValidateCompare<StringType>(gt, R"(["zero","one","two","three","four","five"])", one, |
| "[1,0,1,1,0,0]"); |
| ValidateCompare<StringType>(gt, R"(["four","five","six","seven","eight","nine"])", one, |
| "[0,0,1,1,0,0]"); |
| ValidateCompare<StringType>(gt, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); |
| |
| CompareOptions gte(CompareOperator::GREATER_EQUAL); |
| ValidateCompare<StringType>(gte, "[]", one, "[]"); |
| ValidateCompare<StringType>(gte, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(gte, R"(["zero","zero","one","one","two","two"])", one, |
| "[1,1,1,1,1,1]"); |
| ValidateCompare<StringType>(gte, R"(["zero","one","two","three","four","five"])", one, |
| "[1,1,1,1,0,0]"); |
| ValidateCompare<StringType>(gte, R"(["four","five","six","seven","eight","nine"])", one, |
| "[0,0,1,1,0,0]"); |
| ValidateCompare<StringType>(gte, R"([null,"zero","one","one"])", one, "[null,1,1,1]"); |
| |
| CompareOptions lt(CompareOperator::LESS); |
| ValidateCompare<StringType>(lt, "[]", one, "[]"); |
| ValidateCompare<StringType>(lt, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(lt, R"(["zero","zero","one","one","two","two"])", one, |
| "[0,0,0,0,0,0]"); |
| ValidateCompare<StringType>(lt, R"(["zero","one","two","three","four","five"])", one, |
| "[0,0,0,0,1,1]"); |
| ValidateCompare<StringType>(lt, R"(["four","five","six","seven","eight","nine"])", one, |
| "[1,1,0,0,1,1]"); |
| ValidateCompare<StringType>(lt, R"([null,"zero","one","one"])", one, "[null,0,0,0]"); |
| |
| CompareOptions lte(CompareOperator::LESS_EQUAL); |
| ValidateCompare<StringType>(lte, "[]", one, "[]"); |
| ValidateCompare<StringType>(lte, "[null]", one, "[null]"); |
| ValidateCompare<StringType>(lte, R"(["zero","zero","one","one","two","two"])", one, |
| "[0,0,1,1,0,0]"); |
| ValidateCompare<StringType>(lte, R"(["zero","one","two","three","four","five"])", one, |
| "[0,1,0,0,1,1]"); |
| ValidateCompare<StringType>(lte, R"(["four","five","six","seven","eight","nine"])", one, |
| "[1,1,0,0,1,1]"); |
| ValidateCompare<StringType>(lte, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); |
| } |
| |
| TEST_F(TestStringCompareKernel, RandomCompareArrayScalar) { |
| using ScalarType = typename TypeTraits<StringType>::ScalarType; |
| |
| auto rand = random::RandomArrayGenerator(0x5416447); |
| for (size_t i = 3; i < 10; i++) { |
| for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { |
| for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { |
| const int64_t length = static_cast<int64_t>(1ULL << i); |
| auto array = Datum(rand.String(length, 0, 16, null_probability)); |
| auto hello = Datum(std::make_shared<ScalarType>("hello")); |
| auto options = CompareOptions(op); |
| ValidateCompare<StringType>(options, array, hello); |
| ValidateCompare<StringType>(options, hello, array); |
| } |
| } |
| } |
| } |
| |
| TEST_F(TestStringCompareKernel, RandomCompareArrayArray) { |
| auto rand = random::RandomArrayGenerator(0x5416447); |
| for (size_t i = 3; i < 5; i++) { |
| for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { |
| for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { |
| auto length = static_cast<int64_t>(1ULL << i); |
| auto lhs = Datum(rand.String(length << i, 0, 16, null_probability)); |
| auto rhs = Datum(rand.String(length << i, 0, 16, null_probability)); |
| auto options = CompareOptions(op); |
| ValidateCompare<StringType>(options, lhs, rhs); |
| } |
| } |
| } |
| } |
| |
| } // namespace compute |
| } // namespace arrow |