blob: 347f10f2984e572cde4d6f6425ad39429758206c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.stanbol.enhancer.engine.topic;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.apache.clerezza.rdf.core.Graph;
import org.apache.clerezza.rdf.core.serializedform.Parser;
import org.apache.clerezza.rdf.core.serializedform.SupportedFormat;
import org.apache.clerezza.rdf.jena.parser.JenaParserProvider;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.params.CommonParams;
import org.apache.stanbol.commons.solr.utils.StreamQueryRequest;
import org.apache.stanbol.enhancer.servicesapi.EnhancementEngine;
import org.apache.stanbol.enhancer.servicesapi.rdf.OntologicalClasses;
import org.apache.stanbol.enhancer.servicesapi.rdf.Properties;
import org.apache.stanbol.enhancer.topic.EmbeddedSolrHelper;
import org.apache.stanbol.enhancer.topic.api.ClassificationReport;
import org.apache.stanbol.enhancer.topic.api.ClassifierException;
import org.apache.stanbol.enhancer.topic.api.TopicSuggestion;
import org.apache.stanbol.enhancer.topic.api.training.TrainingSetException;
import org.apache.stanbol.enhancer.topic.training.SolrTrainingSet;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.osgi.service.cm.ConfigurationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TopicEngineTest extends EmbeddedSolrHelper {
private static final Logger log = LoggerFactory.getLogger(TopicEngineTest.class);
private static final String TEST_ROOT =
FilenameUtils.separatorsToSystem("/target/enginetest-files");
private static String userDir;
private static File testRoot;
private static int testCounter = 1;
EmbeddedSolrServer classifierSolrServer;
EmbeddedSolrServer trainingSetSolrServer;
File solrHome;
SolrTrainingSet trainingSet;
TopicClassificationEngine classifier;
@BeforeClass
public static void initTestFolder() throws IOException {
//basedir - if present - is the project base folder
String baseDir = System.getProperty("basedir");
if(baseDir == null){ //if missing fall back to user.dir
baseDir = System.getProperty("user.dir");
}
//store the current user.dir
userDir = System.getProperty("user.dir");
testRoot = new File(baseDir,TEST_ROOT);
log.info("Topic Enigne Test Folder : "+testRoot);
if(testRoot.exists()){
log.info(" ... delete files from previouse test");
FileUtils.deleteDirectory(testRoot);
}
FileUtils.forceMkdir(testRoot);
System.setProperty("user.dir", testRoot.getName());
}
/**
* resets the "user.dir" system property the the original value
*/
@AfterClass
public static void cleanup(){
System.setProperty("user.dir", userDir);
}
@Before
public void setup() throws Exception {
solrHome = new File(testRoot, "test"+testCounter);
testCounter++;
assertTrue("Unable to create solr.home directory '"+solrHome+"'!",solrHome.mkdir());
classifierSolrServer = makeEmbeddedSolrServer(solrHome, "classifierserver", "test-topic-model",
"default-topic-model");
classifier = TopicClassificationEngine.fromParameters(getDefaultClassifierConfigParams());
//configure the directory used to create Embedded SolrServers for CVFold
classifier.configureEmbeddedSolrServerDir(solrHome);
trainingSetSolrServer = makeEmbeddedSolrServer(solrHome, "trainingsetserver",
"test-topic-trainingset", "default-topic-trainingset");
trainingSet = new SolrTrainingSet();
trainingSet.configure(getDefaultTrainingSetConfigParams());
}
@After
public void cleanupEmbeddedSolrServer() {
//FileUtils.deleteQuietly(solrHome);
solrHome = null;
classifierSolrServer = null;
trainingSetSolrServer = null;
trainingSet = null;
classifier.deactivate(null);
}
protected void loadSampleTopicsFromTSV() throws IOException, SolrServerException {
assertNotNull(classifierSolrServer);
String topicSnippetsPath = "/topics_abstracts_snippet.tsv";
InputStream is = getClass().getResourceAsStream(topicSnippetsPath);
assertNotNull("Could not find test resource: " + topicSnippetsPath, is);
// Build a query for the CSV importer
SolrQuery query = new SolrQuery();
query.setQueryType("/update/csv");
query.set("commit", true);
query.set("separator", "\t");
query.set("headers", false);
query.set("fieldnames", "topic,popularity,broader,text");
query.set(CommonParams.STREAM_CONTENTTYPE, "text/plan;charset=utf-8");
query.set(CommonParams.STREAM_BODY, IOUtils.toString(is, "utf-8"));
// Upload an index
QueryResponse response = new StreamQueryRequest(query).process(classifierSolrServer);
assertNotNull(response);
log.info(String.format("Indexed test topics in %dms", response.getElapsedTime()));
}
@Test
public void testEngineConfiguration() throws ConfigurationException {
log.info(" --- testEngineConfiguration --- ");
Hashtable<String,Object> config = getDefaultClassifierConfigParams();
TopicClassificationEngine classifier = TopicClassificationEngine.fromParameters(config);
assertNotNull(classifier);
assertEquals(classifier.engineName, "test-engine");
assertEquals(classifier.getActiveSolrServer(), classifierSolrServer);
assertEquals(classifier.conceptUriField, "concept");
assertEquals(classifier.similarityField, "classifier_features");
assertEquals(classifier.acceptedLanguages, new ArrayList<String>());
// check some required attributes
// NOTE: This is no longer an required field, but uses a default values instead
// Hashtable<String,Object> configWithMissingTopicField = new Hashtable<String,Object>();
// configWithMissingTopicField.putAll(config);
// configWithMissingTopicField.remove(TopicClassificationEngine.CONCEPT_URI_FIELD);
// try {
// TopicClassificationEngine.fromParameters(configWithMissingTopicField);
// fail("Should have raised a ConfigurationException");
// } catch (ConfigurationException e) {}
Hashtable<String,Object> configWithMissingEngineName = new Hashtable<String,Object>();
configWithMissingEngineName.putAll(config);
configWithMissingEngineName.remove(EnhancementEngine.PROPERTY_NAME);
try {
TopicClassificationEngine.fromParameters(configWithMissingEngineName);
fail("Should have raised a ConfigurationException");
} catch (ConfigurationException e) {}
// check accept language optional param
Hashtable<String,Object> configWithAcceptLangage = new Hashtable<String,Object>();
configWithAcceptLangage.putAll(config);
configWithAcceptLangage.put(TopicClassificationEngine.LANGUAGES, "en, fr");
classifier = TopicClassificationEngine.fromParameters(configWithAcceptLangage);
assertNotNull(classifier);
assertEquals(classifier.acceptedLanguages, Arrays.asList("en", "fr"));
}
@Test
public void testImportModelFromSKOS() throws Exception {
log.info(" --- testImportModelFromSKOS --- ");
Parser parser = Parser.getInstance();
parser.bindParsingProvider(new JenaParserProvider());
Graph graph = parser.parse(getClass().getResourceAsStream("/sample-scheme.skos.rdf.xml"),
SupportedFormat.RDF_XML);
int imported = classifier.importConceptsFromGraph(graph, OntologicalClasses.SKOS_CONCEPT,
Properties.SKOS_BROADER);
assertEquals(imported, 4);
assertEquals(0, classifier.getBroaderConcepts("http://example.com/ns#someconceptscheme/100").size());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/ns#someconceptscheme/200").size());
assertEquals(1, classifier.getBroaderConcepts("http://example.com/ns#someconceptscheme/010").size());
assertEquals(1, classifier.getBroaderConcepts("http://example.com/ns#someconceptscheme/020").size());
assertEquals(2, classifier.getRootConcepts().size());
}
@Test
public void testProgrammaticThesaurusConstruction() throws Exception {
log.info(" --- testProgrammaticThesaurusConstruction --- ");
// Register the roots of the taxonomy
classifier.addConcept("http://example.com/topics/root1", null);
classifier.addConcept("http://example.com/topics/root2", null);
classifier.addConcept("http://example.com/topics/root3", new ArrayList<String>());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root1").size());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root2").size());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root3").size());
assertEquals(3, classifier.getRootConcepts().size());
// Register some non root nodes
classifier.addConcept("http://example.com/topics/node1",
Arrays.asList("http://example.com/topics/root1", "http://example.com/topics/root2"));
classifier.addConcept("http://example.com/topics/node2",
Arrays.asList("http://example.com/topics/root3"));
classifier.addConcept("http://example.com/topics/node3",
Arrays.asList("http://example.com/topics/node1", "http://example.com/topics/node2"));
// the root where not impacted
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root1").size());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root2").size());
assertEquals(0, classifier.getBroaderConcepts("http://example.com/topics/root3").size());
assertEquals(3, classifier.getRootConcepts().size());
// the other nodes have the same broader topics as at creation time
assertEquals(2, classifier.getBroaderConcepts("http://example.com/topics/node1").size());
assertEquals(1, classifier.getBroaderConcepts("http://example.com/topics/node2").size());
assertEquals(2, classifier.getBroaderConcepts("http://example.com/topics/node3").size());
// check the induced narrower relationships
assertEquals(1, classifier.getNarrowerConcepts("http://example.com/topics/root1").size());
assertEquals(1, classifier.getNarrowerConcepts("http://example.com/topics/root2").size());
assertEquals(1, classifier.getNarrowerConcepts("http://example.com/topics/root3").size());
assertEquals(1, classifier.getNarrowerConcepts("http://example.com/topics/node1").size());
assertEquals(1, classifier.getNarrowerConcepts("http://example.com/topics/node2").size());
assertEquals(0, classifier.getNarrowerConcepts("http://example.com/topics/node3").size());
}
@Test
public void testEmptyIndexTopicClassification() throws Exception {
log.info(" --- testEmptyIndexTopicClassification --- ");
TopicClassificationEngine engine = TopicClassificationEngine
.fromParameters(getDefaultClassifierConfigParams());
List<TopicSuggestion> suggestedTopics = engine.suggestTopics("This is a test.");
assertNotNull(suggestedTopics);
assertEquals(suggestedTopics.size(), 0);
}
// @Test
// to get updated to work with the new Solr schema + move the CSV import directly to the classifier or
// training set API
public void testTopicClassification() throws Exception {
log.info(" --- testTopicClassification --- ");
loadSampleTopicsFromTSV();
List<TopicSuggestion> suggestedTopics = classifier
.suggestTopics("The Man Who Shot Liberty Valance is a 1962"
+ " American Western film directed by John Ford,"
+ " narrated by Charlton Heston and starring James"
+ " Stewart, John Wayne and Vivien Leigh.");
assertNotNull(suggestedTopics);
assertEquals(suggestedTopics.size(), 10);
TopicSuggestion bestSuggestion = suggestedTopics.get(0);
assertEquals(bestSuggestion.conceptUri, "Category:American_films");
}
@Test
public void testTrainClassifierFromExamples() throws Exception {
log.info(" --- testTrainClassifierFromExamples --- ");
// mini taxonomy for news articles
String[] business = {"urn:topics/business", "http://dbpedia.org/resource/Business"};
String[] technology = {"urn:topics/technology", "http://dbpedia.org/resource/Technology"};
String apple = "urn:topics/apple";
String sport = "urn:topics/sport";
String football = "urn:topics/football";
String worldcup = "urn:topics/worldcup";
String music = "urn:topics/music";
String law = "urn:topics/law";
classifier.addConcept(business[0], business[1], null);
classifier.addConcept(technology[0], technology[1], null);
classifier.addConcept(sport, null);
classifier.addConcept(music, null);
classifier.addConcept(apple, Arrays.asList(business[0], technology[0]));
classifier.addConcept(football, Arrays.asList(sport));
classifier.addConcept(worldcup, Arrays.asList(football));
// train the classifier on an empty dataset
classifier.setTrainingSet(trainingSet);
assertEquals(7, classifier.updateModel(true));
// the model is updated but does not predict anything
List<TopicSuggestion> suggestions = classifier
.suggestTopics("I like the sound of vuvuzula in the morning!");
assertEquals(0, suggestions.size());
// check that updating the model incrementally without changing the dataset won't change anything.
assertEquals(0, classifier.updateModel(true));
// lets register some examples including stop words as well to limit statistical artifacts cause by
// the small size of the training set.
String STOP_WORDS = " the a is are be in at ";
trainingSet.registerExample(null, "Money, money, money is the root of all evil." + STOP_WORDS,
Arrays.asList(business));
trainingSet.registerExample(null, "VC invested more money in tech startups in 2011." + STOP_WORDS,
Arrays.asList(business[0], technology[0]));
trainingSet.registerExample(null, "Apple's iPad is a small handheld computer with a touch screen UI"
+ STOP_WORDS, Arrays.asList(apple, technology[0]));
trainingSet.registerExample(null, "Apple sold the iPad at a very high price"
+ " and made record profits." + STOP_WORDS,
Arrays.asList(apple, business[0]));
trainingSet.registerExample(null, "Manchester United won 3-2 against FC Barcelona." + STOP_WORDS,
Arrays.asList(football));
trainingSet.registerExample(null, "The 2012 Football Worldcup takes place in Brazil." + STOP_WORDS,
Arrays.asList(football, worldcup));
trainingSet.registerExample(null, "Vuvuzela made the soundtrack of the"
+ " football worldcup of 2010 in South Africa." + STOP_WORDS,
Arrays.asList(football, worldcup, music));
trainingSet.registerExample(null, "Amon Tobin will be live in Paris soon." + STOP_WORDS,
Arrays.asList(music));
// retrain the model: all topics are recomputed
assertEquals(7, classifier.updateModel(true));
// test the trained classifier
suggestions = classifier.suggestTopics("I like the sound of vuvuzula in the morning!");
assertTrue(suggestions.size() >= 4);
assertEquals(worldcup, suggestions.get(0).conceptUri);
assertEquals(music, suggestions.get(1).conceptUri);
assertEquals(football, suggestions.get(2).conceptUri);
assertEquals(sport, suggestions.get(3).conceptUri);
// check that the scores are decreasing:
assertTrue(suggestions.get(0).score >= suggestions.get(1).score);
assertTrue(suggestions.get(1).score >= suggestions.get(2).score);
assertTrue(suggestions.get(2).score >= suggestions.get(3).score);
suggestions = classifier.suggestTopics("Apple is no longer a startup.");
assertTrue(suggestions.size() >= 3);
assertEquals(apple, suggestions.get(0).conceptUri);
assertNull(suggestions.get(0).primaryTopicUri);
assertEquals(Arrays.asList(business[0], technology[0]), suggestions.get(0).broader);
assertEquals(technology[0], suggestions.get(1).conceptUri);
assertEquals(technology[1], suggestions.get(1).primaryTopicUri);
assertEquals(business[0], suggestions.get(2).conceptUri);
assertEquals(business[1], suggestions.get(2).primaryTopicUri);
suggestions = classifier.suggestTopics("You can watch the worldcup on your iPad.");
assertTrue(suggestions.size() >= 2);
assertEquals(apple, suggestions.get(0).conceptUri);
assertEquals(worldcup, suggestions.get(1).conceptUri);
// test incremental update of a single root node
Thread.sleep(10);
trainingSet.registerExample(null, "Dubstep is a broken beat musical style as are Hip-Hop,"
+ " Dancehall or Drum & Bass", Arrays.asList(music));
assertEquals(1, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
suggestions = classifier.suggestTopics("Glory box is best mixed as dubstep.");
assertTrue(suggestions.size() >= 1);
assertEquals(music, suggestions.get(0).conceptUri);
// test incremental update of a leaf node (the parent topic needs re-indexing too)
Thread.sleep(10);
trainingSet.registerExample(null, "The Brazil team has won the cup so many times.",
Arrays.asList(worldcup));
assertEquals(2, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
// it's always possible to rebuild all models from scratch
assertEquals(7, classifier.updateModel(false));
assertEquals(0, classifier.updateModel(true));
// it's also possible to define new topics on an existing model and leverage incremental indexing for
// them as long as there are effectively registered on the classifier
trainingSet.registerExample(null,
"Under Belgian law, judges and prosecutors are judicial officers with equal rank and pay.",
Arrays.asList(law));
trainingSet.registerExample(null, "Prosecutors are typically lawyers who possess a law degree,"
+ " and are recognized as legal professionals by the court"
+ " in which they intend to represent the state.",
Arrays.asList(law));
assertEquals(0, classifier.updateModel(true));
classifier.addConcept(law, null);
assertEquals(1, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
// registering new subtopics invalidate the models of the parent as well
classifier.addConcept("urn:topics/sportsmafia", Arrays.asList(football, business[0]));
assertEquals(3, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
}
@Test
public void testUpdatePerformanceEstimates() throws Exception {
log.info(" --- testUpdatePerformanceEstimates --- ");
ClassificationReport performanceEstimates;
// no registered topic
try {
classifier.getPerformanceEstimates("urn:t/001");
fail("Should have raised ClassifierException");
} catch (ClassifierException e) {
// expected
}
// register some topics
classifier.addConcept("urn:t/001", null);
classifier.addConcept("urn:t/002", Arrays.asList("urn:t/001"));
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertFalse(performanceEstimates.uptodate);
// update the performance metadata manually
classifier.updatePerformanceMetadata("urn:t/002", 0.76f, 0.60f, 34, 32,
Arrays.asList("ex14", "ex78"), Arrays.asList("ex34", "ex23", "ex89"));
classifier.getActiveSolrServer().commit();
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertTrue(performanceEstimates.uptodate);
assertEquals(0.76f, performanceEstimates.precision, 0.01);
assertEquals(0.60f, performanceEstimates.recall, 0.01);
assertEquals(0.67f, performanceEstimates.f1, 0.01);
assertEquals(34, performanceEstimates.positiveSupport);
assertEquals(32, performanceEstimates.negativeSupport);
assertTrue(classifier.getBroaderConcepts("urn:t/002").contains("urn:t/001"));
// accumulate other folds statistics and compute means of statistics
classifier.updatePerformanceMetadata("urn:t/002", 0.79f, 0.63f, 10, 10, Arrays.asList("ex1", "ex5"),
Arrays.asList("ex3", "ex10", "ex11"));
classifier.getActiveSolrServer().commit();
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertTrue(performanceEstimates.uptodate);
assertEquals(0.775f, performanceEstimates.precision, 0.01);
assertEquals(0.615f, performanceEstimates.recall, 0.01);
assertEquals(0.695f, performanceEstimates.f1, 0.01);
assertEquals(44, performanceEstimates.positiveSupport);
assertEquals(42, performanceEstimates.negativeSupport);
}
@Test
public void testCrossValidation() throws Exception {
log.info(" --- testCrossValidation --- ");
// seed a pseudo random number generator for reproducible tests
Random rng = new Random(0);
ClassificationReport performanceEstimates;
// build an artificial data set used for training models and evaluation
int numberOfTopics = 10;
int numberOfDocuments = 100;
int vocabSizeMin = 20;
int vocabSizeMax = 30;
initArtificialTrainingSet(numberOfTopics, numberOfDocuments, vocabSizeMin, vocabSizeMax, rng);
// by default the reports are not computed
performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
assertFalse(performanceEstimates.uptodate);
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertFalse(performanceEstimates.uptodate);
performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
assertFalse(performanceEstimates.uptodate);
try {
classifier.getPerformanceEstimates("urn:doesnotexist");
fail("Should have raised a ClassifierException");
} catch (ClassifierException e) {
// expected
}
// launch an evaluation of the classifier according to the current state of the training set
assertEquals(numberOfTopics, classifier.updatePerformanceEstimates(true));
for (int i = 1; i <= numberOfTopics; i++) {
String topic = String.format("urn:t/%03d", i);
performanceEstimates = classifier.getPerformanceEstimates(topic);
assertTrue(performanceEstimates.uptodate);
assertGreater(performanceEstimates.precision, 0.45f);
assertNotNull(performanceEstimates.falsePositiveExampleIds);
assertNotNull(performanceEstimates.falseNegativeExampleIds);
if (performanceEstimates.precision < 1) {
assertFalse(performanceEstimates.falsePositiveExampleIds.isEmpty());
}
if (performanceEstimates.recall < 1) {
assertFalse(performanceEstimates.falseNegativeExampleIds.isEmpty());
}
assertGreater(performanceEstimates.recall, 0.45f);
assertGreater(performanceEstimates.f1, 0.55f);
// very small support, hence the estimates are unstable, hence we set low min expectations, but we
// need this test to run reasonably fast...
assertGreater(performanceEstimates.positiveSupport, 4);
assertGreater(performanceEstimates.negativeSupport, 4);
assertNotNull(performanceEstimates.evaluationDate);
}
// TODO: test model invalidation by registering a sub topic manually
}
protected void assertGreater(float large, float small) {
if (small > large) {
throw new AssertionError(String.format("Expected %f to be greater than %f.", large, small));
}
}
protected void initArtificialTrainingSet(int numberOfTopics,
int numberOfDocuments,
int vocabSizeMin,
int vocabSizeMax,
Random rng) throws ClassifierException, TrainingSetException {
// define some artificial topics and register them to the classifier
String[] stopWords = randomVocabulary(0, 50, 50, rng);
String[] topics = new String[numberOfTopics];
Map<String,String[]> vocabularies = new TreeMap<String,String[]>();
for (int i = 0; i < numberOfTopics; i++) {
String topic = String.format("urn:t/%03d", i + 1);
topics[i] = topic;
classifier.addConcept(topic, null);
String[] terms = randomVocabulary(i, vocabSizeMin, vocabSizeMax, rng);
vocabularies.put(topic, terms);
}
classifier.setTrainingSet(trainingSet);
// build a random data where each example has a couple of dominating topics and other words
for (int i = 0; i < numberOfDocuments; i++) {
List<String> documentTerms = new ArrayList<String>();
// add terms from some non-dominant topics that are used as classification target
int numberOfDominantTopics = rng.nextInt(3) + 1; // between 1 and 3 topics
List<String> documentTopics = new ArrayList<String>();
for (int j = 0; j < numberOfDominantTopics; j++) {
String topic = randomTopicAndTerms(topics, vocabularies, documentTerms, 100, 150, rng);
documentTopics.add(topic);
}
// add terms from some non-dominant topics
for (int j = 0; j < 5; j++) {
randomTopicAndTerms(topics, vocabularies, documentTerms, 5, 10, rng);
}
// add some non discriminative terms not linked to any topic
for (int k = 0; k < 100; k++) {
documentTerms.add(String.valueOf(stopWords[rng.nextInt(stopWords.length)]));
}
// register the generated example in the training set
String text = StringUtils.join(documentTerms, " ");
trainingSet.registerExample(String.format("example_%03d", i), text, documentTopics);
}
}
protected String[] randomVocabulary(int topicIndex, int vocabSizeMin, int vocabSizeMax, Random rng) {
int vocSize = rng.nextInt(vocabSizeMax + 1 - vocabSizeMin) + vocabSizeMin;
String[] terms = new String[vocSize];
for (int j = 0; j < vocSize; j++) {
// define some artificial vocabulary for each topic to automatically generate random examples
// with some topic structure
terms[j] = String.format("term_%03d_t%03d", j, topicIndex + 1);
}
return terms;
}
protected String randomTopicAndTerms(String[] topics,
Map<String,String[]> vocabularies,
List<String> documentTerms,
int min,
int max,
Random rng) {
String topic = topics[rng.nextInt(topics.length)];
String[] terms = vocabularies.get(topic);
int numberOfDominantTopicTerm = rng.nextInt(max + 1 - min) + min;
for (int k = 0; k < numberOfDominantTopicTerm; k++) {
documentTerms.add(terms[rng.nextInt(terms.length)]);
}
return topic;
}
protected Hashtable<String,Object> getDefaultClassifierConfigParams() {
Hashtable<String,Object> config = new Hashtable<String,Object>();
config.put(EnhancementEngine.PROPERTY_NAME, "test-engine");
config.put(TopicClassificationEngine.SOLR_CORE, classifierSolrServer);
//those are now optional properties
// config.put(TopicClassificationEngine.ENTRY_ID_FIELD, "entry_id");
// config.put(TopicClassificationEngine.ENTRY_TYPE_FIELD, "entry_type");
// config.put(TopicClassificationEngine.MODEL_ENTRY_ID_FIELD, "model_entry_id");
// config.put(TopicClassificationEngine.CONCEPT_URI_FIELD, "concept");
// config.put(TopicClassificationEngine.PRIMARY_TOPIC_URI_FIELD, "primary_topic");
// config.put(TopicClassificationEngine.SIMILARTITY_FIELD, "classifier_features");
// config.put(TopicClassificationEngine.BROADER_FIELD, "broader");
// config.put(TopicClassificationEngine.MODEL_UPDATE_DATE_FIELD, "last_update_dt");
// config.put(TopicClassificationEngine.MODEL_EVALUATION_DATE_FIELD, "last_evaluation_dt");
// config.put(TopicClassificationEngine.PRECISION_FIELD, "precision");
// config.put(TopicClassificationEngine.RECALL_FIELD, "recall");
// config.put(TopicClassificationEngine.POSITIVE_SUPPORT_FIELD, "positive_support");
// config.put(TopicClassificationEngine.NEGATIVE_SUPPORT_FIELD, "negative_support");
// config.put(TopicClassificationEngine.FALSE_POSITIVES_FIELD, "false_positives");
// config.put(TopicClassificationEngine.FALSE_NEGATIVES_FIELD, "false_negatives");
return config;
}
protected Hashtable<String,Object> getDefaultTrainingSetConfigParams() {
Hashtable<String,Object> config = new Hashtable<String,Object>();
config.put(SolrTrainingSet.SOLR_CORE, trainingSetSolrServer);
config.put(SolrTrainingSet.TRAINING_SET_NAME, "test-training-set");
config.put(SolrTrainingSet.EXAMPLE_ID_FIELD, "id");
config.put(SolrTrainingSet.EXAMPLE_TEXT_FIELD, "text");
config.put(SolrTrainingSet.TOPICS_URI_FIELD, "topics");
config.put(SolrTrainingSet.MODIFICATION_DATE_FIELD, "modification_dt");
return config;
}
}