// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// Functions for comparing Arrow data structures

#include "arrow/compare.h"

#include <climits>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/scalar.h"
#include "arrow/sparse_tensor.h"
#include "arrow/status.h"
#include "arrow/tensor.h"
#include "arrow/type.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/memory.h"
#include "arrow/visitor_inline.h"

namespace arrow {

using internal::BitmapEquals;
using internal::checked_cast;

// ----------------------------------------------------------------------
// Public method implementations

namespace internal {

// These helper functions assume we already checked the arrays have equal
// sizes and null bitmaps.

template <typename ArrowType, typename EqualityFunc>
inline bool BaseFloatingEquals(const NumericArray<ArrowType>& left,
                               const NumericArray<ArrowType>& right,
                               EqualityFunc&& equals) {
  using T = typename ArrowType::c_type;

  const T* left_data = left.raw_values();
  const T* right_data = right.raw_values();

  if (left.null_count() > 0) {
    for (int64_t i = 0; i < left.length(); ++i) {
      if (left.IsNull(i)) continue;
      if (!equals(left_data[i], right_data[i])) {
        return false;
      }
    }
  } else {
    for (int64_t i = 0; i < left.length(); ++i) {
      if (!equals(left_data[i], right_data[i])) {
        return false;
      }
    }
  }
  return true;
}

template <typename ArrowType>
inline bool FloatingEquals(const NumericArray<ArrowType>& left,
                           const NumericArray<ArrowType>& right,
                           const EqualOptions& opts) {
  using T = typename ArrowType::c_type;

  if (opts.nans_equal()) {
    return BaseFloatingEquals<ArrowType>(left, right, [](T x, T y) -> bool {
      return (x == y) || (std::isnan(x) && std::isnan(y));
    });
  } else {
    return BaseFloatingEquals<ArrowType>(left, right,
                                         [](T x, T y) -> bool { return x == y; });
  }
}

template <typename ArrowType>
inline bool FloatingApproxEquals(const NumericArray<ArrowType>& left,
                                 const NumericArray<ArrowType>& right,
                                 const EqualOptions& opts) {
  using T = typename ArrowType::c_type;
  const T epsilon = static_cast<T>(opts.atol());

  if (opts.nans_equal()) {
    return BaseFloatingEquals<ArrowType>(left, right, [epsilon](T x, T y) -> bool {
      return (fabs(x - y) <= epsilon) || (std::isnan(x) && std::isnan(y));
    });
  } else {
    return BaseFloatingEquals<ArrowType>(
        left, right, [epsilon](T x, T y) -> bool { return fabs(x - y) <= epsilon; });
  }
}

// RangeEqualsVisitor assumes the range sizes are equal

class RangeEqualsVisitor {
 public:
  RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
                     int64_t right_start_idx)
      : right_(right),
        left_start_idx_(left_start_idx),
        left_end_idx_(left_end_idx),
        right_start_idx_(right_start_idx),
        result_(false) {}

  template <typename ArrayType>
  inline Status CompareValues(const ArrayType& left) {
    const auto& right = checked_cast<const ArrayType&>(right_);

    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      const bool is_null = left.IsNull(i);
      if (is_null != right.IsNull(o_i) ||
          (!is_null && left.Value(i) != right.Value(o_i))) {
        result_ = false;
        return Status::OK();
      }
    }
    result_ = true;
    return Status::OK();
  }

  bool CompareBinaryRange(const BinaryArray& left) const {
    const auto& right = checked_cast<const BinaryArray&>(right_);

    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      const bool is_null = left.IsNull(i);
      if (is_null != right.IsNull(o_i)) {
        return false;
      }
      if (is_null) continue;
      const int32_t begin_offset = left.value_offset(i);
      const int32_t end_offset = left.value_offset(i + 1);
      const int32_t right_begin_offset = right.value_offset(o_i);
      const int32_t right_end_offset = right.value_offset(o_i + 1);
      // Underlying can't be equal if the size isn't equal
      if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
        return false;
      }

