blob: 456a885b49a0fde7b6cb3f02a7c59ca0b822b8a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.join;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
/**
* kNN byte vector query that joins matching children vector documents with their parent doc id. The
* top documents returned are the child document ids and the calculated scores. Here is how to use
* this in conjunction with {@link ToParentBlockJoinQuery}.
*
* <pre class="prettyprint">
* Query knnQuery = new DiversifyingChildrenByteKnnVectorQuery(fieldName, queryVector, ...);
* // Rewrite executes kNN search and collects nearest children docIds and their scores
* Query rewrittenKnnQuery = searcher.rewrite(knnQuery);
* // Join the scored children docs with their parents and score the parents
* Query childrenToParents = new ToParentBlockJoinQuery(rewrittenKnnQuery, parentsFilter, ScoreMode.MAX);
* </pre>
*/
public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final BitSetProducer parentsFilter;
private final Query childFilter;
private final int k;
private final byte[] query;
/**
* Create a ToParentBlockJoinByteVectorQuery.
*
* @param field the query field
* @param query the vector query
* @param childFilter the child filter
* @param k how many parent documents to return given the matching children
* @param parentsFilter Filter identifying the parent documents.
*/
public DiversifyingChildrenByteKnnVectorQuery(
String field, byte[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
super(field, query, k, childFilter);
this.childFilter = childFilter;
this.parentsFilter = parentsFilter;
this.k = k;
this.query = query;
}
@Override
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}
VectorScorer scorer = byteVectorValues.scorer(query);
if (scorer == null) {
return NO_RESULTS;
}
DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer(
acceptIterator, parentBitSet, scorer);
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = vectorScorer.bestChild();
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}
@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
ByteVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs();
}
@Override
public String toString(String field) {
return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]";
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
DiversifyingChildrenByteKnnVectorQuery that = (DiversifyingChildrenByteKnnVectorQuery) o;
return k == that.k
&& Objects.equals(parentsFilter, that.parentsFilter)
&& Objects.equals(childFilter, that.childFilter)
&& Arrays.equals(query, that.query);
}
@Override
public int hashCode() {
int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k);
result = 31 * result + Arrays.hashCode(query);
return result;
}
}