Wrapping the kernel object in a python class, and tweaking the spark context creation
diff --git a/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py b/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
index 5f3c079..3759165 100644
--- a/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
+++ b/pyspark-interpreter/src/main/resources/PySpark/pyspark_runner.py
@@ -37,9 +37,9 @@
 sparkVersion = sys.argv[2]
 
 if re.match("^1\.[456]\..*$", sparkVersion):
-  gateway = JavaGateway(client, auto_convert = True)
+    gateway = JavaGateway(client, auto_convert=True)
 else:
-  gateway = JavaGateway(client)
+    gateway = JavaGateway(client)
 
 java_import(gateway.jvm, "org.apache.spark.SparkEnv")
 java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -51,112 +51,133 @@
 state = bridge.state()
 state.markReady()
 
-#jsc = bridge.javaSparkContext()
-
 if sparkVersion.startswith("1.2"):
-  java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
-  java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
-  java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
-  java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+    java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
+    java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
+    java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
+    java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
 elif sparkVersion.startswith("1.3"):
-  java_import(gateway.jvm, "org.apache.spark.sql.*")
-  java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+    java_import(gateway.jvm, "org.apache.spark.sql.*")
+    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
 elif re.match("^1\.[456]\..*$", sparkVersion):
-  java_import(gateway.jvm, "org.apache.spark.sql.*")
-  java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+    java_import(gateway.jvm, "org.apache.spark.sql.*")
+    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
 
 java_import(gateway.jvm, "scala.Tuple2")
 
-
+conf = None
 sc = None
 sqlContext = None
 code_info = None
 
-kernel = bridge.kernel()
 
 class Logger(object):
-  def __init__(self):
-    self.out = ""
+    def __init__(self):
+        self.out = ""
 
-  def write(self, message):
-    state.sendOutput(code_info.codeId(), message)
-    self.out = self.out + message
+    def write(self, message):
+        state.sendOutput(code_info.codeId(), message)
+        self.out = self.out + message
 
-  def get(self):
-    return self.out
+    def get(self):
+        return self.out
 
-  def reset(self):
-    self.out = ""
+    def reset(self):
+        self.out = ""
 
 output = Logger()
 sys.stdout = output
 sys.stderr = output
 
-while True :
-  try:
-    global code_info
-    code_info = state.nextCode()
 
-    # If code is not available, try again later
-    if (code_info is None):
-      sleep(1)
-      continue
+class Kernel(object):
+    def __init__(self, jkernel):
+        self._jvm_kernel = jkernel
 
-    code_lines = code_info.code().split("\n")
-    #jobGroup = req.jobGroup()
-    final_code = None
+    def createSparkContext(self, config):
+        jconf = gateway.jvm.org.apache.spark.SparkConf(False)
+        for key,value in config.getAll():
+            jconf.set(key, value)
+        self._jvm_kernel.createSparkContext(jconf)
+        self.refreshContext()
 
-    for s in code_lines:
-      if s == None or len(s.strip()) == 0:
-        continue
+    def refreshContext(self):
+        global conf, sc, sqlContext
 
-      # skip comment
-      if s.strip().startswith("#"):
-        continue
+        # This is magic. Please look away. I was never here (prevents multiple gateways being instantiated)
+        with SparkContext._lock:
+            if not SparkContext._gateway:
+                SparkContext._gateway = gateway
+                SparkContext._jvm = gateway.jvm
 
-      if final_code:
-        final_code += "\n" + s
-      else:
-        final_code = s
+        if sc is None:
+            jsc = self._jvm_kernel.javaSparkContext()
+            if jsc is not None:
+                jconf = self._jvm_kernel.sparkConf()
+                conf = SparkConf(_jvm=gateway.jvm, _jconf=jconf)
+                sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
 
-    if sc is None:
-      jsc = kernel.javaSparkContext()
-      if jsc is not None:
-        jconf = kernel.sparkConf()
-        conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
-        sc = SparkContext(jsc = jsc, gateway = gateway, conf = conf)
+        if sqlContext is None:
+            jsqlContext = self._jvm_kernel.sqlContext()
+            if jsqlContext is not None and sc is not None:
+                sqlContext = SQLContext(sc, sqlContext=jsqlContext)
 
