blob: dd990c70ed7d934d102c2b73acebcc8a63c1a257 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hama.ml.regression;
import org.apache.hadoop.fs.Path;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork;
import org.apache.hama.ml.util.FeatureTransformer;
import java.io.IOException;
import java.util.Map;
/**
* The logistic regression model. It can be used to conduct 2-class
* classification.
*
*/
public class LogisticRegression {
private final SmallLayeredNeuralNetwork ann;
public LogisticRegression(int dimension) {
this.ann = new SmallLayeredNeuralNetwork();
this.ann.addLayer(dimension, false, FunctionFactory.createDoubleFunction("Sigmoid"));
this.ann.addLayer(1, true, FunctionFactory.createDoubleFunction("Sigmoid"));
this.ann.setCostFunction(FunctionFactory.createDoubleDoubleFunction("CrossEntropy"));
}
public LogisticRegression(String modelPath) {
this.ann = new SmallLayeredNeuralNetwork(modelPath);
}
/**
* Set the learning rate, recommend in range (0, 0.01]. Note that linear
* regression are easy to get diverge if the learning rate is not small
* enough.
*
* @param learningRate
*/
public LogisticRegression setLearningRate(double learningRate) {
ann.setLearningRate(learningRate);
return this;
}
/**
* Get the learning rate.
*/
public double getLearningRate() {
return ann.getLearningRate();
}
/**
* Set the weight of the momemtum. Recommend in range [0, 1.0]. Too large
* momemtum weight may make model hard to converge.
*
* @param momemtumWeight
*/
public LogisticRegression setMomemtumWeight(double momemtumWeight) {
ann.setMomemtumWeight(momemtumWeight);
return this;
}
/**
* Get the weight of momemtum.
*
* @return the monemtum weight value.
*/
public double getMomemtumWeight() {
return ann.getMomemtumWeight();
}
/**
* Set the weight of regularization, recommend in range [0, 0.1]. Too large
* regularization will mislead the model.
*
* @param regularizationWeight
*/
public LogisticRegression setRegularizationWeight(double regularizationWeight) {
ann.setRegularizationWeight(regularizationWeight);
return this;
}
/**
* Get the weight of regularization.
*
* @return the regularizatioin weight value.
*/
public double getRegularizationWeight() {
return ann.getRegularizationWeight();
}
/**
* Train the linear regression model with one instance. It is HIGHLY
* RECOMMENDED to normalize the data first.
*
* @param trainingInstance
*/
public void trainOnline(DoubleVector trainingInstance) {
ann.trainOnline(trainingInstance);
}
/**
* Train the model with given data. It is HIGHLY RECOMMENDED to normalize the
* data first.
*
* @param dataInputPath The file path that contains the training instance.
* @param trainingParams The training parameters.
* @throws IOException
* @throws InterruptedException
* @throws ClassNotFoundException
*/
public void train(Path dataInputPath, Map<String, String> trainingParams) {
ann.train(dataInputPath, trainingParams);
}
/**
* Get the output according to given input instance.
*
* @param instance
* @return a new vector with the result of the operation.
*/
public DoubleVector getOutput(DoubleVector instance) {
return ann.getOutput(instance);
}
/**
* Set the path to store the model. Note this is just set the path, it does
* not save the model. You should call writeModelToFile to save the model.
*
* @param modelPath
*/
public void setModelPath(String modelPath) {
ann.setModelPath(modelPath);
}
/**
* Save the model to specified model path.
*/
public void writeModelToFile() {
try {
ann.writeModelToFile();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* Get the weights of the model.
*
* @return a new vector with the weights of the model.
*/
public DoubleVector getWeights() {
return ann.getWeightsByLayer(0).getRowVector(0);
}
/**
* Set the feature transformer.
* @param featureTransformer
*/
public void setFeatureTransformer(FeatureTransformer featureTransformer) {
this.ann.setFeatureTransformer(featureTransformer);
}
}