blob: 42f6dd696019cf0d83fdc765038451003ddb9b32 [file]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Detect Thermos tasks on disk
This module contains the TaskDetector, used to detect Thermos tasks within a given checkpoint root.
"""
import functools
import glob
import os
import re
from abc import abstractmethod
from twitter.common.lang import Compatibility, Interface
from apache.thermos.common.constants import DEFAULT_CHECKPOINT_ROOT
from apache.thermos.common.path import TaskPath
class PathDetector(Interface):
@abstractmethod
def get_paths(self):
"""Get a list of valid checkpoint roots."""
class FixedPathDetector(PathDetector):
def __init__(self, path=DEFAULT_CHECKPOINT_ROOT):
if not isinstance(path, Compatibility.string):
raise TypeError('FixedPathDetector path should be a string, got %s' % type(path))
self._paths = [path]
def get_paths(self):
return self._paths[:]
class ChainedPathDetector(PathDetector):
def __init__(self, *detectors):
for detector in detectors:
if not isinstance(detector, PathDetector):
raise TypeError('Expected detector %r to be a PathDetector, got %s' % (
detector, type(detector)))
self._detectors = detectors
def get_paths(self):
def iterate():
for detector in self._detectors:
for path in detector.get_paths():
yield path
return list(set(iterate()))
def memoized(fn):
cache_attr_name = '__memoized_' + fn.__name__
@functools.wraps(fn)
def memoized_fn(self, *args):
if not hasattr(self, cache_attr_name):
setattr(self, cache_attr_name, {})
cache = getattr(self, cache_attr_name)
try:
return cache[args]
except KeyError:
cache[args] = rv = fn(self, *args)
return rv
return memoized_fn
class TaskDetector(object):
"""
Helper class in front of TaskPath to detect active/finished/running tasks. Performs no
introspection on the state of a task; merely detects based on file paths on disk.
"""
class Error(Exception): pass
class MatchingError(Error): pass
def __init__(self, root):
self._root_dir = root
self._pathspec = TaskPath()
@memoized
def __get_task_ids_patterns(self, state):
path_glob = self._pathspec.given(
root=self._root_dir,
task_id="*",
state=state or '*'
).getpath('task_path')
path_regex = self._pathspec.given(
root=re.escape(self._root_dir),
task_id=r'(\S+)',
state=r'(\S+)'
).getpath('task_path')
return path_glob, re.compile(path_regex)
def get_task_ids(self, state=None):
path_glob, path_regex = self.__get_task_ids_patterns(state)
for path in glob.glob(path_glob):
try:
task_state, task_id = path_regex.match(path).groups()
except Exception:
continue
if state is None or task_state == state:
yield (task_state, task_id)
@memoized
def __get_process_runs_patterns(self, task_id, log_dir):
path_glob = self._pathspec.given(
root=self._root_dir,
task_id=task_id,
log_dir=log_dir,
process='*',
run='*'
).getpath('process_logdir')
path_regex = self._pathspec.given(
root=re.escape(self._root_dir),
task_id=re.escape(task_id),
log_dir=log_dir,
process=r'(\S+)',
run=r'(\d+)'
).getpath('process_logdir')
return path_glob, re.compile(path_regex)
def get_process_runs(self, task_id, log_dir):
path_glob, path_regex = self.__get_process_runs_patterns(task_id, log_dir)
for path in glob.glob(path_glob):
try:
process, run = path_regex.match(path).groups()
except Exception:
continue
yield process, int(run)
def get_process_logs(self, task_id, log_dir):
for process, run in self.get_process_runs(task_id, log_dir):
for logtype in ('stdout', 'stderr'):
path = (self._pathspec.with_filename(logtype).given(root=self._root_dir,
task_id=task_id,
log_dir=log_dir,
process=process,
run=run)
.getpath('process_logdir'))
if os.path.exists(path):
yield path
def get_checkpoint(self, task_id):
return self._pathspec.given(root=self._root_dir, task_id=task_id).getpath('runner_checkpoint')
@memoized
def __get_process_checkpoints_patterns(self, task_id):
path_glob = self._pathspec.given(
root=self._root_dir,
task_id=task_id,
process='*'
).getpath('process_checkpoint')
path_regex = self._pathspec.given(
root=re.escape(self._root_dir),
task_id=re.escape(task_id),
process=r'(\S+)',
).getpath('process_checkpoint')
return path_glob, re.compile(path_regex)
def get_process_checkpoints(self, task_id):
path_glob, path_regex = self.__get_process_checkpoints_patterns(task_id)
for path in glob.glob(path_glob):
try:
process, = path_regex.match(path).groups()
except Exception:
continue
yield path