blob: 6eedeba4d0b5a08fee1f117ee809bee7a3e63ffc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
object PushDownUtils extends PredicateHelper {
/**
* Pushes down filters to the data source reader
*
* @return pushed filter and post-scan filters.
*/
def pushFilters(
scanBuilder: ScanBuilder,
filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = {
scanBuilder match {
case r: SupportsPushDownFilters =>
// A map from translated data source leaf node filters to original catalyst filter
// expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially
// pushed down. This map can be used to construct a catalyst filter expression from the
// input filter, or a superset(partial push down filter) of the input filter.
val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression]
val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
// Catalyst filter expression that can't be translated to data source filters.
val untranslatableExprs = mutable.ArrayBuffer.empty[Expression]
for (filterExpr <- filters) {
val translated =
DataSourceStrategy.translateFilterWithMapping(filterExpr, Some(translatedFilterToExpr),
nestedPredicatePushdownEnabled = true)
if (translated.isEmpty) {
untranslatableExprs += filterExpr
} else {
translatedFilters += translated.get
}
}
// Data source filters that need to be evaluated again after scanning. which means
// the data source cannot guarantee the rows returned can pass these filters.
// As a result we must return it so Spark can plan an extra filter operator.
val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter =>
DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr)
}
(r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq)
case _ => (Nil, filters)
}
}
/**
* Pushes down aggregates to the data source reader
*
* @return pushed aggregation.
*/
def pushAggregates(
scanBuilder: ScanBuilder,
aggregates: Seq[AggregateExpression],
groupBy: Seq[Expression]): Option[Aggregation] = {
def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference(name).asInstanceOf[FieldReference])
case _ => None
}
scanBuilder match {
case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
return None
}
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
Some(agg).filter(r.pushAggregation)
case _ => None
}
}
/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
* @return the `Scan` instance (since column pruning is the last step of operator pushdown),
* and new output attributes after column pruning.
*/
def pruneColumns(
scanBuilder: ScanBuilder,
relation: DataSourceV2Relation,
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = {
val exprs = projects ++ filters
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)
scanBuilder match {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
val prunedSchema = if (rootFields.nonEmpty) {
SchemaPruning.pruneDataSchema(relation.schema, rootFields)
} else {
new StructType()
}
val neededFieldNames = neededOutput.map(_.name).toSet
r.pruneColumns(StructType(prunedSchema.filter(f => neededFieldNames.contains(f.name))))
val scan = r.build()
scan -> toOutputAttrs(scan.readSchema(), relation)
case r: SupportsPushDownRequiredColumns =>
r.pruneColumns(neededOutput.toStructType)
val scan = r.build()
// always project, in case the relation's output has been updated and doesn't match
// the underlying table schema
scan -> toOutputAttrs(scan.readSchema(), relation)
case _ => scanBuilder.build() -> relation.output
}
}
private def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
cleaned.toAttributes.map {
// we have to keep the attribute id during transformation
a => a.withExprId(nameToAttr(a.name).exprId)
}
}
}