blob: a7035d00368eeab52e112a8c02444a717b316922 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, List, Tuple
from unittest import TestCase
from unittest.mock import AsyncMock, MagicMock, patch
from streampipes.client.client import StreamPipesClient, StreamPipesClientConfig
from streampipes.client.credential_provider import StreamPipesApiKeyCredentials
from streampipes.function_zoo.river_function import OnlineML
from streampipes.functions.utils.data_stream_generator import (
RuntimeType,
create_data_stream,
)
from tests.functions.test_function_handler import TestMessageIterator
class TestUnsupervisedModel:
def __init__(self) -> None:
self.data_x: List[Dict[str, Any]] = []
def learn_one(self, x):
pass
def predict_one(self, x):
self.data_x.append(x)
return int(x["bool"])
class TestSupervisedModel:
def __init__(self) -> None:
self.data_x: List[Dict[str, Any]] = []
self.data_y: List[Dict[str, Any]] = []
def learn_one(self, x, y):
self.data_y.append(y)
def predict_one(self, x):
self.data_x.append(x)
return x["number"] < 13
class TestRiverFunction(TestCase):
def setUp(self) -> None:
self.data_stream = create_data_stream(
"test", attributes={"number": RuntimeType.FLOAT.value, "bool": RuntimeType.BOOLEAN.value}
).to_dict()
self.test_stream_data = [
{"number": 10.3, "bool": True, "timestamp": 1670000001000},
{"number": 13.4, "bool": False, "timestamp": 1670000002000},
{"number": 12.6, "bool": True, "timestamp": 1670000003000},
{"number": 9.0, "bool": True, "timestamp": 1670000004000},
{"number": 10.6, "bool": False, "timestamp": 1670000005000},
]
@patch("streampipes.functions.broker.nats_broker.NatsBroker.disconnect", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.createSubscription", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker._makeConnection", autospec=True)
@patch("streampipes.functions.streampipes_function.time", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.get_message", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.publish_event", autospec=True)
@patch("streampipes.client.client.Session", autospec=True)
def test_river_function_unsupervised(
self,
http_session: MagicMock,
pulish_event: MagicMock,
get_message: MagicMock,
time: MagicMock,
*args: Tuple[AsyncMock]
):
http_session_mock = MagicMock()
http_session_mock.get.return_value.json.return_value = self.data_stream
http_session.return_value = http_session_mock
output_events = []
def save_event(self, event: Dict[str, Any]):
output_events.append(event)
pulish_event.side_effect = save_event
get_message.return_value = TestMessageIterator(self.test_stream_data)
time.side_effect = lambda: 0
client = StreamPipesClient(
client_config=StreamPipesClientConfig(
credential_provider=StreamPipesApiKeyCredentials(username="user", api_key="key"),
host_address="localhost",
)
)
def on_event(self, event, stream_id):
if len(self.model.data_x) >= 3:
self.learning = False
model = TestUnsupervisedModel()
online_learning = OnlineML(
client=client,
stream_ids=["sp:spdatastream:xboBFK"],
model=model,
prediction_type=RuntimeType.INTEGER.value,
on_event=on_event,
)
online_learning.start()
self.assertListEqual(
output_events,
[
{"learning": True, "prediction": 1, "timestamp": 0},
{"learning": True, "prediction": 0, "timestamp": 0},
{"learning": True, "prediction": 1, "timestamp": 0},
{"learning": False, "prediction": 1, "timestamp": 0},
{"learning": False, "prediction": 0, "timestamp": 0},
],
)
self.assertListEqual(model.data_x, self.test_stream_data)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.disconnect", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.createSubscription", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker._makeConnection", autospec=True)
@patch("streampipes.functions.streampipes_function.time", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.get_message", autospec=True)
@patch("streampipes.functions.broker.nats_broker.NatsBroker.publish_event", autospec=True)
@patch("streampipes.client.client.Session", autospec=True)
def test_river_function_supervised(
self,
http_session: MagicMock,
pulish_event: MagicMock,
get_message: MagicMock,
time: MagicMock,
*args: Tuple[AsyncMock]
):
http_session_mock = MagicMock()
http_session_mock.get.return_value.json.return_value = self.data_stream
http_session.return_value = http_session_mock
output_events = []
def save_event(self, event: Dict[str, Any]):
output_events.append(event)
pulish_event.side_effect = save_event
get_message.return_value = TestMessageIterator(self.test_stream_data)
time.side_effect = lambda: 0
client = StreamPipesClient(
client_config=StreamPipesClientConfig(
credential_provider=StreamPipesApiKeyCredentials(username="user", api_key="key"),
host_address="localhost",
)
)
def on_event(self, event, stream_id):
if len(self.model.data_x) >= 3:
self.learning = False
model = TestSupervisedModel()
online_learning = OnlineML(
client=client,
stream_ids=["sp:spdatastream:xboBFK"],
model=model,
prediction_type=RuntimeType.BOOLEAN.value,
supervised=True,
target_label="bool",
on_event=on_event,
)
online_learning.start()
self.assertListEqual(
output_events,
[
{"truth": True, "learning": True, "prediction": True, "timestamp": 0},
{"truth": False, "learning": True, "prediction": False, "timestamp": 0},
{"truth": True, "learning": True, "prediction": True, "timestamp": 0},
{"truth": True, "learning": False, "prediction": True, "timestamp": 0},
{"truth": False, "learning": False, "prediction": True, "timestamp": 0},
],
)
self.assertListEqual(
model.data_x,
[
{"number": 10.3, "timestamp": 1670000001000},
{"number": 13.4, "timestamp": 1670000002000},
{"number": 12.6, "timestamp": 1670000003000},
{"number": 9.0, "timestamp": 1670000004000},
{"number": 10.6, "timestamp": 1670000005000},
],
)
self.assertListEqual(model.data_y, [True, False, True])