blob: 8357836181db7ae9952f1b86b9b64ef220304c7a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.opennlp.ml.perceptron;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.opennlp.ml.model.AbstractModel;
import org.apache.opennlp.ml.model.AbstractModelWriter;
import org.apache.opennlp.ml.model.ComparablePredicate;
import org.apache.opennlp.ml.model.Context;
import org.apache.opennlp.ml.model.IndexHashTable;
/**
* Abstract parent class for Perceptron writers. It provides the persist method
* which takes care of the structure of a stored document, and requires an
* extending class to define precisely how the data should be stored.
*
*/
public abstract class PerceptronModelWriter extends AbstractModelWriter {
protected Context[] PARAMS;
protected String[] OUTCOME_LABELS;
protected String[] PRED_LABELS;
int numOutcomes;
public PerceptronModelWriter (AbstractModel model) {
Object[] data = model.getDataStructures();
this.numOutcomes = model.getNumOutcomes();
PARAMS = (Context[]) data[0];
IndexHashTable<String> pmap = (IndexHashTable<String>) data[1];
OUTCOME_LABELS = (String[])data[2];
PRED_LABELS = new String[pmap.size()];
pmap.toArray(PRED_LABELS);
}
protected ComparablePredicate[] sortValues () {
ComparablePredicate[] sortPreds;
ComparablePredicate[] tmpPreds = new ComparablePredicate[PARAMS.length];
int[] tmpOutcomes = new int[numOutcomes];
double[] tmpParams = new double[numOutcomes];
int numPreds = 0;
//remove parameters with 0 weight and predicates with no parameters
for (int pid=0; pid<PARAMS.length; pid++) {
int numParams = 0;
double[] predParams = PARAMS[pid].getParameters();
int[] outcomePattern = PARAMS[pid].getOutcomes();
for (int pi=0;pi<predParams.length;pi++) {
if (predParams[pi] != 0d) {
tmpOutcomes[numParams]=outcomePattern[pi];
tmpParams[numParams]=predParams[pi];
numParams++;
}
}
int[] activeOutcomes = new int[numParams];
double[] activeParams = new double[numParams];
for (int pi=0;pi<numParams;pi++) {
activeOutcomes[pi] = tmpOutcomes[pi];
activeParams[pi] = tmpParams[pi];
}
if (numParams != 0) {
tmpPreds[numPreds] = new ComparablePredicate(PRED_LABELS[pid],activeOutcomes,activeParams);
numPreds++;
}
}
System.err.println("Compressed "+PARAMS.length+" parameters to "+numPreds);
sortPreds = new ComparablePredicate[numPreds];
System.arraycopy(tmpPreds, 0, sortPreds, 0, numPreds);
Arrays.sort(sortPreds);
return sortPreds;
}
protected List<List<ComparablePredicate>> computeOutcomePatterns(ComparablePredicate[] sorted) {
ComparablePredicate cp = sorted[0];
List<List<ComparablePredicate>> outcomePatterns = new ArrayList<List<ComparablePredicate>>();
List<ComparablePredicate> newGroup = new ArrayList<ComparablePredicate>();
for (ComparablePredicate predicate : sorted) {
if (cp.compareTo(predicate) == 0) {
newGroup.add(predicate);
} else {
cp = predicate;
outcomePatterns.add(newGroup);
newGroup = new ArrayList<ComparablePredicate>();
newGroup.add(predicate);
}
}
outcomePatterns.add(newGroup);
System.err.println(outcomePatterns.size()+" outcome patterns");
return outcomePatterns;
}
/**
* Writes the model to disk, using the <code>writeX()</code> methods
* provided by extending classes.
*
* <p>If you wish to create a PerceptronModelWriter which uses a different
* structure, it will be necessary to override the persist method in
* addition to implementing the <code>writeX()</code> methods.
*/
public void persist() throws IOException {
// the type of model (Perceptron)
writeUTF("Perceptron");
// the mapping from outcomes to their integer indexes
writeInt(OUTCOME_LABELS.length);
for (String label : OUTCOME_LABELS) {
writeUTF(label);
}
// the mapping from predicates to the outcomes they contributed to.
// The sorting is done so that we actually can write this out more
// compactly than as the entire list.
ComparablePredicate[] sorted = sortValues();
List<List<ComparablePredicate>> compressed = computeOutcomePatterns(sorted);
writeInt(compressed.size());
for (List<ComparablePredicate> a : compressed) {
writeUTF(a.size() + a.get(0).toString());
}
// the mapping from predicate names to their integer indexes
writeInt(sorted.length);
for (ComparablePredicate s : sorted) {
writeUTF(s.name);
}
// write out the parameters
for (int i=0; i<sorted.length; i++)
for (int j=0; j<sorted[i].params.length; j++)
writeDouble(sorted[i].params[j]);
close();
}
}