IGNITE-8403: [ML] Add Binary Logistic Regression based on
partitioned datasets and MLP
this closes #3924
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 276d43f..04d1778 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -34,7 +34,7 @@
import java.util.UUID;
/**
- * Run linear regression model over distributed matrix.
+ * Run linear regression model over cached dataset.
*
* @see LinearRegressionLSQRTrainer
*/
@@ -104,8 +104,6 @@
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
LinearRegressionLSQRTrainerExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java
index 0358f44..6c9273c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java
@@ -24,7 +24,6 @@
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
@@ -38,7 +37,7 @@
import java.util.UUID;
/**
- * Run linear regression model over distributed matrix.
+ * Run linear regression model over cached dataset.
*
* @see LinearRegressionLSQRTrainer
* @see NormalizationTrainer
@@ -105,15 +104,13 @@
/** Run example. */
public static void main(String[] args) throws InterruptedException {
System.out.println();
- System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
+ System.out.println(">>> Linear regression model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+ LinearRegressionLSQRTrainerWithNormalizationExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
System.out.println(">>> Create new normalization trainer object.");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index ce6ad3b..da5f942 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -25,11 +25,11 @@
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.thread.IgniteThread;
import javax.cache.Cache;
@@ -37,7 +37,7 @@
import java.util.UUID;
/**
- * Run linear regression model over distributed matrix.
+ * Run linear regression model over cached dataset.
*
* @see LinearRegressionSGDTrainer
*/
@@ -106,8 +106,6 @@
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
LinearRegressionSGDTrainerExample.class.getSimpleName(), () -> {
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java
new file mode 100644
index 0000000..0505ddd
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.logistic;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.thread.IgniteThread;
+
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
+/**
+ * Run logistic regression model over distributed cache.
+ *
+ * @see LogisticRegressionSGDTrainer
+ */
+public class LogisticRegressionSGDTrainerSample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> {
+
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
+
+ System.out.println(">>> Perform the training to get the model.");
+ LogisticRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ ).withRawLabels(true);
+
+ System.out.println(">>> Logistic regression model: " + mdl);
+
+ int amountOfErrors = 0;
+ int totalAmount = 0;
+
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0}, {0, 0}};
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+
+ totalAmount++;
+ if(groundTruth != prediction)
+ amountOfErrors++;
+
+ int idx1 = (int)prediction;
+ int idx2 = (int)groundTruth;
+
+ confusionMtx[idx1][idx2]++;
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
+
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ }
+
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println(">>> ---------------------------------");
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int i = 0; i < data.length; i++)
+ cache.put(i, data[i]);
+
+ return cache;
+ }
+
+
+ /** The 1st and 2nd classes from the Iris dataset. */
+ private static final double[][] data = {
+ {0, 5.1, 3.5, 1.4, 0.2},
+ {0, 4.9, 3, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.3, 0.2},
+ {0, 4.6, 3.1, 1.5, 0.2},
+ {0, 5, 3.6, 1.4, 0.2},
+ {0, 5.4, 3.9, 1.7, 0.4},
+ {0, 4.6, 3.4, 1.4, 0.3},
+ {0, 5, 3.4, 1.5, 0.2},
+ {0, 4.4, 2.9, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5.4, 3.7, 1.5, 0.2},
+ {0, 4.8, 3.4, 1.6, 0.2},
+ {0, 4.8, 3, 1.4, 0.1},
+ {0, 4.3, 3, 1.1, 0.1},
+ {0, 5.8, 4, 1.2, 0.2},
+ {0, 5.7, 4.4, 1.5, 0.4},
+ {0, 5.4, 3.9, 1.3, 0.4},
+ {0, 5.1, 3.5, 1.4, 0.3},
+ {0, 5.7, 3.8, 1.7, 0.3},
+ {0, 5.1, 3.8, 1.5, 0.3},
+ {0, 5.4, 3.4, 1.7, 0.2},
+ {0, 5.1, 3.7, 1.5, 0.4},
+ {0, 4.6, 3.6, 1, 0.2},
+ {0, 5.1, 3.3, 1.7, 0.5},
+ {0, 4.8, 3.4, 1.9, 0.2},
+ {0, 5, 3, 1.6, 0.2},
+ {0, 5, 3.4, 1.6, 0.4},
+ {0, 5.2, 3.5, 1.5, 0.2},
+ {0, 5.2, 3.4, 1.4, 0.2},
+ {0, 4.7, 3.2, 1.6, 0.2},
+ {0, 4.8, 3.1, 1.6, 0.2},
+ {0, 5.4, 3.4, 1.5, 0.4},
+ {0, 5.2, 4.1, 1.5, 0.1},
+ {0, 5.5, 4.2, 1.4, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 5, 3.2, 1.2, 0.2},
+ {0, 5.5, 3.5, 1.3, 0.2},
+ {0, 4.9, 3.1, 1.5, 0.1},
+ {0, 4.4, 3, 1.3, 0.2},
+ {0, 5.1, 3.4, 1.5, 0.2},
+ {0, 5, 3.5, 1.3, 0.3},
+ {0, 4.5, 2.3, 1.3, 0.3},
+ {0, 4.4, 3.2, 1.3, 0.2},
+ {0, 5, 3.5, 1.6, 0.6},
+ {0, 5.1, 3.8, 1.9, 0.4},
+ {0, 4.8, 3, 1.4, 0.3},
+ {0, 5.1, 3.8, 1.6, 0.2},
+ {0, 4.6, 3.2, 1.4, 0.2},
+ {0, 5.3, 3.7, 1.5, 0.2},
+ {0, 5, 3.3, 1.4, 0.2},
+ {1, 7, 3.2, 4.7, 1.4},
+ {1, 6.4, 3.2, 4.5, 1.5},
+ {1, 6.9, 3.1, 4.9, 1.5},
+ {1, 5.5, 2.3, 4, 1.3},
+ {1, 6.5, 2.8, 4.6, 1.5},
+ {1, 5.7, 2.8, 4.5, 1.3},
+ {1, 6.3, 3.3, 4.7, 1.6},
+ {1, 4.9, 2.4, 3.3, 1},
+ {1, 6.6, 2.9, 4.6, 1.3},
+ {1, 5.2, 2.7, 3.9, 1.4},
+ {1, 5, 2, 3.5, 1},
+ {1, 5.9, 3, 4.2, 1.5},
+ {1, 6, 2.2, 4, 1},
+ {1, 6.1, 2.9, 4.7, 1.4},
+ {1, 5.6, 2.9, 3.6, 1.3},
+ {1, 6.7, 3.1, 4.4, 1.4},
+ {1, 5.6, 3, 4.5, 1.5},
+ {1, 5.8, 2.7, 4.1, 1},
+ {1, 6.2, 2.2, 4.5, 1.5},
+ {1, 5.6, 2.5, 3.9, 1.1},
+ {1, 5.9, 3.2, 4.8, 1.8},
+ {1, 6.1, 2.8, 4, 1.3},
+ {1, 6.3, 2.5, 4.9, 1.5},
+ {1, 6.1, 2.8, 4.7, 1.2},
+ {1, 6.4, 2.9, 4.3, 1.3},
+ {1, 6.6, 3, 4.4, 1.4},
+ {1, 6.8, 2.8, 4.8, 1.4},
+ {1, 6.7, 3, 5, 1.7},
+ {1, 6, 2.9, 4.5, 1.5},
+ {1, 5.7, 2.6, 3.5, 1},
+ {1, 5.5, 2.4, 3.8, 1.1},
+ {1, 5.5, 2.4, 3.7, 1},
+ {1, 5.8, 2.7, 3.9, 1.2},
+ {1, 6, 2.7, 5.1, 1.6},
+ {1, 5.4, 3, 4.5, 1.5},
+ {1, 6, 3.4, 4.5, 1.6},
+ {1, 6.7, 3.1, 4.7, 1.5},
+ {1, 6.3, 2.3, 4.4, 1.3},
+ {1, 5.6, 3, 4.1, 1.3},
+ {1, 5.5, 2.5, 4, 1.3},
+ {1, 5.5, 2.6, 4.4, 1.2},
+ {1, 6.1, 3, 4.6, 1.4},
+ {1, 5.8, 2.6, 4, 1.2},
+ {1, 5, 2.3, 3.3, 1},
+ {1, 5.6, 2.7, 4.2, 1.3},
+ {1, 5.7, 3, 4.2, 1.2},
+ {1, 5.7, 2.9, 4.2, 1.3},
+ {1, 6.2, 2.9, 4.3, 1.3},
+ {1, 5.1, 2.5, 3, 1.1},
+ {1, 5.7, 2.8, 4.1, 1.3},
+ };
+
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java
new file mode 100644
index 0000000..cf27a94
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML logistic regression examples.
+ */
+package org.apache.ignite.examples.ml.regression.logistic;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java
index 13fcb60..a0e8c66 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java
@@ -44,4 +44,79 @@
}).sum() / (vector.size());
}
};
+ /**
+ * Log loss function.
+ */
+ public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> LOG = groundTruth ->
+ new IgniteDifferentiableVectorToDoubleFunction() {
+ /** {@inheritDoc} */
+ @Override public Vector differential(Vector pnt) {
+ double multiplier = 2.0 / pnt.size();
+ return pnt.minus(groundTruth).times(multiplier);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector vector) {
+ return groundTruth.copy().map(vector,
+ (a, b) -> a == 1 ? - Math.log(b) : -Math.log(1 - b)
+ ).sum();
+ }
+ };
+
+ /**
+ * L2 loss function.
+ */
+ public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> L2 = groundTruth ->
+ new IgniteDifferentiableVectorToDoubleFunction() {
+ /** {@inheritDoc} */
+ @Override public Vector differential(Vector pnt) {
+ double multiplier = 2.0 / pnt.size();
+ return pnt.minus(groundTruth).times(multiplier);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector vector) {
+ return groundTruth.copy().map(vector, (a, b) -> {
+ double diff = a - b;
+ return diff * diff;
+ }).sum();
+ }
+ };
+
+ /**
+ * L1 loss function.
+ */
+ public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> L1 = groundTruth ->
+ new IgniteDifferentiableVectorToDoubleFunction() {
+ /** {@inheritDoc} */
+ @Override public Vector differential(Vector pnt) {
+ double multiplier = 2.0 / pnt.size();
+ return pnt.minus(groundTruth).times(multiplier);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector vector) {
+ return groundTruth.copy().map(vector, (a, b) -> {
+ double diff = a - b;
+ return Math.abs(diff);
+ }).sum();
+ }
+ };
+
+ /**
+ * Hinge loss function.
+ */
+ public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> HINGE = groundTruth ->
+ new IgniteDifferentiableVectorToDoubleFunction() {
+ /** {@inheritDoc} */
+ @Override public Vector differential(Vector pnt) {
+ double multiplier = 2.0 / pnt.size();
+ return pnt.minus(groundTruth).times(multiplier);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector vector) {
+ return Math.max(0, 1 - groundTruth.dot(vector));
+ }
+ };
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
new file mode 100644
index 0000000..8ea1490
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic.binomial;
+
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
+
+import java.io.Serializable;
+import java.util.Objects;
+
+/**
+ * Logistic regression (logit model) is a generalized linear model used for binomial regression.
+ */
+public class LogisticRegressionModel implements Model<Vector, Double>, Exportable<LogisticRegressionModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -133984600091550776L;
+
+ /** Multiplier of the objects's vector required to make prediction. */
+ private Vector weights;
+
+ /** Intercept of the linear regression model. */
+ private double intercept;
+
+ /** Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise. */
+ private boolean isKeepingRawLabels = false;
+
+ /** Threshold to assign '1' label to the observation if raw value more than this threshold. */
+ private double threshold = 0.5;
+
+ /** */
+ public LogisticRegressionModel(Vector weights, double intercept) {
+ this.weights = weights;
+ this.intercept = intercept;
+ }
+
+ /**
+ * Set up the output label format.
+ *
+ * @param isKeepingRawLabels The parameter value.
+ * @return Model with new isKeepingRawLabels parameter value.
+ */
+ public LogisticRegressionModel withRawLabels(boolean isKeepingRawLabels) {
+ this.isKeepingRawLabels = isKeepingRawLabels;
+ return this;
+ }
+
+ /**
+ * Set up the threshold.
+ *
+ * @param threshold The parameter value.
+ * @return Model with new threshold parameter value.
+ */
+ public LogisticRegressionModel withThreshold(double threshold) {
+ this.threshold = threshold;
+ return this;
+ }
+
+ /**
+ * Set up the weights.
+ *
+ * @param weights The parameter value.
+ * @return Model with new weights parameter value.
+ */
+ public LogisticRegressionModel withWeights(Vector weights) {
+ this.weights = weights;
+ return this;
+ }
+
+ /**
+ * Set up the intercept.
+ *
+ * @param intercept The parameter value.
+ * @return Model with new intercept parameter value.
+ */
+ public LogisticRegressionModel withIntercept(double intercept) {
+ this.intercept = intercept;
+ return this;
+ }
+
+ /**
+ * Gets the output label format mode.
+ *
+ * @return The parameter value.
+ */
+ public boolean isKeepingRawLabels() {
+ return isKeepingRawLabels;
+ }
+
+ /**
+ * Gets the threshold.
+ *
+ * @return The parameter value.
+ */
+ public double threshold() {
+ return threshold;
+ }
+
+ /**
+ * Gets the weights.
+ *
+ * @return The parameter value.
+ */
+ public Vector weights() {
+ return weights;
+ }
+
+ /**
+ * Gets the intercept.
+ *
+ * @return The parameter value.
+ */
+ public double intercept() {
+ return intercept;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector input) {
+ final double res = sigmoid(input.dot(weights) + intercept);
+
+ if (isKeepingRawLabels)
+ return res;
+ else
+ return res - threshold > 0 ? 1.0 : 0;
+ }
+
+ /**
+ * Sigmoid function.
+ * @param z The regression value.
+ * @return The result.
+ */
+ private static double sigmoid(double z) {
+ return 1.0 / (1.0 + Math.exp(-z));
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<LogisticRegressionModel, P> exporter, P path) {
+ exporter.save(this, path);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ LogisticRegressionModel mdl = (LogisticRegressionModel)o;
+
+ return Double.compare(mdl.intercept, intercept) == 0
+ && Double.compare(mdl.threshold, threshold) == 0
+ && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0
+ && Objects.equals(weights, mdl.weights);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ return Objects.hash(weights, intercept, isKeepingRawLabels, threshold);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ if (weights.size() < 20) {
+ StringBuilder builder = new StringBuilder();
+
+ for (int i = 0; i < weights.size(); i++) {
+ double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1);
+
+ builder.append(String.format("%.4f", Math.abs(weights.get(i))))
+ .append("*x")
+ .append(i)
+ .append(nextItem > 0 ? " + " : " - ");
+ }
+
+ builder.append(String.format("%.4f", Math.abs(intercept)));
+ return builder.toString();
+ }
+
+ return "LogisticRegressionModel{" +
+ "weights=" + weights +
+ ", intercept=" + intercept +
+ '}';
+ }
+}
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
new file mode 100644
index 0000000..8fe57cf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic.binomial;
+
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
+ */
+public class LogisticRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LogisticRegressionModel> {
+ /** Update strategy. */
+ private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+
+ /** Max number of iteration. */
+ private final int maxIterations;
+
+ /** Batch size. */
+ private final int batchSize;
+
+ /** Number of local iterations. */
+ private final int locIterations;
+
+ /** Seed for random generator. */
+ private final long seed;
+
+ /**
+ * Constructs a new instance of linear regression SGD trainer.
+ *
+ * @param updatesStgy Update strategy.
+ * @param maxIterations Max number of iteration.
+ * @param batchSize Batch size.
+ * @param locIterations Number of local iterations.
+ * @param seed Seed for random generator.
+ */
+ public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
+ int batchSize, int locIterations, long seed) {
+ this.updatesStgy = updatesStgy;
+ this.maxIterations = maxIterations;
+ this.batchSize = batchSize;
+ this.locIterations = locIterations;
+ this.seed = seed;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+ IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
+
+ int cols = dataset.compute(data -> {
+ if (data.getFeatures() == null)
+ return null;
+ return data.getFeatures().length / data.getRows();
+ }, (a, b) -> a == null ? b : a);
+
+ MLPArchitecture architecture = new MLPArchitecture(cols);
+ architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
+
+ return architecture;
+ };
+
+ MLPTrainer<?> trainer = new MLPTrainer<>(
+ archSupplier,
+ LossFunctions.L2,
+ updatesStgy,
+ maxIterations,
+ batchSize,
+ locIterations,
+ seed
+ );
+
+ MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)});
+
+ double[] params = mlp.parameters().getStorage().data();
+
+ return new LogisticRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(params, params.length - 1)),
+ params[params.length - 1]
+ );
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java
new file mode 100644
index 0000000..d32b1ee
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains binomial logistic regression.
+ */
+package org.apache.ignite.ml.regressions.logistic.binomial;
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java
new file mode 100644
index 0000000..b1f8331
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various logistic regressions.
+ */
+package org.apache.ignite.ml.regressions.logistic;
\ No newline at end of file
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 5005ef2..2d21d3b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -17,7 +17,11 @@
package org.apache.ignite.ml.regressions;
-import org.apache.ignite.ml.regressions.linear.*;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -28,7 +32,9 @@
@Suite.SuiteClasses({
LinearRegressionModelTest.class,
LinearRegressionLSQRTrainerTest.class,
- LinearRegressionSGDTrainerTest.class
+ LinearRegressionSGDTrainerTest.class,
+ LogisticRegressionModelTest.class,
+ LogisticRegressionSGDTrainerTest.class
})
public class RegressionsTestSuite {
// No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
new file mode 100644
index 0000000..1268a7d
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.junit.Test;
+
+/**
+ * Tests for {@link LogisticRegressionModel}.
+ */
+public class LogisticRegressionModelTest {
+ /** */
+ private static final double PRECISION = 1e-6;
+
+ /** */
+ @Test
+ public void testPredict() {
+ Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0});
+ LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0).withRawLabels(true);
+
+ Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
+ TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
+
+ observation = new DenseLocalOnHeapVector(new double[]{2.0, 1.0});
+ TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
+
+ observation = new DenseLocalOnHeapVector(new double[]{1.0, 2.0});
+ TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION);
+
+ observation = new DenseLocalOnHeapVector(new double[]{-2.0, 1.0});
+ TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
+
+ observation = new DenseLocalOnHeapVector(new double[]{1.0, -2.0});
+ TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION);
+ }
+
+ /** */
+ @Test(expected = CardinalityException.class)
+ public void testPredictOnAnObservationWithWrongCardinality() {
+ Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0});
+
+ LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0);
+
+ Vector observation = new DenseLocalOnHeapVector(new double[]{1.0});
+
+ mdl.apply(observation);
+ }
+
+ /**
+ * Sigmoid function.
+ * @param z The regression value.
+ * @return The result.
+ */
+ private static double sigmoid(double z) {
+ return 1.0 / (1.0 + Math.exp(-z));
+ }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
new file mode 100644
index 0000000..27d3a30e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.logistic;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * Tests for {@LogisticRegressionSGDTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class LogisticRegressionSGDTrainerTest {
+ /** Fixed size of Dataset. */
+ private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+ /** Fixed size of columns in Dataset. */
+ private static final int AMOUNT_OF_FEATURES = 2;
+
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+
+ /** Parameters. */
+ @Parameterized.Parameters(name = "Data divided on {0} partitions")
+ public static Iterable<Integer[]> data() {
+ return Arrays.asList(
+ new Integer[] {1},
+ new Integer[] {2},
+ new Integer[] {3},
+ new Integer[] {5},
+ new Integer[] {7},
+ new Integer[] {100}
+ );
+ }
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ /**
+ * Test trainer on classification model y = x.
+ */
+ @Test
+ public void trainWithTheLinearlySeparableCase() {
+ Map<Integer, double[]> data = new HashMap<>();
+
+ ThreadLocalRandom rndX = ThreadLocalRandom.current();
+ ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+ for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+ double x = rndX.nextDouble(-1000, 1000);
+ double y = rndY.nextDouble(-1000, 1000);
+ double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+ vec[0] = y - x > 0 ? 1 : 0; // assign label.
+ vec[1] = x;
+ vec[2] = y;
+ data.put(i, vec);
+ }
+
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
+
+ LogisticRegressionModel mdl = trainer.fit(
+ data,
+ 10,
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ );
+
+ TestUtils.assertEquals(0, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION);
+ TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION);
+ }
+}