blob: 6eaf5483086a20ace6c56eaf55e73aa8bd6a1628 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samoa;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.input.Tailer;
import org.apache.commons.io.input.TailerListenerAdapter;
import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.InvocationTargetException;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertTrue;
public class TestUtils {
private static final Logger LOG = LoggerFactory.getLogger(TestUtils.class.getName());
public static void test(final TestParams testParams) throws IOException, ClassNotFoundException,
NoSuchMethodException, InvocationTargetException, IllegalAccessException, InterruptedException {
final File tempFile = File.createTempFile("test", "test");
final File labelFile = File.createTempFile("result", "result");
LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString());
Executors.newSingleThreadExecutor().submit(new Callable<Void>() {
@Override
public Void call() throws Exception {
try {
Class.forName(testParams.getTaskClassName())
.getMethod("main", String[].class)
.invoke(null, (Object) String.format(
testParams.getCliStringTemplate(),
tempFile.getAbsolutePath(),
testParams.getInputInstances(),
testParams.getSamplingSize(),
testParams.getInputDelayMicroSec(),
labelFile.getAbsolutePath(),
testParams.getLabelSamplingSize()
).split("[ ]"));
} catch (Exception e) {
LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage());
}
return null;
}
});
Thread.sleep(TimeUnit.SECONDS.toMillis(testParams.getPrePollWaitSeconds()));
CountDownLatch signalComplete = new CountDownLatch(1);
final Tailer tailer = Tailer.create(tempFile, new TestResultsTailerAdapter(signalComplete), 1000);
new Thread(new Runnable() {
@Override
public void run() {
tailer.run();
}
}).start();
signalComplete.await();
tailer.stop();
assertResults(tempFile, testParams);
if (testParams.getLabelFileCreated())
assertLabels(labelFile, testParams);
}
public static void assertResults(File outputFile, org.apache.samoa.TestParams testParams) throws IOException {
LOG.info("Checking results file " + outputFile.getAbsolutePath());
// 1. parse result file with csv parser
Reader in = new FileReader(outputFile);
Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false)
.withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in);
CSVRecord last = null;
Iterator<CSVRecord> iterator = records.iterator();
CSVRecord header = iterator.next();
// Results of Standard Evaluation have 5 columns, and cv Evaluation have 9 columns
int cvEvaluation = (header.size() == 9) ? 1: 0;
String cvText = (header.size() == 9) ? "[avg] " : "";
Assert
.assertEquals("Unexpected column", org.apache.samoa.TestParams.EVALUATION_INSTANCES, header.get(0).trim());
Assert
.assertEquals("Unexpected column", cvText + org.apache.samoa.TestParams.CLASSIFIED_INSTANCES, header.get(1).trim());
Assert.assertEquals("Unexpected column", cvText + org.apache.samoa.TestParams.CLASSIFICATIONS_CORRECT, header.get(2 + cvEvaluation)
.trim());
Assert.assertEquals("Unexpected column", cvText + org.apache.samoa.TestParams.KAPPA_STAT, header.get(3 + 2 * cvEvaluation).trim());
Assert.assertEquals("Unexpected column", cvText + org.apache.samoa.TestParams.KAPPA_TEMP_STAT, header.get(4 + 3 * cvEvaluation).trim());
// 2. check last line result
while (iterator.hasNext()) {
last = iterator.next();
}
assertTrue(String.format("Unmet threshold expected %d got %f",
testParams.getEvaluationInstances(), Float.parseFloat(last.get(0))),
testParams.getEvaluationInstances() <= Float.parseFloat(last.get(0)));
assertTrue(String.format("Unmet threshold expected %d got %f", testParams.getClassifiedInstances(),
Float.parseFloat(last.get(1))),
testParams.getClassifiedInstances() <= Float.parseFloat(last.get(1)));
assertTrue(String.format("Unmet threshold expected %f got %f",
testParams.getClassificationsCorrect(), Float.parseFloat(last.get(2 + cvEvaluation))),
testParams.getClassificationsCorrect() <= Float.parseFloat(last.get(2 + cvEvaluation)));
assertTrue(String.format("Unmet threshold expected %f got %f",
testParams.getKappaStat(), Float.parseFloat(last.get(3 + 2 * cvEvaluation))),
testParams.getKappaStat() <= Float.parseFloat(last.get(3 + 2 * cvEvaluation)));
assertTrue(String.format("Unmet threshold expected %f got %f",
testParams.getKappaTempStat(), Float.parseFloat(last.get(4 + 3 * cvEvaluation))),
testParams.getKappaTempStat() <= Float.parseFloat(last.get(4 + 3 * cvEvaluation)));
}
public static void assertLabels(File labelFile, org.apache.samoa.TestParams testParams) throws IOException {
LOG.info("Checking labels file " + labelFile.getAbsolutePath());
//1. parse result file with csv parser
Reader in = new FileReader(labelFile);
long lineCount = 0;
long expectedLineCount = testParams.getInputInstances() / testParams.getLabelSamplingSize();
Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false)
.withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in);
Iterator<CSVRecord> iterator = records.iterator();
CSVRecord header = iterator.next();
while (iterator.hasNext()) {
iterator.next();
lineCount = lineCount + 1;
}
Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.INSTANCE_ID, header.get(0).trim());
Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.TRUE_CLASS_VALUE, header.get(1).trim());
Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.PREDICTED_CLASS_VALUE, header.get(2).trim());
for (int i = 3; i < header.size(); i++)
Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.VOTES, header.get(i).trim().substring(0, org.apache.samoa.TestParams.VOTES.length()));
Assert.assertEquals("Wrong number of lines in prediction file", expectedLineCount, lineCount);
}
private static class TestResultsTailerAdapter extends TailerListenerAdapter {
private final CountDownLatch signalComplete;
public TestResultsTailerAdapter(CountDownLatch signalComplete) {
this.signalComplete = signalComplete;
}
@Override
public void handle(String line) {
if ("# COMPLETED".equals(line.trim())) {
signalComplete.countDown();
}
}
}
}