| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.suggest.analyzing; |
| |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Locale; |
| import java.util.Map; |
| import java.util.Set; |
| |
| import org.apache.lucene.analysis.Analyzer; |
| import org.apache.lucene.analysis.CharArraySet; |
| import org.apache.lucene.analysis.MockAnalyzer; |
| import org.apache.lucene.analysis.MockTokenizer; |
| import org.apache.lucene.analysis.StopFilter; |
| import org.apache.lucene.analysis.Tokenizer; |
| import org.apache.lucene.document.Document; |
| import org.apache.lucene.search.suggest.Input; |
| import org.apache.lucene.search.suggest.InputArrayIterator; |
| import org.apache.lucene.search.suggest.InputIterator; |
| import org.apache.lucene.search.suggest.Lookup.LookupResult; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.LineFileDocs; |
| import org.apache.lucene.util.LuceneTestCase; |
| import org.apache.lucene.util.TestUtil; |
| import org.junit.Ignore; |
| |
| public class TestFreeTextSuggester extends LuceneTestCase { |
| |
| public void testBasic() throws Exception { |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo bar baz blah", 50), |
| new Input("boo foo bar foo bee", 20) |
| ); |
| |
| Analyzer a = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| assertEquals(2, sug.getCount()); |
| |
| for(int i=0;i<2;i++) { |
| |
| // Uses bigram model and unigram backoff: |
| assertEquals("foo bar/0.67 foo bee/0.33 baz/0.04 blah/0.04 boo/0.04", |
| toString(sug.lookup("foo b", 10))); |
| |
| // Uses only bigram model: |
| assertEquals("foo bar/0.67 foo bee/0.33", |
| toString(sug.lookup("foo ", 10))); |
| |
| // Uses only unigram model: |
| assertEquals("foo/0.33", |
| toString(sug.lookup("foo", 10))); |
| |
| // Uses only unigram model: |
| assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11", |
| toString(sug.lookup("b", 10))); |
| |
| // Try again after save/load: |
| Path tmpDir = createTempDir("FreeTextSuggesterTest"); |
| |
| Path path = tmpDir.resolve("suggester"); |
| |
| OutputStream os = Files.newOutputStream(path); |
| sug.store(os); |
| os.close(); |
| |
| InputStream is = Files.newInputStream(path); |
| sug = new FreeTextSuggester(a, a, 2, (byte) 0x20); |
| sug.load(is); |
| is.close(); |
| assertEquals(2, sug.getCount()); |
| } |
| a.close(); |
| } |
| |
| public void testIllegalByteDuringBuild() throws Exception { |
| // Default separator is INFORMATION SEPARATOR TWO |
| // (0x1e), so no input token is allowed to contain it |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo\u001ebar baz", 50) |
| ); |
| Analyzer analyzer = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(analyzer); |
| expectThrows(IllegalArgumentException.class, () -> { |
| sug.build(new InputArrayIterator(keys)); |
| }); |
| |
| analyzer.close(); |
| } |
| |
| public void testIllegalByteDuringQuery() throws Exception { |
| // Default separator is INFORMATION SEPARATOR TWO |
| // (0x1e), so no input token is allowed to contain it |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo bar baz", 50) |
| ); |
| Analyzer analyzer = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(analyzer); |
| sug.build(new InputArrayIterator(keys)); |
| |
| expectThrows(IllegalArgumentException.class, () -> { |
| sug.lookup("foo\u001eb", 10); |
| }); |
| |
| analyzer.close(); |
| } |
| |
| @Ignore |
| public void testWiki() throws Exception { |
| final LineFileDocs lfd = new LineFileDocs(null, "/lucenedata/enwiki/enwiki-20120502-lines-1k.txt"); |
| // Skip header: |
| lfd.nextDoc(); |
| Analyzer analyzer = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(analyzer); |
| sug.build(new InputIterator() { |
| |
| private int count; |
| |
| @Override |
| public long weight() { |
| return 1; |
| } |
| |
| @Override |
| public BytesRef next() { |
| Document doc; |
| try { |
| doc = lfd.nextDoc(); |
| } catch (IOException ioe) { |
| throw new RuntimeException(ioe); |
| } |
| if (doc == null) { |
| return null; |
| } |
| if (count++ == 10000) { |
| return null; |
| } |
| return new BytesRef(doc.get("body")); |
| } |
| |
| @Override |
| public BytesRef payload() { |
| return null; |
| } |
| |
| @Override |
| public boolean hasPayloads() { |
| return false; |
| } |
| |
| @Override |
| public Set<BytesRef> contexts() { |
| return null; |
| } |
| |
| @Override |
| public boolean hasContexts() { |
| return false; |
| } |
| }); |
| if (VERBOSE) { |
| System.out.println(sug.ramBytesUsed() + " bytes"); |
| |
| List<LookupResult> results = sug.lookup("general r", 10); |
| System.out.println("results:"); |
| for(LookupResult result : results) { |
| System.out.println(" " + result); |
| } |
| } |
| analyzer.close(); |
| lfd.close(); |
| } |
| |
| // Make sure you can suggest based only on unigram model: |
| public void testUnigrams() throws Exception { |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo bar baz blah boo foo bar foo bee", 50) |
| ); |
| |
| Analyzer a = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 1, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| // Sorts first by count, descending, second by term, ascending |
| assertEquals("bar/0.22 baz/0.11 bee/0.11 blah/0.11 boo/0.11", |
| toString(sug.lookup("b", 10))); |
| a.close(); |
| } |
| |
| // Make sure the last token is not duplicated |
| public void testNoDupsAcrossGrams() throws Exception { |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo bar bar bar bar", 50) |
| ); |
| Analyzer a = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| assertEquals("foo bar/1.00", |
| toString(sug.lookup("foo b", 10))); |
| a.close(); |
| } |
| |
| // Lookup of just empty string produces unicode only matches: |
| public void testEmptyString() throws Exception { |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("foo bar bar bar bar", 50) |
| ); |
| Analyzer a = new MockAnalyzer(random()); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 2, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| expectThrows(IllegalArgumentException.class, () -> { |
| sug.lookup("", 10); |
| }); |
| |
| a.close(); |
| } |
| |
| // With one ending hole, ShingleFilter produces "of _" and |
| // we should properly predict from that: |
| public void testEndingHole() throws Exception { |
| // Just deletes "of" |
| Analyzer a = new Analyzer() { |
| @Override |
| public TokenStreamComponents createComponents(String field) { |
| Tokenizer tokenizer = new MockTokenizer(); |
| CharArraySet stopSet = StopFilter.makeStopSet("of"); |
| return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet)); |
| } |
| }; |
| |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("wizard of oz", 50) |
| ); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| assertEquals("wizard _ oz/1.00", |
| toString(sug.lookup("wizard of", 10))); |
| |
| // Falls back to unigram model, with backoff 0.4 times |
| // prop 0.5: |
| assertEquals("oz/0.20", |
| toString(sug.lookup("wizard o", 10))); |
| a.close(); |
| } |
| |
| // If the number of ending holes exceeds the ngrams window |
| // then there are no predictions, because ShingleFilter |
| // does not produce e.g. a hole only "_ _" token: |
| public void testTwoEndingHoles() throws Exception { |
| // Just deletes "of" |
| Analyzer a = new Analyzer() { |
| @Override |
| public TokenStreamComponents createComponents(String field) { |
| Tokenizer tokenizer = new MockTokenizer(); |
| CharArraySet stopSet = StopFilter.makeStopSet("of"); |
| return new TokenStreamComponents(tokenizer, new StopFilter(tokenizer, stopSet)); |
| } |
| }; |
| |
| Iterable<Input> keys = AnalyzingSuggesterTest.shuffle( |
| new Input("wizard of of oz", 50) |
| ); |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, 3, (byte) 0x20); |
| sug.build(new InputArrayIterator(keys)); |
| assertEquals("", |
| toString(sug.lookup("wizard of of", 10))); |
| a.close(); |
| } |
| |
| private static Comparator<LookupResult> byScoreThenKey = new Comparator<LookupResult>() { |
| @Override |
| public int compare(LookupResult a, LookupResult b) { |
| if (a.value > b.value) { |
| return -1; |
| } else if (a.value < b.value) { |
| return 1; |
| } else { |
| // Tie break by UTF16 sort order: |
| return ((String) a.key).compareTo((String) b.key); |
| } |
| } |
| }; |
| |
| public void testRandom() throws IOException { |
| String[] terms = new String[TestUtil.nextInt(random(), 2, 10)]; |
| Set<String> seen = new HashSet<>(); |
| while (seen.size() < terms.length) { |
| String token = TestUtil.randomSimpleString(random(), 1, 5); |
| if (!seen.contains(token)) { |
| terms[seen.size()] = token; |
| seen.add(token); |
| } |
| } |
| |
| Analyzer a = new MockAnalyzer(random()); |
| |
| int numDocs = atLeast(10); |
| long totTokens = 0; |
| final String[][] docs = new String[numDocs][]; |
| for(int i=0;i<numDocs;i++) { |
| docs[i] = new String[atLeast(100)]; |
| if (VERBOSE) { |
| System.out.print(" doc " + i + ":"); |
| } |
| for(int j=0;j<docs[i].length;j++) { |
| docs[i][j] = getZipfToken(terms); |
| if (VERBOSE) { |
| System.out.print(" " + docs[i][j]); |
| } |
| } |
| if (VERBOSE) { |
| System.out.println(); |
| } |
| totTokens += docs[i].length; |
| } |
| |
| int grams = TestUtil.nextInt(random(), 1, 4); |
| |
| if (VERBOSE) { |
| System.out.println("TEST: " + terms.length + " terms; " + numDocs + " docs; " + grams + " grams"); |
| } |
| |
| // Build suggester model: |
| FreeTextSuggester sug = new FreeTextSuggester(a, a, grams, (byte) 0x20); |
| sug.build(new InputIterator() { |
| int upto; |
| |
| @Override |
| public BytesRef next() { |
| if (upto == docs.length) { |
| return null; |
| } else { |
| StringBuilder b = new StringBuilder(); |
| for(String token : docs[upto]) { |
| b.append(' '); |
| b.append(token); |
| } |
| upto++; |
| return new BytesRef(b.toString()); |
| } |
| } |
| |
| @Override |
| public long weight() { |
| return random().nextLong(); |
| } |
| |
| @Override |
| public BytesRef payload() { |
| return null; |
| } |
| |
| @Override |
| public boolean hasPayloads() { |
| return false; |
| } |
| |
| @Override |
| public Set<BytesRef> contexts() { |
| return null; |
| } |
| |
| @Override |
| public boolean hasContexts() { |
| return false; |
| } |
| }); |
| |
| // Build inefficient but hopefully correct model: |
| List<Map<String,Integer>> gramCounts = new ArrayList<>(grams); |
| for(int gram=0;gram<grams;gram++) { |
| if (VERBOSE) { |
| System.out.println("TEST: build model for gram=" + gram); |
| } |
| Map<String,Integer> model = new HashMap<>(); |
| gramCounts.add(model); |
| for(String[] doc : docs) { |
| for(int i=0;i<doc.length-gram;i++) { |
| StringBuilder b = new StringBuilder(); |
| for(int j=i;j<=i+gram;j++) { |
| if (j > i) { |
| b.append(' '); |
| } |
| b.append(doc[j]); |
| } |
| String token = b.toString(); |
| Integer curCount = model.get(token); |
| if (curCount == null) { |
| model.put(token, 1); |
| } else { |
| model.put(token, 1 + curCount); |
| } |
| if (VERBOSE) { |
| System.out.println(" add '" + token + "' -> count=" + model.get(token)); |
| } |
| } |
| } |
| } |
| |
| int lookups = atLeast(100); |
| for(int iter=0;iter<lookups;iter++) { |
| String[] tokens = new String[TestUtil.nextInt(random(), 1, 5)]; |
| for(int i=0;i<tokens.length;i++) { |
| tokens[i] = getZipfToken(terms); |
| } |
| |
| // Maybe trim last token; be sure not to create the |
| // empty string: |
| int trimStart; |
| if (tokens.length == 1) { |
| trimStart = 1; |
| } else { |
| trimStart = 0; |
| } |
| int trimAt = TestUtil.nextInt(random(), trimStart, tokens[tokens.length - 1].length()); |
| tokens[tokens.length-1] = tokens[tokens.length-1].substring(0, trimAt); |
| |
| int num = TestUtil.nextInt(random(), 1, 100); |
| StringBuilder b = new StringBuilder(); |
| for(String token : tokens) { |
| b.append(' '); |
| b.append(token); |
| } |
| String query = b.toString(); |
| query = query.substring(1); |
| |
| if (VERBOSE) { |
| System.out.println("\nTEST: iter=" + iter + " query='" + query + "' num=" + num); |
| } |
| |
| // Expected: |
| List<LookupResult> expected = new ArrayList<>(); |
| double backoff = 1.0; |
| seen = new HashSet<>(); |
| |
| if (VERBOSE) { |
| System.out.println(" compute expected"); |
| } |
| for(int i=grams-1;i>=0;i--) { |
| if (VERBOSE) { |
| System.out.println(" grams=" + i); |
| } |
| |
| if (tokens.length < i+1) { |
| // Don't have enough tokens to use this model |
| if (VERBOSE) { |
| System.out.println(" skip"); |
| } |
| continue; |
| } |
| |
| if (i == 0 && tokens[tokens.length-1].length() == 0) { |
| // Never suggest unigrams from empty string: |
| if (VERBOSE) { |
| System.out.println(" skip unigram priors only"); |
| } |
| continue; |
| } |
| |
| // Build up "context" ngram: |
| b = new StringBuilder(); |
| for(int j=tokens.length-i-1;j<tokens.length-1;j++) { |
| b.append(' '); |
| b.append(tokens[j]); |
| } |
| String context = b.toString(); |
| if (context.length() > 0) { |
| context = context.substring(1); |
| } |
| if (VERBOSE) { |
| System.out.println(" context='" + context + "'"); |
| } |
| long contextCount; |
| if (context.length() == 0) { |
| contextCount = totTokens; |
| } else { |
| Integer count = gramCounts.get(i-1).get(context); |
| if (count == null) { |
| // We never saw this context: |
| backoff *= FreeTextSuggester.ALPHA; |
| if (VERBOSE) { |
| System.out.println(" skip: never saw context"); |
| } |
| continue; |
| } |
| contextCount = count; |
| } |
| if (VERBOSE) { |
| System.out.println(" contextCount=" + contextCount); |
| } |
| Map<String,Integer> model = gramCounts.get(i); |
| |
| // First pass, gather all predictions for this model: |
| if (VERBOSE) { |
| System.out.println(" find terms w/ prefix=" + tokens[tokens.length-1]); |
| } |
| List<LookupResult> tmp = new ArrayList<>(); |
| for(String term : terms) { |
| if (term.startsWith(tokens[tokens.length-1])) { |
| if (VERBOSE) { |
| System.out.println(" term=" + term); |
| } |
| if (seen.contains(term)) { |
| if (VERBOSE) { |
| System.out.println(" skip seen"); |
| } |
| continue; |
| } |
| String ngram = (context + " " + term).trim(); |
| Integer count = model.get(ngram); |
| if (count != null) { |
| LookupResult lr = new LookupResult(ngram, (long) (Long.MAX_VALUE * (backoff * (double) count / contextCount))); |
| tmp.add(lr); |
| if (VERBOSE) { |
| System.out.println(" add tmp key='" + lr.key + "' score=" + lr.value); |
| } |
| } |
| } |
| } |
| |
| // Second pass, trim to only top N, and fold those |
| // into overall suggestions: |
| Collections.sort(tmp, byScoreThenKey); |
| if (tmp.size() > num) { |
| tmp.subList(num, tmp.size()).clear(); |
| } |
| for(LookupResult result : tmp) { |
| String key = result.key.toString(); |
| int idx = key.lastIndexOf(' '); |
| String lastToken; |
| if (idx != -1) { |
| lastToken = key.substring(idx+1); |
| } else { |
| lastToken = key; |
| } |
| if (!seen.contains(lastToken)) { |
| seen.add(lastToken); |
| expected.add(result); |
| if (VERBOSE) { |
| System.out.println(" keep key='" + result.key + "' score=" + result.value); |
| } |
| } |
| } |
| |
| backoff *= FreeTextSuggester.ALPHA; |
| } |
| |
| Collections.sort(expected, byScoreThenKey); |
| |
| if (expected.size() > num) { |
| expected.subList(num, expected.size()).clear(); |
| } |
| |
| // Actual: |
| List<LookupResult> actual = sug.lookup(query, num); |
| |
| if (VERBOSE) { |
| System.out.println(" expected: " + expected); |
| System.out.println(" actual: " + actual); |
| } |
| |
| assertEquals(expected.toString(), actual.toString()); |
| } |
| a.close(); |
| } |
| |
| private static String getZipfToken(String[] tokens) { |
| // Zipf-like distribution: |
| for(int k=0;k<tokens.length;k++) { |
| if (random().nextBoolean() || k == tokens.length-1) { |
| return tokens[k]; |
| } |
| } |
| assert false; |
| return null; |
| } |
| |
| private static String toString(List<LookupResult> results) { |
| StringBuilder b = new StringBuilder(); |
| for(LookupResult result : results) { |
| b.append(' '); |
| b.append(result.key); |
| b.append('/'); |
| b.append(String.format(Locale.ROOT, "%.2f", ((double) result.value)/Long.MAX_VALUE)); |
| } |
| return b.toString().trim(); |
| } |
| } |
| |