| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Objects; |
| import java.util.OptionalLong; |
| import java.util.stream.Stream; |
| import org.apache.lucene.search.BooleanClause.Occur; |
| import org.apache.lucene.search.Weight.DefaultBulkScorer; |
| import org.apache.lucene.util.Bits; |
| |
| final class BooleanScorerSupplier extends ScorerSupplier { |
| private final Map<BooleanClause.Occur, Collection<ScorerSupplier>> subs; |
| private final ScoreMode scoreMode; |
| private final int minShouldMatch; |
| private final int maxDoc; |
| private long cost = -1; |
| private boolean topLevelScoringClause; |
| |
| BooleanScorerSupplier( |
| Weight weight, |
| Map<Occur, Collection<ScorerSupplier>> subs, |
| ScoreMode scoreMode, |
| int minShouldMatch, |
| int maxDoc) { |
| if (minShouldMatch < 0) { |
| throw new IllegalArgumentException( |
| "minShouldMatch must be positive, but got: " + minShouldMatch); |
| } |
| if (minShouldMatch != 0 && minShouldMatch >= subs.get(Occur.SHOULD).size()) { |
| throw new IllegalArgumentException( |
| "minShouldMatch must be strictly less than the number of SHOULD clauses"); |
| } |
| if (scoreMode.needsScores() == false |
| && minShouldMatch == 0 |
| && subs.get(Occur.SHOULD).size() > 0 |
| && subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() > 0) { |
| throw new IllegalArgumentException( |
| "Cannot pass purely optional clauses if scores are not needed"); |
| } |
| if (subs.get(Occur.SHOULD).size() + subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() |
| == 0) { |
| throw new IllegalArgumentException("There should be at least one positive clause"); |
| } |
| this.subs = subs; |
| this.scoreMode = scoreMode; |
| this.minShouldMatch = minShouldMatch; |
| this.maxDoc = maxDoc; |
| } |
| |
| private long computeCost() { |
| OptionalLong minRequiredCost = |
| Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream()) |
| .mapToLong(ScorerSupplier::cost) |
| .min(); |
| if (minRequiredCost.isPresent() && minShouldMatch == 0) { |
| return minRequiredCost.getAsLong(); |
| } else { |
| final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD); |
| final long shouldCost = |
| ScorerUtil.costWithMinShouldMatch( |
| optionalScorers.stream().mapToLong(ScorerSupplier::cost), |
| optionalScorers.size(), |
| minShouldMatch); |
| return Math.min(minRequiredCost.orElse(Long.MAX_VALUE), shouldCost); |
| } |
| } |
| |
| @Override |
| public void setTopLevelScoringClause() throws IOException { |
| topLevelScoringClause = true; |
| if (subs.get(Occur.SHOULD).size() + subs.get(Occur.MUST).size() == 1) { |
| // If there is a single scoring clause, propagate the call. |
| for (ScorerSupplier ss : subs.get(Occur.SHOULD)) { |
| ss.setTopLevelScoringClause(); |
| } |
| for (ScorerSupplier ss : subs.get(Occur.MUST)) { |
| ss.setTopLevelScoringClause(); |
| } |
| } |
| } |
| |
| @Override |
| public long cost() { |
| if (cost == -1) { |
| cost = computeCost(); |
| } |
| return cost; |
| } |
| |
| @Override |
| public Scorer get(long leadCost) throws IOException { |
| Scorer scorer = getInternal(leadCost); |
| if (scoreMode == ScoreMode.TOP_SCORES |
| && subs.get(Occur.SHOULD).isEmpty() |
| && subs.get(Occur.MUST).isEmpty()) { |
| // no scoring clauses but scores are needed so we wrap the scorer in |
| // a constant score in order to allow early termination |
| return scorer.twoPhaseIterator() != null |
| ? new ConstantScoreScorer(0f, scoreMode, scorer.twoPhaseIterator()) |
| : new ConstantScoreScorer(0f, scoreMode, scorer.iterator()); |
| } |
| return scorer; |
| } |
| |
| private Scorer getInternal(long leadCost) throws IOException { |
| // three cases: conjunction, disjunction, or mix |
| leadCost = Math.min(leadCost, cost()); |
| |
| // pure conjunction |
| if (subs.get(Occur.SHOULD).isEmpty()) { |
| return excl( |
| req(subs.get(Occur.FILTER), subs.get(Occur.MUST), leadCost, topLevelScoringClause), |
| subs.get(Occur.MUST_NOT), |
| leadCost); |
| } |
| |
| // pure disjunction |
| if (subs.get(Occur.FILTER).isEmpty() && subs.get(Occur.MUST).isEmpty()) { |
| return excl( |
| opt(subs.get(Occur.SHOULD), minShouldMatch, scoreMode, leadCost, topLevelScoringClause), |
| subs.get(Occur.MUST_NOT), |
| leadCost); |
| } |
| |
| // conjunction-disjunction mix: |
| // we create the required and optional pieces, and then |
| // combine the two: if minNrShouldMatch > 0, then it's a conjunction: because the |
| // optional side must match. otherwise it's required + optional |
| |
| if (minShouldMatch > 0) { |
| Scorer req = |
| excl( |
| req(subs.get(Occur.FILTER), subs.get(Occur.MUST), leadCost, false), |
| subs.get(Occur.MUST_NOT), |
| leadCost); |
| Scorer opt = opt(subs.get(Occur.SHOULD), minShouldMatch, scoreMode, leadCost, false); |
| return new ConjunctionScorer(Arrays.asList(req, opt), Arrays.asList(req, opt)); |
| } else { |
| assert scoreMode.needsScores(); |
| return new ReqOptSumScorer( |
| excl( |
| req(subs.get(Occur.FILTER), subs.get(Occur.MUST), leadCost, false), |
| subs.get(Occur.MUST_NOT), |
| leadCost), |
| opt(subs.get(Occur.SHOULD), minShouldMatch, scoreMode, leadCost, false), |
| scoreMode); |
| } |
| } |
| |
| @Override |
| public BulkScorer bulkScorer() throws IOException { |
| final BulkScorer bulkScorer = booleanScorer(); |
| if (bulkScorer != null) { |
| // bulk scoring is applicable, use it |
| return bulkScorer; |
| } else { |
| // use a Scorer-based impl (BS2) |
| return super.bulkScorer(); |
| } |
| } |
| |
| BulkScorer booleanScorer() throws IOException { |
| final int numOptionalClauses = subs.get(Occur.SHOULD).size(); |
| final int numRequiredClauses = subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size(); |
| |
| BulkScorer positiveScorer; |
| if (numRequiredClauses == 0) { |
| // TODO: what is the right heuristic here? |
| final long costThreshold; |
| if (minShouldMatch <= 1) { |
| // when all clauses are optional, use BooleanScorer aggressively |
| // TODO: is there actually a threshold under which we should rather |
| // use the regular scorer? |
| costThreshold = -1; |
| } else { |
| // when a minimum number of clauses should match, BooleanScorer is |
| // going to score all windows that have at least minNrShouldMatch |
| // matches in the window. But there is no way to know if there is |
| // an intersection (all clauses might match a different doc ID and |
| // there will be no matches in the end) so we should only use |
| // BooleanScorer if matches are very dense |
| costThreshold = maxDoc / 3; |
| } |
| |
| if (cost() < costThreshold) { |
| return null; |
| } |
| |
| positiveScorer = optionalBulkScorer(); |
| } else if (numRequiredClauses > 0 && numOptionalClauses == 0 && minShouldMatch == 0) { |
| positiveScorer = requiredBulkScorer(); |
| } else { |
| // TODO: there are some cases where BooleanScorer |
| // would handle conjunctions faster than |
| // BooleanScorer2... |
| return null; |
| } |
| |
| if (positiveScorer == null) { |
| return null; |
| } |
| final long positiveScorerCost = positiveScorer.cost(); |
| |
| List<Scorer> prohibited = new ArrayList<>(); |
| for (ScorerSupplier ss : subs.get(Occur.MUST_NOT)) { |
| prohibited.add(ss.get(positiveScorerCost)); |
| } |
| |
| if (prohibited.isEmpty()) { |
| return positiveScorer; |
| } else { |
| Scorer prohibitedScorer = |
| prohibited.size() == 1 |
| ? prohibited.get(0) |
| : new DisjunctionSumScorer(prohibited, ScoreMode.COMPLETE_NO_SCORES); |
| return new ReqExclBulkScorer(positiveScorer, prohibitedScorer); |
| } |
| } |
| |
| static BulkScorer disableScoring(final BulkScorer scorer) { |
| Objects.requireNonNull(scorer); |
| return new BulkScorer() { |
| |
| @Override |
| public int score(final LeafCollector collector, Bits acceptDocs, int min, int max) |
| throws IOException { |
| final LeafCollector noScoreCollector = |
| new LeafCollector() { |
| Score fake = new Score(); |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| collector.setScorer(fake); |
| } |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| collector.collect(doc); |
| } |
| }; |
| return scorer.score(noScoreCollector, acceptDocs, min, max); |
| } |
| |
| @Override |
| public long cost() { |
| return scorer.cost(); |
| } |
| }; |
| } |
| |
| // Return a BulkScorer for the optional clauses only, |
| // or null if it is not applicable |
| // pkg-private for forcing use of BooleanScorer in tests |
| BulkScorer optionalBulkScorer() throws IOException { |
| if (subs.get(Occur.SHOULD).size() == 0) { |
| return null; |
| } else if (subs.get(Occur.SHOULD).size() == 1 && minShouldMatch <= 1) { |
| return subs.get(Occur.SHOULD).iterator().next().bulkScorer(); |
| } |
| |
| if (scoreMode == ScoreMode.TOP_SCORES && minShouldMatch <= 1) { |
| List<Scorer> optionalScorers = new ArrayList<>(); |
| for (ScorerSupplier ss : subs.get(Occur.SHOULD)) { |
| optionalScorers.add(ss.get(Long.MAX_VALUE)); |
| } |
| |
| return new MaxScoreBulkScorer(maxDoc, optionalScorers); |
| } |
| |
| List<BulkScorer> optional = new ArrayList<BulkScorer>(); |
| for (ScorerSupplier ss : subs.get(Occur.SHOULD)) { |
| optional.add(ss.bulkScorer()); |
| } |
| |
| return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores()); |
| } |
| |
| // Return a BulkScorer for the required clauses only |
| private BulkScorer requiredBulkScorer() throws IOException { |
| if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) { |
| // No required clauses at all. |
| return null; |
| } else if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 1) { |
| BulkScorer scorer; |
| if (subs.get(Occur.MUST).isEmpty() == false) { |
| scorer = subs.get(Occur.MUST).iterator().next().bulkScorer(); |
| } else { |
| scorer = subs.get(Occur.FILTER).iterator().next().bulkScorer(); |
| if (scoreMode.needsScores()) { |
| scorer = disableScoring(scorer); |
| } |
| } |
| return scorer; |
| } |
| |
| long leadCost = |
| subs.get(Occur.MUST).stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE); |
| leadCost = |
| subs.get(Occur.FILTER).stream().mapToLong(ScorerSupplier::cost).min().orElse(leadCost); |
| |
| List<Scorer> requiredNoScoring = new ArrayList<>(); |
| for (ScorerSupplier ss : subs.get(Occur.FILTER)) { |
| requiredNoScoring.add(ss.get(leadCost)); |
| } |
| List<Scorer> requiredScoring = new ArrayList<>(); |
| Collection<ScorerSupplier> requiredScoringSupplier = subs.get(Occur.MUST); |
| for (ScorerSupplier ss : requiredScoringSupplier) { |
| if (requiredScoringSupplier.size() == 1) { |
| ss.setTopLevelScoringClause(); |
| } |
| requiredScoring.add(ss.get(leadCost)); |
| } |
| if (scoreMode == ScoreMode.TOP_SCORES |
| && requiredNoScoring.isEmpty() |
| && requiredScoring.size() > 1 |
| // Only specialize top-level conjunctions for clauses that don't have a two-phase iterator. |
| && requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) { |
| return new BlockMaxConjunctionBulkScorer(maxDoc, requiredScoring); |
| } |
| if (scoreMode != ScoreMode.TOP_SCORES |
| && requiredScoring.size() + requiredNoScoring.size() >= 2 |
| && requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull) |
| && requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) { |
| return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring); |
| } |
| if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) { |
| requiredScoring = Collections.singletonList(new BlockMaxConjunctionScorer(requiredScoring)); |
| } |
| Scorer conjunctionScorer; |
| if (requiredNoScoring.size() + requiredScoring.size() == 1) { |
| if (requiredScoring.size() == 1) { |
| conjunctionScorer = requiredScoring.get(0); |
| } else { |
| conjunctionScorer = requiredNoScoring.get(0); |
| if (scoreMode.needsScores()) { |
| Scorer inner = conjunctionScorer; |
| conjunctionScorer = |
| new FilterScorer(inner) { |
| @Override |
| public float score() throws IOException { |
| return 0f; |
| } |
| |
| @Override |
| public float getMaxScore(int upTo) throws IOException { |
| return 0f; |
| } |
| }; |
| } |
| } |
| } else { |
| List<Scorer> required = new ArrayList<>(); |
| required.addAll(requiredScoring); |
| required.addAll(requiredNoScoring); |
| conjunctionScorer = new ConjunctionScorer(required, requiredScoring); |
| } |
| return new DefaultBulkScorer(conjunctionScorer); |
| } |
| |
| /** |
| * Create a new scorer for the given required clauses. Note that {@code requiredScoring} is a |
| * subset of {@code required} containing required clauses that should participate in scoring. |
| */ |
| private Scorer req( |
| Collection<ScorerSupplier> requiredNoScoring, |
| Collection<ScorerSupplier> requiredScoring, |
| long leadCost, |
| boolean topLevelScoringClause) |
| throws IOException { |
| if (requiredNoScoring.size() + requiredScoring.size() == 1) { |
| Scorer req = |
| (requiredNoScoring.isEmpty() ? requiredScoring : requiredNoScoring) |
| .iterator() |
| .next() |
| .get(leadCost); |
| |
| if (scoreMode.needsScores() == false) { |
| return req; |
| } |
| |
| if (requiredScoring.isEmpty()) { |
| // Scores are needed but we only have a filter clause |
| // BooleanWeight expects that calling score() is ok so we need to wrap |
| // to prevent score() from being propagated |
| return new FilterScorer(req) { |
| @Override |
| public float score() throws IOException { |
| return 0f; |
| } |
| |
| @Override |
| public float getMaxScore(int upTo) throws IOException { |
| return 0f; |
| } |
| }; |
| } |
| |
| return req; |
| } else { |
| List<Scorer> requiredScorers = new ArrayList<>(); |
| List<Scorer> scoringScorers = new ArrayList<>(); |
| for (ScorerSupplier s : requiredNoScoring) { |
| requiredScorers.add(s.get(leadCost)); |
| } |
| for (ScorerSupplier s : requiredScoring) { |
| Scorer scorer = s.get(leadCost); |
| scoringScorers.add(scorer); |
| } |
| if (scoreMode == ScoreMode.TOP_SCORES && scoringScorers.size() > 1 && topLevelScoringClause) { |
| Scorer blockMaxScorer = new BlockMaxConjunctionScorer(scoringScorers); |
| if (requiredScorers.isEmpty()) { |
| return blockMaxScorer; |
| } |
| scoringScorers = Collections.singletonList(blockMaxScorer); |
| } |
| requiredScorers.addAll(scoringScorers); |
| return new ConjunctionScorer(requiredScorers, scoringScorers); |
| } |
| } |
| |
| private Scorer excl(Scorer main, Collection<ScorerSupplier> prohibited, long leadCost) |
| throws IOException { |
| if (prohibited.isEmpty()) { |
| return main; |
| } else { |
| return new ReqExclScorer( |
| main, opt(prohibited, 1, ScoreMode.COMPLETE_NO_SCORES, leadCost, false)); |
| } |
| } |
| |
| private Scorer opt( |
| Collection<ScorerSupplier> optional, |
| int minShouldMatch, |
| ScoreMode scoreMode, |
| long leadCost, |
| boolean topLevelScoringClause) |
| throws IOException { |
| if (optional.size() == 1) { |
| return optional.iterator().next().get(leadCost); |
| } else { |
| final List<Scorer> optionalScorers = new ArrayList<>(); |
| for (ScorerSupplier scorer : optional) { |
| optionalScorers.add(scorer.get(leadCost)); |
| } |
| |
| // Technically speaking, WANDScorer should be able to handle the following 3 conditions now |
| // 1. Any ScoreMode (with scoring or not) |
| // 2. Any minCompetitiveScore ( >= 0 ) |
| // 3. Any minShouldMatch ( >= 0 ) |
| // |
| // However, as WANDScorer uses more complex algorithm and data structure, we would like to |
| // still use DisjunctionSumScorer to handle exhaustive pure disjunctions, which may be faster |
| if ((scoreMode == ScoreMode.TOP_SCORES && topLevelScoringClause) || minShouldMatch > 1) { |
| return new WANDScorer(optionalScorers, minShouldMatch, scoreMode); |
| } else { |
| return new DisjunctionSumScorer(optionalScorers, scoreMode); |
| } |
| } |
| } |
| } |