      if (end_offset - begin_offset > 0 &&
          std::memcmp(left.value_data()->data() + begin_offset,
                      right.value_data()->data() + right_begin_offset,
                      static_cast<size_t>(end_offset - begin_offset))) {
        return false;
      }
    }
    return true;
  }

  bool CompareLists(const ListArray& left) {
    const auto& right = checked_cast<const ListArray&>(right_);

    const std::shared_ptr<Array>& left_values = left.values();
    const std::shared_ptr<Array>& right_values = right.values();

    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      const bool is_null = left.IsNull(i);
      if (is_null != right.IsNull(o_i)) {
        return false;
      }
      if (is_null) continue;
      const int32_t begin_offset = left.value_offset(i);
      const int32_t end_offset = left.value_offset(i + 1);
      const int32_t right_begin_offset = right.value_offset(o_i);
      const int32_t right_end_offset = right.value_offset(o_i + 1);
      // Underlying can't be equal if the size isn't equal
      if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
        return false;
      }
      if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset,
                                    right_values)) {
        return false;
      }
    }
    return true;
  }

  bool CompareStructs(const StructArray& left) {
    const auto& right = checked_cast<const StructArray&>(right_);
    bool equal_fields = true;
    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      if (left.IsNull(i) != right.IsNull(o_i)) {
        return false;
      }
      if (left.IsNull(i)) continue;
      for (int j = 0; j < left.num_fields(); ++j) {
        // TODO: really we should be comparing stretches of non-null data rather
        // than looking at one value at a time.
        equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j));
        if (!equal_fields) {
          return false;
        }
      }
    }
    return true;
  }

  bool CompareUnions(const UnionArray& left) const {
    const auto& right = checked_cast<const UnionArray&>(right_);

    const UnionMode::type union_mode = left.mode();
    if (union_mode != right.mode()) {
      return false;
    }

    const auto& left_type = checked_cast<const UnionType&>(*left.type());

    // Define a mapping from the type id to child number
    uint8_t max_code = 0;

    const std::vector<uint8_t>& type_codes = left_type.type_codes();
    for (size_t i = 0; i < type_codes.size(); ++i) {
      const uint8_t code = type_codes[i];
      if (code > max_code) {
        max_code = code;
      }
    }

    // Store mapping in a vector for constant time lookups
    std::vector<uint8_t> type_id_to_child_num(max_code + 1);
    for (uint8_t i = 0; i < static_cast<uint8_t>(type_codes.size()); ++i) {
      type_id_to_child_num[type_codes[i]] = i;
    }

    const uint8_t* left_ids = left.raw_type_ids();
    const uint8_t* right_ids = right.raw_type_ids();

    uint8_t id, child_num;
    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      if (left.IsNull(i) != right.IsNull(o_i)) {
        return false;
      }
      if (left.IsNull(i)) continue;
      if (left_ids[i] != right_ids[o_i]) {
        return false;
      }

      id = left_ids[i];
      child_num = type_id_to_child_num[id];

      // TODO(wesm): really we should be comparing stretches of non-null data
      // rather than looking at one value at a time.
      if (union_mode == UnionMode::SPARSE) {
        if (!left.child(child_num)->RangeEquals(i, i + 1, o_i, right.child(child_num))) {
          return false;
        }
      } else {
        const int32_t offset = left.raw_value_offsets()[i];
        const int32_t o_offset = right.raw_value_offsets()[o_i];
        if (!left.child(child_num)->RangeEquals(offset, offset + 1, o_offset,
                                                right.child(child_num))) {
          return false;
        }
      }
    }
    return true;
  }

  Status Visit(const BinaryArray& left) {
    result_ = CompareBinaryRange(left);
    return Status::OK();
  }

  Status Visit(const FixedSizeBinaryArray& left) {
    const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_);

    int32_t width = left.byte_width();

    const uint8_t* left_data = nullptr;
    const uint8_t* right_data = nullptr;

    if (left.values()) {
      left_data = left.raw_values();
    }

    if (right.values()) {
      right_data = right.raw_values();
    }

    for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
         ++i, ++o_i) {
      const bool is_null = left.IsNull(i);
      if (is_null != right.IsNull(o_i)) {
        result_ = false;
        return Status::OK();
      }
      if (is_null) continue;

      if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
        result_ = false;
        return Status::OK();
      }
    }
    result_ = true;
    return Status::OK();
  }

  Status Visit(const Decimal128Array& left) {
    return Visit(checked_cast<const FixedSizeBinaryArray&>(left));
  }

  Status Visit(const NullArray& left) {
    ARROW_UNUSED(left);
    result_ = true;
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
      const T& left) {
    return CompareValues<T>(left);
  }

  Status Visit(const ListArray& left) {
    result_ = CompareLists(left);
    return Status::OK();
  }

  Status Visit(const FixedSizeListArray& left) {
    const auto& right = checked_cast<const FixedSizeListArray&>(right_);
    result_ = left.values()->RangeEquals(
        left.value_offset(left_start_idx_), left.value_offset(left_end_idx_),
        right.value_offset(right_start_idx_), right.values());
    return Status::OK();
  }

  Status Visit(const StructArray& left) {
    result_ = CompareStructs(left);
    return Status::OK();
  }

  Status Visit(const UnionArray& left) {
    result_ = CompareUnions(left);
    return Status::OK();
  }

  Status Visit(const DictionaryArray& left) {
    const auto& right = checked_cast<const DictionaryArray&>(right_);
    if (!left.dictionary()->Equals(right.dictionary())) {
      result_ = false;
      return Status::OK();
    }
    result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_,
                                          right_start_idx_, right.indices());
    return Status::OK();
  }

  Status Visit(const ExtensionArray& left) {
    result_ = (right_.type()->Equals(*left.type()) &&
               ArrayRangeEquals(*left.storage(),
                                *static_cast<const ExtensionArray&>(right_).storage(),
                                left_start_idx_, left_end_idx_, right_start_idx_));
    return Status::OK();
  }

  bool result() const { return result_; }

 protected:
  const Array& right_;
  int64_t left_start_idx_;
  int64_t left_end_idx_;
  int64_t right_start_idx_;

  bool result_;
};

