blob: e713d8332c272ec9811dda113a85046cdd1d2f80 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package opennlp.addons.mallet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.BeamSearchContextGenerator;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.model.SerializableArtifact;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Sequence;
public class TransducerModel<T> implements SequenceClassificationModel<T>, SerializableArtifact {
private Transducer model;
public TransducerModel(Transducer model) {
this.model = model;
}
Transducer getModel() {
return model;
}
public opennlp.tools.util.Sequence bestSequence(T[] sequence,
Object[] additionalContext, BeamSearchContextGenerator<T> cg,
SequenceValidator<T> validator) {
return bestSequences(1, sequence, additionalContext, cg, validator)[0];
}
@Override
public opennlp.tools.util.Sequence[] bestSequences(int numSequences,
T[] sequence, Object[] additionalContext, double minSequenceScore,
BeamSearchContextGenerator<T> cg, SequenceValidator<T> validator) {
// TODO: How to implement min score filtering here?
return bestSequences(numSequences, sequence, additionalContext, cg, validator);
}
public opennlp.tools.util.Sequence[] bestSequences(int numSequences,
T[] sequence, Object[] additionalContext,
BeamSearchContextGenerator<T> cg, SequenceValidator<T> validator) {
// TODO: CRF.getInputAlphabet
Alphabet dataAlphabet = model.getInputPipe().getAlphabet();
FeatureVector featureVectors[] = new FeatureVector[sequence.length];
// TODO:: The feature generator needs to get the detected sequence in the end
// to update the adaptive data!
String prior[] = new String[sequence.length];
Arrays.fill(prior, "s"); // <- HACK, this will degrade performance!
// TODO: Put together a feature generator which doesn't fail if outcomes is null!
for (int i = 0; i < sequence.length; i++) {
String features[] = cg.getContext(i, sequence, null, additionalContext);
List<Integer> malletFeatureList = new ArrayList<>(features.length);
for (int featureIndex = 0; featureIndex < features.length; featureIndex++) {
if (dataAlphabet.contains(features[featureIndex])) {
malletFeatureList.add(dataAlphabet.lookupIndex(features[featureIndex]));
}
}
int malletFeatures[] = new int[malletFeatureList.size()];
for (int k = 0; k < malletFeatureList.size(); k++) {
malletFeatures[k] = malletFeatureList.get(k);
}
// Note: Might contain a feature more than once ... will that work ?!
featureVectors[i] = new FeatureVector(dataAlphabet, malletFeatures);
}
FeatureVectorSequence malletSequence = new FeatureVectorSequence(featureVectors);
Sequence[] answers = null;
if (numSequences == 1) {
answers = new Sequence[1];
answers[0] = model.transduce(malletSequence);
} else {
MaxLatticeDefault lattice = new MaxLatticeDefault(model, malletSequence, null, 3);
answers = lattice.bestOutputSequences(numSequences).toArray(new Sequence[0]);
}
opennlp.tools.util.Sequence[] outcomeSequences = new opennlp.tools.util.Sequence[answers.length];
for (int i = 0; i < answers.length; i++) {
Sequence seq = answers[i];
List<String> outcomes = new ArrayList<>(seq.size());
for (int j = 0; j < seq.size(); j++) {
outcomes.add(seq.get(j).toString());
}
outcomeSequences[i] = new opennlp.tools.util.Sequence(outcomes);
}
return outcomeSequences;
}
@Override
public Class<?> getArtifactSerializerClass() {
return TransducerModelSerializer.class;
}
@Override
public String[] getOutcomes() {
Alphabet targetAlphabet = model.getInputPipe().getTargetAlphabet();
String outcomes[] = new String[targetAlphabet.size()];
for (int i = 0; i < targetAlphabet.size(); i++) {
outcomes[i] = targetAlphabet.lookupObject(i).toString();
}
return outcomes;
}
}