Wrapping the kernel object in a python class, and tweaking the spark context creation
diff --git a/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py b/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
index 5f3c079..3759165 100644
--- a/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
+++ b/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
@@ -37,9 +37,9 @@
sparkVersion = sys.argv[2]
if re.match("^1\.[456]\..*$", sparkVersion):
- gateway = JavaGateway(client, auto_convert = True)
+ gateway = JavaGateway(client, auto_convert=True)
else:
- gateway = JavaGateway(client)
+ gateway = JavaGateway(client)
java_import(gateway.jvm, "org.apache.spark.SparkEnv")
java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -51,112 +51,133 @@
state = bridge.state()
state.markReady()
-#jsc = bridge.javaSparkContext()
-
if sparkVersion.startswith("1.2"):
- java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
elif sparkVersion.startswith("1.3"):
- java_import(gateway.jvm, "org.apache.spark.sql.*")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
elif re.match("^1\.[456]\..*$", sparkVersion):
- java_import(gateway.jvm, "org.apache.spark.sql.*")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
-
+conf = None
sc = None
sqlContext = None
code_info = None
-kernel = bridge.kernel()
class Logger(object):
- def __init__(self):
- self.out = ""
+ def __init__(self):
+ self.out = ""
- def write(self, message):
- state.sendOutput(code_info.codeId(), message)
- self.out = self.out + message
+ def write(self, message):
+ state.sendOutput(code_info.codeId(), message)
+ self.out = self.out + message
- def get(self):
- return self.out
+ def get(self):
+ return self.out
- def reset(self):
- self.out = ""
+ def reset(self):
+ self.out = ""
output = Logger()
sys.stdout = output
sys.stderr = output
-while True :
- try:
- global code_info
- code_info = state.nextCode()
- # If code is not available, try again later
- if (code_info is None):
- sleep(1)
- continue
+class Kernel(object):
+ def __init__(self, jkernel):
+ self._jvm_kernel = jkernel
- code_lines = code_info.code().split("\n")
- #jobGroup = req.jobGroup()
- final_code = None
+ def createSparkContext(self, config):
+ jconf = gateway.jvm.org.apache.spark.SparkConf(False)
+ for key,value in config.getAll():
+ jconf.set(key, value)
+ self._jvm_kernel.createSparkContext(jconf)
+ self.refreshContext()
- for s in code_lines:
- if s == None or len(s.strip()) == 0:
- continue
+ def refreshContext(self):
+ global conf, sc, sqlContext
- # skip comment
- if s.strip().startswith("#"):
- continue
+ # This is magic. Please look away. I was never here (prevents multiple gateways being instantiated)
+ with SparkContext._lock:
+ if not SparkContext._gateway:
+ SparkContext._gateway = gateway
+ SparkContext._jvm = gateway.jvm
- if final_code:
- final_code += "\n" + s
- else:
- final_code = s
+ if sc is None:
+ jsc = self._jvm_kernel.javaSparkContext()
+ if jsc is not None:
+ jconf = self._jvm_kernel.sparkConf()
+ conf = SparkConf(_jvm=gateway.jvm, _jconf=jconf)
+ sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
- if sc is None:
- jsc = kernel.javaSparkContext()
- if jsc is not None:
- jconf = kernel.sparkConf()
- conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
- sc = SparkContext(jsc = jsc, gateway = gateway, conf = conf)
+ if sqlContext is None:
+ jsqlContext = self._jvm_kernel.sqlContext()
+ if jsqlContext is not None and sc is not None:
+ sqlContext = SQLContext(sc, sqlContext=jsqlContext)
- if sqlContext is None:
- jsqlContext = kernel.sqlContext()
- if jsqlContext is not None and sc is not None:
- sqlContext = SQLContext(sc, sqlContext=jsqlContext)
+kernel = Kernel(bridge.kernel())
- if final_code:
- '''Parse the final_code to an AST parse tree. If the last node is an expression (where an expression
- can be a print function or an operation like 1+1) turn it into an assignment where temp_val = last expression.
- The modified parse tree will get executed. If the variable temp_val introduced is not none then we have the
- result of the last expression and should return it as an execute result. The sys.stdout sendOutput logic
- gets triggered on each logger message to support long running code blocks instead of bulk'''
- ast_parsed = ast.parse(final_code)
- the_last_expression_to_assign_temp_value = None
- if isinstance(ast_parsed.body[-1], ast.Expr):
- new_node = (ast.Assign(targets=[ast.Name(id='the_last_expression_to_assign_temp_value', ctx=ast.Store())], value=ast_parsed.body[-1].value))
- ast_parsed.body[-1] = ast.fix_missing_locations(new_node)
- compiled_code = compile(ast_parsed, "<string>", "exec")
- eval(compiled_code)
- if the_last_expression_to_assign_temp_value is not None:
- state.markSuccess(code_info.codeId(), str(the_last_expression_to_assign_temp_value))
- else:
- state.markSuccess(code_info.codeId(), "")
- del the_last_expression_to_assign_temp_value
+while True:
+ try:
+ code_info = state.nextCode()
- except Py4JJavaError:
- excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
- innerErrorStart = excInnerError.find("Py4JJavaError:")
- if innerErrorStart > -1:
- excInnerError = excInnerError[innerErrorStart:]
- state.markFailure(code_info.codeId(), excInnerError + str(sys.exc_info()))
- except:
- state.markFailure(code_info.codeId(), traceback.format_exc())
+ # If code is not available, try again later
+ if code_info is None:
+ sleep(1)
+ continue
- output.reset()
+ code_lines = code_info.code().split("\n")
+ final_code = None
+
+ for s in code_lines:
+ if s is None or len(s.strip()) == 0:
+ continue
+
+ # skip comment
+ if s.strip().startswith("#"):
+ continue
+
+ if final_code:
+ final_code += "\n" + s
+ else:
+ final_code = s
+
+ # Ensure the appropriate variables are set in the module namespace
+ kernel.refreshContext()
+
+ if final_code:
+ '''Parse the final_code to an AST parse tree. If the last node is an expression (where an expression
+ can be a print function or an operation like 1+1) turn it into an assignment where temp_val = last expression.
+ The modified parse tree will get executed. If the variable temp_val introduced is not none then we have the
+ result of the last expression and should return it as an execute result. The sys.stdout sendOutput logic
+ gets triggered on each logger message to support long running code blocks instead of bulk'''
+ ast_parsed = ast.parse(final_code)
+ the_last_expression_to_assign_temp_value = None
+ if isinstance(ast_parsed.body[-1], ast.Expr):
+ new_node = (ast.Assign(targets=[ast.Name(id='the_last_expression_to_assign_temp_value', ctx=ast.Store())], value=ast_parsed.body[-1].value))
+ ast_parsed.body[-1] = ast.fix_missing_locations(new_node)
+ compiled_code = compile(ast_parsed, "<string>", "exec")
+ eval(compiled_code)
+ if the_last_expression_to_assign_temp_value is not None:
+ state.markSuccess(code_info.codeId(), str(the_last_expression_to_assign_temp_value))
+ else:
+ state.markSuccess(code_info.codeId(), "")
+ del the_last_expression_to_assign_temp_value
+
+ except Py4JJavaError:
+ excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
+ innerErrorStart = excInnerError.find("Py4JJavaError:")
+ if innerErrorStart > -1:
+ excInnerError = excInnerError[innerErrorStart:]
+ state.markFailure(code_info.codeId(), excInnerError + str(sys.exc_info()))
+ except:
+ state.markFailure(code_info.codeId(), traceback.format_exc())
+
+ output.reset()