blob: 527e6365ff627dcc281fa07cc899f7dcbadfe7b5 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.utils._
import{DenseVector, Vector => FlinkVector}
import{FitOperation, PredictDataSetOperation, Predictor}
import org.apache.flink.util.Collector
import scala.collection.immutable.Vector
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
/** Implements a `k`-nearest neighbor join.
* Calculates the `k`-nearest neighbor points in the training set for each point in the test set.
* @example
* {{{
* val trainingDS: DataSet[Vector] = ...
* val testingDS: DataSet[Vector] = ...
* val knn = KNN()
* .setK(10)
* .setBlocks(5)
* .setDistanceMetric(EuclideanDistanceMetric())
* val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS)
* }}}
* =Parameters=
* - [[]]
* Sets the K which is the number of selected points as neighbors. (Default value: '''5''')
* - [[]]
* Sets the distance metric we use to calculate the distance between two points. If no metric is
* specified, then [[]] is used.
* (Default value: '''EuclideanDistanceMetric()''')
* - [[]]
* Sets the number of blocks into which the input data will be split. This number should be set
* at least to the degree of parallelism. If no value is specified, then the parallelism of the
* input [[DataSet]] is used as the number of blocks. (Default value: '''None''')
* - [[]]
* A boolean variable that whether or not to use a quadtree to partition the training set
* to potentially simplify the KNN search. If no value is specified, the code will
* automatically decide whether or not to use a quadtree. Use of a quadtree scales well
* with the number of training and testing points, though poorly with the dimension.
* (Default value: '''None''')
* - [[]]
* Specifies whether the training set or test set is small to optimize the cross
* product operation needed for the KNN search. If the training set is small
* this should be `CrossHint.FIRST_IS_SMALL` and set to `CrossHint.SECOND_IS_SMALL`
* if the test set is small.
* (Default value: '''None''')
class KNN extends Predictor[KNN] {
import KNN._
var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
/** Sets K
* @param k the number of selected points as neighbors
def setK(k: Int): KNN = {
require(k > 0, "K must be positive.")
parameters.add(K, k)
/** Sets the distance metric
* @param metric the distance metric to calculate distance between two points
def setDistanceMetric(metric: DistanceMetric): KNN = {
parameters.add(DistanceMetric, metric)
/** Sets the number of data blocks/partitions
* @param n the number of data blocks
def setBlocks(n: Int): KNN = {
require(n > 0, "Number of blocks must be positive.")
parameters.add(Blocks, n)
/** Sets the Boolean variable that decides whether to use the QuadTree or not */
def setUseQuadTree(useQuadTree: Boolean): KNN = {
if (useQuadTree) {
require(parameters(DistanceMetric).isInstanceOf[SquaredEuclideanDistanceMetric] ||
parameters.add(UseQuadTree, useQuadTree)
/** Parameter a user can specify if one of the training or test sets are small
* @param sizeHint cross hint tells the system which sizes to expect from the data sets
def setSizeHint(sizeHint: CrossHint): KNN = {
parameters.add(SizeHint, sizeHint)
object KNN {
case object K extends Parameter[Int] {
val defaultValue: Option[Int] = Some(5)
case object DistanceMetric extends Parameter[DistanceMetric] {
val defaultValue: Option[DistanceMetric] = Some(EuclideanDistanceMetric())
case object Blocks extends Parameter[Int] {
val defaultValue: Option[Int] = None
case object UseQuadTree extends Parameter[Boolean] {
val defaultValue: Option[Boolean] = None
case object SizeHint extends Parameter[CrossHint] {
val defaultValue: Option[CrossHint] = None
def apply(): KNN = {
new KNN()
/** [[FitOperation]] which trains a KNN based on the given training data set.
* @tparam T Subtype of [[]]
implicit def fitKNN[T <: FlinkVector : TypeInformation] = new FitOperation[KNN, T] {
override def fit(
instance: KNN,
fitParameters: ParameterMap,
input: DataSet[T]
): Unit = {
val resultParameters = instance.parameters ++ fitParameters
require(resultParameters.get(K).isDefined, "K is needed for calculation")
val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
val partitioner = FlinkMLTools.ModuloKeyPartitioner
val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, Some(partitioner)))
/** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the given testing data
* set.
* @tparam T Subtype of [[Vector]]
* @return The given testing data set with k-nearest neighbors
implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
override def predictDataSet(
instance: KNN,
predictParameters: ParameterMap,
input: DataSet[T]
): DataSet[(FlinkVector, Array[FlinkVector])] = {
val resultParameters = instance.parameters ++ predictParameters
instance.trainingSet match {
case Some(trainingSet) =>
val k = resultParameters.get(K).get
val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
val metric = resultParameters.get(DistanceMetric).get
val partitioner = FlinkMLTools.ModuloKeyPartitioner
// attach unique id for each data
val inputWithId: DataSet[(Long, T)] = input.zipWithUniqueId
// split data into multiple blocks
val inputSplit = FlinkMLTools.block(inputWithId, blocks, Some(partitioner))
val sizeHint = resultParameters.get(SizeHint)
val crossTuned = sizeHint match {
case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
case _ => trainingSet.cross(inputSplit)
// join input and training set
val crossed = crossTuned.mapPartition {
(iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) => {
for ((training, testing) <- iter) {
// use a quadtree if (4 ^ dim) * Ntest * log(Ntrain)
// < Ntest * Ntrain, and distance is Euclidean
val checkSize = math.log(4.0) * training.values.head.size +
math.log(math.log(training.values.length)) < math.log(training.values.length)
val checkMetric = metric match {
case _: EuclideanDistanceMetric => true
case _: SquaredEuclideanDistanceMetric => true
case _ => false
val useQuadTree = resultParameters.get(UseQuadTree)
.getOrElse(checkSize && checkMetric)
if (useQuadTree) {
knnQueryWithQuadTree(training.values, testing.values, k, metric, out)
} else {
knnQueryBasic(training.values, testing.values, k, metric, out)
// group by input vector id and pick k nearest neighbor for each group
val result = crossed.groupBy(2).sortGroup(3, Order.ASCENDING).reduceGroup {
(iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
if (iter.hasNext) {
val head =
val key = head._2
val neighbors: ArrayBuffer[FlinkVector] = ArrayBuffer(head._1)
for ((vector, _, _, _) <- iter.take(k - 1)) {
// we already took a first element
neighbors += vector
out.collect(key, neighbors.toArray)
case None => throw new RuntimeException("The KNN model has not been trained." +
"Call first fit before calling the predict operation.")
private def knnQueryWithQuadTree[T <: FlinkVector](
training: Vector[T],
testing: Vector[(Long, T)],
k: Int,
metric: DistanceMetric,
out: Collector[(FlinkVector, FlinkVector, Long, Double)]
): Unit = {
// find a bounding box
val MinArr = Array.tabulate(training.head.size)(x => x)
val MaxArr = Array.tabulate(training.head.size)(x => x)
val minVecTrain = => => x(i)).min - 0.01)
val minVecTest = => => x._2(i)).min - 0.01)
val maxVecTrain = => => x(i)).max + 0.01)
val maxVecTest = => => x._2(i)).max + 0.01)
val MinVec = DenseVector( => math.min(minVecTrain(i), minVecTest(i))))
val MaxVec = DenseVector( => math.max(maxVecTrain(i), maxVecTest(i))))
// default value of max elements/box is set to max(20,k)
val maxPerBox = math.max(k, 20)
val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, Double)]()(
for (v <- training) {
for ((id, vector) <- testing) {
// Find siblings' objects and do local kNN there
val siblingObjects = trainingQuadTree.searchNeighborsSiblingQueue(vector)
// do KNN query on siblingObjects and get max distance of kNN then rad is good choice
// for a neighborhood to do a refined local kNN search
val knnSiblings = => metric.distance(vector, v)).sortWith(_ < _).take(k)
val rad = knnSiblings.last
val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
for (b <- trainingFiltered) {
// (training vector, input vector, input key, distance)
queue.enqueue((b, vector, id, metric.distance(b, vector)))
if (queue.size > k) {
for (v <- queue) {
private def knnQueryBasic[T <: FlinkVector](
training: Vector[T],
testing: Vector[(Long, T)],
k: Int,
metric: DistanceMetric,
out: Collector[(FlinkVector, FlinkVector, Long, Double)]
): Unit = {
val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, Double)]()(
for ((id, vector) <- testing) {
for (b <- training) {
// (training vector, input vector, input key, distance)
queue.enqueue((b, vector, id, metric.distance(b, vector)))
if (queue.size > k) {
for (v <- queue) {