| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.join; |
| |
| import java.io.IOException; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.Locale; |
| import java.util.Map; |
| import java.util.TreeSet; |
| import java.util.function.BiConsumer; |
| import java.util.function.LongFunction; |
| |
| import org.apache.lucene.document.DoublePoint; |
| import org.apache.lucene.document.FloatPoint; |
| import org.apache.lucene.document.IntPoint; |
| import org.apache.lucene.document.LongPoint; |
| import org.apache.lucene.index.BinaryDocValues; |
| import org.apache.lucene.index.DocValues; |
| import org.apache.lucene.index.DocValuesType; |
| import org.apache.lucene.index.LeafReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.NumericDocValues; |
| import org.apache.lucene.index.OrdinalMap; |
| import org.apache.lucene.index.SortedDocValues; |
| import org.apache.lucene.index.SortedNumericDocValues; |
| import org.apache.lucene.index.SortedSetDocValues; |
| import org.apache.lucene.search.Collector; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.MatchNoDocsQuery; |
| import org.apache.lucene.search.PointInSetQuery; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.search.Scorable; |
| import org.apache.lucene.search.SimpleCollector; |
| import org.apache.lucene.search.join.DocValuesTermsCollector.Function; |
| import org.apache.lucene.util.BytesRef; |
| |
| /** |
| * Utility for query time joining. |
| * |
| * @lucene.experimental |
| */ |
| public final class JoinUtil { |
| |
| // No instances allowed |
| private JoinUtil() { |
| } |
| |
| /** |
| * Method for query time joining. |
| * <p> |
| * Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have the same terms in the |
| * to field that match with documents matching the specified fromQuery and have the same terms in the from field. |
| * <p> |
| * In the case a single document relates to more than one document the <code>multipleValuesPerDocument</code> option |
| * should be set to true. When the <code>multipleValuesPerDocument</code> is set to <code>true</code> only the |
| * the score from the first encountered join value originating from the 'from' side is mapped into the 'to' side. |
| * Even in the case when a second join value related to a specific document yields a higher score. Obviously this |
| * doesn't apply in the case that {@link ScoreMode#None} is used, since no scores are computed at all. |
| * <p> |
| * Memory considerations: During joining all unique join values are kept in memory. On top of that when the scoreMode |
| * isn't set to {@link ScoreMode#None} a float value per unique join value is kept in memory for computing scores. |
| * When scoreMode is set to {@link ScoreMode#Avg} also an additional integer value is kept in memory per unique |
| * join value. |
| * |
| * @param fromField The from field to join from |
| * @param multipleValuesPerDocument Whether the from field has multiple terms per document |
| * @param toField The to field to join to |
| * @param fromQuery The query to match documents on the from side |
| * @param fromSearcher The searcher that executed the specified fromQuery |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @return a {@link Query} instance that can be used to join documents based on the |
| * terms in the from and to field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery(String fromField, |
| boolean multipleValuesPerDocument, |
| String toField, |
| Query fromQuery, |
| IndexSearcher fromSearcher, |
| ScoreMode scoreMode) throws IOException { |
| |
| final GenericTermsCollector termsWithScoreCollector; |
| |
| if (multipleValuesPerDocument) { |
| Function<SortedSetDocValues> mvFunction = DocValuesTermsCollector.sortedSetDocValues(fromField); |
| termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode); |
| } else { |
| Function<BinaryDocValues> svFunction = DocValuesTermsCollector.binaryDocValues(fromField); |
| termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode); |
| } |
| |
| return createJoinQuery(multipleValuesPerDocument, toField, fromQuery, fromField, fromSearcher, scoreMode, termsWithScoreCollector); |
| } |
| |
| /** |
| * Method for query time joining for numeric fields. It supports multi- and single- values longs, ints, floats and longs. |
| * All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, |
| * though memory consumption might be higher. |
| * @param fromField The from field to join from |
| * @param multipleValuesPerDocument Whether the from field has multiple terms per document |
| * when true fromField might be {@link DocValuesType#SORTED_NUMERIC}, |
| * otherwise fromField should be {@link DocValuesType#NUMERIC} |
| * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link FloatPoint} |
| * or {@link DoublePoint}. |
| * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link java.lang.Float} |
| * or {@link java.lang.Double} it should correspond to toField types |
| * @param fromQuery The query to match documents on the from side |
| * @param fromSearcher The searcher that executed the specified fromQuery |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @return a {@link Query} instance that can be used to join documents based on the |
| * terms in the from and to field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery(String fromField, |
| boolean multipleValuesPerDocument, |
| String toField, |
| Class<? extends Number> numericType, |
| Query fromQuery, |
| IndexSearcher fromSearcher, |
| ScoreMode scoreMode) throws IOException { |
| TreeSet<Long> joinValues = new TreeSet<>(); |
| Map<Long, Float> aggregatedScores = new HashMap<>(); |
| Map<Long, Integer> occurrences = new HashMap<>(); |
| boolean needsScore = scoreMode != ScoreMode.None; |
| BiConsumer<Long, Float> scoreAggregator; |
| if (scoreMode == ScoreMode.Max) { |
| scoreAggregator = (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, Math.max(currentValue, score)); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Min) { |
| scoreAggregator = (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, Math.min(currentValue, score)); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Total) { |
| scoreAggregator = (key, score) -> { |
| Float currentValue = aggregatedScores.putIfAbsent(key, score); |
| if (currentValue != null) { |
| aggregatedScores.put(key, currentValue + score); |
| } |
| }; |
| } else if (scoreMode == ScoreMode.Avg) { |
| scoreAggregator = (key, score) -> { |
| Float currentSore = aggregatedScores.putIfAbsent(key, score); |
| if (currentSore != null) { |
| aggregatedScores.put(key, currentSore + score); |
| } |
| Integer currentOccurrence = occurrences.putIfAbsent(key, 1); |
| if (currentOccurrence != null) { |
| occurrences.put(key, ++currentOccurrence); |
| } |
| |
| }; |
| } else { |
| scoreAggregator = (key, score) -> { |
| throw new UnsupportedOperationException(); |
| }; |
| } |
| |
| LongFunction<Float> joinScorer; |
| if (scoreMode == ScoreMode.Avg) { |
| joinScorer = (joinValue) -> { |
| Float aggregatedScore = aggregatedScores.get(joinValue); |
| Integer occurrence = occurrences.get(joinValue); |
| return aggregatedScore / occurrence; |
| }; |
| } else { |
| joinScorer = aggregatedScores::get; |
| } |
| |
| Collector collector; |
| if (multipleValuesPerDocument) { |
| collector = new SimpleCollector() { |
| |
| SortedNumericDocValues sortedNumericDocValues; |
| Scorable scorer; |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| if (sortedNumericDocValues.advanceExact(doc)) { |
| for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) { |
| long value = sortedNumericDocValues.nextValue(); |
| joinValues.add(value); |
| if (needsScore) { |
| scoreAggregator.accept(value, scorer.score()); |
| } |
| } |
| } |
| } |
| |
| @Override |
| protected void doSetNextReader(LeafReaderContext context) throws IOException { |
| sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField); |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| this.scorer = scorer; |
| } |
| |
| @Override |
| public org.apache.lucene.search.ScoreMode scoreMode() { |
| return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; |
| } |
| }; |
| } else { |
| collector = new SimpleCollector() { |
| |
| NumericDocValues numericDocValues; |
| Scorable scorer; |
| private int lastDocID = -1; |
| |
| private boolean docsInOrder(int docID) { |
| if (docID < lastDocID) { |
| throw new AssertionError("docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID); |
| } |
| lastDocID = docID; |
| return true; |
| } |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| assert docsInOrder(doc); |
| long value = 0; |
| if (numericDocValues.advanceExact(doc)) { |
| value = numericDocValues.longValue(); |
| } |
| joinValues.add(value); |
| if (needsScore) { |
| scoreAggregator.accept(value, scorer.score()); |
| } |
| } |
| |
| @Override |
| protected void doSetNextReader(LeafReaderContext context) throws IOException { |
| numericDocValues = DocValues.getNumeric(context.reader(), fromField); |
| lastDocID = -1; |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) throws IOException { |
| this.scorer = scorer; |
| } |
| |
| @Override |
| public org.apache.lucene.search.ScoreMode scoreMode() { |
| return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; |
| } |
| }; |
| } |
| fromSearcher.search(fromQuery, collector); |
| |
| Iterator<Long> iterator = joinValues.iterator(); |
| |
| final int bytesPerDim; |
| final BytesRef encoded = new BytesRef(); |
| final PointInSetIncludingScoreQuery.Stream stream; |
| if (Integer.class.equals(numericType)) { |
| bytesPerDim = Integer.BYTES; |
| stream = new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| IntPoint.encodeDimension((int) value, encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Long.class.equals(numericType)) { |
| bytesPerDim = Long.BYTES; |
| stream = new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| LongPoint.encodeDimension(value, encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Float.class.equals(numericType)) { |
| bytesPerDim = Float.BYTES; |
| stream = new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else if (Double.class.equals(numericType)) { |
| bytesPerDim = Double.BYTES; |
| stream = new PointInSetIncludingScoreQuery.Stream() { |
| @Override |
| public BytesRef next() { |
| if (iterator.hasNext()) { |
| long value = iterator.next(); |
| DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0); |
| if (needsScore) { |
| score = joinScorer.apply(value); |
| } |
| return encoded; |
| } else { |
| return null; |
| } |
| } |
| }; |
| } else { |
| throw new IllegalArgumentException("unsupported numeric type, only Integer, Long, Float and Double are supported"); |
| } |
| |
| encoded.bytes = new byte[bytesPerDim]; |
| encoded.length = bytesPerDim; |
| |
| if (needsScore) { |
| return new PointInSetIncludingScoreQuery(scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) { |
| |
| @Override |
| protected String toString(byte[] value) { |
| return toString.apply(value, numericType); |
| } |
| }; |
| } else { |
| return new PointInSetQuery(toField, 1, bytesPerDim, stream) { |
| @Override |
| protected String toString(byte[] value) { |
| return PointInSetIncludingScoreQuery.toString.apply(value, numericType); |
| } |
| }; |
| } |
| } |
| |
| private static Query createJoinQuery(boolean multipleValuesPerDocument, String toField, Query fromQuery, String fromField, |
| IndexSearcher fromSearcher, ScoreMode scoreMode, final GenericTermsCollector collector) throws IOException { |
| |
| fromSearcher.search(fromQuery, collector); |
| switch (scoreMode) { |
| case None: |
| return new TermsQuery(toField, collector.getCollectedTerms(), fromField, fromQuery, fromSearcher.getTopReaderContext().id()); |
| case Total: |
| case Max: |
| case Min: |
| case Avg: |
| return new TermsIncludingScoreQuery( |
| scoreMode, |
| toField, |
| multipleValuesPerDocument, |
| collector.getCollectedTerms(), |
| collector.getScoresPerTerm(), |
| fromField, |
| fromQuery, |
| fromSearcher.getTopReaderContext().id() |
| ); |
| default: |
| throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); |
| } |
| } |
| |
| |
| /** |
| * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode, OrdinalMap, int, int)}, |
| * but disables the min and max filtering. |
| * |
| * @param joinField The {@link SortedDocValues} field containing the join values |
| * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. |
| * @param toQuery The query identifying all documents on the "to" side. |
| * @param searcher The index searcher used to execute the from query |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map |
| * needs to be provided. |
| * @return a {@link Query} instance that can be used to join documents based on the join field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery(String joinField, |
| Query fromQuery, |
| Query toQuery, |
| IndexSearcher searcher, |
| ScoreMode scoreMode, |
| OrdinalMap ordinalMap) throws IOException { |
| return createJoinQuery(joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE); |
| } |
| |
| /** |
| * A query time join using global ordinals over a dedicated join field. |
| * |
| * This join has certain restrictions and requirements: |
| * 1) A document can only refer to one other document. (but can be referred by one or more documents) |
| * 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field |
| * that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account. |
| * 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field |
| * should store the join values as UTF-8 strings. |
| * 4) An ordinal map must be provided that is created on top of the join field. |
| * |
| * Note: min and max filtering and the avg score mode will require this join to keep track of the number of times |
| * a document matches per join value. This will increase the per join cost in terms of execution time and memory. |
| * |
| * @param joinField The {@link SortedDocValues} field containing the join values |
| * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. |
| * @param toQuery The query identifying all documents on the "to" side. |
| * @param searcher The index searcher used to execute the from query |
| * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query |
| * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map |
| * needs to be provided. |
| * @param min Optionally the minimum number of "from" documents that are required to match for a "to" document |
| * to be a match. The min is inclusive. Setting min to 0 and max to <code>Interger.MAX_VALUE</code> |
| * disables the min and max "from" documents filtering |
| * @param max Optionally the maximum number of "from" documents that are allowed to match for a "to" document |
| * to be a match. The max is inclusive. Setting min to 0 and max to <code>Interger.MAX_VALUE</code> |
| * disables the min and max "from" documents filtering |
| * @return a {@link Query} instance that can be used to join documents based on the join field |
| * @throws IOException If I/O related errors occur |
| */ |
| public static Query createJoinQuery(String joinField, |
| Query fromQuery, |
| Query toQuery, |
| IndexSearcher searcher, |
| ScoreMode scoreMode, |
| OrdinalMap ordinalMap, |
| int min, |
| int max) throws IOException { |
| int numSegments = searcher.getIndexReader().leaves().size(); |
| final long valueCount; |
| if (numSegments == 0) { |
| return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments"); |
| } else if (numSegments == 1) { |
| // No need to use the ordinal map, because there is just one segment. |
| ordinalMap = null; |
| LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader(); |
| SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField); |
| if (joinSortedDocValues != null) { |
| valueCount = joinSortedDocValues.getValueCount(); |
| } else { |
| return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values"); |
| } |
| } else { |
| if (ordinalMap == null) { |
| throw new IllegalArgumentException("OrdinalMap is required, because there is more than 1 segment"); |
| } |
| valueCount = ordinalMap.getValueCount(); |
| } |
| |
| final Query rewrittenFromQuery = searcher.rewrite(fromQuery); |
| final Query rewrittenToQuery = searcher.rewrite(toQuery); |
| GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; |
| switch (scoreMode) { |
| case Total: |
| globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Min: |
| globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Max: |
| globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case Avg: |
| globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max); |
| break; |
| case None: |
| if (min <= 1 && max == Integer.MAX_VALUE) { |
| GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); |
| searcher.search(rewrittenFromQuery, globalOrdinalsCollector); |
| return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, rewrittenToQuery, |
| rewrittenFromQuery, searcher.getTopReaderContext().id()); |
| } else { |
| globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.NoScore(joinField, ordinalMap, valueCount, min, max); |
| break; |
| } |
| default: |
| throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); |
| } |
| searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector); |
| return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, scoreMode, joinField, ordinalMap, rewrittenToQuery, |
| rewrittenFromQuery, min, max, searcher.getTopReaderContext().id()); |
| } |
| |
| } |