blob: 91e7d74888929a4e6a69bcf9e3ed536650fafd60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.join;
import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.SolrStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.UniqueStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.cloud.CloudDescriptor;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.DocRouter;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.BitDocSet;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.DocSetUtil;
import org.apache.solr.search.Filter;
import org.apache.solr.search.SolrIndexSearcher;
public class CrossCollectionJoinQuery extends Query {
protected final String query;
protected final String zkHost;
protected final String solrUrl;
protected final String collection;
protected final String fromField;
protected final String toField;
protected final boolean routedByJoinKey;
protected final long timestamp;
protected final int ttl;
protected SolrParams otherParams;
protected String otherParamsString;
public CrossCollectionJoinQuery(String query, String zkHost, String solrUrl,
String collection, String fromField, String toField,
boolean routedByJoinKey, int ttl, SolrParams otherParams) {
this.query = query;
this.zkHost = zkHost;
this.solrUrl = solrUrl;
this.collection = collection;
this.fromField = fromField;
this.toField = toField;
this.routedByJoinKey = routedByJoinKey;
this.timestamp = System.nanoTime();
this.ttl = ttl;
this.otherParams = otherParams;
// SolrParams doesn't implement equals(), so use this string to compare them
if (otherParams != null) {
this.otherParamsString = otherParams.toString();
}
}
private interface JoinKeyCollector {
void collect(Object value) throws IOException;
DocSet getDocSet() throws IOException;
}
private class TermsJoinKeyCollector implements JoinKeyCollector {
FieldType fieldType;
SolrIndexSearcher searcher;
TermsEnum termsEnum;
BytesRefBuilder bytes;
PostingsEnum postingsEnum;
FixedBitSet bitSet;
public TermsJoinKeyCollector(FieldType fieldType, Terms terms, SolrIndexSearcher searcher) throws IOException {
this.fieldType = fieldType;
this.searcher = searcher;
termsEnum = terms.iterator();
bytes = new BytesRefBuilder();
bitSet = new FixedBitSet(searcher.maxDoc());
}
@Override
public void collect(Object value) throws IOException {
fieldType.readableToIndexed((String) value, bytes);
if (termsEnum.seekExact(bytes.get())) {
postingsEnum = termsEnum.postings(postingsEnum, PostingsEnum.NONE);
bitSet.or(postingsEnum);
}
}
@Override
public DocSet getDocSet() throws IOException {
if (searcher.getIndexReader().hasDeletions()) {
bitSet.and(searcher.getLiveDocSet().getBits());
}
return new BitDocSet(bitSet);
}
}
private class PointJoinKeyCollector extends GraphPointsCollector implements JoinKeyCollector {
SolrIndexSearcher searcher;
public PointJoinKeyCollector(SolrIndexSearcher searcher) {
super(searcher.getSchema().getField(toField), null, null);
this.searcher = searcher;
}
@Override
public void collect(Object value) throws IOException {
if (value instanceof Long || value instanceof Integer) {
set.add(((Number) value).longValue());
} else {
throw new UnsupportedOperationException("Unsupported field type for XCJFQuery");
}
}
@Override
public DocSet getDocSet() throws IOException {
Query query = getResultQuery(searcher.getSchema().getField(toField), false);
if (query == null) {
return DocSet.EMPTY;
}
return DocSetUtil.createDocSet(searcher, query, null);
}
}
private class CrossCollectionJoinQueryWeight extends ConstantScoreWeight {
private SolrIndexSearcher searcher;
private ScoreMode scoreMode;
private Filter filter;
public CrossCollectionJoinQueryWeight(SolrIndexSearcher searcher, ScoreMode scoreMode, float score) {
super(CrossCollectionJoinQuery.this, score);
this.scoreMode = scoreMode;
this.searcher = searcher;
}
private String createHashRangeFq() {
if (routedByJoinKey) {
ClusterState clusterState = searcher.getCore().getCoreContainer().getZkController().getClusterState();
CloudDescriptor desc = searcher.getCore().getCoreDescriptor().getCloudDescriptor();
Slice slice = clusterState.getCollection(desc.getCollectionName()).getSlicesMap().get(desc.getShardId());
DocRouter.Range range = slice.getRange();
// In CompositeIdRouter, the routing prefix only affects the top 16 bits
int min = range.min & 0xffff0000;
int max = range.max | 0x0000ffff;
return String.format(Locale.ROOT, "{!hash_range f=%s l=%d u=%d}", fromField, min, max);
} else {
return null;
}
}
private TupleStream createCloudSolrStream(SolrClientCache solrClientCache) throws IOException {
String streamZkHost;
if (zkHost != null) {
streamZkHost = zkHost;
} else {
streamZkHost = searcher.getCore().getCoreContainer().getZkController().getZkServerAddress();
}
ModifiableSolrParams params = new ModifiableSolrParams(otherParams);
params.set(CommonParams.Q, query);
String fq = createHashRangeFq();
if (fq != null) {
params.add(CommonParams.FQ, fq);
}
params.set(CommonParams.FL, fromField);
params.set(CommonParams.SORT, fromField + " asc");
params.set(CommonParams.QT, "/export");
params.set(CommonParams.WT, CommonParams.JAVABIN);
StreamContext streamContext = new StreamContext();
streamContext.setSolrClientCache(solrClientCache);
TupleStream cloudSolrStream = new CloudSolrStream(streamZkHost, collection, params);
TupleStream uniqueStream = new UniqueStream(cloudSolrStream, new FieldEqualitor(fromField));
uniqueStream.setStreamContext(streamContext);
return uniqueStream;
}
private TupleStream createSolrStream() {
StreamExpression searchExpr = new StreamExpression("search")
.withParameter(collection)
.withParameter(new StreamExpressionNamedParameter(CommonParams.Q, query));
String fq = createHashRangeFq();
if (fq != null) {
searchExpr.withParameter(new StreamExpressionNamedParameter(CommonParams.FQ, fq));
}
searchExpr.withParameter(new StreamExpressionNamedParameter(CommonParams.FL, fromField))
.withParameter(new StreamExpressionNamedParameter(CommonParams.SORT, fromField + " asc"))
.withParameter(new StreamExpressionNamedParameter(CommonParams.QT, "/export"));
for (Map.Entry<String,String[]> entry : otherParams) {
for (String value : entry.getValue()) {
searchExpr.withParameter(new StreamExpressionNamedParameter(entry.getKey(), value));
}
}
StreamExpression uniqueExpr = new StreamExpression("unique");
uniqueExpr.withParameter(searchExpr)
.withParameter(new StreamExpressionNamedParameter("over", fromField));
ModifiableSolrParams params = new ModifiableSolrParams();
params.set("expr", uniqueExpr.toString());
params.set(CommonParams.QT, "/stream");
params.set(CommonParams.WT, CommonParams.JAVABIN);
return new SolrStream(solrUrl + "/" + collection, params);
}
private DocSet getDocSet() throws IOException {
SolrClientCache solrClientCache = searcher.getCore().getCoreContainer().getSolrClientCache();
TupleStream solrStream;
if (zkHost != null || solrUrl == null) {
solrStream = createCloudSolrStream(solrClientCache);
} else {
solrStream = createSolrStream();
}
FieldType fieldType = searcher.getSchema().getFieldType(toField);
JoinKeyCollector collector;
if (fieldType.isPointField()) {
collector = new PointJoinKeyCollector(searcher);
} else {
Terms terms = searcher.getSlowAtomicReader().terms(toField);
if (terms == null) {
return DocSet.EMPTY;
}
collector = new TermsJoinKeyCollector(fieldType, terms, searcher);
}
try {
solrStream.open();
while (true) {
Tuple tuple = solrStream.read();
if (tuple.EXCEPTION) {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, tuple.getException());
}
if (tuple.EOF) {
break;
}
Object value = tuple.get(fromField);
if (null != value) {
collector.collect(value);
}
}
} catch (IOException e) {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
} finally {
solrStream.close();
}
return collector.getDocSet();
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
if (filter == null) {
filter = getDocSet().getTopFilter();
}
DocIdSet readerSet = filter.getDocIdSet(context, null);
if (readerSet == null) {
return null;
}
DocIdSetIterator readerSetIterator = readerSet.iterator();
if (readerSetIterator == null) {
return null;
}
return new ConstantScoreScorer(this, score(), scoreMode, readerSetIterator);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new CrossCollectionJoinQueryWeight((SolrIndexSearcher) searcher, scoreMode, boost);
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public int hashCode() {
final int prime = 31;
int result = classHash();
result = prime * result + Objects.hashCode(query);
result = prime * result + Objects.hashCode(zkHost);
result = prime * result + Objects.hashCode(solrUrl);
result = prime * result + Objects.hashCode(collection);
result = prime * result + Objects.hashCode(fromField);
result = prime * result + Objects.hashCode(toField);
result = prime * result + Objects.hashCode(routedByJoinKey);
result = prime * result + Objects.hashCode(otherParamsString);
// timestamp and ttl should not be included in hash code
return result;
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(CrossCollectionJoinQuery other) {
return Objects.equals(query, other.query) &&
Objects.equals(zkHost, other.zkHost) &&
Objects.equals(solrUrl, other.solrUrl) &&
Objects.equals(collection, other.collection) &&
Objects.equals(fromField, other.fromField) &&
Objects.equals(toField, other.toField) &&
routedByJoinKey == other.routedByJoinKey &&
Objects.equals(otherParamsString, other.otherParamsString) &&
TimeUnit.SECONDS.convert(Math.abs(timestamp - other.timestamp), TimeUnit.NANOSECONDS) < Math.min(ttl, other.ttl);
}
@Override
public String toString(String field) {
return String.format(Locale.ROOT, "{!xcjf collection=%s from=%s to=%s routed=%b ttl=%d}%s",
collection, fromField, toField, routedByJoinKey, ttl, query.toString());
}
}