| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import atexit |
| import os |
| import sys |
| import select |
| import signal |
| import shlex |
| import socket |
| import platform |
| from subprocess import Popen, PIPE |
| |
| if sys.version >= '3': |
| xrange = range |
| |
| from py4j.java_gateway import java_import, JavaGateway, GatewayClient |
| from py4j.java_collections import ListConverter |
| |
| from pyspark.serializers import read_int |
| |
| |
| # patching ListConverter, or it will convert bytearray into Java ArrayList |
| def can_convert_list(self, obj): |
| return isinstance(obj, (list, tuple, xrange)) |
| |
| ListConverter.can_convert = can_convert_list |
| |
| |
| def launch_gateway(): |
| if "PYSPARK_GATEWAY_PORT" in os.environ: |
| gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) |
| else: |
| SPARK_HOME = os.environ["SPARK_HOME"] |
| # Launch the Py4j gateway using Spark's run command so that we pick up the |
| # proper classpath and settings from spark-env.sh |
| on_windows = platform.system() == "Windows" |
| script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" |
| submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") |
| if os.environ.get("SPARK_TESTING"): |
| submit_args = ' '.join([ |
| "--conf spark.ui.enabled=false", |
| submit_args |
| ]) |
| command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) |
| |
| # Start a socket that will be used by PythonGatewayServer to communicate its port to us |
| callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| callback_socket.bind(('127.0.0.1', 0)) |
| callback_socket.listen(1) |
| callback_host, callback_port = callback_socket.getsockname() |
| env = dict(os.environ) |
| env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host |
| env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) |
| |
| # Launch the Java gateway. |
| # We open a pipe to stdin so that the Java gateway can die when the pipe is broken |
| if not on_windows: |
| # Don't send ctrl-c / SIGINT to the Java gateway: |
| def preexec_func(): |
| signal.signal(signal.SIGINT, signal.SIG_IGN) |
| proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) |
| else: |
| # preexec_fn not supported on Windows |
| proc = Popen(command, stdin=PIPE, env=env) |
| |
| gateway_port = None |
| # We use select() here in order to avoid blocking indefinitely if the subprocess dies |
| # before connecting |
| while gateway_port is None and proc.poll() is None: |
| timeout = 1 # (seconds) |
| readable, _, _ = select.select([callback_socket], [], [], timeout) |
| if callback_socket in readable: |
| gateway_connection = callback_socket.accept()[0] |
| # Determine which ephemeral port the server started on: |
| gateway_port = read_int(gateway_connection.makefile(mode="rb")) |
| gateway_connection.close() |
| callback_socket.close() |
| if gateway_port is None: |
| raise Exception("Java gateway process exited before sending the driver its port number") |
| |
| # In Windows, ensure the Java child processes do not linger after Python has exited. |
| # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when |
| # the parent process' stdin sends an EOF). In Windows, however, this is not possible |
| # because java.lang.Process reads directly from the parent process' stdin, contending |
| # with any opportunity to read an EOF from the parent. Note that this is only best |
| # effort and will not take effect if the python process is violently terminated. |
| if on_windows: |
| # In Windows, the child process here is "spark-submit.cmd", not the JVM itself |
| # (because the UNIX "exec" command is not available). This means we cannot simply |
| # call proc.kill(), which kills only the "spark-submit.cmd" process but not the |
| # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all |
| # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) |
| def killChild(): |
| Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) |
| atexit.register(killChild) |
| |
| # Connect to the gateway |
| gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) |
| |
| # Import the classes used by PySpark |
| java_import(gateway.jvm, "org.apache.spark.SparkConf") |
| java_import(gateway.jvm, "org.apache.spark.api.java.*") |
| java_import(gateway.jvm, "org.apache.spark.api.python.*") |
| java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") |
| # TODO(davies): move into sql |
| java_import(gateway.jvm, "org.apache.spark.sql.*") |
| java_import(gateway.jvm, "org.apache.spark.sql.hive.*") |
| java_import(gateway.jvm, "scala.Tuple2") |
| |
| return gateway |