| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| // Functions for comparing Arrow data structures |
| |
| #include "arrow/compare.h" |
| |
| #include <climits> |
| #include <cmath> |
| #include <cstdint> |
| #include <cstring> |
| #include <memory> |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "arrow/array.h" |
| #include "arrow/array/diff.h" |
| #include "arrow/buffer.h" |
| #include "arrow/scalar.h" |
| #include "arrow/sparse_tensor.h" |
| #include "arrow/status.h" |
| #include "arrow/tensor.h" |
| #include "arrow/type.h" |
| #include "arrow/type_traits.h" |
| #include "arrow/util/bit_run_reader.h" |
| #include "arrow/util/bit_util.h" |
| #include "arrow/util/bitmap_ops.h" |
| #include "arrow/util/bitmap_reader.h" |
| #include "arrow/util/checked_cast.h" |
| #include "arrow/util/logging.h" |
| #include "arrow/util/macros.h" |
| #include "arrow/util/memory.h" |
| #include "arrow/visitor_inline.h" |
| |
| namespace arrow { |
| |
| using internal::BitmapEquals; |
| using internal::BitmapReader; |
| using internal::BitmapUInt64Reader; |
| using internal::checked_cast; |
| using internal::OptionalBitmapEquals; |
| |
| // ---------------------------------------------------------------------- |
| // Public method implementations |
| |
| namespace { |
| |
| // TODO also handle HALF_FLOAT NaNs |
| |
| enum FloatingEqualityFlags : int8_t { Approximate = 1, NansEqual = 2 }; |
| |
| template <typename T, int8_t Flags> |
| struct FloatingEquality { |
| bool operator()(T x, T y) { return x == y; } |
| }; |
| |
| template <typename T> |
| struct FloatingEquality<T, NansEqual> { |
| bool operator()(T x, T y) { return (x == y) || (std::isnan(x) && std::isnan(y)); } |
| }; |
| |
| template <typename T> |
| struct FloatingEquality<T, Approximate> { |
| explicit FloatingEquality(const EqualOptions& options) |
| : epsilon(static_cast<T>(options.atol())) {} |
| |
| bool operator()(T x, T y) { return (fabs(x - y) <= epsilon) || (x == y); } |
| |
| const T epsilon; |
| }; |
| |
| template <typename T> |
| struct FloatingEquality<T, Approximate | NansEqual> { |
| explicit FloatingEquality(const EqualOptions& options) |
| : epsilon(static_cast<T>(options.atol())) {} |
| |
| bool operator()(T x, T y) { |
| return (fabs(x - y) <= epsilon) || (x == y) || (std::isnan(x) && std::isnan(y)); |
| } |
| |
| const T epsilon; |
| }; |
| |
| template <typename T, typename Visitor> |
| void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate, |
| Visitor&& visit) { |
| if (options.nans_equal()) { |
| if (floating_approximate) { |
| visit(FloatingEquality<T, NansEqual | Approximate>{options}); |
| } else { |
| visit(FloatingEquality<T, NansEqual>{}); |
| } |
| } else { |
| if (floating_approximate) { |
| visit(FloatingEquality<T, Approximate>{options}); |
| } else { |
| visit(FloatingEquality<T, 0>{}); |
| } |
| } |
| } |
| |
| inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) { |
| if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) { |
| return false; |
| } |
| for (const auto& child : type.fields()) { |
| if (!IdentityImpliesEqualityNansNotEqual(*child->type())) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) { |
| if (options.nans_equal()) { |
| return true; |
| } |
| return IdentityImpliesEqualityNansNotEqual(type); |
| } |
| |
| bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, |
| int64_t left_start_idx, int64_t left_end_idx, |
| int64_t right_start_idx, const EqualOptions& options, |
| bool floating_approximate); |
| |
| class RangeDataEqualsImpl { |
| public: |
| // PRE-CONDITIONS: |
| // - the types are equal |
| // - the ranges are in bounds |
| RangeDataEqualsImpl(const EqualOptions& options, bool floating_approximate, |
| const ArrayData& left, const ArrayData& right, |
| int64_t left_start_idx, int64_t right_start_idx, |
| int64_t range_length) |
| : options_(options), |
| floating_approximate_(floating_approximate), |
| left_(left), |
| right_(right), |
| left_start_idx_(left_start_idx), |
| right_start_idx_(right_start_idx), |
| range_length_(range_length), |
| result_(false) {} |
| |
| bool Compare() { |
| // Compare null bitmaps |
| if (left_start_idx_ == 0 && right_start_idx_ == 0 && range_length_ == left_.length && |
| range_length_ == right_.length) { |
| // If we're comparing entire arrays, we can first compare the cached null counts |
| if (left_.GetNullCount() != right_.GetNullCount()) { |
| return false; |
| } |
| } |
| if (!OptionalBitmapEquals(left_.buffers[0], left_.offset + left_start_idx_, |
| right_.buffers[0], right_.offset + right_start_idx_, |
| range_length_)) { |
| return false; |
| } |
| // Compare values |
| return CompareWithType(*left_.type); |
| } |
| |
| bool CompareWithType(const DataType& type) { |
| result_ = true; |
| if (range_length_ != 0) { |
| ARROW_CHECK_OK(VisitTypeInline(type, this)); |
| } |
| return result_; |
| } |
| |
| Status Visit(const NullType&) { return Status::OK(); } |
| |
| template <typename TypeClass> |
| enable_if_primitive_ctype<TypeClass, Status> Visit(const TypeClass& type) { |
| return ComparePrimitive(type); |
| } |
| |
| template <typename TypeClass> |
| enable_if_t<is_temporal_type<TypeClass>::value, Status> Visit(const TypeClass& type) { |
| return ComparePrimitive(type); |
| } |
| |
| Status Visit(const BooleanType&) { |
| const uint8_t* left_bits = left_.GetValues<uint8_t>(1, 0); |
| const uint8_t* right_bits = right_.GetValues<uint8_t>(1, 0); |
| auto compare_runs = [&](int64_t i, int64_t length) -> bool { |
| if (length <= 8) { |
| // Avoid the BitmapUInt64Reader overhead for very small runs |
| for (int64_t j = i; j < i + length; ++j) { |
| if (BitUtil::GetBit(left_bits, left_start_idx_ + left_.offset + j) != |
| BitUtil::GetBit(right_bits, right_start_idx_ + right_.offset + j)) { |
| return false; |
| } |
| } |
| return true; |
| } else if (length <= 1024) { |
| BitmapUInt64Reader left_reader(left_bits, left_start_idx_ + left_.offset + i, |
| length); |
| BitmapUInt64Reader right_reader(right_bits, right_start_idx_ + right_.offset + i, |
| length); |
| while (left_reader.position() < length) { |
| if (left_reader.NextWord() != right_reader.NextWord()) { |
| return false; |
| } |
| } |
| DCHECK_EQ(right_reader.position(), length); |
| } else { |
| // BitmapEquals is the fastest method on large runs |
| return BitmapEquals(left_bits, left_start_idx_ + left_.offset + i, right_bits, |
| right_start_idx_ + right_.offset + i, length); |
| } |
| return true; |
| }; |
| VisitValidRuns(compare_runs); |
| return Status::OK(); |
| } |
| |
| Status Visit(const FloatType& type) { return CompareFloating(type); } |
| |
| Status Visit(const DoubleType& type) { return CompareFloating(type); } |
| |
| // Also matches StringType |
| Status Visit(const BinaryType& type) { return CompareBinary(type); } |
| |
| // Also matches LargeStringType |
| Status Visit(const LargeBinaryType& type) { return CompareBinary(type); } |
| |
| Status Visit(const FixedSizeBinaryType& type) { |
| const auto byte_width = type.byte_width(); |
| const uint8_t* left_data = left_.GetValues<uint8_t>(1, 0); |
| const uint8_t* right_data = right_.GetValues<uint8_t>(1, 0); |
| |
| if (left_data != nullptr && right_data != nullptr) { |
| auto compare_runs = [&](int64_t i, int64_t length) -> bool { |
| return memcmp(left_data + (left_start_idx_ + left_.offset + i) * byte_width, |
| right_data + (right_start_idx_ + right_.offset + i) * byte_width, |
| length * byte_width) == 0; |
| }; |
| VisitValidRuns(compare_runs); |
| } else { |
| auto compare_runs = [&](int64_t i, int64_t length) -> bool { return true; }; |
| VisitValidRuns(compare_runs); |
| } |
| return Status::OK(); |
| } |
| |
| // Also matches MapType |
| Status Visit(const ListType& type) { return CompareList(type); } |
| |
| Status Visit(const LargeListType& type) { return CompareList(type); } |
| |
| Status Visit(const FixedSizeListType& type) { |
| const auto list_size = type.list_size(); |
| const ArrayData& left_data = *left_.child_data[0]; |
| const ArrayData& right_data = *right_.child_data[0]; |
| |
| auto compare_runs = [&](int64_t i, int64_t length) -> bool { |
| RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, |
| (left_start_idx_ + left_.offset + i) * list_size, |
| (right_start_idx_ + right_.offset + i) * list_size, |
| length * list_size); |
| return impl.Compare(); |
| }; |
| VisitValidRuns(compare_runs); |
| return Status::OK(); |
| } |
| |
| Status Visit(const StructType& type) { |
| const int32_t num_fields = type.num_fields(); |
| |
| auto compare_runs = [&](int64_t i, int64_t length) -> bool { |
| for (int32_t f = 0; f < num_fields; ++f) { |
| RangeDataEqualsImpl impl(options_, floating_approximate_, *left_.child_data[f], |
| *right_.child_data[f], |
| left_start_idx_ + left_.offset + i, |
| right_start_idx_ + right_.offset + i, length); |
| if (!impl.Compare()) { |
| return false; |
| } |
| } |
| return true; |
| }; |
| VisitValidRuns(compare_runs); |
| return Status::OK(); |
| } |
| |
| Status Visit(const SparseUnionType& type) { |
| const auto& child_ids = type.child_ids(); |
| const int8_t* left_codes = left_.GetValues<int8_t>(1); |
| const int8_t* right_codes = right_.GetValues<int8_t>(1); |
| |
| // Unions don't have a null bitmap |
| for (int64_t i = 0; i < range_length_; ++i) { |
| const auto type_id = left_codes[left_start_idx_ + i]; |
| if (type_id != right_codes[right_start_idx_ + i]) { |
| result_ = false; |
| break; |
| } |
| const auto child_num = child_ids[type_id]; |
| // XXX can we instead detect runs of same-child union values? |
| RangeDataEqualsImpl impl( |
| options_, floating_approximate_, *left_.child_data[child_num], |
| *right_.child_data[child_num], left_start_idx_ + left_.offset + i, |
| right_start_idx_ + right_.offset + i, 1); |
| if (!impl.Compare()) { |
| result_ = false; |
| break; |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status Visit(const DenseUnionType& type) { |
| const auto& child_ids = type.child_ids(); |
| const int8_t* left_codes = left_.GetValues<int8_t>(1); |
| const int8_t* right_codes = right_.GetValues<int8_t>(1); |
| const int32_t* left_offsets = left_.GetValues<int32_t>(2); |
| const int32_t* right_offsets = right_.GetValues<int32_t>(2); |
| |
| for (int64_t i = 0; i < range_length_; ++i) { |
| const auto type_id = left_codes[left_start_idx_ + i]; |
| if (type_id != right_codes[right_start_idx_ + i]) { |
| result_ = false; |
| break; |
| } |
| const auto child_num = child_ids[type_id]; |
| RangeDataEqualsImpl impl( |
| options_, floating_approximate_, *left_.child_data[child_num], |
| *right_.child_data[child_num], left_offsets[left_start_idx_ + i], |
| right_offsets[right_start_idx_ + i], 1); |
| if (!impl.Compare()) { |
| result_ = false; |
| break; |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status Visit(const DictionaryType& type) { |
| // Compare dictionaries |
| result_ &= CompareArrayRanges( |
| *left_.dictionary, *right_.dictionary, |
| /*left_start_idx=*/0, |
| /*left_end_idx=*/std::max(left_.dictionary->length, right_.dictionary->length), |
| /*right_start_idx=*/0, options_, floating_approximate_); |
| if (result_) { |
| // Compare indices |
| result_ &= CompareWithType(*type.index_type()); |
| } |
| return Status::OK(); |
| } |
| |
| Status Visit(const ExtensionType& type) { |
| // Compare storages |
| result_ &= CompareWithType(*type.storage_type()); |
| return Status::OK(); |
| } |
| |
| protected: |
| // For CompareFloating (templated local classes or lambdas not supported in C++11) |
| template <typename CType> |
| struct ComparatorVisitor { |
| RangeDataEqualsImpl* impl; |
| const CType* left_values; |
| const CType* right_values; |
| |
| template <typename CompareFunction> |
| void operator()(CompareFunction&& compare) { |
| impl->VisitValues([&](int64_t i) { |
| const CType x = left_values[i + impl->left_start_idx_]; |
| const CType y = right_values[i + impl->right_start_idx_]; |
| return compare(x, y); |
| }); |
| } |
| }; |
| |
| template <typename CType> |
| friend struct ComparatorVisitor; |
| |
| template <typename TypeClass, typename CType = typename TypeClass::c_type> |
| Status ComparePrimitive(const TypeClass&) { |
| const CType* left_values = left_.GetValues<CType>(1); |
| const CType* right_values = right_.GetValues<CType>(1); |
| VisitValidRuns([&](int64_t i, int64_t length) { |
| return memcmp(left_values + left_start_idx_ + i, |
| right_values + right_start_idx_ + i, length * sizeof(CType)) == 0; |
| }); |
| return Status::OK(); |
| } |
| |
| template <typename TypeClass> |
| Status CompareFloating(const TypeClass&) { |
| using CType = typename TypeClass::c_type; |
| const CType* left_values = left_.GetValues<CType>(1); |
| const CType* right_values = right_.GetValues<CType>(1); |
| |
| ComparatorVisitor<CType> visitor{this, left_values, right_values}; |
| VisitFloatingEquality<CType>(options_, floating_approximate_, visitor); |
| return Status::OK(); |
| } |
| |
| template <typename TypeClass> |
| Status CompareBinary(const TypeClass&) { |
| const uint8_t* left_data = left_.GetValues<uint8_t>(2, 0); |
| const uint8_t* right_data = right_.GetValues<uint8_t>(2, 0); |
| |
| if (left_data != nullptr && right_data != nullptr) { |
| const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, |
| int64_t length) -> bool { |
| return memcmp(left_data + left_offset, right_data + right_offset, length) == 0; |
| }; |
| CompareWithOffsets<typename TypeClass::offset_type>(1, compare_ranges); |
| } else { |
| // One of the arrays is an array of empty strings and nulls. |
| // We just need to compare the offsets. |
| // (note we must not call memcmp() with null data pointers) |
| CompareWithOffsets<typename TypeClass::offset_type>(1, [](...) { return true; }); |
| } |
| return Status::OK(); |
| } |
| |
| template <typename TypeClass> |
| Status CompareList(const TypeClass&) { |
| const ArrayData& left_data = *left_.child_data[0]; |
| const ArrayData& right_data = *right_.child_data[0]; |
| |
| const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, |
| int64_t length) -> bool { |
| RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, |
| left_offset, right_offset, length); |
| return impl.Compare(); |
| }; |
| |
| CompareWithOffsets<typename TypeClass::offset_type>(1, compare_ranges); |
| return Status::OK(); |
| } |
| |
| template <typename offset_type, typename CompareRanges> |
| void CompareWithOffsets(int offsets_buffer_index, CompareRanges&& compare_ranges) { |
| const offset_type* left_offsets = |
| left_.GetValues<offset_type>(offsets_buffer_index) + left_start_idx_; |
| const offset_type* right_offsets = |
| right_.GetValues<offset_type>(offsets_buffer_index) + right_start_idx_; |
| |
| const auto compare_runs = [&](int64_t i, int64_t length) { |
| for (int64_t j = i; j < i + length; ++j) { |
| if (left_offsets[j + 1] - left_offsets[j] != |
| right_offsets[j + 1] - right_offsets[j]) { |
| return false; |
| } |
| } |
| if (!compare_ranges(left_offsets[i], right_offsets[i], |
| left_offsets[i + length] - left_offsets[i])) { |
| return false; |
| } |
| return true; |
| }; |
| |
| VisitValidRuns(compare_runs); |
| } |
| |
| template <typename CompareValues> |
| void VisitValues(CompareValues&& compare_values) { |
| internal::VisitSetBitRunsVoid(left_.buffers[0], left_.offset + left_start_idx_, |
| range_length_, [&](int64_t position, int64_t length) { |
| for (int64_t i = 0; i < length; ++i) { |
| result_ &= compare_values(position + i); |
| } |
| }); |
| } |
| |
| // Visit and compare runs of non-null values |
| template <typename CompareRuns> |
| void VisitValidRuns(CompareRuns&& compare_runs) { |
| const uint8_t* left_null_bitmap = left_.GetValues<uint8_t>(0, 0); |
| if (left_null_bitmap == nullptr) { |
| result_ = compare_runs(0, range_length_); |
| return; |
| } |
| internal::SetBitRunReader reader(left_null_bitmap, left_.offset + left_start_idx_, |
| range_length_); |
| while (true) { |
| const auto run = reader.NextRun(); |
| if (run.length == 0) { |
| return; |
| } |
| if (!compare_runs(run.position, run.length)) { |
| result_ = false; |
| return; |
| } |
| } |
| } |
| |
| const EqualOptions& options_; |
| const bool floating_approximate_; |
| const ArrayData& left_; |
| const ArrayData& right_; |
| const int64_t left_start_idx_; |
| const int64_t right_start_idx_; |
| const int64_t range_length_; |
| |
| bool result_; |
| }; |
| |
| bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, |
| int64_t left_start_idx, int64_t left_end_idx, |
| int64_t right_start_idx, const EqualOptions& options, |
| bool floating_approximate) { |
| if (left.type->id() != right.type->id() || |
| !TypeEquals(*left.type, *right.type, false /* check_metadata */)) { |
| return false; |
| } |
| |
| const int64_t range_length = left_end_idx - left_start_idx; |
| DCHECK_GE(range_length, 0); |
| if (left_start_idx + range_length > left.length) { |
| // Left range too small |
| return false; |
| } |
| if (right_start_idx + range_length > right.length) { |
| // Right range too small |
| return false; |
| } |
| if (&left == &right && left_start_idx == right_start_idx && |
| IdentityImpliesEquality(*left.type, options)) { |
| return true; |
| } |
| // Compare values |
| RangeDataEqualsImpl impl(options, floating_approximate, left, right, left_start_idx, |
| right_start_idx, range_length); |
| return impl.Compare(); |
| } |
| |
| class TypeEqualsVisitor { |
| public: |
| explicit TypeEqualsVisitor(const DataType& right, bool check_metadata) |
| : right_(right), check_metadata_(check_metadata), result_(false) {} |
| |
| Status VisitChildren(const DataType& left) { |
| if (left.num_fields() != right_.num_fields()) { |
| result_ = false; |
| return Status::OK(); |
| } |
| |
| for (int i = 0; i < left.num_fields(); ++i) { |
| if (!left.field(i)->Equals(right_.field(i), check_metadata_)) { |
| result_ = false; |
| return Status::OK(); |
| } |
| } |
| result_ = true; |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| enable_if_t<is_null_type<T>::value || is_primitive_ctype<T>::value || |
| is_base_binary_type<T>::value, |
| Status> |
| Visit(const T&) { |
| result_ = true; |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| enable_if_interval<T, Status> Visit(const T& left) { |
| const auto& right = checked_cast<const IntervalType&>(right_); |
| result_ = right.interval_type() == left.interval_type(); |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| enable_if_t<is_time_type<T>::value || is_date_type<T>::value || |
| is_duration_type<T>::value, |
| Status> |
| Visit(const T& left) { |
| const auto& right = checked_cast<const T&>(right_); |
| result_ = left.unit() == right.unit(); |
| return Status::OK(); |
| } |
| |
| Status Visit(const TimestampType& left) { |
| const auto& right = checked_cast<const TimestampType&>(right_); |
| result_ = left.unit() == right.unit() && left.timezone() == right.timezone(); |
| return Status::OK(); |
| } |
| |
| Status Visit(const FixedSizeBinaryType& left) { |
| const auto& right = checked_cast<const FixedSizeBinaryType&>(right_); |
| result_ = left.byte_width() == right.byte_width(); |
| return Status::OK(); |
| } |
| |
| Status Visit(const Decimal128Type& left) { |
| const auto& right = checked_cast<const Decimal128Type&>(right_); |
| result_ = left.precision() == right.precision() && left.scale() == right.scale(); |
| return Status::OK(); |
| } |
| |
| Status Visit(const Decimal256Type& left) { |
| const auto& right = checked_cast<const Decimal256Type&>(right_); |
| result_ = left.precision() == right.precision() && left.scale() == right.scale(); |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit( |
| const T& left) { |
| return VisitChildren(left); |
| } |
| |
| Status Visit(const MapType& left) { |
| const auto& right = checked_cast<const MapType&>(right_); |
| if (left.keys_sorted() != right.keys_sorted()) { |
| result_ = false; |
| return Status::OK(); |
| } |
| result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) && |
| left.item_type()->Equals(*right.item_type(), check_metadata_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const UnionType& left) { |
| const auto& right = checked_cast<const UnionType&>(right_); |
| |
| if (left.mode() != right.mode() || left.type_codes() != right.type_codes()) { |
| result_ = false; |
| return Status::OK(); |
| } |
| |
| result_ = std::equal( |
| left.fields().begin(), left.fields().end(), right.fields().begin(), |
| [this](const std::shared_ptr<Field>& l, const std::shared_ptr<Field>& r) { |
| return l->Equals(r, check_metadata_); |
| }); |
| return Status::OK(); |
| } |
| |
| Status Visit(const DictionaryType& left) { |
| const auto& right = checked_cast<const DictionaryType&>(right_); |
| result_ = left.index_type()->Equals(right.index_type()) && |
| left.value_type()->Equals(right.value_type()) && |
| (left.ordered() == right.ordered()); |
| return Status::OK(); |
| } |
| |
| Status Visit(const ExtensionType& left) { |
| result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_)); |
| return Status::OK(); |
| } |
| |
| bool result() const { return result_; } |
| |
| protected: |
| const DataType& right_; |
| bool check_metadata_; |
| bool result_; |
| }; |
| |
| bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, |
| bool floating_approximate); |
| bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, |
| bool floating_approximate); |
| |
| class ScalarEqualsVisitor { |
| public: |
| // PRE-CONDITIONS: |
| // - the types are equal |
| // - the scalars are non-null |
| explicit ScalarEqualsVisitor(const Scalar& right, const EqualOptions& opts, |
| bool floating_approximate) |
| : right_(right), |
| options_(opts), |
| floating_approximate_(floating_approximate), |
| result_(false) {} |
| |
| Status Visit(const NullScalar& left) { |
| result_ = true; |
| return Status::OK(); |
| } |
| |
| Status Visit(const BooleanScalar& left) { |
| const auto& right = checked_cast<const BooleanScalar&>(right_); |
| result_ = left.value == right.value; |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| typename std::enable_if<(is_primitive_ctype<typename T::TypeClass>::value || |
| is_temporal_type<typename T::TypeClass>::value), |
| Status>::type |
| Visit(const T& left_) { |
| const auto& right = checked_cast<const T&>(right_); |
| result_ = right.value == left_.value; |
| return Status::OK(); |
| } |
| |
| Status Visit(const FloatScalar& left) { return CompareFloating(left); } |
| |
| Status Visit(const DoubleScalar& left) { return CompareFloating(left); } |
| |
| template <typename T> |
| typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type |
| Visit(const T& left) { |
| const auto& right = checked_cast<const BaseBinaryScalar&>(right_); |
| result_ = internal::SharedPtrEquals(left.value, right.value); |
| return Status::OK(); |
| } |
| |
| Status Visit(const Decimal128Scalar& left) { |
| const auto& right = checked_cast<const Decimal128Scalar&>(right_); |
| result_ = left.value == right.value; |
| return Status::OK(); |
| } |
| |
| Status Visit(const Decimal256Scalar& left) { |
| const auto& right = checked_cast<const Decimal256Scalar&>(right_); |
| result_ = left.value == right.value; |
| return Status::OK(); |
| } |
| |
| Status Visit(const ListScalar& left) { |
| const auto& right = checked_cast<const ListScalar&>(right_); |
| result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const LargeListScalar& left) { |
| const auto& right = checked_cast<const LargeListScalar&>(right_); |
| result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const MapScalar& left) { |
| const auto& right = checked_cast<const MapScalar&>(right_); |
| result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const FixedSizeListScalar& left) { |
| const auto& right = checked_cast<const FixedSizeListScalar&>(right_); |
| result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const StructScalar& left) { |
| const auto& right = checked_cast<const StructScalar&>(right_); |
| |
| if (right.value.size() != left.value.size()) { |
| result_ = false; |
| } else { |
| bool all_equals = true; |
| for (size_t i = 0; i < left.value.size() && all_equals; i++) { |
| all_equals &= ScalarEquals(*left.value[i], *right.value[i], options_, |
| floating_approximate_); |
| } |
| result_ = all_equals; |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status Visit(const UnionScalar& left) { |
| const auto& right = checked_cast<const UnionScalar&>(right_); |
| if (left.is_valid && right.is_valid) { |
| result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); |
| } else if (!left.is_valid && !right.is_valid) { |
| result_ = true; |
| } else { |
| result_ = false; |
| } |
| return Status::OK(); |
| } |
| |
| Status Visit(const DictionaryScalar& left) { |
| const auto& right = checked_cast<const DictionaryScalar&>(right_); |
| result_ = ScalarEquals(*left.value.index, *right.value.index, options_, |
| floating_approximate_) && |
| ArrayEquals(*left.value.dictionary, *right.value.dictionary, options_, |
| floating_approximate_); |
| return Status::OK(); |
| } |
| |
| Status Visit(const ExtensionScalar& left) { |
| return Status::NotImplemented("extension"); |
| } |
| |
| bool result() const { return result_; } |
| |
| protected: |
| // For CompareFloating (templated local classes or lambdas not supported in C++11) |
| template <typename ScalarType> |
| struct ComparatorVisitor { |
| const ScalarType& left; |
| const ScalarType& right; |
| bool* result; |
| |
| template <typename CompareFunction> |
| void operator()(CompareFunction&& compare) { |
| *result = compare(left.value, right.value); |
| } |
| }; |
| |
| template <typename ScalarType> |
| Status CompareFloating(const ScalarType& left) { |
| using CType = decltype(left.value); |
| |
| ComparatorVisitor<ScalarType> visitor{left, checked_cast<const ScalarType&>(right_), |
| &result_}; |
| VisitFloatingEquality<CType>(options_, floating_approximate_, visitor); |
| return Status::OK(); |
| } |
| |
| const Scalar& right_; |
| const EqualOptions options_; |
| const bool floating_approximate_; |
| bool result_; |
| }; |
| |
| Status PrintDiff(const Array& left, const Array& right, std::ostream* os); |
| |
| Status PrintDiff(const Array& left, const Array& right, int64_t left_offset, |
| int64_t left_length, int64_t right_offset, int64_t right_length, |
| std::ostream* os) { |
| if (os == nullptr) { |
| return Status::OK(); |
| } |
| |
| if (!left.type()->Equals(right.type())) { |
| *os << "# Array types differed: " << *left.type() << " vs " << *right.type() |
| << std::endl; |
| return Status::OK(); |
| } |
| |
| if (left.type()->id() == Type::DICTIONARY) { |
| *os << "# Dictionary arrays differed" << std::endl; |
| |
| const auto& left_dict = checked_cast<const DictionaryArray&>(left); |
| const auto& right_dict = checked_cast<const DictionaryArray&>(right); |
| |
| *os << "## dictionary diff"; |
| auto pos = os->tellp(); |
| RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os)); |
| if (os->tellp() == pos) { |
| *os << std::endl; |
| } |
| |
| *os << "## indices diff"; |
| pos = os->tellp(); |
| RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os)); |
| if (os->tellp() == pos) { |
| *os << std::endl; |
| } |
| return Status::OK(); |
| } |
| |
| const auto left_slice = left.Slice(left_offset, left_length); |
| const auto right_slice = right.Slice(right_offset, right_length); |
| ARROW_ASSIGN_OR_RAISE(auto edits, |
| Diff(*left_slice, *right_slice, default_memory_pool())); |
| ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os)); |
| return formatter(*edits, *left_slice, *right_slice); |
| } |
| |
| Status PrintDiff(const Array& left, const Array& right, std::ostream* os) { |
| return PrintDiff(left, right, 0, left.length(), 0, right.length(), os); |
| } |
| |
| bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, |
| int64_t left_end_idx, int64_t right_start_idx, |
| const EqualOptions& options, bool floating_approximate) { |
| bool are_equal = |
| CompareArrayRanges(*left.data(), *right.data(), left_start_idx, left_end_idx, |
| right_start_idx, options, floating_approximate); |
| if (!are_equal) { |
| ARROW_IGNORE_EXPR(PrintDiff( |
| left, right, left_start_idx, left_end_idx, right_start_idx, |
| right_start_idx + (left_end_idx - left_start_idx), options.diff_sink())); |
| } |
| return are_equal; |
| } |
| |
| bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, |
| bool floating_approximate) { |
| if (left.length() != right.length()) { |
| ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); |
| return false; |
| } |
| return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); |
| } |
| |
| bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, |
| bool floating_approximate) { |
| if (&left == &right && IdentityImpliesEquality(*left.type, options)) { |
| return true; |
| } |
| if (!left.type->Equals(right.type)) { |
| return false; |
| } |
| if (left.is_valid != right.is_valid) { |
| return false; |
| } |
| if (!left.is_valid) { |
| return true; |
| } |
| ScalarEqualsVisitor visitor(right, options, floating_approximate); |
| auto error = VisitScalarInline(left, &visitor); |
| DCHECK_OK(error); |
| return visitor.result(); |
| } |
| |
| } // namespace |
| |
| bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, |
| int64_t left_end_idx, int64_t right_start_idx, |
| const EqualOptions& options) { |
| const bool floating_approximate = false; |
| return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx, |
| options, floating_approximate); |
| } |
| |
| bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_start_idx, |
| int64_t left_end_idx, int64_t right_start_idx, |
| const EqualOptions& options) { |
| const bool floating_approximate = true; |
| return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx, |
| options, floating_approximate); |
| } |
| |
| bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) { |
| const bool floating_approximate = false; |
| return ArrayEquals(left, right, opts, floating_approximate); |
| } |
| |
| bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) { |
| const bool floating_approximate = true; |
| return ArrayEquals(left, right, opts, floating_approximate); |
| } |
| |
| bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { |
| const bool floating_approximate = false; |
| return ScalarEquals(left, right, options, floating_approximate); |
| } |
| |
| bool ScalarApproxEquals(const Scalar& left, const Scalar& right, |
| const EqualOptions& options) { |
| const bool floating_approximate = true; |
| return ScalarEquals(left, right, options, floating_approximate); |
| } |
| |
| namespace { |
| |
| bool StridedIntegerTensorContentEquals(const int dim_index, int64_t left_offset, |
| int64_t right_offset, int elem_size, |
| const Tensor& left, const Tensor& right) { |
| const auto n = left.shape()[dim_index]; |
| const auto left_stride = left.strides()[dim_index]; |
| const auto right_stride = right.strides()[dim_index]; |
| if (dim_index == left.ndim() - 1) { |
| for (int64_t i = 0; i < n; ++i) { |
| if (memcmp(left.raw_data() + left_offset + i * left_stride, |
| right.raw_data() + right_offset + i * right_stride, elem_size) != 0) { |
| return false; |
| } |
| } |
| return true; |
| } |
| for (int64_t i = 0; i < n; ++i) { |
| if (!StridedIntegerTensorContentEquals(dim_index + 1, left_offset, right_offset, |
| elem_size, left, right)) { |
| return false; |
| } |
| left_offset += left_stride; |
| right_offset += right_stride; |
| } |
| return true; |
| } |
| |
| bool IntegerTensorEquals(const Tensor& left, const Tensor& right) { |
| bool are_equal; |
| // The arrays are the same object |
| if (&left == &right) { |
| are_equal = true; |
| } else { |
| const bool left_row_major_p = left.is_row_major(); |
| const bool left_column_major_p = left.is_column_major(); |
| const bool right_row_major_p = right.is_row_major(); |
| const bool right_column_major_p = right.is_column_major(); |
| |
| if (!(left_row_major_p && right_row_major_p) && |
| !(left_column_major_p && right_column_major_p)) { |
| const auto& type = checked_cast<const FixedWidthType&>(*left.type()); |
| are_equal = StridedIntegerTensorContentEquals(0, 0, 0, internal::GetByteWidth(type), |
| left, right); |
| } else { |
| const int byte_width = internal::GetByteWidth(*left.type()); |
| DCHECK_GT(byte_width, 0); |
| |
| const uint8_t* left_data = left.data()->data(); |
| const uint8_t* right_data = right.data()->data(); |
| |
| are_equal = memcmp(left_data, right_data, |
| static_cast<size_t>(byte_width * left.size())) == 0; |
| } |
| } |
| return are_equal; |
| } |
| |
| template <typename DataType> |
| bool StridedFloatTensorContentEquals(const int dim_index, int64_t left_offset, |
| int64_t right_offset, const Tensor& left, |
| const Tensor& right, const EqualOptions& opts) { |
| using c_type = typename DataType::c_type; |
| static_assert(std::is_floating_point<c_type>::value, |
| "DataType must be a floating point type"); |
| |
| const auto n = left.shape()[dim_index]; |
| const auto left_stride = left.strides()[dim_index]; |
| const auto right_stride = right.strides()[dim_index]; |
| if (dim_index == left.ndim() - 1) { |
| auto left_data = left.raw_data(); |
| auto right_data = right.raw_data(); |
| if (opts.nans_equal()) { |
| for (int64_t i = 0; i < n; ++i) { |
| c_type left_value = |
| *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride); |
| c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset + |
| i * right_stride); |
| if (left_value != right_value && |
| !(std::isnan(left_value) && std::isnan(right_value))) { |
| return false; |
| } |
| } |
| } else { |
| for (int64_t i = 0; i < n; ++i) { |
| c_type left_value = |
| *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride); |
| c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset + |
| i * right_stride); |
| if (left_value != right_value) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| for (int64_t i = 0; i < n; ++i) { |
| if (!StridedFloatTensorContentEquals<DataType>(dim_index + 1, left_offset, |
| right_offset, left, right, opts)) { |
| return false; |
| } |
| left_offset += left_stride; |
| right_offset += right_stride; |
| } |
| return true; |
| } |
| |
| template <typename DataType> |
| bool FloatTensorEquals(const Tensor& left, const Tensor& right, |
| const EqualOptions& opts) { |
| return StridedFloatTensorContentEquals<DataType>(0, 0, 0, left, right, opts); |
| } |
| |
| } // namespace |
| |
| bool TensorEquals(const Tensor& left, const Tensor& right, const EqualOptions& opts) { |
| if (left.type_id() != right.type_id()) { |
| return false; |
| } else if (left.size() == 0 && right.size() == 0) { |
| return true; |
| } else if (left.shape() != right.shape()) { |
| return false; |
| } |
| |
| switch (left.type_id()) { |
| // TODO: Support half-float tensors |
| // case Type::HALF_FLOAT: |
| case Type::FLOAT: |
| return FloatTensorEquals<FloatType>(left, right, opts); |
| |
| case Type::DOUBLE: |
| return FloatTensorEquals<DoubleType>(left, right, opts); |
| |
| default: |
| return IntegerTensorEquals(left, right); |
| } |
| } |
| |
| namespace { |
| |
| template <typename LeftSparseIndexType, typename RightSparseIndexType> |
| struct SparseTensorEqualsImpl { |
| static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left, |
| const SparseTensorImpl<RightSparseIndexType>& right, |
| const EqualOptions&) { |
| // TODO(mrkn): should we support the equality among different formats? |
| return false; |
| } |
| }; |
| |
| bool IntegerSparseTensorDataEquals(const uint8_t* left_data, const uint8_t* right_data, |
| const int byte_width, const int64_t length) { |
| if (left_data == right_data) { |
| return true; |
| } |
| return memcmp(left_data, right_data, static_cast<size_t>(byte_width * length)) == 0; |
| } |
| |
| template <typename DataType> |
| bool FloatSparseTensorDataEquals(const typename DataType::c_type* left_data, |
| const typename DataType::c_type* right_data, |
| const int64_t length, const EqualOptions& opts) { |
| using c_type = typename DataType::c_type; |
| static_assert(std::is_floating_point<c_type>::value, |
| "DataType must be a floating point type"); |
| if (opts.nans_equal()) { |
| if (left_data == right_data) { |
| return true; |
| } |
| |
| for (int64_t i = 0; i < length; ++i) { |
| const auto left = left_data[i]; |
| const auto right = right_data[i]; |
| if (left != right && !(std::isnan(left) && std::isnan(right))) { |
| return false; |
| } |
| } |
| } else { |
| for (int64_t i = 0; i < length; ++i) { |
| if (left_data[i] != right_data[i]) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| template <typename SparseIndexType> |
| struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> { |
| static bool Compare(const SparseTensorImpl<SparseIndexType>& left, |
| const SparseTensorImpl<SparseIndexType>& right, |
| const EqualOptions& opts) { |
| DCHECK(left.type()->id() == right.type()->id()); |
| DCHECK(left.shape() == right.shape()); |
| |
| const auto length = left.non_zero_length(); |
| DCHECK(length == right.non_zero_length()); |
| |
| const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index()); |
| const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index()); |
| |
| if (!left_index.Equals(right_index)) { |
| return false; |
| } |
| |
| const int byte_width = internal::GetByteWidth(*left.type()); |
| DCHECK_GT(byte_width, 0); |
| |
| const uint8_t* left_data = left.data()->data(); |
| const uint8_t* right_data = right.data()->data(); |
| switch (left.type()->id()) { |
| // TODO: Support half-float tensors |
| // case Type::HALF_FLOAT: |
| case Type::FLOAT: |
| return FloatSparseTensorDataEquals<FloatType>( |
| reinterpret_cast<const float*>(left_data), |
| reinterpret_cast<const float*>(right_data), length, opts); |
| |
| case Type::DOUBLE: |
| return FloatSparseTensorDataEquals<DoubleType>( |
| reinterpret_cast<const double*>(left_data), |
| reinterpret_cast<const double*>(right_data), length, opts); |
| |
| default: // Integer cases |
| return IntegerSparseTensorDataEquals(left_data, right_data, byte_width, length); |
| } |
| } |
| }; |
| |
| template <typename SparseIndexType> |
| inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left, |
| const SparseTensor& right, |
| const EqualOptions& opts) { |
| switch (right.format_id()) { |
| case SparseTensorFormat::COO: { |
| const auto& right_coo = |
| checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right); |
| return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare( |
| left, right_coo, opts); |
| } |
| |
| case SparseTensorFormat::CSR: { |
| const auto& right_csr = |
| checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right); |
| return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare( |
| left, right_csr, opts); |
| } |
| |
| case SparseTensorFormat::CSC: { |
| const auto& right_csc = |
| checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(right); |
| return SparseTensorEqualsImpl<SparseIndexType, SparseCSCIndex>::Compare( |
| left, right_csc, opts); |
| } |
| |
| case SparseTensorFormat::CSF: { |
| const auto& right_csf = |
| checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(right); |
| return SparseTensorEqualsImpl<SparseIndexType, SparseCSFIndex>::Compare( |
| left, right_csf, opts); |
| } |
| |
| default: |
| return false; |
| } |
| } |
| |
| } // namespace |
| |
| bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right, |
| const EqualOptions& opts) { |
| if (left.type()->id() != right.type()->id()) { |
| return false; |
| } else if (left.size() == 0 && right.size() == 0) { |
| return true; |
| } else if (left.shape() != right.shape()) { |
| return false; |
| } else if (left.non_zero_length() != right.non_zero_length()) { |
| return false; |
| } |
| |
| switch (left.format_id()) { |
| case SparseTensorFormat::COO: { |
| const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left); |
| return SparseTensorEqualsImplDispatch(left_coo, right, opts); |
| } |
| |
| case SparseTensorFormat::CSR: { |
| const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left); |
| return SparseTensorEqualsImplDispatch(left_csr, right, opts); |
| } |
| |
| case SparseTensorFormat::CSC: { |
| const auto& left_csc = checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(left); |
| return SparseTensorEqualsImplDispatch(left_csc, right, opts); |
| } |
| |
| case SparseTensorFormat::CSF: { |
| const auto& left_csf = checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(left); |
| return SparseTensorEqualsImplDispatch(left_csf, right, opts); |
| } |
| |
| default: |
| return false; |
| } |
| } |
| |
| bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) { |
| // The arrays are the same object |
| if (&left == &right) { |
| return true; |
| } else if (left.id() != right.id()) { |
| return false; |
| } else { |
| // First try to compute fingerprints |
| if (check_metadata) { |
| const auto& left_metadata_fp = left.metadata_fingerprint(); |
| const auto& right_metadata_fp = right.metadata_fingerprint(); |
| if (left_metadata_fp != right_metadata_fp) { |
| return false; |
| } |
| } |
| |
| const auto& left_fp = left.fingerprint(); |
| const auto& right_fp = right.fingerprint(); |
| if (!left_fp.empty() && !right_fp.empty()) { |
| return left_fp == right_fp; |
| } |
| |
| // TODO remove check_metadata here? |
| TypeEqualsVisitor visitor(right, check_metadata); |
| auto error = VisitTypeInline(left, &visitor); |
| if (!error.ok()) { |
| DCHECK(false) << "Types are not comparable: " << error.ToString(); |
| } |
| return visitor.result(); |
| } |
| } |
| |
| } // namespace arrow |