blob: 65a43514930f0b0216e31bd3462350884c286b6f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration
#' S4 class that represents a FPGrowthModel
#'
#' @param jobj a Java object reference to the backing Scala FPGrowthModel
#' @note FPGrowthModel since 2.2.0
setClass("FPGrowthModel", slots = list(jobj = "jobj"))
#' S4 class that represents a PrefixSpan
#'
#' @param jobj a Java object reference to the backing Scala PrefixSpan
#' @note PrefixSpan since 3.0.0
setClass("PrefixSpan", slots = list(jobj = "jobj"))
#' FP-growth
#'
#' A parallel FP-growth algorithm to mine frequent itemsets.
#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can
#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get
#' association rules, \code{predict} to make predictions on new data based on generated association
#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#' For more details, see
#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{
#' FP-growth}.
#'
#' @param data A SparkDataFrame for training.
#' @param minSupport Minimal support level.
#' @param minConfidence Minimal confidence level.
#' @param itemsCol Features column name.
#' @param numPartitions Number of partitions used for fitting.
#' @param ... additional argument(s) passed to the method.
#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model.
#' @rdname spark.fpGrowth
#' @name spark.fpGrowth
#' @aliases spark.fpGrowth,SparkDataFrame-method
#' @examples
#' \dontrun{
#' raw_data <- read.df(
#' "data/mllib/sample_fpgrowth.txt",
#' source = "csv",
#' schema = structType(structField("raw_items", "string")))
#'
#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items")
#' model <- spark.fpGrowth(data)
#'
#' # Show frequent itemsets
#' frequent_itemsets <- spark.freqItemsets(model)
#' showDF(frequent_itemsets)
#'
#' # Show association rules
#' association_rules <- spark.associationRules(model)
#' showDF(association_rules)
#'
#' # Predict on new data
#' new_itemsets <- data.frame(items = c("t", "t,s"))
#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items")
#' predict(model, new_data)
#'
#' # Save and load model
#' path <- "/path/to/model"
#' write.ml(model, path)
#' read.ml(path)
#'
#' # Optional arguments
#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets")
#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5,
#' itemsCol = "baskets", numPartitions = 10)
#' }
#' @note spark.fpGrowth since 2.2.0
setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"),
function(data, minSupport = 0.3, minConfidence = 0.8,
itemsCol = "items", numPartitions = NULL) {
if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) {
stop("minSupport should be a number [0, 1].")
}
if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) {
stop("minConfidence should be a number [0, 1].")
}
if (!is.null(numPartitions)) {
numPartitions <- as.integer(numPartitions)
stopifnot(numPartitions > 0)
}
jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit",
data@sdf, as.numeric(minSupport), as.numeric(minConfidence),
itemsCol, numPartitions)
new("FPGrowthModel", jobj = jobj)
})
# Get frequent itemsets.
#' @param object a fitted FPGrowth model.
#' @return A \code{SparkDataFrame} with frequent itemsets.
#' The \code{SparkDataFrame} contains two columns:
#' \code{items} (an array of the same type as the input column)
#' and \code{freq} (frequency of the itemset).
#' @rdname spark.fpGrowth
#' @aliases freqItemsets,FPGrowthModel-method
#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0
setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"),
function(object) {
dataFrame(callJMethod(object@jobj, "freqItemsets"))
})
# Get association rules.
#' @return A \code{SparkDataFrame} with association rules.
#' The \code{SparkDataFrame} contains five columns:
#' \code{antecedent} (an array of the same type as the input column),
#' \code{consequent} (an array of the same type as the input column),
#' \code{confidence} (confidence for the rule)
#' \code{lift} (lift for the rule)
#' and \code{support} (support for the rule)
#' @rdname spark.fpGrowth
#' @aliases associationRules,FPGrowthModel-method
#' @note spark.associationRules(FPGrowthModel) since 2.2.0
setMethod("spark.associationRules", signature(object = "FPGrowthModel"),
function(object) {
dataFrame(callJMethod(object@jobj, "associationRules"))
})
# Makes predictions based on generated association rules
#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted values.
#' @rdname spark.fpGrowth
#' @aliases predict,FPGrowthModel-method
#' @note predict(FPGrowthModel) since 2.2.0
setMethod("predict", signature(object = "FPGrowthModel"),
function(object, newData) {
predict_internal(object, newData)
})
# Saves the FPGrowth model to the output path.
#' @param path the directory where the model is saved.
#' @param overwrite logical value indicating whether to overwrite if the output path
#' already exists. Default is FALSE which means throw exception
#' if the output path exists.
#' @rdname spark.fpGrowth
#' @aliases write.ml,FPGrowthModel,character-method
#' @seealso \link{read.ml}
#' @note write.ml(FPGrowthModel, character) since 2.2.0
setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
#' PrefixSpan
#'
#' A parallel PrefixSpan algorithm to mine frequent sequential patterns.
#' \code{spark.findFrequentSequentialPatterns} returns a complete set of frequent sequential
#' patterns.
#' For more details, see
#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#prefixspan}{
#' PrefixSpan}.
#'
# Find frequent sequential patterns.
#' @param data A SparkDataFrame.
#' @param minSupport Minimal support level.
#' @param maxPatternLength Maximal pattern length.
#' @param maxLocalProjDBSize Maximum number of items (including delimiters used in the internal
#' storage format) allowed in a projected database before local
#' processing.
#' @param sequenceCol name of the sequence column in dataset.
#' @param ... additional argument(s) passed to the method.
#' @return A complete set of frequent sequential patterns in the input sequences of itemsets.
#' The returned \code{SparkDataFrame} contains columns of sequence and corresponding
#' frequency. The schema of it will be:
#' \code{sequence: ArrayType(ArrayType(T))}, \code{freq: integer}
#' where T is the item type
#' @rdname spark.prefixSpan
#' @aliases findFrequentSequentialPatterns,PrefixSpan,SparkDataFrame-method
#' @examples
#' \dontrun{
#' df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))),
#' list(list(list(1L), list(3L, 2L), list(1L, 2L))),
#' list(list(list(1L, 2L), list(5L))),
#' list(list(list(6L)))),
#' schema = c("sequence"))
#' frequency <- spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L,
#' maxLocalProjDBSize = 32000000L)
#' showDF(frequency)
#' }
#' @note spark.findFrequentSequentialPatterns(SparkDataFrame) since 3.0.0
setMethod("spark.findFrequentSequentialPatterns",
signature(data = "SparkDataFrame"),
function(data, minSupport = 0.1, maxPatternLength = 10L,
maxLocalProjDBSize = 32000000L, sequenceCol = "sequence") {
if (!is.numeric(minSupport) || minSupport < 0) {
stop("minSupport should be a number with value >= 0.")
}
if (!is.integer(maxPatternLength) || maxPatternLength <= 0) {
stop("maxPatternLength should be a number with value > 0.")
}
if (!is.numeric(maxLocalProjDBSize) || maxLocalProjDBSize <= 0) {
stop("maxLocalProjDBSize should be a number with value > 0.")
}
jobj <- callJStatic("org.apache.spark.ml.r.PrefixSpanWrapper", "getPrefixSpan",
as.numeric(minSupport), as.integer(maxPatternLength),
as.numeric(maxLocalProjDBSize), as.character(sequenceCol))
object <- new("PrefixSpan", jobj = jobj)
dataFrame(callJMethod(object@jobj, "findFrequentSequentialPatterns", data@sdf))
}
)