blob: d943d8d0ab4c0b5c75f998dd0f88a21184d1b8ab [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mllib_utils.R: Utilities for MLlib integration
# Integration with R's standard functions.
# Most of MLlib's algorithms are provided in two flavours:
# - a specialization of the default R methods (glm). These methods try to respect
# the inputs and the outputs of R's method to the largest extent, but some small differences
# may exist.
# - a set of methods that reflect the arguments of the other languages supported by Spark. These
# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
#' MLlib model below.
#' @rdname write.ml
#' @name write.ml
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.logit},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear},
#' @seealso \link{read.ml}
NULL
#' Makes predictions from a MLlib model
#'
#' Makes predictions from a MLlib model. For more information, see the specific
#' MLlib model below.
#' @rdname predict
#' @name predict
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
NULL
write_internal <- function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
}
predict_internal <- function(object, newData) {
dataFrame(callJMethod(object@jobj, "transform", newData@sdf))
}
#' Load a fitted MLlib model from the input path.
#'
#' @param path path of the model to read.
#' @return A fitted MLlib model.
#' @rdname read.ml
#' @name read.ml
#' @seealso \link{write.ml}
#' @examples
#' \dontrun{
#' path <- "path/to/model"
#' model <- read.ml(path)
#' }
#' @note read.ml since 2.0.0
read.ml <- function(path) {
path <- suppressWarnings(normalizePath(path))
sparkSession <- getSparkSession()
callJStatic("org.apache.spark.ml.r.RWrappers", "session", sparkSession)
jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
new("NaiveBayesModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
new("AFTSurvivalRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
new("GeneralizedLinearRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
new("KMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
new("LDAModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) {
new("MultilayerPerceptronClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
new("IsotonicRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
new("GaussianMixtureModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
new("ALSModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
new("LogisticRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) {
new("RandomForestRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
new("RandomForestClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
new("DecisionTreeRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
new("DecisionTreeClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
new("GBTRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {
new("GBTClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.BisectingKMeansWrapper")) {
new("BisectingKMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) {
new("LinearSVCModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
new("FPGrowthModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMClassifierWrapper")) {
new("FMClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearRegressionWrapper")) {
new("LinearRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMRegressorWrapper")) {
new("FMRegressionModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
}