| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.solr.search; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| |
| import org.apache.lucene.index.LeafReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.NumericDocValues; |
| import org.apache.lucene.index.PostingsEnum; |
| import org.apache.lucene.index.Terms; |
| import org.apache.lucene.index.TermsEnum; |
| import org.apache.lucene.search.DocIdSetIterator; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.SparseFixedBitSet; |
| import org.apache.solr.client.solrj.io.ClassificationEvaluation; |
| import org.apache.solr.common.params.SolrParams; |
| import org.apache.solr.common.util.NamedList; |
| import org.apache.solr.handler.component.ResponseBuilder; |
| import org.apache.solr.request.SolrQueryRequest; |
| |
| /** |
| * Returns an AnalyticsQuery implementation that performs |
| * one Gradient Descent iteration of a result set to train a |
| * logistic regression model |
| * |
| * The TextLogitStream provides the parallel iterative framework for this class. |
| **/ |
| |
| public class TextLogisticRegressionQParserPlugin extends QParserPlugin { |
| public static final String NAME = "tlogit"; |
| |
| @Override |
| public void init(@SuppressWarnings({"rawtypes"})NamedList args) { |
| } |
| |
| @Override |
| public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { |
| return new TextLogisticRegressionQParser(qstr, localParams, params, req); |
| } |
| |
| private static class TextLogisticRegressionQParser extends QParser{ |
| |
| TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { |
| super(qstr, localParams, params, req); |
| } |
| |
| public Query parse() { |
| |
| String fs = params.get("feature"); |
| String[] terms = params.get("terms").split(","); |
| String ws = params.get("weights"); |
| String dfsStr = params.get("idfs"); |
| int iteration = params.getInt("iteration", 0); |
| String outcome = params.get("outcome"); |
| int positiveLabel = params.getInt("positiveLabel", 1); |
| double threshold = params.getDouble("threshold", 0.5); |
| double alpha = params.getDouble("alpha", 0.01); |
| |
| double[] idfs = new double[terms.length]; |
| String[] idfsArr = dfsStr.split(","); |
| for (int i = 0; i < idfsArr.length; i++) { |
| idfs[i] = Double.parseDouble(idfsArr[i]); |
| } |
| |
| double[] weights = new double[terms.length+1]; |
| |
| if(ws != null) { |
| String[] wa = ws.split(","); |
| for (int i = 0; i < wa.length; i++) { |
| weights[i] = Double.parseDouble(wa[i]); |
| } |
| } else { |
| for(int i=0; i<weights.length; i++) { |
| weights[i]= 1.0d; |
| } |
| } |
| |
| TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold); |
| |
| return new TextLogisticRegressionQuery(input); |
| } |
| } |
| |
| private static class TextLogisticRegressionQuery extends AnalyticsQuery { |
| private TrainingParams trainingParams; |
| |
| public TextLogisticRegressionQuery(TrainingParams trainingParams) { |
| this.trainingParams = trainingParams; |
| } |
| |
| public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) { |
| return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams); |
| } |
| } |
| |
| private static class TextLogisticRegressionCollector extends DelegatingCollector { |
| private TrainingParams trainingParams; |
| private LeafReader leafReader; |
| |
| private double[] workingDeltas; |
| private ClassificationEvaluation classificationEvaluation; |
| private double[] weights; |
| |
| private ResponseBuilder rbsp; |
| private NumericDocValues leafOutcomeValue; |
| private double totalError; |
| private SparseFixedBitSet positiveDocsSet; |
| private SparseFixedBitSet docsSet; |
| private IndexSearcher searcher; |
| |
| TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher, |
| TrainingParams trainingParams) { |
| this.trainingParams = trainingParams; |
| this.workingDeltas = new double[trainingParams.weights.length]; |
| this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length); |
| this.rbsp = rbsp; |
| this.classificationEvaluation = new ClassificationEvaluation(); |
| this.searcher = searcher; |
| positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs()); |
| docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs()); |
| } |
| |
| public void doSetNextReader(LeafReaderContext context) throws IOException { |
| super.doSetNextReader(context); |
| leafReader = context.reader(); |
| leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome); |
| } |
| |
| public void collect(int doc) throws IOException{ |
| int outcome; |
| if (leafOutcomeValue.advanceExact(doc)) { |
| outcome = (int) leafOutcomeValue.longValue(); |
| } else { |
| outcome = 0; |
| } |
| |
| outcome = trainingParams.positiveLabel == outcome? 1 : 0; |
| if (outcome == 1) { |
| positiveDocsSet.set(context.docBase + doc); |
| } |
| docsSet.set(context.docBase+doc); |
| |
| } |
| |
| @SuppressWarnings({"unchecked"}) |
| public void finish() throws IOException { |
| |
| Map<Integer, double[]> docVectors = new HashMap<>(); |
| Terms terms = ((SolrIndexSearcher)searcher).getSlowAtomicReader().terms(trainingParams.feature); |
| TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator(); |
| PostingsEnum postingsEnum = null; |
| int termIndex = 0; |
| for (String termStr : trainingParams.terms) { |
| BytesRef term = new BytesRef(termStr); |
| if (termsEnum.seekExact(term)) { |
| postingsEnum = termsEnum.postings(postingsEnum); |
| while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { |
| int docId = postingsEnum.docID(); |
| if (docsSet.get(docId)) { |
| double[] vector = docVectors.get(docId); |
| if (vector == null) { |
| vector = new double[trainingParams.terms.length+1]; |
| vector[0] = 1.0; |
| docVectors.put(docId, vector); |
| } |
| vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq())); |
| } |
| } |
| } |
| termIndex++; |
| } |
| |
| for (Map.Entry<Integer, double[]> entry : docVectors.entrySet()) { |
| double[] vector = entry.getValue(); |
| int outcome = 0; |
| if (positiveDocsSet.get(entry.getKey())) { |
| outcome = 1; |
| } |
| double sig = sigmoid(sum(multiply(vector, weights))); |
| double error = sig - outcome; |
| double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights))); |
| totalError += Math.abs(lastSig - outcome); |
| classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0); |
| |
| workingDeltas = multiply(error * trainingParams.alpha, vector); |
| |
| for(int i = 0; i< workingDeltas.length; i++) { |
| weights[i] -= workingDeltas[i]; |
| } |
| } |
| |
| @SuppressWarnings({"rawtypes"}) |
| NamedList analytics = new NamedList(); |
| rbsp.rsp.add("logit", analytics); |
| |
| List<Double> outWeights = new ArrayList<>(); |
| for(Double d : weights) { |
| outWeights.add(d); |
| } |
| |
| analytics.add("weights", outWeights); |
| analytics.add("error", totalError); |
| analytics.add("evaluation", classificationEvaluation.toMap()); |
| analytics.add("feature", trainingParams.feature); |
| analytics.add("positiveLabel", trainingParams.positiveLabel); |
| if(this.delegate instanceof DelegatingCollector) { |
| ((DelegatingCollector)this.delegate).finish(); |
| } |
| } |
| |
| private double sigmoid(double in) { |
| double d = 1.0 / (1+Math.exp(-in)); |
| return d; |
| } |
| |
| private double[] multiply(double[] vals, double[] weights) { |
| for(int i = 0; i < vals.length; ++i) { |
| workingDeltas[i] = vals[i] * weights[i]; |
| } |
| |
| return workingDeltas; |
| } |
| |
| private double[] multiply(double d, double[] vals) { |
| for(int i = 0; i<vals.length; ++i) { |
| workingDeltas[i] = vals[i] * d; |
| } |
| |
| return workingDeltas; |
| } |
| |
| private double sum(double[] vals) { |
| double d = 0.0d; |
| for(double val : vals) { |
| d += val; |
| } |
| |
| return d; |
| } |
| |
| } |
| |
| private static class TrainingParams { |
| public final String feature; |
| public final String[] terms; |
| public final double[] idfs; |
| public final String outcome; |
| public final double[] weights; |
| public final int interation; |
| public final int positiveLabel; |
| public final double threshold; |
| public final double alpha; |
| |
| public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) { |
| this.feature = feature; |
| this.terms = terms; |
| this.idfs = idfs; |
| this.outcome = outcome; |
| this.weights = weights; |
| this.alpha = alpha; |
| this.interation = interation; |
| this.positiveLabel = positiveLabel; |
| this.threshold = threshold; |
| } |
| } |
| } |