blob: 530a4c3ce3ef7d502bc6e0970435acbc9871b45a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.suggest.analyzing;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.analysis.StopFilter;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.suggest.Input;
import org.apache.lucene.search.suggest.InputArrayIterator;
import org.apache.lucene.search.suggest.InputIterator;
import org.apache.lucene.search.suggest.Lookup.LookupResult;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LineFileDocs;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.junit.Ignore;
public class TestFreeTextSuggester extends LuceneTestCase {
public void testBasic() throws Exception {
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo bar baz blah", 50),
new Input("boo foo bar foo bee", 20)
);
Analyzer a = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
assertEquals(2, sug.getCount());
for(int i=0;i<2;i++) {
// Uses bigram model and unigram backoff:
assertEquals("foo bar/0.67 foo bee/0.33 baz/0.04 blah/0.04 boo/0.04",
toString(sug.lookup("foo b", 10)));
// Uses only bigram model:
assertEquals("foo bar/0.67 foo bee/0.33",
toString(sug.lookup("foo ", 10)));
// Uses only unigram model:
assertEquals("foo/0.33",
toString(sug.lookup("foo", 10)));
// Uses only unigram model:
assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11",
toString(sug.lookup("b", 10)));
// Try again after save/load:
Path tmpDir = createTempDir("FreeTextSuggesterTest");
Path path = tmpDir.resolve("suggester");
OutputStream os = Files.newOutputStream(path);
sug.store(os);
os.close();
InputStream is = Files.newInputStream(path);
sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
sug.load(is);
is.close();
assertEquals(2, sug.getCount());
}
a.close();
}
public void testIllegalByteDuringBuild() throws Exception {
// Default separator is INFORMATION SEPARATOR TWO
// (0x1e), so no input token is allowed to contain it
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo\u001ebar baz", 50)
);
Analyzer analyzer = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(analyzer);
expectThrows(IllegalArgumentException.class, () -> {
sug.build(new InputArrayIterator(keys));
});
analyzer.close();
}
public void testIllegalByteDuringQuery() throws Exception {
// Default separator is INFORMATION SEPARATOR TWO
// (0x1e), so no input token is allowed to contain it
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo bar baz", 50)
);
Analyzer analyzer = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(analyzer);
sug.build(new InputArrayIterator(keys));
expectThrows(IllegalArgumentException.class, () -> {
sug.lookup("foo\u001eb", 10);
});
analyzer.close();
}
@Ignore
public void testWiki() throws Exception {
final LineFileDocs lfd = new LineFileDocs(null, "/lucenedata/enwiki/enwiki-20120502-lines-1k.txt");
// Skip header:
lfd.nextDoc();
Analyzer analyzer = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(analyzer);
sug.build(new InputIterator() {
private int count;
@Override
public long weight() {
return 1;
}
@Override
public BytesRef next() {
Document doc;
try {
doc = lfd.nextDoc();
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
if (doc == null) {
return null;
}
if (count++ == 10000) {
return null;
}
return new BytesRef(doc.get("body"));
}
@Override
public BytesRef payload() {
return null;
}
@Override
public boolean hasPayloads() {
return false;
}
@Override
public Set<BytesRef> contexts() {
return null;
}
@Override
public boolean hasContexts() {
return false;
}
});
if (VERBOSE) {
System.out.println(sug.ramBytesUsed() + " bytes");
List<LookupResult> results = sug.lookup("general r", 10);
System.out.println("results:");
for(LookupResult result : results) {
System.out.println(" " + result);
}
}
analyzer.close();
lfd.close();
}
// Make sure you can suggest based only on unigram model:
public void testUnigrams() throws Exception {
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo bar baz blah boo foo bar foo bee", 50)
);
Analyzer a = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(a, a, 1, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
// Sorts first by count, descending, second by term, ascending
assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11",
toString(sug.lookup("b", 10)));
a.close();
}
// Make sure the last token is not duplicated
public void testNoDupsAcrossGrams() throws Exception {
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo bar bar bar bar", 50)
);
Analyzer a = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
assertEquals("foo bar/1.00",
toString(sug.lookup("foo b", 10)));
a.close();
}
// Lookup of just empty string produces unicode only matches:
public void testEmptyString() throws Exception {
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("foo bar bar bar bar", 50)
);
Analyzer a = new MockAnalyzer(random());
FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
expectThrows(IllegalArgumentException.class, () -> {
sug.lookup("", 10);
});
a.close();
}
// With one ending hole, ShingleFilter produces "of _" and
// we should properly predict from that:
public void testEndingHole() throws Exception {
// Just deletes "of"
Analyzer a = new Analyzer() {
@Override
public TokenStreamComponents createComponents(String field) {
Tokenizer tokenizer = new MockTokenizer();
CharArraySet stopSet = StopFilter.makeStopSet("of");
return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet));
}
};
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("wizard of oz", 50)
);
FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
assertEquals("wizard _ oz/1.00",
toString(sug.lookup("wizard of", 10)));
// Falls back to unigram model, with backoff 0.4 times
// prop 0.5:
assertEquals("oz/0.20",
toString(sug.lookup("wizard o", 10)));
a.close();
}
// If the number of ending holes exceeds the ngrams window
// then there are no predictions, because ShingleFilter
// does not produce e.g. a hole only "_ _" token:
public void testTwoEndingHoles() throws Exception {
// Just deletes "of"
Analyzer a = new Analyzer() {
@Override
public TokenStreamComponents createComponents(String field) {
Tokenizer tokenizer = new MockTokenizer();
CharArraySet stopSet = StopFilter.makeStopSet("of");
return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet));
}
};
Iterable<Input> keys = AnalyzingSuggesterTest.shuffle(
new Input("wizard of of oz", 50)
);
FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20);
sug.build(new InputArrayIterator(keys));
assertEquals("",
toString(sug.lookup("wizard of of", 10)));
a.close();
}
private static Comparator<LookupResult> byScoreThenKey = new Comparator<LookupResult>() {
@Override
public int compare(LookupResult a, LookupResult b) {
if (a.value > b.value) {
return -1;
} else if (a.value < b.value) {
return 1;
} else {
// Tie break by UTF16 sort order:
return ((String) a.key).compareTo((String) b.key);
}
}
};
public void testRandom() throws IOException {
String[] terms = new String[TestUtil.nextInt(random(), 2, 10)];
Set<String> seen = new HashSet<>();
while (seen.size() < terms.length) {
String token = TestUtil.randomSimpleString(random(), 1, 5);
if (!seen.contains(token)) {
terms[seen.size()] = token;
seen.add(token);
}
}
Analyzer a = new MockAnalyzer(random());
int numDocs = atLeast(10);
long totTokens = 0;
final String[][] docs = new String[numDocs][];
for(int i=0;i<numDocs;i++) {
docs[i] = new String[atLeast(100)];
if (VERBOSE) {
System.out.print(" doc " + i + ":");
}
for(int j=0;j<docs[i].length;j++) {
docs[i][j] = getZipfToken(terms);
if (VERBOSE) {
System.out.print(" " + docs[i][j]);
}
}
if (VERBOSE) {
System.out.println();
}
totTokens += docs[i].length;
}
int grams = TestUtil.nextInt(random(), 1, 4);
if (VERBOSE) {
System.out.println("TEST: " + terms.length + " terms; " + numDocs + " docs; " + grams + " grams");
}
// Build suggester model:
FreeTextSuggester sug = new FreeTextSuggester(a, a, grams, (byte) 0x20);
sug.build(new InputIterator() {
int upto;
@Override
public BytesRef next() {
if (upto == docs.length) {
return null;
} else {
StringBuilder b = new StringBuilder();
for(String token : docs[upto]) {
b.append(' ');
b.append(token);
}
upto++;
return new BytesRef(b.toString());
}
}
@Override
public long weight() {
return random().nextLong();
}
@Override
public BytesRef payload() {
return null;
}
@Override
public boolean hasPayloads() {
return false;
}
@Override
public Set<BytesRef> contexts() {
return null;
}
@Override
public boolean hasContexts() {
return false;
}
});
// Build inefficient but hopefully correct model:
List<Map<String,Integer>> gramCounts = new ArrayList<>(grams);
for(int gram=0;gram<grams;gram++) {
if (VERBOSE) {
System.out.println("TEST: build model for gram=" + gram);
}
Map<String,Integer> model = new HashMap<>();
gramCounts.add(model);
for(String[] doc : docs) {
for(int i=0;i<doc.length-gram;i++) {
StringBuilder b = new StringBuilder();
for(int j=i;j<=i+gram;j++) {
if (j > i) {
b.append(' ');
}
b.append(doc[j]);
}
String token = b.toString();
Integer curCount = model.get(token);
if (curCount == null) {
model.put(token, 1);
} else {
model.put(token, 1 + curCount);
}
if (VERBOSE) {
System.out.println(" add '" + token + "' -> count=" + model.get(token));
}
}
}
}
int lookups = atLeast(100);
for(int iter=0;iter<lookups;iter++) {
String[] tokens = new String[TestUtil.nextInt(random(), 1, 5)];
for(int i=0;i<tokens.length;i++) {
tokens[i] = getZipfToken(terms);
}
// Maybe trim last token; be sure not to create the
// empty string:
int trimStart;
if (tokens.length == 1) {
trimStart = 1;
} else {
trimStart = 0;
}
int trimAt = TestUtil.nextInt(random(), trimStart, tokens[tokens.length - 1].length());
tokens[tokens.length-1] = tokens[tokens.length-1].substring(0, trimAt);
int num = TestUtil.nextInt(random(), 1, 100);
StringBuilder b = new StringBuilder();
for(String token : tokens) {
b.append(' ');
b.append(token);
}
String query = b.toString();
query = query.substring(1);
if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter + " query='" + query + "' num=" + num);
}
// Expected:
List<LookupResult> expected = new ArrayList<>();
double backoff = 1.0;
seen = new HashSet<>();
if (VERBOSE) {
System.out.println(" compute expected");
}
for(int i=grams-1;i>=0;i--) {
if (VERBOSE) {
System.out.println(" grams=" + i);
}
if (tokens.length < i+1) {
// Don't have enough tokens to use this model
if (VERBOSE) {
System.out.println(" skip");
}
continue;
}
if (i == 0 && tokens[tokens.length-1].length() == 0) {
// Never suggest unigrams from empty string:
if (VERBOSE) {
System.out.println(" skip unigram priors only");
}
continue;
}
// Build up "context" ngram:
b = new StringBuilder();
for(int j=tokens.length-i-1;j<tokens.length-1;j++) {
b.append(' ');
b.append(tokens[j]);
}
String context = b.toString();
if (context.length() > 0) {
context = context.substring(1);
}
if (VERBOSE) {
System.out.println(" context='" + context + "'");
}
long contextCount;
if (context.length() == 0) {
contextCount = totTokens;
} else {
Integer count = gramCounts.get(i-1).get(context);
if (count == null) {
// We never saw this context:
backoff *= FreeTextSuggester.ALPHA;
if (VERBOSE) {
System.out.println(" skip: never saw context");
}
continue;
}
contextCount = count;
}
if (VERBOSE) {
System.out.println(" contextCount=" + contextCount);
}
Map<String,Integer> model = gramCounts.get(i);
// First pass, gather all predictions for this model:
if (VERBOSE) {
System.out.println(" find terms w/ prefix=" + tokens[tokens.length-1]);
}
List<LookupResult> tmp = new ArrayList<>();
for(String term : terms) {
if (term.startsWith(tokens[tokens.length-1])) {
if (VERBOSE) {
System.out.println(" term=" + term);
}
if (seen.contains(term)) {
if (VERBOSE) {
System.out.println(" skip seen");
}
continue;
}
String ngram = (context + " " + term).trim();
Integer count = model.get(ngram);
if (count != null) {
LookupResult lr = new LookupResult(ngram, (long) (Long.MAX_VALUE * (backoff * (double) count / contextCount)));
tmp.add(lr);
if (VERBOSE) {
System.out.println(" add tmp key='" + lr.key + "' score=" + lr.value);
}
}
}
}
// Second pass, trim to only top N, and fold those
// into overall suggestions:
Collections.sort(tmp, byScoreThenKey);
if (tmp.size() > num) {
tmp.subList(num, tmp.size()).clear();
}
for(LookupResult result : tmp) {
String key = result.key.toString();
int idx = key.lastIndexOf(' ');
String lastToken;
if (idx != -1) {
lastToken = key.substring(idx+1);
} else {
lastToken = key;
}
if (!seen.contains(lastToken)) {
seen.add(lastToken);
expected.add(result);
if (VERBOSE) {
System.out.println(" keep key='" + result.key + "' score=" + result.value);
}
}
}
backoff *= FreeTextSuggester.ALPHA;
}
Collections.sort(expected, byScoreThenKey);
if (expected.size() > num) {
expected.subList(num, expected.size()).clear();
}
// Actual:
List<LookupResult> actual = sug.lookup(query, num);
if (VERBOSE) {
System.out.println(" expected: " + expected);
System.out.println(" actual: " + actual);
}
assertEquals(expected.toString(), actual.toString());
}
a.close();
}
private static String getZipfToken(String[] tokens) {
// Zipf-like distribution:
for(int k=0;k<tokens.length;k++) {
if (random().nextBoolean() || k == tokens.length-1) {
return tokens[k];
}
}
assert false;
return null;
}
private static String toString(List<LookupResult> results) {
StringBuilder b = new StringBuilder();
for(LookupResult result : results) {
b.append(' ');
b.append(result.key);
b.append('/');
b.append(String.format(Locale.ROOT, "%.2f", ((double) result.value)/Long.MAX_VALUE));
}
return b.toString().trim();
}
}