blob: df06f2afb805e4c2223504e6fec219e62d57f356 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.namefind;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceCodec;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.Span;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator;
import opennlp.tools.util.featuregen.AdditionalContextFeatureGenerator;
import opennlp.tools.util.featuregen.BigramNameFeatureGenerator;
import opennlp.tools.util.featuregen.CachedFeatureGenerator;
import opennlp.tools.util.featuregen.FeatureGeneratorResourceProvider;
import opennlp.tools.util.featuregen.GeneratorFactory;
import opennlp.tools.util.featuregen.OutcomePriorFeatureGenerator;
import opennlp.tools.util.featuregen.PreviousMapFeatureGenerator;
import opennlp.tools.util.featuregen.SentenceFeatureGenerator;
import opennlp.tools.util.featuregen.TokenClassFeatureGenerator;
import opennlp.tools.util.featuregen.TokenFeatureGenerator;
import opennlp.tools.util.featuregen.WindowFeatureGenerator;
import opennlp.tools.util.model.ModelUtil;
/**
* Class for creating a maximum-entropy-based name finder.
*/
public class NameFinderME implements TokenNameFinder {
private static String[][] EMPTY = new String[0][0];
public static final int DEFAULT_BEAM_SIZE = 3;
private static final Pattern typedOutcomePattern = Pattern.compile("(.+)-\\w+");
public static final String START = "start";
public static final String CONTINUE = "cont";
public static final String OTHER = "other";
private SequenceCodec<String> seqCodec = new BioCodec();
protected SequenceClassificationModel<String> model;
protected NameContextGenerator contextGenerator;
private Sequence bestSequence;
private AdditionalContextFeatureGenerator additionalContextFeatureGenerator
= new AdditionalContextFeatureGenerator();
private SequenceValidator<String> sequenceValidator;
public NameFinderME(TokenNameFinderModel model) {
TokenNameFinderFactory factory = model.getFactory();
seqCodec = factory.createSequenceCodec();
sequenceValidator = seqCodec.createSequenceValidator();
this.model = model.getNameFinderSequenceModel();
contextGenerator = factory.createContextGenerator();
// TODO: We should deprecate this. And come up with a better solution!
contextGenerator.addFeatureGenerator(
new WindowFeatureGenerator(additionalContextFeatureGenerator, 8, 8));
}
@Deprecated
/**
* @deprecated the default feature generation is now always included in the models and loaded
* if not by the factory. Subclasses using this methods should do the same.
*/
static AdaptiveFeatureGenerator createFeatureGenerator() {
return new CachedFeatureGenerator(
new AdaptiveFeatureGenerator[]{
new WindowFeatureGenerator(new TokenFeatureGenerator(), 2, 2),
new WindowFeatureGenerator(new TokenClassFeatureGenerator(true), 2, 2),
new OutcomePriorFeatureGenerator(),
new PreviousMapFeatureGenerator(),
new BigramNameFeatureGenerator(),
new SentenceFeatureGenerator(true, false)
});
}
private static AdaptiveFeatureGenerator createFeatureGenerator(
byte[] generatorDescriptor, final Map<String, Object> resources)
throws IOException {
AdaptiveFeatureGenerator featureGenerator;
if (generatorDescriptor != null) {
featureGenerator = GeneratorFactory.create(new ByteArrayInputStream(
generatorDescriptor), new FeatureGeneratorResourceProvider() {
public Object getResource(String key) {
if (resources != null) {
return resources.get(key);
}
return null;
}
});
} else {
featureGenerator = null;
}
return featureGenerator;
}
public Span[] find(String[] tokens) {
return find(tokens, EMPTY);
}
/**
* Generates name tags for the given sequence, typically a sentence, returning
* token spans for any identified names.
*
* @param tokens an array of the tokens or words of the sequence, typically a
* sentence.
* @param additionalContext features which are based on context outside of the
* sentence but which should also be used.
*
* @return an array of spans for each of the names identified.
*/
public Span[] find(String[] tokens, String[][] additionalContext) {
additionalContextFeatureGenerator.setCurrentContext(additionalContext);
bestSequence = model.bestSequence(tokens, additionalContext, contextGenerator, sequenceValidator);
List<String> c = bestSequence.getOutcomes();
contextGenerator.updateAdaptiveData(tokens, c.toArray(new String[c.size()]));
Span[] spans = seqCodec.decode(c);
spans = setProbs(spans);
return spans;
}
/**
* Forgets all adaptive data which was collected during previous calls to one
* of the find methods.
*
* This method is typical called at the end of a document.
*/
public void clearAdaptiveData() {
contextGenerator.clearAdaptiveData();
}
/**
* Populates the specified array with the probabilities of the last decoded
* sequence. The sequence was determined based on the previous call to
* <code>chunk</code>. The specified array should be at least as large as the
* number of tokens in the previous call to <code>chunk</code>.
*
* @param probs An array used to hold the probabilities of the last decoded
* sequence.
*/
public void probs(double[] probs) {
bestSequence.getProbs(probs);
}
/**
* Returns an array with the probabilities of the last decoded sequence. The
* sequence was determined based on the previous call to <code>chunk</code>.
*
* @return An array with the same number of probabilities as tokens were sent
* to <code>chunk</code> when it was last called.
*/
public double[] probs() {
return bestSequence.getProbs();
}
/**
* sets the probs for the spans
*
* @param spans
* @return
*/
private Span[] setProbs(Span[] spans) {
double[] probs = probs(spans);
if (probs != null) {
for (int i = 0; i < probs.length; i++) {
double prob = probs[i];
spans[i]= new Span(spans[i], prob);
}
}
return spans;
}
/**
* Returns an array of probabilities for each of the specified spans which is
* the arithmetic mean of the probabilities for each of the outcomes which
* make up the span.
*
* @param spans The spans of the names for which probabilities are desired.
*
* @return an array of probabilities for each of the specified spans.
*/
public double[] probs(Span[] spans) {
double[] sprobs = new double[spans.length];
double[] probs = bestSequence.getProbs();
for (int si = 0; si < spans.length; si++) {
double p = 0;
for (int oi = spans[si].getStart(); oi < spans[si].getEnd(); oi++) {
p += probs[oi];
}
p /= spans[si].length();
sprobs[si] = p;
}
return sprobs;
}
public static TokenNameFinderModel train(String languageCode, String type,
ObjectStream<NameSample> samples, TrainingParameters trainParams,
TokenNameFinderFactory factory) throws IOException {
String beamSizeString = trainParams.getSettings().get(BeamSearch.BEAM_SIZE_PARAMETER);
int beamSize = NameFinderME.DEFAULT_BEAM_SIZE;
if (beamSizeString != null) {
beamSize = Integer.parseInt(beamSizeString);
}
Map<String, String> manifestInfoEntries = new HashMap<String, String>();
MaxentModel nameFinderModel = null;
SequenceClassificationModel<String> seqModel = null;
TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream<Event> eventStream = new NameFinderEventStream(samples, type,
factory.createContextGenerator(), factory.createSequenceCodec());
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries);
nameFinderModel = trainer.train(eventStream);
} // TODO: Maybe it is not a good idea, that these two don't use the context generator ?!
// These also don't use the sequence codec ?!
else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, factory.createContextGenerator());
EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
trainParams.getSettings(), manifestInfoEntries);
nameFinderModel = trainer.train(ss);
} else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams.getSettings(), manifestInfoEntries);
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, factory.createContextGenerator(), false);
seqModel = trainer.train(ss);
} else {
throw new IllegalStateException("Unexpected trainer type!");
}
if (seqModel != null) {
return new TokenNameFinderModel(languageCode, seqModel, factory.getFeatureGenerator(),
factory.getResources(), manifestInfoEntries, factory.getSequenceCodec(), factory);
} else {
return new TokenNameFinderModel(languageCode, nameFinderModel, beamSize, factory.getFeatureGenerator(),
factory.getResources(), manifestInfoEntries, factory.getSequenceCodec(), factory);
}
}
/**
* Trains a name finder model.
*
* @param languageCode the language of the training data
* @param type null or an override type for all types in the training data
* @param samples the training data
* @param trainParams machine learning train parameters
* @param generator null or the feature generator
* @param resources the resources for the name finder or null if none
*
* @return the newly trained model
*
* @throws IOException
* @deprecated use
* {@link NameFinderME#train(String, String, ObjectStream, TrainingParameters, TokenNameFinderFactory)}
* instead.
*/
@Deprecated
static TokenNameFinderModel train(String languageCode, String type, ObjectStream<NameSample> samples,
TrainingParameters trainParams, AdaptiveFeatureGenerator generator, final Map<String, Object> resources)
throws IOException {
if (languageCode == null) {
throw new IllegalArgumentException("languageCode must not be null!");
}
String beamSizeString = trainParams.getSettings().get(BeamSearch.BEAM_SIZE_PARAMETER);
int beamSize = NameFinderME.DEFAULT_BEAM_SIZE;
if (beamSizeString != null) {
beamSize = Integer.parseInt(beamSizeString);
}
Map<String, String> manifestInfoEntries = new HashMap<String, String>();
AdaptiveFeatureGenerator featureGenerator;
if (generator != null) {
featureGenerator = generator;
} else {
featureGenerator = createFeatureGenerator();
}
MaxentModel nameFinderModel = null;
SequenceClassificationModel<String> seqModel = null;
TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream<Event> eventStream = new NameFinderEventStream(samples, type,
new DefaultNameContextGenerator(featureGenerator), new BioCodec());
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries);
nameFinderModel = trainer.train(eventStream);
} else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator);
EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
trainParams.getSettings(), manifestInfoEntries);
nameFinderModel = trainer.train(ss);
} else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams.getSettings(), manifestInfoEntries);
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator, false);
seqModel = trainer.train(ss);
} else {
throw new IllegalStateException("Unexpected trainer type!");
}
// TODO: Pass the sequence codec down to the model! We will just store the class
// name in the model, and then always use the extension loader to create it!
// The cmd line interface, will replace shortcuts with actual class names.
// depending on which one is not null!
if (seqModel != null) {
return new TokenNameFinderModel(languageCode, seqModel, null,
resources, manifestInfoEntries, new BioCodec(), new TokenNameFinderFactory());
} else {
return new TokenNameFinderModel(languageCode, nameFinderModel, beamSize, null,
resources, manifestInfoEntries, new BioCodec(), new TokenNameFinderFactory());
}
}
/**
* Trains a name finder model.
*
* @param languageCode the language of the training data
* @param type null or an override type for all types in the training data
* @param samples the training data
* @param trainParams machine learning train parameters
* @param featureGeneratorBytes descriptor to configure the feature generation
* or null
* @param resources the resources for the name finder or null if none
*
* @return the newly trained model
*
* @throws IOException
* @deprecated use
* {@link NameFinderME#train(String, String, ObjectStream, TrainingParameters, TokenNameFinderFactory)}
* instead.
*/
@Deprecated
static TokenNameFinderModel train(String languageCode, String type,
ObjectStream<NameSample> samples, TrainingParameters trainParams,
byte[] featureGeneratorBytes, final Map<String, Object> resources)
throws IOException {
TokenNameFinderModel model = train(languageCode, type, samples, trainParams,
createFeatureGenerator(featureGeneratorBytes, resources), resources);
if (featureGeneratorBytes != null) {
model = model.updateFeatureGenerator(featureGeneratorBytes);
}
return model;
}
/**
* Gets the name type from the outcome
*
* @param outcome the outcome
* @return the name type, or null if not set
*/
static final String extractNameType(String outcome) {
Matcher matcher = typedOutcomePattern.matcher(outcome);
if (matcher.matches()) {
String nameType = matcher.group(1);
return nameType;
}
return null;
}
/**
* Removes spans with are intersecting or crossing in anyway.
*
* <p>
* The following rules are used to remove the spans:<br>
* Identical spans: The first span in the array after sorting it remains<br>
* Intersecting spans: The first span after sorting remains<br>
* Contained spans: All spans which are contained by another are removed<br>
*
* @param spans
*
* @return non-overlapping spans
*/
public static Span[] dropOverlappingSpans(Span spans[]) {
List<Span> sortedSpans = new ArrayList<Span>(spans.length);
Collections.addAll(sortedSpans, spans);
Collections.sort(sortedSpans);
Iterator<Span> it = sortedSpans.iterator();
Span lastSpan = null;
while (it.hasNext()) {
Span span = it.next();
if (lastSpan != null) {
if (lastSpan.intersects(span)) {
it.remove();
span = lastSpan;
}
}
lastSpan = span;
}
return sortedSpans.toArray(new Span[sortedSpans.size()]);
}
}