* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package datafu.spark
import java.util.{List => JavaList}
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{LongType, SparkOverwriteUDAFs, StructType}
* class definition so we could expose this functionality in PySpark
class SparkDFUtilsBridge {
def dedupWithOrder(df: DataFrame,
groupCol: Column,
orderCols: JavaList[Column]): DataFrame = {
val converted = convertJavaListToSeq(orderCols)
SparkDFUtils.dedupWithOrder(df = df, groupCol = groupCol, orderCols = converted: _*)
def dedupTopN(df: DataFrame,
n: Int,
groupCol: Column,
orderCols: JavaList[Column]): DataFrame = {
val converted = convertJavaListToSeq(orderCols)
SparkDFUtils.dedupTopN(df = df,
n = n,
groupCol = groupCol,
orderCols = converted: _*)
def dedupWithCombiner(df: DataFrame,
groupCol: Column,
orderByCol: Column,
desc: Boolean,
columnsFilter: JavaList[String],
columnsFilterKeep: Boolean): DataFrame = {
val columnsFilter_converted = convertJavaListToSeq(columnsFilter)
df = df,
groupCol = groupCol,
orderByCol = orderByCol,
desc = desc,
moreAggFunctions = Nil,
columnsFilter = columnsFilter_converted,
columnsFilterKeep = columnsFilterKeep
def changeSchema(df: DataFrame, newScheme: JavaList[String]): DataFrame = {
val newScheme_converted = convertJavaListToSeq(newScheme)
SparkDFUtils.changeSchema(df = df, newScheme = newScheme_converted: _*)
def joinSkewed(dfLeft: DataFrame,
dfRight: DataFrame,
joinExprs: Column,
numShards: Int,
joinType: String): DataFrame = {
SparkDFUtils.joinSkewed(dfLeft = dfLeft,
dfRight = dfRight,
joinExprs = joinExprs,
numShards = numShards,
joinType = joinType)
def broadcastJoinSkewed(notSkewed: DataFrame,
skewed: DataFrame,
joinCol: String,
numRowsToBroadcast: Int): DataFrame = {
SparkDFUtils.broadcastJoinSkewed(notSkewed = notSkewed,
skewed = skewed,
joinCol = joinCol,
numRowsToBroadcast = numRowsToBroadcast)
def joinWithRange(dfSingle: DataFrame,
colSingle: String,
dfRange: DataFrame,
colRangeStart: String,
colRangeEnd: String,
DECREASE_FACTOR: Long): DataFrame = {
SparkDFUtils.joinWithRange(dfSingle = dfSingle,
colSingle = colSingle,
dfRange = dfRange,
colRangeStart = colRangeStart,
colRangeEnd = colRangeEnd,
def joinWithRangeAndDedup(dfSingle: DataFrame,
colSingle: String,
dfRange: DataFrame,
colRangeStart: String,
colRangeEnd: String,
dedupSmallRange: Boolean): DataFrame = {
dfSingle = dfSingle,
colSingle = colSingle,
dfRange = dfRange,
colRangeStart = colRangeStart,
colRangeEnd = colRangeEnd,
dedupSmallRange = dedupSmallRange
private def convertJavaListToSeq[T](list: JavaList[T]): Seq[T] = {
object SparkDFUtils {
* Used to get the 'latest' record (after ordering according to the provided order columns)
* in each group.
* Different from {@link org.apache.spark.sql.Dataset#dropDuplicates} because order matters.
* @param df DataFrame to operate on
* @param groupCol column to group by the records
* @param orderCols columns to order the records according to
* @return DataFrame representing the data after the operation
def dedupWithOrder(df: DataFrame, groupCol: Column, orderCols: Column*): DataFrame = {
dedupTopN(df, 1, groupCol, orderCols: _*)
* Used get the top N records (after ordering according to the provided order columns)
* in each group.
* @param df DataFrame to operate on
* @param n number of records to return from each group
* @param groupCol column to group by the records
* @param orderCols columns to order the records according to
* @return DataFrame representing the data after the operation
def dedupTopN(df: DataFrame,
n: Int,
groupCol: Column,
orderCols: Column*): DataFrame = {
val w = Window.partitionBy(groupCol).orderBy(orderCols: _*)
df.withColumn("rn", row_number.over(w)).where(col("rn") <= n).drop("rn")
* Used to get the 'latest' record (after ordering according to the provided order columns)
* in each group.
* the same functionality as {@link #dedup} but implemented using UDAF to utilize
* map side aggregation.
* this function should be used in cases when you expect a large number of rows to get combined,
* as they share the same group column.
* @param df DataFrame to operate on
* @param groupCol column to group by the records
* @param orderByCol column to order the records according to
* @param desc have the order as desc
* @param moreAggFunctions more aggregate functions
* @param columnsFilter columns to filter
* @param columnsFilterKeep indicates whether we should filter the selected columns 'out'
* or alternatively have only those columns in the result
* @return DataFrame representing the data after the operation
def dedupWithCombiner(df: DataFrame,
groupCol: Column,
orderByCol: Column,
desc: Boolean = true,
moreAggFunctions: Seq[Column] = Nil,
columnsFilter: Seq[String] = Nil,
columnsFilterKeep: Boolean = true): DataFrame = {
val newDF =
if (columnsFilter == Nil) {
df.withColumn("sort_by_column", orderByCol)
} else {
if (columnsFilterKeep) {
df.withColumn("sort_by_column", orderByCol)
.select("sort_by_column", columnsFilter: _*)
} else {
.filter(colName => !columnsFilter.contains(colName))
.map(colName => new Column(colName)): _*)
.withColumn("sort_by_column", orderByCol)
val aggFunc =
if (desc) SparkOverwriteUDAFs.maxValueByKey(_: Column, _: Column)
else SparkOverwriteUDAFs.minValueByKey(_: Column, _: Column)
val df2 = newDF
.agg(aggFunc(expr("sort_by_column"), expr("struct(sort_by_column, *)"))
struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
.selectExpr("h2.*", "h1.*")
* Returns a DataFrame with the given column (should be a StructType)
* replaced by its inner fields.
* This method only flattens a single level of nesting.
* +-------+----------+----------+----------+
* |id |s.sub_col1|s.sub_col2|s.sub_col3|
* +-------+----------+----------+----------+
* |123 |1 |2 |3 |
* +-------+----------+----------+----------+
* +-------+----------+----------+----------+
* |id |sub_col1 |sub_col2 |sub_col3 |
* +-------+----------+----------+----------+
* |123 |1 |2 |3 |
* +-------+----------+----------+----------+
* @param df DataFrame to operate on
* @param colName column name for a column of type StructType
* @return DataFrame representing the data after the operation
def flatten(df: DataFrame, colName: String): DataFrame = {
s"Column $colName must be of type Struct")
val outerFields =
val flattenFields = df
.filter(f => !outerFields.contains(
.map("`" + colName + "`.`" + + "`")
df.selectExpr("*" +: flattenFields: _*).drop(colName)
* Returns a DataFrame with the column names renamed to the column names in the new schema
* @param df DataFrame to operate on
* @param newScheme new column names
* @return DataFrame representing the data after the operation
def changeSchema(df: DataFrame, newScheme: String*): DataFrame = {
case (oldCol: String, newCol: String) => col(oldCol).as(newCol)
}: _*)
* Used to perform a join when the right df is relatively small
* but still too big to fit in memory to perform map side broadcast join.
* Use cases:
* a. excluding keys that might be skewed from a medium size list.
* b. join a big skewed table with a table that has small number of very large rows.
* @param dfLeft left DataFrame
* @param dfRight right DataFrame
* @param joinExprs join expression
* @param numShards number of shards - number of times to duplicate the right DataFrame
* @param joinType join type
* @return joined DataFrame
def joinSkewed(dfLeft: DataFrame,
dfRight: DataFrame,
joinExprs: Column,
numShards: Int = 10,
joinType: String = "inner"): DataFrame = {
// skew join based on salting
// salts the left DF by adding another random column and join with the right DF after
// duplicating it
val ss = dfLeft.sparkSession
import ss.implicits._
val shards ="shard")
.withColumn("randLeft", ceil(rand() * numShards))
joinExprs and $"randLeft" === $"shard",
* Suitable to perform a join in cases when one DF is skewed and the other is not skewed.
* splits both of the DFs to two parts according to the skewed keys.
* 1. Map-join: broadcasts the skewed-keys part of the not skewed DF to the skewed-keys
* part of the skewed DF
* 2. Regular join: between the remaining two parts.
* @param notSkewed not skewed DataFrame
* @param skewed skewed DataFrame
* @param joinCol join column
* @param numRowsToBroadcast num of rows to broadcast
* @return DataFrame representing the data after the operation
def broadcastJoinSkewed(notSkewed: DataFrame,
skewed: DataFrame,
joinCol: String,
numRowsToBroadcast: Int): DataFrame = {
val ss = notSkewed.sparkSession
import ss.implicits._
val skewedKeys = skewed
.withColumnRenamed(joinCol, "skew_join_key")
val notSkewedWithSkewIndicator = notSkewed
.join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol), "left")
.withColumn("is_skewed_record", col("skew_join_key").isNotNull)
// broadcast map-join, sending the notSkewed data
val bigRecordsJnd =
.join(skewed, joinCol)
// regular join for the rest
val skewedWithoutSkewedKeys = skewed
.join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol), "left")
.where("skew_join_key is null")
val smallRecordsJnd = notSkewedWithSkewIndicator
.filter("not is_skewed_record")
.join(skewedWithoutSkewedKeys, joinCol)
.drop("is_skewed_record", "skew_join_key")
* Helper function to join a table with point column to a table with range column.
* For example, join a table that contains specific time in minutes with a table that
* contains time ranges.
* The main problem this function addresses is that doing naive explode on the ranges can result
* in a huge table.
* requires:
* 1. point table needs to be distinct on the point column. there could be a few corresponding
* ranges to each point, so we choose the minimal range.
* 2. the range and point columns need to be numeric.
* +-------+
* |time |
* +-------+
* |11:55 |
* +-------+
* +----------+---------+----------+
* |start_time|end_time |desc |
* +----------+---------+----------+
* |10:00 |12:00 | meeting |
* +----------+---------+----------+
* |11:50 |12:15 | lunch |
* +----------+---------+----------+
* +-------+----------+---------+---------+
* |time |start_time|end_time |desc |
* +-------+----------+---------+---------+
* |11:55 |10:00 |12:00 | meeting |
* +-------+----------+---------+---------+
* |11:55 |11:50 |12:15 | lunch |
* +-------+----------+---------+---------+
* @param dfSingle DataFrame that contains the point column
* @param colSingle the point column's name
* @param dfRange DataFrame that contains the range column
* @param colRangeStart the start range column's name
* @param colRangeEnd the end range column's name
* @param DECREASE_FACTOR resolution factor. instead of exploding the range column directly,
* we first decrease its resolution by this factor
* @return
def joinWithRange(dfSingle: DataFrame,
colSingle: String,
dfRange: DataFrame,
colRangeStart: String,
colRangeEnd: String,
DECREASE_FACTOR: Long): DataFrame = {
val dfJoined = joinWithRangeInternal(dfSingle,
private def joinWithRangeInternal(dfSingle: DataFrame,
colSingle: String,
dfRange: DataFrame,
colRangeStart: String,
colRangeEnd: String,
DECREASE_FACTOR: Long): DataFrame = {
import org.apache.spark.sql.functions.udf
val rangeUDF = udf((start: Long, end: Long) => (start to end).toArray)
val dfRange_exploded = dfRange
.withColumn("range_start", col(colRangeStart).cast(LongType))
.withColumn("range_end", col(colRangeEnd).cast(LongType))
rangeUDF(col("range_start") / lit(DECREASE_FACTOR),
col("range_end") / lit(DECREASE_FACTOR))))
.withColumn("single", floor(col(colSingle).cast(LongType)))
floor(col(colSingle).cast(LongType) / lit(DECREASE_FACTOR)))
col("decreased_single") === col("decreased_range_single"),
.withColumn("range_size", expr("(range_end - range_start + 1)"))
.filter("single>=range_start and single<=range_end")
* Run joinWithRange and afterwards run dedup
* @param dedupSmallRange - by small/large range
* OUTPUT for dedupSmallRange = "true":
* +-------+----------+---------+---------+
* |time |start_time|end_time |desc |
* +-------+----------+---------+---------+
* |11:55 |11:50 |12:15 | lunch |
* +-------+----------+---------+---------+
* OUTPUT for dedupSmallRange = "false":
* +-------+----------+---------+---------+
* |time |start_time|end_time |desc |
* +-------+----------+---------+---------+
* |11:55 |10:00 |12:00 | meeting |
* +-------+----------+---------+---------+
def joinWithRangeAndDedup(dfSingle: DataFrame,
colSingle: String,
dfRange: DataFrame,
colRangeStart: String,
colRangeEnd: String,
dedupSmallRange: Boolean): DataFrame = {
val dfJoined = joinWithRangeInternal(dfSingle,
// "range_start" is here for consistency
val dfDeduped = if (dedupSmallRange) {
struct("range_size", "range_start"),
desc = false)
} else {
struct(expr("-range_size"), col("range_start")),
desc = true)