static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
  const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
  const int byte_width = size_meta.bit_width() / CHAR_BIT;

  const uint8_t* left_data = nullptr;
  const uint8_t* right_data = nullptr;

  if (left.values()) {
    left_data = left.values()->data() + left.offset() * byte_width;
  }

  if (right.values()) {
    right_data = right.values()->data() + right.offset() * byte_width;
  }

  if (byte_width == 0) {
    // Special case 0-width data, as the data pointers may be null
    for (int64_t i = 0; i < left.length(); ++i) {
      if (left.IsNull(i) != right.IsNull(i)) {
        return false;
      }
    }
    return true;
  } else if (left.null_count() > 0) {
    for (int64_t i = 0; i < left.length(); ++i) {
      const bool left_null = left.IsNull(i);
      const bool right_null = right.IsNull(i);
      if (left_null != right_null) {
        return false;
      }
      if (!left_null && memcmp(left_data, right_data, byte_width) != 0) {
        return false;
      }
      left_data += byte_width;
      right_data += byte_width;
    }
    return true;
  } else {
    auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length());
    return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0;
  }
}

// A bit confusing: ArrayEqualsVisitor inherits from RangeEqualsVisitor but
// doesn't share the same preconditions.
// When RangeEqualsVisitor is called, we only know the range sizes equal.
// When ArrayEqualsVisitor is called, we know the sizes and null bitmaps are equal.

class ArrayEqualsVisitor : public RangeEqualsVisitor {
 public:
  explicit ArrayEqualsVisitor(const Array& right, const EqualOptions& opts)
      : RangeEqualsVisitor(right, 0, right.length(), 0), opts_(opts) {}

  Status Visit(const NullArray& left) {
    ARROW_UNUSED(left);
    result_ = true;
    return Status::OK();
  }

