blob: 27254d8579f2ad788be2cb2242ec89f89d9d26d3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.util.LuceneTestCase;
import static org.hamcrest.CoreMatchers.equalTo;
public class TestQueryVisitor extends LuceneTestCase {
private static final Query query = new BooleanQuery.Builder()
.add(new TermQuery(new Term("field1", "t1")), BooleanClause.Occur.MUST)
.add(new BooleanQuery.Builder()
.add(new TermQuery(new Term("field1", "tm2")), BooleanClause.Occur.SHOULD)
.add(new BoostQuery(new TermQuery(new Term("field1", "tm3")), 2), BooleanClause.Occur.SHOULD)
.build(), BooleanClause.Occur.MUST)
.add(new BoostQuery(new PhraseQuery.Builder()
.add(new Term("field1", "term4"))
.add(new Term("field1", "term5"))
.build(), 3), BooleanClause.Occur.MUST)
.add(new SpanNearQuery(new SpanQuery[]{
new SpanTermQuery(new Term("field1", "term6")),
new SpanTermQuery(new Term("field1", "term7"))
}, 2, true), BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("field1", "term8")), BooleanClause.Occur.MUST_NOT)
.add(new PrefixQuery(new Term("field1", "term9")), BooleanClause.Occur.SHOULD)
.add(new BoostQuery(new BooleanQuery.Builder()
.add(new BoostQuery(new TermQuery(new Term("field2", "term10")), 3), BooleanClause.Occur.MUST)
.build(), 2), BooleanClause.Occur.SHOULD)
.build();
public void testExtractTermsEquivalent() {
Set<Term> terms = new HashSet<>();
Set<Term> expected = new HashSet<>(Arrays.asList(
new Term("field1", "t1"), new Term("field1", "tm2"),
new Term("field1", "tm3"), new Term("field1", "term4"),
new Term("field1", "term5"), new Term("field1", "term6"),
new Term("field1", "term7"), new Term("field2", "term10")
));
query.visit(QueryVisitor.termCollector(terms));
assertThat(terms, equalTo(expected));
}
public void extractAllTerms() {
Set<Term> terms = new HashSet<>();
QueryVisitor visitor = new QueryVisitor() {
@Override
public void consumeTerms(Query query, Term... ts) {
terms.addAll(Arrays.asList(ts));
}
@Override
public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
return this;
}
};
Set<Term> expected = new HashSet<>(Arrays.asList(
new Term("field1", "t1"), new Term("field1", "tm2"),
new Term("field1", "tm3"), new Term("field1", "term4"),
new Term("field1", "term5"), new Term("field1", "term6"),
new Term("field1", "term7"), new Term("field1", "term8"),
new Term("field2", "term10")
));
query.visit(visitor);
assertThat(terms, equalTo(expected));
}
public void extractTermsFromField() {
final Set<Term> actual = new HashSet<>();
Set<Term> expected = new HashSet<>(Arrays.asList(new Term("field2", "term10")));
query.visit(new QueryVisitor(){
@Override
public boolean acceptField(String field) {
return "field2".equals(field);
}
@Override
public void consumeTerms(Query query, Term... terms) {
actual.addAll(Arrays.asList(terms));
}
});
assertThat(actual, equalTo(expected));
}
static class BoostedTermExtractor extends QueryVisitor {
final float boost;
final Map<Term, Float> termsToBoosts;
BoostedTermExtractor(float boost, Map<Term, Float> termsToBoosts) {
this.boost = boost;
this.termsToBoosts = termsToBoosts;
}
@Override
public void consumeTerms(Query query, Term... terms) {
for (Term term : terms) {
termsToBoosts.put(term, boost);
}
}
@Override
public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
if (parent instanceof BoostQuery) {
return new BoostedTermExtractor(boost * ((BoostQuery)parent).getBoost(), termsToBoosts);
}
return super.getSubVisitor(occur, parent);
}
}
public void testExtractTermsAndBoosts() {
Map<Term, Float> termsToBoosts = new HashMap<>();
query.visit(new BoostedTermExtractor(1, termsToBoosts));
Map<Term, Float> expected = new HashMap<>();
expected.put(new Term("field1", "t1"), 1f);
expected.put(new Term("field1", "tm2"), 1f);
expected.put(new Term("field1", "tm3"), 2f);
expected.put(new Term("field1", "term4"), 3f);
expected.put(new Term("field1", "term5"), 3f);
expected.put(new Term("field1", "term6"), 1f);
expected.put(new Term("field1", "term7"), 1f);
expected.put(new Term("field2", "term10"), 6f);
assertThat(termsToBoosts, equalTo(expected));
}
public void testLeafQueryTypeCounts() {
Map<Class<? extends Query>, Integer> queryCounts = new HashMap<>();
query.visit(new QueryVisitor() {
private void countQuery(Query q) {
queryCounts.compute(q.getClass(), (query, i) -> {
if (i == null) {
return 1;
}
return i + 1;
});
}
@Override
public void consumeTerms(Query query, Term... terms) {
countQuery(query);
}
@Override
public void visitLeaf(Query query) {
countQuery(query);
}
});
assertEquals(4, queryCounts.get(TermQuery.class).intValue());
assertEquals(1, queryCounts.get(PhraseQuery.class).intValue());
}
static abstract class QueryNode extends QueryVisitor {
final List<QueryNode> children = new ArrayList<>();
abstract int getWeight();
abstract void collectTerms(Set<Term> terms);
abstract boolean nextTermSet();
@Override
public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
if (occur == BooleanClause.Occur.MUST || occur == BooleanClause.Occur.FILTER) {
QueryNode n = new ConjunctionNode();
children.add(n);
return n;
}
if (occur == BooleanClause.Occur.MUST_NOT) {
return QueryVisitor.EMPTY_VISITOR;
}
if (parent instanceof BooleanQuery) {
BooleanQuery bq = (BooleanQuery) parent;
if (bq.getClauses(BooleanClause.Occur.MUST).size() > 0 || bq.getClauses(BooleanClause.Occur.FILTER).size() > 0) {
return QueryVisitor.EMPTY_VISITOR;
}
}
DisjunctionNode n = new DisjunctionNode();
children.add(n);
return n;
}
}
static class TermNode extends QueryNode {
final Term term;
TermNode(Term term) {
this.term = term;
}
@Override
int getWeight() {
return term.text().length();
}
@Override
void collectTerms(Set<Term> terms) {
terms.add(term);
}
@Override
boolean nextTermSet() {
return false;
}
@Override
public String toString() {
return "TERM(" + term.toString() + ")";
}
}
static class ConjunctionNode extends QueryNode {
@Override
int getWeight() {
children.sort(Comparator.comparingInt(QueryNode::getWeight));
return children.get(0).getWeight();
}
@Override
void collectTerms(Set<Term> terms) {
children.sort(Comparator.comparingInt(QueryNode::getWeight));
children.get(0).collectTerms(terms);
}
@Override
boolean nextTermSet() {
children.sort(Comparator.comparingInt(QueryNode::getWeight));
if (children.get(0).nextTermSet()) {
return true;
}
if (children.size() == 1) {
return false;
}
children.remove(0);
return true;
}
@Override
public void consumeTerms(Query query, Term... terms) {
for (Term term : terms) {
children.add(new TermNode(term));
}
}
@Override
public String toString() {
return children.stream().map(QueryNode::toString).collect(Collectors.joining(",", "AND(", ")"));
}
}
static class DisjunctionNode extends QueryNode {
@Override
int getWeight() {
children.sort(Comparator.comparingInt(QueryNode::getWeight).reversed());
return children.get(0).getWeight();
}
@Override
void collectTerms(Set<Term> terms) {
for (QueryNode child : children) {
child.collectTerms(terms);
}
}
@Override
boolean nextTermSet() {
boolean next = false;
for (QueryNode child : children) {
next |= child.nextTermSet();
}
return next;
}
@Override
public void consumeTerms(Query query, Term... terms) {
for (Term term : terms) {
children.add(new TermNode(term));
}
}
@Override
public String toString() {
return children.stream().map(QueryNode::toString).collect(Collectors.joining(",", "OR(", ")"));
}
}
public void testExtractMatchingTermSet() {
QueryNode extractor = new ConjunctionNode();
query.visit(extractor);
Set<Term> minimumTermSet = new HashSet<>();
extractor.collectTerms(minimumTermSet);
Set<Term> expected1 = new HashSet<>(Collections.singletonList(new Term("field1", "t1")));
assertThat(minimumTermSet, equalTo(expected1));
assertTrue(extractor.nextTermSet());
Set<Term> expected2 = new HashSet<>(Arrays.asList(new Term("field1", "tm2"), new Term("field1", "tm3")));
minimumTermSet.clear();
extractor.collectTerms(minimumTermSet);
assertThat(minimumTermSet, equalTo(expected2));
BooleanQuery bq = new BooleanQuery.Builder()
.add(new BooleanQuery.Builder()
.add(new TermQuery(new Term("f", "1")), BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("f", "61")), BooleanClause.Occur.MUST)
.add(new TermQuery(new Term("f", "211")), BooleanClause.Occur.FILTER)
.add(new TermQuery(new Term("f", "5")), BooleanClause.Occur.SHOULD)
.build(), BooleanClause.Occur.SHOULD)
.add(new PhraseQuery("f", "3333", "44444"), BooleanClause.Occur.SHOULD)
.build();
QueryNode ex2 = new ConjunctionNode();
bq.visit(ex2);
Set<Term> expected3 = new HashSet<>(Arrays.asList(new Term("f", "1"), new Term("f", "3333")));
minimumTermSet.clear();
ex2.collectTerms(minimumTermSet);
assertThat(minimumTermSet, equalTo(expected3));
ex2.getWeight(); // force sort order
assertThat(ex2.toString(), equalTo("AND(AND(OR(AND(TERM(f:3333),TERM(f:44444)),AND(TERM(f:1),TERM(f:61),AND(TERM(f:211))))))"));
}
}