blob: ee605d572507e634eb5d981db994f0826ab7a385 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification.document;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;
/**
* A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier}
*
* @lucene.experimental
*/
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
/**
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields
*/
protected Map<String, Analyzer> field2analyzer;
/**
* Creates a new NaiveBayes classifier.
*
* @param indexReader the reader on the index to be used for classification
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be heavely analyzed
* as the returned class will be a token indexed for this field
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/
public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(indexReader, null, query, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer;
}
@Override
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> c : assignedClasses) {
if (c.getScore() > maxscore) {
assignedClass = c;
maxscore = c.getScore();
}
}
return assignedClass;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
Collections.sort(assignedClasses);
return assignedClasses;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
Terms classes = MultiTerms.getTerms(indexReader, classFieldName);
if (classes != null) {
TermsEnum classesEnum = classes.iterator();
BytesRef c;
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) {
double classScore = 0;
Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
}
assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
}
}
return normClassificationResults(assignedClasses);
}
/**
* This methods performs the analysis for the seed document and extract the boosts if present.
* This is done only one time for the Seed Document.
*
* @param inputDocument the seed unseen document
* @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values
* @param fieldName2boost a map that associates the boost to the field
* @throws IOException If there is a low-level I/O error
*/
private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
for (int i = 0; i < textFieldNames.length; i++) {
String fieldName = textFieldNames[i];
float boost = 1;
List<String[]> tokenizedValues = new LinkedList<>();
if (fieldName.contains("^")) {
String[] field2boost = fieldName.split("\\^");
fieldName = field2boost[0];
boost = Float.parseFloat(field2boost[1]);
}
IndexableField[] fieldValues = inputDocument.getFields(fieldName);
for (IndexableField fieldValue : fieldValues) {
TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null);
String[] fieldTokensArray = getTokenArray(fieldTokens);
tokenizedValues.add(fieldTokensArray);
}
fieldName2tokensArray.put(fieldName, tokenizedValues);
fieldName2boost.put(fieldName, boost);
textFieldNames[i] = fieldName;
}
}
/**
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
*
* @param tokenizedText the tokenized content of a field
* @return a {@code String} array of the resulting tokens
* @throws java.io.IOException If tokenization fails because there is a low-level I/O error
*/
protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
Collection<String> tokens = new LinkedList<>();
CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class);
tokenizedText.reset();
while (tokenizedText.incrementToken()) {
tokens.add(charTermAttribute.toString());
}
tokenizedText.end();
tokenizedText.close();
return tokens.toArray(new String[tokens.size()]);
}
/**
* @param tokenizedText the tokenized content of a field
* @param fieldName the input field name
* @param term the {@link Term} referring to the class to calculate the score of
* @param docsWithClass the total number of docs that have a class
* @return a normalized score for the class
* @throws IOException If there is a low-level I/O error
*/
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, fieldName, term);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(term, fieldName) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
result += Math.log(wordProbability);
}
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields
return normScore;
}
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
*
* @param term the class term
* @return the average number of unique terms
* @throws java.io.IOException If there is a low-level I/O error
*/
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
double avgNumberOfUniqueTerms;
Terms terms = MultiTerms.getTerms(indexReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = indexReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
/**
* Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected)
*
* @param word the token produced by the analyzer
* @param fieldName the field the word is coming from
* @param term the class term
* @return number of documents of the input class
* @throws java.io.IOException If there is a low-level I/O error
*/
private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
}
private int docCount(Term term) throws IOException {
return indexReader.docFreq(term);
}
}