blob: a5ef9d44e1878a810a65f0f239c7940e1accac8a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <memory>
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "arrow/compute/kernel.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"
#include "arrow/util/key_value_metadata.h"
namespace arrow {
namespace compute {
// ----------------------------------------------------------------------
// TypeMatcher
TEST(TypeMatcher, SameTypeId) {
std::shared_ptr<TypeMatcher> matcher = match::SameTypeId(Type::DECIMAL);
ASSERT_TRUE(matcher->Matches(*decimal(12, 2)));
ASSERT_FALSE(matcher->Matches(*int8()));
ASSERT_EQ("Type::DECIMAL128", matcher->ToString());
ASSERT_TRUE(matcher->Equals(*matcher));
ASSERT_TRUE(matcher->Equals(*match::SameTypeId(Type::DECIMAL)));
ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP)));
}
TEST(TypeMatcher, TimestampTypeUnit) {
auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI);
auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI);
ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI)));
ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc")));
ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND)));
ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI)));
ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI)));
// Check ToString representation
ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString());
ASSERT_EQ("timestamp(ms)", match::TimestampTypeUnit(TimeUnit::MILLI)->ToString());
ASSERT_EQ("timestamp(us)", match::TimestampTypeUnit(TimeUnit::MICRO)->ToString());
ASSERT_EQ("timestamp(ns)", match::TimestampTypeUnit(TimeUnit::NANO)->ToString());
// Equals implementation
ASSERT_TRUE(matcher->Equals(*matcher));
ASSERT_TRUE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MILLI)));
ASSERT_FALSE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MICRO)));
ASSERT_FALSE(matcher->Equals(*match::Time32TypeUnit(TimeUnit::MILLI)));
}
// ----------------------------------------------------------------------
// InputType
TEST(InputType, AnyTypeConstructor) {
// Check the ANY_TYPE ctors
InputType ty;
ASSERT_EQ(InputType::ANY_TYPE, ty.kind());
ASSERT_EQ(ValueDescr::ANY, ty.shape());
ty = InputType(ValueDescr::SCALAR);
ASSERT_EQ(ValueDescr::SCALAR, ty.shape());
ty = InputType(ValueDescr::ARRAY);
ASSERT_EQ(ValueDescr::ARRAY, ty.shape());
}
TEST(InputType, Constructors) {
// Exact type constructor
InputType ty1(int8());
ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind());
ASSERT_EQ(ValueDescr::ANY, ty1.shape());
AssertTypeEqual(*int8(), *ty1.type());
InputType ty1_implicit = int8();
ASSERT_TRUE(ty1.Equals(ty1_implicit));
InputType ty1_array(int8(), ValueDescr::ARRAY);
ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape());
InputType ty1_scalar(int8(), ValueDescr::SCALAR);
ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape());
// Same type id constructor
InputType ty2(Type::DECIMAL);
ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind());
ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString());
ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2)));
ASSERT_FALSE(ty2.type_matcher().Matches(*int16()));
InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY);
ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape());
InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR);
ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape());
// Implicit construction in a vector
std::vector<InputType> types = {int8(), InputType(Type::DECIMAL)};
ASSERT_TRUE(types[0].Equals(ty1));
ASSERT_TRUE(types[1].Equals(ty2));
// Copy constructor
InputType ty3 = ty1;
InputType ty4 = ty2;
ASSERT_TRUE(ty3.Equals(ty1));
ASSERT_TRUE(ty4.Equals(ty2));
// Move constructor
InputType ty5 = std::move(ty3);
InputType ty6 = std::move(ty4);
ASSERT_TRUE(ty5.Equals(ty1));
ASSERT_TRUE(ty6.Equals(ty2));
// ToString
ASSERT_EQ("any[int8]", ty1.ToString());
ASSERT_EQ("array[int8]", ty1_array.ToString());
ASSERT_EQ("scalar[int8]", ty1_scalar.ToString());
ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString());
ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString());
ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString());
InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO));
ASSERT_EQ("any[timestamp(us)]", ty7.ToString());
InputType ty8;
InputType ty9(ValueDescr::ANY);
InputType ty10(ValueDescr::ARRAY);
InputType ty11(ValueDescr::SCALAR);
ASSERT_EQ("any[any]", ty8.ToString());
ASSERT_EQ("any[any]", ty9.ToString());
ASSERT_EQ("array[any]", ty10.ToString());
ASSERT_EQ("scalar[any]", ty11.ToString());
}
TEST(InputType, Equals) {
InputType t1 = int8();
InputType t2 = int8();
InputType t3(int8(), ValueDescr::ARRAY);
InputType t3_i32(int32(), ValueDescr::ARRAY);
InputType t3_scalar(int8(), ValueDescr::SCALAR);
InputType t4(int8(), ValueDescr::ARRAY);
InputType t4_i32(int32(), ValueDescr::ARRAY);
InputType t5(Type::DECIMAL);
InputType t6(Type::DECIMAL);
InputType t7(Type::DECIMAL, ValueDescr::SCALAR);
InputType t7_i32(Type::INT32, ValueDescr::SCALAR);
InputType t8(Type::DECIMAL, ValueDescr::SCALAR);
InputType t8_i32(Type::INT32, ValueDescr::SCALAR);
ASSERT_TRUE(t1.Equals(t2));
ASSERT_EQ(t1, t2);
// ANY vs SCALAR
ASSERT_NE(t1, t3);
ASSERT_EQ(t3, t4);
// both ARRAY, but different type
ASSERT_NE(t3, t3_i32);
// ARRAY vs SCALAR
ASSERT_NE(t3, t3_scalar);
ASSERT_EQ(t3_i32, t4_i32);
ASSERT_FALSE(t1.Equals(t5));
ASSERT_NE(t1, t5);
ASSERT_EQ(t5, t5);
ASSERT_EQ(t5, t6);
ASSERT_NE(t5, t7);
ASSERT_EQ(t7, t8);
ASSERT_EQ(t7, t8);
ASSERT_NE(t7, t7_i32);
ASSERT_EQ(t7_i32, t8_i32);
// NOTE: For the time being, we treat int32() and Type::INT32 as being
// different. This could obviously be fixed later to make these equivalent
ASSERT_NE(InputType(int8()), InputType(Type::INT32));
// Check that field metadata excluded from equality checks
InputType t9 = list(
field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"})));
InputType t10 = list(field("item", utf8()));
ASSERT_TRUE(t9.Equals(t10));
}
TEST(InputType, Hash) {
InputType t0;
InputType t0_scalar(ValueDescr::SCALAR);
InputType t0_array(ValueDescr::ARRAY);
InputType t1 = int8();
InputType t2(Type::DECIMAL);
// These checks try to determine first of all whether Hash always returns the
// same value, and whether the elements of the type are all incorporated into
// the Hash
ASSERT_EQ(t0.Hash(), t0.Hash());
ASSERT_NE(t0.Hash(), t0_scalar.Hash());
ASSERT_NE(t0.Hash(), t0_array.Hash());
ASSERT_NE(t0_scalar.Hash(), t0_array.Hash());
ASSERT_EQ(t1.Hash(), t1.Hash());
ASSERT_EQ(t2.Hash(), t2.Hash());
ASSERT_NE(t0.Hash(), t1.Hash());
ASSERT_NE(t0.Hash(), t2.Hash());
ASSERT_NE(t1.Hash(), t2.Hash());
}
TEST(InputType, Matches) {
InputType ty1 = int8();
ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8())));
ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8())));
ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8())));
ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16())));
InputType ty2(Type::DECIMAL);
ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2))));
ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2))));
ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64())));
InputType ty3(int64(), ValueDescr::SCALAR);
ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64())));
ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64())));
ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32())));
ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64())));
}
// ----------------------------------------------------------------------
// OutputType
TEST(OutputType, Constructors) {
OutputType ty1 = int8();
ASSERT_EQ(OutputType::FIXED, ty1.kind());
AssertTypeEqual(*int8(), *ty1.type());
auto DummyResolver = [](KernelContext*,
const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
return ValueDescr(int32(), GetBroadcastShape(args));
};
OutputType ty2(DummyResolver);
ASSERT_EQ(OutputType::COMPUTED, ty2.kind());
ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {}));
ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr2);
// Copy constructor
OutputType ty3 = ty1;
ASSERT_EQ(OutputType::FIXED, ty3.kind());
AssertTypeEqual(*ty1.type(), *ty3.type());
OutputType ty4 = ty2;
ASSERT_EQ(OutputType::COMPUTED, ty4.kind());
ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {}));
ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr4);
// Move constructor
OutputType ty5 = std::move(ty1);
ASSERT_EQ(OutputType::FIXED, ty5.kind());
AssertTypeEqual(*int8(), *ty5.type());
OutputType ty6 = std::move(ty4);
ASSERT_EQ(OutputType::COMPUTED, ty6.kind());
ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {}));
ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr6);
// ToString
// ty1 was copied to ty3
ASSERT_EQ("int8", ty3.ToString());
ASSERT_EQ("computed", ty2.ToString());
}
TEST(OutputType, Resolve) {
// Check shape promotion rules for FIXED kind
OutputType ty1(int32());
ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {}));
ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
ASSERT_OK_AND_ASSIGN(descr,
ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)}));
ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
ASSERT_OK_AND_ASSIGN(descr,
ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR),
ValueDescr(int8(), ValueDescr::ARRAY)}));
ASSERT_EQ(ValueDescr::Array(int32()), descr);
OutputType ty2([](KernelContext*, const std::vector<ValueDescr>& args) {
return ValueDescr(args[0].type, GetBroadcastShape(args));
});
ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())}));
ASSERT_EQ(ValueDescr::Array(utf8()), descr);
// Type resolver that returns an error
OutputType ty3(
[](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
// NB: checking the value types versus the function arity should be
// validated elsewhere, so this is just for illustration purposes
if (args.size() == 0) {
return Status::Invalid("Need at least one argument");
}
return ValueDescr(args[0]);
});
ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {}));
// Type resolver that returns ValueDescr::ANY and needs type promotion
OutputType ty4(
[](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
return int32();
});
ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())}));
ASSERT_EQ(ValueDescr::Array(int32()), descr);
ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())}));
ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
}
TEST(OutputType, ResolveDescr) {
ValueDescr d1 = ValueDescr::Scalar(int32());
ValueDescr d2 = ValueDescr::Array(int32());
OutputType ty1(d1);
OutputType ty2(d2);
ASSERT_EQ(ValueDescr::SCALAR, ty1.shape());
ASSERT_EQ(ValueDescr::ARRAY, ty2.shape());
{
ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {}));
ASSERT_EQ(d1, descr);
}
{
ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {}));
ASSERT_EQ(d2, descr);
}
}
// ----------------------------------------------------------------------
// KernelSignature
TEST(KernelSignature, Basics) {
// (any[int8], scalar[decimal]) -> utf8
std::vector<InputType> in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)});
OutputType out_type(utf8());
KernelSignature sig(in_types, out_type);
ASSERT_EQ(2, sig.in_types().size());
ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8()));
ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8())));
ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8())));
ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2))));
ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2))));
}
TEST(KernelSignature, Equals) {
KernelSignature sig1({}, utf8());
KernelSignature sig1_copy({}, utf8());
KernelSignature sig2({int8()}, utf8());
// Output type doesn't matter (for now)
KernelSignature sig3({int8()}, int32());
KernelSignature sig4({int8(), int16()}, utf8());
KernelSignature sig4_copy({int8(), int16()}, utf8());
KernelSignature sig5({int8(), int16(), int32()}, utf8());
// Differ in shape
KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8());
KernelSignature sig7({ValueDescr::Array(int8())}, utf8());
ASSERT_EQ(sig1, sig1);
ASSERT_EQ(sig2, sig3);
ASSERT_NE(sig3, sig4);
// Different sig objects, but same sig
ASSERT_EQ(sig1, sig1_copy);
ASSERT_EQ(sig4, sig4_copy);
// Match first 2 args, but not third
ASSERT_NE(sig4, sig5);
ASSERT_NE(sig6, sig7);
}
TEST(KernelSignature, VarArgsEquals) {
KernelSignature sig1({int8()}, utf8(), /*is_varargs=*/true);
KernelSignature sig2({int8()}, utf8(), /*is_varargs=*/true);
KernelSignature sig3({int8()}, utf8());
ASSERT_EQ(sig1, sig2);
ASSERT_NE(sig2, sig3);
}
TEST(KernelSignature, Hash) {
// Some basic tests to ensure that the hashes are deterministic and that all
// input arguments are incorporated
KernelSignature sig1({}, utf8());
KernelSignature sig2({int8()}, utf8());
KernelSignature sig3({int8(), int32()}, utf8());
ASSERT_EQ(sig1.Hash(), sig1.Hash());
ASSERT_EQ(sig2.Hash(), sig2.Hash());
ASSERT_NE(sig1.Hash(), sig2.Hash());
ASSERT_NE(sig2.Hash(), sig3.Hash());
}
TEST(KernelSignature, MatchesInputs) {
// () -> boolean
KernelSignature sig1({}, boolean());
ASSERT_TRUE(sig1.MatchesInputs({}));
ASSERT_FALSE(sig1.MatchesInputs({int8()}));
// (any[int8], any[decimal]) -> boolean
KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean());
ASSERT_FALSE(sig2.MatchesInputs({}));
ASSERT_FALSE(sig2.MatchesInputs({int8()}));
ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)}));
ASSERT_TRUE(sig2.MatchesInputs(
{ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))}));
ASSERT_TRUE(
sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))}));
// (scalar[int8], array[int32]) -> boolean
KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())},
boolean());
ASSERT_FALSE(sig3.MatchesInputs({}));
// Unqualified, these are ANY type and do not match because the kernel
// requires a scalar and an array
ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()}));
ASSERT_TRUE(
sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}));
ASSERT_FALSE(
sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())}));
}
TEST(KernelSignature, VarArgsMatchesInputs) {
KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);
std::vector<ValueDescr> args = {int8()};
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(ValueDescr::Scalar(int8()));
args.push_back(ValueDescr::Array(int8()));
ASSERT_TRUE(sig.MatchesInputs(args));
args.push_back(int32());
ASSERT_FALSE(sig.MatchesInputs(args));
}
TEST(KernelSignature, ToString) {
std::vector<InputType> in_types = {InputType(int8(), ValueDescr::SCALAR),
InputType(Type::DECIMAL, ValueDescr::ARRAY),
InputType(utf8())};
KernelSignature sig(in_types, utf8());
ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string",
sig.ToString());
OutputType out_type([](KernelContext*, const std::vector<ValueDescr>& args) {
return Status::Invalid("NYI");
});
KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type);
ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString());
}
TEST(KernelSignature, VarArgsToString) {
KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);
ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString());
}
} // namespace compute
} // namespace arrow