blob: 5f6661da318a24bdd685647f0b967490df356125 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package opennlp.addons.mallet;
import java.util.ArrayList;
import java.util.List;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.model.SerializableArtifact;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
class ClassifierModel implements MaxentModel, SerializableArtifact {
private Classifier classifer;
public ClassifierModel(Classifier classifer) {
this.classifer = classifer;
}
Classifier getClassifer() {
return classifer;
}
public double[] eval(String[] features) {
Alphabet dataAlphabet = classifer.getAlphabet();
List<Integer> malletFeatureList = new ArrayList<>(features.length);
for (String feature : features) {
int featureId = dataAlphabet.lookupIndex(feature);
if (featureId != -1) {
malletFeatureList.add(featureId);
}
}
int malletFeatures[] = new int[malletFeatureList.size()];
for (int i = 0; i < malletFeatureList.size(); i++) {
malletFeatures[i] = malletFeatureList.get(i);
}
FeatureVector fv = new FeatureVector(classifer.getAlphabet(),
malletFeatures);
Instance instance = new Instance(fv, null, null, null);
Classification result = classifer.classify(instance);
LabelVector labeling = result.getLabelVector();
LabelAlphabet targetAlphabet = classifer.getLabelAlphabet();
double outcomes[] = new double[targetAlphabet.size()];
for (int i = 0; i < outcomes.length; i++) {
Label label = targetAlphabet.lookupLabel(i);
int rank = labeling.getRank(label);
outcomes[i] = labeling.getValueAtRank(rank);
}
return outcomes;
}
public double[] eval(String[] context, double[] probs) {
return eval(context);
}
public double[] eval(String[] context, float[] values) {
return eval(context);
}
@Override
public String getBestOutcome(double[] ocs) {
int best = 0;
for (int i = 1; i < ocs.length; i++)
if (ocs[i] > ocs[best])
best = i;
return getOutcome(best);
}
@Override
public String getAllOutcomes(double[] outcomes) {
return null;
}
@Override
public String getOutcome(int i) {
return classifer.getLabelAlphabet().lookupLabel(i).getEntry().toString();
}
@Override
public int getIndex(String outcome) {
return classifer.getLabelAlphabet().lookupIndex(outcome);
}
@Override
public int getNumOutcomes() {
return classifer.getLabelAlphabet().size();
}
@Override
public Class<?> getArtifactSerializerClass() {
return ClassifierModelSerializer.class;
}
}