blob: 56d2d29d539e6a15c4e3650d1cddd855156873da [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.regressions.logistic.multiclass;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
/** Base class for multi-classification model for set of Logistic Regression classifiers. */
public class LogRegressionMultiClassModel implements Model<Vector, Double>, Exportable<LogRegressionMultiClassModel>, Serializable {
/** */
private static final long serialVersionUID = -114986533350117L;
/** List of models associated with each class. */
private Map<Double, LogisticRegressionModel> models;
/** */
public LogRegressionMultiClassModel() {
this.models = new HashMap<>();
}
/** {@inheritDoc} */
@Override public Double apply(Vector input) {
TreeMap<Double, Double> maxMargins = new TreeMap<>();
models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k));
return maxMargins.lastEntry().getValue();
}
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<LogRegressionMultiClassModel, P> exporter, P path) {
exporter.save(this, path);
}
/** {@inheritDoc} */
@Override public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o;
return Objects.equals(models, mdl.models);
}
/** {@inheritDoc} */
@Override public int hashCode() {
return Objects.hash(models);
}
/** {@inheritDoc} */
@Override public String toString() {
StringBuilder wholeStr = new StringBuilder();
models.forEach((clsLb, mdl) ->
wholeStr
.append("The class with label ")
.append(clsLb)
.append(" has classifier: ")
.append(mdl.toString())
.append(System.lineSeparator())
);
return wholeStr.toString();
}
/** {@inheritDoc} */
@Override public String toString(boolean pretty) {
return toString();
}
/**
* Adds a specific Log Regression binary classifier to the bunch of same classifiers.
*
* @param clsLb The class label for the added model.
* @param mdl The model.
*/
public void add(double clsLb, LogisticRegressionModel mdl) {
models.put(clsLb, mdl);
}
}