blob: 11466801342acf550e1c0fb4debef0c1c1d32e31 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
m_knnbf = function(
Matrix[Double] X,
Matrix[Double] T,
Integer k_value = 5
) return(
Matrix[Double] NNR
)
{
num_records = nrow(X);
num_queries = nrow(T);
D = matrix(0, rows = num_records, cols = num_queries);
NNR = matrix(0, rows = num_queries, cols = k_value);
parfor(i in 1 : num_queries) {
D[ , i] = calculateDistance(X, T[i, ]);
NNR[i, ] = sortAndGetK(D[ , i], k_value);
}
}
calculateDistance = function(Matrix[Double] R, Matrix[Double] Q)
return(Matrix[Double] distances)
{
NR = rowSums(R ^ 2) %*% matrix(1,1,nrow(Q));
NQ = matrix(1,nrow(R),1) %*% t(rowSums(Q ^ 2));
distances = NR + NQ - 2.0 * R %*% t(Q);
}
sortAndGetK = function(Matrix[Double] D, Integer k)
return (Matrix[Double] knn_)
{
if(nrow(D) < k)
stop("can not pick "+k+" nearest neighbours from "+nrow(D)+" total instances")
sort_dist = order(target = D, by = 1, decreasing= FALSE, index.return = TRUE)
knn_ = t(sort_dist[1:k,])
}