blob: 7e6de666f24a963723790412c81dbc6a1f03b7e8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package opennlp.addons.mallet;
import java.io.IOException;
import java.util.Map;
import java.util.regex.Pattern;
import opennlp.tools.ml.AbstractSequenceTrainer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.ml.model.SequenceStream;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByValueGradients;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
// Transducer should be abstract, we have two CRF and HMM.
// For HMM we don't need to generate any features (how to do that nicely?!)
// Dummy feature generator ?!
public class CRFTrainer extends AbstractSequenceTrainer {
private int[] getOrders() {
String[] ordersString = "0,1".split(",");
int[] orders = new int[ordersString.length];
for (int i = 0; i < ordersString.length; i++) {
orders[i] = Integer.parseInt(ordersString[i]);
System.err.println("Orders: " + orders[i]);
}
return orders;
}
// TODO: Interface has to be changed here,
@Override
public SequenceClassificationModel<String> doTrain(SequenceStream sequences)
throws IOException {
Alphabet dataAlphabet = new Alphabet();
LabelAlphabet targetAlphabet = new LabelAlphabet();
InstanceList trainingData = new InstanceList(dataAlphabet, targetAlphabet);
int nameIndex = 0;
Sequence sequence;
while ((sequence = sequences.read()) != null) {
FeatureVector featureVectors[] = new FeatureVector[sequence.getEvents().length];
Label malletOutcomes[] = new Label[sequence.getEvents().length];
Event events[] = sequence.getEvents();
for (int eventIndex = 0; eventIndex < events.length; eventIndex++) {
Event event = events[eventIndex];
String features[] = event.getContext();
int malletFeatures[] = new int[features.length];
for (int featureIndex = 0; featureIndex < features.length; featureIndex++) {
malletFeatures[featureIndex] = dataAlphabet.lookupIndex(
features[featureIndex], true);
}
// Note: Might contain a feature more than once ... will that
// work ?!
featureVectors[eventIndex] = new FeatureVector(dataAlphabet,
malletFeatures);
malletOutcomes[eventIndex] = targetAlphabet.lookupLabel(
event.getOutcome(), true);
}
LabelSequence malletOutcomeSequence = new LabelSequence(malletOutcomes);
FeatureVectorSequence malletSequence = new FeatureVectorSequence(
featureVectors);
trainingData.add(new Instance(malletSequence, malletOutcomeSequence,
"name" + nameIndex++, "source"));
}
CRF crf = new CRF(trainingData.getDataAlphabet(),
trainingData.getTargetAlphabet());
String startStateName = crf.addOrderNStates(trainingData, getOrders(),
(boolean[]) null,
// default label
"other", Pattern.compile("other,*-cont"), // forbidden pattern
null, // allowed pattern
true);
crf.getState(startStateName).setInitialWeight(0.0);
for (int i = 0; i < crf.numStates(); i++) {
crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
}
crf.getState(startStateName).setInitialWeight(0.0);
crf.setWeightsDimensionAsIn(trainingData, true);
// CRFOptimizableBy* objects (terms in the objective function)
// objective 1: label likelihood objective
// CRFTrainerByLabelLikelihood crfTrainer = new CRFTrainerByLabelLikelihood(crf);
// crfTrainer.setGaussianPriorVariance(1.0);
CRFOptimizableByLabelLikelihood optLabel = new
CRFOptimizableByLabelLikelihood(crf, trainingData);
// // CRF trainer
Optimizable.ByGradientValue[] opts = new Optimizable.ByGradientValue[] {
optLabel };
// by default, use L-BFGS as the optimizer
CRFTrainerByValueGradients crfTrainer = new CRFTrainerByValueGradients(
crf, opts);
crfTrainer.setMaxResets(0);
// SNIP
crfTrainer.train(trainingData, Integer.MAX_VALUE);
// can be very similar to the other model
// one important difference is that the feature gen needs to be integrated
// ...
return new TransducerModel(crf);
}
// TODO: We need to return a sequence model here. How should that be done ?!
//
}