blob: d24fca5a286305dadb6b84d035c48cb6f9f24ab4 [file] [log] [blame]
package org.apache.solr.search.grouping.distributed.shardresultserializer;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.document.FieldSelectorResult;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.Command;
import org.apache.solr.search.grouping.distributed.command.QueryCommand;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList} structure and
* visa versa.
*/
public class TopGroupsResultTransformer implements ShardResultTransformer<List<Command>, Map<String, ?>> {
private final ResponseBuilder rb;
public TopGroupsResultTransformer(ResponseBuilder rb) {
this.rb = rb;
}
/**
* {@inheritDoc}
*/
public NamedList transform(List<Command> data) throws IOException {
NamedList<NamedList> result = new NamedList<NamedList>();
for (Command command : data) {
NamedList commandResult;
if (TopGroupsFieldCommand.class.isInstance(command)) {
TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command;
SchemaField groupField = rb.req.getSearcher().getSchema().getField(fieldCommand.getKey());
commandResult = serializeTopGroups(fieldCommand.result(), groupField);
} else if (QueryCommand.class.isInstance(command)) {
QueryCommand queryCommand = (QueryCommand) command;
commandResult = serializeTopDocs(queryCommand.result());
} else {
commandResult = null;
}
result.add(command.getKey(), commandResult);
}
return result;
}
/**
* {@inheritDoc}
*/
public Map<String, ?> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort, Sort sortWithinGroup, String shard) {
Map<String, Object> result = new HashMap<String, Object>();
for (Map.Entry<String, NamedList> entry : shardResponse) {
String key = entry.getKey();
NamedList commandResult = entry.getValue();
Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount");
Integer totalHits = (Integer) commandResult.get("totalHits");
if (totalHits != null) {
Integer matches = (Integer) commandResult.get("matches");
Float maxScore = (Float) commandResult.get("maxScore");
if (maxScore == null) {
maxScore = Float.NaN;
}
@SuppressWarnings("unchecked")
List<NamedList<Object>> documents = (List<NamedList<Object>>) commandResult.get("documents");
ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
int j = 0;
for (NamedList<Object> document : documents) {
Object uniqueId = document.get("id").toString();
Float score = (Float) document.get("score");
if (score == null) {
score = Float.NaN;
}
Object[] sortValues = ((List) document.get("sortValues")).toArray();
scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
}
result.put(key, new QueryCommandResult(new TopDocs(totalHits, scoreDocs, maxScore), matches));
continue;
}
Integer totalHitCount = (Integer) commandResult.get("totalHitCount");
Integer totalGroupCount = (Integer) commandResult.get("totalGroupCount");
List<GroupDocs<String>> groupDocs = new ArrayList<GroupDocs<String>>();
for (int i = totalGroupCount == null ? 2 : 3; i < commandResult.size(); i++) {
String groupValue = commandResult.getName(i);
@SuppressWarnings("unchecked")
NamedList<Object> groupResult = (NamedList<Object>) commandResult.getVal(i);
Integer totalGroupHits = (Integer) groupResult.get("totalHits");
Float maxScore = (Float) groupResult.get("maxScore");
if (maxScore == null) {
maxScore = Float.NaN;
}
@SuppressWarnings("unchecked")
List<NamedList<Object>> documents = (List<NamedList<Object>>) groupResult.get("documents");
ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
int j = 0;
for (NamedList<Object> document : documents) {
Object uniqueId = document.get("id").toString();
Float score = (Float) document.get("score");
if (score == null) {
score = Float.NaN;
}
Object[] sortValues = ((List) document.get("sortValues")).toArray();
scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
}
String groupValueRef = groupValue != null ? groupValue : null;
groupDocs.add(new GroupDocs<String>(maxScore, totalGroupHits, scoreDocs, groupValueRef, null));
}
@SuppressWarnings("unchecked")
GroupDocs<String>[] groupDocsArr = groupDocs.toArray(new GroupDocs[groupDocs.size()]);
TopGroups<String> topGroups = new TopGroups<String>(
groupSort.getSort(), sortWithinGroup.getSort(), totalHitCount, totalGroupedHitCount, groupDocsArr
);
if (totalGroupCount != null) {
topGroups = new TopGroups<String>(topGroups, totalGroupCount);
}
result.put(key, topGroups);
}
return result;
}
protected NamedList serializeTopGroups(TopGroups<String> data, SchemaField groupField) throws IOException {
NamedList<Object> result = new NamedList<Object>();
result.add("totalGroupedHitCount", data.totalGroupedHitCount);
result.add("totalHitCount", data.totalHitCount);
if (data.totalGroupCount != null) {
result.add("totalGroupCount", data.totalGroupCount);
}
SchemaField uniqueField = rb.req.getSearcher().getSchema().getUniqueKeyField();
for (GroupDocs<String> searchGroup : data.groups) {
NamedList<Object> groupResult = new NamedList<Object>();
groupResult.add("totalHits", searchGroup.totalHits);
if (!Float.isNaN(searchGroup.maxScore)) {
groupResult.add("maxScore", searchGroup.maxScore);
}
List<NamedList<Object>> documents = new ArrayList<NamedList<Object>>();
for (int i = 0; i < searchGroup.scoreDocs.length; i++) {
NamedList<Object> document = new NamedList<Object>();
documents.add(document);
Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc);
document.add("id", uniqueField.getType().toObject(doc.getFieldable(uniqueField.getName())));
if (!Float.isNaN(searchGroup.scoreDocs[i].score)) {
document.add("score", searchGroup.scoreDocs[i].score);
}
if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) {
continue;
}
FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i];
Object[] convertedSortValues = new Object[fieldDoc.fields.length];
for (int j = 0; j < fieldDoc.fields.length; j++) {
Object sortValue = fieldDoc.fields[j];
Sort sortWithinGroup = rb.getGroupingSpec().getSortWithinGroup();
SchemaField field = sortWithinGroup.getSort()[j].getField() != null ? rb.req.getSearcher().getSchema().getFieldOrNull(sortWithinGroup.getSort()[j].getField()) : null;
if (field != null) {
FieldType fieldType = field.getType();
if (sortValue instanceof String) {
sortValue = fieldType.toObject(field.createField(fieldType.indexedToReadable((String) sortValue), 0.0f));
}
}
convertedSortValues[j] = sortValue;
}
document.add("sortValues", convertedSortValues);
}
groupResult.add("documents", documents);
String groupValue = searchGroup.groupValue != null ? groupField.getType().indexedToReadable(searchGroup.groupValue): null;
result.add(groupValue, groupResult);
}
return result;
}
protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException {
NamedList<Object> queryResult = new NamedList<Object>();
queryResult.add("matches", result.getMatches());
queryResult.add("totalHits", result.getTopDocs().totalHits);
if (rb.getGroupingSpec().isNeedScore()) {
queryResult.add("maxScore", result.getTopDocs().getMaxScore());
}
List<NamedList> documents = new ArrayList<NamedList>();
queryResult.add("documents", documents);
SchemaField uniqueField = rb.req.getSearcher().getSchema().getUniqueKeyField();
for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) {
NamedList<Object> document = new NamedList<Object>();
documents.add(document);
Document doc = retrieveDocument(uniqueField, scoreDoc.doc);
document.add("id", uniqueField.getType().toObject(doc.getFieldable(uniqueField.getName())));
if (rb.getGroupingSpec().isNeedScore()) {
document.add("score", scoreDoc.score);
}
if (!FieldDoc.class.isInstance(scoreDoc)) {
continue;
}
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
Object[] convertedSortValues = new Object[fieldDoc.fields.length];
for (int j = 0; j < fieldDoc.fields.length; j++) {
Object sortValue = fieldDoc.fields[j];
Sort groupSort = rb.getGroupingSpec().getGroupSort();
SchemaField field = groupSort.getSort()[j].getField() != null ? rb.req.getSearcher().getSchema().getFieldOrNull(groupSort.getSort()[j].getField()) : null;
if (field != null) {
FieldType fieldType = field.getType();
if (sortValue instanceof String) {
sortValue = fieldType.toObject(field.createField(fieldType.indexedToReadable((String) sortValue), 0.0f));
}
}
convertedSortValues[j] = sortValue;
}
document.add("sortValues", convertedSortValues);
}
return queryResult;
}
private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException {
FieldSelector fieldSelectorVisitor = new FieldSelector() {
public FieldSelectorResult accept(String fieldName) {
if (uniqueField.getName().equals(fieldName)) {
return FieldSelectorResult.LOAD_AND_BREAK;
}
return FieldSelectorResult.NO_LOAD;
}
};
return rb.req.getSearcher().doc(doc, fieldSelectorVisitor);
}
}