  Status Visit(const BooleanArray& left) {
    const auto& right = checked_cast<const BooleanArray&>(right_);

    if (left.null_count() > 0) {
      const uint8_t* left_data = left.values()->data();
      const uint8_t* right_data = right.values()->data();

      for (int64_t i = 0; i < left.length(); ++i) {
        if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) !=
                                   BitUtil::GetBit(right_data, i + right.offset())) {
          result_ = false;
          return Status::OK();
        }
      }
      result_ = true;
    } else {
      result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(),
                             right.offset(), left.length());
    }
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
                              !std::is_base_of<FloatArray, T>::value &&
                              !std::is_base_of<DoubleArray, T>::value &&
                              !std::is_base_of<BooleanArray, T>::value,
                          Status>::type
  Visit(const T& left) {
    result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_));
    return Status::OK();
  }

  // TODO nan-aware specialization for half-floats

  Status Visit(const FloatArray& left) {
    result_ =
        FloatingEquals<FloatType>(left, checked_cast<const FloatArray&>(right_), opts_);
    return Status::OK();
  }

  Status Visit(const DoubleArray& left) {
    result_ =
        FloatingEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_), opts_);
    return Status::OK();
  }

  template <typename ArrayType>
  bool ValueOffsetsEqual(const ArrayType& left) {
    const auto& right = checked_cast<const ArrayType&>(right_);

    if (left.offset() == 0 && right.offset() == 0) {
      return left.value_offsets()->Equals(*right.value_offsets(),
                                          (left.length() + 1) * sizeof(int32_t));
    } else {
      // One of the arrays is sliced; logic is more complicated because the
      // value offsets are not both 0-based
      auto left_offsets =
          reinterpret_cast<const int32_t*>(left.value_offsets()->data()) + left.offset();
      auto right_offsets =
          reinterpret_cast<const int32_t*>(right.value_offsets()->data()) +
          right.offset();

      for (int64_t i = 0; i < left.length() + 1; ++i) {
        if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) {
          return false;
        }
      }
      return true;
    }
  }

  bool CompareBinary(const BinaryArray& left) {
    const auto& right = checked_cast<const BinaryArray&>(right_);

    bool equal_offsets = ValueOffsetsEqual<BinaryArray>(left);
    if (!equal_offsets) {
      return false;
    }

    if (!left.value_data() && !(right.value_data())) {
      return true;
    }
    if (left.value_offset(left.length()) == left.value_offset(0)) {
      return true;
    }

    const uint8_t* left_data = left.value_data()->data();
    const uint8_t* right_data = right.value_data()->data();

    if (left.null_count() == 0) {
      // Fast path for null count 0, single memcmp
      if (left.offset() == 0 && right.offset() == 0) {
        return std::memcmp(left_data, right_data,
                           left.raw_value_offsets()[left.length()]) == 0;
      } else {
        const int64_t total_bytes =
            left.value_offset(left.length()) - left.value_offset(0);
        return std::memcmp(left_data + left.value_offset(0),
                           right_data + right.value_offset(0),
                           static_cast<size_t>(total_bytes)) == 0;
      }
    } else {
      // ARROW-537: Only compare data in non-null slots
      const int32_t* left_offsets = left.raw_value_offsets();
      const int32_t* right_offsets = right.raw_value_offsets();
      for (int64_t i = 0; i < left.length(); ++i) {
        if (left.IsNull(i)) {
          continue;
        }
        if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i],
                        left.value_length(i))) {
          return false;
        }
      }
      return true;
    }
  }

  Status Visit(const BinaryArray& left) {
    result_ = CompareBinary(left);
    return Status::OK();
  }

  Status Visit(const ListArray& left) {
    const auto& right = checked_cast<const ListArray&>(right_);
    bool equal_offsets = ValueOffsetsEqual<ListArray>(left);
    if (!equal_offsets) {
      result_ = false;
      return Status::OK();
    }

    result_ =
        left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
                                   right.value_offset(0), right.values());
    return Status::OK();
  }

  Status Visit(const FixedSizeListArray& left) {
    const auto& right = checked_cast<const FixedSizeListArray&>(right_);
    result_ =
        left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
                                   right.value_offset(0), right.values());
    return Status::OK();
  }

  Status Visit(const DictionaryArray& left) {
    const auto& right = checked_cast<const DictionaryArray&>(right_);
    if (!left.dictionary()->Equals(right.dictionary())) {
      result_ = false;
    } else {
      result_ = left.indices()->Equals(right.indices());
    }
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
                          Status>::type
  Visit(const T& left) {
    return RangeEqualsVisitor::Visit(left);
  }

  Status Visit(const ExtensionArray& left) {
    result_ = (right_.type()->Equals(*left.type()) &&
               ArrayEquals(*left.storage(),
                           *static_cast<const ExtensionArray&>(right_).storage()));
    return Status::OK();
  }

 protected:
  const EqualOptions opts_;
};

class ApproxEqualsVisitor : public ArrayEqualsVisitor {
 public:
  explicit ApproxEqualsVisitor(const Array& right, const EqualOptions& opts)
      : ArrayEqualsVisitor(right, opts) {}

  using ArrayEqualsVisitor::Visit;

  // TODO half-floats

  Status Visit(const FloatArray& left) {
    result_ = FloatingApproxEquals<FloatType>(
        left, checked_cast<const FloatArray&>(right_), opts_);
    return Status::OK();
  }

  Status Visit(const DoubleArray& left) {
    result_ = FloatingApproxEquals<DoubleType>(
        left, checked_cast<const DoubleArray&>(right_), opts_);
    return Status::OK();
  }
};

