| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| |
| import java.io.IOException; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Map; |
| import java.util.Set; |
| |
| import org.apache.lucene.analysis.MockAnalyzer; |
| import org.apache.lucene.document.Document; |
| import org.apache.lucene.document.Field; |
| import org.apache.lucene.index.FieldInvertState; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.RandomIndexWriter; |
| import org.apache.lucene.index.Term; |
| import org.apache.lucene.search.BooleanClause.Occur; |
| import org.apache.lucene.search.similarities.Similarity; |
| import org.apache.lucene.store.Directory; |
| import org.apache.lucene.store.RAMDirectory; |
| import org.apache.lucene.util.LuceneTestCase; |
| import org.junit.AfterClass; |
| import org.junit.BeforeClass; |
| import org.junit.Test; |
| |
| public class TestSubScorerFreqs extends LuceneTestCase { |
| |
| private static Directory dir; |
| private static IndexSearcher s; |
| |
| @BeforeClass |
| public static void makeIndex() throws Exception { |
| dir = new RAMDirectory(); |
| RandomIndexWriter w = new RandomIndexWriter( |
| random(), dir, newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy())); |
| // make sure we have more than one segment occationally |
| int num = atLeast(31); |
| for (int i = 0; i < num; i++) { |
| Document doc = new Document(); |
| doc.add(newTextField("f", "a b c d b c d c d d", Field.Store.NO)); |
| w.addDocument(doc); |
| |
| doc = new Document(); |
| doc.add(newTextField("f", "a b c d", Field.Store.NO)); |
| w.addDocument(doc); |
| } |
| |
| s = newSearcher(w.getReader()); |
| s.setSimilarity(new CountingSimilarity()); |
| w.close(); |
| } |
| |
| @AfterClass |
| public static void finish() throws Exception { |
| s.getIndexReader().close(); |
| s = null; |
| dir.close(); |
| dir = null; |
| } |
| |
| private static class CountingCollector extends FilterCollector { |
| public final Map<Integer, Map<Query, Float>> docCounts = new HashMap<>(); |
| |
| private final Map<Query, Scorable> subScorers = new HashMap<>(); |
| private final Set<String> relationships; |
| |
| public CountingCollector(Collector other) { |
| this(other, new HashSet<>(Arrays.asList("MUST", "SHOULD", "MUST_NOT"))); |
| } |
| |
| public CountingCollector(Collector other, Set<String> relationships) { |
| super(other); |
| this.relationships = relationships; |
| } |
| |
| public void setSubScorers(Scorable scorer) throws IOException { |
| scorer = AssertingScorable.unwrap(scorer); |
| for (Scorable.ChildScorable child : scorer.getChildren()) { |
| if (relationships.contains(child.relationship)) { |
| setSubScorers(child.child); |
| } |
| } |
| subScorers.put(((Scorer)scorer).getWeight().getQuery(), scorer); |
| } |
| |
| public LeafCollector getLeafCollector(LeafReaderContext context) |
| throws IOException { |
| final int docBase = context.docBase; |
| return new FilterLeafCollector(super.getLeafCollector(context)) { |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| final Map<Query, Float> freqs = new HashMap<Query, Float>(); |
| for (Map.Entry<Query, Scorable> ent : subScorers.entrySet()) { |
| Scorable value = ent.getValue(); |
| int matchId = value.docID(); |
| freqs.put(ent.getKey(), matchId == doc ? value.score() : 0.0f); |
| } |
| docCounts.put(doc + docBase, freqs); |
| super.collect(doc); |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| super.setScorer(scorer); |
| subScorers.clear(); |
| setSubScorers(scorer); |
| } |
| |
| }; |
| } |
| |
| } |
| |
| private static final float FLOAT_TOLERANCE = 0.00001F; |
| |
| @Test |
| public void testTermQuery() throws Exception { |
| TermQuery q = new TermQuery(new Term("f", "d")); |
| CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10, Integer.MAX_VALUE)); |
| s.search(q, c); |
| final int maxDocs = s.getIndexReader().maxDoc(); |
| assertEquals(maxDocs, c.docCounts.size()); |
| for (int i = 0; i < maxDocs; i++) { |
| Map<Query, Float> doc0 = c.docCounts.get(i); |
| assertEquals(1, doc0.size()); |
| assertEquals(4.0F, doc0.get(q), FLOAT_TOLERANCE); |
| |
| Map<Query, Float> doc1 = c.docCounts.get(++i); |
| assertEquals(1, doc1.size()); |
| assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE); |
| } |
| } |
| |
| @Test |
| public void testBooleanQuery() throws Exception { |
| TermQuery aQuery = new TermQuery(new Term("f", "a")); |
| TermQuery dQuery = new TermQuery(new Term("f", "d")); |
| TermQuery cQuery = new TermQuery(new Term("f", "c")); |
| TermQuery yQuery = new TermQuery(new Term("f", "y")); |
| |
| BooleanQuery.Builder query = new BooleanQuery.Builder(); |
| BooleanQuery.Builder inner = new BooleanQuery.Builder(); |
| |
| inner.add(cQuery, Occur.SHOULD); |
| inner.add(yQuery, Occur.MUST_NOT); |
| query.add(inner.build(), Occur.MUST); |
| query.add(aQuery, Occur.MUST); |
| query.add(dQuery, Occur.MUST); |
| |
| // Only needed in Java6; Java7+ has a @SafeVarargs annotated Arrays#asList()! |
| // see http://docs.oracle.com/javase/7/docs/api/java/lang/SafeVarargs.html |
| @SuppressWarnings("unchecked") final Iterable<Set<String>> occurList = Arrays.asList( |
| Collections.singleton("MUST"), |
| new HashSet<>(Arrays.asList("MUST", "SHOULD")) |
| ); |
| |
| for (final Set<String> occur : occurList) { |
| CountingCollector c = new CountingCollector(TopScoreDocCollector.create( |
| 10, Integer.MAX_VALUE), occur); |
| s.search(query.build(), c); |
| final int maxDocs = s.getIndexReader().maxDoc(); |
| assertEquals(maxDocs, c.docCounts.size()); |
| boolean includeOptional = occur.contains("SHOULD"); |
| for (int i = 0; i < maxDocs; i++) { |
| Map<Query, Float> doc0 = c.docCounts.get(i); |
| // Y doesnt exist in the index, so it's not in the scorer tree |
| assertEquals(4, doc0.size()); |
| assertEquals(1.0F, doc0.get(aQuery), FLOAT_TOLERANCE); |
| assertEquals(4.0F, doc0.get(dQuery), FLOAT_TOLERANCE); |
| if (includeOptional) { |
| assertEquals(3.0F, doc0.get(cQuery), FLOAT_TOLERANCE); |
| } |
| |
| Map<Query, Float> doc1 = c.docCounts.get(++i); |
| // Y doesnt exist in the index, so it's not in the scorer tree |
| assertEquals(4, doc1.size()); |
| assertEquals(1.0F, doc1.get(aQuery), FLOAT_TOLERANCE); |
| assertEquals(1.0F, doc1.get(dQuery), FLOAT_TOLERANCE); |
| if (includeOptional) { |
| assertEquals(1.0F, doc1.get(cQuery), FLOAT_TOLERANCE); |
| } |
| } |
| } |
| } |
| |
| @Test |
| public void testPhraseQuery() throws Exception { |
| PhraseQuery q = new PhraseQuery("f", "b", "c"); |
| CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10, Integer.MAX_VALUE)); |
| s.search(q, c); |
| final int maxDocs = s.getIndexReader().maxDoc(); |
| assertEquals(maxDocs, c.docCounts.size()); |
| for (int i = 0; i < maxDocs; i++) { |
| Map<Query, Float> doc0 = c.docCounts.get(i); |
| assertEquals(1, doc0.size()); |
| assertEquals(2.0F, doc0.get(q), FLOAT_TOLERANCE); |
| |
| Map<Query, Float> doc1 = c.docCounts.get(++i); |
| assertEquals(1, doc1.size()); |
| assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE); |
| } |
| |
| } |
| |
| // Similarity that just returns the frequency as the score |
| private static class CountingSimilarity extends Similarity { |
| |
| @Override |
| public long computeNorm(FieldInvertState state) { |
| return 1; |
| } |
| |
| @Override |
| public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { |
| return new SimScorer() { |
| @Override |
| public float score(float freq, long norm) { |
| return freq; |
| } |
| }; |
| } |
| } |
| } |