blob: 8cec9488c7041fb05d5640cd7d696df900a10b4d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.maxent.quasinewton;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
public class QNModel extends AbstractModel {
public QNModel(Context[] params, String[] predLabels, String[] outcomeNames) {
super(params, predLabels, outcomeNames);
this.modelType = ModelType.MaxentQn;
}
public int getNumOutcomes() {
return this.outcomeNames.length;
}
private Context getPredIndex(String predicate) {
return pmap.get(predicate);
}
public double[] eval(String[] context) {
return eval(context, new double[evalParams.getNumOutcomes()]);
}
public double[] eval(String[] context, double[] probs) {
return eval(context, null, probs);
}
public double[] eval(String[] context, float[] values) {
return eval(context, values, new double[evalParams.getNumOutcomes()]);
}
/**
* Model evaluation which should be used during inference.
* @param context
* The predicates which have been observed at the present
* decision point.
* @param values
* Weights of the predicates which have been observed at
* the present decision point.
* @param probs
* Probability for outcomes.
* @return Normalized probabilities for the outcomes given the context.
*/
private double[] eval(String[] context, float[] values, double[] probs) {
for (int ci = 0; ci < context.length; ci++) {
Context pred = getPredIndex(context[ci]);
if (pred != null) {
double predValue = 1.0;
if (values != null) predValue = values[ci];
double[] parameters = pred.getParameters();
int[] outcomes = pred.getOutcomes();
for (int i = 0; i < outcomes.length; i++) {
int oi = outcomes[i];
probs[oi] += predValue * parameters[i];
}
}
}
double logSumExp = ArrayMath.logSumOfExps(probs);
for (int oi = 0; oi < outcomeNames.length; oi++) {
probs[oi] = StrictMath.exp(probs[oi] - logSumExp);
}
return probs;
}
/**
* Model evaluation which should be used during training to report model accuracy.
* @param context
* Indices of the predicates which have been observed at the present
* decision point.
* @param values
* Weights of the predicates which have been observed at
* the present decision point.
* @param probs
* Probability for outcomes
* @param nOutcomes
* Number of outcomes
* @param nPredLabels
* Number of unique predicates
* @param parameters
* Model parameters
* @return Normalized probabilities for the outcomes given the context.
*/
static double[] eval(int[] context, float[] values, double[] probs,
int nOutcomes, int nPredLabels, double[] parameters) {
for (int i = 0; i < context.length; i++) {
int predIdx = context[i];
double predValue = values != null ? values[i] : 1.0;
for (int oi = 0; oi < nOutcomes; oi++) {
probs[oi] += predValue * parameters[oi * nPredLabels + predIdx];
}
}
double logSumExp = ArrayMath.logSumOfExps(probs);
for (int oi = 0; oi < nOutcomes; oi++) {
probs[oi] = StrictMath.exp(probs[oi] - logSumExp);
}
return probs;
}
}