blob: 73f6cfd67cee9847297814f8bfe1b92c7d4989f4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
library(testthat)
context("MLlib recommendation algorithms")
# Tests for MLlib recommendation algorithms in SparkR
sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("spark.als", {
data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
list(2, 1, 1.0), list(2, 2, 5.0))
df <- createDataFrame(data, c("user", "item", "score"))
model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
rank = 10, maxIter = 15, seed = 0, regParam = 0.1)
stats <- summary(model)
expect_equal(stats$rank, 10)
test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
result <- predict(model, test)
predictions <- collect(arrange(result, desc(result$item), result$user))
expect_equal(predictions$prediction, c(0.6324540, 3.6218479, -0.4568263),
tolerance = 1e-4)
# Test model save/load
if (windows_with_hadoop()) {
modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_equal(stats2$rating, "score")
userFactors <- collect(stats$userFactors)
itemFactors <- collect(stats$itemFactors)
userFactors2 <- collect(stats2$userFactors)
itemFactors2 <- collect(stats2$itemFactors)
orderUser <- order(userFactors$id)
orderUser2 <- order(userFactors2$id)
expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
orderItem <- order(itemFactors$id)
orderItem2 <- order(itemFactors2$id)
expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
unlink(modelPath)
}
})
sparkR.session.stop()