| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "VeloxToSubstraitPlan.h" |
| #include <google/protobuf/wrappers.pb.h> |
| #include "utils/Exception.h" |
| |
| namespace gluten { |
| namespace { |
| |
| struct AggregateCompanion { |
| std::string functionName; |
| core::AggregationNode::Step step; |
| }; |
| |
| AggregateCompanion toAggregateCompanion(const core::AggregationNode::Aggregate& aggregate) { |
| const auto& companionName = aggregate.call->name(); |
| auto offset = companionName.find_last_of('_'); |
| if (offset == std::string::npos) { |
| return {companionName, core::AggregationNode::Step::kSingle}; |
| } |
| // found '_' |
| const auto& suffix = companionName.substr(offset + 1); |
| if (suffix.empty()) { |
| // the last char is '_' |
| return {companionName, core::AggregationNode::Step::kSingle}; |
| } |
| const auto& functionName = companionName.substr(0, offset); |
| if (suffix == "_partial") { |
| return {functionName, core::AggregationNode::Step::kPartial}; |
| } |
| if (suffix == "_merge_extract") { |
| return {functionName, core::AggregationNode::Step::kFinal}; |
| } |
| if (suffix == "_merge") { |
| return {functionName, core::AggregationNode::Step::kIntermediate}; |
| } |
| // others, not a companion function |
| return {companionName, core::AggregationNode::Step::kSingle}; |
| } |
| |
| ::substrait::AggregationPhase toAggregationPhase(const core::AggregationNode::Step& step) { |
| switch (step) { |
| case core::AggregationNode::Step::kPartial: { |
| return ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE; |
| } |
| case core::AggregationNode::Step::kIntermediate: { |
| return ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE; |
| } |
| case core::AggregationNode::Step::kSingle: { |
| return ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT; |
| } |
| case core::AggregationNode::Step::kFinal: { |
| return ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT; |
| } |
| default: |
| VELOX_UNSUPPORTED("Unsupported Aggregate Step '{}' in Substrait ", mapAggregationStepToName(step)); |
| } |
| } |
| |
| ::substrait::SortField_SortDirection toSortDirection(core::SortOrder sortOrder) { |
| if (sortOrder.isNullsFirst()) { |
| if (sortOrder.isAscending()) { |
| return ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST; |
| } else { |
| return ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST; |
| } |
| } else { |
| if (sortOrder.isAscending()) { |
| return ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST; |
| } else { |
| return ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST; |
| } |
| } |
| } |
| |
| } // namespace |
| |
| ::substrait::Plan& VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const core::PlanNodePtr& plan) { |
| // Construct the extension colllector. |
| extensionCollector_ = std::make_shared<SubstraitExtensionCollector>(); |
| // Construct the expression converter. |
| exprConvertor_ = std::make_unique<VeloxToSubstraitExprConvertor>(extensionCollector_); |
| |
| auto substraitPlan = google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena); |
| |
| // Add unknown type in extension. |
| auto unknownType = substraitPlan->add_extensions()->mutable_extension_type(); |
| |
| unknownType->set_extension_uri_reference(0); |
| unknownType->set_type_anchor(0); |
| unknownType->set_name("UNKNOWN"); |
| |
| // Do conversion. |
| ::substrait::RelRoot* rootRel = substraitPlan->add_relations()->mutable_root(); |
| |
| toSubstrait(arena, plan, rootRel->mutable_input()); |
| |
| // Add extensions for all functions and types seen in the plan. |
| extensionCollector_->addExtensionsToPlan(substraitPlan); |
| |
| // Set RootRel names. |
| for (const auto& name : plan->outputType()->names()) { |
| rootRel->add_names(name); |
| } |
| |
| return *substraitPlan; |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const core::PlanNodePtr& planNode, |
| ::substrait::Rel* rel) { |
| if (auto filterNode = std::dynamic_pointer_cast<const core::FilterNode>(planNode)) { |
| auto filterRel = rel->mutable_filter(); |
| toSubstrait(arena, filterNode, filterRel); |
| return; |
| } |
| if (auto valuesNode = std::dynamic_pointer_cast<const core::ValuesNode>(planNode)) { |
| ::substrait::ReadRel* readRel = rel->mutable_read(); |
| toSubstrait(arena, valuesNode, readRel); |
| return; |
| } |
| if (auto projectNode = std::dynamic_pointer_cast<const core::ProjectNode>(planNode)) { |
| ::substrait::ProjectRel* projectRel = rel->mutable_project(); |
| toSubstrait(arena, projectNode, projectRel); |
| return; |
| } |
| if (auto aggregationNode = std::dynamic_pointer_cast<const core::AggregationNode>(planNode)) { |
| ::substrait::AggregateRel* aggregateRel = rel->mutable_aggregate(); |
| toSubstrait(arena, aggregationNode, aggregateRel); |
| return; |
| } |
| if (auto orderbyNode = std::dynamic_pointer_cast<const core::OrderByNode>(planNode)) { |
| toSubstrait(arena, orderbyNode, rel->mutable_sort()); |
| return; |
| } |
| if (auto topNNode = std::dynamic_pointer_cast<const core::TopNNode>(planNode)) { |
| toSubstrait(arena, topNNode, rel->mutable_fetch()); |
| return; |
| } |
| if (auto limitNode = std::dynamic_pointer_cast<const core::LimitNode>(planNode)) { |
| toSubstrait(arena, limitNode, rel->mutable_fetch()); |
| return; |
| } |
| VELOX_UNSUPPORTED("Unsupported plan node '{}' .", planNode->name()); |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::FilterNode>& filterNode, |
| ::substrait::FilterRel* filterRel) { |
| const auto& source = getSingleSource(filterNode); |
| |
| toSubstrait(arena, source, filterRel->mutable_input()); |
| |
| // Construct substrait expr(Filter condition). |
| auto filterCondition = filterNode->filter(); |
| auto inputType = source->outputType(); |
| filterRel->mutable_condition()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, filterCondition, inputType)); |
| |
| filterRel->mutable_common()->mutable_direct(); |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::ValuesNode>& valuesNode, |
| ::substrait::ReadRel* readRel) { |
| const auto& outputType = valuesNode->outputType(); |
| |
| ::substrait::ReadRel_VirtualTable* virtualTable = readRel->mutable_virtual_table(); |
| |
| for (const auto& vector : valuesNode->values()) { |
| ::substrait::Expression_Literal_Struct* litValue = virtualTable->add_values(); |
| |
| for (const auto& column : vector->children()) { |
| ::substrait::Expression_Literal* substraitField = |
| google::protobuf::Arena::CreateMessage<::substrait::Expression_Literal>(&arena); |
| |
| substraitField->MergeFrom(exprConvertor_->toSubstraitLiteral(arena, column, litValue)); |
| } |
| } |
| |
| readRel->mutable_base_schema()->MergeFrom(typeConvertor_->toSubstraitNamedStruct(arena, outputType)); |
| |
| readRel->mutable_common()->mutable_direct(); |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::ProjectNode>& projectNode, |
| ::substrait::ProjectRel* projectRel) { |
| const auto& projections = projectNode->projections(); |
| |
| const auto& source = getSingleSource(projectNode); |
| |
| // Process the source Node. |
| toSubstrait(arena, source, projectRel->mutable_input()); |
| |
| // Remap the output. |
| ::substrait::RelCommon_Emit* projRelEmit = projectRel->mutable_common()->mutable_emit(); |
| |
| int64_t projectionSize = projections.size(); |
| |
| auto inputType = source->outputType(); |
| int64_t inputTypeSize = inputType->size(); |
| |
| for (int64_t i = 0; i < projectionSize; i++) { |
| const auto& veloxExpr = projections.at(i); |
| |
| projectRel->add_expressions()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, veloxExpr, inputType)); |
| |
| // Add outputMapping for each expression. |
| projRelEmit->add_output_mapping(inputTypeSize + i); |
| } |
| |
| return; |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::AggregationNode>& aggregateNode, |
| ::substrait::AggregateRel* aggregateRel) { |
| // Process the source Node. |
| const auto& source = getSingleSource(aggregateNode); |
| toSubstrait(arena, source, aggregateRel->mutable_input()); |
| |
| // Convert aggregate grouping keys, such as: group by key1, key2. |
| auto inputType = source->outputType(); |
| auto groupingKeys = aggregateNode->groupingKeys(); |
| int64_t groupingKeySize = groupingKeys.size(); |
| ::substrait::AggregateRel_Grouping* aggGroupings = aggregateRel->add_groupings(); |
| |
| for (int64_t i = 0; i < groupingKeySize; i++) { |
| aggGroupings->add_grouping_expressions()->mutable_selection()->MergeFrom( |
| exprConvertor_->toSubstraitExpr(arena, groupingKeys.at(i), inputType)); |
| } |
| |
| // AggregatesSize should be equal to or greater than the aggregateMasks Size. |
| // Two cases: 1. aggregateMasksSize = 0, aggregatesSize > aggregateMasksSize. |
| // 2. aggregateMasksSize != 0, aggregatesSize = aggregateMasksSize. |
| auto aggregates = aggregateNode->aggregates(); |
| int64_t aggregatesSize = aggregates.size(); |
| |
| for (int64_t i = 0; i < aggregatesSize; i++) { |
| const auto& aggregate = aggregates.at(i); |
| |
| ::substrait::AggregateRel_Measure* aggMeasures = aggregateRel->add_measures(); |
| |
| // Set substrait filter. |
| ::substrait::Expression* aggFilter = aggMeasures->mutable_filter(); |
| if (const auto& mask = aggregate.mask) { |
| aggFilter->mutable_selection()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, mask, inputType)); |
| } else { |
| // Set null. |
| aggFilter = nullptr; |
| } |
| |
| // Process measure, eg:sum(a). |
| ::substrait::AggregateFunction* aggFunction = aggMeasures->mutable_measure(); |
| |
| // Use aggregate node's step information to write advanced extension 'allowFlush'. |
| const auto& step = aggregateNode->step(); |
| switch (step) { |
| case core::AggregationNode::Step::kPartial: { |
| substrait::extensions::AdvancedExtension ae{}; |
| google::protobuf::StringValue msg; |
| msg.set_value("allowFlush=1"); |
| ae.mutable_optimization()->PackFrom(msg); |
| aggregateRel->mutable_advanced_extension()->MergeFrom(ae); |
| break; |
| } |
| case core::AggregationNode::Step::kSingle: |
| break; |
| case core::AggregationNode::Step::kFinal: |
| case core::AggregationNode::Step::kIntermediate: |
| VELOX_USER_FAIL("Step not supported"); |
| break; |
| } |
| |
| // Set aggFunction args. |
| std::vector<TypePtr> arguments; |
| arguments.reserve(aggregate.call->inputs().size()); |
| for (const auto& expr : aggregate.call->inputs()) { |
| // If the expr is CallTypedExpr, people need to do project firstly. |
| if (auto aggregatesExprInput = std::dynamic_pointer_cast<const core::CallTypedExpr>(expr)) { |
| VELOX_NYI("In Velox Plan, the aggregates type cannot be CallTypedExpr"); |
| } else { |
| aggFunction->add_arguments()->mutable_value()->MergeFrom( |
| exprConvertor_->toSubstraitExpr(arena, expr, inputType)); |
| |
| arguments.emplace_back(expr->type()); |
| } |
| } |
| |
| const auto& aggregateCompanion = toAggregateCompanion(aggregate); |
| auto referenceNumber = |
| extensionCollector_->getReferenceNumber(aggregateCompanion.functionName, aggregate.rawInputTypes); |
| |
| aggFunction->set_function_reference(referenceNumber); |
| |
| aggFunction->mutable_output_type()->MergeFrom(typeConvertor_->toSubstraitType(arena, aggregate.call->type())); |
| |
| // Set substrait aggregate Function phase. |
| aggFunction->set_phase(toAggregationPhase(aggregateCompanion.step)); |
| } |
| |
| // Direct output. |
| aggregateRel->mutable_common()->mutable_direct(); |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::OrderByNode>& orderByNode, |
| ::substrait::SortRel* sortRel) { |
| const auto& source = getSingleSource(orderByNode); |
| toSubstrait(arena, source, sortRel->mutable_input()); |
| |
| sortRel->MergeFrom( |
| processSortFields(arena, orderByNode->sortingKeys(), orderByNode->sortingOrders(), source->outputType())); |
| |
| VELOX_CHECK(!orderByNode->isPartial(), "Substrait doesn't support partial order by yet"); |
| sortRel->mutable_common()->mutable_direct(); |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::TopNNode>& topNNode, |
| ::substrait::FetchRel* fetchRel) { |
| const auto& source = getSingleSource(topNNode); |
| |
| // Construct the sortRel as the FetchRel input. |
| ::substrait::SortRel* sortRel = fetchRel->mutable_input()->mutable_sort(); |
| toSubstrait(arena, source, sortRel->mutable_input()); |
| |
| sortRel->MergeFrom( |
| processSortFields(arena, topNNode->sortingKeys(), topNNode->sortingOrders(), source->outputType())); |
| |
| sortRel->mutable_common()->mutable_direct(); |
| |
| VELOX_CHECK(!topNNode->isPartial(), "Substrait doesn't support partial topN yet"); |
| |
| fetchRel->set_offset(0); |
| fetchRel->set_count(topNNode->count()); |
| fetchRel->mutable_common()->mutable_direct(); |
| } |
| |
| const ::substrait::SortRel& VeloxToSubstraitPlanConvertor::processSortFields( |
| google::protobuf::Arena& arena, |
| const std::vector<core::FieldAccessTypedExprPtr>& sortingKeys, |
| const std::vector<core::SortOrder>& sortingOrders, |
| const facebook::velox::RowTypePtr& inputType) { |
| ::substrait::SortRel* sortRel = google::protobuf::Arena::CreateMessage<::substrait::SortRel>(&arena); |
| |
| VELOX_CHECK_EQ( |
| sortingKeys.size(), sortingOrders.size(), "Number of sorting keys and sorting orders must be the same"); |
| |
| for (int64_t i = 0; i < sortingKeys.size(); i++) { |
| ::substrait::SortField* sortField = sortRel->add_sorts(); |
| sortField->mutable_expr()->mutable_selection()->MergeFrom( |
| exprConvertor_->toSubstraitExpr(arena, sortingKeys[i], inputType)); |
| |
| sortField->set_direction(toSortDirection(sortingOrders[i])); |
| } |
| return *sortRel; |
| } |
| |
| void VeloxToSubstraitPlanConvertor::toSubstrait( |
| google::protobuf::Arena& arena, |
| const std::shared_ptr<const core::LimitNode>& limitNode, |
| ::substrait::FetchRel* fetchRel) { |
| const auto& source = getSingleSource(limitNode); |
| toSubstrait(arena, source, fetchRel->mutable_input()); |
| |
| fetchRel->set_offset(limitNode->offset()); |
| fetchRel->set_count(limitNode->count()); |
| |
| VELOX_CHECK(!limitNode->isPartial(), "Substrait doesn't support partial limit yet"); |
| |
| fetchRel->mutable_common()->mutable_direct(); |
| } |
| |
| const core::PlanNodePtr& VeloxToSubstraitPlanConvertor::getSingleSource(const core::PlanNodePtr& node) { |
| const auto& sources = node->sources(); |
| |
| VELOX_USER_CHECK_EQ(1, sources.size(), "Plan node must have exactly one source."); |
| return sources[0]; |
| } |
| |
| } // namespace gluten |