blob: 278834c23d4674492968b235e58f7d15941aee06 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import cloudpickle
import sys
import threading
import traceback
from concurrent.futures import Future
from threading import Timer
# Possible job states.
PENDING = 'PENDING'
RUNNING = 'RUNNING'
CANCELLED = 'CANCELLED'
FINISHED = 'FINISHED'
SENT = 'SENT'
QUEUED = 'QUEUED'
class JobHandle(Future):
"""A child class of concurrent.futures.Future. Allows for monitoring and
controlling of the running remote job
"""
# Poll intervals in seconds
_JOB_INITIAL_POLL_INTERVAL = 0.1
_JOB_MAX_POLL_INTERVAL = 5
def __init__(self, conn, session_id, executor):
Future.__init__(self)
self._conn = conn
self._session_id = session_id
self._executor = executor
self._cancelled = False
self._job_id = -1
self._done = False
self._job_handle_condition = threading.Condition()
self._running_callbacks = []
self._queued_callbacks = []
def _start(self, command, serialized_job):
self._executor.submit(self._send_job_task, command, serialized_job)
def _send_job_task(self, command, job):
suffix_url = "/" + str(self._session_id) + "/" + command
job_status = self._conn.send_request('POST', suffix_url,
headers=self._conn._JSON_HEADERS, data=job)
self._job_id = job_status.json()['id']
self._poll_result()
def queued(self):
"""
Returns
-------
True if the job is currently queued.
"""
with self._job_handle_condition:
return self._state == QUEUED
def _invoke_queued_callbacks(self):
for callback in self._queued_callbacks:
try:
callback(self)
except Exception:
traceback.print_exc()
self._queued_callbacks.remove(callback)
def _invoke_running_callbacks(self):
for callback in self._running_callbacks:
try:
callback(self)
except Exception:
traceback.print_exc()
self._running_callbacks.remove(callback)
def add_queued_callback(self, fn):
"""
Attaches a callable that will be called when the job is queued.
Parameters
----------
fn : Function
A callable that will be called with this future as its only
argument when the job is queued.
"""
with self._job_handle_condition:
if self._state == PENDING:
self._queued_callbacks.append(fn)
elif self._state == QUEUED:
fn(self)
def add_running_callback(self, fn):
"""
Attaches a callable that will be called when the job is running.
Parameters
----------
fn : Function
A callable that will be called with this future as its only
argument when the job is running.
"""
with self._job_handle_condition:
if self._state in [PENDING, QUEUED]:
self._running_callbacks.append(fn)
elif self._state == RUNNING:
fn(self)
def _update_state(self, state):
with self._job_handle_condition:
if state == 'STARTED':
self._state = RUNNING
self._invoke_running_callbacks()
elif state == 'QUEUED':
self._state = QUEUED
self._invoke_queued_callbacks()
elif state == 'CANCELLED':
self._state = CANCELLED
self._invoke_callbacks()
else:
raise RuntimeError('Future in unexpected state::', self._state)
self._job_handle_condition.notifyAll()
def cancel(self):
"""
Cancel the job if possible.
Returns
-------
True if the job was cancelled, False if the job has already finished.
"""
with self._job_handle_condition:
if self._state == FINISHED:
return False
if self._state == CANCELLED:
return True
if self._job_id > -1:
self._executor.submit(self._send_cancel_request)
self._job_handle_condition.notify_all()
return True
def _send_cancel_request(self):
try:
end_point = "/" + str(self._session_id) + "/jobs/" + \
str(self._job_id) + "/cancel"
self._conn.send_json(None, end_point)
except Exception as err:
self.set_job_exception(err, traceback.format_exc())
traceback.print_exc()
def _poll_result(self):
def do_poll_result():
try:
suffix_url = "/" + str(self._session_id) + "/jobs/" + \
str(self._job_id)
job_status = self._conn.send_request('GET', suffix_url,
headers={'Accept': 'application/json'}).json()
job_state = job_status['state']
job_result = None
has_finished = False
job_error = None
if job_state == 'SUCCEEDED':
job_result = job_status['result']
has_finished = True
elif job_state == 'FAILED':
job_error = job_status['error']
has_finished = True
elif job_state == 'CANCELLED':
repeated_timer.stop()
else:
pass
if has_finished:
if job_result is not None:
b64_decoded = base64.b64decode(job_result)
b64_decoded_decoded = base64.b64decode(b64_decoded)
deserialized_object = cloudpickle.loads(
b64_decoded_decoded)
super(JobHandle, self).set_result(deserialized_object)
if job_error is not None:
self.set_job_exception(Exception(job_error))
repeated_timer.stop()
else:
self._update_state(job_state)
except Exception as err:
repeated_timer.stop()
traceback.print_exc()
self.set_job_exception(err, traceback.format_exc())
repeated_timer = self._RepeatedTimer(self._JOB_INITIAL_POLL_INTERVAL,
do_poll_result, self._executor)
repeated_timer.start()
def set_running_or_notify_cancel(self):
raise NotImplementedError("This operation is not supported.")
def set_result(self, result):
raise NotImplementedError("This operation is not supported.")
def set_exception_info(self, exception, traceback):
raise NotImplementedError("This operation is not supported.")
def set_exception(self, exception):
raise NotImplementedError("This operation is not supported.")
def set_job_exception(self, exception, error_msg=None):
if sys.version >= '3':
super(JobHandle, self).set_exception(exception)
else:
super(JobHandle, self).set_exception_info(exception, error_msg)
class _RepeatedTimer(object):
def __init__(self, interval, polling_job, executor):
self._timer = None
self.polling_job = polling_job
self.interval = interval
self.is_running = False
self.stop_called = False
self.executor = executor
def _run(self):
self.is_running = False
self.executor.submit(self.polling_job)
self.start()
def start(self):
if not self.is_running and not self.stop_called:
self._timer = Timer(self.interval, self._run)
self._timer.start()
self.interval = min(self.interval * 2,
JobHandle._JOB_MAX_POLL_INTERVAL)
self.is_running = True
def stop(self):
self._timer.cancel()
self.stop_called = True
self.is_running = False