// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/kernels/test_util.h"

#include <cstdint>
#include <memory>
#include <string>

#include "arrow/array.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/function.h"
#include "arrow/compute/registry.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {
namespace compute {

namespace {

template <typename T>
std::vector<Datum> GetDatums(const std::vector<T>& inputs) {
  std::vector<Datum> datums;
  for (const auto& input : inputs) {
    datums.emplace_back(input);
  }
  return datums;
}

void CheckScalarNonRecursive(const std::string& func_name, const ArrayVector& inputs,
                             const std::shared_ptr<Array>& expected,
                             const FunctionOptions* options) {
  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
  std::shared_ptr<Array> actual = std::move(out).make_array();
  ASSERT_OK(actual->ValidateFull());
  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}

template <typename... SliceArgs>
ArrayVector SliceAll(const ArrayVector& inputs, SliceArgs... slice_args) {
  ArrayVector sliced;
  for (const auto& input : inputs) {
    sliced.push_back(input->Slice(slice_args...));
  }
  return sliced;
}

ScalarVector GetScalars(const ArrayVector& inputs, int64_t index) {
  ScalarVector scalars;
  for (const auto& input : inputs) {
    scalars.push_back(*input->GetScalar(index));
  }
  return scalars;
}

void CheckScalar(std::string func_name, const ScalarVector& inputs,
                 std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
  if (!out.scalar()->Equals(expected)) {
    std::string summary = func_name + "(";
    for (const auto& input : inputs) {
      summary += input->ToString() + ",";
    }
    summary.back() = ')';

    summary += " = " + out.scalar()->ToString() + " != " + expected->ToString();

    if (!out.type()->Equals(expected->type)) {
      summary += " (types differed: " + out.type()->ToString() + " vs " +
                 expected->type->ToString() + ")";
    }

    FAIL() << summary;
  }
}

void CheckScalar(std::string func_name, const ArrayVector& inputs,
                 std::shared_ptr<Array> expected, const FunctionOptions* options) {
  CheckScalarNonRecursive(func_name, inputs, expected, options);

  // Check all the input scalars, if scalars are implemented
  if (std::none_of(inputs.begin(), inputs.end(), [](const std::shared_ptr<Array>& array) {
        return array->type_id() == Type::EXTENSION;
      })) {
    for (int64_t i = 0; i < inputs[0]->length(); ++i) {
      CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
    }
  }

  // Since it's a scalar function, calling it on sliced inputs should
  // result in the sliced expected output.
  const auto slice_length = inputs[0]->length() / 3;
  if (slice_length > 0) {
    CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, slice_length),
                            expected->Slice(0, slice_length), options);

    CheckScalarNonRecursive(func_name, SliceAll(inputs, slice_length, slice_length),
                            expected->Slice(slice_length, slice_length), options);

    CheckScalarNonRecursive(func_name, SliceAll(inputs, 2 * slice_length),
                            expected->Slice(2 * slice_length), options);
  }

  // Should also work with an empty slice
  CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0),
                          options);

  // Ditto with ChunkedArray inputs
  if (slice_length > 0) {
    std::vector<std::shared_ptr<ChunkedArray>> chunked_inputs;
    chunked_inputs.reserve(inputs.size());
    for (const auto& input : inputs) {
      chunked_inputs.push_back(std::make_shared<ChunkedArray>(
          ArrayVector{input->Slice(0, slice_length), input->Slice(slice_length)}));
    }
    ArrayVector expected_chunks{expected->Slice(0, slice_length),
                                expected->Slice(slice_length)};

    ASSERT_OK_AND_ASSIGN(Datum out,
                         CallFunction(func_name, GetDatums(chunked_inputs), options));
    ASSERT_OK(out.chunked_array()->ValidateFull());
    AssertDatumsEqual(std::make_shared<ChunkedArray>(expected_chunks), out);
  }
}

}  // namespace

void CheckScalarUnary(std::string func_name, std::shared_ptr<Array> input,
                      std::shared_ptr<Array> expected, const FunctionOptions* options) {
  CheckScalar(std::move(func_name), {input}, expected, options);
}

void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
                      std::string json_input, std::shared_ptr<DataType> out_ty,
                      std::string json_expected, const FunctionOptions* options) {
  CheckScalarUnary(std::move(func_name), ArrayFromJSON(in_ty, json_input),
                   ArrayFromJSON(out_ty, json_expected), options);
}

void CheckScalarUnary(std::string func_name, std::shared_ptr<Scalar> input,
                      std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
  CheckScalar(std::move(func_name), {input}, expected, options);
}

void CheckVectorUnary(std::string func_name, Datum input, std::shared_ptr<Array> expected,
                      const FunctionOptions* options) {
  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}, options));
  std::shared_ptr<Array> actual = std::move(out).make_array();
  ASSERT_OK(actual->ValidateFull());
  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}

void CheckScalarBinary(std::string func_name, std::shared_ptr<Scalar> left_input,
                       std::shared_ptr<Scalar> right_input,
                       std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
  CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
}

void CheckScalarBinary(std::string func_name, std::shared_ptr<Array> left_input,
                       std::shared_ptr<Array> right_input,
                       std::shared_ptr<Array> expected, const FunctionOptions* options) {
  CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
}

void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> original_values,
                       std::vector<ValueDescr> expected_equivalent_values) {
  ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));

  auto values = original_values;
  ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values));

  ASSERT_OK_AND_ASSIGN(auto expected_kernel,
                       function->DispatchExact(expected_equivalent_values));

  EXPECT_EQ(actual_kernel, expected_kernel)
      << "  DispatchBest" << ValueDescr::ToString(original_values) << " => "
      << actual_kernel->signature->ToString() << "\n"
      << "  DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => "
      << expected_kernel->signature->ToString();
}

void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> values) {
  ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
  ASSERT_NOT_OK(function->DispatchBest(&values));
  ASSERT_NOT_OK(function->DispatchExact(values));
}

}  // namespace compute
}  // namespace arrow
