blob: 5472f533ec9b9c4f1d3b181ae5654d9e4528892e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import print_function
import ast
from collections import OrderedDict
import datetime
import decimal
import io
import json
import logging
import sys
import traceback
import base64
import os
import re
import threading
import tempfile
import shutil
import pickle
import textwrap
if sys.version >= '3':
unicode = str
else:
import cStringIO
import StringIO
if sys.version_info > (3,8):
from ast import Module
else :
# mock the new API, ignore second argument
# see https://github.com/ipython/ipython/issues/11590
from ast import Module as OriginalModule
Module = lambda nodelist, type_ignores: OriginalModule(nodelist)
logging.basicConfig()
LOG = logging.getLogger('fake_shell')
global_dict = {}
job_context = None
local_tmp_dir_path = None
TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in <module>')
def execute_reply(status, content):
return {
'msg_type': 'execute_reply',
'content': dict(
content,
status=status,
)
}
def execute_reply_ok(data):
return execute_reply('ok', {
'data': data,
})
def execute_reply_error(exc_type, exc_value, tb):
LOG.error('execute_reply', exc_info=True)
if sys.version >= '3':
formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False)
else:
formatted_tb = traceback.format_exception(exc_type, exc_value, tb)
for i in range(len(formatted_tb)):
if TOP_FRAME_REGEX.match(formatted_tb[i]):
formatted_tb = formatted_tb[:1] + formatted_tb[i + 1:]
break
return execute_reply('error', {
'ename': unicode(exc_type.__name__),
'evalue': unicode(exc_value),
'traceback': formatted_tb,
})
def execute_reply_internal_error(message, exc_info=None):
LOG.error('execute_reply_internal_error', exc_info=exc_info)
return execute_reply('error', {
'ename': 'InternalError',
'evalue': message,
'traceback': [],
})
class JobContextImpl(object):
def __init__(self):
self.lock = threading.Lock()
self.sc = global_dict['sc']
self.sql_ctx = global_dict['sqlContext']
self.hive_ctx = None
self.streaming_ctx = None
self.local_tmp_dir_path = local_tmp_dir_path
self.spark_session = global_dict.get('spark')
self.shared_variables = OrderedDict()
self.max_var_size = 100
def sc(self):
return self.sc
def sql_ctx(self):
return self.sql_ctx
def hive_ctx(self):
if self.hive_ctx is None:
with self.lock:
if self.hive_ctx is None:
if isinstance(self.sql_ctx, global_dict['HiveContext']):
self.hive_ctx = self.sql_ctx
else:
self.hive_ctx = global_dict['HiveContext'](self.sc)
return self.hive_ctx
def create_streaming_ctx(self, batch_duration):
with self.lock:
if self.streaming_ctx is not None:
raise ValueError("Streaming context already exists")
self.streaming_ctx = global_dict['StreamingContext'](self.sc, batch_duration)
def streaming_ctx(self):
with self.lock:
if self.streaming_ctx is None:
raise ValueError("create_streaming_ctx function should be called first")
return self.streaming_ctx
def stop_streaming_ctx(self):
with self.lock:
if self.streaming_ctx is None:
raise ValueError("Cannot stop streaming context. Streaming context is None")
self.streaming_ctx.stop()
self.streaming_ctx = None
def get_local_tmp_dir_path(self):
return self.local_tmp_dir_path
def stop(self):
with self.lock:
if self.streaming_ctx is not None:
self.stop_streaming_ctx()
if self.sc is not None:
self.sc.stop()
def spark_session(self):
return self.spark_session
def get_shared_object(self, name):
with self.lock:
try:
var = self.shared_variables[name]
del self.shared_variables[name]
self.shared_variables[name] = var
except:
var = None
return var
def set_shared_object(self, name, object):
with self.lock:
self.shared_variables[name] = object
while len(self.shared_variables) > self.max_var_size:
self.popitem(last=False)
def remove_shared_object(self, name):
with self.lock:
try:
del self.shared_variables[name]
except:
pass
class PySparkJobProcessorImpl(object):
def processBypassJob(self, serialized_job):
try:
if sys.version >= '3':
deserialized_job = pickle.loads(serialized_job, encoding="bytes")
else:
deserialized_job = pickle.loads(serialized_job)
result = deserialized_job(job_context)
serialized_result = global_dict['cloudpickle'].dumps(result)
response = bytearray(base64.b64encode(serialized_result))
except:
response = bytearray('Client job error:' + traceback.format_exc(), 'utf-8')
return response
def addFile(self, uri_path):
job_context.sc.addFile(uri_path)
def addPyFile(self, uri_path):
job_context.sc.addPyFile(uri_path)
def getLocalTmpDirPath(self):
return os.path.join(job_context.get_local_tmp_dir_path(), '__livy__')
class Scala:
extends = ['org.apache.livy.repl.PySparkJobProcessor']
class ExecutionError(Exception):
def __init__(self, exc_info):
self.exc_info = exc_info
class NormalNode(object):
def __init__(self, code):
self.code = compile(code, '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
def execute(self):
to_run_exec, to_run_single = self.code.body[:-1], self.code.body[-1:]
try:
for node in to_run_exec:
mod = Module([node], [])
code = compile(mod, '<stdin>', 'exec')
exec(code, global_dict)
for node in to_run_single:
mod = ast.Interactive([node])
code = compile(mod, '<stdin>', 'single')
exec(code, global_dict)
except:
# We don't need to log the exception because we're just executing user
# code and passing the error along.
raise ExecutionError(sys.exc_info())
class UnknownMagic(Exception):
pass
class MagicNode(object):
def __init__(self, line):
parts = line[1:].split(' ', 1)
if len(parts) == 1:
self.magic, self.rest = parts[0], ()
else:
self.magic, self.rest = parts[0], (parts[1],)
def execute(self):
if not self.magic:
raise UnknownMagic('magic command not specified')
try:
handler = magic_router[self.magic]
except KeyError:
raise UnknownMagic("unknown magic command '%s'" % self.magic)
return handler(*self.rest)
def parse_code_into_nodes(code):
nodes = []
try:
nodes.append(NormalNode(code))
except SyntaxError:
# It's possible we hit a syntax error because of a magic command. Split the code groups
# of 'normal code', and code that starts with a '%'. possibly magic code
# lines, and see if any of the lines
# Remove lines until we find a node that parses, then check if the next line is a magic
# line
# .
# Split the code into chunks of normal code, and possibly magic code, which starts with
# a '%'.
normal = []
chunks = []
for i, line in enumerate(code.rstrip().split('\n')):
if line.startswith('%'):
if normal:
chunks.append('\n'.join(normal))
normal = []
chunks.append(line)
else:
normal.append(line)
if normal:
chunks.append('\n'.join(normal))
# Convert the chunks into AST nodes. Let exceptions propagate.
for chunk in chunks:
if chunk.startswith('%'):
nodes.append(MagicNode(chunk))
else:
nodes.append(NormalNode(chunk))
return nodes
def execute_request(content):
try:
code = content['code']
except KeyError:
return execute_reply_internal_error(
'Malformed message: content object missing "code"', sys.exc_info()
)
try:
nodes = parse_code_into_nodes(code)
except SyntaxError:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
result = None
try:
for node in nodes:
result = node.execute()
except UnknownMagic:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
except ExecutionError as e:
return execute_reply_error(*e.exc_info)
if result is None:
result = {}
stdout = sys.stdout.getvalue()
stderr = sys.stderr.getvalue()
clearOutputs()
output = result.pop('text/plain', '')
if stdout:
output += stdout
if stderr:
output += stderr
output = output.rstrip()
# Only add the output if it exists, or if there are no other mimetypes in the result.
if output or not result:
result['text/plain'] = output.rstrip()
return execute_reply_ok(result)
def magic_table_convert(value):
try:
converter = magic_table_types[type(value)]
except KeyError:
converter = magic_table_types[str]
return converter(value)
def magic_table_convert_seq(items):
last_item_type = None
converted_items = []
for item in items:
item_type, item = magic_table_convert(item)
if last_item_type is None:
last_item_type = item_type
elif last_item_type != item_type:
raise ValueError('value has inconsistent types')
converted_items.append(item)
return 'ARRAY_TYPE', converted_items
def magic_table_convert_map(m):
last_key_type = None
last_value_type = None
converted_items = {}
for key, value in m:
key_type, key = magic_table_convert(key)
value_type, value = magic_table_convert(value)
if last_key_type is None:
last_key_type = key_type
elif last_value_type != value_type:
raise ValueError('value has inconsistent types')
if last_value_type is None:
last_value_type = value_type
elif last_value_type != value_type:
raise ValueError('value has inconsistent types')
converted_items[key] = value
return 'MAP_TYPE', converted_items
magic_table_types = {
type(None): lambda x: ('NULL_TYPE', x),
bool: lambda x: ('BOOLEAN_TYPE', x),
int: lambda x: ('INT_TYPE', x),
float: lambda x: ('DOUBLE_TYPE', x),
str: lambda x: ('STRING_TYPE', str(x)),
datetime.date: lambda x: ('DATE_TYPE', str(x)),
datetime.datetime: lambda x: ('TIMESTAMP_TYPE', str(x)),
decimal.Decimal: lambda x: ('DECIMAL_TYPE', str(x)),
tuple: magic_table_convert_seq,
list: magic_table_convert_seq,
dict: magic_table_convert_map,
}
# python 2.x only
if sys.version < '3':
magic_table_types.update({
long: lambda x: ('BIGINT_TYPE', x),
unicode: lambda x: ('STRING_TYPE', x.encode('utf-8'))
})
def magic_table(name):
try:
value = global_dict[name]
except KeyError:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
if not isinstance(value, (list, tuple)):
value = [value]
headers = {}
data = []
for row in value:
cols = []
data.append(cols)
if 'Row' == row.__class__.__name__:
row = row.asDict()
if not isinstance(row, (list, tuple, dict)):
row = [row]
if isinstance(row, (list, tuple)):
iterator = enumerate(row)
else:
iterator = sorted(row.items())
for name, col in iterator:
col_type, col = magic_table_convert(col)
try:
header = headers[name]
except KeyError:
header = {
'name': str(name),
'type': col_type,
}
headers[name] = header
else:
# Reject columns that have a different type. (allow none value)
if col_type != "NULL_TYPE" and header['type'] != col_type:
if header['type'] == "NULL_TYPE":
header['type'] = col_type
else:
exc_type = Exception
exc_value = 'table rows have different types'
return execute_reply_error(exc_type, exc_value, None)
cols.append(col)
headers = [v for k, v in sorted(headers.items())]
return {
'application/vnd.livy.table.v1+json': {
'headers': headers,
'data': data,
}
}
def magic_json(name):
try:
value = global_dict[name]
except KeyError:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
return {
'application/json': value,
}
def magic_matplot(name):
try:
value = global_dict[name]
fig = value.gcf()
imgdata = io.BytesIO()
fig.savefig(imgdata, format='png')
imgdata.seek(0)
encode = base64.b64encode(imgdata.getvalue())
if sys.version >= '3':
encode = encode.decode()
except:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
return {
'image/png': encode,
'text/plain': "",
}
def shutdown_request(_content):
sys.exit()
magic_router = {
'table': magic_table,
'json': magic_json,
'matplot': magic_matplot,
}
msg_type_router = {
'execute_request': execute_request,
'shutdown_request': shutdown_request,
}
class UnicodeDecodingStringIO(io.StringIO):
def write(self, s):
if isinstance(s, bytes):
s = s.decode("utf-8")
super(UnicodeDecodingStringIO, self).write(s)
def clearOutputs():
sys.stdout.close()
sys.stderr.close()
sys.stdout = UnicodeDecodingStringIO()
sys.stderr = UnicodeDecodingStringIO()
def main():
sys_stdin = sys.stdin
sys_stdout = sys.stdout
sys_stderr = sys.stderr
if sys.version >= '3':
sys.stdin = io.StringIO()
else:
sys.stdin = cStringIO.StringIO()
sys.stdout = UnicodeDecodingStringIO()
sys.stderr = UnicodeDecodingStringIO()
spark_major_version = os.getenv("LIVY_SPARK_MAJOR_VERSION")
try:
listening_port = 0
if os.environ.get("LIVY_TEST") != "true":
#Load spark into the context
exec('from pyspark.sql import HiveContext', global_dict)
exec('from pyspark.streaming import StreamingContext', global_dict)
exec('import pyspark.cloudpickle as cloudpickle', global_dict)
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SQLContext, HiveContext, Row
# Connect to the gateway
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
try:
from py4j.java_gateway import GatewayParameters
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
gateway = JavaGateway(gateway_parameters=GatewayParameters(
port=gateway_port, auth_token=gateway_secret, auto_convert=True))
except:
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.ml.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "org.apache.spark.resource.*")
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
jsc = gateway.entry_point.sc()
jconf = gateway.entry_point.sc().getConf()
jsqlc = gateway.entry_point.hivectx() if gateway.entry_point.hivectx() is not None \
else gateway.entry_point.sqlctx()
conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
global_dict['sc'] = sc
if spark_major_version >= "2":
from pyspark.sql import SparkSession
spark_session = SparkSession(sc, gateway.entry_point.sparkSession())
sqlc = SQLContext(sc, spark_session, jsqlc)
global_dict['sqlContext'] = sqlc
global_dict['spark'] = spark_session
else:
sqlc = SQLContext(sc, jsqlc)
global_dict['sqlContext'] = sqlc
# LIVY-294, need to check whether HiveContext can work properly,
# fallback to SQLContext if HiveContext can not be initialized successfully.
# Only for spark-1.
code = textwrap.dedent("""
import py4j
from pyspark.sql import SQLContext
try:
sqlContext.tables()
except py4j.protocol.Py4JError:
sqlContext = SQLContext(sc)""")
exec(code, global_dict)
#Start py4j callback server
from py4j.protocol import ENTRY_POINT_OBJECT_ID
from py4j.java_gateway import CallbackServerParameters
try:
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
gateway.start_callback_server(
callback_server_parameters=CallbackServerParameters(
port=0, auth_token=gateway_secret))
except:
gateway.start_callback_server(
callback_server_parameters=CallbackServerParameters(port=0))
socket_info = gateway._callback_server.server_socket.getsockname()
listening_port = socket_info[1]
pyspark_job_processor = PySparkJobProcessorImpl()
gateway.gateway_property.pool.dict[ENTRY_POINT_OBJECT_ID] = pyspark_job_processor
global local_tmp_dir_path, job_context
local_tmp_dir_path = tempfile.mkdtemp()
job_context = JobContextImpl()
print(sys.stdout.getvalue(), file=sys_stderr)
print(sys.stderr.getvalue(), file=sys_stderr)
clearOutputs()
print('READY(port=' + str(listening_port) + ')', file=sys_stdout)
sys_stdout.flush()
while True:
line = sys_stdin.readline()
if line == '':
break
elif line == '\n':
continue
try:
msg = json.loads(line)
except ValueError:
LOG.error('failed to parse message', exc_info=True)
continue
try:
msg_type = msg['msg_type']
except KeyError:
LOG.error('missing message type', exc_info=True)
continue
try:
content = msg['content']
except KeyError:
LOG.error('missing content', exc_info=True)
continue
if not isinstance(content, dict):
LOG.error('content is not a dictionary')
continue
try:
handler = msg_type_router[msg_type]
except KeyError:
LOG.error('unknown message type: %s', msg_type)
continue
response = handler(content)
try:
response = json.dumps(response)
except ValueError:
response = json.dumps({
'msg_type': 'inspect_reply',
'content': {
'status': 'error',
'ename': 'ValueError',
'evalue': 'cannot json-ify %s' % response,
'traceback': [],
}
})
print(response, file=sys_stdout)
sys_stdout.flush()
finally:
if os.environ.get("LIVY_TEST") != "true" and 'sc' in global_dict:
gateway.shutdown_callback_server()
shutil.rmtree(local_tmp_dir_path)
global_dict['sc'].stop()
sys.stdin = sys_stdin
sys.stdout = sys_stdout
sys.stderr = sys_stderr
if __name__ == '__main__':
sys.exit(main())