blob: b944a4a82c146a8ef8a709734affd54ce6431322 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PubSub verifier used for end-to-end test."""
# pytype: skip-file
import logging
import time
from collections import Counter
from hamcrest.core.base_matcher import BaseMatcher
from apache_beam.io.gcp.pubsub import PubsubMessage
__all__ = ['PubSubMessageMatcher']
# Protect against environments where pubsub library is not available.
try:
from google.cloud import pubsub
except ImportError:
pubsub = None
DEFAULT_TIMEOUT = 5 * 60
DEFAULT_SLEEP_TIME = 1
DEFAULT_MAX_MESSAGES_IN_ONE_PULL = 50
DEFAULT_PULL_TIMEOUT = 30.0
_LOGGER = logging.getLogger(__name__)
class PubSubMessageMatcher(BaseMatcher):
"""Matcher that verifies messages from given subscription.
This matcher can block the test and keep pulling messages from given
subscription until all expected messages are shown or timeout.
"""
def __init__(
self,
project,
sub_name,
expected_msg=None,
expected_msg_len=None,
timeout=DEFAULT_TIMEOUT,
with_attributes=False,
strip_attributes=None,
sleep_time=DEFAULT_SLEEP_TIME,
max_messages_in_one_pull=DEFAULT_MAX_MESSAGES_IN_ONE_PULL,
pull_timeout=DEFAULT_PULL_TIMEOUT):
"""Initialize PubSubMessageMatcher object.
Args:
project: A name string of project.
sub_name: A name string of subscription which is attached to output.
expected_msg: A string list that contains expected message data pulled
from the subscription. See also: with_attributes.
expected_msg_len: Number of expected messages pulled from the
subscription.
timeout: Timeout in seconds to wait for all expected messages appears.
with_attributes: If True, will match against both message data and
attributes. If True, expected_msg should be a list of ``PubsubMessage``
objects. Otherwise, it should be a list of ``bytes``.
strip_attributes: List of strings. If with_attributes==True, strip the
attributes keyed by these values from incoming messages.
If a key is missing, will add an attribute with an error message as
value to prevent a successful match.
sleep_time: Time in seconds between which the pulls from pubsub are done.
max_messages_in_one_pull: Maximum number of messages pulled from pubsub
at once.
pull_timeout: Time in seconds after which the pull from pubsub is repeated
"""
if pubsub is None:
raise ImportError('PubSub dependencies are not installed.')
if not project:
raise ValueError('Invalid project %s.' % project)
if not sub_name:
raise ValueError('Invalid subscription %s.' % sub_name)
if not expected_msg_len and not expected_msg:
raise ValueError(
'Required expected_msg: {} or expected_msg_len: {}.'.format(
expected_msg, expected_msg_len))
if expected_msg and not isinstance(expected_msg, list):
raise ValueError('Invalid expected messages %s.' % expected_msg)
if expected_msg_len and not isinstance(expected_msg_len, int):
raise ValueError('Invalid expected messages %s.' % expected_msg_len)
self.project = project
self.sub_name = sub_name
self.expected_msg = expected_msg
self.expected_msg_len = expected_msg_len or len(self.expected_msg)
self.timeout = timeout
self.messages = None
self.with_attributes = with_attributes
self.strip_attributes = strip_attributes
self.sleep_time = sleep_time
self.max_messages_in_one_pull = max_messages_in_one_pull
self.pull_timeout = pull_timeout
def _matches(self, _):
if self.messages is None:
self.messages = self._wait_for_messages(
self.expected_msg_len, self.timeout)
if self.expected_msg:
return Counter(self.messages) == Counter(self.expected_msg)
else:
return len(self.messages) == self.expected_msg_len
def _wait_for_messages(self, expected_num, timeout):
"""Wait for messages from given subscription."""
total_messages = []
sub_client = pubsub.SubscriberClient()
start_time = time.time()
while time.time() - start_time <= timeout:
response = sub_client.pull(
self.sub_name,
max_messages=self.max_messages_in_one_pull,
return_immediately=True,
timeout=self.pull_timeout)
for rm in response.received_messages:
msg = PubsubMessage._from_message(rm.message)
if not self.with_attributes:
total_messages.append(msg.data)
continue
if self.strip_attributes:
for attr in self.strip_attributes:
try:
del msg.attributes[attr]
except KeyError:
msg.attributes[attr] = (
'PubSubMessageMatcher error: '
'expected attribute not found.')
total_messages.append(msg)
ack_ids = [rm.ack_id for rm in response.received_messages]
if ack_ids:
sub_client.acknowledge(self.sub_name, ack_ids)
if len(total_messages) >= expected_num:
break
time.sleep(self.sleep_time)
if time.time() - start_time > timeout:
_LOGGER.error(
'Timeout after %d sec. Received %d messages from %s.',
timeout,
len(total_messages),
self.sub_name)
return total_messages
def describe_to(self, description):
description.append_text('Expected %d messages.' % self.expected_msg_len)
def describe_mismatch(self, _, mismatch_description):
c_expected = Counter(self.expected_msg)
c_actual = Counter(self.messages)
mismatch_description.append_text("Got %d messages. " % (len(self.messages)))
if self.expected_msg:
mismatch_description.append_text(
"Diffs (item, count):\n"
" Expected but not in actual: %s\n"
" Unexpected: %s" % ((c_expected - c_actual).items(),
(c_actual - c_expected).items()))
if self.with_attributes and self.strip_attributes:
mismatch_description.append_text(
'\n Stripped attributes: %r' % self.strip_attributes)