blob: 12991b94aeba733a52e97fae8633218d392a01ea [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Functions for comparing Arrow data structures
#include "arrow/compare.h"
#include <climits>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/scalar.h"
#include "arrow/sparse_tensor.h"
#include "arrow/status.h"
#include "arrow/tensor.h"
#include "arrow/type.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/memory.h"
#include "arrow/visitor_inline.h"
namespace arrow {
using internal::BitmapEquals;
using internal::checked_cast;
// ----------------------------------------------------------------------
// Public method implementations
namespace internal {
// These helper functions assume we already checked the arrays have equal
// sizes and null bitmaps.
template <typename ArrowType, typename EqualityFunc>
inline bool BaseFloatingEquals(const NumericArray<ArrowType>& left,
const NumericArray<ArrowType>& right,
EqualityFunc&& equals) {
using T = typename ArrowType::c_type;
const T* left_data = left.raw_values();
const T* right_data = right.raw_values();
if (left.null_count() > 0) {
for (int64_t i = 0; i < left.length(); ++i) {
if (left.IsNull(i)) continue;
if (!equals(left_data[i], right_data[i])) {
return false;
}
}
} else {
for (int64_t i = 0; i < left.length(); ++i) {
if (!equals(left_data[i], right_data[i])) {
return false;
}
}
}
return true;
}
template <typename ArrowType>
inline bool FloatingEquals(const NumericArray<ArrowType>& left,
const NumericArray<ArrowType>& right,
const EqualOptions& opts) {
using T = typename ArrowType::c_type;
if (opts.nans_equal()) {
return BaseFloatingEquals<ArrowType>(left, right, [](T x, T y) -> bool {
return (x == y) || (std::isnan(x) && std::isnan(y));
});
} else {
return BaseFloatingEquals<ArrowType>(left, right,
[](T x, T y) -> bool { return x == y; });
}
}
template <typename ArrowType>
inline bool FloatingApproxEquals(const NumericArray<ArrowType>& left,
const NumericArray<ArrowType>& right,
const EqualOptions& opts) {
using T = typename ArrowType::c_type;
const T epsilon = static_cast<T>(opts.atol());
if (opts.nans_equal()) {
return BaseFloatingEquals<ArrowType>(left, right, [epsilon](T x, T y) -> bool {
return (fabs(x - y) <= epsilon) || (std::isnan(x) && std::isnan(y));
});
} else {
return BaseFloatingEquals<ArrowType>(
left, right, [epsilon](T x, T y) -> bool { return fabs(x - y) <= epsilon; });
}
}
// RangeEqualsVisitor assumes the range sizes are equal
class RangeEqualsVisitor {
public:
RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
int64_t right_start_idx)
: right_(right),
left_start_idx_(left_start_idx),
left_end_idx_(left_end_idx),
right_start_idx_(right_start_idx),
result_(false) {}
template <typename ArrayType>
inline Status CompareValues(const ArrayType& left) {
const auto& right = checked_cast<const ArrayType&>(right_);
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
const bool is_null = left.IsNull(i);
if (is_null != right.IsNull(o_i) ||
(!is_null && left.Value(i) != right.Value(o_i))) {
result_ = false;
return Status::OK();
}
}
result_ = true;
return Status::OK();
}
bool CompareBinaryRange(const BinaryArray& left) const {
const auto& right = checked_cast<const BinaryArray&>(right_);
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
const bool is_null = left.IsNull(i);
if (is_null != right.IsNull(o_i)) {
return false;
}
if (is_null) continue;
const int32_t begin_offset = left.value_offset(i);
const int32_t end_offset = left.value_offset(i + 1);
const int32_t right_begin_offset = right.value_offset(o_i);
const int32_t right_end_offset = right.value_offset(o_i + 1);
// Underlying can't be equal if the size isn't equal
if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
return false;
}
if (end_offset - begin_offset > 0 &&
std::memcmp(left.value_data()->data() + begin_offset,
right.value_data()->data() + right_begin_offset,
static_cast<size_t>(end_offset - begin_offset))) {
return false;
}
}
return true;
}
bool CompareLists(const ListArray& left) {
const auto& right = checked_cast<const ListArray&>(right_);
const std::shared_ptr<Array>& left_values = left.values();
const std::shared_ptr<Array>& right_values = right.values();
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
const bool is_null = left.IsNull(i);
if (is_null != right.IsNull(o_i)) {
return false;
}
if (is_null) continue;
const int32_t begin_offset = left.value_offset(i);
const int32_t end_offset = left.value_offset(i + 1);
const int32_t right_begin_offset = right.value_offset(o_i);
const int32_t right_end_offset = right.value_offset(o_i + 1);
// Underlying can't be equal if the size isn't equal
if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
return false;
}
if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset,
right_values)) {
return false;
}
}
return true;
}
bool CompareStructs(const StructArray& left) {
const auto& right = checked_cast<const StructArray&>(right_);
bool equal_fields = true;
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
if (left.IsNull(i) != right.IsNull(o_i)) {
return false;
}
if (left.IsNull(i)) continue;
for (int j = 0; j < left.num_fields(); ++j) {
// TODO: really we should be comparing stretches of non-null data rather
// than looking at one value at a time.
equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j));
if (!equal_fields) {
return false;
}
}
}
return true;
}
bool CompareUnions(const UnionArray& left) const {
const auto& right = checked_cast<const UnionArray&>(right_);
const UnionMode::type union_mode = left.mode();
if (union_mode != right.mode()) {
return false;
}
const auto& left_type = checked_cast<const UnionType&>(*left.type());
// Define a mapping from the type id to child number
uint8_t max_code = 0;
const std::vector<uint8_t>& type_codes = left_type.type_codes();
for (size_t i = 0; i < type_codes.size(); ++i) {
const uint8_t code = type_codes[i];
if (code > max_code) {
max_code = code;
}
}
// Store mapping in a vector for constant time lookups
std::vector<uint8_t> type_id_to_child_num(max_code + 1);
for (uint8_t i = 0; i < static_cast<uint8_t>(type_codes.size()); ++i) {
type_id_to_child_num[type_codes[i]] = i;
}
const uint8_t* left_ids = left.raw_type_ids();
const uint8_t* right_ids = right.raw_type_ids();
uint8_t id, child_num;
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
if (left.IsNull(i) != right.IsNull(o_i)) {
return false;
}
if (left.IsNull(i)) continue;
if (left_ids[i] != right_ids[o_i]) {
return false;
}
id = left_ids[i];
child_num = type_id_to_child_num[id];
// TODO(wesm): really we should be comparing stretches of non-null data
// rather than looking at one value at a time.
if (union_mode == UnionMode::SPARSE) {
if (!left.child(child_num)->RangeEquals(i, i + 1, o_i, right.child(child_num))) {
return false;
}
} else {
const int32_t offset = left.raw_value_offsets()[i];
const int32_t o_offset = right.raw_value_offsets()[o_i];
if (!left.child(child_num)->RangeEquals(offset, offset + 1, o_offset,
right.child(child_num))) {
return false;
}
}
}
return true;
}
Status Visit(const BinaryArray& left) {
result_ = CompareBinaryRange(left);
return Status::OK();
}
Status Visit(const FixedSizeBinaryArray& left) {
const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_);
int32_t width = left.byte_width();
const uint8_t* left_data = nullptr;
const uint8_t* right_data = nullptr;
if (left.values()) {
left_data = left.raw_values();
}
if (right.values()) {
right_data = right.raw_values();
}
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
++i, ++o_i) {
const bool is_null = left.IsNull(i);
if (is_null != right.IsNull(o_i)) {
result_ = false;
return Status::OK();
}
if (is_null) continue;
if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
result_ = false;
return Status::OK();
}
}
result_ = true;
return Status::OK();
}
Status Visit(const Decimal128Array& left) {
return Visit(checked_cast<const FixedSizeBinaryArray&>(left));
}
Status Visit(const NullArray& left) {
ARROW_UNUSED(left);
result_ = true;
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
const T& left) {
return CompareValues<T>(left);
}
Status Visit(const ListArray& left) {
result_ = CompareLists(left);
return Status::OK();
}
Status Visit(const FixedSizeListArray& left) {
const auto& right = checked_cast<const FixedSizeListArray&>(right_);
result_ = left.values()->RangeEquals(
left.value_offset(left_start_idx_), left.value_offset(left_end_idx_),
right.value_offset(right_start_idx_), right.values());
return Status::OK();
}
Status Visit(const StructArray& left) {
result_ = CompareStructs(left);
return Status::OK();
}
Status Visit(const UnionArray& left) {
result_ = CompareUnions(left);
return Status::OK();
}
Status Visit(const DictionaryArray& left) {
const auto& right = checked_cast<const DictionaryArray&>(right_);
if (!left.dictionary()->Equals(right.dictionary())) {
result_ = false;
return Status::OK();
}
result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_,
right_start_idx_, right.indices());
return Status::OK();
}
Status Visit(const ExtensionArray& left) {
result_ = (right_.type()->Equals(*left.type()) &&
ArrayRangeEquals(*left.storage(),
*static_cast<const ExtensionArray&>(right_).storage(),
left_start_idx_, left_end_idx_, right_start_idx_));
return Status::OK();
}
bool result() const { return result_; }
protected:
const Array& right_;
int64_t left_start_idx_;
int64_t left_end_idx_;
int64_t right_start_idx_;
bool result_;
};
static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
const int byte_width = size_meta.bit_width() / CHAR_BIT;
const uint8_t* left_data = nullptr;
const uint8_t* right_data = nullptr;
if (left.values()) {
left_data = left.values()->data() + left.offset() * byte_width;
}
if (right.values()) {
right_data = right.values()->data() + right.offset() * byte_width;
}
if (byte_width == 0) {
// Special case 0-width data, as the data pointers may be null
for (int64_t i = 0; i < left.length(); ++i) {
if (left.IsNull(i) != right.IsNull(i)) {
return false;
}
}
return true;
} else if (left.null_count() > 0) {
for (int64_t i = 0; i < left.length(); ++i) {
const bool left_null = left.IsNull(i);
const bool right_null = right.IsNull(i);
if (left_null != right_null) {
return false;
}
if (!left_null && memcmp(left_data, right_data, byte_width) != 0) {
return false;
}
left_data += byte_width;
right_data += byte_width;
}
return true;
} else {
auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length());
return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0;
}
}
// A bit confusing: ArrayEqualsVisitor inherits from RangeEqualsVisitor but
// doesn't share the same preconditions.
// When RangeEqualsVisitor is called, we only know the range sizes equal.
// When ArrayEqualsVisitor is called, we know the sizes and null bitmaps are equal.
class ArrayEqualsVisitor : public RangeEqualsVisitor {
public:
explicit ArrayEqualsVisitor(const Array& right, const EqualOptions& opts)
: RangeEqualsVisitor(right, 0, right.length(), 0), opts_(opts) {}
Status Visit(const NullArray& left) {
ARROW_UNUSED(left);
result_ = true;
return Status::OK();
}
Status Visit(const BooleanArray& left) {
const auto& right = checked_cast<const BooleanArray&>(right_);
if (left.null_count() > 0) {
const uint8_t* left_data = left.values()->data();
const uint8_t* right_data = right.values()->data();
for (int64_t i = 0; i < left.length(); ++i) {
if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) !=
BitUtil::GetBit(right_data, i + right.offset())) {
result_ = false;
return Status::OK();
}
}
result_ = true;
} else {
result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(),
right.offset(), left.length());
}
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
!std::is_base_of<FloatArray, T>::value &&
!std::is_base_of<DoubleArray, T>::value &&
!std::is_base_of<BooleanArray, T>::value,
Status>::type
Visit(const T& left) {
result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_));
return Status::OK();
}
// TODO nan-aware specialization for half-floats
Status Visit(const FloatArray& left) {
result_ =
FloatingEquals<FloatType>(left, checked_cast<const FloatArray&>(right_), opts_);
return Status::OK();
}
Status Visit(const DoubleArray& left) {
result_ =
FloatingEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_), opts_);
return Status::OK();
}
template <typename ArrayType>
bool ValueOffsetsEqual(const ArrayType& left) {
const auto& right = checked_cast<const ArrayType&>(right_);
if (left.offset() == 0 && right.offset() == 0) {
return left.value_offsets()->Equals(*right.value_offsets(),
(left.length() + 1) * sizeof(int32_t));
} else {
// One of the arrays is sliced; logic is more complicated because the
// value offsets are not both 0-based
auto left_offsets =
reinterpret_cast<const int32_t*>(left.value_offsets()->data()) + left.offset();
auto right_offsets =
reinterpret_cast<const int32_t*>(right.value_offsets()->data()) +
right.offset();
for (int64_t i = 0; i < left.length() + 1; ++i) {
if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) {
return false;
}
}
return true;
}
}
bool CompareBinary(const BinaryArray& left) {
const auto& right = checked_cast<const BinaryArray&>(right_);
bool equal_offsets = ValueOffsetsEqual<BinaryArray>(left);
if (!equal_offsets) {
return false;
}
if (!left.value_data() && !(right.value_data())) {
return true;
}
if (left.value_offset(left.length()) == left.value_offset(0)) {
return true;
}
const uint8_t* left_data = left.value_data()->data();
const uint8_t* right_data = right.value_data()->data();
if (left.null_count() == 0) {
// Fast path for null count 0, single memcmp
if (left.offset() == 0 && right.offset() == 0) {
return std::memcmp(left_data, right_data,
left.raw_value_offsets()[left.length()]) == 0;
} else {
const int64_t total_bytes =
left.value_offset(left.length()) - left.value_offset(0);
return std::memcmp(left_data + left.value_offset(0),
right_data + right.value_offset(0),
static_cast<size_t>(total_bytes)) == 0;
}
} else {
// ARROW-537: Only compare data in non-null slots
const int32_t* left_offsets = left.raw_value_offsets();
const int32_t* right_offsets = right.raw_value_offsets();
for (int64_t i = 0; i < left.length(); ++i) {
if (left.IsNull(i)) {
continue;
}
if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i],
left.value_length(i))) {
return false;
}
}
return true;
}
}
Status Visit(const BinaryArray& left) {
result_ = CompareBinary(left);
return Status::OK();
}
Status Visit(const ListArray& left) {
const auto& right = checked_cast<const ListArray&>(right_);
bool equal_offsets = ValueOffsetsEqual<ListArray>(left);
if (!equal_offsets) {
result_ = false;
return Status::OK();
}
result_ =
left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
right.value_offset(0), right.values());
return Status::OK();
}
Status Visit(const FixedSizeListArray& left) {
const auto& right = checked_cast<const FixedSizeListArray&>(right_);
result_ =
left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
right.value_offset(0), right.values());
return Status::OK();
}
Status Visit(const DictionaryArray& left) {
const auto& right = checked_cast<const DictionaryArray&>(right_);
if (!left.dictionary()->Equals(right.dictionary())) {
result_ = false;
} else {
result_ = left.indices()->Equals(right.indices());
}
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
Status>::type
Visit(const T& left) {
return RangeEqualsVisitor::Visit(left);
}
Status Visit(const ExtensionArray& left) {
result_ = (right_.type()->Equals(*left.type()) &&
ArrayEquals(*left.storage(),
*static_cast<const ExtensionArray&>(right_).storage()));
return Status::OK();
}
protected:
const EqualOptions opts_;
};
class ApproxEqualsVisitor : public ArrayEqualsVisitor {
public:
explicit ApproxEqualsVisitor(const Array& right, const EqualOptions& opts)
: ArrayEqualsVisitor(right, opts) {}
using ArrayEqualsVisitor::Visit;
// TODO half-floats
Status Visit(const FloatArray& left) {
result_ = FloatingApproxEquals<FloatType>(
left, checked_cast<const FloatArray&>(right_), opts_);
return Status::OK();
}
Status Visit(const DoubleArray& left) {
result_ = FloatingApproxEquals<DoubleType>(
left, checked_cast<const DoubleArray&>(right_), opts_);
return Status::OK();
}
};
static bool BaseDataEquals(const Array& left, const Array& right) {
if (left.length() != right.length() || left.null_count() != right.null_count() ||
left.type_id() != right.type_id()) {
return false;
}
// ARROW-2567: Ensure that not only the type id but also the type equality
// itself is checked.
if (!TypeEquals(*left.type(), *right.type(), false /* check_metadata */)) {
return false;
}
if (left.null_count() > 0 && left.null_count() < left.length()) {
return BitmapEquals(left.null_bitmap()->data(), left.offset(),
right.null_bitmap()->data(), right.offset(), left.length());
}
return true;
}
template <typename VISITOR, typename... Extra>
inline bool ArrayEqualsImpl(const Array& left, const Array& right, Extra&&... extra) {
bool are_equal;
// The arrays are the same object
if (&left == &right) {
are_equal = true;
} else if (!BaseDataEquals(left, right)) {
are_equal = false;
} else if (left.length() == 0) {
are_equal = true;
} else if (left.null_count() == left.length()) {
are_equal = true;
} else {
VISITOR visitor(right, std::forward<Extra>(extra)...);
auto error = VisitArrayInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Arrays are not comparable: " << error.ToString();
}
are_equal = visitor.result();
}
return are_equal;
}
class TypeEqualsVisitor {
public:
explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
: right_(right), check_metadata_(check_metadata), result_(false) {}
Status VisitChildren(const DataType& left) {
if (left.num_children() != right_.num_children()) {
result_ = false;
return Status::OK();
}
for (int i = 0; i < left.num_children(); ++i) {
if (!left.child(i)->Equals(right_.child(i), check_metadata_)) {
result_ = false;
return Status::OK();
}
}
result_ = true;
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<NoExtraMeta, T>::value ||
std::is_base_of<PrimitiveCType, T>::value,
Status>::type
Visit(const T&) {
result_ = true;
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<IntervalType, T>::value, Status>::type Visit(
const T& left) {
const auto& right = checked_cast<const IntervalType&>(right_);
result_ = right.interval_type() == left.interval_type();
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<TimeType, T>::value ||
std::is_base_of<DateType, T>::value ||
std::is_base_of<DurationType, T>::value,
Status>::type
Visit(const T& left) {
const auto& right = checked_cast<const T&>(right_);
result_ = left.unit() == right.unit();
return Status::OK();
}
Status Visit(const TimestampType& left) {
const auto& right = checked_cast<const TimestampType&>(right_);
result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
return Status::OK();
}
Status Visit(const FixedSizeBinaryType& left) {
const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
result_ = left.byte_width() == right.byte_width();
return Status::OK();
}
Status Visit(const Decimal128Type& left) {
const auto& right = checked_cast<const Decimal128Type&>(right_);
result_ = left.precision() == right.precision() && left.scale() == right.scale();
return Status::OK();
}
Status Visit(const ListType& left) { return VisitChildren(left); }
Status Visit(const MapType& left) {
const auto& right = checked_cast<const MapType&>(right_);
if (left.keys_sorted() != right.keys_sorted()) {
result_ = false;
return Status::OK();
}
return VisitChildren(left);
}
Status Visit(const FixedSizeListType& left) { return VisitChildren(left); }
Status Visit(const StructType& left) { return VisitChildren(left); }
Status Visit(const UnionType& left) {
const auto& right = checked_cast<const UnionType&>(right_);
if (left.mode() != right.mode() ||
left.type_codes().size() != right.type_codes().size()) {
result_ = false;
return Status::OK();
}
const std::vector<uint8_t>& left_codes = left.type_codes();
const std::vector<uint8_t>& right_codes = right.type_codes();
for (size_t i = 0; i < left_codes.size(); ++i) {
if (left_codes[i] != right_codes[i]) {
result_ = false;
return Status::OK();
}
}
for (int i = 0; i < left.num_children(); ++i) {
if (!left.child(i)->Equals(right_.child(i), check_metadata_)) {
result_ = false;
return Status::OK();
}
}
result_ = true;
return Status::OK();
}
Status Visit(const DictionaryType& left) {
const auto& right = checked_cast<const DictionaryType&>(right_);
result_ = left.index_type()->Equals(right.index_type()) &&
left.value_type()->Equals(right.value_type()) &&
(left.ordered() == right.ordered());
return Status::OK();
}
Status Visit(const ExtensionType& left) {
result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_));
return Status::OK();
}
bool result() const { return result_; }
protected:
const DataType& right_;
bool check_metadata_;
bool result_;
};
class ScalarEqualsVisitor {
public:
explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {}
Status Visit(const NullScalar& left) {
result_ = true;
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<internal::PrimitiveScalar, T>::value,
Status>::type
Visit(const T& left_) {
const auto& right = checked_cast<const T&>(right_);
result_ = right.value == left_.value;
return Status::OK();
}
template <typename T>
typename std::enable_if<std::is_base_of<BinaryScalar, T>::value, Status>::type Visit(
const T& left_) {
const auto& left = checked_cast<const BinaryScalar&>(left_);
const auto& right = checked_cast<const BinaryScalar&>(right_);
result_ = internal::SharedPtrEquals(left.value, right.value);
return Status::OK();
}
Status Visit(const Decimal128Scalar& left) {
const auto& right = checked_cast<const Decimal128Scalar&>(right_);
result_ = left.value == right.value;
return Status::OK();
}
Status Visit(const ListScalar& left) {
const auto& right = checked_cast<const ListScalar&>(right_);
result_ = internal::SharedPtrEquals(left.value, right.value);
return Status::OK();
}
Status Visit(const MapScalar& left) {
const auto& right = checked_cast<const MapScalar&>(right_);
result_ = internal::SharedPtrEquals(left.keys, right.keys) &&
internal::SharedPtrEquals(left.items, right.items);
return Status::OK();
}
Status Visit(const FixedSizeListScalar& left) {
const auto& right = checked_cast<const FixedSizeListScalar&>(right_);
result_ = internal::SharedPtrEquals(left.value, right.value);
return Status::OK();
}
Status Visit(const StructScalar& left) {
const auto& right = checked_cast<const StructScalar&>(right_);
if (right.value.size() != left.value.size()) {
result_ = false;
} else {
bool all_equals = true;
for (size_t i = 0; i < left.value.size() && all_equals; i++) {
all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
}
result_ = all_equals;
}
return Status::OK();
}
Status Visit(const UnionScalar& left) { return Status::NotImplemented("union"); }
Status Visit(const DictionaryScalar& left) {
return Status::NotImplemented("dictionary");
}
Status Visit(const ExtensionScalar& left) {
return Status::NotImplemented("extension");
}
bool result() const { return result_; }
protected:
const Scalar& right_;
bool result_;
};
} // namespace internal
bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
return internal::ArrayEqualsImpl<internal::ArrayEqualsVisitor>(left, right, opts);
}
bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
return internal::ArrayEqualsImpl<internal::ApproxEqualsVisitor>(left, right, opts);
}
bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
int64_t left_end_idx, int64_t right_start_idx) {
bool are_equal;
if (&left == &right) {
are_equal = true;
} else if (left.type_id() != right.type_id()) {
are_equal = false;
} else if (left.length() == 0) {
are_equal = true;
} else {
internal::RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx,
right_start_idx);
auto error = VisitArrayInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Arrays are not comparable: " << error.ToString();
}
are_equal = visitor.result();
}
return are_equal;
}
bool StridedTensorContentEquals(int dim_index, int64_t left_offset, int64_t right_offset,
int elem_size, const Tensor& left, const Tensor& right) {
if (dim_index == left.ndim() - 1) {
for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
if (memcmp(left.raw_data() + left_offset + i * left.strides()[dim_index],
right.raw_data() + right_offset + i * right.strides()[dim_index],
elem_size) != 0) {
return false;
}
}
return true;
}
for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
if (!StridedTensorContentEquals(dim_index + 1, left_offset, right_offset, elem_size,
left, right)) {
return false;
}
left_offset += left.strides()[dim_index];
right_offset += right.strides()[dim_index];
}
return true;
}
bool TensorEquals(const Tensor& left, const Tensor& right) {
bool are_equal;
// The arrays are the same object
if (&left == &right) {
are_equal = true;
} else if (left.type_id() != right.type_id()) {
are_equal = false;
} else if (left.size() == 0) {
are_equal = true;
} else {
if (!left.is_contiguous() || !right.is_contiguous()) {
const auto& shape = left.shape();
if (shape != right.shape()) {
are_equal = false;
} else {
const auto& type = checked_cast<const FixedWidthType&>(*left.type());
are_equal =
StridedTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
}
} else {
const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
const int byte_width = size_meta.bit_width() / CHAR_BIT;
DCHECK_GT(byte_width, 0);
const uint8_t* left_data = left.data()->data();
const uint8_t* right_data = right.data()->data();
are_equal = memcmp(left_data, right_data,
static_cast<size_t>(byte_width * left.size())) == 0;
}
}
return are_equal;
}
namespace {
template <typename LeftSparseIndexType, typename RightSparseIndexType>
struct SparseTensorEqualsImpl {
static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
const SparseTensorImpl<RightSparseIndexType>& right) {
// TODO(mrkn): should we support the equality among different formats?
return false;
}
};
template <typename SparseIndexType>
struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
const SparseTensorImpl<SparseIndexType>& right) {
DCHECK(left.type()->id() == right.type()->id());
DCHECK(left.shape() == right.shape());
DCHECK(left.non_zero_length() == right.non_zero_length());
const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
if (!left_index.Equals(right_index)) {
return false;
}
const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
const int byte_width = size_meta.bit_width() / CHAR_BIT;
DCHECK_GT(byte_width, 0);
const uint8_t* left_data = left.data()->data();
const uint8_t* right_data = right.data()->data();
return memcmp(left_data, right_data,
static_cast<size_t>(byte_width * left.non_zero_length()));
}
};
template <typename SparseIndexType>
inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
const SparseTensor& right) {
switch (right.format_id()) {
case SparseTensorFormat::COO: {
const auto& right_coo =
checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left,
right_coo);
}
case SparseTensorFormat::CSR: {
const auto& right_csr =
checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left,
right_csr);
}
default:
return false;
}
}
} // namespace
bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
if (&left == &right) {
return true;
} else if (left.type()->id() != right.type()->id()) {
return false;
} else if (left.size() == 0) {
return true;
} else if (left.shape() != right.shape()) {
return false;
} else if (left.non_zero_length() != right.non_zero_length()) {
return false;
}
switch (left.format_id()) {
case SparseTensorFormat::COO: {
const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
return SparseTensorEqualsImplDispatch(left_coo, right);
}
case SparseTensorFormat::CSR: {
const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
return SparseTensorEqualsImplDispatch(left_csr, right);
}
default:
return false;
}
}
bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
bool are_equal;
// The arrays are the same object
if (&left == &right) {
are_equal = true;
} else if (left.id() != right.id()) {
are_equal = false;
} else {
internal::TypeEqualsVisitor visitor(right, check_metadata);
auto error = VisitTypeInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Types are not comparable: " << error.ToString();
}
are_equal = visitor.result();
}
return are_equal;
}
bool ScalarEquals(const Scalar& left, const Scalar& right) {
bool are_equal = false;
if (&left == &right) {
are_equal = true;
} else if (!left.type->Equals(right.type)) {
are_equal = false;
} else if (left.is_valid != right.is_valid) {
are_equal = false;
} else {
internal::ScalarEqualsVisitor visitor(right);
auto error = VisitScalarInline(left, &visitor);
DCHECK_OK(error);
are_equal = visitor.result();
}
return are_equal;
}
} // namespace arrow