| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.opennlp.ml.perceptron; |
| |
| import org.apache.opennlp.ml.model.AbstractModel; |
| import org.apache.opennlp.ml.model.DataIndexer; |
| import org.apache.opennlp.ml.model.EvalParameters; |
| import org.apache.opennlp.ml.model.MutableContext; |
| |
| /** |
| * Trains models using the perceptron algorithm. Each outcome is represented as |
| * a binary perceptron classifier. This supports standard (integer) weighting as well |
| * average weighting as described in: |
| * Discriminative Training Methods for Hidden Markov Models: Theory and Experiments |
| * with the Perceptron Algorithm. Michael Collins, EMNLP 2002. |
| * |
| */ |
| public class PerceptronTrainer { |
| |
| public static final double TOLERANCE_DEFAULT = .00001; |
| |
| /** Number of unique events which occurred in the event set. */ |
| private int numUniqueEvents; |
| /** Number of events in the event set. */ |
| private int numEvents; |
| |
| /** Number of predicates. */ |
| private int numPreds; |
| /** Number of outcomes. */ |
| private int numOutcomes; |
| /** Records the array of predicates seen in each event. */ |
| private int[][] contexts; |
| |
| /** The value associates with each context. If null then context values are assumes to be 1. */ |
| private float[][] values; |
| |
| /** List of outcomes for each event i, in context[i]. */ |
| private int[] outcomeList; |
| |
| /** Records the num of times an event has been seen for each event i, in context[i]. */ |
| private int[] numTimesEventsSeen; |
| |
| /** Stores the String names of the outcomes. The GIS only tracks outcomes |
| as ints, and so this array is needed to save the model to disk and |
| thereby allow users to know what the outcome was in human |
| understandable terms. */ |
| private String[] outcomeLabels; |
| |
| /** Stores the String names of the predicates. The GIS only tracks |
| predicates as ints, and so this array is needed to save the model to |
| disk and thereby allow users to know what the outcome was in human |
| understandable terms. */ |
| private String[] predLabels; |
| |
| private boolean printMessages = true; |
| |
| private double tolerance = TOLERANCE_DEFAULT; |
| |
| private Double stepSizeDecrease; |
| |
| private boolean useSkippedlAveraging; |
| |
| /** |
| * Specifies the tolerance. If the change in training set accuracy |
| * is less than this, stop iterating. |
| * |
| * @param tolerance |
| */ |
| public void setTolerance(double tolerance) { |
| |
| if (tolerance < 0) |
| throw new IllegalArgumentException("tolerance must be a positive number!"); |
| |
| this.tolerance = tolerance; |
| } |
| |
| /** |
| * Enables and sets step size decrease. The step size is |
| * decreased every iteration by the specified value. |
| * |
| * @param decrease - step size decrease in percent |
| */ |
| public void setStepSizeDecrease(double decrease) { |
| |
| if (decrease < 0 || decrease > 100) |
| throw new IllegalArgumentException("decrease must be between 0 and 100"); |
| |
| stepSizeDecrease = decrease; |
| } |
| |
| /** |
| * Enables skipped averaging, this flag changes the standard |
| * averaging to special averaging instead. |
| * <p> |
| * If we are doing averaging, and the current iteration is one |
| * of the first 20 or it is a perfect square, then updated the |
| * summed parameters. |
| * <p> |
| * The reason we don't take all of them is that the parameters change |
| * less toward the end of training, so they drown out the contributions |
| * of the more volatile early iterations. The use of perfect |
| * squares allows us to sample from successively farther apart iterations. |
| * |
| * @param averaging averaging flag |
| */ |
| public void setSkippedAveraging(boolean averaging) { |
| useSkippedlAveraging = averaging; |
| } |
| |
| public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff) { |
| return trainModel(iterations,di,cutoff,true); |
| } |
| |
| public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff, boolean useAverage) { |
| display("Incorporating indexed data for training... \n"); |
| contexts = di.getContexts(); |
| values = di.getValues(); |
| numTimesEventsSeen = di.getNumTimesEventsSeen(); |
| numEvents = di.getNumEvents(); |
| numUniqueEvents = contexts.length; |
| |
| outcomeLabels = di.getOutcomeLabels(); |
| outcomeList = di.getOutcomeList(); |
| |
| predLabels = di.getPredLabels(); |
| numPreds = predLabels.length; |
| numOutcomes = outcomeLabels.length; |
| |
| display("done.\n"); |
| |
| display("\tNumber of Event Tokens: " + numUniqueEvents + "\n"); |
| display("\t Number of Outcomes: " + numOutcomes + "\n"); |
| display("\t Number of Predicates: " + numPreds + "\n"); |
| |
| display("Computing model parameters...\n"); |
| |
| MutableContext[] finalParameters = findParameters(iterations, useAverage); |
| |
| display("...done.\n"); |
| |
| /*************** Create and return the model ******************/ |
| return new PerceptronModel(finalParameters, predLabels, outcomeLabels); |
| } |
| |
| private MutableContext[] findParameters (int iterations, boolean useAverage) { |
| |
| display("Performing " + iterations + " iterations.\n"); |
| |
| int[] allOutcomesPattern= new int[numOutcomes]; |
| for (int oi = 0; oi < numOutcomes; oi++) |
| allOutcomesPattern[oi] = oi; |
| |
| /** Stores the estimated parameter value of each predicate during iteration. */ |
| MutableContext[] params = new MutableContext[numPreds]; |
| for (int pi = 0; pi < numPreds; pi++) { |
| params[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]); |
| for (int aoi=0;aoi<numOutcomes;aoi++) |
| params[pi].setParameter(aoi, 0.0); |
| } |
| |
| EvalParameters evalParams = new EvalParameters(params,numOutcomes); |
| |
| /** Stores the sum of parameter values of each predicate over many iterations. */ |
| MutableContext[] summedParams = new MutableContext[numPreds]; |
| if (useAverage) { |
| for (int pi = 0; pi < numPreds; pi++) { |
| summedParams[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]); |
| for (int aoi=0;aoi<numOutcomes;aoi++) |
| summedParams[pi].setParameter(aoi, 0.0); |
| } |
| } |
| |
| // Keep track of the previous three accuracies. The difference of |
| // the mean of these and the current training set accuracy is used |
| // with tolerance to decide whether to stop. |
| double prevAccuracy1 = 0.0; |
| double prevAccuracy2 = 0.0; |
| double prevAccuracy3 = 0.0; |
| |
| // A counter for the denominator for averaging. |
| int numTimesSummed = 0; |
| |
| double stepsize = 1; |
| for (int i = 1; i <= iterations; i++) { |
| |
| // Decrease the stepsize by a small amount. |
| if (stepSizeDecrease != null) |
| stepsize *= 1 - stepSizeDecrease; |
| |
| displayIteration(i); |
| |
| int numCorrect = 0; |
| |
| for (int ei = 0; ei < numUniqueEvents; ei++) { |
| int targetOutcome = outcomeList[ei]; |
| |
| for (int ni=0; ni<this.numTimesEventsSeen[ei]; ni++) { |
| |
| // Compute the model's prediction according to the current parameters. |
| double[] modelDistribution = new double[numOutcomes]; |
| if (values != null) |
| PerceptronModel.eval(contexts[ei], values[ei], modelDistribution, evalParams, false); |
| else |
| PerceptronModel.eval(contexts[ei], null, modelDistribution, evalParams, false); |
| |
| int maxOutcome = maxIndex(modelDistribution); |
| |
| // If the predicted outcome is different from the target |
| // outcome, do the standard update: boost the parameters |
| // associated with the target and reduce those associated |
| // with the incorrect predicted outcome. |
| if (maxOutcome != targetOutcome) { |
| for (int ci = 0; ci < contexts[ei].length; ci++) { |
| int pi = contexts[ei][ci]; |
| if (values == null) { |
| params[pi].updateParameter(targetOutcome, stepsize); |
| params[pi].updateParameter(maxOutcome, -stepsize); |
| } else { |
| params[pi].updateParameter(targetOutcome, stepsize*values[ei][ci]); |
| params[pi].updateParameter(maxOutcome, -stepsize*values[ei][ci]); |
| } |
| } |
| } |
| |
| // Update the counts for accuracy. |
| if (maxOutcome == targetOutcome) |
| numCorrect++; |
| } |
| } |
| |
| // Calculate the training accuracy and display. |
| double trainingAccuracy = (double) numCorrect / numEvents; |
| if (i < 10 || (i%10) == 0) |
| display(". (" + numCorrect + "/" + numEvents+") " + trainingAccuracy + "\n"); |
| |
| // TODO: Make averaging configurable !!! |
| |
| boolean doAveraging; |
| |
| if (useAverage && useSkippedlAveraging && (i < 20 || isPerfectSquare(i))) { |
| doAveraging = true; |
| } |
| else if (useAverage) { |
| doAveraging = true; |
| } |
| else { |
| doAveraging = false; |
| } |
| |
| if (doAveraging) { |
| numTimesSummed++; |
| for (int pi = 0; pi < numPreds; pi++) |
| for (int aoi=0;aoi<numOutcomes;aoi++) |
| summedParams[pi].updateParameter(aoi, params[pi].getParameters()[aoi]); |
| } |
| |
| // If the tolerance is greater than the difference between the |
| // current training accuracy and all of the previous three |
| // training accuracies, stop training. |
| if (Math.abs(prevAccuracy1-trainingAccuracy) < tolerance |
| && Math.abs(prevAccuracy2-trainingAccuracy) < tolerance |
| && Math.abs(prevAccuracy3-trainingAccuracy) < tolerance) { |
| display("Stopping: change in training set accuracy less than " + tolerance + "\n"); |
| break; |
| } |
| |
| // Update the previous training accuracies. |
| prevAccuracy1 = prevAccuracy2; |
| prevAccuracy2 = prevAccuracy3; |
| prevAccuracy3 = trainingAccuracy; |
| } |
| |
| // Output the final training stats. |
| trainingStats(evalParams); |
| |
| // Create averaged parameters |
| if (useAverage) { |
| for (int pi = 0; pi < numPreds; pi++) |
| for (int aoi=0;aoi<numOutcomes;aoi++) |
| summedParams[pi].setParameter(aoi, summedParams[pi].getParameters()[aoi]/numTimesSummed); |
| |
| return summedParams; |
| |
| } else { |
| |
| return params; |
| |
| } |
| |
| } |
| |
| private double trainingStats (EvalParameters evalParams) { |
| int numCorrect = 0; |
| |
| for (int ei = 0; ei < numUniqueEvents; ei++) { |
| for (int ni=0;ni<this.numTimesEventsSeen[ei];ni++) { |
| |
| double[] modelDistribution = new double[numOutcomes]; |
| |
| if (values != null) |
| PerceptronModel.eval(contexts[ei], values[ei], modelDistribution, evalParams,false); |
| else |
| PerceptronModel.eval(contexts[ei], null, modelDistribution, evalParams, false); |
| |
| int max = maxIndex(modelDistribution); |
| if (max == outcomeList[ei]) |
| numCorrect++; |
| } |
| } |
| double trainingAccuracy = (double) numCorrect / numEvents; |
| display("Stats: (" + numCorrect + "/" + numEvents+") " + trainingAccuracy + "\n"); |
| return trainingAccuracy; |
| } |
| |
| |
| private int maxIndex (double[] values) { |
| int max = 0; |
| for (int i = 1; i < values.length; i++) |
| if (values[i] > values[max]) |
| max = i; |
| return max; |
| } |
| |
| private void display (String s) { |
| if (printMessages) |
| System.out.print(s); |
| } |
| |
| private void displayIteration (int i) { |
| if (i > 10 && (i%10) != 0) |
| return; |
| |
| if (i < 10) |
| display(" " + i + ": "); |
| else if (i < 100) |
| display(" " + i + ": "); |
| else |
| display(i + ": "); |
| } |
| |
| // See whether a number is a perfect square. Inefficient, but fine |
| // for our purposes. |
| private final static boolean isPerfectSquare (int n) { |
| int root = (int)Math.sqrt(n); |
| return root*root == n; |
| } |
| |
| } |