leftDf.top_k_join(k: Column, rightDf: DataFrame, joinExprs: Column, score: Column)
only joins the top-k records of rightDf
for each leftDf
record with a join condition joinExprs
. An output schema of this operation is the joined schema of leftDf
and rightDf
plus (rank: Int, score: score
type).
top_k_join
is much IO-efficient as compared to regular joining + ranking operations because top_k_join
drops unsatisfied records and writes only top-k records to disks during joins.
Caution
top_k_join
is supported in the DataFrame of Spark v2.1.0 or later.- A type of
score
must be ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, or DecimalType.- If
k
is less than 0, the order is reverse andtop_k_join
joins the tail-K records ofrightDf
.
For example, we have two tables below;
leftDf
)scala> :paste val leftDf = Seq( (1, "b", 0.3, 0.3), (2, "a", 0.5, 0.4), (3, "a", 0.1, 0.8), (4, "c", 0.2, 0.2), (5, "a", 0.1, 0.4), (6, "b", 0.8, 0.8) ).toDF("userId", "group", "x", "y") scala> leftDf.show +------+-----+---+---+ |userId|group| x| y| +------+-----+---+---+ | 1| b|0.3|0.3| | 2| a|0.5|0.4| | 3| a|0.1|0.8| | 4| c|0.2|0.2| | 5| a|0.1|0.4| | 6| b|0.8|0.8| +------+-----+---+---+
rightDf
)scala> :paste val rightDf = Seq( ("a", "pos1", 0.0, 0.1), ("a", "pos2", 0.9, 0.3), ("a", "pos3", 0.3, 0.2), ("b", "pos4", 0.5, 0.7), ("b", "pos5", 0.4, 0.2), ("c", "pos6", 0.8, 0.7), ("c", "pos7", 0.3, 0.3), ("c", "pos8", 0.4, 0.2), ("c", "pos9", 0.3, 0.8) ).toDF("group", "position", "x", "y") scala> rightDf.show +-----+--------+---+---+ |group|position| x| y| +-----+--------+---+---+ | a| pos1|0.0|0.1| | a| pos2|0.9|0.3| | a| pos3|0.3|0.2| | b| pos4|0.5|0.7| | b| pos5|0.4|0.2| | c| pos6|0.8|0.7| | c| pos7|0.3|0.3| | c| pos8|0.4|0.2| | c| pos9|0.3|0.8| +-----+--------+---+---+
In the two tables, the example computes the nearest position
for userId
in each group
. The standard way using DataFrame window functions would be as follows:
scala> paste: val computeDistanceFunc = sqrt(pow(inputDf("x") - masterDf("x"), lit(2.0)) + pow(inputDf("y") - masterDf("y"), lit(2.0))) val resultDf = leftDf.join( right = rightDf, joinExpr = leftDf("group") === rightDf("group") ) .select(inputDf("group"), $"userId", $"posId", computeDistanceFunc.as("score")) .withColumn("rank", rank().over(Window.partitionBy($"group", $"userId").orderBy($"score".desc))) .where($"rank" <= 1)
You can use top_k_join
as follows:
scala> paste: import org.apache.spark.sql.hive.HivemallOps._ val resultDf = leftDf.top_k_join( k = lit(-1), right = rightDf, joinExpr = leftDf("group") === rightDf("group"), score = computeDistanceFunc.as("score") )
The result is as follows:
scala> resultDf.show +----+-------------------+------+-----+---+---+-----+--------+---+---+ |rank| score|userId|group| x| y|group|position| x| y| +----+-------------------+------+-----+---+---+-----+--------+---+---+ | 1|0.09999999999999998| 4| c|0.2|0.2| c| pos9|0.3|0.8| | 1|0.10000000000000003| 1| b|0.3|0.3| b| pos5|0.4|0.2| | 1|0.30000000000000004| 6| b|0.8|0.8| b| pos4|0.5|0.7| | 1| 0.2| 2| a|0.5|0.4| a| pos3|0.3|0.2| | 1| 0.1| 3| a|0.1|0.8| a| pos1|0.0|0.1| | 1| 0.1| 5| a|0.1|0.4| a| pos1|0.0|0.1| +----+-------------------+------+-----+---+---+-----+--------+---+---+
top_k_join
is also useful for Spark Vector users. If you'd like to filter the records having the smallest squared distances between vectors, you can use top_k_join
as follows;
scala> import org.apache.spark.ml.linalg._ scala> import org.apache.spark.sql.hive.HivemallOps._ scala> paste: val leftDf = Seq( (1, "a", Vectors.dense(Array(1.0, 0.5, 0.6, 0.2))), (2, "b", Vectors.dense(Array(0.2, 0.3, 0.4, 0.1))), (3, "a", Vectors.dense(Array(0.8, 0.4, 0.2, 0.6))), (4, "a", Vectors.dense(Array(0.2, 0.7, 0.4, 0.8))), (5, "c", Vectors.dense(Array(0.4, 0.5, 0.6, 0.2))), (6, "c", Vectors.dense(Array(0.3, 0.9, 1.0, 0.1))) ).toDF("userId", "group", "vector") scala> leftDf.show +------+-----+-----------------+ |userId|group| vector| +------+-----+-----------------+ | 1| a|[1.0,0.5,0.6,0.2]| | 2| b|[0.2,0.3,0.4,0.1]| | 3| a|[0.8,0.4,0.2,0.6]| | 4| a|[0.2,0.7,0.4,0.8]| | 5| c|[0.4,0.5,0.6,0.2]| | 6| c|[0.3,0.9,1.0,0.1]| +------+-----+-----------------+ scala> paste: val rightDf = Seq( ("a", "pos-1", Vectors.dense(Array(0.3, 0.4, 0.3, 0.5))), ("a", "pos-2", Vectors.dense(Array(0.9, 0.2, 0.8, 0.3))), ("a", "pos-3", Vectors.dense(Array(1.0, 0.0, 0.3, 0.1))), ("a", "pos-4", Vectors.dense(Array(0.1, 0.8, 0.5, 0.7))), ("b", "pos-5", Vectors.dense(Array(0.3, 0.3, 0.3, 0.8))), ("b", "pos-6", Vectors.dense(Array(0.0, 0.7, 0.5, 0.6))), ("b", "pos-7", Vectors.dense(Array(0.1, 0.8, 0.4, 0.5))), ("c", "pos-8", Vectors.dense(Array(0.8, 0.3, 0.2, 0.1))), ("c", "pos-9", Vectors.dense(Array(0.7, 0.5, 0.8, 0.3))) ).toDF("group", "position", "vector") scala> rightDf.show +-----+--------+-----------------+ |group|position| vector| +-----+--------+-----------------+ | a| pos-1|[0.3,0.4,0.3,0.5]| | a| pos-2|[0.9,0.2,0.8,0.3]| | a| pos-3|[1.0,0.0,0.3,0.1]| | a| pos-4|[0.1,0.8,0.5,0.7]| | b| pos-5|[0.3,0.3,0.3,0.8]| | b| pos-6|[0.0,0.7,0.5,0.6]| | b| pos-7|[0.1,0.8,0.4,0.5]| | c| pos-8|[0.8,0.3,0.2,0.1]| | c| pos-9|[0.7,0.5,0.8,0.3]| +-----+--------+-----------------+ scala> paste: val sqDistFunc = udf { (v1: Vector, v2: Vector) => Vectors.sqdist(v1, v2) } val resultDf = leftDf.top_k_join( k = lit(-1), right = rightDf, joinExpr = leftDf("group") === rightDf("group"), score = sqDistFunc(leftDf("vector"), rightDf("vector")).as("score") ) scala> resultDf.show +----+-------------------+------+-----+-----------------+-----+--------+-----------------+ |rank| score|userId|group| vector|group|position| vector| +----+-------------------+------+-----+-----------------+-----+--------+-----------------+ | 1|0.13999999999999996| 5| c|[0.4,0.5,0.6,0.2]| c| pos-9|[0.7,0.5,0.8,0.3]| | 1|0.39999999999999997| 6| c|[0.3,0.9,1.0,0.1]| c| pos-9|[0.7,0.5,0.8,0.3]| | 1|0.42000000000000004| 2| b|[0.2,0.3,0.4,0.1]| b| pos-7|[0.1,0.8,0.4,0.5]| | 1|0.15000000000000002| 1| a|[1.0,0.5,0.6,0.2]| a| pos-2|[0.9,0.2,0.8,0.3]| | 1| 0.27| 3| a|[0.8,0.4,0.2,0.6]| a| pos-1|[0.3,0.4,0.3,0.5]| | 1|0.04000000000000003| 4| a|[0.2,0.7,0.4,0.8]| a| pos-4|[0.1,0.8,0.5,0.7]| +----+-------------------+------+-----+-----------------+-----+--------+-----------------+