blob: 42275c852ac11c6829f3d0996c5180f5fbb81e78 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tempfile
import unittest
import numpy as np
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import (
LinearSVC,
LinearSVCModel,
OneVsRest,
OneVsRestModel,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
class OneVsRestTestsMixin:
def test_one_vs_rest(self):
spark = self.spark
df = (
spark.createDataFrame(
[
(0, 1.0, Vectors.dense(0.0, 5.0)),
(1, 0.0, Vectors.dense(1.0, 2.0)),
(2, 1.0, Vectors.dense(2.0, 1.0)),
(3, 2.0, Vectors.dense(3.0, 3.0)),
],
["index", "label", "features"],
)
.coalesce(1)
.sortWithinPartitions("index")
.select("label", "features")
)
svc = LinearSVC(maxIter=1, regParam=1.0)
self.assertEqual(svc.getMaxIter(), 1)
self.assertEqual(svc.getRegParam(), 1.0)
ovr = OneVsRest(classifier=svc, parallelism=1)
self.assertEqual(ovr.getParallelism(), 1)
model = ovr.fit(df)
self.assertIsInstance(model, OneVsRestModel)
self.assertEqual(len(model.models), 3)
for submodel in model.models:
self.assertIsInstance(submodel, LinearSVCModel)
self.assertTrue(
np.allclose(model.models[0].intercept, 0.06279247869226989, atol=1e-4),
model.models[0].intercept,
)
self.assertTrue(
np.allclose(
model.models[0].coefficients.toArray(),
[-0.1198765502306968, -0.1027513287691687],
atol=1e-4,
),
model.models[0].coefficients,
)
self.assertTrue(
np.allclose(model.models[1].intercept, 0.025877458475338313, atol=1e-4),
model.models[1].intercept,
)
self.assertTrue(
np.allclose(
model.models[1].coefficients.toArray(),
[-0.0362284418654736, 0.010350983390135305],
atol=1e-4,
),
model.models[1].coefficients,
)
self.assertTrue(
np.allclose(model.models[2].intercept, -0.37024065419409624, atol=1e-4),
model.models[2].intercept,
)
self.assertTrue(
np.allclose(
model.models[2].coefficients.toArray(),
[0.12886829400126, 0.012273170857262873],
atol=1e-4,
),
model.models[2].coefficients,
)
output = model.transform(df)
expected_cols = ["label", "features", "rawPrediction", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)
# Model save & load
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
ovr.write().overwrite().save(d)
ovr2 = OneVsRest.load(d)
self.assertEqual(str(ovr), str(ovr2))
model.write().overwrite().save(d)
model2 = OneVsRestModel.load(d)
self.assertEqual(str(model), str(model2))
class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.ml.tests.test_ovr import * # noqa: F401,F403
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)