blob: 1479ed3b3469c0b676ff06624469839c55700b78 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
@Override
public Object doWork(Object... values) throws IOException {
if(values.length < 3) {
throw new IOException("knn expects three parameters a Matrix, numeric array and k");
}
Matrix matrix = null;
double[] vec = null;
int k = 0;
if(values[0] instanceof Matrix) {
matrix = (Matrix)values[0];
} else {
throw new IOException("The first parameter for knn should be a matrix.");
}
if(values[1] instanceof List) {
@SuppressWarnings({"unchecked"})
List<Number> nums = (List<Number>)values[1];
vec = new double[nums.size()];
for(int i=0; i<nums.size(); i++) {
vec[i] = nums.get(i).doubleValue();
}
} else {
throw new IOException("The second parameter for knn should be a numeric array.");
}
if(values[2] instanceof Number) {
k = ((Number)values[2]).intValue();
} else {
throw new IOException("The third parameter for knn should be k.");
}
DistanceMeasure distanceMeasure = null;
if(values.length == 4) {
distanceMeasure = (DistanceMeasure)values[3];
} else {
distanceMeasure = new EuclideanDistance();
}
return search(matrix, vec, k, distanceMeasure);
}
public static Matrix search(Matrix observations,
double[] vec,
int k,
DistanceMeasure distanceMeasure) {
double[][] data = observations.getData();
TreeSet<Neighbor> neighbors = new TreeSet<>();
for(int i=0; i<data.length; i++) {
double distance = distanceMeasure.compute(vec, data[i]);
neighbors.add(new Neighbor(i, distance));
if(neighbors.size() > k) {
neighbors.pollLast();
}
}
double[][] out = new double[neighbors.size()][];
List<String> rowLabels = observations.getRowLabels();
List<String> newRowLabels = new ArrayList<>();
List<Number> indexes = new ArrayList<>();
List<Number> distances = new ArrayList<>();
int i=-1;
while(neighbors.size() > 0) {
Neighbor neighbor = neighbors.pollFirst();
int rowIndex = neighbor.getRow();
if(rowLabels != null) {
newRowLabels.add(rowLabels.get(rowIndex));
}
out[++i] = data[rowIndex];
distances.add(neighbor.getDistance());
indexes.add(rowIndex);
}
Matrix knn = new Matrix(out);
if(rowLabels != null) {
knn.setRowLabels(newRowLabels);
}
knn.setColumnLabels(observations.getColumnLabels());
knn.setAttribute("distances", distances);
knn.setAttribute("indexes", indexes);
return knn;
}
public static class Neighbor implements Comparable<Neighbor> {
private Double distance;
private int row;
public Neighbor(int row, double distance) {
this.distance = distance;
this.row = row;
}
public int getRow() {
return this.row;
}
public Double getDistance() {
return distance;
}
public int compareTo(Neighbor neighbor) {
if(this.distance.compareTo(neighbor.getDistance()) == 0) {
return row-neighbor.getRow();
}
return this.distance.compareTo(neighbor.getDistance());
}
}
}