blob: 8c6ed10da039a66cad58dfb85b32a2d591f42c15 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
/**
* A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors
* </code>) based on {@link MoreLikeThis}
*
* @lucene.experimental
*/
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
/** a {@link MoreLikeThis} instance used to perform MLT queries */
protected final MoreLikeThis mlt;
/** the name of the fields used as the input text */
protected final String[] textFieldNames;
/** the name of the field used as the output text */
protected final String classFieldName;
/** an {@link IndexSearcher} used to perform queries */
protected final IndexSearcher indexSearcher;
/** the no. of docs to compare in order to find the nearest neighbor to the input text */
protected final int k;
/**
* a {@link Query} used to filter the documents that should be used from this classifier's
* underlying {@link LeafReader}
*/
protected final Query query;
/**
* Creates a {@link KNearestNeighborClassifier}.
*
* @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or
* {@code null} (defaults to {@link org.apache.lucene.search.similarities.BM25Similarity})
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or
* {@code null} if all the indexed docs should be used
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
* @param minDocsFreq {@link MoreLikeThis#minDocFreq} parameter
* @param minTermFreq {@link MoreLikeThis#minTermFreq} parameter
* @param classFieldName the name of the field used as the output for the classifier
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can
* contain boosting indication e.g. title^10
*/
public KNearestNeighborClassifier(
IndexReader indexReader,
Similarity similarity,
Analyzer analyzer,
Query query,
int k,
int minDocsFreq,
int minTermFreq,
String classFieldName,
String... textFieldNames) {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.mlt = new MoreLikeThis(indexReader);
this.mlt.setAnalyzer(analyzer);
this.mlt.setFieldNames(textFieldNames);
this.indexSearcher = new IndexSearcher(indexReader);
if (similarity != null) {
this.indexSearcher.setSimilarity(similarity);
} else {
this.indexSearcher.setSimilarity(new BM25Similarity());
}
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
this.k = k;
}
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
return classifyFromTopDocs(knnSearch(text));
}
/** TODO */
protected ClassificationResult<BytesRef> classifyFromTopDocs(TopDocs knnResults)
throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> cl : assignedClasses) {
if (cl.getScore() > maxscore) {
assignedClass = cl;
maxscore = cl.getScore();
}
}
return assignedClass;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
private TopDocs knnSearch(String text) throws IOException {
BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
for (String fieldName : textFieldNames) {
String boost = null;
mlt.setBoost(true); // terms boost actually helps in MLT queries
if (fieldName.contains("^")) {
String[] field2boost = fieldName.split("\\^");
fieldName = field2boost[0];
boost = field2boost[1];
}
if (boost != null) {
mlt.setBoostFactor(Float.parseFloat(boost)); // if we have a field boost, we add it
}
mltQuery.add(
new BooleanClause(
mlt.like(fieldName, new StringReader(text)), BooleanClause.Occur.SHOULD));
mlt.setBoostFactor(1); // restore neutral boost for next field
}
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
mltQuery.add(query, BooleanClause.Occur.MUST);
}
return indexSearcher.search(mltQuery.build(), k);
}
// ranking of classes must be taken in consideration
/**
* build a list of classification results from search results
*
* @param topDocs the search results as a {@link TopDocs} object
* @return a {@link List} of {@link ClassificationResult}, one for each existing class
* @throws IOException if it's not possible to get the stored value of class field
*/
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs)
throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts =
new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.totalHits.value == 0 ? Float.NaN : topDocs.scoreDocs[0].score;
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
IndexableField[] storableFields = indexSearcher.doc(scoreDoc.doc).getFields(classFieldName);
for (IndexableField singleStorableField : storableFields) {
if (singleStorableField != null) {
BytesRef cl = new BytesRef(singleStorableField.stringValue());
// update count
classCounts.merge(cl, 1, (a, b) -> a + b);
// update boost, the boost is based on the best score
Double totalBoost = classBoosts.get(cl);
double singleBoost = scoreDoc.score / maxScore;
if (totalBoost != null) {
classBoosts.put(cl, totalBoost + singleBoost);
} else {
classBoosts.put(cl, singleBoost);
}
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
Double normBoost =
classBoosts.get(entry.getKey()) / count; // the boost is normalized to be 0<b<1
temporaryList.add(
new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
sumdoc += count;
}
// correction
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : temporaryList) {
returnList.add(
new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
}
} else {
returnList = temporaryList;
}
return returnList;
}
@Override
public String toString() {
return "KNearestNeighborClassifier{"
+ "textFieldNames="
+ Arrays.toString(textFieldNames)
+ ", classFieldName='"
+ classFieldName
+ '\''
+ ", k="
+ k
+ ", query="
+ query
+ ", similarity="
+ indexSearcher.getSimilarity()
+ '}';
}
}