blob: 9996daef8fd3d269ec02edfde3dc05a95ef7bb37 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.join;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
// A TermsIncludingScoreQuery variant for point values:
abstract class PointInSetIncludingScoreQuery extends Query implements Accountable {
protected static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(PointInSetIncludingScoreQuery.class);
static BiFunction<byte[], Class<? extends Number>, String> toString = (value, numericType) -> {
if (Integer.class.equals(numericType)) {
return Integer.toString(IntPoint.decodeDimension(value, 0));
} else if (Long.class.equals(numericType)) {
return Long.toString(LongPoint.decodeDimension(value, 0));
} else if (Float.class.equals(numericType)) {
return Float.toString(FloatPoint.decodeDimension(value, 0));
} else if (Double.class.equals(numericType)) {
return Double.toString(DoublePoint.decodeDimension(value, 0));
} else {
return "unsupported";
}
};
final ScoreMode scoreMode;
final Query originalQuery;
final boolean multipleValuesPerDocument;
final PrefixCodedTerms sortedPackedPoints;
final int sortedPackedPointsHashCode;
final String field;
final int bytesPerDim;
final List<Float> aggregatedJoinScores;
private final long ramBytesUsed; // cache
static abstract class Stream extends PointInSetQuery.Stream {
float score;
}
PointInSetIncludingScoreQuery(ScoreMode scoreMode, Query originalQuery, boolean multipleValuesPerDocument,
String field, int bytesPerDim, Stream packedPoints) {
this.scoreMode = scoreMode;
this.originalQuery = originalQuery;
this.multipleValuesPerDocument = multipleValuesPerDocument;
this.field = field;
if (bytesPerDim < 1 || bytesPerDim > PointValues.MAX_NUM_BYTES) {
throw new IllegalArgumentException("bytesPerDim must be > 0 and <= " + PointValues.MAX_NUM_BYTES + "; got " + bytesPerDim);
}
this.bytesPerDim = bytesPerDim;
aggregatedJoinScores = new ArrayList<>();
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
BytesRefBuilder previous = null;
BytesRef current;
while ((current = packedPoints.next()) != null) {
if (current.length != bytesPerDim) {
throw new IllegalArgumentException("packed point length should be " + (bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\"bytesPerDim=" + bytesPerDim);
}
if (previous == null) {
previous = new BytesRefBuilder();
} else {
int cmp = previous.get().compareTo(current);
if (cmp == 0) {
throw new IllegalArgumentException("unexpected duplicated value: " + current);
} else if (cmp >= 0) {
throw new IllegalArgumentException("values are out of order: saw " + previous + " before " + current);
}
}
builder.add(field, current);
aggregatedJoinScores.add(packedPoints.score);
previous.copyBytes(current);
}
sortedPackedPoints = builder.finish();
sortedPackedPointsHashCode = sortedPackedPoints.hashCode();
this.ramBytesUsed = BASE_RAM_BYTES +
RamUsageEstimator.sizeOfObject(this.field) +
RamUsageEstimator.sizeOfObject(this.originalQuery, RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED) +
RamUsageEstimator.sizeOfObject(this.sortedPackedPoints);
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public final Weight createWeight(IndexSearcher searcher, org.apache.lucene.search.ScoreMode scoreMode, float boost) throws IOException {
return new Weight(this) {
@Override
public void extractTerms(Set<Term> terms) {
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Scorer scorer = scorer(context);
if (scorer != null) {
int target = scorer.iterator().advance(doc);
if (doc == target) {
return Explanation.match(scorer.score(), "A match");
}
}
return Explanation.noMatch("Not a match");
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(field);
if (fieldInfo == null) {
return null;
}
if (fieldInfo.getPointDimensionCount() != 1) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numDims=" + fieldInfo.getPointDimensionCount() + " but this query has numDims=1");
}
if (fieldInfo.getPointNumBytes() != bytesPerDim) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + fieldInfo.getPointNumBytes() + " but this query has bytesPerDim=" + bytesPerDim);
}
PointValues values = reader.getPointValues(field);
if (values == null) {
return null;
}
FixedBitSet result = new FixedBitSet(reader.maxDoc());
float[] scores = new float[reader.maxDoc()];
values.intersect(new MergePointVisitor(sortedPackedPoints, result, scores));
return new Scorer(this) {
DocIdSetIterator disi = new BitSetIterator(result, 10L);
@Override
public float score() throws IOException {
return scores[docID()];
}
@Override
public float getMaxScore(int upTo) throws IOException {
return Float.POSITIVE_INFINITY;
}
@Override
public int docID() {
return disi.docID();
}
@Override
public DocIdSetIterator iterator() {
return disi;
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
private class MergePointVisitor implements IntersectVisitor {
private final FixedBitSet result;
private final float[] scores;
private TermIterator iterator;
private Iterator<Float> scoreIterator;
private BytesRef nextQueryPoint;
float nextScore;
private final BytesRef scratch = new BytesRef();
private MergePointVisitor(PrefixCodedTerms sortedPackedPoints, FixedBitSet result, float[] scores) throws IOException {
this.result = result;
this.scores = scores;
scratch.length = bytesPerDim;
this.iterator = sortedPackedPoints.iterator();
this.scoreIterator = aggregatedJoinScores.iterator();
nextQueryPoint = iterator.next();
if (scoreIterator.hasNext()) {
nextScore = scoreIterator.next();
}
}
@Override
public void visit(int docID) {
throw new IllegalStateException("shouldn't get here, since CELL_INSIDE_QUERY isn't emitted");
}
@Override
public void visit(int docID, byte[] packedValue) {
scratch.bytes = packedValue;
while (nextQueryPoint != null) {
int cmp = nextQueryPoint.compareTo(scratch);
if (cmp == 0) {
// Query point equals index point, so collect and return
if (multipleValuesPerDocument) {
if (result.get(docID) == false) {
result.set(docID);
scores[docID] = nextScore;
}
} else {
result.set(docID);
scores[docID] = nextScore;
}
break;
} else if (cmp < 0) {
// Query point is before index point, so we move to next query point
nextQueryPoint = iterator.next();
if (scoreIterator.hasNext()) {
nextScore = scoreIterator.next();
}
} else {
// Query point is after index point, so we don't collect and we return:
break;
}
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
while (nextQueryPoint != null) {
scratch.bytes = minPackedValue;
int cmpMin = nextQueryPoint.compareTo(scratch);
if (cmpMin < 0) {
// query point is before the start of this cell
nextQueryPoint = iterator.next();
if (scoreIterator.hasNext()) {
nextScore = scoreIterator.next();
}
continue;
}
scratch.bytes = maxPackedValue;
int cmpMax = nextQueryPoint.compareTo(scratch);
if (cmpMax > 0) {
// query point is after the end of this cell
return Relation.CELL_OUTSIDE_QUERY;
}
return Relation.CELL_CROSSES_QUERY;
}
// We exhausted all points in the query:
return Relation.CELL_OUTSIDE_QUERY;
}
}
@Override
public final int hashCode() {
int hash = classHash();
hash = 31 * hash + scoreMode.hashCode();
hash = 31 * hash + field.hashCode();
hash = 31 * hash + originalQuery.hashCode();
hash = 31 * hash + sortedPackedPointsHashCode;
hash = 31 * hash + bytesPerDim;
return hash;
}
@Override
public final boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(PointInSetIncludingScoreQuery other) {
return other.scoreMode.equals(scoreMode) &&
other.field.equals(field) &&
other.originalQuery.equals(originalQuery) &&
other.bytesPerDim == bytesPerDim &&
other.sortedPackedPointsHashCode == sortedPackedPointsHashCode &&
other.sortedPackedPoints.equals(sortedPackedPoints);
}
@Override
public final String toString(String field) {
final StringBuilder sb = new StringBuilder();
if (this.field.equals(field) == false) {
sb.append(this.field);
sb.append(':');
}
sb.append("{");
TermIterator iterator = sortedPackedPoints.iterator();
byte[] pointBytes = new byte[bytesPerDim];
boolean first = true;
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
if (first == false) {
sb.append(" ");
}
first = false;
System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length);
sb.append(toString(pointBytes));
}
sb.append("}");
return sb.toString();
}
protected abstract String toString(byte[] value);
@Override
public long ramBytesUsed() {
return ramBytesUsed;
}
}