blob: ad3ee04e812e21de9ffaee8851d836f4df154695 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LinearSVC
from pyspark.testing.connectutils import ReusedConnectTestCase
class MLConnectCacheTests(ReusedConnectTestCase):
def test_delete_model(self):
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.0, 3.0, Vectors.dense(2.0, 1.0)),
(0.0, 4.0, Vectors.dense(3.0, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)
svc = LinearSVC(maxIter=1, regParam=1.0)
model = svc.fit(df)
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 1)
self.assertEqual(
json.loads(cache_info[0])["class"],
"org.apache.spark.ml.classification.LinearSVCModel",
cache_info,
)
# the `model._summary` holds another ref to the remote model.
assert model._java_obj._ref_count == 2
model_size = spark.client._query_model_size(model._java_obj.ref_id)
assert isinstance(model_size, int) and model_size > 0
model2 = model.copy()
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 1)
assert model._java_obj._ref_count == 3
assert model2._java_obj._ref_count == 3
# explicitly delete the model
del model
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 1)
# Note the copied model 'model2' also holds the `_summary` object,
# and the `_summary` object holds another ref to the remote model.
# so the ref count is 2.
assert model2._java_obj._ref_count == 2
del model2
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 0)
def test_cleanup_ml_cache(self):
spark = self.spark
df = (
spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.0, 3.0, Vectors.dense(2.0, 1.0)),
(0.0, 4.0, Vectors.dense(3.0, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)
svc = LinearSVC(maxIter=1, regParam=1.0)
model1 = svc.fit(df)
model2 = svc.fit(df)
model3 = svc.fit(df)
self.assertEqual(len([model1, model2, model3]), 3)
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 3)
self.assertTrue(
all(
json.loads(c)["class"] == "org.apache.spark.ml.classification.LinearSVCModel"
for c in cache_info
),
cache_info,
)
del model1
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 2)
spark.client._cleanup_ml_cache()
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 0)
if __name__ == "__main__":
from pyspark.testing import main
main()