blob: 6dc359105af7da8a91bc64bfaa0949904cbbdfb0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.lemmatizer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.StringUtil;
import opennlp.tools.util.TrainingParameters;
/**
* A probabilistic lemmatizer. Tries to predict the induced permutation class
* for each word depending on its surrounding context. Based on
* Grzegorz ChrupaƂa. 2008. Towards a Machine-Learning Architecture
* for Lexical Functional Grammar Parsing. PhD dissertation, Dublin City University.
* http://grzegorz.chrupala.me/papers/phd-single.pdf
*/
public class LemmatizerME implements Lemmatizer {
public static final int LEMMA_NUMBER = 29;
public static final int DEFAULT_BEAM_SIZE = 3;
protected int beamSize;
private Sequence bestSequence;
private SequenceClassificationModel<String> model;
private LemmatizerContextGenerator contextGenerator;
private SequenceValidator<String> sequenceValidator;
/**
* Initializes the current instance with the provided model
* and the default beam size of 3.
*
* @param model the model
*/
public LemmatizerME(LemmatizerModel model) {
LemmatizerFactory factory = model.getFactory();
int defaultBeamSize = LemmatizerME.DEFAULT_BEAM_SIZE;
String beamSizeString = model.getManifestProperty(BeamSearch.BEAM_SIZE_PARAMETER);
if (beamSizeString != null) {
defaultBeamSize = Integer.parseInt(beamSizeString);
}
contextGenerator = factory.getContextGenerator();
beamSize = defaultBeamSize;
sequenceValidator = factory.getSequenceValidator();
if (model.getLemmatizerSequenceModel() != null) {
this.model = model.getLemmatizerSequenceModel();
}
else {
this.model = new opennlp.tools.ml.BeamSearch<>(beamSize,
(MaxentModel) model.getLemmatizerSequenceModel(), 0);
}
}
public String[] lemmatize(String[] toks, String[] tags) {
String[] ses = predictSES(toks, tags);
return decodeLemmas(toks, ses);
}
@Override public List<List<String>> lemmatize(List<String> toks,
List<String> tags) {
String[] tokens = toks.toArray(new String[toks.size()]);
String[] posTags = tags.toArray(new String[tags.size()]);
String[][] allLemmas = predictLemmas(LEMMA_NUMBER, tokens, posTags);
List<List<String>> predictedLemmas = new ArrayList<>();
for (String[] allLemma : allLemmas) {
predictedLemmas.add(Arrays.asList(allLemma));
}
return predictedLemmas;
}
/**
* Predict Short Edit Script (automatically induced lemma class).
* @param toks the array of tokens
* @param tags the array of pos tags
* @return an array containing the lemma classes
*/
public String[] predictSES(String[] toks, String[] tags) {
bestSequence = model.bestSequence(toks, new Object[] {tags}, contextGenerator, sequenceValidator);
List<String> ses = bestSequence.getOutcomes();
return ses.toArray(new String[ses.size()]);
}
/**
* Predict all possible lemmas (using a default upper bound).
* @param numLemmas the default number of lemmas
* @param toks the tokens
* @param tags the postags
* @return a double array containing all posible lemmas for each token and postag pair
*/
public String[][] predictLemmas(int numLemmas, String[] toks, String[] tags) {
Sequence[] bestSequences = model.bestSequences(numLemmas, toks, new Object[] {tags},
contextGenerator, sequenceValidator);
String[][] allLemmas = new String[bestSequences.length][];
for (int i = 0; i < allLemmas.length; i++) {
List<String> ses = bestSequences[i].getOutcomes();
String[] sesArray = ses.toArray(new String[ses.size()]);
allLemmas[i] = decodeLemmas(toks,sesArray);
}
return allLemmas;
}
/**
* Decodes the lemma from the word and the induced lemma class.
* @param toks the array of tokens
* @param preds the predicted lemma classes
* @return the array of decoded lemmas
*/
public static String[] decodeLemmas(String[] toks, String[] preds) {
List<String> lemmas = new ArrayList<>();
for (int i = 0; i < toks.length; i++) {
String lemma = StringUtil.decodeShortestEditScript(toks[i].toLowerCase(), preds[i]);
if (lemma.length() == 0) {
lemma = "_";
}
lemmas.add(lemma);
}
return lemmas.toArray(new String[lemmas.size()]);
}
public static String[] encodeLemmas(String[] toks, String[] lemmas) {
List<String> sesList = new ArrayList<>();
for (int i = 0; i < toks.length; i++) {
String ses = StringUtil.getShortestEditScript(toks[i], lemmas[i]);
if (ses.length() == 0) {
ses = "_";
}
sesList.add(ses);
}
return sesList.toArray(new String[sesList.size()]);
}
public Sequence[] topKSequences(String[] sentence, String[] tags) {
return model.bestSequences(DEFAULT_BEAM_SIZE, sentence,
new Object[] { tags }, contextGenerator, sequenceValidator);
}
public Sequence[] topKSequences(String[] sentence, String[] tags, double minSequenceScore) {
return model.bestSequences(DEFAULT_BEAM_SIZE, sentence, new Object[] { tags }, minSequenceScore,
contextGenerator, sequenceValidator);
}
/**
* Populates the specified array with the probabilities of the last decoded sequence. The
* sequence was determined based on the previous call to <code>lemmatize</code>. The
* specified array should be at least as large as the number of tokens in the
* previous call to <code>lemmatize</code>.
*
* @param probs An array used to hold the probabilities of the last decoded sequence.
*/
public void probs(double[] probs) {
bestSequence.getProbs(probs);
}
/**
* Returns an array with the probabilities of the last decoded sequence. The
* sequence was determined based on the previous call to <code>chunk</code>.
* @return An array with the same number of probabilities as tokens were sent to <code>chunk</code>
* when it was last called.
*/
public double[] probs() {
return bestSequence.getProbs();
}
public static LemmatizerModel train(String languageCode,
ObjectStream<LemmaSample> samples, TrainingParameters trainParams,
LemmatizerFactory posFactory) throws IOException {
int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER,
LemmatizerME.DEFAULT_BEAM_SIZE);
LemmatizerContextGenerator contextGenerator = posFactory.getContextGenerator();
Map<String, String> manifestInfoEntries = new HashMap<>();
TrainerType trainerType = TrainerFactory.getTrainerType(trainParams);
MaxentModel lemmatizerModel = null;
SequenceClassificationModel<String> seqLemmatizerModel = null;
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream<Event> es = new LemmaSampleEventStream(samples, contextGenerator);
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams,
manifestInfoEntries);
lemmatizerModel = trainer.train(es);
}
else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
LemmaSampleSequenceStream ss = new LemmaSampleSequenceStream(samples, contextGenerator);
EventModelSequenceTrainer trainer =
TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
lemmatizerModel = trainer.train(ss);
}
else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams, manifestInfoEntries);
// TODO: This will probably cause issue, since the feature generator uses the outcomes array
LemmaSampleSequenceStream ss = new LemmaSampleSequenceStream(samples, contextGenerator);
seqLemmatizerModel = trainer.train(ss);
}
else {
throw new IllegalArgumentException("Trainer type is not supported: " + trainerType);
}
if (lemmatizerModel != null) {
return new LemmatizerModel(languageCode, lemmatizerModel, beamSize, manifestInfoEntries, posFactory);
}
else {
return new LemmatizerModel(languageCode, seqLemmatizerModel, manifestInfoEntries, posFactory);
}
}
public Sequence[] topKLemmaClasses(String[] sentence, String[] tags) {
return model.bestSequences(DEFAULT_BEAM_SIZE, sentence,
new Object[] { tags }, contextGenerator, sequenceValidator);
}
public Sequence[] topKLemmaClasses(String[] sentence, String[] tags, double minSequenceScore) {
return model.bestSequences(DEFAULT_BEAM_SIZE, sentence, new Object[] { tags }, minSequenceScore,
contextGenerator, sequenceValidator);
}
}