blob: 0fd0f6ab0442dbdccd914709c9ec0aaa4de60dc2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
/**
* Returns an AnalyticsQuery implementation that performs
* one Gradient Descent iteration of a result set to train a
* logistic regression model
*
* The TextLogitStream provides the parallel iterative framework for this class.
**/
public class TextLogisticRegressionQParserPlugin extends QParserPlugin {
public static final String NAME = "tlogit";
@Override
public void init(@SuppressWarnings({"rawtypes"})NamedList args) {
}
@Override
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new TextLogisticRegressionQParser(qstr, localParams, params, req);
}
private static class TextLogisticRegressionQParser extends QParser{
TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
public Query parse() {
String fs = params.get("feature");
String[] terms = params.get("terms").split(",");
String ws = params.get("weights");
String dfsStr = params.get("idfs");
int iteration = params.getInt("iteration", 0);
String outcome = params.get("outcome");
int positiveLabel = params.getInt("positiveLabel", 1);
double threshold = params.getDouble("threshold", 0.5);
double alpha = params.getDouble("alpha", 0.01);
double[] idfs = new double[terms.length];
String[] idfsArr = dfsStr.split(",");
for (int i = 0; i < idfsArr.length; i++) {
idfs[i] = Double.parseDouble(idfsArr[i]);
}
double[] weights = new double[terms.length+1];
if(ws != null) {
String[] wa = ws.split(",");
for (int i = 0; i < wa.length; i++) {
weights[i] = Double.parseDouble(wa[i]);
}
} else {
for(int i=0; i<weights.length; i++) {
weights[i]= 1.0d;
}
}
TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
return new TextLogisticRegressionQuery(input);
}
}
private static class TextLogisticRegressionQuery extends AnalyticsQuery {
private TrainingParams trainingParams;
public TextLogisticRegressionQuery(TrainingParams trainingParams) {
this.trainingParams = trainingParams;
}
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) {
return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams);
}
}
private static class TextLogisticRegressionCollector extends DelegatingCollector {
private TrainingParams trainingParams;
private LeafReader leafReader;
private double[] workingDeltas;
private ClassificationEvaluation classificationEvaluation;
private double[] weights;
private ResponseBuilder rbsp;
private NumericDocValues leafOutcomeValue;
private double totalError;
private SparseFixedBitSet positiveDocsSet;
private SparseFixedBitSet docsSet;
private IndexSearcher searcher;
TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher,
TrainingParams trainingParams) {
this.trainingParams = trainingParams;
this.workingDeltas = new double[trainingParams.weights.length];
this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
this.rbsp = rbsp;
this.classificationEvaluation = new ClassificationEvaluation();
this.searcher = searcher;
positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
}
public void doSetNextReader(LeafReaderContext context) throws IOException {
super.doSetNextReader(context);
leafReader = context.reader();
leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome);
}
public void collect(int doc) throws IOException{
int outcome;
if (leafOutcomeValue.advanceExact(doc)) {
outcome = (int) leafOutcomeValue.longValue();
} else {
outcome = 0;
}
outcome = trainingParams.positiveLabel == outcome? 1 : 0;
if (outcome == 1) {
positiveDocsSet.set(context.docBase + doc);
}
docsSet.set(context.docBase+doc);
}
@SuppressWarnings({"unchecked"})
public void finish() throws IOException {
Map<Integer, double[]> docVectors = new HashMap<>();
Terms terms = ((SolrIndexSearcher)searcher).getSlowAtomicReader().terms(trainingParams.feature);
TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator();
PostingsEnum postingsEnum = null;
int termIndex = 0;
for (String termStr : trainingParams.terms) {
BytesRef term = new BytesRef(termStr);
if (termsEnum.seekExact(term)) {
postingsEnum = termsEnum.postings(postingsEnum);
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int docId = postingsEnum.docID();
if (docsSet.get(docId)) {
double[] vector = docVectors.get(docId);
if (vector == null) {
vector = new double[trainingParams.terms.length+1];
vector[0] = 1.0;
docVectors.put(docId, vector);
}
vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
}
}
}
termIndex++;
}
for (Map.Entry<Integer, double[]> entry : docVectors.entrySet()) {
double[] vector = entry.getValue();
int outcome = 0;
if (positiveDocsSet.get(entry.getKey())) {
outcome = 1;
}
double sig = sigmoid(sum(multiply(vector, weights)));
double error = sig - outcome;
double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights)));
totalError += Math.abs(lastSig - outcome);
classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0);
workingDeltas = multiply(error * trainingParams.alpha, vector);
for(int i = 0; i< workingDeltas.length; i++) {
weights[i] -= workingDeltas[i];
}
}
@SuppressWarnings({"rawtypes"})
NamedList analytics = new NamedList();
rbsp.rsp.add("logit", analytics);
List<Double> outWeights = new ArrayList<>();
for(Double d : weights) {
outWeights.add(d);
}
analytics.add("weights", outWeights);
analytics.add("error", totalError);
analytics.add("evaluation", classificationEvaluation.toMap());
analytics.add("feature", trainingParams.feature);
analytics.add("positiveLabel", trainingParams.positiveLabel);
if(this.delegate instanceof DelegatingCollector) {
((DelegatingCollector)this.delegate).finish();
}
}
private double sigmoid(double in) {
double d = 1.0 / (1+Math.exp(-in));
return d;
}
private double[] multiply(double[] vals, double[] weights) {
for(int i = 0; i < vals.length; ++i) {
workingDeltas[i] = vals[i] * weights[i];
}
return workingDeltas;
}
private double[] multiply(double d, double[] vals) {
for(int i = 0; i<vals.length; ++i) {
workingDeltas[i] = vals[i] * d;
}
return workingDeltas;
}
private double sum(double[] vals) {
double d = 0.0d;
for(double val : vals) {
d += val;
}
return d;
}
}
private static class TrainingParams {
public final String feature;
public final String[] terms;
public final double[] idfs;
public final String outcome;
public final double[] weights;
public final int interation;
public final int positiveLabel;
public final double threshold;
public final double alpha;
public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) {
this.feature = feature;
this.terms = terms;
this.idfs = idfs;
this.outcome = outcome;
this.weights = weights;
this.alpha = alpha;
this.interation = interation;
this.positiveLabel = positiveLabel;
this.threshold = threshold;
}
}
}