| using Lucene.Net.Analysis; |
| using Lucene.Net.Analysis.TokenAttributes; |
| using Lucene.Net.Index; |
| using Lucene.Net.Search; |
| using Lucene.Net.Util; |
| using System; |
| using System.Collections.Generic; |
| using System.IO; |
| |
| namespace Lucene.Net.Classification |
| { |
| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| /// <summary> |
| /// A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code> |
| /// |
| /// @lucene.experimental |
| /// </summary> |
| public class SimpleNaiveBayesClassifier : IClassifier<BytesRef> |
| { |
| private AtomicReader _atomicReader; |
| private string[] _textFieldNames; |
| private string _classFieldName; |
| private int _docsWithClassSize; |
| private Analyzer _analyzer; |
| private IndexSearcher _indexSearcher; |
| private Query _query; |
| |
| /// <summary> |
| /// Creates a new NaiveBayes classifier. |
| /// Note that you must call <see cref="Train(AtomicReader, string, string, Analyzer)"/> before you can |
| /// classify any documents. |
| /// </summary> |
| public SimpleNaiveBayesClassifier() |
| { |
| } |
| |
| /// <summary> |
| /// Train the classifier using the underlying Lucene index |
| /// </summary> |
| /// <param name="analyzer"> the analyzer used to tokenize / filter the unseen text</param> |
| /// <param name="atomicReader">the reader to use to access the Lucene index</param> |
| /// <param name="classFieldName">the name of the field containing the class assigned to documents</param> |
| /// <param name="textFieldName">the name of the field used to compare documents</param> |
| public virtual void Train(AtomicReader atomicReader, string textFieldName, string classFieldName, Analyzer analyzer) |
| { |
| Train(atomicReader, textFieldName, classFieldName, analyzer, null); |
| } |
| |
| /// <summary>Train the classifier using the underlying Lucene index</summary> |
| /// <param name="analyzer">the analyzer used to tokenize / filter the unseen text</param> |
| /// <param name="atomicReader">the reader to use to access the Lucene index</param> |
| /// <param name="classFieldName">the name of the field containing the class assigned to documents</param> |
| /// <param name="query">the query to filter which documents use for training</param> |
| /// <param name="textFieldName">the name of the field used to compare documents</param> |
| public virtual void Train(AtomicReader atomicReader, string textFieldName, string classFieldName, Analyzer analyzer, Query query) |
| { |
| Train(atomicReader, new string[]{textFieldName}, classFieldName, analyzer, query); |
| } |
| |
| /// <summary>Train the classifier using the underlying Lucene index</summary> |
| /// <param name="analyzer">the analyzer used to tokenize / filter the unseen text</param> |
| /// <param name="atomicReader">the reader to use to access the Lucene index</param> |
| /// <param name="classFieldName">the name of the field containing the class assigned to documents</param> |
| /// <param name="query">the query to filter which documents use for training</param> |
| /// <param name="textFieldNames">the names of the fields to be used to compare documents</param> |
| public virtual void Train(AtomicReader atomicReader, string[] textFieldNames, string classFieldName, Analyzer analyzer, Query query) |
| { |
| _atomicReader = atomicReader; |
| _indexSearcher = new IndexSearcher(_atomicReader); |
| _textFieldNames = textFieldNames; |
| _classFieldName = classFieldName; |
| _analyzer = analyzer; |
| _query = query; |
| _docsWithClassSize = CountDocsWithClass(); |
| } |
| |
| private int CountDocsWithClass() |
| { |
| int docCount = MultiFields.GetTerms(_atomicReader, _classFieldName).DocCount; |
| if (docCount == -1) |
| { // in case codec doesn't support getDocCount |
| TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); |
| BooleanQuery q = new BooleanQuery(); |
| q.Add(new BooleanClause(new WildcardQuery(new Term(_classFieldName, WildcardQuery.WILDCARD_STRING.ToString())), Occur.MUST)); |
| if (_query != null) |
| { |
| q.Add(_query, Occur.MUST); |
| } |
| _indexSearcher.Search(q, totalHitCountCollector); |
| docCount = totalHitCountCollector.TotalHits; |
| } |
| return docCount; |
| } |
| |
| private string[] TokenizeDoc(string doc) |
| { |
| ICollection<string> result = new LinkedList<string>(); |
| foreach (string textFieldName in _textFieldNames) { |
| TokenStream tokenStream = _analyzer.GetTokenStream(textFieldName, new StringReader(doc)); |
| try |
| { |
| ICharTermAttribute charTermAttribute = tokenStream.AddAttribute<ICharTermAttribute>(); |
| tokenStream.Reset(); |
| while (tokenStream.IncrementToken()) |
| { |
| result.Add(charTermAttribute.ToString()); |
| } |
| tokenStream.End(); |
| } |
| finally |
| { |
| IOUtils.DisposeWhileHandlingException(tokenStream); |
| } |
| } |
| var ret = new string[result.Count]; |
| result.CopyTo(ret, 0); |
| return ret; |
| } |
| |
| /// <summary> |
| /// Assign a class (with score) to the given text string |
| /// </summary> |
| /// <param name="inputDocument">a string containing text to be classified</param> |
| /// <returns>a <see cref="ClassificationResult{BytesRef}"/> holding assigned class of type <see cref="BytesRef"/> and score</returns> |
| public virtual ClassificationResult<BytesRef> AssignClass(string inputDocument) |
| { |
| if (_atomicReader == null) |
| { |
| throw new IOException("You must first call Classifier#train"); |
| } |
| double max = - double.MaxValue; |
| BytesRef foundClass = new BytesRef(); |
| |
| Terms terms = MultiFields.GetTerms(_atomicReader, _classFieldName); |
| TermsEnum termsEnum = terms.GetEnumerator(); |
| BytesRef next; |
| string[] tokenizedDoc = TokenizeDoc(inputDocument); |
| while (termsEnum.MoveNext()) |
| { |
| next = termsEnum.Term; |
| double clVal = CalculateLogPrior(next) + CalculateLogLikelihood(tokenizedDoc, next); |
| if (clVal > max) |
| { |
| max = clVal; |
| foundClass = BytesRef.DeepCopyOf(next); |
| } |
| } |
| double score = 10 / Math.Abs(max); |
| return new ClassificationResult<BytesRef>(foundClass, score); |
| } |
| |
| |
| private double CalculateLogLikelihood(string[] tokenizedDoc, BytesRef c) |
| { |
| // for each word |
| double result = 0d; |
| foreach (string word in tokenizedDoc) |
| { |
| // search with text:word AND class:c |
| int hits = GetWordFreqForClass(word, c); |
| |
| // num : count the no of times the word appears in documents of class c (+1) |
| double num = hits + 1; // +1 is added because of add 1 smoothing |
| |
| // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) |
| double den = GetTextTermFreqForClass(c) + _docsWithClassSize; |
| |
| // P(w|c) = num/den |
| double wordProbability = num / den; |
| result += Math.Log(wordProbability); |
| } |
| |
| // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) |
| return result; |
| } |
| |
| private double GetTextTermFreqForClass(BytesRef c) |
| { |
| double avgNumberOfUniqueTerms = 0; |
| foreach (string textFieldName in _textFieldNames) |
| { |
| Terms terms = MultiFields.GetTerms(_atomicReader, textFieldName); |
| long numPostings = terms.SumDocFreq; // number of term/doc pairs |
| avgNumberOfUniqueTerms += numPostings / (double) terms.DocCount; // avg # of unique terms per doc |
| } |
| int docsWithC = _atomicReader.DocFreq(new Term(_classFieldName, c)); |
| return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c |
| } |
| |
| private int GetWordFreqForClass(string word, BytesRef c) |
| { |
| BooleanQuery booleanQuery = new BooleanQuery(); |
| BooleanQuery subQuery = new BooleanQuery(); |
| foreach (string textFieldName in _textFieldNames) |
| { |
| subQuery.Add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), Occur.SHOULD)); |
| } |
| booleanQuery.Add(new BooleanClause(subQuery, Occur.MUST)); |
| booleanQuery.Add(new BooleanClause(new TermQuery(new Term(_classFieldName, c)), Occur.MUST)); |
| if (_query != null) |
| { |
| booleanQuery.Add(_query, Occur.MUST); |
| } |
| TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); |
| _indexSearcher.Search(booleanQuery, totalHitCountCollector); |
| return totalHitCountCollector.TotalHits; |
| } |
| |
| private double CalculateLogPrior(BytesRef currentClass) |
| { |
| return Math.Log((double) DocCount(currentClass)) - Math.Log(_docsWithClassSize); |
| } |
| |
| private int DocCount(BytesRef countedClass) |
| { |
| return _atomicReader.DocFreq(new Term(_classFieldName, countedClass)); |
| } |
| } |
| } |