-    if sqlContext is None:
-      jsqlContext = kernel.sqlContext()
-      if jsqlContext is not None and sc is not None:
-        sqlContext = SQLContext(sc, sqlContext=jsqlContext)
+kernel = Kernel(bridge.kernel())
 
-    if final_code:
-      '''Parse the final_code to an AST parse tree.  If the last node is an expression (where an expression
-      can be a print function or an operation like 1+1) turn it into an assignment where temp_val = last expression.
-      The modified parse tree will get executed.  If the variable temp_val introduced is not none then we have the
-      result of the last expression and should return it as an execute result.  The sys.stdout sendOutput logic
-      gets triggered on each logger message to support long running code blocks instead of bulk'''
-      ast_parsed = ast.parse(final_code)
-      the_last_expression_to_assign_temp_value = None
-      if isinstance(ast_parsed.body[-1], ast.Expr):
-        new_node = (ast.Assign(targets=[ast.Name(id='the_last_expression_to_assign_temp_value', ctx=ast.Store())], value=ast_parsed.body[-1].value))
-        ast_parsed.body[-1] = ast.fix_missing_locations(new_node)
-      compiled_code = compile(ast_parsed, "<string>", "exec")
-      eval(compiled_code)
-      if the_last_expression_to_assign_temp_value is not None:
-        state.markSuccess(code_info.codeId(), str(the_last_expression_to_assign_temp_value))
-      else:
-        state.markSuccess(code_info.codeId(), "")
-      del the_last_expression_to_assign_temp_value
+while True:
+    try:
+        code_info = state.nextCode()
 
-  except Py4JJavaError:
-    excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
-    innerErrorStart = excInnerError.find("Py4JJavaError:")
-    if innerErrorStart > -1:
-       excInnerError = excInnerError[innerErrorStart:]
-    state.markFailure(code_info.codeId(), excInnerError + str(sys.exc_info()))
-  except:
-    state.markFailure(code_info.codeId(), traceback.format_exc())
+        # If code is not available, try again later
+        if code_info is None:
+            sleep(1)
+            continue
 
-  output.reset()
+        code_lines = code_info.code().split("\n")
+        final_code = None
+
+        for s in code_lines:
+            if s is None or len(s.strip()) == 0:
+                continue
+
+            # skip comment
+            if s.strip().startswith("#"):
+                continue
+
+            if final_code:
+                final_code += "\n" + s
+            else:
+                final_code = s
+
+        # Ensure the appropriate variables are set in the module namespace
+        kernel.refreshContext()
+
+        if final_code:
+            '''Parse the final_code to an AST parse tree.  If the last node is an expression (where an expression
+            can be a print function or an operation like 1+1) turn it into an assignment where temp_val = last expression.
+            The modified parse tree will get executed.  If the variable temp_val introduced is not none then we have the
+            result of the last expression and should return it as an execute result.  The sys.stdout sendOutput logic
+            gets triggered on each logger message to support long running code blocks instead of bulk'''
+            ast_parsed = ast.parse(final_code)
+            the_last_expression_to_assign_temp_value = None
+            if isinstance(ast_parsed.body[-1], ast.Expr):
+                new_node = (ast.Assign(targets=[ast.Name(id='the_last_expression_to_assign_temp_value', ctx=ast.Store())], value=ast_parsed.body[-1].value))
+                ast_parsed.body[-1] = ast.fix_missing_locations(new_node)
+            compiled_code = compile(ast_parsed, "<string>", "exec")
+            eval(compiled_code)
+            if the_last_expression_to_assign_temp_value is not None:
+                state.markSuccess(code_info.codeId(), str(the_last_expression_to_assign_temp_value))
+            else:
+                state.markSuccess(code_info.codeId(), "")
+            del the_last_expression_to_assign_temp_value
+
+    except Py4JJavaError:
+        excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
+        innerErrorStart = excInnerError.find("Py4JJavaError:")
+        if innerErrorStart > -1:
+            excInnerError = excInnerError[innerErrorStart:]
+        state.markFailure(code_info.codeId(), excInnerError + str(sys.exc_info()))
+    except:
+        state.markFailure(code_info.codeId(), traceback.format_exc())
+
+    output.reset()