| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import os, sys, traceback, json, re |
| |
| from py4j.java_gateway import java_import, JavaGateway, GatewayClient |
| from py4j.protocol import Py4JJavaError |
| |
| import ast |
| |
| class Logger(object): |
| def __init__(self): |
| pass |
| |
| def write(self, message): |
| intp.appendOutput(message) |
| |
| def reset(self): |
| pass |
| |
| def flush(self): |
| pass |
| |
| |
| class PythonCompletion: |
| def __init__(self, interpreter, userNameSpace): |
| self.interpreter = interpreter |
| self.userNameSpace = userNameSpace |
| |
| def getObjectCompletion(self, text_value): |
| completions = [completion for completion in list(self.userNameSpace.keys()) if completion.startswith(text_value)] |
| builtinCompletions = [completion for completion in dir(__builtins__) if completion.startswith(text_value)] |
| return completions + builtinCompletions |
| |
| def getMethodCompletion(self, objName, methodName): |
| execResult = locals() |
| try: |
| exec("{} = dir({})".format("objectDefList", objName), _zcUserQueryNameSpace, execResult) |
| except: |
| self.interpreter.logPythonOutput("Fail to run dir on " + objName) |
| self.interpreter.logPythonOutput(traceback.format_exc()) |
| return None |
| else: |
| objectDefList = execResult['objectDefList'] |
| return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] |
| |
| def getCompletion(self, text_value): |
| if text_value == None: |
| return None |
| |
| dotPos = text_value.find(".") |
| if dotPos == -1: |
| objName = text_value |
| completionList = self.getObjectCompletion(objName) |
| else: |
| objName = text_value[:dotPos] |
| methodName = text_value[dotPos + 1:] |
| completionList = self.getMethodCompletion(objName, methodName) |
| |
| if completionList is None or len(completionList) <= 0: |
| self.interpreter.setStatementsFinished("", False) |
| else: |
| result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) |
| self.interpreter.setStatementsFinished(result, False) |
| |
| # ast api is changed after python 3.8, see https://github.com/ipython/ipython/pull/11593 |
| if sys.version_info > (3,8): |
| from ast import Module |
| else : |
| # mock the new API, ignore second argument |
| # see https://github.com/ipython/ipython/issues/11590 |
| from ast import Module as OriginalModule |
| Module = lambda nodelist, type_ignores: OriginalModule(nodelist) |
| |
| host = sys.argv[1] |
| port = int(sys.argv[2]) |
| |
| if "PY4J_GATEWAY_SECRET" in os.environ: |
| from py4j.java_gateway import GatewayParameters |
| gateway_secret = os.environ["PY4J_GATEWAY_SECRET"] |
| gateway = JavaGateway(gateway_parameters=GatewayParameters( |
| address=host, port=port, auth_token=gateway_secret, auto_convert=True)) |
| else: |
| gateway = JavaGateway(GatewayClient(address=host, port=port), auto_convert=True) |
| |
| intp = gateway.entry_point |
| _zcUserQueryNameSpace = {} |
| |
| completion = PythonCompletion(intp, _zcUserQueryNameSpace) |
| _zcUserQueryNameSpace["__zeppelin_completion__"] = completion |
| _zcUserQueryNameSpace["gateway"] = gateway |
| |
| from zeppelin_context import PyZeppelinContext |
| if intp.getZeppelinContext(): |
| z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway) |
| __zeppelin__._setup_matplotlib() |
| _zcUserQueryNameSpace["z"] = z |
| _zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ |
| |
| intp.onPythonScriptInitialized(os.getpid()) |
| # redirect stdout/stderr to java side so that PythonInterpreter can capture the python execution result |
| output = Logger() |
| sys.stdout = output |
| sys.stderr = output |
| |
| while True : |
| req = intp.getStatements() |
| try: |
| stmts = req.statements().split("\n") |
| isForCompletion = req.isForCompletion() |
| |
| # Get post-execute hooks |
| try: |
| if req.isCallHooks(): |
| global_hook = intp.getHook('post_exec_dev') |
| else: |
| global_hook = None |
| except: |
| global_hook = None |
| |
| try: |
| if req.isCallHooks(): |
| user_hook = __zeppelin__.getHook('post_exec') |
| else: |
| user_hook = None |
| except: |
| user_hook = None |
| |
| nhooks = 0 |
| if not isForCompletion: |
| for hook in (global_hook, user_hook): |
| if hook: |
| nhooks += 1 |
| |
| if stmts: |
| # use exec mode to compile the statements except the last statement, |
| # so that the last statement's evaluation will be printed to stdout |
| code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1) |
| to_run_hooks = [] |
| if (nhooks > 0): |
| to_run_hooks = code.body[-nhooks:] |
| to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], |
| [code.body[-(nhooks + 1)]] if len(code.body) > nhooks else []) |
| try: |
| for node in to_run_exec: |
| mod = Module([node], []) |
| code = compile(mod, '<stdin>', 'exec') |
| exec(code, _zcUserQueryNameSpace) |
| |
| for node in to_run_single: |
| mod = ast.Interactive([node]) |
| code = compile(mod, '<stdin>', 'single') |
| exec(code, _zcUserQueryNameSpace) |
| |
| for node in to_run_hooks: |
| mod = Module([node], []) |
| code = compile(mod, '<stdin>', 'exec') |
| exec(code, _zcUserQueryNameSpace) |
| |
| if not isForCompletion: |
| # only call it when it is not for code completion. code completion will call it in |
| # PythonCompletion.getCompletion |
| intp.setStatementsFinished("", False) |
| except Py4JJavaError: |
| # raise it to outside try except |
| raise |
| except: |
| if not isForCompletion: |
| # extract which line incur error from error message. e.g. |
| # Traceback (most recent call last): |
| # File "<stdin>", line 1, in <module> |
| # ZeroDivisionError: integer division or modulo by zero |
| exception = traceback.format_exc() |
| m = re.search("File \"<stdin>\", line (\d+).*", exception) |
| if m: |
| line_no = int(m.group(1)) |
| intp.setStatementsFinished( |
| "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) |
| else: |
| intp.setStatementsFinished(exception, True) |
| else: |
| intp.setStatementsFinished("", False) |
| |
| except Py4JJavaError: |
| excInnerError = traceback.format_exc() # format_tb() does not return the inner exception |
| innerErrorStart = excInnerError.find("Py4JJavaError:") |
| if innerErrorStart > -1: |
| excInnerError = excInnerError[innerErrorStart:] |
| intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) |
| except: |
| intp.setStatementsFinished(traceback.format_exc(), True) |
| |
| output.reset() |