blob: 7b949763c39884080a4707b6be474b30d2d63bde [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tempfile
import unittest
from pyspark.sql import is_remote, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
PrefixSpan,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
class FPMTestsMixin:
def test_fp_growth(self):
if is_remote():
self.skipTest("Do not support Spark Connect.")
df = self.spark.createDataFrame(
[
["r z h k p"],
["z y x w v u t s"],
["s x o n r"],
["x z y m t s q e"],
["z"],
["x z y r q t p"],
],
["items"],
).select(sf.split("items", " ").alias("items"))
fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
fp.setNumPartitions(1)
self.assertEqual(fp.getMinSupport(), 0.2)
self.assertEqual(fp.getMinConfidence(), 0.7)
self.assertEqual(fp.getNumPartitions(), 1)
model = fp.fit(df)
self.assertEqual(fp.uid, model.uid)
self.assertEqual(model.freqItemsets.columns, ["items", "freq"])
self.assertEqual(model.freqItemsets.count(), 54)
self.assertEqual(
model.associationRules.columns,
["antecedent", "consequent", "confidence", "lift", "support"],
)
self.assertEqual(model.associationRules.count(), 89)
output = model.transform(df)
self.assertEqual(output.columns, ["items", "prediction"])
self.assertEqual(output.count(), 6)
# save & load
with tempfile.TemporaryDirectory(prefix="fp_growth") as d:
fp.write().overwrite().save(d)
fp2 = FPGrowth.load(d)
self.assertEqual(str(fp), str(fp2))
model.write().overwrite().save(d)
model2 = FPGrowthModel.load(d)
self.assertEqual(str(model), str(model2))
def test_prefix_span(self):
spark = self.spark
df = spark.createDataFrame(
[
Row(sequence=[[1, 2], [3]]),
Row(sequence=[[1], [3, 2], [1, 2]]),
Row(sequence=[[1, 2], [5]]),
Row(sequence=[[6]]),
]
)
ps = PrefixSpan()
ps.setMinSupport(0.5)
ps.setMaxPatternLength(5)
self.assertEqual(ps.getMinSupport(), 0.5)
self.assertEqual(ps.getMaxPatternLength(), 5)
output = ps.findFrequentSequentialPatterns(df)
self.assertEqual(output.columns, ["sequence", "freq"])
self.assertEqual(output.count(), 5)
head = output.sort("sequence").head()
self.assertEqual(head.sequence, [[1]])
self.assertEqual(head.freq, 3)
class FPMTests(FPMTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.ml.tests.test_fpm import * # noqa: F401,F403
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)