blob: 44fbb01b39a65bcc5a806ee0ac030f40a0b54d56 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "VeloxToSubstraitPlan.h"
#include <google/protobuf/wrappers.pb.h>
#include "utils/Exception.h"
namespace gluten {
namespace {
struct AggregateCompanion {
std::string functionName;
core::AggregationNode::Step step;
};
AggregateCompanion toAggregateCompanion(const core::AggregationNode::Aggregate& aggregate) {
const auto& companionName = aggregate.call->name();
auto offset = companionName.find_last_of('_');
if (offset == std::string::npos) {
return {companionName, core::AggregationNode::Step::kSingle};
}
// found '_'
const auto& suffix = companionName.substr(offset + 1);
if (suffix.empty()) {
// the last char is '_'
return {companionName, core::AggregationNode::Step::kSingle};
}
const auto& functionName = companionName.substr(0, offset);
if (suffix == "_partial") {
return {functionName, core::AggregationNode::Step::kPartial};
}
if (suffix == "_merge_extract") {
return {functionName, core::AggregationNode::Step::kFinal};
}
if (suffix == "_merge") {
return {functionName, core::AggregationNode::Step::kIntermediate};
}
// others, not a companion function
return {companionName, core::AggregationNode::Step::kSingle};
}
::substrait::AggregationPhase toAggregationPhase(const core::AggregationNode::Step& step) {
switch (step) {
case core::AggregationNode::Step::kPartial: {
return ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE;
}
case core::AggregationNode::Step::kIntermediate: {
return ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE;
}
case core::AggregationNode::Step::kSingle: {
return ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT;
}
case core::AggregationNode::Step::kFinal: {
return ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT;
}
default:
VELOX_UNSUPPORTED("Unsupported Aggregate Step '{}' in Substrait ", mapAggregationStepToName(step));
}
}
::substrait::SortField_SortDirection toSortDirection(core::SortOrder sortOrder) {
if (sortOrder.isNullsFirst()) {
if (sortOrder.isAscending()) {
return ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST;
} else {
return ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST;
}
} else {
if (sortOrder.isAscending()) {
return ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST;
} else {
return ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST;
}
}
}
} // namespace
::substrait::Plan& VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const core::PlanNodePtr& plan) {
// Construct the extension colllector.
extensionCollector_ = std::make_shared<SubstraitExtensionCollector>();
// Construct the expression converter.
exprConvertor_ = std::make_unique<VeloxToSubstraitExprConvertor>(extensionCollector_);
auto substraitPlan = google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena);
// Add unknown type in extension.
auto unknownType = substraitPlan->add_extensions()->mutable_extension_type();
unknownType->set_extension_uri_reference(0);
unknownType->set_type_anchor(0);
unknownType->set_name("UNKNOWN");
// Do conversion.
::substrait::RelRoot* rootRel = substraitPlan->add_relations()->mutable_root();
toSubstrait(arena, plan, rootRel->mutable_input());
// Add extensions for all functions and types seen in the plan.
extensionCollector_->addExtensionsToPlan(substraitPlan);
// Set RootRel names.
for (const auto& name : plan->outputType()->names()) {
rootRel->add_names(name);
}
return *substraitPlan;
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const core::PlanNodePtr& planNode,
::substrait::Rel* rel) {
if (auto filterNode = std::dynamic_pointer_cast<const core::FilterNode>(planNode)) {
auto filterRel = rel->mutable_filter();
toSubstrait(arena, filterNode, filterRel);
return;
}
if (auto valuesNode = std::dynamic_pointer_cast<const core::ValuesNode>(planNode)) {
::substrait::ReadRel* readRel = rel->mutable_read();
toSubstrait(arena, valuesNode, readRel);
return;
}
if (auto projectNode = std::dynamic_pointer_cast<const core::ProjectNode>(planNode)) {
::substrait::ProjectRel* projectRel = rel->mutable_project();
toSubstrait(arena, projectNode, projectRel);
return;
}
if (auto aggregationNode = std::dynamic_pointer_cast<const core::AggregationNode>(planNode)) {
::substrait::AggregateRel* aggregateRel = rel->mutable_aggregate();
toSubstrait(arena, aggregationNode, aggregateRel);
return;
}
if (auto orderbyNode = std::dynamic_pointer_cast<const core::OrderByNode>(planNode)) {
toSubstrait(arena, orderbyNode, rel->mutable_sort());
return;
}
if (auto topNNode = std::dynamic_pointer_cast<const core::TopNNode>(planNode)) {
toSubstrait(arena, topNNode, rel->mutable_fetch());
return;
}
if (auto limitNode = std::dynamic_pointer_cast<const core::LimitNode>(planNode)) {
toSubstrait(arena, limitNode, rel->mutable_fetch());
return;
}
VELOX_UNSUPPORTED("Unsupported plan node '{}' .", planNode->name());
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::FilterNode>& filterNode,
::substrait::FilterRel* filterRel) {
const auto& source = getSingleSource(filterNode);
toSubstrait(arena, source, filterRel->mutable_input());
// Construct substrait expr(Filter condition).
auto filterCondition = filterNode->filter();
auto inputType = source->outputType();
filterRel->mutable_condition()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, filterCondition, inputType));
filterRel->mutable_common()->mutable_direct();
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::ValuesNode>& valuesNode,
::substrait::ReadRel* readRel) {
const auto& outputType = valuesNode->outputType();
::substrait::ReadRel_VirtualTable* virtualTable = readRel->mutable_virtual_table();
for (const auto& vector : valuesNode->values()) {
::substrait::Expression_Literal_Struct* litValue = virtualTable->add_values();
for (const auto& column : vector->children()) {
::substrait::Expression_Literal* substraitField =
google::protobuf::Arena::CreateMessage<::substrait::Expression_Literal>(&arena);
substraitField->MergeFrom(exprConvertor_->toSubstraitLiteral(arena, column, litValue));
}
}
readRel->mutable_base_schema()->MergeFrom(typeConvertor_->toSubstraitNamedStruct(arena, outputType));
readRel->mutable_common()->mutable_direct();
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::ProjectNode>& projectNode,
::substrait::ProjectRel* projectRel) {
const auto& projections = projectNode->projections();
const auto& source = getSingleSource(projectNode);
// Process the source Node.
toSubstrait(arena, source, projectRel->mutable_input());
// Remap the output.
::substrait::RelCommon_Emit* projRelEmit = projectRel->mutable_common()->mutable_emit();
int64_t projectionSize = projections.size();
auto inputType = source->outputType();
int64_t inputTypeSize = inputType->size();
for (int64_t i = 0; i < projectionSize; i++) {
const auto& veloxExpr = projections.at(i);
projectRel->add_expressions()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, veloxExpr, inputType));
// Add outputMapping for each expression.
projRelEmit->add_output_mapping(inputTypeSize + i);
}
return;
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::AggregationNode>& aggregateNode,
::substrait::AggregateRel* aggregateRel) {
// Process the source Node.
const auto& source = getSingleSource(aggregateNode);
toSubstrait(arena, source, aggregateRel->mutable_input());
// Convert aggregate grouping keys, such as: group by key1, key2.
auto inputType = source->outputType();
auto groupingKeys = aggregateNode->groupingKeys();
int64_t groupingKeySize = groupingKeys.size();
::substrait::AggregateRel_Grouping* aggGroupings = aggregateRel->add_groupings();
for (int64_t i = 0; i < groupingKeySize; i++) {
aggGroupings->add_grouping_expressions()->mutable_selection()->MergeFrom(
exprConvertor_->toSubstraitExpr(arena, groupingKeys.at(i), inputType));
}
// AggregatesSize should be equal to or greater than the aggregateMasks Size.
// Two cases: 1. aggregateMasksSize = 0, aggregatesSize > aggregateMasksSize.
// 2. aggregateMasksSize != 0, aggregatesSize = aggregateMasksSize.
auto aggregates = aggregateNode->aggregates();
int64_t aggregatesSize = aggregates.size();
for (int64_t i = 0; i < aggregatesSize; i++) {
const auto& aggregate = aggregates.at(i);
::substrait::AggregateRel_Measure* aggMeasures = aggregateRel->add_measures();
// Set substrait filter.
::substrait::Expression* aggFilter = aggMeasures->mutable_filter();
if (const auto& mask = aggregate.mask) {
aggFilter->mutable_selection()->MergeFrom(exprConvertor_->toSubstraitExpr(arena, mask, inputType));
} else {
// Set null.
aggFilter = nullptr;
}
// Process measure, eg:sum(a).
::substrait::AggregateFunction* aggFunction = aggMeasures->mutable_measure();
// Use aggregate node's step information to write advanced extension 'allowFlush'.
const auto& step = aggregateNode->step();
switch (step) {
case core::AggregationNode::Step::kPartial: {
substrait::extensions::AdvancedExtension ae{};
google::protobuf::StringValue msg;
msg.set_value("allowFlush=1");
ae.mutable_optimization()->PackFrom(msg);
aggregateRel->mutable_advanced_extension()->MergeFrom(ae);
break;
}
case core::AggregationNode::Step::kSingle:
break;
case core::AggregationNode::Step::kFinal:
case core::AggregationNode::Step::kIntermediate:
VELOX_USER_FAIL("Step not supported");
break;
}
// Set aggFunction args.
std::vector<TypePtr> arguments;
arguments.reserve(aggregate.call->inputs().size());
for (const auto& expr : aggregate.call->inputs()) {
// If the expr is CallTypedExpr, people need to do project firstly.
if (auto aggregatesExprInput = std::dynamic_pointer_cast<const core::CallTypedExpr>(expr)) {
VELOX_NYI("In Velox Plan, the aggregates type cannot be CallTypedExpr");
} else {
aggFunction->add_arguments()->mutable_value()->MergeFrom(
exprConvertor_->toSubstraitExpr(arena, expr, inputType));
arguments.emplace_back(expr->type());
}
}
const auto& aggregateCompanion = toAggregateCompanion(aggregate);
auto referenceNumber =
extensionCollector_->getReferenceNumber(aggregateCompanion.functionName, aggregate.rawInputTypes);
aggFunction->set_function_reference(referenceNumber);
aggFunction->mutable_output_type()->MergeFrom(typeConvertor_->toSubstraitType(arena, aggregate.call->type()));
// Set substrait aggregate Function phase.
aggFunction->set_phase(toAggregationPhase(aggregateCompanion.step));
}
// Direct output.
aggregateRel->mutable_common()->mutable_direct();
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::OrderByNode>& orderByNode,
::substrait::SortRel* sortRel) {
const auto& source = getSingleSource(orderByNode);
toSubstrait(arena, source, sortRel->mutable_input());
sortRel->MergeFrom(
processSortFields(arena, orderByNode->sortingKeys(), orderByNode->sortingOrders(), source->outputType()));
VELOX_CHECK(!orderByNode->isPartial(), "Substrait doesn't support partial order by yet");
sortRel->mutable_common()->mutable_direct();
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::TopNNode>& topNNode,
::substrait::FetchRel* fetchRel) {
const auto& source = getSingleSource(topNNode);
// Construct the sortRel as the FetchRel input.
::substrait::SortRel* sortRel = fetchRel->mutable_input()->mutable_sort();
toSubstrait(arena, source, sortRel->mutable_input());
sortRel->MergeFrom(
processSortFields(arena, topNNode->sortingKeys(), topNNode->sortingOrders(), source->outputType()));
sortRel->mutable_common()->mutable_direct();
VELOX_CHECK(!topNNode->isPartial(), "Substrait doesn't support partial topN yet");
fetchRel->set_offset(0);
fetchRel->set_count(topNNode->count());
fetchRel->mutable_common()->mutable_direct();
}
const ::substrait::SortRel& VeloxToSubstraitPlanConvertor::processSortFields(
google::protobuf::Arena& arena,
const std::vector<core::FieldAccessTypedExprPtr>& sortingKeys,
const std::vector<core::SortOrder>& sortingOrders,
const facebook::velox::RowTypePtr& inputType) {
::substrait::SortRel* sortRel = google::protobuf::Arena::CreateMessage<::substrait::SortRel>(&arena);
VELOX_CHECK_EQ(
sortingKeys.size(), sortingOrders.size(), "Number of sorting keys and sorting orders must be the same");
for (int64_t i = 0; i < sortingKeys.size(); i++) {
::substrait::SortField* sortField = sortRel->add_sorts();
sortField->mutable_expr()->mutable_selection()->MergeFrom(
exprConvertor_->toSubstraitExpr(arena, sortingKeys[i], inputType));
sortField->set_direction(toSortDirection(sortingOrders[i]));
}
return *sortRel;
}
void VeloxToSubstraitPlanConvertor::toSubstrait(
google::protobuf::Arena& arena,
const std::shared_ptr<const core::LimitNode>& limitNode,
::substrait::FetchRel* fetchRel) {
const auto& source = getSingleSource(limitNode);
toSubstrait(arena, source, fetchRel->mutable_input());
fetchRel->set_offset(limitNode->offset());
fetchRel->set_count(limitNode->count());
VELOX_CHECK(!limitNode->isPartial(), "Substrait doesn't support partial limit yet");
fetchRel->mutable_common()->mutable_direct();
}
const core::PlanNodePtr& VeloxToSubstraitPlanConvertor::getSingleSource(const core::PlanNodePtr& node) {
const auto& sources = node->sources();
VELOX_USER_CHECK_EQ(1, sources.size(), "Plan node must have exactly one source.");
return sources[0];
}
} // namespace gluten