static bool BaseDataEquals(const Array& left, const Array& right) {
  if (left.length() != right.length() || left.null_count() != right.null_count() ||
      left.type_id() != right.type_id()) {
    return false;
  }
  // ARROW-2567: Ensure that not only the type id but also the type equality
  // itself is checked.
  if (!TypeEquals(*left.type(), *right.type(), false /* check_metadata */)) {
    return false;
  }
  if (left.null_count() > 0 && left.null_count() < left.length()) {
    return BitmapEquals(left.null_bitmap()->data(), left.offset(),
                        right.null_bitmap()->data(), right.offset(), left.length());
  }
  return true;
}

template <typename VISITOR, typename... Extra>
inline bool ArrayEqualsImpl(const Array& left, const Array& right, Extra&&... extra) {
  bool are_equal;
  // The arrays are the same object
  if (&left == &right) {
    are_equal = true;
  } else if (!BaseDataEquals(left, right)) {
    are_equal = false;
  } else if (left.length() == 0) {
    are_equal = true;
  } else if (left.null_count() == left.length()) {
    are_equal = true;
  } else {
    VISITOR visitor(right, std::forward<Extra>(extra)...);
    auto error = VisitArrayInline(left, &visitor);
    if (!error.ok()) {
      DCHECK(false) << "Arrays are not comparable: " << error.ToString();
    }
    are_equal = visitor.result();
  }
  return are_equal;
}

class TypeEqualsVisitor {
 public:
  explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
      : right_(right), check_metadata_(check_metadata), result_(false) {}

  Status VisitChildren(const DataType& left) {
    if (left.num_children() != right_.num_children()) {
      result_ = false;
      return Status::OK();
    }

    for (int i = 0; i < left.num_children(); ++i) {
      if (!left.child(i)->Equals(right_.child(i), check_metadata_)) {
        result_ = false;
        return Status::OK();
      }
    }
    result_ = true;
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<NoExtraMeta, T>::value ||
                              std::is_base_of<PrimitiveCType, T>::value,
                          Status>::type
  Visit(const T&) {
    result_ = true;
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<IntervalType, T>::value, Status>::type Visit(
      const T& left) {
    const auto& right = checked_cast<const IntervalType&>(right_);
    result_ = right.interval_type() == left.interval_type();
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<TimeType, T>::value ||
                              std::is_base_of<DateType, T>::value ||
                              std::is_base_of<DurationType, T>::value,
                          Status>::type
  Visit(const T& left) {
    const auto& right = checked_cast<const T&>(right_);
    result_ = left.unit() == right.unit();
    return Status::OK();
  }

  Status Visit(const TimestampType& left) {
    const auto& right = checked_cast<const TimestampType&>(right_);
    result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
    return Status::OK();
  }

  Status Visit(const FixedSizeBinaryType& left) {
    const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
    result_ = left.byte_width() == right.byte_width();
    return Status::OK();
  }

  Status Visit(const Decimal128Type& left) {
    const auto& right = checked_cast<const Decimal128Type&>(right_);
    result_ = left.precision() == right.precision() && left.scale() == right.scale();
    return Status::OK();
  }

  Status Visit(const ListType& left) { return VisitChildren(left); }

  Status Visit(const MapType& left) {
    const auto& right = checked_cast<const MapType&>(right_);
    if (left.keys_sorted() != right.keys_sorted()) {
      result_ = false;
      return Status::OK();
    }
    return VisitChildren(left);
  }

  Status Visit(const FixedSizeListType& left) { return VisitChildren(left); }

  Status Visit(const StructType& left) { return VisitChildren(left); }

  Status Visit(const UnionType& left) {
    const auto& right = checked_cast<const UnionType&>(right_);

    if (left.mode() != right.mode() ||
        left.type_codes().size() != right.type_codes().size()) {
      result_ = false;
      return Status::OK();
    }

    const std::vector<uint8_t>& left_codes = left.type_codes();
    const std::vector<uint8_t>& right_codes = right.type_codes();

    for (size_t i = 0; i < left_codes.size(); ++i) {
      if (left_codes[i] != right_codes[i]) {
        result_ = false;
        return Status::OK();
      }
    }

    for (int i = 0; i < left.num_children(); ++i) {
      if (!left.child(i)->Equals(right_.child(i), check_metadata_)) {
        result_ = false;
        return Status::OK();
      }
    }

    result_ = true;
    return Status::OK();
  }

  Status Visit(const DictionaryType& left) {
    const auto& right = checked_cast<const DictionaryType&>(right_);
    result_ = left.index_type()->Equals(right.index_type()) &&
              left.value_type()->Equals(right.value_type()) &&
              (left.ordered() == right.ordered());
    return Status::OK();
  }

  Status Visit(const ExtensionType& left) {
    result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_));
    return Status::OK();
  }

  bool result() const { return result_; }

 protected:
  const DataType& right_;
  bool check_metadata_;
  bool result_;
};

class ScalarEqualsVisitor {
 public:
  explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {}

