| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.solr.search.grouping.distributed.shardresultserializer; |
| |
| import java.io.IOException; |
| import java.lang.invoke.MethodHandles; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| |
| import org.apache.lucene.document.Document; |
| import org.apache.lucene.search.FieldDoc; |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.Sort; |
| import org.apache.lucene.search.TopDocs; |
| import org.apache.lucene.search.TopFieldDocs; |
| import org.apache.lucene.search.TotalHits; |
| import org.apache.lucene.search.grouping.GroupDocs; |
| import org.apache.lucene.search.grouping.TopGroups; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.CharsRefBuilder; |
| import org.apache.solr.common.util.NamedList; |
| import org.apache.solr.handler.component.ResponseBuilder; |
| import org.apache.solr.handler.component.ShardDoc; |
| import org.apache.solr.schema.FieldType; |
| import org.apache.solr.schema.IndexSchema; |
| import org.apache.solr.schema.SchemaField; |
| import org.apache.solr.search.grouping.Command; |
| import org.apache.solr.search.grouping.distributed.command.QueryCommand; |
| import org.apache.solr.search.grouping.distributed.command.QueryCommandResult; |
| import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import static org.apache.solr.common.params.CommonParams.ID; |
| |
| /** |
| * Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList} structure and |
| * vice versa. |
| */ |
| @SuppressWarnings({"rawtypes"}) |
| public class TopGroupsResultTransformer implements ShardResultTransformer<List<Command>, Map<String, ?>> { |
| |
| private final ResponseBuilder rb; |
| |
| private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); |
| |
| public TopGroupsResultTransformer(ResponseBuilder rb) { |
| this.rb = rb; |
| } |
| |
| @Override |
| public NamedList transform(List<Command> data) throws IOException { |
| NamedList<NamedList> result = new NamedList<>(); |
| final IndexSchema schema = rb.req.getSearcher().getSchema(); |
| for (Command command : data) { |
| NamedList commandResult; |
| if (TopGroupsFieldCommand.class.isInstance(command)) { |
| TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command; |
| SchemaField groupField = schema.getField(fieldCommand.getKey()); |
| commandResult = serializeTopGroups(fieldCommand.result(), groupField); |
| } else if (QueryCommand.class.isInstance(command)) { |
| QueryCommand queryCommand = (QueryCommand) command; |
| commandResult = serializeTopDocs(queryCommand.result()); |
| } else { |
| commandResult = null; |
| } |
| |
| result.add(command.getKey(), commandResult); |
| } |
| return result; |
| } |
| |
| @Override |
| public Map<String, ?> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort, Sort withinGroupSort, String shard) { |
| Map<String, Object> result = new HashMap<>(); |
| |
| final IndexSchema schema = rb.req.getSearcher().getSchema(); |
| |
| for (Map.Entry<String, NamedList> entry : shardResponse) { |
| String key = entry.getKey(); |
| NamedList commandResult = entry.getValue(); |
| Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount"); |
| Number totalHits = (Number) commandResult.get("totalHits"); // previously Integer now Long |
| if (totalHits != null) { |
| Integer matches = (Integer) commandResult.get("matches"); |
| Float maxScore = (Float) commandResult.get("maxScore"); |
| if (maxScore == null) { |
| maxScore = Float.NaN; |
| } |
| |
| @SuppressWarnings("unchecked") |
| List<NamedList<Object>> documents = (List<NamedList<Object>>) commandResult.get("documents"); |
| ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, groupSort, shard, schema); |
| final TopDocs topDocs; |
| if (withinGroupSort.equals(Sort.RELEVANCE)) { |
| topDocs = new TopDocs(new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs); |
| } else { |
| topDocs = new TopFieldDocs(new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs, withinGroupSort.getSort()); |
| } |
| result.put(key, new QueryCommandResult(topDocs, matches, maxScore)); |
| continue; |
| } |
| |
| Integer totalHitCount = (Integer) commandResult.get("totalHitCount"); |
| |
| List<GroupDocs<BytesRef>> groupDocs = new ArrayList<>(); |
| for (int i = 2; i < commandResult.size(); i++) { |
| String groupValue = commandResult.getName(i); |
| @SuppressWarnings("unchecked") |
| NamedList<Object> groupResult = (NamedList<Object>) commandResult.getVal(i); |
| Number totalGroupHits = (Number) groupResult.get("totalHits"); // // previously Integer now Long |
| Float maxScore = (Float) groupResult.get("maxScore"); |
| if (maxScore == null) { |
| maxScore = Float.NaN; |
| } |
| |
| @SuppressWarnings("unchecked") |
| List<NamedList<Object>> documents = (List<NamedList<Object>>) groupResult.get("documents"); |
| ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, withinGroupSort, shard, schema); |
| |
| BytesRef groupValueRef = groupValue != null ? new BytesRef(groupValue) : null; |
| groupDocs.add(new GroupDocs<>(Float.NaN, maxScore, new TotalHits(totalGroupHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs, groupValueRef, null)); |
| } |
| |
| @SuppressWarnings({"unchecked"}) |
| GroupDocs<BytesRef>[] groupDocsArr = groupDocs.toArray(new GroupDocs[groupDocs.size()]); |
| TopGroups<BytesRef> topGroups = new TopGroups<>( |
| groupSort.getSort(), withinGroupSort.getSort(), totalHitCount, totalGroupedHitCount, groupDocsArr, Float.NaN |
| ); |
| |
| result.put(key, topGroups); |
| } |
| |
| return result; |
| } |
| |
| protected ScoreDoc[] transformToNativeShardDoc(List<NamedList<Object>> documents, Sort groupSort, String shard, |
| IndexSchema schema) { |
| ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()]; |
| int j = 0; |
| for (NamedList<Object> document : documents) { |
| Object docId = document.get(ID); |
| if (docId != null) { |
| docId = docId.toString(); |
| } else { |
| log.error("doc {} has null 'id'", document); |
| } |
| Float score = (Float) document.get("score"); |
| if (score == null) { |
| score = Float.NaN; |
| } |
| Object[] sortValues = null; |
| Object sortValuesVal = document.get("sortValues"); |
| if (sortValuesVal != null) { |
| sortValues = ((List) sortValuesVal).toArray(); |
| for (int k = 0; k < sortValues.length; k++) { |
| SchemaField field = groupSort.getSort()[k].getField() != null |
| ? schema.getFieldOrNull(groupSort.getSort()[k].getField()) : null; |
| sortValues[k] = ShardResultTransformerUtils.unmarshalSortValue(sortValues[k], field); |
| } |
| } else { |
| log.debug("doc {} has null 'sortValues'", document); |
| } |
| scoreDocs[j++] = new ShardDoc(score, sortValues, docId, shard); |
| } |
| return scoreDocs; |
| } |
| |
| protected NamedList serializeTopGroups(TopGroups<BytesRef> data, SchemaField groupField) throws IOException { |
| NamedList<Object> result = new NamedList<>(); |
| result.add("totalGroupedHitCount", data.totalGroupedHitCount); |
| result.add("totalHitCount", data.totalHitCount); |
| if (data.totalGroupCount != null) { |
| result.add("totalGroupCount", data.totalGroupCount); |
| } |
| |
| final IndexSchema schema = rb.req.getSearcher().getSchema(); |
| SchemaField uniqueField = schema.getUniqueKeyField(); |
| for (GroupDocs<BytesRef> searchGroup : data.groups) { |
| NamedList<Object> groupResult = new NamedList<>(); |
| assert searchGroup.totalHits.relation == TotalHits.Relation.EQUAL_TO; |
| groupResult.add("totalHits", searchGroup.totalHits.value); |
| if (!Float.isNaN(searchGroup.maxScore)) { |
| groupResult.add("maxScore", searchGroup.maxScore); |
| } |
| |
| List<NamedList<Object>> documents = new ArrayList<>(); |
| for (int i = 0; i < searchGroup.scoreDocs.length; i++) { |
| NamedList<Object> document = new NamedList<>(); |
| documents.add(document); |
| |
| Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc); |
| document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName()))); |
| if (!Float.isNaN(searchGroup.scoreDocs[i].score)) { |
| document.add("score", searchGroup.scoreDocs[i].score); |
| } |
| if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) { |
| continue; // thus don't add sortValues below |
| } |
| |
| FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i]; |
| Object[] convertedSortValues = new Object[fieldDoc.fields.length]; |
| for (int j = 0; j < fieldDoc.fields.length; j++) { |
| Object sortValue = fieldDoc.fields[j]; |
| Sort withinGroupSort = rb.getGroupingSpec().getWithinGroupSortSpec().getSort(); |
| SchemaField field = withinGroupSort.getSort()[j].getField() != null ? schema.getFieldOrNull(withinGroupSort.getSort()[j].getField()) : null; |
| if (field != null) { |
| FieldType fieldType = field.getType(); |
| if (sortValue != null) { |
| sortValue = fieldType.marshalSortValue(sortValue); |
| } |
| } |
| convertedSortValues[j] = sortValue; |
| } |
| document.add("sortValues", convertedSortValues); |
| } |
| groupResult.add("documents", documents); |
| String groupValue = searchGroup.groupValue != null ? |
| groupField.getType().indexedToReadable(searchGroup.groupValue, new CharsRefBuilder()).toString(): null; |
| result.add(groupValue, groupResult); |
| } |
| |
| return result; |
| } |
| |
| protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException { |
| NamedList<Object> queryResult = new NamedList<>(); |
| queryResult.add("matches", result.getMatches()); |
| TopDocs topDocs = result.getTopDocs(); |
| assert topDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO; |
| queryResult.add("totalHits", topDocs.totalHits.value); |
| // debug: assert !Float.isNaN(result.getTopDocs().getMaxScore()) == rb.getGroupingSpec().isNeedScore(); |
| if (!Float.isNaN(result.getMaxScore())) { |
| queryResult.add("maxScore", result.getMaxScore()); |
| } |
| List<NamedList> documents = new ArrayList<>(); |
| queryResult.add("documents", documents); |
| |
| final IndexSchema schema = rb.req.getSearcher().getSchema(); |
| SchemaField uniqueField = schema.getUniqueKeyField(); |
| for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) { |
| NamedList<Object> document = new NamedList<>(); |
| documents.add(document); |
| |
| Document doc = retrieveDocument(uniqueField, scoreDoc.doc); |
| document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName()))); |
| if (!Float.isNaN(scoreDoc.score)) { |
| document.add("score", scoreDoc.score); |
| } |
| if (!FieldDoc.class.isInstance(scoreDoc)) { |
| continue; // thus don't add sortValues below |
| } |
| |
| FieldDoc fieldDoc = (FieldDoc) scoreDoc; |
| Object[] convertedSortValues = new Object[fieldDoc.fields.length]; |
| for (int j = 0; j < fieldDoc.fields.length; j++) { |
| Object sortValue = fieldDoc.fields[j]; |
| Sort groupSort = rb.getGroupingSpec().getGroupSortSpec().getSort(); |
| SchemaField field = groupSort.getSort()[j].getField() != null |
| ? schema.getFieldOrNull(groupSort.getSort()[j].getField()) : null; |
| convertedSortValues[j] = ShardResultTransformerUtils.marshalSortValue(sortValue, field); |
| } |
| document.add("sortValues", convertedSortValues); |
| } |
| |
| return queryResult; |
| } |
| |
| private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException { |
| return rb.req.getSearcher().doc(doc, Collections.singleton(uniqueField.getName())); |
| } |
| |
| } |