| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include <gtest/gtest.h> |
| |
| #include "arrow/compute/kernel.h" |
| #include "arrow/status.h" |
| #include "arrow/testing/gtest_util.h" |
| #include "arrow/type.h" |
| #include "arrow/util/key_value_metadata.h" |
| |
| namespace arrow { |
| namespace compute { |
| |
| // ---------------------------------------------------------------------- |
| // TypeMatcher |
| |
| TEST(TypeMatcher, SameTypeId) { |
| std::shared_ptr<TypeMatcher> matcher = match::SameTypeId(Type::DECIMAL); |
| ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); |
| ASSERT_FALSE(matcher->Matches(*int8())); |
| |
| ASSERT_EQ("Type::DECIMAL128", matcher->ToString()); |
| |
| ASSERT_TRUE(matcher->Equals(*matcher)); |
| ASSERT_TRUE(matcher->Equals(*match::SameTypeId(Type::DECIMAL))); |
| ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP))); |
| } |
| |
| TEST(TypeMatcher, TimestampTypeUnit) { |
| auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI); |
| auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI); |
| |
| ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI))); |
| ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc"))); |
| ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND))); |
| ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI))); |
| ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI))); |
| |
| // Check ToString representation |
| ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString()); |
| ASSERT_EQ("timestamp(ms)", match::TimestampTypeUnit(TimeUnit::MILLI)->ToString()); |
| ASSERT_EQ("timestamp(us)", match::TimestampTypeUnit(TimeUnit::MICRO)->ToString()); |
| ASSERT_EQ("timestamp(ns)", match::TimestampTypeUnit(TimeUnit::NANO)->ToString()); |
| |
| // Equals implementation |
| ASSERT_TRUE(matcher->Equals(*matcher)); |
| ASSERT_TRUE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MILLI))); |
| ASSERT_FALSE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MICRO))); |
| ASSERT_FALSE(matcher->Equals(*match::Time32TypeUnit(TimeUnit::MILLI))); |
| } |
| |
| // ---------------------------------------------------------------------- |
| // InputType |
| |
| TEST(InputType, AnyTypeConstructor) { |
| // Check the ANY_TYPE ctors |
| InputType ty; |
| ASSERT_EQ(InputType::ANY_TYPE, ty.kind()); |
| ASSERT_EQ(ValueDescr::ANY, ty.shape()); |
| |
| ty = InputType(ValueDescr::SCALAR); |
| ASSERT_EQ(ValueDescr::SCALAR, ty.shape()); |
| |
| ty = InputType(ValueDescr::ARRAY); |
| ASSERT_EQ(ValueDescr::ARRAY, ty.shape()); |
| } |
| |
| TEST(InputType, Constructors) { |
| // Exact type constructor |
| InputType ty1(int8()); |
| ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind()); |
| ASSERT_EQ(ValueDescr::ANY, ty1.shape()); |
| AssertTypeEqual(*int8(), *ty1.type()); |
| |
| InputType ty1_implicit = int8(); |
| ASSERT_TRUE(ty1.Equals(ty1_implicit)); |
| |
| InputType ty1_array(int8(), ValueDescr::ARRAY); |
| ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape()); |
| |
| InputType ty1_scalar(int8(), ValueDescr::SCALAR); |
| ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape()); |
| |
| // Same type id constructor |
| InputType ty2(Type::DECIMAL); |
| ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); |
| ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); |
| ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); |
| ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); |
| |
| InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY); |
| ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape()); |
| |
| InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR); |
| ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape()); |
| |
| // Implicit construction in a vector |
| std::vector<InputType> types = {int8(), InputType(Type::DECIMAL)}; |
| ASSERT_TRUE(types[0].Equals(ty1)); |
| ASSERT_TRUE(types[1].Equals(ty2)); |
| |
| // Copy constructor |
| InputType ty3 = ty1; |
| InputType ty4 = ty2; |
| ASSERT_TRUE(ty3.Equals(ty1)); |
| ASSERT_TRUE(ty4.Equals(ty2)); |
| |
| // Move constructor |
| InputType ty5 = std::move(ty3); |
| InputType ty6 = std::move(ty4); |
| ASSERT_TRUE(ty5.Equals(ty1)); |
| ASSERT_TRUE(ty6.Equals(ty2)); |
| |
| // ToString |
| ASSERT_EQ("any[int8]", ty1.ToString()); |
| ASSERT_EQ("array[int8]", ty1_array.ToString()); |
| ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); |
| |
| ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); |
| ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString()); |
| ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString()); |
| |
| InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO)); |
| ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); |
| |
| InputType ty8; |
| InputType ty9(ValueDescr::ANY); |
| InputType ty10(ValueDescr::ARRAY); |
| InputType ty11(ValueDescr::SCALAR); |
| ASSERT_EQ("any[any]", ty8.ToString()); |
| ASSERT_EQ("any[any]", ty9.ToString()); |
| ASSERT_EQ("array[any]", ty10.ToString()); |
| ASSERT_EQ("scalar[any]", ty11.ToString()); |
| } |
| |
| TEST(InputType, Equals) { |
| InputType t1 = int8(); |
| InputType t2 = int8(); |
| InputType t3(int8(), ValueDescr::ARRAY); |
| InputType t3_i32(int32(), ValueDescr::ARRAY); |
| InputType t3_scalar(int8(), ValueDescr::SCALAR); |
| InputType t4(int8(), ValueDescr::ARRAY); |
| InputType t4_i32(int32(), ValueDescr::ARRAY); |
| |
| InputType t5(Type::DECIMAL); |
| InputType t6(Type::DECIMAL); |
| InputType t7(Type::DECIMAL, ValueDescr::SCALAR); |
| InputType t7_i32(Type::INT32, ValueDescr::SCALAR); |
| InputType t8(Type::DECIMAL, ValueDescr::SCALAR); |
| InputType t8_i32(Type::INT32, ValueDescr::SCALAR); |
| |
| ASSERT_TRUE(t1.Equals(t2)); |
| ASSERT_EQ(t1, t2); |
| |
| // ANY vs SCALAR |
| ASSERT_NE(t1, t3); |
| |
| ASSERT_EQ(t3, t4); |
| |
| // both ARRAY, but different type |
| ASSERT_NE(t3, t3_i32); |
| |
| // ARRAY vs SCALAR |
| ASSERT_NE(t3, t3_scalar); |
| |
| ASSERT_EQ(t3_i32, t4_i32); |
| |
| ASSERT_FALSE(t1.Equals(t5)); |
| ASSERT_NE(t1, t5); |
| |
| ASSERT_EQ(t5, t5); |
| ASSERT_EQ(t5, t6); |
| ASSERT_NE(t5, t7); |
| ASSERT_EQ(t7, t8); |
| ASSERT_EQ(t7, t8); |
| ASSERT_NE(t7, t7_i32); |
| ASSERT_EQ(t7_i32, t8_i32); |
| |
| // NOTE: For the time being, we treat int32() and Type::INT32 as being |
| // different. This could obviously be fixed later to make these equivalent |
| ASSERT_NE(InputType(int8()), InputType(Type::INT32)); |
| |
| // Check that field metadata excluded from equality checks |
| InputType t9 = list( |
| field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"}))); |
| InputType t10 = list(field("item", utf8())); |
| ASSERT_TRUE(t9.Equals(t10)); |
| } |
| |
| TEST(InputType, Hash) { |
| InputType t0; |
| InputType t0_scalar(ValueDescr::SCALAR); |
| InputType t0_array(ValueDescr::ARRAY); |
| |
| InputType t1 = int8(); |
| InputType t2(Type::DECIMAL); |
| |
| // These checks try to determine first of all whether Hash always returns the |
| // same value, and whether the elements of the type are all incorporated into |
| // the Hash |
| ASSERT_EQ(t0.Hash(), t0.Hash()); |
| ASSERT_NE(t0.Hash(), t0_scalar.Hash()); |
| ASSERT_NE(t0.Hash(), t0_array.Hash()); |
| ASSERT_NE(t0_scalar.Hash(), t0_array.Hash()); |
| |
| ASSERT_EQ(t1.Hash(), t1.Hash()); |
| ASSERT_EQ(t2.Hash(), t2.Hash()); |
| |
| ASSERT_NE(t0.Hash(), t1.Hash()); |
| ASSERT_NE(t0.Hash(), t2.Hash()); |
| ASSERT_NE(t1.Hash(), t2.Hash()); |
| } |
| |
| TEST(InputType, Matches) { |
| InputType ty1 = int8(); |
| |
| ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8()))); |
| ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8()))); |
| ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8()))); |
| ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16()))); |
| |
| InputType ty2(Type::DECIMAL); |
| ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2)))); |
| ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2)))); |
| ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64()))); |
| |
| InputType ty3(int64(), ValueDescr::SCALAR); |
| ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64()))); |
| ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64()))); |
| ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32()))); |
| ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64()))); |
| } |
| |
| // ---------------------------------------------------------------------- |
| // OutputType |
| |
| TEST(OutputType, Constructors) { |
| OutputType ty1 = int8(); |
| ASSERT_EQ(OutputType::FIXED, ty1.kind()); |
| AssertTypeEqual(*int8(), *ty1.type()); |
| |
| auto DummyResolver = [](KernelContext*, |
| const std::vector<ValueDescr>& args) -> Result<ValueDescr> { |
| return ValueDescr(int32(), GetBroadcastShape(args)); |
| }; |
| OutputType ty2(DummyResolver); |
| ASSERT_EQ(OutputType::COMPUTED, ty2.kind()); |
| |
| ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr2); |
| |
| // Copy constructor |
| OutputType ty3 = ty1; |
| ASSERT_EQ(OutputType::FIXED, ty3.kind()); |
| AssertTypeEqual(*ty1.type(), *ty3.type()); |
| |
| OutputType ty4 = ty2; |
| ASSERT_EQ(OutputType::COMPUTED, ty4.kind()); |
| ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr4); |
| |
| // Move constructor |
| OutputType ty5 = std::move(ty1); |
| ASSERT_EQ(OutputType::FIXED, ty5.kind()); |
| AssertTypeEqual(*int8(), *ty5.type()); |
| |
| OutputType ty6 = std::move(ty4); |
| ASSERT_EQ(OutputType::COMPUTED, ty6.kind()); |
| ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr6); |
| |
| // ToString |
| |
| // ty1 was copied to ty3 |
| ASSERT_EQ("int8", ty3.ToString()); |
| ASSERT_EQ("computed", ty2.ToString()); |
| } |
| |
| TEST(OutputType, Resolve) { |
| // Check shape promotion rules for FIXED kind |
| OutputType ty1(int32()); |
| |
| ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), descr); |
| |
| ASSERT_OK_AND_ASSIGN(descr, |
| ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), descr); |
| |
| ASSERT_OK_AND_ASSIGN(descr, |
| ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR), |
| ValueDescr(int8(), ValueDescr::ARRAY)})); |
| ASSERT_EQ(ValueDescr::Array(int32()), descr); |
| |
| OutputType ty2([](KernelContext*, const std::vector<ValueDescr>& args) { |
| return ValueDescr(args[0].type, GetBroadcastShape(args)); |
| }); |
| |
| ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())})); |
| ASSERT_EQ(ValueDescr::Array(utf8()), descr); |
| |
| // Type resolver that returns an error |
| OutputType ty3( |
| [](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> { |
| // NB: checking the value types versus the function arity should be |
| // validated elsewhere, so this is just for illustration purposes |
| if (args.size() == 0) { |
| return Status::Invalid("Need at least one argument"); |
| } |
| return ValueDescr(args[0]); |
| }); |
| ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {})); |
| |
| // Type resolver that returns ValueDescr::ANY and needs type promotion |
| OutputType ty4( |
| [](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> { |
| return int32(); |
| }); |
| |
| ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())})); |
| ASSERT_EQ(ValueDescr::Array(int32()), descr); |
| ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())})); |
| ASSERT_EQ(ValueDescr::Scalar(int32()), descr); |
| } |
| |
| TEST(OutputType, ResolveDescr) { |
| ValueDescr d1 = ValueDescr::Scalar(int32()); |
| ValueDescr d2 = ValueDescr::Array(int32()); |
| |
| OutputType ty1(d1); |
| OutputType ty2(d2); |
| |
| ASSERT_EQ(ValueDescr::SCALAR, ty1.shape()); |
| ASSERT_EQ(ValueDescr::ARRAY, ty2.shape()); |
| |
| { |
| ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); |
| ASSERT_EQ(d1, descr); |
| } |
| |
| { |
| ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {})); |
| ASSERT_EQ(d2, descr); |
| } |
| } |
| |
| // ---------------------------------------------------------------------- |
| // KernelSignature |
| |
| TEST(KernelSignature, Basics) { |
| // (any[int8], scalar[decimal]) -> utf8 |
| std::vector<InputType> in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)}); |
| OutputType out_type(utf8()); |
| |
| KernelSignature sig(in_types, out_type); |
| ASSERT_EQ(2, sig.in_types().size()); |
| ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); |
| ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8()))); |
| ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8()))); |
| |
| ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2)))); |
| ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2)))); |
| } |
| |
| TEST(KernelSignature, Equals) { |
| KernelSignature sig1({}, utf8()); |
| KernelSignature sig1_copy({}, utf8()); |
| KernelSignature sig2({int8()}, utf8()); |
| |
| // Output type doesn't matter (for now) |
| KernelSignature sig3({int8()}, int32()); |
| |
| KernelSignature sig4({int8(), int16()}, utf8()); |
| KernelSignature sig4_copy({int8(), int16()}, utf8()); |
| KernelSignature sig5({int8(), int16(), int32()}, utf8()); |
| |
| // Differ in shape |
| KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8()); |
| KernelSignature sig7({ValueDescr::Array(int8())}, utf8()); |
| |
| ASSERT_EQ(sig1, sig1); |
| |
| ASSERT_EQ(sig2, sig3); |
| ASSERT_NE(sig3, sig4); |
| |
| // Different sig objects, but same sig |
| ASSERT_EQ(sig1, sig1_copy); |
| ASSERT_EQ(sig4, sig4_copy); |
| |
| // Match first 2 args, but not third |
| ASSERT_NE(sig4, sig5); |
| |
| ASSERT_NE(sig6, sig7); |
| } |
| |
| TEST(KernelSignature, VarArgsEquals) { |
| KernelSignature sig1({int8()}, utf8(), /*is_varargs=*/true); |
| KernelSignature sig2({int8()}, utf8(), /*is_varargs=*/true); |
| KernelSignature sig3({int8()}, utf8()); |
| |
| ASSERT_EQ(sig1, sig2); |
| ASSERT_NE(sig2, sig3); |
| } |
| |
| TEST(KernelSignature, Hash) { |
| // Some basic tests to ensure that the hashes are deterministic and that all |
| // input arguments are incorporated |
| KernelSignature sig1({}, utf8()); |
| KernelSignature sig2({int8()}, utf8()); |
| KernelSignature sig3({int8(), int32()}, utf8()); |
| |
| ASSERT_EQ(sig1.Hash(), sig1.Hash()); |
| ASSERT_EQ(sig2.Hash(), sig2.Hash()); |
| ASSERT_NE(sig1.Hash(), sig2.Hash()); |
| ASSERT_NE(sig2.Hash(), sig3.Hash()); |
| } |
| |
| TEST(KernelSignature, MatchesInputs) { |
| // () -> boolean |
| KernelSignature sig1({}, boolean()); |
| |
| ASSERT_TRUE(sig1.MatchesInputs({})); |
| ASSERT_FALSE(sig1.MatchesInputs({int8()})); |
| |
| // (any[int8], any[decimal]) -> boolean |
| KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean()); |
| |
| ASSERT_FALSE(sig2.MatchesInputs({})); |
| ASSERT_FALSE(sig2.MatchesInputs({int8()})); |
| ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); |
| ASSERT_TRUE(sig2.MatchesInputs( |
| {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))})); |
| ASSERT_TRUE( |
| sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))})); |
| |
| // (scalar[int8], array[int32]) -> boolean |
| KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}, |
| boolean()); |
| |
| ASSERT_FALSE(sig3.MatchesInputs({})); |
| |
| // Unqualified, these are ANY type and do not match because the kernel |
| // requires a scalar and an array |
| ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()})); |
| ASSERT_TRUE( |
| sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())})); |
| ASSERT_FALSE( |
| sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())})); |
| } |
| |
| TEST(KernelSignature, VarArgsMatchesInputs) { |
| KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); |
| |
| std::vector<ValueDescr> args = {int8()}; |
| ASSERT_TRUE(sig.MatchesInputs(args)); |
| args.push_back(ValueDescr::Scalar(int8())); |
| args.push_back(ValueDescr::Array(int8())); |
| ASSERT_TRUE(sig.MatchesInputs(args)); |
| args.push_back(int32()); |
| ASSERT_FALSE(sig.MatchesInputs(args)); |
| } |
| |
| TEST(KernelSignature, ToString) { |
| std::vector<InputType> in_types = {InputType(int8(), ValueDescr::SCALAR), |
| InputType(Type::DECIMAL, ValueDescr::ARRAY), |
| InputType(utf8())}; |
| KernelSignature sig(in_types, utf8()); |
| ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string", |
| sig.ToString()); |
| |
| OutputType out_type([](KernelContext*, const std::vector<ValueDescr>& args) { |
| return Status::Invalid("NYI"); |
| }); |
| KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type); |
| ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString()); |
| } |
| |
| TEST(KernelSignature, VarArgsToString) { |
| KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); |
| ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString()); |
| } |
| |
| } // namespace compute |
| } // namespace arrow |