blob: 06250239d416a9d180ef467fa066455fe1c920e5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite;
import java.io.Closeable;
import java.io.IOException;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.classification.utils.DatasetSplitter;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.search.similarities.AfterEffectB;
import org.apache.lucene.search.similarities.AxiomaticF1EXP;
import org.apache.lucene.search.similarities.AxiomaticF1LOG;
import org.apache.lucene.search.similarities.BasicModelG;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.DFRSimilarity;
import org.apache.lucene.search.similarities.DistributionLL;
import org.apache.lucene.search.similarities.DistributionSPL;
import org.apache.lucene.search.similarities.IBSimilarity;
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity;
import org.apache.lucene.search.similarities.LambdaDF;
import org.apache.lucene.search.similarities.LambdaTTF;
import org.apache.lucene.search.similarities.Normalization;
import org.apache.lucene.search.similarities.NormalizationH1;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.TestUtil;
import org.junit.Test;
@LuceneTestCase.SuppressSysoutChecks(bugUrl = "none")
@TimeoutSuite(millis = Integer.MAX_VALUE) // hopefully ~24 days is long enough ;)
@LuceneTestCase.Monster("takes a lot!")
public final class Test20NewsgroupsClassification extends LuceneTestCase {
private static final String CATEGORY_FIELD = "category";
private static final String BODY_FIELD = "body";
private static final String SUBJECT_FIELD = "subject";
private static final String INDEX_DIR = "/path/to/lucene-solr/lucene/classification/20n";
private static boolean index = true;
private static boolean split = true;
@Test
public void test20Newsgroups() throws Exception {
String indexProperty = System.getProperty("index");
if (indexProperty != null) {
try {
index = Boolean.valueOf(indexProperty);
} catch (Exception e) {
// ignore
}
}
String splitProperty = System.getProperty("split");
if (splitProperty != null) {
try {
split = Boolean.valueOf(splitProperty);
} catch (Exception e) {
// ignore
}
}
Directory directory = newDirectory();
Directory cv = null;
Directory test = null;
Directory train = null;
IndexReader testReader = null;
if (split) {
cv = newDirectory();
test = newDirectory();
train = newDirectory();
}
IndexReader reader = null;
List<Classifier<BytesRef>> classifiers = new LinkedList<>();
try {
Analyzer analyzer = new StandardAnalyzer();
if (index) {
System.out.println("Indexing 20 Newsgroups...");
long startIndex = System.currentTimeMillis();
IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig(analyzer));
Path indexDir = Paths.get(INDEX_DIR);
int docsIndexed = buildIndex(indexDir, indexWriter);
long endIndex = System.currentTimeMillis();
System.out.println(
"Indexed " + docsIndexed + " docs in " + (endIndex - startIndex) / 1000 + "s");
indexWriter.close();
}
if (split && !index) {
reader = DirectoryReader.open(train);
} else {
reader = DirectoryReader.open(directory);
}
if (index && split) {
System.out.println("Splitting the index...");
long startSplit = System.currentTimeMillis();
DatasetSplitter datasetSplitter = new DatasetSplitter(0.2, 0);
datasetSplitter.split(
reader,
train,
test,
cv,
analyzer,
false,
CATEGORY_FIELD,
BODY_FIELD,
SUBJECT_FIELD,
CATEGORY_FIELD);
reader.close();
reader = DirectoryReader.open(train); // using the train index from now on
long endSplit = System.currentTimeMillis();
System.out.println("Splitting done in " + (endSplit - startSplit) / 1000 + "s");
}
classifiers.add(
new KNearestNeighborClassifier(
reader,
new ClassicSimilarity(),
analyzer,
null,
1,
0,
0,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader, null, analyzer, null, 1, 0, 0, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new ClassicSimilarity(),
analyzer,
null,
3,
0,
0,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader, new AxiomaticF1EXP(), analyzer, null, 3, 0, 0, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader, new AxiomaticF1LOG(), analyzer, null, 3, 0, 0, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new LMDirichletSimilarity(),
analyzer,
null,
3,
1,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new LMJelinekMercerSimilarity(0.3f),
analyzer,
null,
3,
1,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader, null, analyzer, null, 3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new DFRSimilarity(new BasicModelG(), new AfterEffectB(), new NormalizationH1()),
analyzer,
null,
3,
1,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new IBSimilarity(
new DistributionSPL(), new LambdaDF(), new Normalization.NoNormalization()),
analyzer,
null,
3,
1,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestNeighborClassifier(
reader,
new IBSimilarity(new DistributionLL(), new LambdaTTF(), new NormalizationH1()),
analyzer,
null,
3,
1,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader,
new LMJelinekMercerSimilarity(0.3f),
analyzer,
null,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader,
new IBSimilarity(new DistributionLL(), new LambdaTTF(), new NormalizationH1()),
analyzer,
null,
1,
CATEGORY_FIELD,
BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader, new ClassicSimilarity(), analyzer, null, 1, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader, new ClassicSimilarity(), analyzer, null, 3, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(reader, null, analyzer, null, 3, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader, new AxiomaticF1EXP(), analyzer, null, 3, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new KNearestFuzzyClassifier(
reader, new AxiomaticF1LOG(), analyzer, null, 3, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(new BM25NBClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new CachingNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));
classifiers.add(
new SimpleNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));
int maxdoc;
if (split) {
testReader = DirectoryReader.open(test);
maxdoc = testReader.maxDoc();
} else {
maxdoc = reader.maxDoc();
}
System.out.println("Starting evaluation on " + maxdoc + " docs...");
ExecutorService service =
new ThreadPoolExecutor(
1,
TestUtil.nextInt(random(), 2, 6),
Long.MAX_VALUE,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(),
new NamedThreadFactory(getClass().getName()));
List<Future<String>> futures = new LinkedList<>();
for (Classifier<BytesRef> classifier : classifiers) {
testClassifier(reader, testReader, service, futures, classifier);
}
for (Future<String> f : futures) {
System.out.println(f.get());
}
Thread.sleep(10000);
service.shutdown();
} finally {
if (reader != null) {
reader.close();
}
directory.close();
if (testReader != null) {
testReader.close();
}
if (test != null) {
test.close();
}
if (train != null) {
train.close();
}
if (cv != null) {
cv.close();
}
for (Classifier<BytesRef> c : classifiers) {
if (c instanceof Closeable) {
((Closeable) c).close();
}
}
}
}
private void testClassifier(
final IndexReader ar,
IndexReader testReader,
ExecutorService service,
List<Future<String>> futures,
Classifier<BytesRef> classifier) {
futures.add(
service.submit(
() -> {
final long startTime = System.currentTimeMillis();
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix;
if (split) {
confusionMatrix =
ConfusionMatrixGenerator.getConfusionMatrix(
testReader, classifier, CATEGORY_FIELD, BODY_FIELD, 60000 * 30);
} else {
confusionMatrix =
ConfusionMatrixGenerator.getConfusionMatrix(
ar, classifier, CATEGORY_FIELD, BODY_FIELD, 60000 * 30);
}
final long endTime = System.currentTimeMillis();
final int elapse = (int) (endTime - startTime) / 1000;
return " * "
+ classifier
+ " \n * accuracy = "
+ confusionMatrix.getAccuracy()
+ "\n * precision = "
+ confusionMatrix.getPrecision()
+ "\n * recall = "
+ confusionMatrix.getRecall()
+ "\n * f1-measure = "
+ confusionMatrix.getF1Measure()
+ "\n * avgClassificationTime = "
+ confusionMatrix.getAvgClassificationTime()
+ "\n * time = "
+ elapse
+ " (sec)\n ";
}));
}
private int buildIndex(Path indexDir, IndexWriter indexWriter) throws IOException {
int i = 0;
try (DirectoryStream<Path> groupsStream = Files.newDirectoryStream(indexDir)) {
for (Path groupsDir : groupsStream) {
if (!Files.isHidden(groupsDir)) {
try (DirectoryStream<Path> stream = Files.newDirectoryStream(groupsDir)) {
for (Path p : stream) {
if (!Files.isHidden(p)) {
NewsPost post =
parse(p, p.getParent().getFileName().toString(), p.getFileName().toString());
if (post != null) {
Document d = new Document();
d.add(new StringField(CATEGORY_FIELD, post.getGroup(), Field.Store.YES));
d.add(new SortedDocValuesField(CATEGORY_FIELD, new BytesRef(post.getGroup())));
d.add(new TextField(SUBJECT_FIELD, post.getSubject(), Field.Store.YES));
d.add(new TextField(BODY_FIELD, post.getBody(), Field.Store.YES));
indexWriter.addDocument(d);
i++;
}
}
}
}
}
}
}
indexWriter.commit();
return i;
}
private NewsPost parse(Path path, String groupName, String number) {
StringBuilder body = new StringBuilder();
String subject = "";
boolean inBody = false;
try {
if (Files.isReadable(path)) {
for (String line : Files.readAllLines(path)) {
if (line.startsWith("Subject:")) {
subject = line.substring(8);
} else {
if (inBody) {
if (body.length() > 0) {
body.append("\n");
}
body.append(line);
} else if (line.isEmpty() || line.trim().length() == 0) {
inBody = true;
}
}
}
}
return new NewsPost(body.toString(), subject, groupName, number);
} catch (Throwable e) {
return null;
}
}
private class NewsPost {
private final String body;
private final String subject;
private final String group;
private final String number;
private NewsPost(String body, String subject, String group, String number) {
this.body = body;
this.subject = subject;
this.group = group;
this.number = number;
}
public String getBody() {
return body;
}
public String getSubject() {
return subject;
}
public String getGroup() {
return group;
}
public String getNumber() {
return number;
}
}
}