| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include "SerializedPlanBuilder.h" |
| #include <DataTypes/DataTypeArray.h> |
| #include <DataTypes/DataTypeDateTime64.h> |
| #include <DataTypes/DataTypeLowCardinality.h> |
| #include <DataTypes/DataTypeMap.h> |
| #include <DataTypes/DataTypeNullable.h> |
| #include <DataTypes/DataTypeTuple.h> |
| #include <DataTypes/DataTypesDecimal.h> |
| #include <Functions/FunctionHelpers.h> |
| |
| namespace DB |
| { |
| namespace ErrorCodes |
| { |
| extern const int UNKNOWN_TYPE; |
| } |
| } |
| |
| namespace dbms |
| { |
| using namespace DB; |
| SchemaPtr SerializedSchemaBuilder::build() |
| { |
| for (const auto & [name, type] : this->type_map) |
| { |
| this->schema->add_names(name); |
| auto * type_struct = this->schema->mutable_struct_(); |
| if (type == "I8") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_i8()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "I32") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_i32()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "I64") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_i64()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "Boolean") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_bool_()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "I16") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_i16()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "String") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_string()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "FP32") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_fp32()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "FP64") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_fp64()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "Date") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_date()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else if (type == "Timestamp") |
| { |
| auto * t = type_struct->mutable_types()->Add(); |
| t->mutable_timestamp_tz()->set_nullability( |
| this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE |
| : substrait::Type_Nullability_NULLABILITY_REQUIRED); |
| } |
| else |
| { |
| throw std::runtime_error("doesn't support type " + type); |
| } |
| } |
| return std::move(this->schema); |
| } |
| SerializedSchemaBuilder & SerializedSchemaBuilder::column(const std::string & name, const std::string & type, bool nullable) |
| { |
| this->type_map.emplace(name, type); |
| this->nullability_map.emplace(name, nullable); |
| return *this; |
| } |
| SerializedSchemaBuilder::SerializedSchemaBuilder() : schema(new substrait::NamedStruct()) |
| { |
| } |
| SerializedPlanBuilder & SerializedPlanBuilder::registerFunction(int id, const std::string & name) |
| { |
| auto * extension = this->plan->mutable_extensions()->Add(); |
| auto * function_mapping = extension->mutable_extension_function(); |
| function_mapping->set_function_anchor(id); |
| function_mapping->set_name(name); |
| return *this; |
| } |
| |
| void SerializedPlanBuilder::setInputToPrev(substrait::Rel * input) |
| { |
| if (!this->prev_rel) |
| { |
| auto * root = this->plan->mutable_relations()->Add()->mutable_root(); |
| root->set_allocated_input(input); |
| return; |
| } |
| if (this->prev_rel->has_filter()) |
| { |
| this->prev_rel->mutable_filter()->set_allocated_input(input); |
| } |
| else if (this->prev_rel->has_aggregate()) |
| { |
| this->prev_rel->mutable_aggregate()->set_allocated_input(input); |
| } |
| else if (this->prev_rel->has_project()) |
| { |
| this->prev_rel->mutable_project()->set_allocated_input(input); |
| } |
| else |
| { |
| throw std::runtime_error("does support rel type"); |
| } |
| } |
| |
| SerializedPlanBuilder & SerializedPlanBuilder::filter(substrait::Expression * condition) |
| { |
| substrait::Rel * filter = new substrait::Rel(); |
| filter->mutable_filter()->set_allocated_condition(condition); |
| setInputToPrev(filter); |
| this->prev_rel = filter; |
| return *this; |
| } |
| |
| SerializedPlanBuilder & SerializedPlanBuilder::read(const std::string & path, SchemaPtr schema) |
| { |
| substrait::Rel * rel = new substrait::Rel(); |
| auto * read = rel->mutable_read(); |
| read->mutable_local_files()->add_items()->set_uri_file(path); |
| read->set_allocated_base_schema(schema); |
| setInputToPrev(rel); |
| this->prev_rel = rel; |
| return *this; |
| } |
| |
| std::unique_ptr<substrait::Plan> SerializedPlanBuilder::build() |
| { |
| return std::move(this->plan); |
| } |
| |
| SerializedPlanBuilder::SerializedPlanBuilder() : plan(std::make_unique<substrait::Plan>()) |
| { |
| } |
| |
| SerializedPlanBuilder & |
| SerializedPlanBuilder::aggregate(const std::vector<int32_t> & /*keys*/, const std::vector<substrait::AggregateRel_Measure *> & aggregates) |
| { |
| substrait::Rel * rel = new substrait::Rel(); |
| auto * agg = rel->mutable_aggregate(); |
| // TODO support group |
| auto * measures = agg->mutable_measures(); |
| for (auto * measure : aggregates) |
| { |
| measures->AddAllocated(measure); |
| } |
| setInputToPrev(rel); |
| this->prev_rel = rel; |
| return *this; |
| } |
| |
| SerializedPlanBuilder & SerializedPlanBuilder::project(const std::vector<substrait::Expression *> & projections) |
| { |
| substrait::Rel * project = new substrait::Rel(); |
| for (auto * expr : projections) |
| { |
| project->mutable_project()->mutable_expressions()->AddAllocated(expr); |
| } |
| setInputToPrev(project); |
| this->prev_rel = project; |
| return *this; |
| } |
| |
| std::shared_ptr<substrait::Type> SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type) |
| { |
| const auto ch_type_wo_lowcardinality = DB::removeLowCardinality(ch_type); |
| |
| const auto * ch_type_nullable = checkAndGetDataType<DataTypeNullable>(ch_type_wo_lowcardinality.get()); |
| |
| const bool is_nullable = (ch_type_nullable != nullptr); |
| auto type_nullability |
| = is_nullable ? substrait::Type_Nullability_NULLABILITY_NULLABLE : substrait::Type_Nullability_NULLABILITY_REQUIRED; |
| |
| const auto ch_type_without_nullable = DB::removeNullable(ch_type_wo_lowcardinality); |
| const DB::WhichDataType which(ch_type_without_nullable); |
| |
| auto res = std::make_shared<substrait::Type>(); |
| if (which.isUInt8()) |
| res->mutable_bool_()->set_nullability(type_nullability); |
| else if (which.isInt8()) |
| res->mutable_i8()->set_nullability(type_nullability); |
| else if (which.isInt16()) |
| res->mutable_i16()->set_nullability(type_nullability); |
| else if (which.isInt32()) |
| res->mutable_i32()->set_nullability(type_nullability); |
| else if (which.isInt64()) |
| res->mutable_i64()->set_nullability(type_nullability); |
| else if (which.isStringOrFixedString() || which.isAggregateFunction()) |
| res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type |
| else if (which.isFloat32()) |
| res->mutable_fp32()->set_nullability(type_nullability); |
| else if (which.isFloat64()) |
| res->mutable_fp64()->set_nullability(type_nullability); |
| else if (which.isFloat64()) |
| res->mutable_fp64()->set_nullability(type_nullability); |
| else if (which.isDateTime64()) |
| { |
| const auto * ch_type_datetime64 = checkAndGetDataType<DataTypeDateTime64>(ch_type_without_nullable.get()); |
| if (ch_type_datetime64->getScale() != 6) |
| throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); |
| res->mutable_timestamp_tz()->set_nullability(type_nullability); |
| } |
| else if (which.isDate32()) |
| res->mutable_date()->set_nullability(type_nullability); |
| else if (which.isDecimal()) |
| { |
| if (which.isDecimal256()) |
| throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); |
| |
| const auto scale = getDecimalScale(*ch_type_without_nullable); |
| const auto precision = getDecimalPrecision(*ch_type_without_nullable); |
| if (scale == 0 && precision == 0) |
| throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); |
| |
| res->mutable_decimal()->set_nullability(type_nullability); |
| res->mutable_decimal()->set_scale(scale); |
| res->mutable_decimal()->set_precision(precision); |
| } |
| else if (which.isTuple()) |
| { |
| const auto * ch_tuple_type = checkAndGetDataType<DataTypeTuple>(ch_type_without_nullable.get()); |
| const auto & ch_field_types = ch_tuple_type->getElements(); |
| res->mutable_struct_()->set_nullability(type_nullability); |
| for (const auto & ch_field_type : ch_field_types) |
| res->mutable_struct_()->mutable_types()->Add(std::move(*buildType(ch_field_type))); |
| } |
| else if (which.isArray()) |
| { |
| const auto * ch_array_type = checkAndGetDataType<DataTypeArray>(ch_type_without_nullable.get()); |
| const auto & ch_nested_type = ch_array_type->getNestedType(); |
| res->mutable_list()->set_nullability(type_nullability); |
| *(res->mutable_list()->mutable_type()) = *buildType(ch_nested_type); |
| } |
| else if (which.isMap()) |
| { |
| const auto & ch_map_type = checkAndGetDataType<DataTypeMap>(ch_type_without_nullable.get()); |
| const auto & ch_key_type = ch_map_type->getKeyType(); |
| const auto & ch_val_type = ch_map_type->getValueType(); |
| res->mutable_map()->set_nullability(type_nullability); |
| *(res->mutable_map()->mutable_key()) = *buildType(ch_key_type); |
| *(res->mutable_map()->mutable_value()) = *buildType(ch_val_type); |
| } |
| else if (which.isNothing()) |
| res->mutable_nothing(); |
| else |
| throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); |
| |
| return res; |
| } |
| |
| void SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type, String & substrait_type) |
| { |
| auto pb = buildType(ch_type); |
| substrait_type = pb->SerializeAsString(); |
| } |
| |
| substrait::Expression * scalarFunction(int32_t id, ExpressionList args) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * function = rel->mutable_scalar_function(); |
| function->set_function_reference(id); |
| std::for_each(args.begin(), args.end(), [function](auto * expr) { function->mutable_args()->AddAllocated(expr); }); |
| return rel; |
| } |
| substrait::AggregateRel_Measure * measureFunction(int32_t id, ExpressionList args) |
| { |
| substrait::AggregateRel_Measure * rel = new substrait::AggregateRel_Measure(); |
| auto * measure = rel->mutable_measure(); |
| measure->set_function_reference(id); |
| std::for_each(args.begin(), args.end(), [measure](auto * expr) { measure->mutable_args()->AddAllocated(expr); }); |
| return rel; |
| } |
| substrait::Expression * literal(double_t value) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * literal = rel->mutable_literal(); |
| literal->set_fp64(value); |
| return rel; |
| } |
| |
| substrait::Expression * literal(int32_t value) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * literal = rel->mutable_literal(); |
| literal->set_i32(value); |
| return rel; |
| } |
| |
| substrait::Expression * literal(const std::string & value) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * literal = rel->mutable_literal(); |
| literal->set_string(value); |
| return rel; |
| } |
| |
| substrait::Expression * literalDate(int32_t value) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * literal = rel->mutable_literal(); |
| literal->set_date(value); |
| return rel; |
| } |
| |
| /// Timestamp in units of microseconds since the UNIX epoch. |
| substrait::Expression * literalTimestamp(int64_t value) |
| { |
| substrait::Expression * rel = new substrait::Expression(); |
| auto * literal = rel->mutable_literal(); |
| literal->set_timestamp_tz(value); |
| return rel; |
| } |
| |
| } |