| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from collections import defaultdict |
| import json |
| import logging |
| from random import randint |
| import socket |
| import threading |
| import traceback |
| import time |
| import urllib2 |
| import uuid |
| |
| from Types.ttypes import TNetworkAddress |
| from thrift.protocol import TBinaryProtocol |
| from thrift.server.TServer import TServer |
| from thrift.transport import TSocket |
| from thrift.transport import TTransport |
| |
| import StatestoreService.StatestoreSubscriber as Subscriber |
| import StatestoreService.StatestoreService as Statestore |
| from StatestoreService.StatestoreSubscriber import TUpdateStateResponse |
| from StatestoreService.StatestoreSubscriber import TTopicRegistration |
| from ErrorCodes.ttypes import TErrorCode |
| from Status.ttypes import TStatus |
| |
| from tests.common.environ import build_flavor_timeout |
| from tests.common.skip import SkipIfDockerizedCluster |
| |
| LOG = logging.getLogger('test_statestore') |
| |
| # Tests for the statestore. The StatestoreSubscriber class is a skeleton implementation of |
| # a Python-based statestore subscriber with additional hooks to allow testing. Each |
| # StatestoreSubscriber runs its own server so that the statestore may contact it. |
| # |
| # All tests in this file may be run in parallel. They assume that a statestore instance is |
| # already running, and is configured with out-of-the-box defaults (as is the case in our |
| # usual test environment) which govern failure-detector timeouts etc. |
| # |
| # These tests do not yet provide sufficient coverage. |
| # If no topic entries, do the first and second subscribers always get a callback? |
| # Adding topic entries to non-existant topic |
| # Test for from_version and to_version behavior |
| # Test with many concurrent subscribers |
| # Test that only the subscribed-to topics are sent |
| # Test that topic deletions take effect correctly. |
| |
| def get_statestore_subscribers(host='localhost', port=25010): |
| response = urllib2.urlopen("http://{0}:{1}/subscribers?json".format(host, port)) |
| page = response.read() |
| return json.loads(page) |
| |
| STATUS_OK = TStatus(TErrorCode.OK) |
| DEFAULT_UPDATE_STATE_RESPONSE = TUpdateStateResponse(status=STATUS_OK, topic_updates=[], |
| skipped=False) |
| |
| # IMPALA-3501: the timeout needs to be higher in code coverage builds |
| WAIT_FOR_FAILURE_TIMEOUT = build_flavor_timeout(40, code_coverage_build_timeout=60) |
| WAIT_FOR_HEARTBEAT_TIMEOUT = build_flavor_timeout( |
| 40, code_coverage_build_timeout=60) |
| WAIT_FOR_UPDATE_TIMEOUT = build_flavor_timeout(40, code_coverage_build_timeout=60) |
| |
| class WildcardServerSocket(TSocket.TSocketBase, TTransport.TServerTransportBase): |
| """Specialised server socket that binds to a random port at construction""" |
| def __init__(self, host=None, port=0): |
| self.host = host |
| self.handle = None |
| self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| self.handle.bind(('localhost', 0)) |
| _, self.port = self.handle.getsockname() |
| |
| def listen(self): |
| self.handle.listen(128) |
| |
| def accept(self): |
| client, addr = self.handle.accept() |
| result = TSocket.TSocket() |
| result.setHandle(client) |
| return result |
| |
| |
| class KillableThreadedServer(TServer): |
| """Based on TServer.TThreadedServer, this server may be shutdown (by calling |
| shutdown()), after which no new connections may be made. Most of the implementation is |
| directly copied from Thrift.""" |
| def __init__(self, *args, **kwargs): |
| TServer.__init__(self, *args) |
| self.daemon = kwargs.get("daemon", False) |
| self.is_shutdown = False |
| self.port = self.serverTransport.port |
| |
| def shutdown(self): |
| self.is_shutdown = True |
| self.serverTransport.close() |
| self.wait_until_down() |
| # The processor contains a reference to a StatestoreSubscriber. Clean up that |
| # reference to avoid a circular reference that would prevent object deletion. |
| self.processor = None |
| |
| def wait_until_up(self, num_tries=10): |
| for i in xrange(num_tries): |
| cnxn = TSocket.TSocket('localhost', self.port) |
| try: |
| cnxn.open() |
| return |
| except Exception, e: |
| if i == num_tries - 1: raise |
| time.sleep(0.1) |
| |
| def wait_until_down(self, num_tries=10): |
| for i in xrange(num_tries): |
| cnxn = TSocket.TSocket('localhost', self.port) |
| try: |
| cnxn.open() |
| time.sleep(0.1) |
| except Exception, e: |
| return |
| raise Exception("Server did not stop") |
| |
| def serve(self): |
| self.serverTransport.listen() |
| while not self.is_shutdown: |
| client = self.serverTransport.accept() |
| # Since accept() can take a while, check again if the server is shutdown to avoid |
| # starting an unnecessary thread. |
| if self.is_shutdown: return |
| t = threading.Thread(target=self.handle, args=(client,)) |
| t.setDaemon(self.daemon) |
| t.start() |
| |
| def handle(self, client): |
| itrans = self.inputTransportFactory.getTransport(client) |
| otrans = self.outputTransportFactory.getTransport(client) |
| iprot = self.inputProtocolFactory.getProtocol(itrans) |
| oprot = self.outputProtocolFactory.getProtocol(otrans) |
| try: |
| while not self.is_shutdown: |
| self.processor.process(iprot, oprot) |
| except TTransport.TTransportException, tx: |
| pass |
| except Exception, x: |
| print x |
| |
| itrans.close() |
| otrans.close() |
| |
| class StatestoreSubscriber(object): |
| """A bare-bones subscriber skeleton. Tests should create a new StatestoreSubscriber(), |
| call start() and then register(). The subscriber will run a Thrift server on an unused |
| port, and after registration the statestore will call Heartbeat() and UpdateState() via |
| RPC. Tests can provide callbacks to the constructor that will be called during those |
| RPCs, and this is the easiest way to check that the statestore protocol is being |
| correctly followed. Tests should use wait_for_* methods to confirm that some event (like |
| an RPC call) has happened asynchronously. |
| |
| Since RPC callbacks will execute on a different thread from the main one, any assertions |
| there will not trigger a test failure without extra plumbing. What we do is simple: any |
| exceptions during an RPC are caught and stored, and the check_thread_exceptions() method |
| will re-raise them. |
| |
| The methods that may be called by a test deliberately return 'self' to allow for |
| chaining, see test_failure_detected() for an example of how this makes the test flow |
| more readable.""" |
| def __init__(self, heartbeat_cb=None, update_cb=None): |
| self.heartbeat_event, self.heartbeat_count = threading.Condition(), 0 |
| # Track the number of updates received per topic. |
| self.update_counts = defaultdict(lambda : 0) |
| # Variables to notify for updates on each topic. |
| self.update_event = threading.Condition() |
| self.heartbeat_cb, self.update_cb = heartbeat_cb, update_cb |
| self.subscriber_id = "python-test-client-%s" % uuid.uuid1() |
| self.exception = None |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, *args): |
| self.kill() |
| self.wait_for_failure() |
| |
| def Heartbeat(self, args): |
| """Heartbeat RPC handler. Calls heartbeat callback if one exists.""" |
| self.heartbeat_event.acquire() |
| try: |
| self.heartbeat_count += 1 |
| response = Subscriber.THeartbeatResponse() |
| if self.heartbeat_cb is not None and self.exception is None: |
| try: |
| response = self.heartbeat_cb(self, args) |
| except Exception, e: |
| self.exception = e |
| self.heartbeat_event.notify() |
| finally: |
| self.heartbeat_event.release() |
| return response |
| |
| def UpdateState(self, args): |
| """UpdateState RPC handler. Calls update callback if one exists.""" |
| self.update_event.acquire() |
| try: |
| for topic_name in args.topic_deltas: self.update_counts[topic_name] += 1 |
| response = DEFAULT_UPDATE_STATE_RESPONSE |
| if self.update_cb is not None and self.exception is None: |
| try: |
| response = self.update_cb(self, args) |
| except Exception, e: |
| # Print the original backtrace so it doesn't get lost. |
| traceback.print_exc() |
| self.exception = e |
| self.update_event.notify() |
| finally: |
| self.update_event.release() |
| return response |
| |
| def __init_server(self): |
| processor = Subscriber.Processor(self) |
| transport = WildcardServerSocket() |
| tfactory = TTransport.TBufferedTransportFactory() |
| pfactory = TBinaryProtocol.TBinaryProtocolFactory() |
| self.server = KillableThreadedServer(processor, transport, tfactory, pfactory, |
| daemon=True) |
| self.server_thread = threading.Thread(target=self.server.serve) |
| self.server_thread.setDaemon(True) |
| self.server_thread.start() |
| self.server.wait_until_up() |
| self.port = self.server.port |
| |
| def __init_client(self): |
| self.client_transport = \ |
| TTransport.TBufferedTransport(TSocket.TSocket('localhost', 24000)) |
| self.protocol = TBinaryProtocol.TBinaryProtocol(self.client_transport) |
| self.client = Statestore.Client(self.protocol) |
| self.client_transport.open() |
| |
| def check_thread_exceptions(self): |
| """Checks if an exception was raised and stored in a callback thread""" |
| if self.exception is not None: raise self.exception |
| |
| def kill(self): |
| """Closes both the server and client sockets, and waits for the server to become |
| unavailable""" |
| if self.client_transport: |
| self.client_transport.close() |
| if self.server: |
| self.server.shutdown() |
| return self |
| |
| def start(self): |
| """Starts a subscriber server, and opens a client to the statestore. Returns only when |
| the server is running.""" |
| self.__init_server() |
| self.__init_client() |
| return self |
| |
| def register(self, topics=None): |
| """Call the Register() RPC""" |
| if topics is None: topics = [] |
| request = Subscriber.TRegisterSubscriberRequest( |
| topic_registrations=topics, |
| subscriber_location=TNetworkAddress("localhost", self.port), |
| subscriber_id=self.subscriber_id) |
| response = self.client.RegisterSubscriber(request) |
| if response.status.status_code == TErrorCode.OK: |
| self.registration_id = response.registration_id |
| else: |
| raise Exception("Registration failed: %s, %s" % |
| (response.status.status_code, |
| '\n'.join(response.status.error_msgs))) |
| return self |
| |
| def wait_for_heartbeat(self, count=None): |
| """Waits for some number of heartbeats. If 'count' is provided, waits until the number |
| of heartbeats seen by this subscriber exceeds count, otherwise waits for one further |
| heartbeat.""" |
| self.heartbeat_event.acquire() |
| try: |
| if count is not None and self.heartbeat_count >= count: return self |
| if count is None: count = self.heartbeat_count + 1 |
| while count > self.heartbeat_count: |
| self.check_thread_exceptions() |
| last_count = self.heartbeat_count |
| self.heartbeat_event.wait(WAIT_FOR_HEARTBEAT_TIMEOUT) |
| if last_count == self.heartbeat_count: |
| raise Exception( |
| "Heartbeat not received within {0}s (heartbeat count: {1})".format( |
| WAIT_FOR_HEARTBEAT_TIMEOUT, self.heartbeat_count)) |
| self.check_thread_exceptions() |
| return self |
| finally: |
| self.heartbeat_event.release() |
| |
| def wait_for_update(self, topic_name, count=None): |
| """Waits for some number of updates of 'topic_name'. If 'count' is provided, waits |
| until the number updates seen by this subscriber exceeds count, otherwise waits |
| for one further update.""" |
| self.update_event.acquire() |
| start_time = time.time() |
| try: |
| if count is not None and self.update_counts[topic_name] >= count: return self |
| if count is None: count = self.update_counts[topic_name] + 1 |
| while count > self.update_counts[topic_name]: |
| self.check_thread_exceptions() |
| last_count = self.update_counts[topic_name] |
| self.update_event.wait(WAIT_FOR_UPDATE_TIMEOUT) |
| if (time.time() > start_time + WAIT_FOR_UPDATE_TIMEOUT and |
| last_count == self.update_counts[topic_name]): |
| raise Exception( |
| "Update not received for {0} within {1} (update count: {2})".format( |
| topic_name, WAIT_FOR_UPDATE_TIMEOUT, last_count)) |
| self.check_thread_exceptions() |
| return self |
| finally: |
| self.update_event.release() |
| |
| |
| def wait_for_failure(self, timeout=WAIT_FOR_FAILURE_TIMEOUT): |
| """Waits until this subscriber no longer appears in the statestore's subscriber |
| list. If 'timeout' seconds pass, throws an exception.""" |
| start = time.time() |
| while True: |
| subs = [s["id"] for s in get_statestore_subscribers()["subscribers"]] |
| if self.subscriber_id not in subs: return self |
| if time.time() - start > timeout: |
| raise Exception("Subscriber {0} did not fail in {1}s".format( |
| self.subscriber_id, timeout)) |
| time.sleep(0.2) |
| |
| |
| @SkipIfDockerizedCluster.statestore_not_exposed |
| class TestStatestore(): |
| def make_topic_update(self, topic_name, key_template="foo", value_template="bar", |
| num_updates=1, clear_topic_entries=False): |
| topic_entries = [ |
| Subscriber.TTopicItem(key=key_template + str(x), value=value_template + str(x)) |
| for x in xrange(num_updates)] |
| return Subscriber.TTopicDelta(topic_name=topic_name, |
| topic_entries=topic_entries, |
| is_delta=False, |
| clear_topic_entries=clear_topic_entries) |
| |
| def test_registration_ids_different(self): |
| """Test that if a subscriber with the same id registers twice, the registration ID is |
| different""" |
| with StatestoreSubscriber() as sub: |
| sub.start().register() |
| old_reg_id = sub.registration_id |
| sub.register() |
| assert old_reg_id != sub.registration_id |
| |
| def test_receive_heartbeats(self): |
| """Smoke test to confirm that heartbeats get sent to a correctly registered |
| subscriber""" |
| with StatestoreSubscriber() as sub: |
| sub.start().register().wait_for_heartbeat(5) |
| |
| def test_receive_updates(self): |
| """Test that updates are correctly received when a subscriber alters a topic""" |
| topic_name = "topic_delta_%s" % uuid.uuid1() |
| |
| def topic_update_correct(sub, args): |
| delta = self.make_topic_update(topic_name) |
| update_count = sub.update_counts[topic_name] |
| if topic_name not in args.topic_deltas: |
| # The update doesn't contain our topic. |
| pass |
| elif update_count == 1: |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[delta], |
| skipped=False) |
| elif update_count == 2: |
| assert len(args.topic_deltas) == 1, args.topic_deltas |
| assert args.topic_deltas[topic_name].topic_entries == delta.topic_entries |
| assert args.topic_deltas[topic_name].topic_name == delta.topic_name |
| elif update_count == 3: |
| # After the content-bearing update was processed, the next delta should be empty |
| assert len(args.topic_deltas[topic_name].topic_entries) == 0 |
| |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| with StatestoreSubscriber(update_cb=topic_update_correct) as sub: |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=False) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 3) |
| ) |
| |
| def test_filter_prefix(self): |
| topic_name = "topic_delta_%s" % uuid.uuid1() |
| |
| def topic_update_correct(sub, args): |
| foo_delta = self.make_topic_update(topic_name, num_updates=1) |
| bar_delta = self.make_topic_update(topic_name, num_updates=2, key_template='bar') |
| |
| update_count = sub.update_counts[topic_name] |
| if topic_name not in args.topic_deltas: |
| # The update doesn't contain our topic. |
| pass |
| elif update_count == 1: |
| # Send some values with both prefixes. |
| return TUpdateStateResponse(status=STATUS_OK, |
| topic_updates=[foo_delta, bar_delta], |
| skipped=False) |
| elif update_count == 2: |
| # We should only get the 'bar' entries back. |
| assert len(args.topic_deltas) == 1, args.topic_deltas |
| assert args.topic_deltas[topic_name].topic_entries == bar_delta.topic_entries |
| assert args.topic_deltas[topic_name].topic_name == bar_delta.topic_name |
| elif update_count == 3: |
| # Send some more updates that only have 'foo' prefixes. |
| return TUpdateStateResponse(status=STATUS_OK, |
| topic_updates=[foo_delta], |
| skipped=False) |
| elif update_count == 4: |
| # We shouldn't see any entries from the above update, but we should still see |
| # the version number change due to the new entries in the topic. |
| assert len(args.topic_deltas[topic_name].topic_entries) == 0 |
| assert args.topic_deltas[topic_name].from_version == 3 |
| assert args.topic_deltas[topic_name].to_version == 4 |
| elif update_count == 5: |
| # After the content-bearing update was processed, the next delta should be empty |
| assert len(args.topic_deltas[topic_name].topic_entries) == 0 |
| assert args.topic_deltas[topic_name].from_version == 4 |
| assert args.topic_deltas[topic_name].to_version == 4 |
| |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| with StatestoreSubscriber(update_cb=topic_update_correct) as sub: |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=False, |
| filter_prefix="bar") |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 5) |
| ) |
| |
| def test_update_is_delta(self): |
| """Test that the 'is_delta' flag is correctly set. The first update for a topic should |
| always not be a delta, and so should all subsequent updates until the subscriber says |
| it has not skipped the update.""" |
| topic_name = "test_update_is_delta_%s" % uuid.uuid1() |
| |
| def check_delta(sub, args): |
| update_count = sub.update_counts[topic_name] |
| if topic_name not in args.topic_deltas: |
| # The update doesn't contain our topic. |
| pass |
| elif update_count == 1: |
| assert args.topic_deltas[topic_name].is_delta == False |
| delta = self.make_topic_update(topic_name) |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[delta], |
| skipped=False) |
| elif update_count == 2: |
| assert args.topic_deltas[topic_name].is_delta == False |
| elif update_count == 3: |
| assert args.topic_deltas[topic_name].is_delta == True |
| assert len(args.topic_deltas[topic_name].topic_entries) == 0 |
| assert args.topic_deltas[topic_name].to_version == 1 |
| |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| with StatestoreSubscriber(update_cb=check_delta) as sub: |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=False) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 3) |
| ) |
| |
| def test_skipped(self): |
| """Test that skipping an update causes it to be resent""" |
| topic_name = "test_skipped_%s" % uuid.uuid1() |
| |
| def check_skipped(sub, args): |
| # Ignore responses that don't contain our topic. |
| if topic_name not in args.topic_deltas: return DEFAULT_UPDATE_STATE_RESPONSE |
| update_count = sub.update_counts[topic_name] |
| if update_count == 1: |
| update = self.make_topic_update(topic_name) |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[update], |
| skipped=False) |
| # All subsequent updates: set skipped=True and expected the full topic to be resent |
| # every time |
| assert args.topic_deltas[topic_name].is_delta == False |
| assert len(args.topic_deltas[topic_name].topic_entries) == 1 |
| return TUpdateStateResponse(status=STATUS_OK, skipped=True) |
| |
| with StatestoreSubscriber(update_cb=check_skipped) as sub: |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=False) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 3) |
| ) |
| |
| def test_failure_detected(self): |
| with StatestoreSubscriber() as sub: |
| topic_name = "test_failure_detected" |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=True) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 1) |
| .kill() |
| .wait_for_failure() |
| ) |
| |
| def test_hung_heartbeat(self): |
| """Test for IMPALA-1712: If heartbeats hang (which we simulate by sleeping for five |
| minutes) the statestore should time them out every 3s and then eventually fail after |
| 40s (10 times (3 + 1), where the 1 is the inter-heartbeat delay)""" |
| with StatestoreSubscriber(heartbeat_cb=lambda sub, args: time.sleep(300)) as sub: |
| topic_name = "test_hung_heartbeat" |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=True) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 1) |
| .wait_for_failure(timeout=60) |
| ) |
| |
| def test_intermittent_hung_heartbeats(self): |
| """Heartbeats that occasionally time out should not cause a failure to be detected.""" |
| heartbeat_count = [0] # Use array to allow mutating from inside callback. |
| |
| def heartbeat_cb(sub, args): |
| heartbeat_count[0] += 1 |
| # Delay every second heartbeat. |
| if (heartbeat_count[0] % 2 == 1): |
| time.sleep(4) |
| return Subscriber.THeartbeatResponse() |
| |
| with StatestoreSubscriber(heartbeat_cb=heartbeat_cb) as sub: |
| topic_name = "test_intermittent_hung_heartbeats" |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=True) |
| ( |
| sub.start() |
| .register(topics=[reg]) |
| .wait_for_update(topic_name, 30) |
| .kill() |
| .wait_for_failure() |
| ) |
| |
| def test_slow_subscriber(self): |
| """Test for IMPALA-6644: This test kills a healthy subscriber and sleeps for multiple |
| intervals of about 1 second each, this lets the heartbeats to the subscriber fail. |
| It polls the subscribers page of the statestore to ensure that the |
| 'secs_since_heartbeat' field is updated with an acceptable value. This test only |
| checks for a strictly increasing value since the actual value of time might depend |
| on the system load. It stops polling the page once the subscriber is removed from |
| the set of active subscribers. It also checks that a valid heartbeat record of the |
| subscriber is found at least once.""" |
| sub = StatestoreSubscriber() |
| sub.start().register().wait_for_heartbeat(1) |
| sub.kill() |
| # secs_since_heartbeat is initially unknown. |
| secs_since_heartbeat = -1 |
| valid_heartbeat_record = False |
| while secs_since_heartbeat != 0: |
| sleep_start_time = time.time() |
| while time.time() - sleep_start_time < 1: |
| time.sleep(0.1) |
| prev_secs_since_heartbeat = secs_since_heartbeat |
| secs_since_heartbeat = 0 |
| subscribers = get_statestore_subscribers()["subscribers"] |
| for s in subscribers: |
| if str(s["id"]) == sub.subscriber_id: |
| secs_since_heartbeat = float(s["secs_since_heartbeat"]) |
| assert (secs_since_heartbeat > prev_secs_since_heartbeat) |
| valid_heartbeat_record = True |
| assert valid_heartbeat_record |
| |
| def test_topic_persistence(self): |
| """Test that persistent topic entries survive subscriber failure, but transent topic |
| entries are erased when the associated subscriber fails""" |
| topic_id = str(uuid.uuid1()) |
| persistent_topic_name = "test_topic_persistence_persistent_%s" % topic_id |
| transient_topic_name = "test_topic_persistence_transient_%s" % topic_id |
| |
| def add_entries(sub, args): |
| # None of, one or both of the topics may be in the update. |
| updates = [] |
| if (persistent_topic_name in args.topic_deltas and |
| sub.update_counts[persistent_topic_name] == 1): |
| updates.append(self.make_topic_update(persistent_topic_name)) |
| |
| if (transient_topic_name in args.topic_deltas and |
| sub.update_counts[transient_topic_name] == 1): |
| updates.append(self.make_topic_update(transient_topic_name)) |
| |
| if len(updates) > 0: |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=updates, |
| skipped=False) |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| def check_entries(sub, args): |
| # None of, one or both of the topics may be in the update. |
| if (persistent_topic_name in args.topic_deltas and |
| sub.update_counts[persistent_topic_name] == 1): |
| assert len(args.topic_deltas[persistent_topic_name].topic_entries) == 1 |
| # Statestore should not send deletions when the update is not a delta, see |
| # IMPALA-1891 |
| assert args.topic_deltas[persistent_topic_name].topic_entries[0].deleted == False |
| if (transient_topic_name in args.topic_deltas and |
| sub.update_counts[persistent_topic_name] == 1): |
| assert len(args.topic_deltas[transient_topic_name].topic_entries) == 0 |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| reg = [TTopicRegistration(topic_name=persistent_topic_name, is_transient=False), |
| TTopicRegistration(topic_name=transient_topic_name, is_transient=True)] |
| |
| with StatestoreSubscriber(update_cb=add_entries) as sub: |
| ( |
| sub.start() |
| .register(topics=reg) |
| .wait_for_update(persistent_topic_name, 2) |
| .wait_for_update(transient_topic_name, 2) |
| .kill() |
| .wait_for_failure() |
| ) |
| |
| with StatestoreSubscriber(update_cb=check_entries) as sub2: |
| ( |
| sub2.start() |
| .register(topics=reg) |
| .wait_for_update(persistent_topic_name, 1) |
| .wait_for_update(transient_topic_name, 1) |
| ) |
| |
| def test_update_with_clear_entries_flag(self): |
| """Test that the statestore clears all topic entries when a subscriber |
| sets the clear_topic_entries flag in a topic update message (IMPALA-6948).""" |
| topic_name = "test_topic_%s" % str(uuid.uuid1()) |
| |
| def add_entries(sub, args): |
| updates = [] |
| if (topic_name in args.topic_deltas and sub.update_counts[topic_name] == 1): |
| updates.append(self.make_topic_update(topic_name, num_updates=2, |
| key_template="old")) |
| |
| if (topic_name in args.topic_deltas and sub.update_counts[topic_name] == 2): |
| updates.append(self.make_topic_update(topic_name, num_updates=1, |
| key_template="new", clear_topic_entries=True)) |
| |
| if len(updates) > 0: |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=updates, |
| skipped=False) |
| |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| def check_entries(sub, args): |
| if (topic_name in args.topic_deltas and sub.update_counts[topic_name] == 1): |
| assert len(args.topic_deltas[topic_name].topic_entries) == 1 |
| assert args.topic_deltas[topic_name].topic_entries[0].key == "new0" |
| |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| reg = [TTopicRegistration(topic_name=topic_name, is_transient=False)] |
| with StatestoreSubscriber(update_cb=add_entries) as sub1: |
| ( |
| sub1.start() |
| .register(topics=reg) |
| .wait_for_update(topic_name, 1) |
| .kill() |
| .wait_for_failure() |
| .start() |
| .register(topics=reg) |
| .wait_for_update(topic_name, 2) |
| ) |
| |
| with StatestoreSubscriber(update_cb=check_entries) as sub2: |
| ( |
| sub2.start() |
| .register(topics=reg) |
| .wait_for_update(topic_name, 2) |
| ) |
| |
| def test_heartbeat_failure_reset(self): |
| """Regression test for IMPALA-6785: the heartbeat failure count for the subscriber ID |
| should be reset when it resubscribes, not after the first successful heartbeat. Delay |
| the heartbeat to force the topic update to finish first.""" |
| |
| with StatestoreSubscriber(heartbeat_cb=lambda sub, args: time.sleep(0.5)) as sub: |
| topic_name = "test_heartbeat_failure_reset" |
| reg = TTopicRegistration(topic_name=topic_name, is_transient=True) |
| sub.start() |
| sub.register(topics=[reg]) |
| LOG.info("Registered with id {0}".format(sub.subscriber_id)) |
| sub.wait_for_heartbeat(1) |
| sub.kill() |
| LOG.info("Killed, waiting for statestore to detect failure via heartbeats") |
| sub.wait_for_failure() |
| # IMPALA-6785 caused only one topic update to be send. Wait for multiple updates to |
| # be received to confirm that the subsequent updates are being scheduled repeatedly. |
| target_updates = sub.update_counts[topic_name] + 5 |
| sub.start() |
| sub.register(topics=[reg]) |
| LOG.info("Re-registered with id {0}, waiting for update".format(sub.subscriber_id)) |
| sub.wait_for_update(topic_name, target_updates) |
| |
| def test_min_subscriber_topic_version(self): |
| self._do_test_min_subscriber_topic_version(False) |
| |
| def test_min_subscriber_topic_version_with_straggler(self): |
| self._do_test_min_subscriber_topic_version(True) |
| |
| def _do_test_min_subscriber_topic_version(self, simulate_straggler): |
| """Implementation of test that the 'min_subscriber_topic_version' flag is correctly |
| set when requested. This tests runs two subscribers concurrently and tracks the |
| minimum version each has processed. If 'simulate_straggler' is true, one subscriber |
| rejects updates so that its version is not advanced.""" |
| topic_name = "test_min_subscriber_topic_version_%s" % uuid.uuid1() |
| |
| # This lock is held while processing the update to protect last_to_versions. |
| update_lock = threading.Lock() |
| last_to_versions = {} |
| TOTAL_SUBSCRIBERS = 2 |
| def callback(sub, args, is_producer, sub_name): |
| """Callback for subscriber to verify min_subscriber_topic_version behaviour. |
| If 'is_producer' is true, this acts as the producer, otherwise it acts as the |
| consumer. 'sub_name' is a name used to index into last_to_versions.""" |
| if topic_name not in args.topic_deltas: |
| # The update doesn't contain our topic. |
| pass |
| with update_lock: |
| LOG.info("{0} got update {1}".format(sub_name, |
| repr(args.topic_deltas[topic_name]))) |
| LOG.info("Versions: {0}".format(last_to_versions)) |
| to_version = args.topic_deltas[topic_name].to_version |
| from_version = args.topic_deltas[topic_name].from_version |
| min_subscriber_topic_version = \ |
| args.topic_deltas[topic_name].min_subscriber_topic_version |
| |
| if is_producer: |
| assert min_subscriber_topic_version is not None |
| assert (to_version == 0 and min_subscriber_topic_version == 0) or\ |
| min_subscriber_topic_version < to_version,\ |
| "'to_version' hasn't been created yet by this subscriber." |
| # Only validate version once all subscribers have processed an update. |
| if len(last_to_versions) == TOTAL_SUBSCRIBERS: |
| min_to_version = min(last_to_versions.values()) |
| assert min_subscriber_topic_version <= min_to_version,\ |
| "The minimum subscriber topic version seen by the producer cannot get " +\ |
| "ahead of the minimum version seem by the consumer, by definition." |
| assert min_subscriber_topic_version >= min_to_version - 2,\ |
| "The min topic version can be two behind the last version seen by " + \ |
| "this subscriber because the updates for both subscribers are " + \ |
| "prepared in parallel and because it's possible that the producer " + \ |
| "processes two updates in-between consumer updates. This is not " + \ |
| "absolute but depends on updates not being delayed a large amount." |
| else: |
| # Consumer did not request topic version. |
| assert min_subscriber_topic_version is None |
| |
| # Check the 'to_version' and update 'last_to_versions'. |
| last_to_version = last_to_versions.get(sub_name, 0) |
| if to_version > 0: |
| # Non-empty update. |
| assert from_version == last_to_version |
| # Stragglers should accept the first update then skip later ones. |
| skip_update = simulate_straggler and not is_producer and last_to_version > 0 |
| if not skip_update: last_to_versions[sub_name] = to_version |
| |
| if is_producer: |
| delta = self.make_topic_update(topic_name) |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[delta], |
| skipped=False) |
| elif skip_update: |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[], skipped=True) |
| else: |
| return DEFAULT_UPDATE_STATE_RESPONSE |
| |
| # Two concurrent subscribers, which pushes out updates and checks the minimum |
| # version, the other which just consumes the updates. |
| def producer_callback(sub, args): return callback(sub, args, True, "producer") |
| def consumer_callback(sub, args): return callback(sub, args, False, "consumer") |
| with StatestoreSubscriber(update_cb=consumer_callback) as consumer_sub: |
| with StatestoreSubscriber(update_cb=producer_callback) as producer_sub: |
| consumer_reg = TTopicRegistration(topic_name=topic_name, is_transient=True) |
| producer_reg = TTopicRegistration(topic_name=topic_name, is_transient=True, |
| populate_min_subscriber_topic_version=True) |
| NUM_UPDATES = 6 |
| ( |
| consumer_sub.start() |
| .register(topics=[consumer_reg]) |
| ) |
| ( |
| producer_sub.start() |
| .register(topics=[producer_reg]) |
| .wait_for_update(topic_name, NUM_UPDATES) |
| ) |
| consumer_sub.wait_for_update(topic_name, NUM_UPDATES) |
| |
| def test_transient_entry_removal_race(self): |
| """IMPALA-7306: transient entries were not deleted if the subscriber is unregistered |
| while it is in the middle of a callback. This test exercises that case by blocking |
| the update callback so that it is still running when the statestore unregisters the |
| subscriber for failed heartbeats. It also confirms that non-transient entries are not |
| removed.""" |
| transient_topic_name = "test_transient_entry_removal_race_transient" |
| non_transient_topic_name = "test_transient_entry_removal_race_non_transient" |
| topic_regs = [TTopicRegistration(topic_name=transient_topic_name, is_transient=True), |
| TTopicRegistration(topic_name=non_transient_topic_name, is_transient=False)] |
| # The heartbeat timeout is 3s, so sleep for long enough for it to expire |
| HEARTBEAT_DELAY = 10 |
| |
| def delayed_heartbeat(sub, args): |
| LOG.info("Heartbeat callback called") |
| time.sleep(HEARTBEAT_DELAY) |
| LOG.debug("Heartbeat callback about to return") |
| |
| def add_transient_entries_after_hb_failure(sub, args): |
| LOG.info("Update callback called") |
| # Add an additional delay so that this returns after the heartbeat. |
| time.sleep(WAIT_FOR_FAILURE_TIMEOUT) |
| updates = [self.make_topic_update(transient_topic_name, "k", "v"), |
| self.make_topic_update(non_transient_topic_name, "k", "v")] |
| LOG.debug("Update callback about to return") |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=updates, skipped=False) |
| |
| # Subscriber with delay creates a transient entry, which should not be added since |
| # the subscriber failed and was unregistered. |
| with StatestoreSubscriber(heartbeat_cb=delayed_heartbeat, |
| update_cb=add_transient_entries_after_hb_failure) as sub: |
| # Wait for the first update (which should happen after failure), then confirm |
| # that the failure occurred. |
| ( |
| sub.start() |
| .register(topics=topic_regs) |
| .wait_for_update(transient_topic_name, 1) |
| .wait_for_failure(timeout=WAIT_FOR_FAILURE_TIMEOUT) |
| ) |
| |
| def verify_transient_entry_removed(sub, args): |
| transient_delta = args.topic_deltas[transient_topic_name] |
| assert len(transient_delta.topic_entries) == 0, args |
| non_transient_delta = args.topic_deltas[non_transient_topic_name] |
| # Non-transient update should include topic that was not removed |
| assert len(non_transient_delta.topic_entries) == 1, args |
| entry = non_transient_delta.topic_entries[0] |
| assert entry.key == "k0" |
| assert entry.value == "v0" |
| assert not entry.deleted |
| # Skip updates so that statestore will re-send non-transient entries and the above |
| # assertions remain valid on subsequent callbacks. |
| return TUpdateStateResponse(status=STATUS_OK, topic_updates=[], skipped=True) |
| |
| # Verify that the transient entry for the failed subscriber is not present. |
| with StatestoreSubscriber(update_cb=verify_transient_entry_removed) as sub: |
| ( |
| sub.start() |
| .register(topics=topic_regs) |
| .wait_for_update(transient_topic_name, 1) |
| ) |