/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package opennlp.tools.parse_thicket.kernel_interface; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.logging.Logger; | |
import org.apache.commons.io.FileUtils; | |
import org.apache.tika.Tika; | |
import org.apache.tika.exception.TikaException; | |
import opennlp.tools.jsmlearning.ProfileReaderWriter; | |
import opennlp.tools.parse_thicket.ParseThicket; | |
import opennlp.tools.parse_thicket.VerbNetProcessor; | |
import opennlp.tools.parse_thicket.apps.MultiSentenceSearchResultsProcessor; | |
import opennlp.tools.parse_thicket.matching.Matcher; | |
public class TreeKernelBasedClassifier { | |
protected static Logger LOG = Logger | |
.getLogger("opennlp.tools.similarity.apps.TreeKernelBasedClassifier"); | |
protected ArrayList<File> queuePos = new ArrayList<File>(), queueNeg = new ArrayList<File>(); | |
protected Matcher matcher = new Matcher(); | |
protected TreeKernelRunner tkRunner = new TreeKernelRunner(); | |
protected TreeExtenderByAnotherLinkedTree treeExtender = new TreeExtenderByAnotherLinkedTree(); | |
protected String path; | |
public void setKernelPath (String path){ | |
this.path=path; | |
} | |
protected static final String modelFileName = "model.txt"; | |
protected static final String trainingFileName = "training.txt"; | |
protected static final String unknownToBeClassified = "unknown.txt"; | |
protected static final String classifierOutput = "classifier_output.txt"; | |
protected static final Float MIN_SVM_SCORE_TOBE_IN = 0.2f; | |
/* main entry point to SVM TK classifier | |
* gets a file, reads it outside of CI, extracts longer paragraphs and builds parse thickets for them. | |
* Then parse thicket dump is processed by svm_classify | |
*/ | |
public Boolean classifyText(File f){ | |
FileUtils.deleteQuietly(new File(path+unknownToBeClassified)); | |
if (!(new File(path+modelFileName).exists())){ | |
LOG.severe("Model file '" +modelFileName + "'is absent: skip SVM classification"); | |
return null; | |
} | |
Map<Integer, Integer> countObject = new HashMap<Integer, Integer>(); | |
int itemCount=0, objectCount = 0; | |
List<String> treeBankBuffer = new ArrayList<String>(); | |
List<String> texts=DescriptiveParagraphFromDocExtractor.getLongParagraphsFromFile(f); | |
List<String> lines = formTreeKernelStructuresMultiplePara(texts, "0"); | |
for(String l: lines){ | |
countObject.put(itemCount, objectCount); | |
itemCount++; | |
} | |
objectCount++; | |
treeBankBuffer.addAll(lines); | |
// write the lists of samples to a file | |
try { | |
FileUtils.writeLines(new File(path+unknownToBeClassified), null, treeBankBuffer); | |
} catch (IOException e) { | |
LOG.severe("Problem creating parse thicket files '"+ path+unknownToBeClassified + "' to be classified\n"+ e.getMessage() ); | |
} | |
tkRunner.runClassifier(path, unknownToBeClassified, modelFileName, classifierOutput); | |
// read classification results | |
List<String[]> classifResults = ProfileReaderWriter.readProfiles(path+classifierOutput, ' '); | |
itemCount=0; objectCount = 0; | |
int currentItemCount=0; | |
float accum = 0; | |
LOG.info("\nsvm scores per paragraph: " ); | |
for(String[] line: classifResults){ | |
Float val = Float.parseFloat(line[0]); | |
System.out.print(val+" "); | |
accum+=val; | |
currentItemCount++; | |
} | |
float averaged = accum/(float)currentItemCount; | |
LOG.info("\n average = "+averaged); | |
currentItemCount=0; | |
Boolean in = false; | |
if (averaged> MIN_SVM_SCORE_TOBE_IN) | |
return true; | |
else | |
return false; | |
} | |
protected void addFilesPos(File file) { | |
if (!file.exists()) { | |
System.out.println(file + " does not exist."); | |
} | |
if (file.isDirectory()) { | |
for (File f : file.listFiles()) { | |
//if (!(f.getName().endsWith(".txt") || f.getName().endsWith(".pdf"))) | |
// continue; | |
addFilesPos(f); | |
System.out.println(f.getName()); | |
} | |
} else { | |
queuePos.add(file); | |
} | |
} | |
protected void addFilesNeg(File file) { | |
if (!file.exists()) { | |
System.out.println(file + " does not exist."); | |
} | |
if (file.isDirectory()) { | |
for (File f : file.listFiles()) { | |
//if (!(f.getName().endsWith(".txt")||f.getName().endsWith(".pdf"))) | |
// continue; | |
addFilesNeg(f); | |
System.out.println(f.getName()); | |
} | |
} else { | |
queueNeg.add(file); | |
} | |
} | |
protected void trainClassifier( | |
String posDirectory, String negDirectory) { | |
queuePos.clear(); queueNeg.clear(); | |
addFilesPos(new File(posDirectory)); | |
addFilesNeg(new File(negDirectory)); | |
List<File> filesPos = new ArrayList<File>(queuePos), filesNeg = new ArrayList<File>(queueNeg); | |
List<String[]> treeBankBuffer = new ArrayList<String[]>(); | |
for (File f : filesPos) { | |
// get first paragraph of text | |
String text=DescriptiveParagraphFromDocExtractor.getFirstParagraphFromFile(f); | |
treeBankBuffer.add(new String[]{formTreeKernelStructure(text, "1")}); | |
} | |
for (File f : filesNeg) { | |
// get first paragraph of text | |
String text=DescriptiveParagraphFromDocExtractor.getFirstParagraphFromFile(f); | |
treeBankBuffer.add(new String[]{formTreeKernelStructure(text, "-1")}); | |
} | |
// write the lists of samples to a file | |
ProfileReaderWriter.writeReport(treeBankBuffer, path+trainingFileName, ' '); | |
// build the model | |
tkRunner.runLearner(path, trainingFileName, modelFileName); | |
} | |
public List<String[]> classifyFilesInDirectory(String dirFilesToBeClassified){ | |
List<String[]> treeBankBuffer = new ArrayList<String[]>(); | |
queuePos.clear(); | |
addFilesPos(new File( dirFilesToBeClassified)); | |
List<File> filesUnkn = new ArrayList<File>(queuePos); | |
for (File f : filesUnkn) { | |
String text=DescriptiveParagraphFromDocExtractor.getFirstParagraphFromFile(f); | |
String line = formTreeKernelStructure(text, "0"); | |
treeBankBuffer.add(new String[]{line}); | |
} | |
// form a file from the texts to be classified | |
ProfileReaderWriter.writeReport(treeBankBuffer, path+unknownToBeClassified, ' '); | |
tkRunner.runClassifier(path, unknownToBeClassified, modelFileName, classifierOutput); | |
// read classification results | |
List<String[]> classifResults = ProfileReaderWriter.readProfiles(path+classifierOutput, ' '); | |
// iterate through classification results and set them as scores for hits | |
List<String[]>results = new ArrayList<String[]>(); | |
int count=0; | |
for(String[] line: classifResults){ | |
Float val = Float.parseFloat(line[0]); | |
Boolean in = false; | |
if (val> MIN_SVM_SCORE_TOBE_IN) | |
in = true; | |
String[] rline = new String[]{filesUnkn.get(count).getName(), in.toString(), line[0], filesUnkn.get(count).getAbsolutePath() }; // treeBankBuffer.get(count).toString() }; | |
results.add(rline); | |
count++; | |
} | |
return results; | |
} | |
protected List<String> formTreeKernelStructuresMultiplePara(List<String> texts, String flag) { | |
List<String> extendedTreesDumpTotal = new ArrayList<String>(); | |
try { | |
for(String text: texts){ | |
// get the parses from original documents, and form the training dataset | |
LOG.info("About to build pt from "+text); | |
ParseThicket pt = matcher.buildParseThicketFromTextWithRST(text); | |
LOG.info("About to build extended forest "); | |
List<String> extendedTreesDump = treeExtender.buildForestForCorefArcs(pt); | |
for(String line: extendedTreesDump) | |
extendedTreesDumpTotal.add(flag + " |BT| "+line + " |ET| "); | |
LOG.info("DONE"); | |
} | |
} catch (Exception e) { | |
LOG.severe("Problem forming parse thicket flat file to be classified\n"+ e.getMessage() ); | |
} | |
return extendedTreesDumpTotal; | |
} | |
protected String formTreeKernelStructure(String text, String flag) { | |
String treeBankBuffer = ""; | |
try { | |
// get the parses from original documents, and form the training dataset | |
LOG.info("About to build pt from "+text); | |
ParseThicket pt = matcher.buildParseThicketFromTextWithRST(text); | |
LOG.info("About to build extended forest "); | |
List<String> extendedTreesDump = treeExtender.buildForestForCorefArcs(pt); | |
LOG.info("DONE"); | |
treeBankBuffer+=flag; | |
// form the list of training samples | |
for(String t: extendedTreesDump ){ | |
if (BracesProcessor.isBalanced(t)) | |
treeBankBuffer+=" |BT| "+t; | |
else | |
System.err.println("Wrong tree: " + t); | |
} | |
if (extendedTreesDump.size()<1) | |
treeBankBuffer+=" |BT| "; | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
return treeBankBuffer+ " |ET|"; | |
} | |
public static void main(String[] args){ | |
VerbNetProcessor p = VerbNetProcessor. | |
getInstance("/Users/borisgalitsky/Documents/workspace/deepContentInspection/src/test/resources"); | |
TreeKernelBasedClassifier proc = new TreeKernelBasedClassifier(); | |
proc.setKernelPath("/Users/borisgalitsky/Documents/tree_kernel/"); | |
proc.trainClassifier(args[0], args[1]); | |
List<String[]>res = proc.classifyFilesInDirectory(args[2]); | |
ProfileReaderWriter.writeReport(res, "svmDesignDocReport03minus.csv"); | |
} | |
} |