/* | |
* Licensed to the Apache Software Foundation (ASF) under one | |
* or more contributor license agreements. See the NOTICE file | |
* distributed with this work for additional information | |
* regarding copyright ownership. The ASF licenses this file | |
* to you under the Apache License, Version 2.0 (the | |
* "License"); you may not use this file except in compliance | |
* with the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, | |
* software distributed under the License is distributed on an | |
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
* KIND, either express or implied. See the License for the | |
* specific language governing permissions and limitations | |
* under the License. | |
*/ | |
package org.apache.uima.ruta.textruler.learner.trabal; | |
import java.io.BufferedWriter; | |
import java.io.File; | |
import java.io.FileWriter; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.Comparator; | |
import java.util.HashMap; | |
import java.util.HashSet; | |
import java.util.Iterator; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Set; | |
import java.util.TreeMap; | |
import org.apache.commons.lang3.StringUtils; | |
import org.apache.uima.cas.CAS; | |
import org.apache.uima.cas.FSIterator; | |
import org.apache.uima.cas.Feature; | |
import org.apache.uima.cas.FeatureStructure; | |
import org.apache.uima.cas.Type; | |
import org.apache.uima.cas.text.AnnotationFS; | |
import org.apache.uima.cas.text.AnnotationIndex; | |
import org.apache.uima.ruta.textruler.core.TextRulerBasicLearner; | |
import org.apache.uima.ruta.textruler.core.TextRulerExample; | |
import org.apache.uima.ruta.textruler.core.TextRulerExampleDocument; | |
import org.apache.uima.ruta.textruler.core.TextRulerExampleDocumentSet; | |
import org.apache.uima.ruta.textruler.core.TextRulerStatisticsCollector; | |
import org.apache.uima.ruta.textruler.core.TextRulerTarget; | |
import org.apache.uima.ruta.textruler.core.TextRulerToolkit; | |
import org.apache.uima.ruta.textruler.extension.TextRulerLearnerDelegate; | |
import org.apache.uima.ruta.textruler.tools.MemoryWatch; | |
import org.apache.uima.util.CasCopier; | |
public class TrabalLearner extends TextRulerBasicLearner { | |
public final static String ANNOTATION_TYPE_BASIC = "org.apache.uima.ruta.type.RutaBasic"; | |
public final static String ANNOTATION_TYPE_ANY = "org.apache.uima.ruta.type.ANY"; | |
public final static String ANNOTATION_TYPE_FRAME = "org.apache.uima.ruta.type.RutaFrame"; | |
public final static String PACKAGE = "org.apache.uima"; | |
public final static String TYPESYSTEM = "de.uniwue.tm.citie.CompleteTypeSystemTypeSystem"; | |
public final static List<String> FILTERED_FEATURES; | |
static { | |
ArrayList<String> tmp = new ArrayList<String>(); | |
tmp.add("sofa"); | |
tmp.add("begin"); | |
tmp.add("end"); | |
FILTERED_FEATURES = tmp; | |
} | |
public final static String MAX_NUMBER_OF_BASIC_RULES_KEY = "maxNumberOfBasicRules"; | |
public final static String MAX_NUMBER_OF_RULES_KEY = "maxNumberOfRules"; | |
public final static String MAX_NUMBER_OF_ITERATIONS_KEY = "maxNumberOfIterations"; | |
public final static String ALGORITHM_ITERATIONS_KEY = "algorithmIterations"; | |
public final static String MAX_ERROR_RATE_KEY = "maxErrorRate"; | |
public final static String ENABLE_FEATURES_KEY = "enableFeatures"; | |
public final static int MAX_NUMBER_OF_BASIC_RULES = 80; | |
public final static int MAX_NUMBER_OF_RULES = 150; | |
public final static int MAX_NUMBER_OF_ITERATIONS = 2; | |
public final static int ALGORITHM_ITERATIONS = 3; | |
public final static double MAX_ERROR_RATE = 0.4; | |
public final static boolean ENABLE_FEATURES = false; | |
private int maxNumberOfBasicRules = MAX_NUMBER_OF_BASIC_RULES; | |
private int maxNumberOfRules = MAX_NUMBER_OF_RULES; | |
private int maxNumberOfIterations = MAX_NUMBER_OF_ITERATIONS; | |
private int algorithmIterations = ALGORITHM_ITERATIONS; | |
private double maxErrorRate = MAX_ERROR_RATE; | |
public boolean enableFeatures = ENABLE_FEATURES; | |
private String result = ""; // actualResult + tempRules | |
private String actualResult = ""; | |
private List<TrabalRule> bestRulesForStatus = new ArrayList<TrabalRule>(); | |
private String inputDirectory; | |
private String additionalFolderPath; | |
private TextRulerExampleDocumentSet additionalDocuments; | |
private List<AnnotationError> errors; | |
private Map<String, RankedList> positiveExamples; | |
private Map<String, Double> idf; | |
private Map<String, TextRulerStatisticsCollector> inducedRules = new TreeMap<String, TextRulerStatisticsCollector>(); | |
public TrabalLearner(String inputFolderPath, String additionalFolderPath, | |
String preprocessorTMfile, String tempFolderPath, String[] fullSlotTypeNames, | |
Set<String> filterSet, boolean skip, TextRulerLearnerDelegate delegate) { | |
super(inputFolderPath, preprocessorTMfile, tempFolderPath, fullSlotTypeNames, filterSet, skip, | |
delegate); | |
this.inputDirectory = inputFolderPath; | |
this.additionalFolderPath = additionalFolderPath; | |
} | |
/** | |
* Main method. Starts the algorithm. | |
*/ | |
protected void doRun() { | |
try { | |
getAnalysisEngine(); | |
getAdditionalDocuments(); | |
if (additionalDocuments == null) | |
throw new Exception("Error: Additional data is missing!"); | |
if (exampleDocuments.getDocuments().size() != additionalDocuments.getDocuments().size()) | |
throw new Exception("Error: Training data doesn't match additional data!"); | |
sendStatusUpdateToDelegate("Loading documents...", TextRulerLearnerState.ML_INITIALIZING, | |
true); | |
sendStatusUpdateToDelegate("Comparing documents...", TextRulerLearnerState.ML_RUNNING, true); | |
for (int i = 0; i < algorithmIterations; i++) { | |
actualResult += "BLOCK(Iteration_" + (i + 1) + ") Document{} {\n"; | |
sendStatusUpdateToDelegate("Comparing documents...", TextRulerLearnerState.ML_RUNNING, true); | |
idf = createIDF(); | |
Map<String, List<AnnotationError>> errorGrps = createErrorGroups(); | |
List<TrabalRule> rules = runAlgorithm(errorGrps); | |
if (rules.size() == 0) { | |
actualResult = actualResult.replace("BLOCK(Iteration_" + (i + 1) + ") Document{} {", ""); | |
break; | |
} | |
actualResult += "}\n\n"; | |
result = actualResult; | |
sendStatusUpdateToDelegate("Finished " + (i + 1) + ". iteration.", | |
TextRulerLearnerState.ML_DONE, true); | |
} | |
Map<String, List<AnnotationError>> errorGrps = createErrorGroups(); | |
if (errorGrps.size() > 0) { | |
actualResult += "// Remaining errors \n"; | |
for (List<AnnotationError> list : errorGrps.values()) { | |
for (AnnotationError err : list) { | |
actualResult += "\n// " + err; | |
} | |
} | |
} | |
if (new File(tempDirectory).isDirectory()) | |
new File(tempDirectory).delete(); | |
result = actualResult; | |
sendStatusUpdateToDelegate("Done!", TextRulerLearnerState.ML_DONE, true); | |
} catch (Exception e) { | |
sendStatusUpdateToDelegate(e.getLocalizedMessage(), TextRulerLearnerState.ML_ERROR, true); | |
e.printStackTrace(); | |
} | |
} | |
private void removeBasics() { | |
for (TextRulerExampleDocument doc : additionalDocuments.getDocuments()) { | |
CAS cas = doc.getCAS(); | |
List<AnnotationFS> basics = new ArrayList<AnnotationFS>(); | |
for (AnnotationFS a : cas.getAnnotationIndex()) { | |
if (a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
basics.add(a); | |
} | |
} | |
for (AnnotationFS a : basics) | |
cas.removeFsFromIndexes(a); | |
} | |
} | |
private void removeBasics(CAS cas) { | |
List<AnnotationFS> basics = new ArrayList<AnnotationFS>(); | |
for (AnnotationFS a : cas.getAnnotationIndex()) { | |
if (a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
basics.add(a); | |
} | |
} | |
for (AnnotationFS a : basics) | |
cas.removeFsFromIndexes(a); | |
} | |
private Map<String, Double> createIDF() { | |
Map<String, Double> result = new HashMap<String, Double>(); | |
List<TextRulerExampleDocument> documents = additionalDocuments.getDocuments(); | |
// int counter = 0; | |
Map<String, Integer> annotations = new HashMap<String, Integer>(); | |
for (int i = 0; i < documents.size(); i++) { | |
AnnotationIndex<AnnotationFS> index = documents.get(i).getCAS().getAnnotationIndex(); | |
FSIterator<AnnotationFS> iterator = index.iterator(); | |
while (iterator.isValid()) { | |
AnnotationFS annotation = iterator.next(); | |
if (isSlotType(annotation.getType())) { | |
TrabalAnnotation a = new TrabalAnnotation(annotation, enableFeatures); | |
if (annotations.containsKey(a.getType().getShortName())) { | |
annotations.put(a.getType().getShortName(), | |
new Integer(annotations.get(a.getType().getShortName()).intValue() + 1)); | |
} else { | |
annotations.put(a.getType().getShortName(), new Integer(1)); | |
} | |
// counter++; | |
} | |
} | |
for (String key : annotations.keySet()) { | |
double d = annotations.get(key).doubleValue(); | |
result.put(key, new Double(1 / d)); | |
} | |
} | |
return result; | |
} | |
/** | |
* Compares the training data to the additional data and creates a mapping of the contained | |
* errors, where the key is a combination of the error type and the names of the concerned | |
* annotations, and the values are lists of errors that can be corrected through the same type of | |
* rule. | |
* | |
* @return map of errors | |
*/ | |
private Map<String, List<AnnotationError>> createErrorGroups() { | |
errors = createErrorList(); | |
Collections.sort(errors); | |
Map<String, List<AnnotationError>> result = new HashMap<String, List<AnnotationError>>(); | |
for (AnnotationError error : errors) { | |
String truthFeatures = ""; | |
String errorFeatures = ""; | |
if (error.getError() != null && error.getTruth() != null) { | |
if (enableFeatures) { | |
truthFeatures = listFeatures(error.getTruth()); | |
errorFeatures = listFeatures(error.getError()); | |
} | |
String key = error.getType() + "_" + error.getError().getAnnotation().getType().toString() | |
+ errorFeatures + "_" + error.getTruth().getAnnotation().getType().toString() | |
+ truthFeatures; | |
if (result.get(key) == null) { | |
result.put(key, new ArrayList<AnnotationError>()); | |
} | |
result.get(key).add(error); | |
} else if (error.getTruth() != null) { | |
if (enableFeatures) | |
truthFeatures = listFeatures(error.getTruth()); | |
String key = error.getType() + "_" + error.getTruth().getAnnotation().getType().toString() | |
+ truthFeatures; | |
if (result.get(key) == null) { | |
result.put(key, new ArrayList<AnnotationError>()); | |
} | |
result.get(key).add(error); | |
} else { | |
if (enableFeatures) | |
errorFeatures = listFeatures(error.getError()); | |
String key = error.getType() + "_" + error.getError().getAnnotation().getType().toString() | |
+ errorFeatures; | |
if (result.get(key) == null) { | |
result.put(key, new ArrayList<AnnotationError>()); | |
} | |
result.get(key).add(error); | |
} | |
} | |
return result; | |
} | |
private String listFeatures(TextRulerExample example) { | |
String result = ""; | |
for (Feature f : example.getAnnotation().getType().getFeatures()) { | |
if (!FILTERED_FEATURES.contains(f.getShortName())) | |
result += "_" + f.toString(); | |
} | |
return result; | |
} | |
/** | |
* Creates a list of annotation pairs and the fitting type of transformation action. | |
*/ | |
private List<AnnotationError> createErrorList() { | |
positiveExamples = new HashMap<String, RankedList>(); | |
List<AnnotationError> result = new ArrayList<AnnotationError>(); | |
List<TrabalAnnotation> matches; | |
Iterator<TrabalAnnotation> iterator; | |
Iterator<TrabalAnnotation> docIterator; | |
List<TextRulerExampleDocument> documents; | |
List<TextRulerExampleDocument> goldStandard; | |
documents = additionalDocuments.getDocuments(); | |
goldStandard = exampleDocuments.getDocuments(); | |
for (int i = 0; i < goldStandard.size(); i++) { | |
if (shouldAbort()) | |
break; | |
matches = new ArrayList<TrabalAnnotation>(); | |
AnnotationIndex<AnnotationFS> index = goldStandard.get(i).getCAS().getAnnotationIndex(); | |
List<TrabalAnnotation> gold = new ArrayList<TrabalAnnotation>(); | |
List<TrabalAnnotation> docs = new ArrayList<TrabalAnnotation>(); | |
for (AnnotationFS a : index) { | |
if (isSlotType(a.getType())) { | |
gold.add(new TrabalAnnotation(a, goldStandard.get(i), enableFeatures)); | |
} | |
} | |
AnnotationIndex<AnnotationFS> docIndex = documents.get(i).getCAS().getAnnotationIndex(); | |
for (AnnotationFS b : docIndex) { | |
if (isSlotType(b.getType())) { | |
docs.add(new TrabalAnnotation(b, documents.get(i), enableFeatures)); | |
} | |
} | |
iterator = gold.iterator(); | |
docIterator = docs.iterator(); | |
TrabalAnnotation a; | |
TrabalAnnotation b; | |
// find correct annotated elements | |
int exampleIndex = 0; | |
while (iterator.hasNext()) { | |
if (shouldAbort()) | |
break; | |
exampleIndex++; | |
sendStatusUpdateToDelegate("Comparing documents " + (i + 1) + " of " + goldStandard.size() | |
+ ": example " + exampleIndex + " of " + gold.size(), | |
TextRulerLearnerState.ML_RUNNING, false); | |
a = iterator.next(); | |
docIterator = docs.iterator(); | |
while (docIterator.hasNext()) { | |
b = docIterator.next(); | |
if (b.equals(a)) { | |
matches.add(a); | |
matches.add(b); | |
if (positiveExamples.containsKey(a.getType().getShortName())) { | |
RankedList list = positiveExamples.get(a.getType().getShortName()); | |
list.addAll(createConditions(a)); | |
positiveExamples.put(a.getType().getShortName(), list); | |
} else { | |
RankedList list = new RankedList(idf); | |
list.addAll(createConditions(a)); | |
positiveExamples.put(a.getType().getShortName(), list); | |
} | |
break; | |
} | |
} | |
} | |
// create correction type errors | |
iterator = gold.iterator(); | |
docIterator = docs.iterator(); | |
while (iterator.hasNext()) { | |
a = iterator.next(); | |
docIterator = docs.iterator(); | |
while (!matches.contains(a) && docIterator.hasNext()) { | |
b = docIterator.next(); | |
if (!matches.contains(b)) { | |
if (b.getBegin() == a.getBegin() && b.getEnd() == a.getEnd()) { | |
TextRulerTarget target = new TextRulerTarget(b.getType().getName(), this); | |
TextRulerExample error = new TextRulerExample(b.getDocument(), b, false, target); | |
TextRulerExample truth = new TextRulerExample(a.getDocument(), a, true, target); | |
result.add(new AnnotationError(error, truth, AnnotationErrorType.CORRECTION)); | |
matches.add(a); | |
matches.add(b); | |
break; | |
} | |
} | |
} | |
} | |
// create shifting type errors | |
iterator = gold.iterator(); | |
docIterator = docs.iterator(); | |
List<AnnotationError> tempErrors; | |
int distance; | |
while (iterator.hasNext()) { | |
a = iterator.next(); | |
docIterator = docs.iterator(); | |
tempErrors = new ArrayList<AnnotationError>(); | |
while (!matches.contains(a) && docIterator.hasNext()) { | |
b = docIterator.next(); | |
if (!matches.contains(b)) { | |
if (b.getType().getShortName().equals(a.getType().getShortName()) | |
&& b.getEnd() >= a.getBegin() && b.getBegin() <= a.getEnd()) { | |
TextRulerTarget target = new TextRulerTarget(a.getType().getName(), this); | |
TextRulerExample error = new TextRulerExample(b.getDocument(), b, false, target); | |
TextRulerExample truth = new TextRulerExample(a.getDocument(), a, true, target); | |
AnnotationErrorType type; | |
if ((b.getBegin() == a.getBegin() && b.getEnd() < a.getEnd()) | |
|| (b.getBegin() > a.getBegin() && b.getEnd() == a.getEnd())) { | |
type = AnnotationErrorType.EXPANSION; | |
} else if (b.getBegin() > a.getBegin() || b.getEnd() > a.getEnd()) { | |
type = AnnotationErrorType.SHIFTING_LEFT; | |
} else { | |
type = AnnotationErrorType.SHIFTING_RIGHT; | |
} | |
AnnotationError err = new AnnotationError(error, truth, type); | |
if (!tempErrors.contains(err)) | |
tempErrors.add(err); | |
} | |
} | |
} | |
if (tempErrors.size() > 0) { | |
AnnotationError err = tempErrors.get(0); | |
int begin = err.getError().getAnnotation().getBegin(); | |
int end = err.getError().getAnnotation().getEnd(); | |
distance = Math.abs(begin - a.getBegin()) + Math.abs(end - a.getEnd()); | |
for (int j = 1; j < tempErrors.size(); j++) { | |
int begin2 = tempErrors.get(j).getError().getAnnotation().getBegin(); | |
int end2 = tempErrors.get(j).getError().getAnnotation().getEnd(); | |
if (Math.abs(begin2 - a.getBegin()) + Math.abs(end2 - a.getEnd()) < distance) { | |
distance = Math.abs(begin2 - a.getBegin()) + Math.abs(end2 - a.getEnd()); | |
err = tempErrors.get(j); | |
} | |
} | |
result.add(err); | |
matches.add((TrabalAnnotation) err.getTruth().getAnnotation()); | |
matches.add((TrabalAnnotation) err.getError().getAnnotation()); | |
} | |
} | |
// creating shifting + correction errors | |
iterator = gold.iterator(); | |
docIterator = docs.iterator(); | |
while (iterator.hasNext()) { | |
a = iterator.next(); | |
docIterator = docs.iterator(); | |
tempErrors = new ArrayList<AnnotationError>(); | |
while (!matches.contains(a) && docIterator.hasNext()) { | |
b = docIterator.next(); | |
if (!matches.contains(b)) { | |
if (b.getEnd() >= a.getBegin() && b.getBegin() <= a.getEnd()) { | |
TextRulerTarget target = new TextRulerTarget(b.getType().getName(), this); | |
TextRulerExample error = new TextRulerExample(b.getDocument(), b, false, target); | |
TextRulerExample truth = new TextRulerExample(a.getDocument(), a, true, target); | |
AnnotationErrorType type; | |
if ((b.getBegin() == a.getBegin() && b.getEnd() < a.getEnd()) | |
|| (b.getBegin() > a.getBegin() && b.getEnd() == a.getEnd())) { | |
type = AnnotationErrorType.EXPANSION; | |
} else if (b.getBegin() > a.getBegin() || b.getEnd() > a.getEnd()) { | |
type = AnnotationErrorType.SHIFTING_LEFT; | |
} else { | |
type = AnnotationErrorType.SHIFTING_RIGHT; | |
} | |
AnnotationError err = new AnnotationError(error, truth, type); | |
if (!tempErrors.contains(err)) | |
tempErrors.add(err); | |
} | |
} | |
} | |
if (tempErrors.size() > 0) { | |
AnnotationError err = tempErrors.get(0); | |
int begin = err.getError().getAnnotation().getBegin(); | |
int end = err.getError().getAnnotation().getEnd(); | |
distance = Math.abs(begin - a.getBegin()) + Math.abs(end - a.getEnd()); | |
for (int j = 1; j < tempErrors.size(); j++) { | |
int begin2 = tempErrors.get(j).getError().getAnnotation().getBegin(); | |
int end2 = tempErrors.get(j).getError().getAnnotation().getEnd(); | |
if (Math.abs(begin2 - a.getBegin()) + Math.abs(end2 - a.getEnd()) < distance) { | |
distance = Math.abs(begin2 - a.getBegin()) + Math.abs(end2 - a.getEnd()); | |
err = tempErrors.get(j); | |
} | |
} | |
result.add(err); | |
matches.add((TrabalAnnotation) err.getTruth().getAnnotation()); | |
matches.add((TrabalAnnotation) err.getError().getAnnotation()); | |
} | |
} | |
// create deleting type errors | |
docIterator = docs.iterator(); | |
while (docIterator.hasNext()) { | |
b = docIterator.next(); | |
if (!matches.contains(b)) { | |
TextRulerTarget target = new TextRulerTarget(b.getType().getName(), this); | |
TextRulerExample error = new TextRulerExample(b.getDocument(), b, false, target); | |
result.add(new AnnotationError(error, null, AnnotationErrorType.DELETION)); | |
matches.add(b); | |
} | |
} | |
// create annotation type errors | |
iterator = gold.iterator(); | |
while (iterator.hasNext()) { | |
a = iterator.next(); | |
if (!matches.contains(a)) { | |
TextRulerTarget target = new TextRulerTarget(a.getType().getName(), this); | |
TextRulerExample truth = new TextRulerExample(a.getDocument(), a, true, target); | |
result.add(new AnnotationError(null, truth, AnnotationErrorType.ANNOTATION)); | |
matches.add(a); | |
} | |
} | |
} | |
return result; | |
} | |
/** | |
* Is called by doRun(). Creates and tests rules for a given list of error groups created by | |
* createErrorGroups(). | |
* | |
* @param errorGrps | |
* @return The best found list of correction rules. | |
*/ | |
private List<TrabalRule> runAlgorithm(Map<String, List<AnnotationError>> errorGrps) { | |
removeBasics(); | |
inducedRules.clear(); | |
List<TrabalRule> rules = new ArrayList<TrabalRule>(); | |
bestRulesForStatus.clear(); | |
int i = 1; | |
for (List<AnnotationError> each : errorGrps.values()) { | |
if (shouldAbort()) { | |
break; | |
} | |
Collections.sort(each); | |
String status = each.get(0).toString(); | |
sendStatusUpdateToDelegate("Creating basic rules: " + status, | |
TextRulerLearnerState.ML_RUNNING, false); | |
List<TrabalRule> basicRules = removeDuplicateRules(createBasicRules(each)); | |
if (basicRules.size() > maxNumberOfBasicRules) { | |
basicRules = basicRules.subList(0, maxNumberOfBasicRules); | |
} | |
sendStatusUpdateToDelegate("Testing basic rules: " + status, | |
TextRulerLearnerState.ML_RUNNING, false); | |
basicRules = testTrabalRulesOnDocumentSet(basicRules, exampleDocuments, additionalDocuments, | |
"basic rules (" + i + " of " + errorGrps.size() + ")"); | |
if (basicRules.size() > 0) { | |
Collections.sort(basicRules, basicComparator); | |
bestRulesForStatus.add(basicRules.get(0)); | |
} | |
result = actualResult + getRuleStrings(bestRulesForStatus); | |
sendStatusUpdateToDelegate("Testing basic rules: " + status, | |
TextRulerLearnerState.ML_RUNNING, true); | |
List<TrabalRule> learntRules = new ArrayList<TrabalRule>(); | |
for (TrabalRule rule : basicRules) { | |
if (rule.getCoveringStatistics().getCoveredPositivesCount() > 0 | |
&& rule.getCoveringStatistics().getCoveredNegativesCount() == 0) { | |
learntRules.add(rule); | |
} | |
} | |
// Collections.sort(learntRules, basicComparator); | |
List<TrabalRule> enhancedRules = new ArrayList<TrabalRule>(); | |
int rank = 1; | |
for (TrabalRule rule : basicRules) { | |
if (rule.getCoveringStatistics().getCoveredPositivesCount() > 0 | |
&& rule.getCoveringStatistics().getCoveredNegativesCount() > 0) { | |
if (learntRules.size() > 0) { | |
if (rule.getCoveringStatistics().getCoveredPositivesCount() > learntRules.get(0) | |
.getCoveringStatistics().getCoveredPositivesCount()) { | |
rule.setRating(rank); | |
enhancedRules.add(rule); | |
rank++; | |
} | |
} else { | |
rule.setRating(rank); | |
enhancedRules.add(rule); | |
rank++; | |
} | |
} | |
} | |
basicRules.clear(); | |
try { | |
enhancedRules = enhanceRules(enhancedRules, maxNumberOfIterations, new RankedList(idf)); | |
Collections.sort(enhancedRules); | |
if (enhancedRules.size() > maxNumberOfRules) { | |
enhancedRules = enhancedRules.subList(0, maxNumberOfRules); | |
} | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
sendStatusUpdateToDelegate("Testing enhanced rules: " + status, | |
TextRulerLearnerState.ML_RUNNING, false); | |
enhancedRules = testTrabalRulesOnDocumentSet(enhancedRules, exampleDocuments, | |
additionalDocuments, "enhanced rules (" + i + " of " + errorGrps.size() + ")"); | |
for (TrabalRule rule : enhancedRules) { | |
if (rule.getErrorRate() <= maxErrorRate) { | |
learntRules.add(rule); | |
} | |
} | |
enhancedRules.clear(); | |
learntRules = removeDuplicateRules(learntRules); | |
if (learntRules.size() > 0) { | |
Collections.sort(learntRules, enhancedComparator); | |
bestRulesForStatus.remove(bestRulesForStatus.size() - 1); // TODO | |
bestRulesForStatus.add(learntRules.get(0)); | |
result = actualResult + getRuleStrings(bestRulesForStatus); | |
} | |
sendStatusUpdateToDelegate("Testing optimized rules: " + status, | |
TextRulerLearnerState.ML_RUNNING, true); | |
if (learntRules.size() > maxNumberOfRules) | |
learntRules = learntRules.subList(0, maxNumberOfRules); | |
rules.addAll(learntRules); | |
i++; | |
} | |
return getBest(rules); | |
} | |
private List<TrabalRule> enhanceRules(List<TrabalRule> learntRules, int iterations, | |
RankedList conditions) throws Exception { | |
if (shouldAbort() || learntRules.size() == 0) | |
return learntRules; | |
sendStatusUpdateToDelegate("Adding conditions to rules (" + iterations + ". Iteration)... ", | |
TextRulerLearnerState.ML_RUNNING, false); | |
if (iterations > 0) { | |
List<TrabalRule> rules = new ArrayList<TrabalRule>(); | |
List<TrabalRule> newRules = new ArrayList<TrabalRule>(); | |
if (conditions.size() == 0) | |
conditions = createConditions(learntRules); | |
for (int i = 0; i < learntRules.size(); i++) { | |
rules.add(learntRules.get(i)); | |
// TODO amount of conditions? parameter for 50! | |
for (int j = 0; j < conditions.size() && j < 50; j++) { | |
TrabalRule newRule = learntRules.get(i).copy(); | |
if (!newRule.getConditions().contains(conditions.get(j))) { | |
newRule.addCondition(conditions.get(j), (j + 1)); | |
newRule.getRuleString(); | |
newRules.add(newRule); | |
} | |
} | |
} | |
rules.addAll(enhanceRules(newRules, iterations - 1, conditions)); | |
return rules; | |
} | |
return learntRules; | |
} | |
/** | |
* Chooses the best final rules from the results of runAlgorithm(). | |
* | |
* @param rules | |
* @return A non redundant list of correction rules. | |
*/ | |
private List<TrabalRule> getBest(List<TrabalRule> rules) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
Collections.sort(rules, enhancedComparator); | |
while (rules.size() > 0) { | |
List<TrabalRule> testRules = new ArrayList<TrabalRule>(); | |
List<TrabalRule> newRules = new ArrayList<TrabalRule>(); | |
TrabalRule bestRule = rules.get(0); | |
if (bestRule.getErrorRate() > maxErrorRate) { | |
rules.remove(0); | |
continue; | |
} | |
if (bestRule.getCoveringStatistics().getCoveredPositivesCount() == 0) { | |
break; | |
} | |
sendStatusUpdateToDelegate("Chosen rule: " + bestRule, TextRulerLearnerState.ML_RUNNING, | |
false); | |
result.add(bestRule); | |
rules.remove(0); | |
Set<TextRulerExample> coveredExamples = bestRule.getCoveringStatistics() | |
.getCoveredPositiveExamples(); | |
for (TrabalRule rule : rules) { | |
TextRulerStatisticsCollector collector = rule.getCoveringStatistics(); | |
Iterator<TextRulerExample> iterator = coveredExamples.iterator(); | |
while (iterator.hasNext()) { | |
TextRulerExample example = iterator.next(); | |
if (collector.getCoveredPositiveExamples().contains(example)) { | |
collector.getCoveredPositiveExamples().remove(example); | |
// if (rule.getErrorType() == ErrorType.ANNOTATION | |
// || rule.getErrorType() == ErrorType.DELETING) { | |
// collector.incCoveredPositives(-1); | |
// } else { | |
collector.incCoveredPositives(-1); | |
// } | |
} | |
if (collector.getCoveredPositivesCount() <= 0) { | |
break; | |
} | |
} | |
if (rule.getErrorRate() > maxErrorRate | |
|| rule.getCoveringStatistics().getCoveredPositivesCount() <= 0) { | |
if (bestRulesForStatus.contains(rule)) | |
bestRulesForStatus.remove(rule); | |
this.result = actualResult + getRuleStrings(bestRulesForStatus); | |
continue; | |
} else if (rule.contains(bestRule.getAnnotation()) | |
|| rule.contains(bestRule.getTargetAnnotation())) { | |
testRules.add(rule.copy()); | |
} else { | |
newRules.add(rule); | |
} | |
} | |
rules = new ArrayList<TrabalRule>(); | |
updateDocumentData(bestRule); | |
removeBasics(); | |
testRules = testTrabalRulesOnDocumentSet(testRules, exampleDocuments, additionalDocuments, | |
"final rules "); | |
for (TrabalRule rule : testRules) { | |
if (rule.getErrorRate() <= maxErrorRate | |
&& rule.getCoveringStatistics().getCoveredPositivesCount() > 0) { | |
rules.add(rule); | |
} | |
} | |
rules.addAll(newRules); | |
Collections.sort(rules, enhancedComparator); | |
} | |
this.result = actualResult; | |
return result; | |
} | |
/** | |
* Is called from getBest(). Applies the chosen rules to the additional documents, so the learner | |
* can ignore the already corrected errors in the next interation. | |
* | |
* @param rule | |
*/ | |
private void updateDocumentData(TrabalRule rule) { | |
try { | |
sendStatusUpdateToDelegate("Writing rules...", TextRulerLearnerState.ML_RUNNING, false); | |
rule.saveToRulesFile(getTempRulesFileName()); | |
for (TextRulerExampleDocument doc : additionalDocuments.getDocuments()) { | |
ae.process(doc.getCAS()); | |
} | |
if (bestRulesForStatus.contains(rule)) { | |
bestRulesForStatus.remove(rule); | |
} | |
actualResult += rule.toString(); | |
result = actualResult + getRuleStrings(bestRulesForStatus); | |
sendStatusUpdateToDelegate("Writing rules...", TextRulerLearnerState.ML_RUNNING, true); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
/** | |
* Creates basic rules for the given types of errors. | |
* | |
* @param errors | |
* @return list of basic rules | |
*/ | |
public List<TrabalRule> createBasicRules(List<AnnotationError> errors) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
for (AnnotationError each : errors) { | |
switch (each.getType()) { | |
case SHIFTING_LEFT: | |
case SHIFTING_RIGHT: | |
case EXPANSION: | |
result.addAll(createShiftingRules(each)); | |
break; | |
case CORRECTION: | |
result.addAll(createCorrectionRules(each)); | |
break; | |
case ANNOTATION: | |
result.addAll(createAnnotationRules(each)); | |
break; | |
case DELETION: | |
result.addAll(createDeletionRules(each)); | |
break; | |
} | |
} | |
return getBestBasicRule(result); | |
} | |
/** | |
* Chooses the rules which are expected to be the best. | |
* | |
* @param rules | |
* @return list of best basic rules | |
*/ | |
private List<TrabalRule> getBestBasicRule(List<TrabalRule> rules) { | |
if (rules.size() < maxNumberOfBasicRules) | |
return rules; | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
sendStatusUpdateToDelegate("Choosing basic rules... ", TextRulerLearnerState.ML_RUNNING, false); | |
Map<Integer, Integer> distribution = new HashMap<Integer, Integer>(); | |
for (TrabalRule rule : rules) { | |
Integer key = new Integer(rule.hashCode()); | |
if (distribution.containsKey(key)) { | |
distribution.put(key, new Integer(distribution.get(key).intValue() + 1)); | |
} else { | |
distribution.put(key, new Integer(1)); | |
} | |
} | |
List<TrabalRule> newRules; | |
while (result.size() < maxNumberOfBasicRules && distribution.size() > 0) { | |
if (shouldAbort()) | |
break; | |
int max = 0; | |
Integer maxKey = new Integer(0); | |
for (Integer key : distribution.keySet()) { | |
if (distribution.get(key).intValue() > max) { | |
max = distribution.get(key).intValue(); | |
maxKey = key; | |
} | |
} | |
distribution.remove(maxKey); | |
newRules = new ArrayList<TrabalRule>(); | |
TrabalRule example = null; | |
for (TrabalRule rule : rules) { | |
Integer key = new Integer(rule.hashCode()); | |
if (distribution.containsKey(key)) { | |
if (distribution.get(key).intValue() == max) { | |
result.add(rule); | |
if (example == null) | |
example = rule; | |
} | |
} | |
} | |
for (TrabalRule rule : rules) { | |
newRules.add(rule); | |
} | |
rules = newRules; | |
} | |
if (result.size() > maxNumberOfBasicRules) | |
return result.subList(0, maxNumberOfBasicRules); | |
return result; | |
} | |
/** | |
* Creates basic rules for shifting type errors. | |
* | |
* @param error | |
* @return list of shifting rules | |
*/ | |
private List<TrabalRule> createShiftingRules(AnnotationError each) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExample error = each.getError(); | |
TextRulerExample truth = each.getTruth(); | |
int errorBegin = error.getAnnotation().getBegin(); | |
int errorEnd = error.getAnnotation().getEnd(); | |
int truthBegin = truth.getAnnotation().getBegin(); | |
int truthEnd = truth.getAnnotation().getEnd(); | |
if ((errorBegin > truthBegin && errorEnd > truthEnd) // shift both borders to left | |
|| (errorBegin < truthBegin && errorEnd < truthEnd)) { // shift both borders to right | |
result.addAll(createShiftAllRules(error, truth)); | |
} else if (errorBegin > truthBegin // expand to left | |
|| errorEnd < truthEnd) { // expand to right | |
result.addAll(createExpansionRules(error, truth)); | |
} else { | |
result.addAll(createShiftAllRules(error, truth)); // reduce the annotation | |
} | |
return result; | |
} | |
/** | |
* Creates basic rules for shifting errors, where both borders are shiftet to the same side or the | |
* size of the annotation is reduced. | |
* | |
* @param error | |
* The error example | |
* @param truth | |
* The truth example | |
*/ | |
private List<TrabalRule> createShiftAllRules(TextRulerExample error, TextRulerExample truth) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExampleDocument document = error.getDocument(); | |
CAS cas = error.getDocumentCAS(); | |
List<TrabalRuleItem> truthLeftBorder = getBorderOfExample(truth, document, cas, true); | |
List<TrabalRuleItem> truthRightBorder = getBorderOfExample(truth, document, cas, false); | |
List<TrabalRuleItem> errorLeftBorder = getBorderOfExample(error, document, cas, true); | |
TextRulerTarget target = error.getTarget(); | |
TrabalAnnotation errorAnnotation = (TrabalAnnotation) error.getAnnotation(); | |
TrabalAnnotation truthAnnotation = (TrabalAnnotation) truth.getAnnotation(); | |
// boolean shiftToLeft = error.getAnnotation().getBegin() > truth.getAnnotation().getBegin(); | |
// if (shiftToLeft) { | |
for (TrabalRuleItem a : truthLeftBorder) { | |
for (TrabalRuleItem b : errorLeftBorder) { | |
for (TrabalRuleItem c : truthRightBorder) { | |
ShiftAllRule newRule; | |
if (truth.getAnnotation().getEnd() < error.getAnnotation().getEnd()) | |
newRule = new ShiftAllRule(this, target, errorAnnotation, truthAnnotation, | |
AnnotationErrorType.SHIFTING_LEFT); | |
else | |
newRule = new ShiftAllRule(this, target, errorAnnotation, truthAnnotation, | |
AnnotationErrorType.SHIFTING_RIGHT); | |
newRule.setFrontBoundaryItem(a); | |
if (b.getAnnotation().getBegin() != a.getAnnotation().getBegin()) { | |
newRule.setErrorBoundaryItem(b); | |
} | |
newRule.setRearBoundaryItem(c); | |
result.add(newRule); | |
} | |
} | |
} | |
// } else { | |
// for (TrabalRuleItem a : errorLeftBorder) { | |
// for (TrabalRuleItem b : truthLeftBorder) { | |
// for (TrabalRuleItem c : truthRightBorder) { | |
// ShiftingRule newRule = new ShiftingRule(this, target, errorAnnotation, truthAnnotation, | |
// AnnotationErrorType.SHIFTING_RIGHT); | |
// newRule.addFillerItem(a); | |
// if (b.getAnnotation().getBegin() != a.getAnnotation().getBegin()) { | |
// newRule.addFillerItem(b); | |
// } | |
// newRule.addFillerItem(c); | |
// result.add(newRule); | |
// } | |
// } | |
// } | |
// } | |
return result; | |
} | |
/** | |
* Creates basic rules for expansion type errors. | |
* | |
* @param error | |
* @param truth | |
* @return list of expansion rules | |
*/ | |
private List<TrabalRule> createExpansionRules(TextRulerExample error, TextRulerExample truth) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExampleDocument document = error.getDocument(); | |
CAS cas = error.getDocumentCAS(); | |
boolean shiftToLeft = error.getAnnotation().getBegin() > truth.getAnnotation().getBegin(); | |
List<TrabalRuleItem> border; | |
if (shiftToLeft) { | |
border = getBorderOfExample(truth, document, cas, true); | |
} else { | |
border = getBorderOfExample(truth, document, cas, false); | |
} | |
for (TrabalRuleItem item : border) { | |
ExpansionRule newRule = new ExpansionRule(this, error.getTarget(), | |
(TrabalAnnotation) error.getAnnotation(), (TrabalAnnotation) truth.getAnnotation(), | |
AnnotationErrorType.EXPANSION); | |
if (shiftToLeft) | |
newRule.setFrontBoundaryItem(item); | |
else | |
newRule.setRearBoundaryItem(item); | |
if (!result.contains(newRule)) | |
result.add(newRule); | |
} | |
return result; | |
} | |
/** | |
* Creates basic rules for annotation type errors. | |
* | |
* @param error | |
* @return list of annotation rules | |
*/ | |
private List<TrabalRule> createAnnotationRules(AnnotationError each) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExample truth = each.getTruth(); | |
TextRulerExampleDocument document = additionalDocuments.getDocuments().get( | |
exampleDocuments.getDocuments().indexOf(truth.getDocument())); | |
CAS cas = document.getCAS(); | |
List<TrabalRuleItem> truthLeftBorder = getBorderOfExample(truth, document, cas, true); | |
List<TrabalRuleItem> truthRightBorder = getBorderOfExample(truth, document, cas, false); | |
TextRulerTarget target = truth.getTarget(); | |
TrabalAnnotation truthAnnotation = (TrabalAnnotation) truth.getAnnotation(); | |
for (TrabalRuleItem front : truthLeftBorder) { | |
for (TrabalRuleItem rear : truthRightBorder) { | |
AnnotationRule newRule = new AnnotationRule(this, target, truthAnnotation); | |
newRule.setFrontBoundary(front); | |
newRule.setRearBoundary(rear); | |
if (!result.contains(newRule)) | |
result.add(newRule); | |
} | |
} | |
return result; | |
} | |
/** | |
* Creates basic rules for deletion type errors. | |
* | |
* @param error | |
* @return list of deletion rules | |
*/ | |
private List<TrabalRule> createDeletionRules(AnnotationError each) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExample error = each.getError(); | |
TrabalRule blankRule = new DeletionRule(this, error.getTarget(), | |
(TrabalAnnotation) error.getAnnotation()); | |
result.add(blankRule); | |
return result; | |
} | |
/** | |
* Creates basic rules for correction type errors. | |
* | |
* @param error | |
* @return list of correction rules | |
*/ | |
private List<TrabalRule> createCorrectionRules(AnnotationError each) { | |
List<TrabalRule> result = new ArrayList<TrabalRule>(); | |
TextRulerExample error = each.getError(); | |
TextRulerExample truth = each.getTruth(); | |
TrabalRule blankRule = new CorrectionRule(this, error.getTarget(), | |
(TrabalAnnotation) error.getAnnotation(), (TrabalAnnotation) truth.getAnnotation()); | |
result.add(blankRule); | |
return result; | |
} | |
/** | |
* Examines the border of an example annotation and returns all annotations before and on that | |
* border if looking at the left border or returns all annotations on and after the example if | |
* looking at the right border. | |
* | |
* @param example | |
* The example that should be examined | |
* @param examineLeftBorder | |
* set true for the left border or false for the right border to be examined | |
*/ | |
private List<TrabalRuleItem> getBorderOfExample(TextRulerExample example, | |
TextRulerExampleDocument document, CAS cas, boolean examineLeftBorder) { | |
List<TrabalRuleItem> result = new ArrayList<TrabalRuleItem>(); | |
int begin = example.getAnnotation().getBegin(); | |
int end = example.getAnnotation().getEnd(); | |
TrabalRuleItem ruleItem = new TrabalRuleItem(begin, end, cas, enableFeatures); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, begin, end); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().subiterator(pointer); | |
if (examineLeftBorder) { | |
result.addAll(getTermsBefore(ruleItem, document)); | |
iterator.moveToFirst(); | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
AnnotationFS a = (AnnotationFS) fs; | |
if (a.getBegin() == begin && a.getEnd() < end) { | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
result.add(new TrabalRuleItem(new TrabalAnnotation(a, document, enableFeatures))); | |
} | |
} else { | |
if (a.getBegin() > begin) { | |
break; | |
} | |
} | |
iterator.moveToNext(); | |
} | |
} else { | |
iterator.moveToLast(); | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
AnnotationFS a = (AnnotationFS) fs; | |
if (a.getBegin() > begin && a.getEnd() == end) { | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
result.add(new TrabalRuleItem(new TrabalAnnotation(a, document, enableFeatures))); | |
} | |
} else { | |
break; | |
} | |
iterator.moveToPrevious(); | |
} | |
result.addAll(getTermsAfter(ruleItem, document)); | |
} | |
return result; | |
} | |
/** | |
* Creates conditions for a given rule. | |
* | |
* @param rules | |
* @return ranked list of conditions | |
*/ | |
public RankedList createConditions(List<TrabalRule> rules) { | |
if (rules.size() == 0) | |
return new RankedList(idf); | |
RankedList result; | |
RankedList error = new RankedList(idf); | |
RankedList truth = new RankedList(idf); | |
switch (rules.get(0).getErrorType()) { | |
case SHIFTING_LEFT: | |
case SHIFTING_RIGHT: | |
case EXPANSION: | |
for (TrabalRule rule : rules) { | |
error.addAll(createConditions(rule.getAnnotation())); | |
truth.addAll(createConditions(rule.getTargetAnnotation())); | |
} | |
result = error.subtract(truth.unite(getPositiveExamplesFor(rules.get(0) | |
.getTargetAnnotation().getType()))); | |
setNegative(result); | |
result.addAll(truth.cut( | |
getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType())).subtract( | |
error)); | |
return result; | |
case CORRECTION: | |
for (TrabalRule rule : rules) { | |
truth.addAll(createConditions(rule.getAnnotation())); | |
} | |
result = getPositiveExamplesFor(rules.get(0).getAnnotation().getType()).subtract( | |
truth.unite(getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType()))); | |
setNegative(result); | |
if (getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType()).size() > 0) { | |
result.addAll(truth.cut( | |
getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType())).subtract( | |
getPositiveExamplesFor(rules.get(0).getAnnotation().getType()))); | |
} else { | |
result.addAll(truth); | |
} | |
return result; | |
case ANNOTATION: | |
for (TrabalRule rule : rules) { | |
truth.addAll(createConditions(rule.getFrontBoundary(), rule.getRearBoundary(), | |
rule.getTargetAnnotation())); | |
} | |
if (getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType()).size() > 0) { | |
return truth.cut(getPositiveExamplesFor(rules.get(0).getTargetAnnotation().getType())); | |
} | |
return truth; | |
case DELETION: | |
for (TrabalRule rule : rules) { | |
error.addAll(createConditions(rule.getAnnotation())); | |
} | |
result = getPositiveExamplesFor(rules.get(0).getAnnotation().getType()).subtract(error); | |
setNegative(result); | |
result.addAll(createConditions(rules.get(0).getAnnotation()).subtract( | |
getPositiveExamplesFor(rules.get(0).getAnnotation().getType()))); | |
return result; | |
} | |
return null; | |
} | |
public RankedList getPositiveExamplesFor(Type type) { | |
if (positiveExamples.containsKey(type.getShortName())) { | |
return positiveExamples.get(type.getShortName()); | |
} else { | |
return new RankedList(idf); | |
} | |
} | |
public void setNegative(RankedList list) { | |
for (Condition element : list.getRanking().keySet()) { | |
element.setNegative(); | |
} | |
for (Condition element : list) { | |
element.setNegative(); | |
} | |
} | |
/** | |
* Creates conditions for a rule given by its boundaries and annotation type. | |
* | |
* @param frontBoundary | |
* @param rearBoundary | |
* @param truth | |
* @return ranked list of conditions | |
*/ | |
private RankedList createConditions(TrabalRuleItem frontBoundary, TrabalRuleItem rearBoundary, | |
TrabalAnnotation truth) { | |
RankedList result = new RankedList(idf); | |
TextRulerExampleDocument doc = additionalDocuments.getDocuments().get( | |
exampleDocuments.getDocuments().indexOf(truth.getDocument())); | |
CAS cas = doc.getCAS(); | |
for (TrabalRuleItem item : getTermsBefore(frontBoundary, doc)) { | |
result.add(new Condition(ConditionType.AFTER, item)); | |
} | |
for (TrabalRuleItem item : getTermsAfter(rearBoundary, doc)) { | |
result.add(new Condition(ConditionType.BEFORE, item)); | |
} | |
for (TrabalRuleItem item : getConsumingTerms(frontBoundary, doc)) { | |
if (!item.getAnnotation().getType().equals(cas.getDocumentAnnotation().getType())) { | |
result.add(new Condition(ConditionType.PARTOF, item)); | |
} | |
} | |
if (frontBoundary.getAnnotation().getBegin() < truth.getBegin()) { | |
for (TrabalRuleItem item : getTermsAfter(frontBoundary, doc)) { | |
result.add(new Condition(ConditionType.STARTSWITH, item)); | |
} | |
} else { | |
for (List<TrabalRuleItem> list : getFirstTermsWithinBounds(frontBoundary.getAnnotation() | |
.getBegin(), rearBoundary.getAnnotation().getEnd(), doc, cas, 1)) { | |
result.add(new Condition(ConditionType.STARTSWITH, list.get(0))); | |
} | |
} | |
if (rearBoundary.getAnnotation().getEnd() > truth.getEnd()) { | |
for (TrabalRuleItem item : getTermsBefore(rearBoundary, doc)) { | |
result.add(new Condition(ConditionType.ENDSWITH, item)); | |
} | |
} else { | |
for (List<TrabalRuleItem> list : getLastTermsWithinBounds(rearBoundary.getAnnotation() | |
.getBegin(), frontBoundary.getAnnotation().getEnd(), doc, cas, 1)) { | |
result.add(new Condition(ConditionType.ENDSWITH, list.get(0))); | |
} | |
} | |
return result; | |
} | |
/** | |
* Creates conditions for a rule, given by its annotation. | |
* | |
* @param error | |
* @return ranked list of conditions | |
*/ | |
private RankedList createConditions(TrabalAnnotation annotation) { | |
RankedList result = new RankedList(idf); | |
TrabalRuleItem ruleItem = new TrabalRuleItem(annotation); | |
TextRulerExampleDocument doc = annotation.getDocument(); | |
CAS cas = doc.getCAS(); | |
for (TrabalRuleItem item : getTermsBefore(ruleItem, doc)) { | |
result.add(new Condition(ConditionType.AFTER, item)); | |
} | |
for (TrabalRuleItem item : getTermsAfter(ruleItem, doc)) { | |
result.add(new Condition(ConditionType.BEFORE, item)); | |
} | |
for (TrabalRuleItem item : getConsumingTerms(ruleItem, doc)) { | |
if (!item.getAnnotation().getType().equals(cas.getDocumentAnnotation().getType())) { | |
result.add(new Condition(ConditionType.PARTOF, item)); | |
} | |
} | |
for (List<TrabalRuleItem> list : getFirstTermsWithinBounds(ruleItem.getAnnotation().getBegin(), | |
ruleItem.getAnnotation().getEnd(), doc, cas, 1)) { | |
result.add(new Condition(ConditionType.STARTSWITH, list.get(0))); | |
} | |
for (List<TrabalRuleItem> list : getLastTermsWithinBounds(ruleItem.getAnnotation().getBegin(), | |
ruleItem.getAnnotation().getEnd(), doc, cas, 1)) { | |
result.add(new Condition(ConditionType.ENDSWITH, list.get(0))); | |
} | |
for (TrabalRuleItem item : getSingleTermsWithinBounds(ruleItem.getAnnotation().getBegin(), | |
ruleItem.getAnnotation().getEnd(), doc, cas)) { | |
result.add(new Condition(ConditionType.CONTAINS, item)); | |
} | |
return result; | |
} | |
/** | |
* Returns all items directly preceding the given example rule item. | |
* | |
* @param ruleItem | |
* Examined rule item | |
* @param document | |
* Example document | |
* @param cas | |
* The CAS | |
* @return List of rule items that precede the given example | |
*/ | |
private List<TrabalRuleItem> getTermsBefore(TrabalRuleItem ruleItem, | |
TextRulerExampleDocument document) { | |
List<TrabalRuleItem> result = new ArrayList<TrabalRuleItem>(); | |
int begin = ruleItem.getAnnotation().getBegin(); | |
CAS cas = document.getCAS(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, begin, begin); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
int nextEnd = -1; | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
if (fs instanceof AnnotationFS) { | |
AnnotationFS a = (AnnotationFS) fs; | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
if (a.getEnd() > begin) { | |
iterator.moveToPrevious(); | |
continue; | |
} | |
if (nextEnd == -1) { | |
nextEnd = a.getEnd(); | |
} | |
if (a.getEnd() >= nextEnd && a.getEnd() <= begin) { | |
TrabalRuleItem term = new TrabalRuleItem(new TrabalAnnotation(a, document, | |
enableFeatures)); | |
result.add(term); | |
} | |
} | |
} | |
iterator.moveToPrevious(); | |
} | |
return result; | |
} | |
/** | |
* Returns all items directly following the given example rule item. | |
* | |
* @param ruleItem | |
* Examined rule item | |
* @param document | |
* Example document | |
* @param cas | |
* The CAS | |
* @return List of rule items that follow the given example | |
*/ | |
private List<TrabalRuleItem> getTermsAfter(TrabalRuleItem ruleItem, | |
TextRulerExampleDocument document) { | |
List<TrabalRuleItem> result = new ArrayList<TrabalRuleItem>(); | |
int end = ruleItem.getAnnotation().getEnd(); | |
CAS cas = document.getCAS(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, end, Integer.MAX_VALUE); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
int nextBegin = -1; | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
if (fs instanceof AnnotationFS) { | |
AnnotationFS a = (AnnotationFS) fs; | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
if (nextBegin == -1) { | |
nextBegin = a.getBegin(); | |
} | |
if (a.getBegin() <= nextBegin && a.getBegin() >= end) { | |
TrabalRuleItem term = new TrabalRuleItem(new TrabalAnnotation(a, document, | |
enableFeatures)); | |
result.add(term); | |
} | |
} | |
} | |
iterator.moveToNext(); | |
} | |
return result; | |
} | |
/** | |
* Gets the left border within the bounds, that means a list of elements that start with the | |
* border and those that end right before it. | |
* | |
* @param startPos | |
* @param endPos | |
* @param document | |
* @param cas | |
* @param numberOfSideItems | |
* @return first term | |
*/ | |
private List<List<TrabalRuleItem>> getFirstTermsWithinBounds(int startPos, int endPos, | |
TextRulerExampleDocument document, CAS cas, int numberOfSideItems) { | |
List<List<TrabalRuleItem>> preItems = new ArrayList<List<TrabalRuleItem>>(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, startPos, endPos); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
List<AnnotationFS> startAs = new ArrayList<AnnotationFS>(); | |
int firstBegin = -1; | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
AnnotationFS a = (AnnotationFS) fs; | |
if (a.getBegin() >= startPos && a.getEnd() <= endPos) { | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
if (firstBegin == -1) { | |
firstBegin = a.getBegin(); | |
} | |
if (a.getBegin() == firstBegin) | |
startAs.add(a); | |
} | |
iterator.moveToNext(); | |
} else { | |
iterator.moveToNext(); | |
} | |
} | |
for (AnnotationFS each : startAs) { | |
List<TrabalRuleItem> startList = new ArrayList<TrabalRuleItem>(); | |
TrabalRuleItem term = new TrabalRuleItem(new TrabalAnnotation(each, document, enableFeatures)); | |
startList.add(term); | |
preItems.add(startList); | |
} | |
preItems = addFollowing(preItems, endPos, document, cas, 1, numberOfSideItems); | |
return preItems; | |
} | |
/** | |
* Gets the right border within the bounds, that means a list of elements that end with the border | |
* and those that start right behind it. | |
* | |
* @param startPos | |
* @param endPos | |
* @param document | |
* @param cas | |
* @param numberOfSideItems | |
* @return last term | |
*/ | |
private List<List<TrabalRuleItem>> getLastTermsWithinBounds(int startPos, int endPos, | |
TextRulerExampleDocument document, CAS cas, int numberOfSideItems) { | |
List<List<TrabalRuleItem>> postItems = new ArrayList<List<TrabalRuleItem>>(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, startPos, endPos); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
iterator.moveToLast(); | |
List<AnnotationFS> endAs = new ArrayList<AnnotationFS>(); | |
int lastEnd = -1; | |
while (iterator.isValid()) { | |
FeatureStructure fs = iterator.get(); | |
AnnotationFS a = (AnnotationFS) fs; | |
if (a.getEnd() <= endPos && a.getBegin() >= startPos) { | |
if (!filterSet.contains(a.getType().getName()) | |
&& !a.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
if (lastEnd == -1) { | |
lastEnd = a.getEnd(); | |
} | |
if (a.getEnd() == lastEnd) | |
endAs.add(a); | |
} | |
iterator.moveToPrevious(); | |
} else { | |
iterator.moveToPrevious(); | |
} | |
} | |
for (AnnotationFS each : endAs) { | |
List<TrabalRuleItem> endList = new ArrayList<TrabalRuleItem>(); | |
TrabalRuleItem term = new TrabalRuleItem(new TrabalAnnotation(each, document, enableFeatures)); | |
endList.add(term); | |
postItems.add(endList); | |
} | |
postItems = addPreceding(postItems, startPos, document, cas, 1, numberOfSideItems); | |
return postItems; | |
} | |
/** | |
* Returns a list of rule elements that precede the given elements. | |
* | |
* @param lists | |
* @param till | |
* @param document | |
* @param cas | |
* @param index | |
* @param maxNumberOfItems | |
* @return updated list of items | |
*/ | |
private List<List<TrabalRuleItem>> addPreceding(List<List<TrabalRuleItem>> lists, int till, | |
TextRulerExampleDocument document, CAS cas, int index, int maxNumberOfItems) { | |
if (index >= maxNumberOfItems) { | |
return lists; | |
} | |
List<List<TrabalRuleItem>> result = new ArrayList<List<TrabalRuleItem>>(); | |
for (List<TrabalRuleItem> list : lists) { | |
if (shouldAbort()) { | |
break; | |
} | |
List<List<TrabalRuleItem>> newLists = new ArrayList<List<TrabalRuleItem>>(); | |
TrabalRuleItem first = list.get(0); | |
List<TrabalRuleItem> termsBefore = getTermsBefore(first, document); | |
if (termsBefore.isEmpty()) { | |
continue; | |
} | |
for (TrabalRuleItem eachBefore : termsBefore) { | |
if (eachBefore.getAnnotation().getBegin() >= till) { | |
List<TrabalRuleItem> newList = new ArrayList<TrabalRuleItem>(); | |
newList.addAll(list); | |
newList.add(0, eachBefore); | |
newLists.add(newList); | |
newLists = addPreceding(newLists, till, document, cas, index + 1, maxNumberOfItems); | |
} | |
} | |
result.addAll(newLists); | |
} | |
if (result.isEmpty()) { | |
return lists; | |
} | |
return result; | |
} | |
/** | |
* Returns a list of rule elements that follow on the given elements. | |
* | |
* @param lists | |
* @param till | |
* @param document | |
* @param cas | |
* @param index | |
* @param maxNumberOfItems | |
* @return updated list of items | |
*/ | |
private List<List<TrabalRuleItem>> addFollowing(List<List<TrabalRuleItem>> lists, int till, | |
TextRulerExampleDocument document, CAS cas, int index, int maxNumberOfItems) { | |
if (index >= maxNumberOfItems) { | |
return lists; | |
} | |
List<List<TrabalRuleItem>> result = new ArrayList<List<TrabalRuleItem>>(); | |
for (List<TrabalRuleItem> list : lists) { | |
if (shouldAbort()) { | |
break; | |
} | |
List<List<TrabalRuleItem>> newLists = new ArrayList<List<TrabalRuleItem>>(); | |
TrabalRuleItem last = list.get(list.size() - 1); | |
List<TrabalRuleItem> termsAfter = getTermsAfter(last, document); | |
if (termsAfter.isEmpty()) { | |
continue; | |
} | |
for (TrabalRuleItem eachAfter : termsAfter) { | |
if (eachAfter.getAnnotation().getEnd() <= till) { | |
List<TrabalRuleItem> newList = new ArrayList<TrabalRuleItem>(); | |
newList.addAll(list); | |
newList.add(eachAfter); | |
newLists.add(newList); | |
newLists = addFollowing(newLists, till, document, cas, index + 1, maxNumberOfItems); | |
} | |
} | |
result.addAll(newLists); | |
} | |
if (result.isEmpty()) { | |
return lists; | |
} | |
return result; | |
} | |
/** | |
* Returns a list of all items, consuming the given RuleItem. | |
*/ | |
private List<TrabalRuleItem> getConsumingTerms(TrabalRuleItem ruleItem, | |
TextRulerExampleDocument document) { | |
CAS cas = document.getCAS(); | |
List<TrabalRuleItem> result = new ArrayList<TrabalRuleItem>(); | |
int begin = ruleItem.getAnnotation().getBegin(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, 0, begin); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
while (iterator.hasNext()) { | |
TrabalAnnotation annotation = new TrabalAnnotation(iterator.next(), enableFeatures); | |
if (annotation.getBegin() >= begin) { | |
break; | |
} | |
if (annotation.getEnd() >= ruleItem.getAnnotation().getEnd()) { | |
result.add(new TrabalRuleItem(annotation)); | |
} | |
} | |
return result; | |
} | |
/** | |
* Returns a list of rule items within the given bounds. | |
* | |
* @param begin | |
* @param end | |
* @param doc | |
* @param cas | |
* @return list of terms | |
*/ | |
private List<TrabalRuleItem> getSingleTermsWithinBounds(int begin, int end, | |
TextRulerExampleDocument doc, CAS cas) { | |
Set<TrabalRuleItem> result = new HashSet<TrabalRuleItem>(); | |
Type frameType = cas.getTypeSystem().getType(ANNOTATION_TYPE_FRAME); | |
AnnotationFS pointer = cas.createAnnotation(frameType, begin, end); | |
FSIterator<AnnotationFS> iterator = cas.getAnnotationIndex().iterator(pointer); | |
while (iterator.hasNext()) { | |
TrabalAnnotation annotation = new TrabalAnnotation(iterator.next(), enableFeatures); | |
if (annotation.getBegin() > end) { | |
break; | |
} | |
if (!filterSet.contains(annotation.getType().getName()) | |
&& !annotation.getType().getName().equals(ANNOTATION_TYPE_BASIC)) { | |
result.add(new TrabalRuleItem(annotation)); | |
} | |
} | |
return new ArrayList<TrabalRuleItem>(result); | |
} | |
/** | |
* Applies a set of rules to the additional data and compares it to the training data to rate | |
* them. | |
* | |
* @param rules | |
* @param documents | |
* @param additionalDocuments | |
* @param ruleSet | |
* @return list of rules | |
*/ | |
public List<TrabalRule> testTrabalRulesOnDocumentSet(List<TrabalRule> rules, | |
final TextRulerExampleDocumentSet documents, | |
final TextRulerExampleDocumentSet additionalDocuments, String ruleSet) { | |
if (rules.isEmpty()) | |
return rules; | |
List<TextRulerStatisticsCollector> sums = new ArrayList<TextRulerStatisticsCollector>(); | |
for (TrabalRule each : rules) { | |
sums.add(new TextRulerStatisticsCollector()); | |
} | |
List<TextRulerExampleDocument> goldDocs; | |
List<TextRulerExampleDocument> additionalDocs; | |
goldDocs = documents.getDocuments(); | |
additionalDocs = additionalDocuments.getDocuments(); | |
CAS theTestCAS = getTestCAS(); | |
int counter = 0; | |
for (TrabalRule rule : rules) { | |
counter++; | |
String ruleString = rule.getRuleString(); | |
String ruleInfo = getRuleInfo(rule); | |
System.out.println("testing: " + ruleString); | |
if (inducedRules.containsKey(ruleString)) { | |
rule.setCoveringStatistics(inducedRules.get(ruleString)); | |
System.out.println("skipped with " + inducedRules.get(ruleString)); | |
} else { | |
for (int i = 0; i < goldDocs.size(); i++) { | |
TextRulerExampleDocument goldDoc = goldDocs.get(i); | |
TextRulerExampleDocument additionalDoc = additionalDocs.get(i); | |
sendStatusUpdateToDelegate("Testing " + ruleSet + ruleInfo + " on document " + (i + 1) | |
+ " of " + goldDocs.size() + " : rule " + counter + " of " + rules.size(), | |
TextRulerLearnerState.ML_RUNNING, false); | |
TextRulerStatisticsCollector sumC = new TextRulerStatisticsCollector(); | |
prepareTestCas(theTestCAS, goldDoc, additionalDoc); | |
testRuleOnDocument(rule, goldDoc, additionalDoc, sumC, theTestCAS); | |
sums.get(counter - 1).add(sumC); | |
int n = sumC.getCoveredNegativesCount(); | |
int p = sumC.getCoveredPositivesCount(); | |
int pnorm = p; | |
if (pnorm == 0) { | |
pnorm = 1; | |
} | |
if (n / pnorm > maxErrorRate) { | |
System.out.println("stopped:" + sumC); | |
break; | |
} | |
if (shouldAbort()) | |
return rules; | |
} | |
TextRulerStatisticsCollector c = sums.get(counter - 1); | |
rule.setCoveringStatistics(sums.get(counter - 1)); | |
inducedRules.put(ruleString, c); | |
} | |
} | |
for (int ruleIndex = 0; ruleIndex < rules.size(); ruleIndex++) { | |
rules.get(ruleIndex).setCoveringStatistics(sums.get(ruleIndex)); | |
} | |
// GlobalCASSource.releaseCAS(theTestCAS); | |
sums.clear(); | |
return rules; | |
} | |
private String getRuleInfo(TrabalRule rule) { | |
String ruleInfo; | |
if (rule.getAnnotation() != null && rule.getTargetAnnotation() != null) { | |
ruleInfo = " " + rule.getAnnotation().getType().getShortName() + "(" | |
+ rule.getAnnotation().getBegin() + "," + rule.getAnnotation().getEnd() + ") -> " | |
+ rule.getTargetAnnotation().getType().getShortName() + "(" | |
+ rule.getTargetAnnotation().getBegin() + "," + rule.getTargetAnnotation().getEnd() | |
+ ")"; | |
} else if (rule.getTargetAnnotation() != null) { | |
ruleInfo = " Annotate " + rule.getTargetAnnotation().getType().getShortName() + "(" | |
+ rule.getTargetAnnotation().getBegin() + "," + rule.getTargetAnnotation().getEnd() | |
+ ")"; | |
} else { | |
ruleInfo = " Delete " + rule.getAnnotation().getType().getShortName() + "(" | |
+ rule.getAnnotation().getBegin() + "," + rule.getAnnotation().getEnd() + ")"; | |
} | |
return ruleInfo; | |
} | |
/** | |
* Resets the test CAS and adds the annotations of the additional data. | |
* | |
* @param testCas | |
* @param goldDoc | |
* @param additionalDoc | |
*/ | |
private void prepareTestCas(CAS testCas, TextRulerExampleDocument goldDoc, | |
TextRulerExampleDocument additionalDoc) { | |
testCas.reset(); | |
CAS goldCas = goldDoc.getCAS(); | |
CAS additionalCas = additionalDoc.getCAS(); | |
testCas.setDocumentText(goldCas.getDocumentText()); | |
CasCopier.copyCas(additionalCas, testCas, testCas.getDocumentText() == null); | |
// for (AnnotationFS fs : additionalCas.getAnnotationIndex()) { | |
// Type t = testCas.getTypeSystem().getType(fs.getType().getName()); | |
// if (t != null) { | |
// // TODO what about the features!! | |
// AnnotationFS createAnnotation = testCas.createAnnotation(t, fs.getBegin(), fs.getEnd()); | |
// testCas.addFsToIndexes(createAnnotation); | |
// } else { | |
// TextRulerToolkit.log("Type " + fs.getType().getName() + "is unknown in test CAS"); | |
// } | |
// } | |
} | |
/** | |
* Applies a rule to an element of the additional data and compares the result to the | |
* corresponding element of the training data. | |
* | |
* @param rule | |
* @param goldDoc | |
* @param additionalDoc | |
* @param c | |
* @param testCas | |
*/ | |
private void testRuleOnDocument(final TrabalRule rule, final TextRulerExampleDocument goldDoc, | |
final TextRulerExampleDocument additionalDoc, final TextRulerStatisticsCollector c, | |
CAS testCas) { | |
if (TextRulerToolkit.DEBUG) { | |
MemoryWatch.watch(); | |
} | |
try { | |
rule.saveToRulesFile(getTempRulesFileName()); | |
if (TextRulerToolkit.DEBUG) { | |
TextRulerToolkit.writeCAStoXMIFile(testCas, tempDirectory() + "testCas.xmi"); | |
} | |
ae.process(testCas); | |
removeBasics(testCas); | |
if (TextRulerToolkit.DEBUG) { | |
TextRulerToolkit.writeCAStoXMIFile(testCas, tempDirectory() + "testCasProcessed.xmi"); | |
} | |
if (rule.getAnnotation() != null && rule.getTargetAnnotation() != null) { | |
compareOriginalDocumentWithTestCAS(goldDoc, additionalDoc, testCas, new TextRulerTarget( | |
rule.getAnnotation().getType().getName(), this), c, false); | |
if (rule.getAnnotation().getType() != rule.getTargetAnnotation().getType()) { | |
compareOriginalDocumentWithTestCAS(goldDoc, additionalDoc, testCas, new TextRulerTarget( | |
rule.getTargetAnnotation().getType().getName(), this), c, false); | |
} | |
} else if (rule.getTargetAnnotation() != null) { | |
compareOriginalDocumentWithTestCAS(goldDoc, additionalDoc, testCas, new TextRulerTarget( | |
rule.getTargetAnnotation().getType().getName(), this), c, false); | |
} else { | |
compareOriginalDocumentWithTestCAS(goldDoc, additionalDoc, testCas, new TextRulerTarget( | |
rule.getAnnotation().getType().getName(), this), c, false); | |
} | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
/** | |
* Compares two documents to rate the results of an applied rule. | |
* | |
* @param goldDoc | |
* @param additionalDoc | |
* @param testCas | |
* @param target | |
* @param c | |
* @param collectNegativeExamples | |
*/ | |
private void compareOriginalDocumentWithTestCAS(TextRulerExampleDocument goldDoc, | |
TextRulerExampleDocument additionalDoc, CAS testCas, TextRulerTarget target, | |
TextRulerStatisticsCollector c, boolean collectNegativeExamples) { | |
List<TextRulerExample> goldPositives = goldDoc.createSlotInstancesForCAS(goldDoc.getCAS(), | |
target, true); | |
List<TextRulerExample> additionalPositives = additionalDoc.createSlotInstancesForCAS( | |
additionalDoc.getCAS(), target, true); | |
List<TextRulerExample> testPositives = goldDoc | |
.createSlotInstancesForCAS(testCas, target, false); | |
List<TextRulerExample> baseFP = new ArrayList<TextRulerExample>(); | |
for (TextRulerExample e : additionalPositives) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(goldPositives, | |
e.getAnnotation()); | |
if (example == null) { | |
baseFP.add(e); | |
} | |
} | |
List<TextRulerExample> baseFN = new ArrayList<TextRulerExample>(); | |
for (TextRulerExample e : goldPositives) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation( | |
additionalPositives, e.getAnnotation()); | |
if (example == null) { | |
baseFN.add(e); | |
} | |
} | |
List<TextRulerExample> testFP = new ArrayList<TextRulerExample>(); | |
for (TextRulerExample e : testPositives) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(goldPositives, | |
e.getAnnotation()); | |
if (example == null) { | |
testFP.add(e); | |
} | |
} | |
List<TextRulerExample> testFN = new ArrayList<TextRulerExample>(); | |
for (TextRulerExample e : goldPositives) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(testPositives, | |
e.getAnnotation()); | |
if (example == null) { | |
testFN.add(e); | |
} | |
} | |
for (TextRulerExample e : baseFP) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(testFP, | |
e.getAnnotation()); | |
if (example == null) { | |
c.addCoveredPositive(e); | |
} | |
} | |
for (TextRulerExample e : baseFN) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(testFN, | |
e.getAnnotation()); | |
TextRulerExample coveredExample = TextRulerToolkit.exampleListContainsAnnotation( | |
goldPositives, e.getAnnotation()); | |
if (example == null) { | |
c.addCoveredPositive(coveredExample); | |
} | |
} | |
for (TextRulerExample e : testFN) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(baseFN, | |
e.getAnnotation()); | |
if (example == null) { | |
if (collectNegativeExamples) { | |
e.setPositive(false); | |
c.addCoveredNegative(e); | |
} else { | |
c.incCoveredNegatives(1); | |
} | |
} | |
} | |
for (TextRulerExample e : testFP) { | |
TextRulerExample example = TextRulerToolkit.exampleListContainsAnnotation(baseFP, | |
e.getAnnotation()); | |
if (example == null) { | |
c.incCoveredNegatives(1); | |
} | |
} | |
} | |
/** | |
* Removes duplicate rules by their hash value. | |
* | |
* @param rules | |
* @return updated list | |
*/ | |
public static List<TrabalRule> removeDuplicateRules(List<TrabalRule> rules) { | |
return new ArrayList<TrabalRule>(new HashSet<TrabalRule>(rules)); | |
} | |
/** | |
* Returns a string, containing the rules of the given list. | |
* | |
* @param rules | |
* @return string representation of given rules | |
*/ | |
public static String getRuleStrings(List<TrabalRule> rules) { | |
String result = ""; | |
for (TrabalRule r : rules) { | |
result += r.toString(); | |
} | |
return result; | |
} | |
/** | |
* Tests, if the given type is a slot type. | |
* | |
* @param type | |
* @return true, if given type is a slot type | |
*/ | |
public boolean isSlotType(Type type) { | |
for (String slot : slotNames) { | |
if (slot.equals(type.getName())) { | |
return true; | |
} | |
} | |
return false; | |
} | |
public boolean collectNegativeCoveredInstancesWhenTesting() { | |
return false; | |
} | |
public String getResultString() { | |
return getFileHeaderString(true) + result; | |
} | |
// // @Override | |
// public AnalysisEngine getAnalysisEngine() { | |
// if (ae == null) { | |
// String descriptorFile = TextRulerToolkit.getEngineDescriptorFromTMSourceFile(new Path( | |
// preprocessorTMFile)); | |
// sendStatusUpdateToDelegate("loading AE...", TextRulerLearnerState.ML_INITIALIZING, false); | |
// ae = TextRulerToolkit.loadAnalysisEngine(descriptorFile); | |
// | |
// // set filters to NO filtering so that we can add it manually with | |
// // the FILTERTYPE expression! | |
// String tempRulesFileName = getTempRulesFileName(); | |
// IPath path = new Path(tempRulesFileName); | |
// ae.setConfigParameterValue(RutaEngine.MAIN_SCRIPT, path.removeFileExtension() | |
// .lastSegment()); | |
// String portableString = path.removeLastSegments(1).toPortableString(); | |
// ae.setConfigParameterValue(RutaEngine.SCRIPT_PATHS, new String[] { portableString }); | |
// ae.setConfigParameterValue(RutaEngine.ADDITIONAL_SCRIPTS, new String[0]); | |
// ae.setConfigParameterValue(RutaEngine.RELOAD_SCRIPT, true); | |
// ae.setConfigParameterValue(RutaEngine.REMOVE_BASICS, true); | |
// | |
// try { | |
// ae.reconfigure(); | |
// } catch (ResourceConfigurationException e) { | |
// TextRulerPlugin.error(e); | |
// return null; | |
// } | |
// } | |
// return ae; | |
// } | |
/** | |
* Return the set of additional documents and initializes it, if needed. | |
*/ | |
public TextRulerExampleDocumentSet getAdditionalDocuments() { | |
if (additionalDocuments == null) { | |
if (!StringUtils.isBlank(additionalFolderPath)) { | |
additionalDocuments = new TextRulerExampleDocumentSet(additionalFolderPath, casCache); | |
} | |
} | |
return additionalDocuments; | |
} | |
public boolean getEnableFeatures() { | |
return enableFeatures; | |
} | |
// @Override | |
public void setParameters(Map<String, Object> params) { | |
if (TextRulerToolkit.DEBUG) | |
saveParametersToTempFolder(params); | |
if (params.containsKey(MAX_NUMBER_OF_BASIC_RULES_KEY)) | |
maxNumberOfBasicRules = (Integer) params.get(MAX_NUMBER_OF_BASIC_RULES_KEY); | |
if (params.containsKey(MAX_NUMBER_OF_RULES_KEY)) | |
maxNumberOfRules = (Integer) params.get(MAX_NUMBER_OF_RULES_KEY); | |
if (params.containsKey(MAX_NUMBER_OF_ITERATIONS_KEY)) | |
maxNumberOfIterations = (Integer) params.get(MAX_NUMBER_OF_ITERATIONS_KEY); | |
if (params.containsKey(ALGORITHM_ITERATIONS_KEY)) | |
algorithmIterations = (Integer) params.get(ALGORITHM_ITERATIONS_KEY); | |
if (params.containsKey(MAX_ERROR_RATE_KEY)) | |
maxErrorRate = (Double) params.get(MAX_ERROR_RATE_KEY); | |
if (params.containsKey(ENABLE_FEATURES_KEY)) | |
enableFeatures = (Boolean) params.get(ENABLE_FEATURES_KEY); | |
} | |
/** | |
* Creates a file in the *.csv format, containing TraBaL errors (error type, file name, false | |
* annotation and true annotation). | |
* | |
* @throws Exception | |
*/ | |
public void getErrorsAsCSV(String filePath) throws Exception { | |
String result = ""; | |
if (exampleDocuments == null) { | |
exampleDocuments = new TextRulerExampleDocumentSet(inputDirectory, casCache); | |
} | |
getAdditionalDocuments(); | |
errors = createErrorList(); | |
result += "'ErrorType';'FileName';'AnnotationType';'Annotation';'TargetAnnotationType';'TargetAnnotation';\n"; | |
for (AnnotationError each : errors) { | |
if (each.getAnnotation() != null && each.getTargetAnnotation() != null) { | |
result += "'" + each.getType() + "';'" | |
+ new File(each.getAnnotation().getDocument().getCasFileName()).getName() + "';'" | |
+ each.getAnnotation().getType().getShortName() + "';'" | |
+ each.getAnnotation().getCoveredText() + "';'" | |
+ each.getTargetAnnotation().getType().getShortName() + "';'" | |
+ each.getTargetAnnotation().getCoveredText() + "';\n"; | |
} else if (each.getAnnotation() != null) { | |
result += "'" + each.getType() + "';'" | |
+ new File(each.getAnnotation().getDocument().getCasFileName()).getName() + "';'" | |
+ each.getAnnotation().getType().getShortName() + "';'" | |
+ each.getAnnotation().getCoveredText() + "';'';'';\n"; | |
} else { | |
result += "'" + each.getType() + "';'" | |
+ new File(each.getTargetAnnotation().getDocument().getCasFileName()).getName() | |
+ "';'';'';'" + each.getTargetAnnotation().getType().getShortName() + "';'" | |
+ each.getTargetAnnotation().getCoveredText() + "';\n"; | |
} | |
} | |
if (errors.size() > 0) { | |
File csv = new File(filePath); | |
FileWriter fstream = new FileWriter(csv); | |
BufferedWriter out = new BufferedWriter(fstream); | |
out.write(result); | |
out.close(); | |
} | |
} | |
/** | |
* Compares trabal rules, disregarding the error rate, so basic rules with maximum number of | |
* covered positives are chosen. | |
*/ | |
protected Comparator<TrabalRule> basicComparator = new Comparator<TrabalRule>() { | |
public int compare(TrabalRule o1, TrabalRule o2) { | |
// coveredPositives | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() > o2.getCoveringStatistics() | |
.getCoveredPositivesCount()) | |
return -1; | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() < o2.getCoveringStatistics() | |
.getCoveredPositivesCount()) | |
return 1; | |
// coveredNegatives | |
if (o1.getCoveringStatistics().getCoveredNegativesCount() < o2.getCoveringStatistics() | |
.getCoveredNegativesCount()) | |
return -1; | |
if (o1.getCoveringStatistics().getCoveredNegativesCount() > o2.getCoveringStatistics() | |
.getCoveredNegativesCount()) | |
return 1; | |
// complexity of rules | |
if (o1.getConditions().size() < o2.getConditions().size()) | |
return -1; | |
if (o1.getConditions().size() > o2.getConditions().size()) | |
return 1; | |
// ruleString | |
return o1.getRuleString().compareTo(o2.getRuleString()); | |
} | |
}; | |
/** | |
* Compares trabal rules, starting with the error rate. | |
*/ | |
protected Comparator<TrabalRule> enhancedComparator = new Comparator<TrabalRule>() { | |
public int compare(TrabalRule o1, TrabalRule o2) { | |
// positives - negatives | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() | |
- o1.getCoveringStatistics().getCoveredNegativesCount() > o2.getCoveringStatistics() | |
.getCoveredPositivesCount() - o2.getCoveringStatistics().getCoveredNegativesCount()) | |
return -1; | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() | |
- o1.getCoveringStatistics().getCoveredNegativesCount() < o2.getCoveringStatistics() | |
.getCoveredPositivesCount() - o2.getCoveringStatistics().getCoveredNegativesCount()) | |
return 1; | |
// coveredPositives | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() > o2.getCoveringStatistics() | |
.getCoveredPositivesCount()) | |
return -1; | |
if (o1.getCoveringStatistics().getCoveredPositivesCount() < o2.getCoveringStatistics() | |
.getCoveredPositivesCount()) | |
return 1; | |
// coveredNegatives | |
if (o1.getCoveringStatistics().getCoveredNegativesCount() < o2.getCoveringStatistics() | |
.getCoveredNegativesCount()) | |
return -1; | |
if (o1.getCoveringStatistics().getCoveredNegativesCount() > o2.getCoveringStatistics() | |
.getCoveredNegativesCount()) | |
return 1; | |
// complexity of rules | |
if (o1.getConditions().size() < o2.getConditions().size()) | |
return -1; | |
if (o1.getConditions().size() > o2.getConditions().size()) | |
return 1; | |
// ruleString | |
return o1.getRuleString().compareTo(o2.getRuleString()); | |
} | |
}; | |
} |