blob: 39684ee25e740214abfe5267c17d49350ab15535 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification.document;
import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
/**
* A k-Nearest Neighbor Document classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
* on {@link org.apache.lucene.queries.mlt.MoreLikeThis} .
*
* @lucene.experimental
*/
public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier<BytesRef> {
/**
* map of per field analyzers
*/
protected Map<String, Analyzer> field2analyzer;
/**
* Creates a {@link KNearestNeighborClassifier}.
*
* @param indexReader the reader on the index to be used for classification
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link org.apache.lucene.search.similarities.BM25Similarity})
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
* @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter
* @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter
* @param classFieldName the name of the field used as the output for the classifier
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer;
}
@Override
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
return classifyFromTopDocs(knnSearch(document));
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
max = Math.min(max, assignedClasses.size());
return assignedClasses.subList(0, max);
}
/**
* Returns the top k results from a More Like This query based on the input document
*
* @param document the document to use for More Like This search
* @return the top results for the MLT query
* @throws IOException If there is a low-level I/O error
*/
private TopDocs knnSearch(Document document) throws IOException {
BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
for (String fieldName : textFieldNames) {
String boost = null;
if (fieldName.contains("^")) {
String[] field2boost = fieldName.split("\\^");
fieldName = field2boost[0];
boost = field2boost[1];
}
String[] fieldValues = document.getValues(fieldName);
mlt.setBoost(true); // we want always to use the boost coming from TF * IDF of the term
if (boost != null) {
mlt.setBoostFactor(Float.parseFloat(boost)); // this is an additional multiplicative boost coming from the field boost
}
mlt.setAnalyzer(field2analyzer.get(fieldName));
for (String fieldContent : fieldValues) {
mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
}
mlt.setBoostFactor(1);// restore neutral boost for next field
}
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
mltQuery.add(query, BooleanClause.Occur.MUST);
}
return indexSearcher.search(mltQuery.build(), k);
}
}