blob: b1a9ae9444f45d4d7145edc6848108c28d2340f6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "substrait/SubstraitExtensionCollector.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/core/PlanNode.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
using namespace facebook::velox;
namespace gluten {
class SubstraitExtensionCollectorTest : public ::testing::Test {
protected:
void SetUp() override {
Test::SetUp();
functions::prestosql::registerAllScalarFunctions();
}
int getReferenceNumber(const std::string& functionName, std::vector<TypePtr>&& arguments) {
int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments);
// Repeat the call to make sure properly de-duplicated.
int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments);
EXPECT_EQ(referenceNumber1, referenceNumber2);
return referenceNumber1;
}
int getReferenceNumber(
const std::string& functionName,
std::vector<TypePtr>&& arguments,
core::AggregationNode::Step step) {
int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments);
// Repeat the call to make sure properly de-duplicated.
int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments);
EXPECT_EQ(referenceNumber1, referenceNumber2);
return referenceNumber2;
}
/// Given a substrait plan, return the sorted extension functions by the
/// function anchor.
::google::protobuf::RepeatedPtrField<::substrait::extensions::SimpleExtensionDeclaration> getSortedSubstraitExtension(
const ::substrait::Plan* substraitPlan) {
auto substraitExtensions = substraitPlan->extensions();
std::sort(substraitExtensions.begin(), substraitExtensions.end(), [](const auto& a, const auto& b) {
return a.extension_function().function_anchor() < b.extension_function().function_anchor();
});
return substraitExtensions;
}
SubstraitExtensionCollectorPtr extensionCollector_ = std::make_shared<SubstraitExtensionCollector>();
};
TEST_F(SubstraitExtensionCollectorTest, getReferenceNumberForScalarFunction) {
ASSERT_EQ(getReferenceNumber("plus", {INTEGER(), INTEGER()}), 0);
ASSERT_EQ(getReferenceNumber("divide", {INTEGER(), INTEGER()}), 1);
ASSERT_EQ(getReferenceNumber("cardinality", {ARRAY(INTEGER())}), 2);
ASSERT_EQ(getReferenceNumber("array_sum", {ARRAY(INTEGER())}), 3);
auto functionType = std::make_shared<const FunctionType>(std::vector<TypePtr>{INTEGER(), VARCHAR()}, BIGINT());
std::vector<TypePtr> types = {MAP(INTEGER(), VARCHAR()), functionType};
ASSERT_ANY_THROW(getReferenceNumber("transform_keys", std::move(types)));
}
TEST_F(SubstraitExtensionCollectorTest, getReferenceNumberForAggregateFunction) {
// Sum aggregate function have same argument type for each aggregation step.
ASSERT_EQ(getReferenceNumber("sum", {INTEGER()}, core::AggregationNode::Step::kSingle), 0);
// Partial avg aggregate function should use primitive integral type.
ASSERT_EQ(getReferenceNumber("avg", {INTEGER()}, core::AggregationNode::Step::kPartial), 1);
// Final avg aggregate function should use struct type, like
// 'ROW<DOUBLE,BIGINT>'
ASSERT_EQ(getReferenceNumber("avg", {ROW({DOUBLE(), BIGINT()})}, core::AggregationNode::Step::kFinal), 2);
// Count aggregate function have same argument type for each aggregation step.
ASSERT_EQ(getReferenceNumber("count", {INTEGER()}, core::AggregationNode::Step::kFinal), 3);
}
TEST_F(SubstraitExtensionCollectorTest, addExtensionsToPlan) {
getReferenceNumber("plus", {INTEGER(), INTEGER()});
getReferenceNumber("divide", {INTEGER(), INTEGER()});
getReferenceNumber("cardinality", {ARRAY(INTEGER())});
getReferenceNumber("array_sum", {ARRAY(INTEGER())});
getReferenceNumber("sum", {INTEGER()});
getReferenceNumber("avg", {INTEGER()});
getReferenceNumber("avg", {ROW({DOUBLE(), BIGINT()})});
getReferenceNumber("count", {INTEGER()});
google::protobuf::Arena arena;
auto* substraitPlan = google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena);
extensionCollector_->addExtensionsToPlan(substraitPlan);
const auto& substraitExtensions = getSortedSubstraitExtension(substraitPlan);
auto getFunctionName = [&](auto id) { return substraitExtensions[id].extension_function().name(); };
ASSERT_EQ(substraitPlan->extensions().size(), 8);
ASSERT_EQ(getFunctionName(0), "plus:i32_i32");
ASSERT_EQ(getFunctionName(1), "divide:i32_i32");
ASSERT_EQ(getFunctionName(2), "cardinality:list");
ASSERT_EQ(getFunctionName(3), "array_sum:list");
ASSERT_EQ(getFunctionName(4), "sum:i32");
ASSERT_EQ(getFunctionName(5), "avg:i32");
ASSERT_EQ(getFunctionName(6), "avg:struct<fp64,i64>");
ASSERT_EQ(getFunctionName(7), "count:i32");
}
} // namespace gluten