  Status Visit(const NullScalar& left) {
    result_ = true;
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<internal::PrimitiveScalar, T>::value,
                          Status>::type
  Visit(const T& left_) {
    const auto& right = checked_cast<const T&>(right_);
    result_ = right.value == left_.value;
    return Status::OK();
  }

  template <typename T>
  typename std::enable_if<std::is_base_of<BinaryScalar, T>::value, Status>::type Visit(
      const T& left_) {
    const auto& left = checked_cast<const BinaryScalar&>(left_);
    const auto& right = checked_cast<const BinaryScalar&>(right_);
    result_ = internal::SharedPtrEquals(left.value, right.value);
    return Status::OK();
  }

  Status Visit(const Decimal128Scalar& left) {
    const auto& right = checked_cast<const Decimal128Scalar&>(right_);
    result_ = left.value == right.value;
    return Status::OK();
  }

  Status Visit(const ListScalar& left) {
    const auto& right = checked_cast<const ListScalar&>(right_);
    result_ = internal::SharedPtrEquals(left.value, right.value);
    return Status::OK();
  }

  Status Visit(const MapScalar& left) {
    const auto& right = checked_cast<const MapScalar&>(right_);
    result_ = internal::SharedPtrEquals(left.keys, right.keys) &&
              internal::SharedPtrEquals(left.items, right.items);
    return Status::OK();
  }

  Status Visit(const FixedSizeListScalar& left) {
    const auto& right = checked_cast<const FixedSizeListScalar&>(right_);
    result_ = internal::SharedPtrEquals(left.value, right.value);
    return Status::OK();
  }

  Status Visit(const StructScalar& left) {
    const auto& right = checked_cast<const StructScalar&>(right_);

    if (right.value.size() != left.value.size()) {
      result_ = false;
    } else {
      bool all_equals = true;
      for (size_t i = 0; i < left.value.size() && all_equals; i++) {
        all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
      }
      result_ = all_equals;
    }

    return Status::OK();
  }

  Status Visit(const UnionScalar& left) { return Status::NotImplemented("union"); }

  Status Visit(const DictionaryScalar& left) {
    return Status::NotImplemented("dictionary");
  }

  Status Visit(const ExtensionScalar& left) {
    return Status::NotImplemented("extension");
  }

  bool result() const { return result_; }

 protected:
  const Scalar& right_;
  bool result_;
};

}  // namespace internal

bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
  return internal::ArrayEqualsImpl<internal::ArrayEqualsVisitor>(left, right, opts);
}

bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
  return internal::ArrayEqualsImpl<internal::ApproxEqualsVisitor>(left, right, opts);
}

bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
                      int64_t left_end_idx, int64_t right_start_idx) {
  bool are_equal;
  if (&left == &right) {
    are_equal = true;
  } else if (left.type_id() != right.type_id()) {
    are_equal = false;
  } else if (left.length() == 0) {
    are_equal = true;
  } else {
    internal::RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx,
                                         right_start_idx);
    auto error = VisitArrayInline(left, &visitor);
    if (!error.ok()) {
      DCHECK(false) << "Arrays are not comparable: " << error.ToString();
    }
    are_equal = visitor.result();
  }
  return are_equal;
}

bool StridedTensorContentEquals(int dim_index, int64_t left_offset, int64_t right_offset,
                                int elem_size, const Tensor& left, const Tensor& right) {
  if (dim_index == left.ndim() - 1) {
    for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
      if (memcmp(left.raw_data() + left_offset + i * left.strides()[dim_index],
                 right.raw_data() + right_offset + i * right.strides()[dim_index],
                 elem_size) != 0) {
        return false;
      }
    }
    return true;
  }
  for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
    if (!StridedTensorContentEquals(dim_index + 1, left_offset, right_offset, elem_size,
                                    left, right)) {
      return false;
    }
    left_offset += left.strides()[dim_index];
    right_offset += right.strides()[dim_index];
  }
  return true;
}

