blob: 4e80f010c6601ab25daa70817ad2db32a9b8ae96 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import de.bwaldvogel.liblinear.Train;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.MaxentModel;
public class LiblinearTrainer extends AbstractEventTrainer {
public LiblinearTrainer(Map<String, String> trainParams,
Map<String, String> reportMap) {
super(trainParams, reportMap);
// TODO: Extract solver type here
// depending on it, extract parameters
// e.g. bias, C, eps for L1_LR
}
private static Problem constructProblem(List<Double> vy, List<Feature[]> vx, int maxIndex, double bias) {
// Initialize problem
Problem problem = new Problem();
problem.l = vy.size();
problem.n = maxIndex;
problem.bias = bias;
if (bias >= 0) {
problem.n++;
}
problem.x = new Feature[problem.l][];
for (int i = 0; i < problem.l; i++) {
problem.x[i] = vx.get(i);
if (bias >= 0) {
problem.x[i][problem.x[i].length - 1] = new FeatureNode(maxIndex + 1, bias);
}
}
problem.y = new double[problem.l];
for (int i = 0; i < problem.l; i++) {
problem.y[i] = vy.get(i).doubleValue();
}
return problem;
}
@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {
List<Double> vy = new ArrayList<Double>();
List<Feature[]> vx = new ArrayList<Feature[]>();
// outcomes
int outcomes[] = indexer.getOutcomeList();
final int bias = 0;
int max_index = 0;
// For each event ...
for (int i = 0; i < indexer.getContexts().length; i++) {
int outcome = outcomes[i];
vy.add(Double.valueOf(outcome));
int features[] = indexer.getContexts()[i];
Feature[] x;
if (bias >= 0) {
x = new Feature[features.length + 1];
} else {
x = new Feature[features.length];
}
// for each feature ...
for (int fi = 0; fi < features.length; fi++) {
x[fi] = new FeatureNode(features[fi] + 1, indexer.getNumTimesEventsSeen()[fi]);
}
if (features.length > 0) {
max_index = Math.max(max_index, x[features.length - 1].getIndex());
}
vx.add(x);
}
Problem problem = constructProblem(vy, vx, max_index, bias);
Parameter parameter = new Parameter(SolverType.L1R_LR, 1d, 0.001d);
Model liblinearModel = Linear.train(problem, parameter);
Map<String, Integer> predMap = new HashMap<String, Integer>();
String predLabels[] = indexer.getPredLabels();
for (int i = 0; i < predLabels.length; i++) {
predMap.put(predLabels[i], i);
}
return new LiblinearModel(liblinearModel, indexer.getOutcomeLabels(), predMap);
}
@Override
public boolean isSortAndMerge() {
return true;
}
public static void main(String[] args) throws Exception {
File file = File.createTempFile("svm", "test");
file.deleteOnExit();
Collection<String> lines = new ArrayList<String>();
lines.add("1 1:1 3:1 4:1 6:1");
lines.add("2 2:1 3:1 5:1 7:1");
lines.add("1 3:1 5:1");
lines.add("1 1:1 4:1 7:1");
lines.add("2 4:1 5:1 7:1");
lines.add("1 1:1 4:1 7:1");
lines.add("2 4:1 5:1 7:1");
BufferedWriter writer = new BufferedWriter(new FileWriter(file));
try {
for (String line : lines)
writer.append(line).append("\n");
} finally {
writer.close();
}
Train train = new Train();
Problem problem = train.readProblem(file, 0d);
Model model = Linear.train(problem, new Parameter(SolverType.L1R_LR, 10d,
0.02d));
double result = Linear.predict(model, new Feature[]{new FeatureNode(4, 1d), new FeatureNode(1, 1d)});
double outcomes[] = new double[2];
double result2 = Linear.predictProbability(model, new Feature[]{new FeatureNode(4, 1d), new FeatureNode(1, 1d)}, outcomes);
System.out.println(result);
}
}