blob: aea3d8360e68ed4fc8fe2dba1fe825cba6a3bd9b [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
// IWYU pragma: begin_exports
#include <memory>
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include "arrow/array.h"
#include "arrow/datum.h"
#include "arrow/memory_pool.h"
#include "arrow/pretty_print.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"
#include "arrow/compute/kernel.h"
// IWYU pragma: end_exports
namespace arrow {
using internal::checked_cast;
namespace compute {
template <typename Type, typename T>
std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
const std::vector<T>& values,
const std::vector<bool>& is_valid) {
std::shared_ptr<Array> result;
if (is_valid.size() > 0) {
ArrayFromVector<Type, T>(type, is_valid, values, &result);
} else {
ArrayFromVector<Type, T>(type, values, &result);
}
return result;
}
template <typename Type, typename Enable = void>
struct DatumEqual {};
template <typename Type>
struct DatumEqual<Type, enable_if_floating_point<Type>> {
static constexpr double kArbitraryDoubleErrorBound = 1.0;
using ScalarType = typename TypeTraits<Type>::ScalarType;
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
auto left = checked_cast<const ScalarType*>(lhs.scalar().get());
auto right = checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(left->is_valid, right->is_valid);
ASSERT_EQ(left->type->id(), right->type->id());
ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
}
}
};
template <typename Type>
struct DatumEqual<Type, enable_if_integer<Type>> {
using ScalarType = typename TypeTraits<Type>::ScalarType;
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
auto left = checked_cast<const ScalarType*>(lhs.scalar().get());
auto right = checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(*left, *right);
}
}
};
void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
std::string json_input, std::shared_ptr<DataType> out_ty,
std::string json_expected,
const FunctionOptions* options = nullptr);
void CheckScalarUnary(std::string func_name, std::shared_ptr<Array> input,
std::shared_ptr<Array> expected,
const FunctionOptions* options = nullptr);
void CheckScalarUnary(std::string func_name, std::shared_ptr<Scalar> input,
std::shared_ptr<Scalar> expected,
const FunctionOptions* options = nullptr);
void CheckScalarBinary(std::string func_name, std::shared_ptr<Scalar> left_input,
std::shared_ptr<Scalar> right_input,
std::shared_ptr<Scalar> expected,
const FunctionOptions* options = nullptr);
void CheckScalarBinary(std::string func_name, std::shared_ptr<Array> left_input,
std::shared_ptr<Array> right_input,
std::shared_ptr<Array> expected,
const FunctionOptions* options = nullptr);
void CheckVectorUnary(std::string func_name, Datum input, std::shared_ptr<Array> expected,
const FunctionOptions* options = nullptr);
using BinaryTypes =
::testing::Types<BinaryType, LargeBinaryType, StringType, LargeStringType>;
using StringTypes = ::testing::Types<StringType, LargeStringType>;
static constexpr random::SeedType kRandomSeed = 0x0ff1ce;
template <template <typename> class DoTestFunctor>
void TestRandomPrimitiveCTypes() {
DoTestFunctor<Int8Type>::Test(int8());
DoTestFunctor<Int16Type>::Test(int16());
DoTestFunctor<Int32Type>::Test(int32());
DoTestFunctor<Int64Type>::Test(int64());
DoTestFunctor<UInt8Type>::Test(uint8());
DoTestFunctor<UInt16Type>::Test(uint16());
DoTestFunctor<UInt32Type>::Test(uint32());
DoTestFunctor<UInt64Type>::Test(uint64());
DoTestFunctor<FloatType>::Test(float32());
DoTestFunctor<DoubleType>::Test(float64());
DoTestFunctor<Date32Type>::Test(date32());
DoTestFunctor<Date64Type>::Test(date64());
DoTestFunctor<Time32Type>::Test(time32(TimeUnit::SECOND));
DoTestFunctor<Time64Type>::Test(time64(TimeUnit::MICRO));
DoTestFunctor<TimestampType>::Test(timestamp(TimeUnit::SECOND));
DoTestFunctor<TimestampType>::Test(timestamp(TimeUnit::MICRO));
DoTestFunctor<DurationType>::Test(duration(TimeUnit::MILLI));
}
// Check that DispatchBest on a given function yields the same Kernel as
// produced by DispatchExact on another set of ValueDescrs.
void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> descrs,
std::vector<ValueDescr> exact_descrs);
// Check that function fails to produce a Kernel for the set of ValueDescrs.
void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> descrs);
} // namespace compute
} // namespace arrow