blob: cb62f9764913dad211ba6baeb26b0a8ed497a463 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "substrait/VeloxSubstraitSignature.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
using namespace facebook::velox;
namespace gluten {
class VeloxSubstraitSignatureTest : public ::testing::Test {
protected:
void SetUp() override {
Test::SetUp();
functions::prestosql::registerAllScalarFunctions();
}
static TypePtr fromSubstraitSignature(const std::string& signature) {
return VeloxSubstraitSignature::fromSubstraitSignature(signature);
}
static std::string toSubstraitSignature(const TypePtr& type) {
return VeloxSubstraitSignature::toSubstraitSignature(type);
}
static std::string toSubstraitSignature(const std::string& functionName, const std::vector<TypePtr>& arguments) {
return VeloxSubstraitSignature::toSubstraitSignature(functionName, arguments);
}
};
TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithType) {
ASSERT_EQ(toSubstraitSignature(BOOLEAN()), "bool");
ASSERT_EQ(toSubstraitSignature(TINYINT()), "i8");
ASSERT_EQ(toSubstraitSignature(SMALLINT()), "i16");
ASSERT_EQ(toSubstraitSignature(INTEGER()), "i32");
ASSERT_EQ(toSubstraitSignature(BIGINT()), "i64");
ASSERT_EQ(toSubstraitSignature(REAL()), "fp32");
ASSERT_EQ(toSubstraitSignature(DOUBLE()), "fp64");
ASSERT_EQ(toSubstraitSignature(VARCHAR()), "str");
ASSERT_EQ(toSubstraitSignature(VARBINARY()), "vbin");
ASSERT_EQ(toSubstraitSignature(TIMESTAMP()), "ts");
ASSERT_EQ(toSubstraitSignature(DATE()), "date");
ASSERT_EQ(toSubstraitSignature(ARRAY(BOOLEAN())), "list");
ASSERT_EQ(toSubstraitSignature(ARRAY(INTEGER())), "list");
ASSERT_EQ(toSubstraitSignature(MAP(INTEGER(), BIGINT())), "map");
ASSERT_EQ(toSubstraitSignature(ROW({INTEGER(), BIGINT()})), "struct<i32,i64>");
ASSERT_EQ(toSubstraitSignature(ROW({ARRAY(INTEGER())})), "struct<list>");
ASSERT_EQ(toSubstraitSignature(ROW({MAP(INTEGER(), INTEGER())})), "struct<map>");
ASSERT_EQ(toSubstraitSignature(ROW({ROW({INTEGER()})})), "struct<struct<i32>>");
ASSERT_EQ(toSubstraitSignature(UNKNOWN()), "u!name");
}
TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithFunctionNameAndArguments) {
ASSERT_EQ(toSubstraitSignature("eq", {INTEGER(), INTEGER()}), "eq:i32_i32");
ASSERT_EQ(toSubstraitSignature("gt", {INTEGER(), INTEGER()}), "gt:i32_i32");
ASSERT_EQ(toSubstraitSignature("lt", {INTEGER(), INTEGER()}), "lt:i32_i32");
ASSERT_EQ(toSubstraitSignature("gte", {INTEGER(), INTEGER()}), "gte:i32_i32");
ASSERT_EQ(toSubstraitSignature("lte", {INTEGER(), INTEGER()}), "lte:i32_i32");
ASSERT_EQ(toSubstraitSignature("and", {BOOLEAN(), BOOLEAN()}), "and:bool_bool");
ASSERT_EQ(toSubstraitSignature("or", {BOOLEAN(), BOOLEAN()}), "or:bool_bool");
ASSERT_EQ(toSubstraitSignature("not", {BOOLEAN()}), "not:bool");
ASSERT_EQ(toSubstraitSignature("xor", {BOOLEAN(), BOOLEAN()}), "xor:bool_bool");
ASSERT_EQ(toSubstraitSignature("between", {INTEGER(), INTEGER(), INTEGER()}), "between:i32_i32_i32");
ASSERT_EQ(toSubstraitSignature("plus", {INTEGER(), INTEGER()}), "plus:i32_i32");
ASSERT_EQ(toSubstraitSignature("divide", {INTEGER(), INTEGER()}), "divide:i32_i32");
ASSERT_EQ(toSubstraitSignature("cardinality", {ARRAY(INTEGER())}), "cardinality:list");
ASSERT_EQ(toSubstraitSignature("array_sum", {ARRAY(INTEGER())}), "array_sum:list");
ASSERT_EQ(toSubstraitSignature("sum", {INTEGER()}), "sum:i32");
ASSERT_EQ(toSubstraitSignature("avg", {INTEGER()}), "avg:i32");
ASSERT_EQ(toSubstraitSignature("count", {INTEGER()}), "count:i32");
auto functionType = std::make_shared<const FunctionType>(std::vector<TypePtr>{INTEGER(), VARCHAR()}, BIGINT());
std::vector<TypePtr> types = {MAP(INTEGER(), VARCHAR()), functionType};
ASSERT_ANY_THROW(toSubstraitSignature("transform_keys", std::move(types)));
}
TEST_F(VeloxSubstraitSignatureTest, fromSubstraitSignature) {
ASSERT_EQ(fromSubstraitSignature("bool")->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(fromSubstraitSignature("i8")->kind(), TypeKind::TINYINT);
ASSERT_EQ(fromSubstraitSignature("i16")->kind(), TypeKind::SMALLINT);
ASSERT_EQ(fromSubstraitSignature("i32")->kind(), TypeKind::INTEGER);
ASSERT_EQ(fromSubstraitSignature("i64")->kind(), TypeKind::BIGINT);
ASSERT_EQ(fromSubstraitSignature("fp32")->kind(), TypeKind::REAL);
ASSERT_EQ(fromSubstraitSignature("fp64")->kind(), TypeKind::DOUBLE);
ASSERT_EQ(fromSubstraitSignature("str")->kind(), TypeKind::VARCHAR);
ASSERT_EQ(fromSubstraitSignature("vbin")->kind(), TypeKind::VARBINARY);
ASSERT_EQ(fromSubstraitSignature("ts")->kind(), TypeKind::TIMESTAMP);
ASSERT_EQ(fromSubstraitSignature("date")->kind(), TypeKind::INTEGER);
ASSERT_EQ(fromSubstraitSignature("dec<18,2>")->kind(), TypeKind::BIGINT);
ASSERT_EQ(fromSubstraitSignature("dec<19,2>")->kind(), TypeKind::HUGEINT);
// Struct type test.
auto type = fromSubstraitSignature("struct<bool,vbin,dec<8,2>>");
ASSERT_EQ(type->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::VARBINARY);
ASSERT_EQ(type->childAt(2)->kind(), TypeKind::BIGINT);
type = fromSubstraitSignature("struct<bool,struct<i8,fp32>>");
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::TINYINT);
ASSERT_EQ(type->childAt(1)->childAt(1)->kind(), TypeKind::REAL);
type = fromSubstraitSignature("struct<bool,struct<i8,fp32>,vbin,ts,dec<9,2>>");
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ROW);
type = fromSubstraitSignature("struct<struct<ts>,i16>");
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::TIMESTAMP);
type = fromSubstraitSignature("struct<struct<dec<12,9>>>");
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BIGINT);
type = fromSubstraitSignature("struct<struct<struct<i8>>>");
ASSERT_EQ(type->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(0)->kind(), TypeKind::TINYINT);
type = fromSubstraitSignature("struct<struct<struct<i8,str>>>");
ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(0)->kind(), TypeKind::TINYINT);
ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::VARCHAR);
type = fromSubstraitSignature("struct<struct<struct<i8,dec<19,2>>>>");
ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::HUGEINT);
type = fromSubstraitSignature("struct<fp32,struct<struct<i8,dec<19,2>,list<i32>,map<i16,i32>>>>");
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::REAL);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(0)->kind(), TypeKind::TINYINT);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(1)->kind(), TypeKind::HUGEINT);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(2)->childAt(0)->kind(), TypeKind::INTEGER);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(3)->childAt(0)->kind(), TypeKind::SMALLINT);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(3)->childAt(1)->kind(), TypeKind::INTEGER);
ASSERT_ANY_THROW(fromSubstraitSignature("other")->kind());
// Map type test.
type = fromSubstraitSignature("map<bool,list<map<str,i32>>>");
ASSERT_EQ(type->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
type = fromSubstraitSignature("struct<map<bool,i8>,list<map<str,i32>>>");
ASSERT_EQ(type->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(type->childAt(0)->childAt(1)->kind(), TypeKind::TINYINT);
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(0)->kind(), TypeKind::VARCHAR);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(1)->kind(), TypeKind::INTEGER);
}
} // namespace gluten