blob: 9b0b24abdb82a5b2b224a3ca7a76aba50a61aba5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.intervals;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchesIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.util.PriorityQueue;
class DisjunctionIntervalsSource extends IntervalsSource {
final Collection<IntervalsSource> subSources;
final boolean pullUpDisjunctions;
static IntervalsSource create(Collection<IntervalsSource> subSources, boolean pullUpDisjunctions) {
subSources = simplify(subSources);
if (subSources.size() == 1) {
return subSources.iterator().next();
}
return new DisjunctionIntervalsSource(subSources, pullUpDisjunctions);
}
private DisjunctionIntervalsSource(Collection<IntervalsSource> subSources, boolean pullUpDisjunctions) {
this.subSources = simplify(subSources);
this.pullUpDisjunctions = pullUpDisjunctions;
}
private static Collection<IntervalsSource> simplify(Collection<IntervalsSource> sources) {
Set<IntervalsSource> simplified = new HashSet<>();
for (IntervalsSource source : sources) {
if (source instanceof DisjunctionIntervalsSource) {
simplified.addAll(source.pullUpDisjunctions());
}
else {
simplified.add(source);
}
}
return simplified;
}
@Override
public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
List<IntervalIterator> subIterators = new ArrayList<>();
for (IntervalsSource subSource : subSources) {
IntervalIterator it = subSource.intervals(field, ctx);
if (it != null) {
subIterators.add(it);
}
}
if (subIterators.size() == 0)
return null;
return new DisjunctionIntervalIterator(subIterators);
}
@Override
public IntervalMatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException {
List<IntervalMatchesIterator> subMatches = new ArrayList<>();
for (IntervalsSource subSource : subSources) {
IntervalMatchesIterator mi = subSource.matches(field, ctx, doc);
if (mi != null) {
subMatches.add(mi);
}
}
if (subMatches.size() == 0) {
return null;
}
DisjunctionIntervalIterator it = new DisjunctionIntervalIterator(
subMatches.stream().map(m -> IntervalMatches.wrapMatches(m, doc)).collect(Collectors.toList())
);
if (it.advance(doc) != doc) {
return null;
}
return new DisjunctionMatchesIterator(it, subMatches);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DisjunctionIntervalsSource that = (DisjunctionIntervalsSource) o;
return Objects.equals(subSources, that.subSources);
}
@Override
public int hashCode() {
return Objects.hash(subSources);
}
@Override
public String toString() {
return subSources.stream()
.map(Object::toString)
.sorted()
.collect(Collectors.joining(",", "or(", ")"));
}
@Override
public void visit(String field, QueryVisitor visitor) {
Query parent = new IntervalQuery(field, this);
QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, parent);
for (IntervalsSource source : subSources) {
source.visit(field, v);
}
}
@Override
public int minExtent() {
int minExtent = Integer.MAX_VALUE;
for (IntervalsSource subSource : subSources) {
minExtent = Math.min(minExtent, subSource.minExtent());
}
return minExtent;
}
@Override
public Collection<IntervalsSource> pullUpDisjunctions() {
if (pullUpDisjunctions) {
return subSources;
}
return Collections.singletonList(this);
}
static class DisjunctionIntervalIterator extends IntervalIterator {
final DocIdSetIterator approximation;
final PriorityQueue<IntervalIterator> intervalQueue;
final DisiPriorityQueue disiQueue;
final List<IntervalIterator> iterators;
final float matchCost;
IntervalIterator current = EMPTY;
DisjunctionIntervalIterator(List<IntervalIterator> iterators) {
this.disiQueue = new DisiPriorityQueue(iterators.size());
for (IntervalIterator it : iterators) {
disiQueue.add(new DisiWrapper(it));
}
this.approximation = new DisjunctionDISIApproximation(disiQueue);
this.iterators = iterators;
this.intervalQueue = new PriorityQueue<IntervalIterator>(iterators.size()) {
@Override
protected boolean lessThan(IntervalIterator a, IntervalIterator b) {
return a.end() < b.end() || (a.end() == b.end() && a.start() >= b.start());
}
};
float costsum = 0;
for (IntervalIterator it : iterators) {
costsum += it.cost();
}
this.matchCost = costsum;
}
@Override
public float matchCost() {
return matchCost;
}
@Override
public int start() {
return current.start();
}
@Override
public int end() {
return current.end();
}
@Override
public int gaps() {
return current.gaps();
}
private void reset() throws IOException {
intervalQueue.clear();
for (DisiWrapper dw = disiQueue.topList(); dw != null; dw = dw.next) {
dw.intervals.nextInterval();
intervalQueue.add(dw.intervals);
}
current = EMPTY;
}
int currentOrd() {
assert current != EMPTY && current != EXHAUSTED;
for (int i = 0; i < iterators.size(); i++) {
if (iterators.get(i) == current) {
return i;
}
}
throw new IllegalStateException();
}
@Override
public int nextInterval() throws IOException {
if (current == EMPTY || current == EXHAUSTED) {
if (intervalQueue.size() > 0) {
current = intervalQueue.top();
}
return current.start();
}
int start = current.start(), end = current.end();
while (intervalQueue.size() > 0 && contains(intervalQueue.top(), start, end)) {
IntervalIterator it = intervalQueue.pop();
if (it != null && it.nextInterval() != NO_MORE_INTERVALS) {
intervalQueue.add(it);
}
}
if (intervalQueue.size() == 0) {
current = EXHAUSTED;
return NO_MORE_INTERVALS;
}
current = intervalQueue.top();
return current.start();
}
private boolean contains(IntervalIterator it, int start, int end) {
return start >= it.start() && start <= it.end() && end >= it.start() && end <= it.end();
}
@Override
public int docID() {
return approximation.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = approximation.nextDoc();
reset();
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = approximation.advance(target);
reset();
return doc;
}
@Override
public long cost() {
return approximation.cost();
}
}
private static final IntervalIterator EMPTY = new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return -1;
}
@Override
public int end() {
return -1;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
private static final IntervalIterator EXHAUSTED = new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return NO_MORE_INTERVALS;
}
@Override
public int end() {
return NO_MORE_INTERVALS;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
private static class DisjunctionMatchesIterator implements IntervalMatchesIterator {
final DisjunctionIntervalIterator it;
final List<IntervalMatchesIterator> subs;
private DisjunctionMatchesIterator(DisjunctionIntervalIterator it, List<IntervalMatchesIterator> subs) {
this.it = it;
this.subs = subs;
}
@Override
public boolean next() throws IOException {
return it.nextInterval() != IntervalIterator.NO_MORE_INTERVALS;
}
@Override
public int startPosition() {
return it.start();
}
@Override
public int endPosition() {
return it.end();
}
@Override
public int startOffset() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).startOffset();
}
@Override
public int endOffset() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).endOffset();
}
@Override
public MatchesIterator getSubMatches() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).getSubMatches();
}
@Override
public Query getQuery() {
int ord = it.currentOrd();
return subs.get(ord).getQuery();
}
@Override
public int gaps() {
return it.gaps();
}
@Override
public int width() {
return it.width();
}
}
}