| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package opennlp.tools.doc_classifier; |
| |
| import java.io.File; |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Scanner; |
| |
| import opennlp.tools.similarity.apps.utils.CountItemsList; |
| import opennlp.tools.similarity.apps.utils.ValueSortMap; |
| import opennlp.tools.textsimilarity.TextProcessor; |
| |
| import org.apache.commons.lang.StringUtils; |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.lucene.analysis.Analyzer; |
| import org.apache.lucene.analysis.standard.StandardAnalyzer; |
| import org.apache.lucene.document.Document; |
| import org.apache.lucene.index.DirectoryReader; |
| import org.apache.lucene.index.IndexReader; |
| import org.apache.lucene.queryparser.classic.ParseException; |
| import org.apache.lucene.queryparser.classic.QueryParser; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.TopDocs; |
| import org.apache.lucene.store.Directory; |
| import org.apache.lucene.store.FSDirectory; |
| import org.apache.lucene.util.Version; |
| import org.json.JSONObject; |
| |
| public class DocClassifier { |
| public static final String DOC_CLASSIFIER_KEY = "doc_class"; |
| public static String resourceDir = null; |
| public static final Log logger = LogFactory.getLog(DocClassifier.class); |
| private Map<String, Float> scoredClasses = new HashMap<String, Float>(); |
| |
| |
| public static Float MIN_TOTAL_SCORE_FOR_CATEGORY = 0.3f; //3.0f; |
| protected static IndexReader indexReader = null; |
| protected static IndexSearcher indexSearcher = null; |
| // resource directory plus the index folder |
| private static final String INDEX_PATH = resourceDir |
| + ClassifierTrainingSetIndexer.INDEX_PATH; |
| |
| // http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm |
| private static final int MAX_DOCS_TO_USE_FOR_CLASSIFY = 10, // 10 similar |
| // docs for |
| // nearest |
| // neighbor |
| // settings |
| |
| MAX_CATEG_RESULTS = 2; |
| private static final float BEST_TO_NEX_BEST_RATIO = 2.0f; |
| // to accumulate classif results |
| private CountItemsList<String> localCats = new CountItemsList<String>(); |
| private int MAX_TOKENS_TO_FORM = 30; |
| private String CAT_COMPUTING = "computing"; |
| public static final String DOC_CLASSIFIER_MAP = "doc_classifier_map"; |
| private static final int MIN_SENTENCE_LENGTH_TO_CATEGORIZE = 60; // if |
| // sentence |
| // is |
| // shorter, |
| // should |
| // not |
| // be |
| // used |
| // for |
| // classification |
| private static final int MIN_CHARS_IN_QUERY = 30; // if combination of |
| // keywords is shorter, |
| // should not be used |
| // for classification |
| |
| // these are categories from the index |
| public static final String[] categories = new String[] { "legal", "health", |
| "finance", "computing", "engineering", "business" }; |
| |
| static { |
| synchronized (DocClassifier.class) { |
| Directory indexDirectory = null; |
| |
| try { |
| indexDirectory = FSDirectory.open(new File(INDEX_PATH)); |
| } catch (IOException e2) { |
| logger.error("problem opening index " + e2); |
| } |
| try { |
| indexReader = DirectoryReader.open(indexDirectory); |
| indexSearcher = new IndexSearcher(indexReader); |
| } catch (IOException e2) { |
| logger.error("problem reading index \n" + e2); |
| } |
| } |
| } |
| |
| public DocClassifier(String inputFilename, JSONObject inputJSON) { |
| scoredClasses = new HashMap<String, Float>(); |
| } |
| |
| /* returns the class name for a sentence */ |
| private List<String> classifySentence(String queryStr) { |
| |
| List<String> results = new ArrayList<String>(); |
| // too short of a query |
| if (queryStr.length() < MIN_CHARS_IN_QUERY) { |
| return results; |
| } |
| |
| Analyzer std = new StandardAnalyzer(Version.LUCENE_46); |
| QueryParser parser = new QueryParser(Version.LUCENE_46, "text", std); |
| parser.setDefaultOperator(QueryParser.Operator.OR); |
| Query query = null; |
| try { |
| query = parser.parse(queryStr); |
| |
| } catch (ParseException e2) { |
| |
| return results; |
| } |
| TopDocs hits = null; // TopDocs search(Query query, int n) |
| // Finds the top n hits for query. |
| try { |
| hits = indexSearcher |
| .search(query, MAX_DOCS_TO_USE_FOR_CLASSIFY + 2); |
| } catch (IOException e1) { |
| logger.error("problem searching index \n" + e1); |
| } |
| logger.debug("Found " + hits.totalHits + " hits for " + queryStr); |
| int count = 0; |
| |
| |
| for (ScoreDoc scoreDoc : hits.scoreDocs) { |
| Document doc = null; |
| try { |
| doc = indexSearcher.doc(scoreDoc.doc); |
| } catch (IOException e) { |
| logger.error("Problem searching training set for classif \n" |
| + e); |
| continue; |
| } |
| String flag = doc.get("class"); |
| |
| Float scoreForClass = scoredClasses.get(flag); |
| if (scoreForClass == null) |
| scoredClasses.put(flag, scoreDoc.score); |
| else |
| scoredClasses.put(flag, scoreForClass + scoreDoc.score); |
| |
| logger.debug(" <<categorized as>> " + flag + " | score=" |
| + scoreDoc.score + " \n text =" + doc.get("text") + "\n"); |
| |
| if (count > MAX_DOCS_TO_USE_FOR_CLASSIFY) { |
| break; |
| } |
| count++; |
| } |
| try { |
| scoredClasses = ValueSortMap.sortMapByValue(scoredClasses, false); |
| List<String> resultsAll = new ArrayList<String>( |
| scoredClasses.keySet()), resultsAboveThresh = new ArrayList<String>(); |
| for (String key : resultsAll) { |
| if (scoredClasses.get(key) > MIN_TOTAL_SCORE_FOR_CATEGORY) |
| resultsAboveThresh.add(key); |
| else |
| logger.debug("Too low score of " + scoredClasses.get(key) |
| + " for category = " + key); |
| } |
| |
| int len = resultsAboveThresh.size(); |
| if (len > MAX_CATEG_RESULTS) |
| results = resultsAboveThresh.subList(0, MAX_CATEG_RESULTS); // get |
| // maxRes |
| // elements |
| else |
| results = resultsAboveThresh; |
| } catch (Exception e) { |
| logger.error("Problem aggregating search results\n" + e); |
| } |
| if (results.size() < 2) |
| return results; |
| |
| // if two categories, one is very high and another is relatively low |
| if (scoredClasses.get(results.get(0)) |
| / scoredClasses.get(results.get(1)) > BEST_TO_NEX_BEST_RATIO) // second |
| // best |
| // is |
| // much |
| // worse |
| return results.subList(0, 1); |
| else |
| return results; |
| |
| } |
| |
| |
| |
| |
| public static String formClassifQuery(String pageContentReader, int maxRes) { |
| |
| // We want to control which delimiters we substitute. For example '_' & |
| // \n we retain |
| pageContentReader = pageContentReader.replaceAll("[^A-Za-z0-9 _\\n]", |
| ""); |
| |
| Scanner in = new Scanner(pageContentReader); |
| in.useDelimiter("\\s+"); |
| Map<String, Integer> words = new HashMap<String, Integer>(); |
| |
| while (in.hasNext()) { |
| String word = in.next(); |
| if (!StringUtils.isAlpha(word) || word.length() < 4) |
| continue; |
| |
| if (!words.containsKey(word)) { |
| words.put(word, 1); |
| } else { |
| words.put(word, words.get(word) + 1); |
| } |
| } |
| in.close(); |
| words = ValueSortMap.sortMapByValue(words, false); |
| List<String> resultsAll = new ArrayList<String>(words.keySet()), results = null; |
| |
| int len = resultsAll.size(); |
| if (len > maxRes) |
| results = resultsAll.subList(len - maxRes, len - 1); // get maxRes |
| // elements |
| else |
| results = resultsAll; |
| |
| return results.toString().replaceAll("(\\[|\\]|,)", " ").trim(); |
| } |
| |
| public void close() { |
| try { |
| indexReader.close(); |
| } catch (IOException e) { |
| logger.error("Problem closing index \n" + e); |
| } |
| } |
| |
| |
| /* |
| * Main entry point for classifying sentences |
| */ |
| |
| public List<String> getEntityOrClassFromText(String content) { |
| |
| List<String> sentences = TextProcessor.splitToSentences(content); |
| List<String> classifResults; |
| |
| try { |
| for (String sentence : sentences) { |
| // If sentence is too short, there is a chance it is not form a |
| // main text area, |
| // but from somewhere else, so it is safer not to use this |
| // portion of text for classification |
| |
| if (sentence.length() < MIN_SENTENCE_LENGTH_TO_CATEGORIZE) |
| continue; |
| String query = formClassifQuery(sentence, MAX_TOKENS_TO_FORM); |
| classifResults = classifySentence(query); |
| if (classifResults != null && classifResults.size() > 0) { |
| for (String c : classifResults) { |
| localCats.add(c); |
| } |
| logger.debug(sentence + " => " + classifResults); |
| } |
| } |
| |
| } catch (Exception e) { |
| logger.error("Problem classifying sentence\n " + e); |
| } |
| |
| List<String> aggrResults = new ArrayList<String>(); |
| try { |
| |
| aggrResults = localCats.getFrequentTags(); |
| |
| logger.debug(localCats.getFrequentTags()); |
| } catch (Exception e) { |
| logger.error("Problem aggregating search results\n" + e); |
| } |
| return aggrResults; |
| } |
| |
| } |