blob: 2f688d97a3091221ac592f6c7946043b53143aaa [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for Recommendations AI transforms."""
from __future__ import absolute_import
import unittest
import mock
import apache_beam as beam
from apache_beam.metrics import MetricsFilter
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from google.cloud import recommendationengine
from apache_beam.ml.gcp import recommendations_ai
except ImportError:
recommendationengine = None
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
@unittest.skipIf(
recommendationengine is None,
"Recommendations AI dependencies not installed.")
class RecommendationsAICatalogItemTest(unittest.TestCase):
def setUp(self):
self._mock_client = mock.Mock()
self._mock_client.create_catalog_item.return_value = (
recommendationengine.CatalogItem())
self.m2 = mock.Mock()
self.m2.result.return_value = None
self._mock_client.import_catalog_items.return_value = self.m2
self._catalog_item = {
"id": "12345",
"title": "Sample laptop",
"description": "Indisputably the most fantastic laptop ever created.",
"language_code": "en",
"category_hierarchies": [{
"categories": ["Electronic", "Computers"]
}]
}
def test_CreateCatalogItem(self):
expected_counter = 1
with mock.patch.object(recommendations_ai,
'get_recommendation_catalog_client',
return_value=self._mock_client):
p = beam.Pipeline()
_ = (
p | "Create data" >> beam.Create([self._catalog_item])
| "Create CatalogItem" >>
recommendations_ai.CreateCatalogItem(project="test"))
result = p.run()
result.wait_until_finish()
read_filter = MetricsFilter().with_name('api_calls')
query_result = result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
self.assertTrue(read_counter.result == expected_counter)
def test_ImportCatalogItems(self):
expected_counter = 1
with mock.patch.object(recommendations_ai,
'get_recommendation_catalog_client',
return_value=self._mock_client):
p = beam.Pipeline()
_ = (
p | "Create data" >> beam.Create([
(self._catalog_item["id"], self._catalog_item),
(self._catalog_item["id"], self._catalog_item)
]) | "Create CatalogItems" >>
recommendations_ai.ImportCatalogItems(project="test"))
result = p.run()
result.wait_until_finish()
read_filter = MetricsFilter().with_name('api_calls')
query_result = result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
self.assertTrue(read_counter.result == expected_counter)
@unittest.skipIf(
recommendationengine is None,
"Recommendations AI dependencies not installed.")
class RecommendationsAIUserEventTest(unittest.TestCase):
def setUp(self):
self._mock_client = mock.Mock()
self._mock_client.write_user_event.return_value = (
recommendationengine.UserEvent())
self.m2 = mock.Mock()
self.m2.result.return_value = None
self._mock_client.import_user_events.return_value = self.m2
self._user_event = {
"event_type": "page-visit", "user_info": {
"visitor_id": "1"
}
}
def test_CreateUserEvent(self):
expected_counter = 1
with mock.patch.object(recommendations_ai,
'get_recommendation_user_event_client',
return_value=self._mock_client):
p = beam.Pipeline()
_ = (
p | "Create data" >> beam.Create([self._user_event])
| "Create UserEvent" >>
recommendations_ai.WriteUserEvent(project="test"))
result = p.run()
result.wait_until_finish()
read_filter = MetricsFilter().with_name('api_calls')
query_result = result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
self.assertTrue(read_counter.result == expected_counter)
def test_ImportUserEvents(self):
expected_counter = 1
with mock.patch.object(recommendations_ai,
'get_recommendation_user_event_client',
return_value=self._mock_client):
p = beam.Pipeline()
_ = (
p | "Create data" >> beam.Create([
(self._user_event["user_info"]["visitor_id"], self._user_event),
(self._user_event["user_info"]["visitor_id"], self._user_event)
]) | "Create UserEvents" >>
recommendations_ai.ImportUserEvents(project="test"))
result = p.run()
result.wait_until_finish()
read_filter = MetricsFilter().with_name('api_calls')
query_result = result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
self.assertTrue(read_counter.result == expected_counter)
@unittest.skipIf(
recommendationengine is None,
"Recommendations AI dependencies not installed.")
class RecommendationsAIPredictTest(unittest.TestCase):
def setUp(self):
self._mock_client = mock.Mock()
self._mock_client.predict.return_value = [
recommendationengine.PredictResponse()
]
self._user_event = {
"event_type": "page-visit", "user_info": {
"visitor_id": "1"
}
}
def test_Predict(self):
expected_counter = 1
with mock.patch.object(recommendations_ai,
'get_recommendation_prediction_client',
return_value=self._mock_client):
p = beam.Pipeline()
_ = (
p | "Create data" >> beam.Create([self._user_event])
| "Prediction UserEvents" >> recommendations_ai.PredictUserEvent(
project="test", placement_id="recently_viewed_default"))
result = p.run()
result.wait_until_finish()
read_filter = MetricsFilter().with_name('api_calls')
query_result = result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
self.assertTrue(read_counter.result == expected_counter)
if __name__ == '__main__':
unittest.main()