blob: 5364b1ba22a84105eec5c0b2bf7eed7c84cd3bcb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import java.io.IOException;
import java.util.TreeSet;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
public class IGainTermsQParserPlugin extends QParserPlugin {
public static final String NAME = "igain";
@Override
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new IGainTermsQParser(qstr, localParams, params, req);
}
private static class IGainTermsQParser extends QParser {
public IGainTermsQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
@Override
public Query parse() throws SyntaxError {
String field = getParam("field");
String outcome = getParam("outcome");
int numTerms = Integer.parseInt(getParam("numTerms"));
int positiveLabel = Integer.parseInt(getParam("positiveLabel"));
return new IGainTermsQuery(field, outcome, positiveLabel, numTerms);
}
}
private static class IGainTermsQuery extends AnalyticsQuery {
private String field;
private String outcome;
private int numTerms;
private int positiveLabel;
public IGainTermsQuery(String field, String outcome, int positiveLabel, int numTerms) {
this.field = field;
this.outcome = outcome;
this.numTerms = numTerms;
this.positiveLabel = positiveLabel;
}
@Override
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rb, IndexSearcher searcher) {
return new IGainTermsCollector(rb, searcher, field, outcome, positiveLabel, numTerms);
}
}
private static class IGainTermsCollector extends DelegatingCollector {
private String field;
private String outcome;
private IndexSearcher searcher;
private ResponseBuilder rb;
private int positiveLabel;
private int numTerms;
private int count;
private NumericDocValues leafOutcomeValue;
private SparseFixedBitSet positiveSet;
private SparseFixedBitSet negativeSet;
private int numPositiveDocs;
public IGainTermsCollector(ResponseBuilder rb, IndexSearcher searcher, String field, String outcome, int positiveLabel, int numTerms) {
this.rb = rb;
this.searcher = searcher;
this.field = field;
this.outcome = outcome;
this.positiveSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
this.negativeSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
this.numTerms = numTerms;
this.positiveLabel = positiveLabel;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
super.doSetNextReader(context);
LeafReader reader = context.reader();
leafOutcomeValue = reader.getNumericDocValues(outcome);
}
@Override
public void collect(int doc) throws IOException {
super.collect(doc);
++count;
int value;
if (leafOutcomeValue.advanceExact(doc)) {
value = (int) leafOutcomeValue.longValue();
} else {
value = 0;
}
if (value == positiveLabel) {
positiveSet.set(context.docBase + doc);
numPositiveDocs++;
} else {
negativeSet.set(context.docBase + doc);
}
}
@Override
public void finish() throws IOException {
NamedList<Double> analytics = new NamedList<Double>();
@SuppressWarnings({"unchecked", "rawtypes"})
NamedList<Integer> topFreq = new NamedList();
@SuppressWarnings({"unchecked", "rawtypes"})
NamedList<Integer> allFreq = new NamedList();
rb.rsp.add("featuredTerms", analytics);
rb.rsp.add("docFreq", topFreq);
rb.rsp.add("numDocs", count);
TreeSet<TermWithScore> topTerms = new TreeSet<>();
double numDocs = count;
double pc = numPositiveDocs / numDocs;
double entropyC = binaryEntropy(pc);
Terms terms = ((SolrIndexSearcher)searcher).getSlowAtomicReader().terms(field);
TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator();
BytesRef term;
PostingsEnum postingsEnum = null;
while ((term = termsEnum.next()) != null) {
postingsEnum = termsEnum.postings(postingsEnum);
int xc = 0;
int nc = 0;
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
if (positiveSet.get(postingsEnum.docID())) {
xc++;
} else if (negativeSet.get(postingsEnum.docID())) {
nc++;
}
}
int docFreq = xc+nc;
double entropyContainsTerm = binaryEntropy( (double) xc / docFreq );
double entropyNotContainsTerm = binaryEntropy( (double) (numPositiveDocs - xc) / (numDocs - docFreq + 1) );
double score = entropyC - ( (docFreq / numDocs) * entropyContainsTerm + (1.0 - docFreq / numDocs) * entropyNotContainsTerm);
topFreq.add(term.utf8ToString(), docFreq);
if (topTerms.size() < numTerms) {
topTerms.add(new TermWithScore(term.utf8ToString(), score));
} else {
if (topTerms.first().score < score) {
topTerms.pollFirst();
topTerms.add(new TermWithScore(term.utf8ToString(), score));
}
}
}
for (TermWithScore topTerm : topTerms) {
analytics.add(topTerm.term, topTerm.score);
topFreq.add(topTerm.term, allFreq.get(topTerm.term));
}
if (this.delegate instanceof DelegatingCollector) {
((DelegatingCollector) this.delegate).finish();
}
}
private double binaryEntropy(double prob) {
if (prob == 0 || prob == 1) return 0;
return (-1 * prob * Math.log(prob)) + (-1 * (1.0 - prob) * Math.log(1.0 - prob));
}
}
private static class TermWithScore implements Comparable<TermWithScore>{
public final String term;
public final double score;
public TermWithScore(String term, double score) {
this.term = term;
this.score = score;
}
@Override
public int hashCode() {
return term.hashCode();
}
@Override
public boolean equals(Object obj) {
if (obj == null) return false;
if (obj.getClass() != getClass()) return false;
TermWithScore other = (TermWithScore) obj;
return other.term.equals(this.term);
}
@Override
public int compareTo(TermWithScore o) {
int cmp = Double.compare(this.score, o.score);
if (cmp == 0) {
return this.term.compareTo(o.term);
} else {
return cmp;
}
}
}
}