blob: 8ccc28f019322cc3cc4032fa4ad0957110a9a64d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import copy
import logging
import unittest
from parameterized import parameterized
from apache_beam.ml.anomaly.base import AggregationFn
from apache_beam.ml.anomaly.base import AnomalyDetector
from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector
from apache_beam.ml.anomaly.base import ThresholdFn
from apache_beam.ml.anomaly.specifiable import _KNOWN_SPECIFIABLE
from apache_beam.ml.anomaly.specifiable import Spec
from apache_beam.ml.anomaly.specifiable import Specifiable
from apache_beam.ml.anomaly.specifiable import specifiable
class TestAnomalyDetector(unittest.TestCase):
def setUp(self) -> None:
self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
def tearDown(self) -> None:
_KNOWN_SPECIFIABLE.clear()
_KNOWN_SPECIFIABLE.update(self.saved_specifiable)
@parameterized.expand([(False, False), (True, False), (False, True),
(True, True)])
def test_model_id_and_spec(self, on_demand_init, just_in_time_init):
@specifiable(
on_demand_init=on_demand_init, just_in_time_init=just_in_time_init)
class DummyThreshold(ThresholdFn):
def __init__(self, my_threshold_arg=None):
...
def is_stateful(self):
return False
def threshold(self):
...
def apply(self, x):
...
@specifiable(
on_demand_init=on_demand_init, just_in_time_init=just_in_time_init)
class Dummy(AnomalyDetector):
def __init__(self, my_arg=None, **kwargs):
self._my_arg = my_arg
super().__init__(**kwargs)
def learn_one(self):
...
def score_one(self):
...
def __eq__(self, value) -> bool:
return isinstance(value, Dummy) and \
self._my_arg == value._my_arg
a = Dummy(
my_arg="abc",
target="ABC",
threshold_criterion=(t1 := DummyThreshold(2)))
# The class attributes can only be accessed when
# (1) on_demand_init == False, just_in_time_init == False
# In this case, the true __init__ is called immediately during object
# initialization
# (2) just_in_time_init == True
# In this case, regardless what on_demand_init is, the true __init__
# is called when we first access any class attribute (delay init).
if just_in_time_init or not on_demand_init:
self.assertEqual(a._model_id, "Dummy")
self.assertEqual(a._target, "ABC")
self.assertEqual(a._my_arg, "abc")
assert isinstance(a, Specifiable)
self.assertEqual(
a.init_kwargs, {
"my_arg": "abc",
"target": "ABC",
"threshold_criterion": t1,
})
b = Dummy(
my_arg="efg",
model_id="my_dummy",
target="EFG",
threshold_criterion=(t2 := DummyThreshold(3)))
# See the comment above for more details.
if just_in_time_init or not on_demand_init:
self.assertEqual(b._model_id, "my_dummy")
self.assertEqual(b._target, "EFG")
self.assertEqual(b._my_arg, "efg")
assert isinstance(b, Specifiable)
self.assertEqual(
b.init_kwargs,
{
"model_id": "my_dummy",
"my_arg": "efg",
"target": "EFG",
"threshold_criterion": t2,
})
spec = b.to_spec()
expected_spec = Spec(
type="Dummy",
config={
"my_arg": "efg",
"model_id": "my_dummy",
"target": "EFG",
"threshold_criterion": Spec(
type="DummyThreshold", config={"my_threshold_arg": 3}),
})
self.assertEqual(spec, expected_spec)
b_dup = Specifiable.from_spec(spec)
# We need to manually call the original __init__ function in one scenario.
# See the comment above for more details.
if on_demand_init and not just_in_time_init:
b.run_original_init()
self.assertEqual(b, b_dup)
class TestEnsembleAnomalyDetector(unittest.TestCase):
def setUp(self) -> None:
self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
def tearDown(self) -> None:
_KNOWN_SPECIFIABLE.clear()
_KNOWN_SPECIFIABLE.update(self.saved_specifiable)
@parameterized.expand([(False, False), (True, False), (False, True),
(True, True)])
def test_model_id_and_spec(self, on_demand_init, just_in_time_init):
@specifiable(
on_demand_init=on_demand_init, just_in_time_init=just_in_time_init)
class DummyAggregation(AggregationFn):
def apply(self, x):
...
@specifiable(
on_demand_init=on_demand_init, just_in_time_init=just_in_time_init)
class DummyEnsemble(EnsembleAnomalyDetector):
def __init__(self, my_ensemble_arg=None, **kwargs):
super().__init__(**kwargs)
self._my_ensemble_arg = my_ensemble_arg
def learn_one(self):
...
def score_one(self):
...
def __eq__(self, value) -> bool:
return isinstance(value, DummyEnsemble) and \
self._my_ensemble_arg == value._my_ensemble_arg
@specifiable(
on_demand_init=on_demand_init, just_in_time_init=just_in_time_init)
class DummyWeakLearner(AnomalyDetector):
def __init__(self, my_arg=None, **kwargs):
super().__init__(**kwargs)
self._my_arg = my_arg
def learn_one(self):
...
def score_one(self):
...
def __eq__(self, value) -> bool:
return isinstance(value, DummyWeakLearner) \
and self._my_arg == value._my_arg
# See the comment in TestAnomalyDetector for more details.
if just_in_time_init or not on_demand_init:
a = DummyEnsemble()
self.assertEqual(a._model_id, "DummyEnsemble")
b = DummyEnsemble(model_id="my_dummy_ensemble")
self.assertEqual(b._model_id, "my_dummy_ensemble")
c = EnsembleAnomalyDetector()
self.assertEqual(c._model_id, "custom")
d = EnsembleAnomalyDetector(model_id="my_dummy_ensemble_2")
self.assertEqual(d._model_id, "my_dummy_ensemble_2")
d1 = DummyWeakLearner(my_arg=1)
d2 = DummyWeakLearner(my_arg=2)
ensemble = DummyEnsemble(
my_ensemble_arg=123,
learners=[d1, d2],
aggregation_strategy=DummyAggregation())
expected_spec = Spec(
type="DummyEnsemble",
config={
"my_ensemble_arg": 123,
"learners": [
Spec(type="DummyWeakLearner", config={"my_arg": 1}),
Spec(type="DummyWeakLearner", config={"my_arg": 2})
],
"aggregation_strategy": Spec(
type="DummyAggregation",
config={},
),
})
assert isinstance(ensemble, Specifiable)
spec = ensemble.to_spec()
self.assertEqual(spec, expected_spec)
ensemble_dup = Specifiable.from_spec(spec)
# See the comment in TestAnomalyDetector for more details.
if on_demand_init and not just_in_time_init:
ensemble.run_original_init()
self.assertEqual(ensemble, ensemble_dup)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()