blob: 0ad261816373f917a59ebf3a34fba4d0cf7421c0 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Builtin for k nearest neighbor graph construction
#
# INPUT:
# --------------------------
# X ---
# k ---
# --------------------------
#
# OUTPUT:
# ------------------------
# graph ---
# ------------------------
m_knnGraph = function(Matrix[double] X, integer k) return (Matrix[double] graph) {
distances = dist(X);
graph = matrix(0, rows=nrow(distances), cols=ncol(distances));
ksmall = matrix(0, rows=nrow(distances), cols=1)
for (row in 1:nrow(distances)) {
referent = kthSmallest(distances[row], k + 1);
ksmall[row] = referent
}
graph = distances <= ksmall
# # assign zero to diagonal elements
diagonal = diag(matrix(1, rows=nrow(distances), cols=1)) == 0
graph = graph * diagonal
}
# # # TODO vectorize the below function
kthSmallest = function(Matrix[double] array, integer k)
return (integer res) {
left = 1;
right = ncol(array);
found = FALSE;
while ((left <= right) & !found) {
pivot = as.scalar(array[1,right]);
i = (left - 1);
j = left;
while (j < right) {
if (as.scalar(array[1,j]) <= pivot) {
i = i + 1;
temp = as.scalar(array[1,i]);
array[1,i] = array[1,j];
array[1,j] = temp;
}
j = j + 1;
}
temp = as.scalar(array[1,i + 1]);
array[1,i + 1] = array[1,right];
array[1,right] = temp;
pivot = i + 1;
if(pivot == k) {
res = as.scalar(array[1,pivot]);
found = TRUE;
}
else if (pivot > k)
right = pivot - 1;
else
left = pivot + 1;
}
if (!found)
res = -1;
}