bool TensorEquals(const Tensor& left, const Tensor& right) {
  bool are_equal;
  // The arrays are the same object
  if (&left == &right) {
    are_equal = true;
  } else if (left.type_id() != right.type_id()) {
    are_equal = false;
  } else if (left.size() == 0) {
    are_equal = true;
  } else {
    if (!left.is_contiguous() || !right.is_contiguous()) {
      const auto& shape = left.shape();
      if (shape != right.shape()) {
        are_equal = false;
      } else {
        const auto& type = checked_cast<const FixedWidthType&>(*left.type());
        are_equal =
            StridedTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
      }
    } else {
      const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
      const int byte_width = size_meta.bit_width() / CHAR_BIT;
      DCHECK_GT(byte_width, 0);

      const uint8_t* left_data = left.data()->data();
      const uint8_t* right_data = right.data()->data();

      are_equal = memcmp(left_data, right_data,
                         static_cast<size_t>(byte_width * left.size())) == 0;
    }
  }
  return are_equal;
}

namespace {

template <typename LeftSparseIndexType, typename RightSparseIndexType>
struct SparseTensorEqualsImpl {
  static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
                      const SparseTensorImpl<RightSparseIndexType>& right) {
    // TODO(mrkn): should we support the equality among different formats?
    return false;
  }
};

template <typename SparseIndexType>
struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
  static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
                      const SparseTensorImpl<SparseIndexType>& right) {
    DCHECK(left.type()->id() == right.type()->id());
    DCHECK(left.shape() == right.shape());
    DCHECK(left.non_zero_length() == right.non_zero_length());

    const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
    const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());

    if (!left_index.Equals(right_index)) {
      return false;
    }

    const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
    const int byte_width = size_meta.bit_width() / CHAR_BIT;
    DCHECK_GT(byte_width, 0);

    const uint8_t* left_data = left.data()->data();
    const uint8_t* right_data = right.data()->data();

    return memcmp(left_data, right_data,
                  static_cast<size_t>(byte_width * left.non_zero_length()));
  }
};

template <typename SparseIndexType>
inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
                                           const SparseTensor& right) {
  switch (right.format_id()) {
    case SparseTensorFormat::COO: {
      const auto& right_coo =
          checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
      return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left,
                                                                              right_coo);
    }

    case SparseTensorFormat::CSR: {
      const auto& right_csr =
          checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
      return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left,
                                                                              right_csr);
    }

    default:
      return false;
  }
}

}  // namespace

bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
  if (&left == &right) {
    return true;
  } else if (left.type()->id() != right.type()->id()) {
    return false;
  } else if (left.size() == 0) {
    return true;
  } else if (left.shape() != right.shape()) {
    return false;
  } else if (left.non_zero_length() != right.non_zero_length()) {
    return false;
  }

  switch (left.format_id()) {
    case SparseTensorFormat::COO: {
      const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
      return SparseTensorEqualsImplDispatch(left_coo, right);
    }

    case SparseTensorFormat::CSR: {
      const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
      return SparseTensorEqualsImplDispatch(left_csr, right);
    }

    default:
      return false;
  }
}

bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
  bool are_equal;
  // The arrays are the same object
  if (&left == &right) {
    are_equal = true;
  } else if (left.id() != right.id()) {
    are_equal = false;
  } else {
    internal::TypeEqualsVisitor visitor(right, check_metadata);
    auto error = VisitTypeInline(left, &visitor);
    if (!error.ok()) {
      DCHECK(false) << "Types are not comparable: " << error.ToString();
    }
    are_equal = visitor.result();
  }
  return are_equal;
}

bool ScalarEquals(const Scalar& left, const Scalar& right) {
  bool are_equal = false;
  if (&left == &right) {
    are_equal = true;
  } else if (!left.type->Equals(right.type)) {
    are_equal = false;
  } else if (left.is_valid != right.is_valid) {
    are_equal = false;
  } else {
    internal::ScalarEqualsVisitor visitor(right);
    auto error = VisitScalarInline(left, &visitor);
    DCHECK_OK(error);
    are_equal = visitor.result();
  }
  return are_equal;
}

}  // namespace arrow
