blob: 051e5532af81ba41c2dfff8762b829a05ac13bd7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.opennlp.ml.model;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.opennlp.ml.perceptron.PerceptronTrainer;
import org.apache.opennlp.ml.perceptron.SimplePerceptronSequenceTrainer;
public class TrainUtil {
public static final String ALGORITHM_PARAM = "Algorithm";
public static final String MAXENT_VALUE = "MAXENT";
public static final String PERCEPTRON_VALUE = "PERCEPTRON";
public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
public static final String CUTOFF_PARAM = "Cutoff";
private static final int CUTOFF_DEFAULT = 5;
public static final String ITERATIONS_PARAM = "Iterations";
private static final int ITERATIONS_DEFAULT = 100;
public static final String DATA_INDEXER_PARAM = "DataIndexer";
public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
private static String getStringParam(Map<String, String> trainParams, String key,
String defaultValue, Map<String, String> reportMap) {
String valueString = trainParams.get(key);
if (valueString == null)
valueString = defaultValue;
if (reportMap != null)
reportMap.put(key, valueString);
return valueString;
}
private static int getIntParam(Map<String, String> trainParams, String key,
int defaultValue, Map<String, String> reportMap) {
String valueString = trainParams.get(key);
if (valueString != null)
return Integer.parseInt(valueString);
else
return defaultValue;
}
private static double getDoubleParam(Map<String, String> trainParams, String key,
double defaultValue, Map<String, String> reportMap) {
String valueString = trainParams.get(key);
if (valueString != null)
return Double.parseDouble(valueString);
else
return defaultValue;
}
private static boolean getBooleanParam(Map<String, String> trainParams, String key,
boolean defaultValue, Map<String, String> reportMap) {
String valueString = trainParams.get(key);
if (valueString != null)
return Boolean.parseBoolean(valueString);
else
return defaultValue;
}
public static boolean isValid(Map<String, String> trainParams) {
// TODO: Need to validate all parameters correctly ... error prone?!
String algorithmName = trainParams.get(ALGORITHM_PARAM);
if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
PERCEPTRON_VALUE.equals(algorithmName) ||
PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
return false;
}
try {
String cutoffString = trainParams.get(CUTOFF_PARAM);
if (cutoffString != null) Integer.parseInt(cutoffString);
String iterationsString = trainParams.get(ITERATIONS_PARAM);
if (iterationsString != null) Integer.parseInt(iterationsString);
}
catch (NumberFormatException e) {
return false;
}
String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
if (dataIndexer != null) {
if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
return false;
}
}
// TODO: Check data indexing ...
return true;
}
// TODO: Need a way to report results and settings back for inclusion in model ...
public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap)
throws IOException {
if (!isValid(trainParams))
throw new IllegalArgumentException("trainParams are not valid!");
if(isSequenceTraining(trainParams))
throw new IllegalArgumentException("sequence training is not supported by this method!");
String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
boolean sortAndMerge;
if (MAXENT_VALUE.equals(algorithmName))
sortAndMerge = true;
else if (PERCEPTRON_VALUE.equals(algorithmName))
sortAndMerge = false;
else
throw new IllegalStateException("Unexpected algorihtm name: " + algorithmName);
HashSumEventStream hses = new HashSumEventStream(events);
String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
DATA_INDEXER_TWO_PASS_VALUE, reportMap);
DataIndexer indexer = null;
if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
}
else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
}
else {
throw new IllegalStateException("Unexpected data indexer name: " + dataIndexerName);
}
AbstractModel model;
if (MAXENT_VALUE.equals(algorithmName)) {
int threads = getIntParam(trainParams, "Threads", 1, reportMap);
model = org.apache.opennlp.ml.maxent.GIS.trainModel(iterations, indexer,
true, false, null, 0, threads);
}
else if (PERCEPTRON_VALUE.equals(algorithmName)) {
boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
boolean useSkippedAveraging = getBooleanParam(trainParams, "UseSkippedAveraging", false, reportMap);
// overwrite otherwise it might not work
if (useSkippedAveraging)
useAverage = true;
double stepSizeDecrease = getDoubleParam(trainParams, "StepSizeDecrease", 0, reportMap);
double tolerance = getDoubleParam(trainParams, "Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT, reportMap);
org.apache.opennlp.ml.perceptron.PerceptronTrainer perceptronTrainer =
new org.apache.opennlp.ml.perceptron.PerceptronTrainer();
perceptronTrainer.setSkippedAveraging(useSkippedAveraging);
if (stepSizeDecrease > 0)
perceptronTrainer.setStepSizeDecrease(stepSizeDecrease);
perceptronTrainer.setTolerance(tolerance);
model = perceptronTrainer.trainModel(
iterations, indexer, cutoff, useAverage);
}
else {
throw new IllegalStateException("Algorithm not supported: " + algorithmName);
}
if (reportMap != null)
reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
return model;
}
/**
* Detects if the training algorithm requires sequence based feature generation
* or not.
*/
public static boolean isSequenceTraining(Map<String, String> trainParams) {
return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
}
public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
Map<String, String> reportMap) throws IOException {
if (!isValid(trainParams))
throw new IllegalArgumentException("trainParams are not valid!");
if (!isSequenceTraining(trainParams))
throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
return new SimplePerceptronSequenceTrainer().trainModel(
iterations, events, cutoff,useAverage);
}
}