blob: d05519b8801c1476e21f37f2cecf681ae864f6d7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = {
applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
}
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
case r: DataSourceV2Relation =>
ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options))
}
private def pushDownFilters(plan: LogicalPlan) = plan.transform {
// update the scan builder with filter push down and return a new plan with filter pushed
case Filter(condition, sHolder: ScanBuilderHolder) =>
val filters = splitConjunctivePredicates(condition)
val normalizedFilters =
DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output)
val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
// `pushedFilters` will be pushed down and evaluated in the underlying data sources.
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
sHolder.builder, normalizedFiltersWithoutSubquery)
val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
logInfo(
s"""
|Pushing operators to ${sHolder.relation.name}
|Pushed Filters: ${pushedFilters.mkString(", ")}
|Post-Scan Filters: ${postScanFilters.mkString(",")}
""".stripMargin)
val filterCondition = postScanFilters.reduceLeftOption(And)
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
}
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with agg pushed
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case _: SupportsPushDownAggregates =>
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
var ordinal = 0
val aggregates = resultExpressions.flatMap { expr =>
expr.collect {
// Do not push down duplicated aggregate expressions. For example,
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
// `max(a)` to the data source.
case agg: AggregateExpression
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
ordinal += 1
agg
}
}
val pushedAggregates = PushDownUtils
.pushAggregates(sHolder.builder, aggregates, groupingExpressions)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
// DataSourceV2ScanRelation output columns. All the other columns are not
// included in the output.
val scan = sHolder.builder.build()
// scalastyle:off
// use the group by columns and aggregate columns as the output columns
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
// We want to have the following logical plan:
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + aggregates.length)
val groupAttrs = groupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
}
val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
logInfo(
s"""
|Pushing operators to ${sHolder.relation.name}
|Pushed Aggregate Functions:
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|Pushed Group by:
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
val plan = Aggregate(
output.take(groupingExpressions.length), resultExpressions, scanRelation)
// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// The original logical plan is
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c1#9, c2#10] ...
//
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
// we have the following
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
//
// We want to change it to
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// scalastyle:on
val aggOutput = output.drop(groupAttrs.length)
plan.transformExpressions {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val aggFunction: aggregate.AggregateFunction =
agg.aggregateFunction match {
case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
case other => other
}
agg.copy(aggregateFunction = aggFunction)
}
}
case _ => aggNode
}
case _ => aggNode
}
}
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
// column pruning
val normalizedProjects = DataSourceStrategy
.normalizeExprs(project, sHolder.output)
.asInstanceOf[Seq[NamedExpression]]
val (scan, output) = PushDownUtils.pruneColumns(
sHolder.builder, sHolder.relation, normalizedProjects, filters)
logInfo(
s"""
|Output: ${output.mkString(", ")}
""".stripMargin)
val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation])
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
val projectionOverSchema = ProjectionOverSchema(output.toStructType)
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
val filterCondition = filters.reduceLeftOption(And)
val newFilterCondition = filterCondition.map(projectionFunc)
val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
val withProjection = if (withFilter.output != project) {
val newProjects = normalizedProjects
.map(projectionFunc)
.asInstanceOf[Seq[NamedExpression]]
Project(newProjects, withFilter)
} else {
withFilter
}
withProjection
}
private def getWrappedScan(
scan: Scan,
sHolder: ScanBuilderHolder,
aggregation: Option[Aggregation]): Scan = {
scan match {
case v1: V1Scan =>
val pushedFilters = sHolder.builder match {
case f: SupportsPushDownFilters =>
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
V1ScanWrapper(v1, pushedFilters, aggregation)
case _ => scan
}
}
}
case class ScanBuilderHolder(
output: Seq[AttributeReference],
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode
// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
// the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
handledFilters: Seq[sources.Filter],
pushedAggregate: Option[Aggregation]) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}