blob: 46c16500e839375783abf30dd5b08d793a9d8ad6 [file] [log] [blame]
package org.apache.solr.search.grouping.distributed.responseprocessor;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.handler.component.ShardRequest;
import org.apache.solr.handler.component.ShardResponse;
import org.apache.solr.search.Grouping;
import org.apache.solr.search.grouping.distributed.ShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Concrete implementation for merging {@link TopGroups} instances from shard responses.
*/
public class TopGroupsShardResponseProcessor implements ShardResponseProcessor {
/**
* {@inheritDoc}
*/
@Override
@SuppressWarnings("unchecked")
public void process(ResponseBuilder rb, ShardRequest shardRequest) {
Sort groupSort = rb.getGroupingSpec().getGroupSort();
String[] fields = rb.getGroupingSpec().getFields();
String[] queries = rb.getGroupingSpec().getQueries();
Sort sortWithinGroup = rb.getGroupingSpec().getSortWithinGroup();
// If group.format=simple group.offset doesn't make sense
int groupOffsetDefault;
if (rb.getGroupingSpec().getResponseFormat() == Grouping.Format.simple || rb.getGroupingSpec().isMain()) {
groupOffsetDefault = 0;
} else {
groupOffsetDefault = rb.getGroupingSpec().getGroupOffset();
}
int docsPerGroupDefault = rb.getGroupingSpec().getGroupLimit();
Map<String, List<TopGroups<BytesRef>>> commandTopGroups = new HashMap<String, List<TopGroups<BytesRef>>>();
for (String field : fields) {
commandTopGroups.put(field, new ArrayList<TopGroups<BytesRef>>());
}
Map<String, List<QueryCommandResult>> commandTopDocs = new HashMap<String, List<QueryCommandResult>>();
for (String query : queries) {
commandTopDocs.put(query, new ArrayList<QueryCommandResult>());
}
TopGroupsResultTransformer serializer = new TopGroupsResultTransformer(rb);
for (ShardResponse srsp : shardRequest.responses) {
NamedList<NamedList> secondPhaseResult = (NamedList<NamedList>) srsp.getSolrResponse().getResponse().get("secondPhase");
Map<String, ?> result = serializer.transformToNative(secondPhaseResult, groupSort, sortWithinGroup, srsp.getShard());
for (String field : commandTopGroups.keySet()) {
TopGroups<BytesRef> topGroups = (TopGroups<BytesRef>) result.get(field);
if (topGroups == null) {
continue;
}
commandTopGroups.get(field).add(topGroups);
}
for (String query : queries) {
commandTopDocs.get(query).add((QueryCommandResult) result.get(query));
}
}
try {
for (String groupField : commandTopGroups.keySet()) {
List<TopGroups<BytesRef>> topGroups = commandTopGroups.get(groupField);
if (topGroups.isEmpty()) {
continue;
}
TopGroups<BytesRef>[] topGroupsArr = new TopGroups[topGroups.size()];
rb.mergedTopGroups.put(groupField, TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, sortWithinGroup, groupOffsetDefault, docsPerGroupDefault, TopGroups.ScoreMergeMode.None));
}
for (String query : commandTopDocs.keySet()) {
List<QueryCommandResult> queryCommandResults = commandTopDocs.get(query);
List<TopDocs> topDocs = new ArrayList<TopDocs>(queryCommandResults.size());
int mergedMatches = 0;
for (QueryCommandResult queryCommandResult : queryCommandResults) {
topDocs.add(queryCommandResult.getTopDocs());
mergedMatches += queryCommandResult.getMatches();
}
int topN = rb.getGroupingSpec().getOffset() + rb.getGroupingSpec().getLimit();
TopDocs mergedTopDocs = TopDocs.merge(sortWithinGroup, topN, topDocs.toArray(new TopDocs[topDocs.size()]));
rb.mergedQueryCommandResults.put(query, new QueryCommandResult(mergedTopDocs, mergedMatches));
}
Map<Object, ShardDoc> resultIds = new HashMap<Object, ShardDoc>();
int i = 0;
for (TopGroups<BytesRef> topGroups : rb.mergedTopGroups.values()) {
for (GroupDocs<BytesRef> group : topGroups.groups) {
for (ScoreDoc scoreDoc : group.scoreDocs) {
ShardDoc solrDoc = (ShardDoc) scoreDoc;
solrDoc.positionInResponse = i++;
resultIds.put(solrDoc.id, solrDoc);
}
}
}
for (QueryCommandResult queryCommandResult : rb.mergedQueryCommandResults.values()) {
for (ScoreDoc scoreDoc : queryCommandResult.getTopDocs().scoreDocs) {
ShardDoc solrDoc = (ShardDoc) scoreDoc;
solrDoc.positionInResponse = i++;
resultIds.put(solrDoc.id, solrDoc);
}
}
rb.resultIds = resultIds;
} catch (IOException e) {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
}
}
}