blob: 2e404df40e6eaddebd2fc55f69ef436357402c66 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.util;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.payloads.DelimitedPayloadTokenFilterFactory;
import org.apache.lucene.analysis.payloads.NumericPayloadTokenFilterFactory;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.analysis.util.TokenFilterFactory;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.payloads.AveragePayloadFunction;
import org.apache.lucene.queries.payloads.MaxPayloadFunction;
import org.apache.lucene.queries.payloads.MinPayloadFunction;
import org.apache.lucene.queries.payloads.PayloadDecoder;
import org.apache.lucene.queries.payloads.PayloadFunction;
import org.apache.lucene.queries.payloads.SumPayloadFunction;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.analysis.TokenizerChain;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.PayloadScoreQParserPlugin;
public class PayloadUtils {
public static String getPayloadEncoder(FieldType fieldType) {
// TODO: support custom payload encoding fields too somehow - maybe someone has a custom component that encodes payloads as floats
String encoder = null;
Analyzer a = fieldType.getIndexAnalyzer();
if (a instanceof TokenizerChain) {
// examine the indexing analysis chain for DelimitedPayloadTokenFilterFactory or NumericPayloadTokenFilterFactory
TokenizerChain tc = (TokenizerChain)a;
TokenFilterFactory[] factories = tc.getTokenFilterFactories();
for (TokenFilterFactory factory : factories) {
if (factory instanceof DelimitedPayloadTokenFilterFactory) {
encoder = factory.getOriginalArgs().get(DelimitedPayloadTokenFilterFactory.ENCODER_ATTR);
break;
}
if (factory instanceof NumericPayloadTokenFilterFactory) {
// encodes using `PayloadHelper.encodeFloat(payload)`
encoder = "float";
break;
}
}
}
return encoder;
}
public static PayloadDecoder getPayloadDecoder(FieldType fieldType) {
PayloadDecoder decoder = null;
String encoder = getPayloadEncoder(fieldType);
if ("integer".equals(encoder)) {
decoder = (BytesRef payload) -> payload == null ? 1 : PayloadHelper.decodeInt(payload.bytes, payload.offset);
}
if ("float".equals(encoder)) {
decoder = (BytesRef payload) -> payload == null ? 1 : PayloadHelper.decodeFloat(payload.bytes, payload.offset);
}
// encoder could be "identity" at this point, in the case of DelimitedTokenFilterFactory encoder="identity"
// TODO: support pluggable payload decoders?
return decoder;
}
public static PayloadFunction getPayloadFunction(String func) {
PayloadFunction payloadFunction = null;
if ("min".equals(func)) {
payloadFunction = new MinPayloadFunction();
}
if ("max".equals(func)) {
payloadFunction = new MaxPayloadFunction();
}
if ("average".equals(func)) {
payloadFunction = new AveragePayloadFunction();
}
if ("sum".equals(func)) {
payloadFunction = new SumPayloadFunction();
}
return payloadFunction;
}
public static SpanQuery createSpanQuery(String field, String value, Analyzer analyzer) throws IOException {
return createSpanQuery(field, value, analyzer, PayloadScoreQParserPlugin.DEFAULT_OPERATOR);
}
/**
* The generated SpanQuery will be either a SpanTermQuery or an ordered, zero slop SpanNearQuery, depending
* on how many tokens are emitted.
*/
public static SpanQuery createSpanQuery(String field, String value, Analyzer analyzer, String operator) throws IOException {
// adapted this from QueryBuilder.createSpanQuery (which isn't currently public) and added reset(), end(), and close() calls
List<SpanTermQuery> terms = new ArrayList<>();
try (TokenStream in = analyzer.tokenStream(field, value)) {
in.reset();
TermToBytesRefAttribute termAtt = in.getAttribute(TermToBytesRefAttribute.class);
while (in.incrementToken()) {
terms.add(new SpanTermQuery(new Term(field, termAtt.getBytesRef())));
}
in.end();
}
SpanQuery query;
if (terms.isEmpty()) {
query = null;
} else if (terms.size() == 1) {
query = terms.get(0);
} else if (operator != null && operator.equalsIgnoreCase("or")) {
query = new SpanOrQuery(terms.toArray(new SpanTermQuery[terms.size()]));
} else {
query = new SpanNearQuery(terms.toArray(new SpanTermQuery[terms.size()]), 0, true);
}
return query;
}
}