using Lucene.Net.Analysis;
using Lucene.Net.Analysis.TokenAttributes;
using Lucene.Net.Index;
using Lucene.Net.Search;
using Lucene.Net.Util;
using System;
using System.Collections.Generic;
using System.IO;
namespace Lucene.Net.Classification
/// <summary>
/// A simplistic Lucene based NaiveBayes classifier, see <code></code>
/// @lucene.experimental
/// </summary>
public class SimpleNaiveBayesClassifier : IClassifier<BytesRef>
private AtomicReader atomicReader;
private string[] textFieldNames;
private string classFieldName;
private int docsWithClassSize;
private Analyzer analyzer;
private IndexSearcher indexSearcher;
private Query query;
/// <summary>
/// Creates a new NaiveBayes classifier.
/// Note that you must call <see cref="Train(AtomicReader, string, string, Analyzer)"/> before you can
/// classify any documents.
/// </summary>
public SimpleNaiveBayesClassifier()
/// <summary>
/// Train the classifier using the underlying Lucene index
/// </summary>
/// <param name="analyzer"> the analyzer used to tokenize / filter the unseen text</param>
/// <param name="atomicReader">the reader to use to access the Lucene index</param>
/// <param name="classFieldName">the name of the field containing the class assigned to documents</param>
/// <param name="textFieldName">the name of the field used to compare documents</param>
public virtual void Train(AtomicReader atomicReader, string textFieldName, string classFieldName, Analyzer analyzer)
Train(atomicReader, textFieldName, classFieldName, analyzer, null);
/// <summary>Train the classifier using the underlying Lucene index</summary>
/// <param name="analyzer">the analyzer used to tokenize / filter the unseen text</param>
/// <param name="atomicReader">the reader to use to access the Lucene index</param>
/// <param name="classFieldName">the name of the field containing the class assigned to documents</param>
/// <param name="query">the query to filter which documents use for training</param>
/// <param name="textFieldName">the name of the field used to compare documents</param>
public virtual void Train(AtomicReader atomicReader, string textFieldName, string classFieldName, Analyzer analyzer, Query query)
Train(atomicReader, new string[]{textFieldName}, classFieldName, analyzer, query);
/// <summary>Train the classifier using the underlying Lucene index</summary>
/// <param name="analyzer">the analyzer used to tokenize / filter the unseen text</param>
/// <param name="atomicReader">the reader to use to access the Lucene index</param>
/// <param name="classFieldName">the name of the field containing the class assigned to documents</param>
/// <param name="query">the query to filter which documents use for training</param>
/// <param name="textFieldNames">the names of the fields to be used to compare documents</param>
public virtual void Train(AtomicReader atomicReader, string[] textFieldNames, string classFieldName, Analyzer analyzer, Query query)
this.atomicReader = atomicReader;
indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.query = query;
docsWithClassSize = CountDocsWithClass();
private int CountDocsWithClass()
int docCount = MultiFields.GetTerms(atomicReader, classFieldName).DocCount;
if (docCount == -1)
{ // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
BooleanQuery q = new BooleanQuery
new BooleanClause(new WildcardQuery(new Term(classFieldName, WildcardQuery.WILDCARD_STRING.ToString())), Occur.MUST)
if (query != null)
q.Add(query, Occur.MUST);
indexSearcher.Search(q, totalHitCountCollector);
docCount = totalHitCountCollector.TotalHits;
return docCount;
private string[] TokenizeDoc(string doc)
ICollection<string> result = new LinkedList<string>();
foreach (string textFieldName in textFieldNames) {
TokenStream tokenStream = analyzer.GetTokenStream(textFieldName, new StringReader(doc));
ICharTermAttribute charTermAttribute = tokenStream.AddAttribute<ICharTermAttribute>();
while (tokenStream.IncrementToken())
var ret = new string[result.Count];
result.CopyTo(ret, 0);
return ret;
/// <summary>
/// Assign a class (with score) to the given text string
/// </summary>
/// <param name="inputDocument">a string containing text to be classified</param>
/// <returns>a <see cref="ClassificationResult{BytesRef}"/> holding assigned class of type <see cref="BytesRef"/> and score</returns>
public virtual ClassificationResult<BytesRef> AssignClass(string inputDocument)
if (atomicReader == null)
throw new IOException("You must first call Classifier#train");
double max = - double.MaxValue;
BytesRef foundClass = new BytesRef();
Terms terms = MultiFields.GetTerms(atomicReader, classFieldName);
TermsEnum termsEnum = terms.GetEnumerator();
BytesRef next;
string[] tokenizedDoc = TokenizeDoc(inputDocument);
while (termsEnum.MoveNext())
next = termsEnum.Term;
double clVal = CalculateLogPrior(next) + CalculateLogLikelihood(tokenizedDoc, next);
if (clVal > max)
max = clVal;
foundClass = BytesRef.DeepCopyOf(next);
double score = 10 / Math.Abs(max);
return new ClassificationResult<BytesRef>(foundClass, score);
private double CalculateLogLikelihood(string[] tokenizedDoc, BytesRef c)
// for each word
double result = 0d;
foreach (string word in tokenizedDoc)
// search with text:word AND class:c
int hits = GetWordFreqForClass(word, c);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = GetTextTermFreqForClass(c) + docsWithClassSize;
// P(w|c) = num/den
double wordProbability = num / den;
result += Math.Log(wordProbability);
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
return result;
private double GetTextTermFreqForClass(BytesRef c)
double avgNumberOfUniqueTerms = 0;
foreach (string textFieldName in textFieldNames)
Terms terms = MultiFields.GetTerms(atomicReader, textFieldName);
long numPostings = terms.SumDocFreq; // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.DocCount; // avg # of unique terms per doc
int docsWithC = atomicReader.DocFreq(new Term(classFieldName, c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
private int GetWordFreqForClass(string word, BytesRef c)
BooleanQuery booleanQuery = new BooleanQuery();
BooleanQuery subQuery = new BooleanQuery();
foreach (string textFieldName in textFieldNames)
subQuery.Add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), Occur.SHOULD));
booleanQuery.Add(new BooleanClause(subQuery, Occur.MUST));
booleanQuery.Add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), Occur.MUST));
if (query != null)
booleanQuery.Add(query, Occur.MUST);
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.Search(booleanQuery, totalHitCountCollector);
return totalHitCountCollector.TotalHits;
private double CalculateLogPrior(BytesRef currentClass)
return Math.Log((double) DocCount(currentClass)) - Math.Log(docsWithClassSize);
private int DocCount(BytesRef countedClass)
return atomicReader.DocFreq(new Term(classFieldName, countedClass));