blob: 252f70a727977dcaa250dd4df748a587e10ad04f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A job server submitting portable pipelines as uber jars to Spark."""
# pytype: skip-file
from __future__ import absolute_import
from __future__ import print_function
import itertools
import logging
import os
import tempfile
import time
import urllib
import zipfile
import requests
from apache_beam.options import pipeline_options
from apache_beam.portability.api import beam_job_api_pb2
from apache_beam.runners.portability import abstract_job_service
from apache_beam.runners.portability import job_server
from apache_beam.utils.timestamp import Timestamp
_LOGGER = logging.getLogger(__name__)
class SparkUberJarJobServer(abstract_job_service.AbstractJobServiceServicer):
"""A Job server which submits a self-contained Jar to a Spark cluster.
The jar contains the Beam pipeline definition, dependencies, and
the pipeline artifacts.
"""
def __init__(self, rest_url, options):
super(SparkUberJarJobServer, self).__init__()
self._rest_url = rest_url
self._executable_jar = (
options.view_as(
pipeline_options.SparkRunnerOptions).spark_job_server_jar)
self._artifact_port = (
options.view_as(pipeline_options.JobServerOptions).artifact_port)
self._temp_dir = tempfile.mkdtemp(prefix='apache-beam-spark')
def start(self):
return self
def stop(self):
pass
def executable_jar(self):
if self._executable_jar:
if not os.path.exists(self._executable_jar):
parsed = urllib.parse.urlparse(self._executable_jar)
if not parsed.scheme:
raise ValueError(
'Unable to parse jar URL "%s". If using a full URL, make sure '
'the scheme is specified. If using a local file path, make sure '
'the file exists; you may have to first build the job server '
'using `./gradlew runners:spark:job-server:shadowJar`.' %
self._executable_jar)
url = self._executable_jar
else:
url = job_server.JavaJarJobServer.path_to_beam_jar(
'runners:spark:job-server:shadowJar')
return job_server.JavaJarJobServer.local_jar(url)
def create_beam_job(self, job_id, job_name, pipeline, options):
return SparkBeamJob(
self._rest_url,
self.executable_jar(),
job_id,
job_name,
pipeline,
options,
artifact_port=self._artifact_port)
class SparkBeamJob(abstract_job_service.UberJarBeamJob):
"""Runs a single Beam job on Spark by staging all contents into a Jar
and uploading it via the Spark Rest API.
Note that the Spark Rest API is not enabled by default. It must be enabled by
setting the configuration property spark.master.rest.enabled to true."""
def __init__(
self,
rest_url,
executable_jar,
job_id,
job_name,
pipeline,
options,
artifact_port=0):
super(SparkBeamJob, self).__init__(
executable_jar,
job_id,
job_name,
pipeline,
options,
artifact_port=artifact_port)
self._rest_url = rest_url
# Message history is a superset of state history.
self._message_history = self._state_history[:]
def request(self, method, path, expected_status=200, **kwargs):
url = '%s/%s' % (self._rest_url, path)
response = method(url, **kwargs)
if response.status_code != expected_status:
raise RuntimeError(
"Request to %s failed with status %d: %s" %
(url, response.status_code, response.text))
if response.text:
return response.json()
def get(self, path, **kwargs):
return self.request(requests.get, path, **kwargs)
def post(self, path, **kwargs):
return self.request(requests.post, path, **kwargs)
def delete(self, path, **kwargs):
return self.request(requests.delete, path, **kwargs)
def _get_server_spark_version(self):
# Spark REST API doesn't seem to offer a dedicated endpoint for getting the
# version, but it does include the version in all responses, even errors.
return self.get('', expected_status=400)['serverSparkVersion']
def _get_client_spark_version_from_properties(self, jar):
"""Parse Spark version from spark-version-info.properties file in the jar.
https://github.com/apache/spark/blob/dddfeca175bdce5294debe00d4a993daef92ca60/build/spark-build-info#L30
"""
with zipfile.ZipFile(jar, 'a', compression=zipfile.ZIP_DEFLATED) as z:
with z.open('spark-version-info.properties') as fin:
for line in fin.read().decode('utf-8').splitlines():
split = list(map(lambda s: s.strip(), line.split('=')))
if len(split) == 2 and split[0] == 'version' and split[1] != '':
return split[1]
raise ValueError(
'Property "version" not found in spark-version-info.properties.')
def _get_client_spark_version(self, jar):
try:
return self._get_client_spark_version_from_properties(jar)
except Exception as e:
_LOGGER.debug(e)
server_version = self._get_server_spark_version()
_LOGGER.warning(
'Unable to parse Spark version from '
'spark-version-info.properties. Defaulting to %s' % server_version)
return server_version
def _create_submission_request(self, jar, job_name):
jar_url = "file:%s" % jar
return {
"action": "CreateSubmissionRequest",
"appArgs": [],
"appResource": jar_url,
"clientSparkVersion": self._get_client_spark_version(jar),
"environmentVariables": {},
"mainClass": "org.apache.beam.runners.spark.SparkPipelineRunner",
"sparkProperties": {
"spark.jars": jar_url,
"spark.app.name": job_name,
"spark.submit.deployMode": "cluster",
}
}
def run(self):
self._stop_artifact_service()
# Upload the jar and start the job.
self._spark_submission_id = self.post(
'v1/submissions/create',
json=self._create_submission_request(self._jar,
self._job_name))['submissionId']
_LOGGER.info('Submitted Spark job with ID %s' % self._spark_submission_id)
def cancel(self):
self.post('v1/submissions/kill/%s' % self._spark_submission_id)
@staticmethod
def _get_beam_state(spark_response):
return {
'SUBMITTED': beam_job_api_pb2.JobState.STARTING,
'RUNNING': beam_job_api_pb2.JobState.RUNNING,
'FINISHED': beam_job_api_pb2.JobState.DONE,
'RELAUNCHING': beam_job_api_pb2.JobState.RUNNING,
'UNKNOWN': beam_job_api_pb2.JobState.UNSPECIFIED,
'KILLED': beam_job_api_pb2.JobState.CANCELLED,
'FAILED': beam_job_api_pb2.JobState.FAILED,
'ERROR': beam_job_api_pb2.JobState.FAILED,
}.get(spark_response['driverState'], beam_job_api_pb2.JobState.UNSPECIFIED)
def _get_spark_status(self):
return self.get('v1/submissions/status/%s' % self._spark_submission_id)
def get_state(self):
response = self._get_spark_status()
state = self._get_beam_state(response)
timestamp = self.set_state(state)
if timestamp is None:
# State has not changed since last check. Use previous timestamp.
return super(SparkBeamJob, self).get_state()
else:
return state, timestamp
def _with_message_history(self, message_stream):
return itertools.chain(self._message_history[:], message_stream)
def _get_message_iter(self):
"""Returns an iterator of messages from the Spark server.
Note that while message history is de-duped, this function's returned
iterator may contain duplicate values."""
sleep_secs = 1.0
message_ix = 0
while True:
response = self._get_spark_status()
state = self._get_beam_state(response)
timestamp = Timestamp.now()
message = None
if 'message' in response:
importance = (
beam_job_api_pb2.JobMessage.MessageImportance.JOB_MESSAGE_ERROR
if state == beam_job_api_pb2.JobState.FAILED else
beam_job_api_pb2.JobMessage.MessageImportance.JOB_MESSAGE_BASIC)
message = beam_job_api_pb2.JobMessage(
message_id='message%d' % message_ix,
time=str(int(timestamp)),
importance=importance,
message_text=response['message'])
yield message
message_ix += 1
# TODO(BEAM-8983) In the event of a failure, query
# additional info from Spark master and/or workers.
check_timestamp = self.set_state(state)
if check_timestamp is not None:
if message:
self._message_history.append(message)
self._message_history.append((state, check_timestamp))
yield state, timestamp
sleep_secs = min(60, sleep_secs * 1.2)
time.sleep(sleep_secs)
def get_state_stream(self):
for msg in self._with_message_history(self._get_message_iter()):
if isinstance(msg, tuple):
state, timestamp = msg
yield state, timestamp
if self.is_terminal_state(state):
break
def get_message_stream(self):
for msg in self._with_message_history(self._get_message_iter()):
yield msg
if isinstance(msg, tuple):
state, _ = msg
if self.is_terminal_state(state):
break