blob: 672308452cf13f842c5d39d2468cc86d1c0bf7c5 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/compute/kernels/test_util.h"
#include <cstdint>
#include <memory>
#include <string>
#include "arrow/array.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/function.h"
#include "arrow/compute/registry.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/testing/gtest_util.h"
namespace arrow {
namespace compute {
namespace {
template <typename T>
std::vector<Datum> GetDatums(const std::vector<T>& inputs) {
std::vector<Datum> datums;
for (const auto& input : inputs) {
datums.emplace_back(input);
}
return datums;
}
void CheckScalarNonRecursive(const std::string& func_name, const ArrayVector& inputs,
const std::shared_ptr<Array>& expected,
const FunctionOptions* options) {
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
std::shared_ptr<Array> actual = std::move(out).make_array();
ASSERT_OK(actual->ValidateFull());
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}
template <typename... SliceArgs>
ArrayVector SliceAll(const ArrayVector& inputs, SliceArgs... slice_args) {
ArrayVector sliced;
for (const auto& input : inputs) {
sliced.push_back(input->Slice(slice_args...));
}
return sliced;
}
ScalarVector GetScalars(const ArrayVector& inputs, int64_t index) {
ScalarVector scalars;
for (const auto& input : inputs) {
scalars.push_back(*input->GetScalar(index));
}
return scalars;
}
void CheckScalar(std::string func_name, const ScalarVector& inputs,
std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
if (!out.scalar()->Equals(expected)) {
std::string summary = func_name + "(";
for (const auto& input : inputs) {
summary += input->ToString() + ",";
}
summary.back() = ')';
summary += " = " + out.scalar()->ToString() + " != " + expected->ToString();
if (!out.type()->Equals(expected->type)) {
summary += " (types differed: " + out.type()->ToString() + " vs " +
expected->type->ToString() + ")";
}
FAIL() << summary;
}
}
void CheckScalar(std::string func_name, const ArrayVector& inputs,
std::shared_ptr<Array> expected, const FunctionOptions* options) {
CheckScalarNonRecursive(func_name, inputs, expected, options);
// Check all the input scalars, if scalars are implemented
if (std::none_of(inputs.begin(), inputs.end(), [](const std::shared_ptr<Array>& array) {
return array->type_id() == Type::EXTENSION;
})) {
for (int64_t i = 0; i < inputs[0]->length(); ++i) {
CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
}
}
// Since it's a scalar function, calling it on sliced inputs should
// result in the sliced expected output.
const auto slice_length = inputs[0]->length() / 3;
if (slice_length > 0) {
CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, slice_length),
expected->Slice(0, slice_length), options);
CheckScalarNonRecursive(func_name, SliceAll(inputs, slice_length, slice_length),
expected->Slice(slice_length, slice_length), options);
CheckScalarNonRecursive(func_name, SliceAll(inputs, 2 * slice_length),
expected->Slice(2 * slice_length), options);
}
// Should also work with an empty slice
CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0),
options);
// Ditto with ChunkedArray inputs
if (slice_length > 0) {
std::vector<std::shared_ptr<ChunkedArray>> chunked_inputs;
chunked_inputs.reserve(inputs.size());
for (const auto& input : inputs) {
chunked_inputs.push_back(std::make_shared<ChunkedArray>(
ArrayVector{input->Slice(0, slice_length), input->Slice(slice_length)}));
}
ArrayVector expected_chunks{expected->Slice(0, slice_length),
expected->Slice(slice_length)};
ASSERT_OK_AND_ASSIGN(Datum out,
CallFunction(func_name, GetDatums(chunked_inputs), options));
ASSERT_OK(out.chunked_array()->ValidateFull());
AssertDatumsEqual(std::make_shared<ChunkedArray>(expected_chunks), out);
}
}
} // namespace
void CheckScalarUnary(std::string func_name, std::shared_ptr<Array> input,
std::shared_ptr<Array> expected, const FunctionOptions* options) {
CheckScalar(std::move(func_name), {input}, expected, options);
}
void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
std::string json_input, std::shared_ptr<DataType> out_ty,
std::string json_expected, const FunctionOptions* options) {
CheckScalarUnary(std::move(func_name), ArrayFromJSON(in_ty, json_input),
ArrayFromJSON(out_ty, json_expected), options);
}
void CheckScalarUnary(std::string func_name, std::shared_ptr<Scalar> input,
std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
CheckScalar(std::move(func_name), {input}, expected, options);
}
void CheckVectorUnary(std::string func_name, Datum input, std::shared_ptr<Array> expected,
const FunctionOptions* options) {
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}, options));
std::shared_ptr<Array> actual = std::move(out).make_array();
ASSERT_OK(actual->ValidateFull());
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}
void CheckScalarBinary(std::string func_name, std::shared_ptr<Scalar> left_input,
std::shared_ptr<Scalar> right_input,
std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
}
void CheckScalarBinary(std::string func_name, std::shared_ptr<Array> left_input,
std::shared_ptr<Array> right_input,
std::shared_ptr<Array> expected, const FunctionOptions* options) {
CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
}
void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> original_values,
std::vector<ValueDescr> expected_equivalent_values) {
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
auto values = original_values;
ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values));
ASSERT_OK_AND_ASSIGN(auto expected_kernel,
function->DispatchExact(expected_equivalent_values));
EXPECT_EQ(actual_kernel, expected_kernel)
<< " DispatchBest" << ValueDescr::ToString(original_values) << " => "
<< actual_kernel->signature->ToString() << "\n"
<< " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => "
<< expected_kernel->signature->ToString();
}
void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> values) {
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
ASSERT_NOT_OK(function->DispatchBest(&values));
ASSERT_NOT_OK(function->DispatchExact(values));
}
} // namespace compute
} // namespace arrow