| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.join; |
| |
| import java.io.IOException; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.Locale; |
| import java.util.Map; |
| import java.util.TreeSet; |
| import java.util.function.BiConsumer; |
| import java.util.function.LongFunction; |
| import org.apache.lucene.document.DoublePoint; |
| import org.apache.lucene.document.FloatPoint; |
| import org.apache.lucene.document.IntPoint; |
| import org.apache.lucene.document.LongPoint; |
| import org.apache.lucene.index.BinaryDocValues; |
| import org.apache.lucene.index.DocValues; |
| import org.apache.lucene.index.DocValuesType; |
| import org.apache.lucene.index.LeafReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.NumericDocValues; |
| import org.apache.lucene.index.OrdinalMap; |
| import org.apache.lucene.index.SortedDocValues; |
| import org.apache.lucene.index.SortedNumericDocValues; |
| import org.apache.lucene.index.SortedSetDocValues; |
| import org.apache.lucene.search.Collector; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.MatchNoDocsQuery; |
| import org.apache.lucene.search.PointInSetQuery; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.search.Scorable; |
| import org.apache.lucene.search.SimpleCollector; |
| import org.apache.lucene.search.join.DocValuesTermsCollector.Function; |
| import org.apache.lucene.util.BytesRef; |
| |
| /** |
| * Utility for query time joining. |
| * |
| * @lucene.experimental |
| */ |
| public final class JoinUtil { |
| |
| // No instances allowed |
| private JoinUtil() {} |
| |
| /** |
| * Method for query time joining. |
| * |
| * <p>Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have |
| * the same terms in the to field that match with documents matching the specified fromQuery and |
| * have the same terms in the from field. |
| * |
| * <p>In the case a single document relates to more than one document the <code> |
| * multipleValuesPerDocument</code> option should be set to true. When the <code> |
| * multipleValuesPerDocument</code> is set to <code>true</code> only the the score from the first |
| * encountered join value originating from the 'from' side is mapped into the 'to' side. Even in |
| * the case when a second join value related to a specific document yields a higher score. |
| * Obviously this doesn't apply in the case that {@link ScoreMode#None} is used, since no scores |
| * are computed at all. |
| * |
| * <p>Memory considerations: During joining all unique join values are kept in memory. On top of |
| * that when the scoreMode isn't set to {@link ScoreMode#None} a float value per unique join value |
| * is kept in memory for computing scores. When scoreMode is set to {@link ScoreMode#Avg} also an |
| * additional integer value is kept in memory per unique join value. |
| * |
| * @param fromField The from field to join from |
| * @param multipleValuesPerDocument Whether the from field has multiple terms per document |
| * @param toField The to field to join to |
| * @param fromQuery The query to match documents on the from side |
| * @param fromSearcher The searcher that executed the specified fromQuery |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @return a {@link Query} instance that can be used to join documents based on the terms in the |
| * from and to field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery( |
| String fromField, |
| boolean multipleValuesPerDocument, |
| String toField, |
| Query fromQuery, |
| IndexSearcher fromSearcher, |
| ScoreMode scoreMode) |
| throws IOException { |
| |
| final GenericTermsCollector termsWithScoreCollector; |
| |
| if (multipleValuesPerDocument) { |
| Function<SortedSetDocValues> mvFunction = |
| DocValuesTermsCollector.sortedSetDocValues(fromField); |
| termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode); |
| } else { |
| Function<BinaryDocValues> svFunction = DocValuesTermsCollector.binaryDocValues(fromField); |
| termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode); |
| } |
| |
| return createJoinQuery( |
| multipleValuesPerDocument, |
| toField, |
| fromQuery, |
| fromField, |
| fromSearcher, |
| scoreMode, |
| termsWithScoreCollector); |
| } |
| |
| /** |
| * Method for query time joining for numeric fields. It supports multi- and single- values longs, |
| * ints, floats and longs. All considerations from {@link JoinUtil#createJoinQuery(String, |
| * boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, though memory |
| * consumption might be higher. |
| * |
| * <p> |
| * |
| * @param fromField The from field to join from |
| * @param multipleValuesPerDocument Whether the from field has multiple terms per document when |
| * true fromField might be {@link DocValuesType#SORTED_NUMERIC}, otherwise fromField should be |
| * {@link DocValuesType#NUMERIC} |
| * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link |
| * FloatPoint} or {@link DoublePoint}. |
| * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link |
| * java.lang.Float} or {@link java.lang.Double} it should correspond to toField types |
| * @param fromQuery The query to match documents on the from side |
| * @param fromSearcher The searcher that executed the specified fromQuery |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @return a {@link Query} instance that can be used to join documents based on the terms in the |
| * from and to field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery( |
| String fromField, |
| boolean multipleValuesPerDocument, |
| String toField, |
| Class<? extends Number> numericType, |
| Query fromQuery, |
| IndexSearcher fromSearcher, |
| ScoreMode scoreMode) |
| throws IOException { |
| TreeSet<Long> joinValues = new TreeSet<>(); |
| Map<Long, Float> aggregatedScores = new HashMap<>(); |
| Map<Long, Integer> occurrences = new HashMap<>(); |
| boolean needsScore = scoreMode != ScoreMode.None; |
| BiConsumer<Long, Float> scoreAggregator; |
| if (scoreMode == ScoreMode.Max) { |
| scoreAggregator = |
| (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, Math.max(currentValue, score)); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Min) { |
| scoreAggregator = |
| (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, Math.min(currentValue, score)); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Total) { |
| scoreAggregator = |
| (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, currentValue + score); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Avg) { |
| scoreAggregator = |
| (key, score) -> { |
| Float currentSore = aggregatedScores.putIfAbsent(key, score); |
| if (currentSore != null) { |
| aggregatedScores.put(key, currentSore + score); |
| } |
| Integer currentOccurrence = occurrences.putIfAbsent(key, 1); |
| if (currentOccurrence != null) { |
| occurrences.put(key, ++currentOccurrence); |
| } |
| }; |
| } else { |
| scoreAggregator = |
| (key, score) -> { |
| throw new UnsupportedOperationException(); |
| }; |
| } |
| |
| LongFunction<Float> joinScorer; |
| if (scoreMode == ScoreMode.Avg) { |
| joinScorer = |
| (joinValue) -> { |
| Float aggregatedScore = aggregatedScores.get(joinValue); |
| Integer occurrence = occurrences.get(joinValue); |
| return aggregatedScore / occurrence; |
| }; |
| } else { |
| joinScorer = aggregatedScores::get; |
| } |
| |
| Collector collector; |
| if (multipleValuesPerDocument) { |
| collector = |
| new SimpleCollector() { |
| |
| SortedNumericDocValues sortedNumericDocValues; |
| Scorable scorer; |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| if (sortedNumericDocValues.advanceExact(doc)) { |
| for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) { |
| long value = sortedNumericDocValues.nextValue(); |
| joinValues.add(value); |
| if (needsScore) { |
| scoreAggregator.accept(value, scorer.score()); |
| } |
| } |
| } |
| } |
| |
| @Override |
| protected void doSetNextReader(LeafReaderContext context) throws IOException { |
| sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField); |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| this.scorer = scorer; |
| } |
| |
| @Override |
| public org.apache.lucene.search.ScoreMode scoreMode() { |
| return needsScore |
| ? org.apache.lucene.search.ScoreMode.COMPLETE |
| : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; |
| } |
| }; |
| } else { |
| collector = |
| new SimpleCollector() { |
| |
| NumericDocValues numericDocValues; |
| Scorable scorer; |
| private int lastDocID = -1; |
| |
| private boolean docsInOrder(int docID) { |
| if (docID < lastDocID) { |
| throw new AssertionError( |
| "docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID); |
| } |
| lastDocID = docID; |
| return true; |
| } |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| assert docsInOrder(doc); |
| long value = 0; |
| if (numericDocValues.advanceExact(doc)) { |
| value = numericDocValues.longValue(); |
| } |
| joinValues.add(value); |
| if (needsScore) { |
| scoreAggregator.accept(value, scorer.score()); |
| } |
| } |
| |
| @Override |
| protected void doSetNextReader(LeafReaderContext context) throws IOException { |
| numericDocValues = DocValues.getNumeric(context.reader(), fromField); |
| lastDocID = -1; |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| this.scorer = scorer; |
| } |
| |
| @Override |
| public org.apache.lucene.search.ScoreMode scoreMode() { |
| return needsScore |
| ? org.apache.lucene.search.ScoreMode.COMPLETE |
| : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; |
| } |
| }; |
| } |
| fromSearcher.search(fromQuery, collector); |
| |
| Iterator<Long> iterator = joinValues.iterator(); |
| |
| final int bytesPerDim; |
| final BytesRef encoded = new BytesRef(); |
| final PointInSetIncludingScoreQuery.Stream stream; |
| if (Integer.class.equals(numericType)) { |
| bytesPerDim = Integer.BYTES; |
| stream = |
| new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| IntPoint.encodeDimension((int) value, encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Long.class.equals(numericType)) { |
| bytesPerDim = Long.BYTES; |
| stream = |
| new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| LongPoint.encodeDimension(value, encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Float.class.equals(numericType)) { |
| bytesPerDim = Float.BYTES; |
| stream = |
| new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Double.class.equals(numericType)) { |
| bytesPerDim = Double.BYTES; |
| stream = |
| new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else { |
| throw new IllegalArgumentException( |
| "unsupported numeric type, only Integer, Long, Float and Double are supported"); |
| } |
| |
| encoded.bytes = new byte[bytesPerDim]; |
| encoded.length = bytesPerDim; |
| |
| if (needsScore) { |
| return new PointInSetIncludingScoreQuery( |
| scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) { |
| |
| @Override |
| protected String toString(byte[] value) { |
| return toString.apply(value, numericType); |
| } |
| }; |
| } else { |
| return new PointInSetQuery(toField, 1, bytesPerDim, stream) { |
| @Override |
| protected String toString(byte[] value) { |
| return PointInSetIncludingScoreQuery.toString.apply(value, numericType); |
| } |
| }; |
| } |
| } |
| |
| private static Query createJoinQuery( |
| boolean multipleValuesPerDocument, |
| String toField, |
| Query fromQuery, |
| String fromField, |
| IndexSearcher fromSearcher, |
| ScoreMode scoreMode, |
| final GenericTermsCollector collector) |
| throws IOException { |
| |
| fromSearcher.search(fromQuery, collector); |
| switch (scoreMode) { |
| case None: |
| return new TermsQuery( |
| toField, |
| collector.getCollectedTerms(), |
| fromField, |
| fromQuery, |
| fromSearcher.getTopReaderContext().id()); |
| case Total: |
| case Max: |
| case Min: |
| case Avg: |
| return new TermsIncludingScoreQuery( |
| scoreMode, |
| toField, |
| multipleValuesPerDocument, |
| collector.getCollectedTerms(), |
| collector.getScoresPerTerm(), |
| fromField, |
| fromQuery, |
| fromSearcher.getTopReaderContext().id()); |
| default: |
| throw new IllegalArgumentException( |
| String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); |
| } |
| } |
| |
| /** |
| * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode, |
| * OrdinalMap, int, int)}, but disables the min and max filtering. |
| * |
| * @param joinField The {@link SortedDocValues} field containing the join values |
| * @param fromQuery The query containing the actual user query. Also the fromQuery can only match |
| * "from" documents. |
| * @param toQuery The query identifying all documents on the "to" side. |
| * @param searcher The index searcher used to execute the from query |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment |
| * index, no ordinal map needs to be provided. |
| * @return a {@link Query} instance that can be used to join documents based on the join field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery( |
| String joinField, |
| Query fromQuery, |
| Query toQuery, |
| IndexSearcher searcher, |
| ScoreMode scoreMode, |
| OrdinalMap ordinalMap) |
| throws IOException { |
| return createJoinQuery( |
| joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE); |
| } |
| |
| /** |
| * A query time join using global ordinals over a dedicated join field. |
| * |
| * <p>This join has certain restrictions and requirements: 1) A document can only refer to one |
| * other document. (but can be referred by one or more documents) 2) Documents on each side of the |
| * join must be distinguishable. Typically this can be done by adding an extra field that |
| * identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into |
| * account. 3) There must be a single sorted doc values join field used by both the "from" and |
| * "to" documents. This join field should store the join values as UTF-8 strings. 4) An ordinal |
| * map must be provided that is created on top of the join field. |
| * |
| * <p>Note: min and max filtering and the avg score mode will require this join to keep track of |
| * the number of times a document matches per join value. This will increase the per join cost in |
| * terms of execution time and memory. |
| * |
| * @param joinField The {@link SortedDocValues} field containing the join values |
| * @param fromQuery The query containing the actual user query. Also the fromQuery can only match |
| * "from" documents. |
| * @param toQuery The query identifying all documents on the "to" side. |
| * @param searcher The index searcher used to execute the from query |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment |
| * index, no ordinal map needs to be provided. |
| * @param min Optionally the minimum number of "from" documents that are required to match for a |
| * "to" document to be a match. The min is inclusive. Setting min to 0 and max to <code> |
| * Interger.MAX_VALUE</code> disables the min and max "from" documents filtering |
| * @param max Optionally the maximum number of "from" documents that are allowed to match for a |
| * "to" document to be a match. The max is inclusive. Setting min to 0 and max to <code> |
| * Interger.MAX_VALUE</code> disables the min and max "from" documents filtering |
| * @return a {@link Query} instance that can be used to join documents based on the join field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery( |
| String joinField, |
| Query fromQuery, |
| Query toQuery, |
| IndexSearcher searcher, |
| ScoreMode scoreMode, |
| OrdinalMap ordinalMap, |
| int min, |
| int max) |
| throws IOException { |
| int numSegments = searcher.getIndexReader().leaves().size(); |
| final long valueCount; |
| if (numSegments == 0) { |
| return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments"); |
| } else if (numSegments == 1) { |
| // No need to use the ordinal map, because there is just one segment. |
| ordinalMap = null; |
| LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader(); |
| SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField); |
| if (joinSortedDocValues != null) { |
| valueCount = joinSortedDocValues.getValueCount(); |
| } else { |
| return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values"); |
| } |
| } else { |
| if (ordinalMap == null) { |
| throw new IllegalArgumentException( |
| "OrdinalMap is required, because there is more than 1 segment"); |
| } |
| valueCount = ordinalMap.getValueCount(); |
| } |
| |
| final Query rewrittenFromQuery = searcher.rewrite(fromQuery); |
| final Query rewrittenToQuery = searcher.rewrite(toQuery); |
| GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; |
| switch (scoreMode) { |
| case Total: |
| globalOrdinalsWithScoreCollector = |
| new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Min: |
| globalOrdinalsWithScoreCollector = |
| new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Max: |
| globalOrdinalsWithScoreCollector = |
| new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Avg: |
| globalOrdinalsWithScoreCollector = |
| new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case None: |
| if (min <= 1 && max == Integer.MAX_VALUE) { |
| GlobalOrdinalsCollector globalOrdinalsCollector = |
| new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); |
| searcher.search(rewrittenFromQuery, globalOrdinalsCollector); |
| return new GlobalOrdinalsQuery( |
| globalOrdinalsCollector.getCollectorOrdinals(), |
| joinField, |
| ordinalMap, |
| rewrittenToQuery, |
| rewrittenFromQuery, |
| searcher.getTopReaderContext().id()); |
| } else { |
| globalOrdinalsWithScoreCollector = |
| new GlobalOrdinalsWithScoreCollector.NoScore( |
| joinField, ordinalMap, valueCount, min, max); |
| break; |
| } |
| default: |
| throw new IllegalArgumentException( |
| String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); |
| } |
| searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector); |
| return new GlobalOrdinalsWithScoreQuery( |
| globalOrdinalsWithScoreCollector, |
| scoreMode, |
| joinField, |
| ordinalMap, |
| rewrittenToQuery, |
| rewrittenFromQuery, |
| min, |
| max, |
| searcher.getTopReaderContext().id()); |
| } |
| } |