blob: d420a96db84039eb51c824b3872e520a4dd78b96 [file] [log] [blame]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from threading import Event
from twitter.common.quantity import Amount, Time
from .task_util import StatusHelper
from gen.apache.aurora.api.constants import LIVE_STATES, TERMINAL_STATES
from gen.apache.aurora.api.ttypes import JobKey, TaskQuery
class JobMonitor(object):
MIN_POLL_INTERVAL = Amount(2, Time.SECONDS)
MAX_POLL_INTERVAL = Amount(150, Time.SECONDS)
@classmethod
def running_or_finished(cls, status):
return status in (LIVE_STATES | TERMINAL_STATES)
@classmethod
def terminal(cls, status):
return status in TERMINAL_STATES
def __init__(self, scheduler, job_key, terminating_event=None,
min_poll_interval=MIN_POLL_INTERVAL, max_poll_interval=MAX_POLL_INTERVAL):
self._scheduler = scheduler
self._job_key = job_key
self._min_poll_interval = min_poll_interval
self._max_poll_interval = max_poll_interval
self._terminating = terminating_event or Event()
self._status_helper = StatusHelper(self._scheduler, self.create_query)
def iter_tasks(self, instances):
tasks = self._status_helper.get_tasks(instances)
for task in tasks:
yield task
def states(self, instance_ids):
states = {}
for task in self.iter_tasks(instance_ids):
status, instance_id = task.status, task.assignedTask.instanceId
first_timestamp = task.taskEvents[0].timestamp
if instance_id not in states or first_timestamp > states[instance_id][0]:
states[instance_id] = (first_timestamp, status)
return dict((instance_id, status[1]) for (instance_id, status) in states.items())
def create_query(self, instances=None):
return TaskQuery(
jobKeys=[JobKey(role=self._job_key.role,
environment=self._job_key.env,
name=self._job_key.name)],
instanceIds=frozenset([int(s) for s in instances]) if instances else None)
def wait_until(self, predicate, instances=None, with_timeout=False):
"""Given a predicate (from ScheduleStatus => Boolean), wait until all requested instances
return true for that predicate OR the timeout expires (if with_timeout=True)
Arguments:
predicate -- predicate to check completion with.
instances -- optional subset of job instances to wait for.
with_timeout -- if set, caps waiting time to the MAX_POLL_INTERVAL.
Returns: True if predicate is met or False if timeout has expired.
"""
poll_interval = self._min_poll_interval
while not self._terminating.is_set() and not all(predicate(state) for state
in self.states(instances).values()):
if with_timeout and poll_interval >= self._max_poll_interval:
return False
self._terminating.wait(poll_interval.as_(Time.SECONDS))
poll_interval = min(self._max_poll_interval, 2 * poll_interval)
return True
def terminate(self):
"""Requests immediate termination of the wait cycle."""
self._terminating.set()