| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package opennlp.tools.postag; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Map.Entry; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import opennlp.tools.dictionary.Dictionary; |
| import opennlp.tools.ml.BeamSearch; |
| import opennlp.tools.ml.EventModelSequenceTrainer; |
| import opennlp.tools.ml.EventTrainer; |
| import opennlp.tools.ml.SequenceTrainer; |
| import opennlp.tools.ml.TrainerFactory; |
| import opennlp.tools.ml.TrainerFactory.TrainerType; |
| import opennlp.tools.ml.model.Event; |
| import opennlp.tools.ml.model.MaxentModel; |
| import opennlp.tools.ml.model.SequenceClassificationModel; |
| import opennlp.tools.ngram.NGramModel; |
| import opennlp.tools.util.ObjectStream; |
| import opennlp.tools.util.Sequence; |
| import opennlp.tools.util.SequenceValidator; |
| import opennlp.tools.util.StringList; |
| import opennlp.tools.util.StringUtil; |
| import opennlp.tools.util.TrainingParameters; |
| import opennlp.tools.util.featuregen.StringPattern; |
| |
| /** |
| * A part-of-speech tagger that uses maximum entropy. Tries to predict whether |
| * words are nouns, verbs, or any of 70 other POS tags depending on their |
| * surrounding context. |
| * |
| */ |
| public class POSTaggerME implements POSTagger { |
| |
| public static final int DEFAULT_BEAM_SIZE = 3; |
| |
| private POSModel modelPackage; |
| |
| /** |
| * The feature context generator. |
| */ |
| protected POSContextGenerator contextGen; |
| |
| /** |
| * Tag dictionary used for restricting words to a fixed set of tags. |
| */ |
| protected TagDictionary tagDictionary; |
| |
| protected Dictionary ngramDictionary; |
| |
| /** |
| * Says whether a filter should be used to check whether a tag assignment |
| * is to a word outside of a closed class. |
| */ |
| protected boolean useClosedClassTagsFilter = false; |
| |
| |
| /** |
| * The size of the beam to be used in determining the best sequence of pos tags. |
| */ |
| protected int size; |
| |
| private Sequence bestSequence; |
| |
| private SequenceClassificationModel<String> model; |
| |
| private SequenceValidator<String> sequenceValidator; |
| |
| /** |
| * Initializes the current instance with the provided model. |
| * |
| * @param model |
| */ |
| public POSTaggerME(POSModel model) { |
| POSTaggerFactory factory = model.getFactory(); |
| |
| int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; |
| |
| String beamSizeString = model.getManifestProperty(BeamSearch.BEAM_SIZE_PARAMETER); |
| |
| if (beamSizeString != null) { |
| beamSize = Integer.parseInt(beamSizeString); |
| } |
| |
| modelPackage = model; |
| |
| contextGen = factory.getPOSContextGenerator(beamSize); |
| tagDictionary = factory.getTagDictionary(); |
| size = beamSize; |
| |
| sequenceValidator = factory.getSequenceValidator(); |
| |
| if (model.getPosSequenceModel() != null) { |
| this.model = model.getPosSequenceModel(); |
| } |
| else { |
| this.model = new opennlp.tools.ml.BeamSearch<>(beamSize, |
| model.getPosModel(), 0); |
| } |
| |
| } |
| |
| /** |
| * Retrieves an array of all possible part-of-speech tags from the |
| * tagger. |
| * |
| * @return String[] |
| */ |
| public String[] getAllPosTags() { |
| return model.getOutcomes(); |
| } |
| |
| public String[] tag(String[] sentence) { |
| return this.tag(sentence, null); |
| } |
| |
| public String[] tag(String[] sentence, Object[] additionaContext) { |
| bestSequence = model.bestSequence(sentence, additionaContext, contextGen, sequenceValidator); |
| List<String> t = bestSequence.getOutcomes(); |
| return t.toArray(new String[t.size()]); |
| } |
| |
| /** |
| * Returns at most the specified number of taggings for the specified sentence. |
| * |
| * @param numTaggings The number of tagging to be returned. |
| * @param sentence An array of tokens which make up a sentence. |
| * |
| * @return At most the specified number of taggings for the specified sentence. |
| */ |
| public String[][] tag(int numTaggings, String[] sentence) { |
| Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null, |
| contextGen, sequenceValidator); |
| String[][] tags = new String[bestSequences.length][]; |
| for (int si = 0; si < tags.length; si++) { |
| List<String> t = bestSequences[si].getOutcomes(); |
| tags[si] = t.toArray(new String[t.size()]); |
| } |
| return tags; |
| } |
| |
| public Sequence[] topKSequences(String[] sentence) { |
| return this.topKSequences(sentence, null); |
| } |
| |
| public Sequence[] topKSequences(String[] sentence, Object[] additionaContext) { |
| return model.bestSequences(size, sentence, additionaContext, contextGen, sequenceValidator); |
| } |
| |
| /** |
| * Populates the specified array with the probabilities for each tag of the last tagged sentence. |
| * |
| * @param probs An array to put the probabilities into. |
| */ |
| public void probs(double[] probs) { |
| bestSequence.getProbs(probs); |
| } |
| |
| /** |
| * Returns an array with the probabilities for each tag of the last tagged sentence. |
| * |
| * @return an array with the probabilities for each tag of the last tagged sentence. |
| */ |
| public double[] probs() { |
| return bestSequence.getProbs(); |
| } |
| |
| public String[] getOrderedTags(List<String> words, List<String> tags, int index) { |
| return getOrderedTags(words,tags,index,null); |
| } |
| |
| public String[] getOrderedTags(List<String> words, List<String> tags, int index,double[] tprobs) { |
| |
| if (modelPackage.getPosModel() != null) { |
| |
| MaxentModel posModel = modelPackage.getPosModel(); |
| |
| double[] probs = posModel.eval(contextGen.getContext(index, |
| words.toArray(new String[words.size()]), |
| tags.toArray(new String[tags.size()]),null)); |
| |
| String[] orderedTags = new String[probs.length]; |
| for (int i = 0; i < probs.length; i++) { |
| int max = 0; |
| for (int ti = 1; ti < probs.length; ti++) { |
| if (probs[ti] > probs[max]) { |
| max = ti; |
| } |
| } |
| orderedTags[i] = posModel.getOutcome(max); |
| if (tprobs != null) { |
| tprobs[i] = probs[max]; |
| } |
| probs[max] = 0; |
| } |
| return orderedTags; |
| } |
| else { |
| throw new UnsupportedOperationException("This method can only be called if the " |
| + "classifcation model is an event model!"); |
| } |
| } |
| |
| public static POSModel train(String languageCode, |
| ObjectStream<POSSample> samples, TrainingParameters trainParams, |
| POSTaggerFactory posFactory) throws IOException { |
| |
| int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE); |
| |
| POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator(); |
| |
| Map<String, String> manifestInfoEntries = new HashMap<>(); |
| |
| TrainerType trainerType = TrainerFactory.getTrainerType(trainParams); |
| |
| MaxentModel posModel = null; |
| SequenceClassificationModel<String> seqPosModel = null; |
| if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) { |
| ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator); |
| |
| EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, |
| manifestInfoEntries); |
| posModel = trainer.train(es); |
| } |
| else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) { |
| POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); |
| EventModelSequenceTrainer trainer = |
| TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries); |
| posModel = trainer.train(ss); |
| } |
| else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { |
| SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer( |
| trainParams, manifestInfoEntries); |
| |
| // TODO: This will probably cause issue, since the feature generator uses the outcomes array |
| |
| POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); |
| seqPosModel = trainer.train(ss); |
| } |
| else { |
| throw new IllegalArgumentException("Trainer type is not supported: " + trainerType); |
| } |
| |
| if (posModel != null) { |
| return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory); |
| } |
| else { |
| return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory); |
| } |
| } |
| |
| public static Dictionary buildNGramDictionary(ObjectStream<POSSample> samples, int cutoff) |
| throws IOException { |
| |
| NGramModel ngramModel = new NGramModel(); |
| |
| POSSample sample; |
| while ((sample = samples.read()) != null) { |
| String[] words = sample.getSentence(); |
| |
| if (words.length > 0) |
| ngramModel.add(new StringList(words), 1, 1); |
| } |
| |
| ngramModel.cutoff(cutoff, Integer.MAX_VALUE); |
| |
| return ngramModel.toDictionary(true); |
| } |
| |
| public static void populatePOSDictionary(ObjectStream<POSSample> samples, |
| MutableTagDictionary dict, int cutoff) throws IOException { |
| System.out.println("Expanding POS Dictionary ..."); |
| long start = System.nanoTime(); |
| |
| // the data structure will store the word, the tag, and the number of |
| // occurrences |
| Map<String, Map<String, AtomicInteger>> newEntries = new HashMap<>(); |
| POSSample sample; |
| while ((sample = samples.read()) != null) { |
| String[] words = sample.getSentence(); |
| String[] tags = sample.getTags(); |
| |
| for (int i = 0; i < words.length; i++) { |
| // only store words |
| if (!StringPattern.recognize(words[i]).containsDigit()) { |
| String word; |
| if (dict.isCaseSensitive()) { |
| word = words[i]; |
| } else { |
| word = StringUtil.toLowerCase(words[i]); |
| } |
| |
| if (!newEntries.containsKey(word)) { |
| newEntries.put(word, new HashMap<>()); |
| } |
| |
| String[] dictTags = dict.getTags(word); |
| if (dictTags != null) { |
| for (String tag : dictTags) { |
| // for this tags we start with the cutoff |
| Map<String, AtomicInteger> value = newEntries.get(word); |
| if (!value.containsKey(tag)) { |
| value.put(tag, new AtomicInteger(cutoff)); |
| } |
| } |
| } |
| |
| if (!newEntries.get(word).containsKey(tags[i])) { |
| newEntries.get(word).put(tags[i], new AtomicInteger(1)); |
| } else { |
| newEntries.get(word).get(tags[i]).incrementAndGet(); |
| } |
| } |
| } |
| } |
| |
| // now we check if the word + tag pairs have enough occurrences, if yes we |
| // add it to the dictionary |
| for (Entry<String, Map<String, AtomicInteger>> wordEntry : newEntries |
| .entrySet()) { |
| List<String> tagsForWord = new ArrayList<>(); |
| for (Entry<String, AtomicInteger> entry : wordEntry.getValue().entrySet()) { |
| if (entry.getValue().get() >= cutoff) { |
| tagsForWord.add(entry.getKey()); |
| } |
| } |
| if (tagsForWord.size() > 0) { |
| dict.put(wordEntry.getKey(), |
| tagsForWord.toArray(new String[tagsForWord.size()])); |
| } |
| } |
| |
| System.out.println("... finished expanding POS Dictionary. [" |
| + (System.nanoTime() - start) / 1000000 + "ms]"); |
| } |
| } |