| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.classification.utils; |
| |
| |
| import java.io.IOException; |
| import java.util.List; |
| |
| import org.apache.lucene.analysis.MockAnalyzer; |
| import org.apache.lucene.classification.BM25NBClassifier; |
| import org.apache.lucene.classification.BooleanPerceptronClassifier; |
| import org.apache.lucene.classification.CachingNaiveBayesClassifier; |
| import org.apache.lucene.classification.ClassificationResult; |
| import org.apache.lucene.classification.ClassificationTestBase; |
| import org.apache.lucene.classification.Classifier; |
| import org.apache.lucene.classification.KNearestFuzzyClassifier; |
| import org.apache.lucene.classification.KNearestNeighborClassifier; |
| import org.apache.lucene.classification.SimpleNaiveBayesClassifier; |
| import org.apache.lucene.index.LeafReader; |
| import org.apache.lucene.util.BytesRef; |
| import org.junit.Test; |
| |
| /** |
| * Tests for {@link ConfusionMatrixGenerator} |
| */ |
| public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object> { |
| |
| @Test |
| public void testGetConfusionMatrix() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new Classifier<BytesRef>() { |
| @Override |
| public ClassificationResult<BytesRef> assignClass(String text) throws IOException { |
| return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt()))); |
| } |
| |
| @Override |
| public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException { |
| return null; |
| } |
| |
| @Override |
| public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException { |
| return null; |
| } |
| }; |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| assertNotNull(confusionMatrix); |
| assertNotNull(confusionMatrix.getLinearizedMatrix()); |
| assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); |
| double avgClassificationTime = confusionMatrix.getAvgClassificationTime(); |
| assertTrue(avgClassificationTime >= 0d ); |
| double accuracy = confusionMatrix.getAccuracy(); |
| assertTrue(accuracy >= 0d); |
| assertTrue(accuracy <= 1d); |
| double precision = confusionMatrix.getPrecision(); |
| assertTrue(precision >= 0d); |
| assertTrue(precision <= 1d); |
| double recall = confusionMatrix.getRecall(); |
| assertTrue(recall >= 0d); |
| assertTrue(recall <= 1d); |
| double f1Measure = confusionMatrix.getF1Measure(); |
| assertTrue(f1Measure >= 0d); |
| assertTrue(f1Measure <= 1d); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithSNB() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| private void checkCM(ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix) { |
| assertNotNull(confusionMatrix); |
| assertNotNull(confusionMatrix.getLinearizedMatrix()); |
| assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); |
| assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); |
| double accuracy = confusionMatrix.getAccuracy(); |
| assertTrue(accuracy >= 0d); |
| assertTrue(accuracy <= 1d); |
| double precision = confusionMatrix.getPrecision(); |
| assertTrue(precision >= 0d); |
| assertTrue(precision <= 1d); |
| double recall = confusionMatrix.getRecall(); |
| assertTrue(recall >= 0d); |
| assertTrue(recall <= 1d); |
| double f1Measure = confusionMatrix.getF1Measure(); |
| assertTrue(f1Measure >= 0d); |
| assertTrue(f1Measure <= 1d); |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithBM25NB() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new BM25NBClassifier(reader, analyzer, null, categoryFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithCNB() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithKNN() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithFLTKNN() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, categoryFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, categoryFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| |
| @Test |
| public void testGetConfusionMatrixWithBP() throws Exception { |
| LeafReader reader = null; |
| try { |
| MockAnalyzer analyzer = new MockAnalyzer(random()); |
| reader = getSampleIndex(analyzer); |
| Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName); |
| ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, |
| classifier, booleanFieldName, textFieldName, -1); |
| checkCM(confusionMatrix); |
| assertTrue(confusionMatrix.getPrecision("true") >= 0d); |
| assertTrue(confusionMatrix.getPrecision("true") <= 1d); |
| assertTrue(confusionMatrix.getPrecision("false") >= 0d); |
| assertTrue(confusionMatrix.getPrecision("false") <= 1d); |
| assertTrue(confusionMatrix.getRecall("true") >= 0d); |
| assertTrue(confusionMatrix.getRecall("true") <= 1d); |
| assertTrue(confusionMatrix.getRecall("false") >= 0d); |
| assertTrue(confusionMatrix.getRecall("false") <= 1d); |
| assertTrue(confusionMatrix.getF1Measure("true") >= 0d); |
| assertTrue(confusionMatrix.getF1Measure("true") <= 1d); |
| assertTrue(confusionMatrix.getF1Measure("false") >= 0d); |
| assertTrue(confusionMatrix.getF1Measure("false") <= 1d); |
| } finally { |
| if (reader != null) { |
| reader.close(); |
| } |
| } |
| } |
| } |