blob: 047c31ed3d787ecaf62ea898bb515332269362df [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.grouping.distributed.shardresultserializer;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.Command;
import org.apache.solr.search.grouping.distributed.command.QueryCommand;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.solr.common.params.CommonParams.ID;
/**
* Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList} structure and
* vice versa.
*/
@SuppressWarnings({"rawtypes"})
public class TopGroupsResultTransformer implements ShardResultTransformer<List<Command>, Map<String, ?>> {
private final ResponseBuilder rb;
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public TopGroupsResultTransformer(ResponseBuilder rb) {
this.rb = rb;
}
@Override
public NamedList transform(List<Command> data) throws IOException {
NamedList<NamedList> result = new NamedList<>();
final IndexSchema schema = rb.req.getSearcher().getSchema();
for (Command command : data) {
NamedList commandResult;
if (TopGroupsFieldCommand.class.isInstance(command)) {
TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command;
SchemaField groupField = schema.getField(fieldCommand.getKey());
commandResult = serializeTopGroups(fieldCommand.result(), groupField);
} else if (QueryCommand.class.isInstance(command)) {
QueryCommand queryCommand = (QueryCommand) command;
commandResult = serializeTopDocs(queryCommand.result());
} else {
commandResult = null;
}
result.add(command.getKey(), commandResult);
}
return result;
}
@Override
public Map<String, ?> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort, Sort withinGroupSort, String shard) {
Map<String, Object> result = new HashMap<>();
final IndexSchema schema = rb.req.getSearcher().getSchema();
for (Map.Entry<String, NamedList> entry : shardResponse) {
String key = entry.getKey();
NamedList commandResult = entry.getValue();
Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount");
Number totalHits = (Number) commandResult.get("totalHits"); // previously Integer now Long
if (totalHits != null) {
Integer matches = (Integer) commandResult.get("matches");
Float maxScore = (Float) commandResult.get("maxScore");
if (maxScore == null) {
maxScore = Float.NaN;
}
@SuppressWarnings("unchecked")
List<NamedList<Object>> documents = (List<NamedList<Object>>) commandResult.get("documents");
ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, groupSort, shard, schema);
final TopDocs topDocs;
if (withinGroupSort.equals(Sort.RELEVANCE)) {
topDocs = new TopDocs(new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs);
} else {
topDocs = new TopFieldDocs(new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs, withinGroupSort.getSort());
}
result.put(key, new QueryCommandResult(topDocs, matches, maxScore));
continue;
}
Integer totalHitCount = (Integer) commandResult.get("totalHitCount");
List<GroupDocs<BytesRef>> groupDocs = new ArrayList<>();
for (int i = 2; i < commandResult.size(); i++) {
String groupValue = commandResult.getName(i);
@SuppressWarnings("unchecked")
NamedList<Object> groupResult = (NamedList<Object>) commandResult.getVal(i);
Number totalGroupHits = (Number) groupResult.get("totalHits"); // // previously Integer now Long
Float maxScore = (Float) groupResult.get("maxScore");
if (maxScore == null) {
maxScore = Float.NaN;
}
@SuppressWarnings("unchecked")
List<NamedList<Object>> documents = (List<NamedList<Object>>) groupResult.get("documents");
ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, withinGroupSort, shard, schema);
BytesRef groupValueRef = groupValue != null ? new BytesRef(groupValue) : null;
groupDocs.add(new GroupDocs<>(Float.NaN, maxScore, new TotalHits(totalGroupHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs, groupValueRef, null));
}
@SuppressWarnings({"unchecked"})
GroupDocs<BytesRef>[] groupDocsArr = groupDocs.toArray(new GroupDocs[groupDocs.size()]);
TopGroups<BytesRef> topGroups = new TopGroups<>(
groupSort.getSort(), withinGroupSort.getSort(), totalHitCount, totalGroupedHitCount, groupDocsArr, Float.NaN
);
result.put(key, topGroups);
}
return result;
}
protected ScoreDoc[] transformToNativeShardDoc(List<NamedList<Object>> documents, Sort groupSort, String shard,
IndexSchema schema) {
ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
int j = 0;
for (NamedList<Object> document : documents) {
Object docId = document.get(ID);
if (docId != null) {
docId = docId.toString();
} else {
log.error("doc {} has null 'id'", document);
}
Float score = (Float) document.get("score");
if (score == null) {
score = Float.NaN;
}
Object[] sortValues = null;
Object sortValuesVal = document.get("sortValues");
if (sortValuesVal != null) {
sortValues = ((List) sortValuesVal).toArray();
for (int k = 0; k < sortValues.length; k++) {
SchemaField field = groupSort.getSort()[k].getField() != null
? schema.getFieldOrNull(groupSort.getSort()[k].getField()) : null;
sortValues[k] = ShardResultTransformerUtils.unmarshalSortValue(sortValues[k], field);
}
} else {
log.debug("doc {} has null 'sortValues'", document);
}
scoreDocs[j++] = new ShardDoc(score, sortValues, docId, shard);
}
return scoreDocs;
}
protected NamedList serializeTopGroups(TopGroups<BytesRef> data, SchemaField groupField) throws IOException {
NamedList<Object> result = new NamedList<>();
result.add("totalGroupedHitCount", data.totalGroupedHitCount);
result.add("totalHitCount", data.totalHitCount);
if (data.totalGroupCount != null) {
result.add("totalGroupCount", data.totalGroupCount);
}
final IndexSchema schema = rb.req.getSearcher().getSchema();
SchemaField uniqueField = schema.getUniqueKeyField();
for (GroupDocs<BytesRef> searchGroup : data.groups) {
NamedList<Object> groupResult = new NamedList<>();
assert searchGroup.totalHits.relation == TotalHits.Relation.EQUAL_TO;
groupResult.add("totalHits", searchGroup.totalHits.value);
if (!Float.isNaN(searchGroup.maxScore)) {
groupResult.add("maxScore", searchGroup.maxScore);
}
List<NamedList<Object>> documents = new ArrayList<>();
for (int i = 0; i < searchGroup.scoreDocs.length; i++) {
NamedList<Object> document = new NamedList<>();
documents.add(document);
Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc);
document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
if (!Float.isNaN(searchGroup.scoreDocs[i].score)) {
document.add("score", searchGroup.scoreDocs[i].score);
}
if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) {
continue; // thus don't add sortValues below
}
FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i];
Object[] convertedSortValues = new Object[fieldDoc.fields.length];
for (int j = 0; j < fieldDoc.fields.length; j++) {
Object sortValue = fieldDoc.fields[j];
Sort withinGroupSort = rb.getGroupingSpec().getWithinGroupSortSpec().getSort();
SchemaField field = withinGroupSort.getSort()[j].getField() != null ? schema.getFieldOrNull(withinGroupSort.getSort()[j].getField()) : null;
if (field != null) {
FieldType fieldType = field.getType();
if (sortValue != null) {
sortValue = fieldType.marshalSortValue(sortValue);
}
}
convertedSortValues[j] = sortValue;
}
document.add("sortValues", convertedSortValues);
}
groupResult.add("documents", documents);
String groupValue = searchGroup.groupValue != null ?
groupField.getType().indexedToReadable(searchGroup.groupValue, new CharsRefBuilder()).toString(): null;
result.add(groupValue, groupResult);
}
return result;
}
protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException {
NamedList<Object> queryResult = new NamedList<>();
queryResult.add("matches", result.getMatches());
TopDocs topDocs = result.getTopDocs();
assert topDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
queryResult.add("totalHits", topDocs.totalHits.value);
// debug: assert !Float.isNaN(result.getTopDocs().getMaxScore()) == rb.getGroupingSpec().isNeedScore();
if (!Float.isNaN(result.getMaxScore())) {
queryResult.add("maxScore", result.getMaxScore());
}
List<NamedList> documents = new ArrayList<>();
queryResult.add("documents", documents);
final IndexSchema schema = rb.req.getSearcher().getSchema();
SchemaField uniqueField = schema.getUniqueKeyField();
for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) {
NamedList<Object> document = new NamedList<>();
documents.add(document);
Document doc = retrieveDocument(uniqueField, scoreDoc.doc);
document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
if (!Float.isNaN(scoreDoc.score)) {
document.add("score", scoreDoc.score);
}
if (!FieldDoc.class.isInstance(scoreDoc)) {
continue; // thus don't add sortValues below
}
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
Object[] convertedSortValues = new Object[fieldDoc.fields.length];
for (int j = 0; j < fieldDoc.fields.length; j++) {
Object sortValue = fieldDoc.fields[j];
Sort groupSort = rb.getGroupingSpec().getGroupSortSpec().getSort();
SchemaField field = groupSort.getSort()[j].getField() != null
? schema.getFieldOrNull(groupSort.getSort()[j].getField()) : null;
convertedSortValues[j] = ShardResultTransformerUtils.marshalSortValue(sortValue, field);
}
document.add("sortValues", convertedSortValues);
}
return queryResult;
}
private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException {
return rb.req.getSearcher().doc(doc, Collections.singleton(uniqueField.getName()));
}
}