blob: a8af121be004b1fdd7f9b11656fe5d638d4d3816 [file] [log] [blame]
import sys, getopt, traceback, json, re
import os
from py4j.protocol import Py4JJavaError, Py4JNetworkError
os.environ['PYSPARK_ALLOW_INSECURE_GATEWAY']='1'
import matplotlib
import os
os.environ['PYSPARK_ALLOW_INSECURE_GATEWAY']='1'
matplotlib.use('Agg')
zipPaths = sys.argv[3]
paths = zipPaths.split(':')
for i in range(len(paths)):
sys.path.insert(0, paths[i])
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
import base64
from io import BytesIO
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, Row
class Logger(object):
def __init__(self):
self.out = ""
def write(self, message):
intp.appendOutput(message)
def reset(self):
self.out = ""
def flush(self):
pass
class ErrorLogger(object):
def __init__(self):
self.out = ""
def write(self, message):
intp.appendErrorOutput(message)
def reset(self):
self.out = ""
def flush(self):
pass
class SparkVersion(object):
SPARK_1_4_0 = 140
SPARK_1_3_0 = 130
def __init__(self, versionNumber):
self.version = versionNumber
def isAutoConvertEnabled(self):
return self.version >= self.SPARK_1_4_0
def isImportAllPackageUnderSparkSql(self):
return self.version >= self.SPARK_1_3_0
output = Logger()
errorOutput = ErrorLogger()
sys.stdout = output
sys.stderr = errorOutput
client = GatewayClient(port=int(sys.argv[1]))
sparkVersion = SparkVersion(int(sys.argv[2]))
if sparkVersion.isAutoConvertEnabled():
gateway = JavaGateway(client, auto_convert = True)
else:
gateway = JavaGateway(client)
java_import(gateway.jvm, "org.apache.spark.SparkEnv")
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
intp = gateway.entry_point
if sparkVersion.isImportAllPackageUnderSparkSql():
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
else:
java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
jobGroup = ""
def show(obj):
from pyspark.sql import DataFrame
if isinstance(obj, DataFrame):
# print(intp.showDF(jobGroup, obj._jdf))
intp.showDF(jobGroup, obj._jdf)
else:
print((str(obj)))
def printlog(obj):
try:
intp.printLog(obj)
except Exception as e:
print("send log failed")
def showAlias(obj,alias):
from pyspark.sql import DataFrame
if isinstance(obj, DataFrame):
# print(intp.showDF(jobGroup, obj._jdf))
intp.showAliasDF(jobGroup, obj._jdf,alias)
else:
print((str(obj)))
def show_matplotlib(p, fmt="png", width="auto", height="auto", **kwargs):
"""Matplotlib show function
"""
if fmt == "png":
img = BytesIO()
p.savefig(img, format=fmt)
img_str = b"data:image/png;base64,"
img_str += base64.b64encode(img.getvalue().strip())
img_tag = "<img src={img} style='width={width};height:{height}'>"
# Decoding is necessary for Python 3 compability
img_str = img_str.decode("utf-8")
img_str = img_tag.format(img=img_str, width=width, height=height)
elif fmt == "svg":
img = StringIO()
p.savefig(img, format=fmt)
img_str = img.getvalue()
else:
raise ValueError("fmt must be 'png' or 'svg'")
html = "<div style='width:{width};height:{height}'>{img}<div>"
intp.showHTML(jobGroup,html.format(width=width, height=height, img=img_str))
img.close()
def saveDFToCsv(df, path, hasheader=True,isOverwrite=False,option={}):
from pyspark.sql import DataFrame
from py4j.java_collections import MapConverter
if isinstance(df, DataFrame):
intp.saveDFToCsv(df._jdf, path, hasheader, isOverwrite, MapConverter().convert(option,gateway._gateway_client))
else:
print(str(df))
java_import(gateway.jvm, "scala.Tuple2")
jsc = intp.getJavaSparkContext()
jconf = intp.getSparkConf()
conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
sqlc = HiveContext(sc, intp.sqlContext())
sqlContext = sqlc
spark = SparkSession(sc, intp.getSparkSession())
##add pyfiles
try:
pyfile = sys.argv[4]
pyfiles = pyfile.split(',')
for i in range(len(pyfiles)):
if ""!=pyfiles[i]:
sc.addPyFile(pyfiles[i])
except Exception as e:
print("add pyfile error: " + pyfile)
class UDF(object):
def __init__(self, intp, sqlc):
self.intp = intp
self.sqlc = sqlc
def register(self, udfName, udf):
self.sqlc.registerFunction(udfName, udf)
def listUDFs(self):
self.intp.listUDFs()
def existsUDF(self, name):
self.intp.existsUDF(name)
udf = UDF(intp, sqlc)
intp.onPythonScriptInitialized(os.getpid())
while True :
req = intp.getStatements()
try:
stmts = req.statements().split("\n")
jobGroup = req.jobGroup()
final_code = None
for bdp_dwc_s in stmts:
if bdp_dwc_s == None:
continue
# skip comment
s_stripped = bdp_dwc_s.strip()
if len(s_stripped) == 0 or s_stripped.startswith("#"):
continue
if final_code:
final_code += "\n" + bdp_dwc_s
else:
final_code = bdp_dwc_s
if final_code:
compiledCode = compile(final_code, "<string>", "exec")
sc.setJobGroup(jobGroup, final_code)
eval(compiledCode)
intp.setStatementsFinished("", False)
except Py4JJavaError:
excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
innerErrorStart = excInnerError.find("Py4JJavaError:")
if innerErrorStart > -1:
excInnerError = excInnerError[innerErrorStart:]
intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True)
except Py4JNetworkError:
# lost connection from gateway server. exit
intp.setStatementsFinished(msg, True)
sys.exit(1)
except:
msg = traceback.format_exc()
intp.setStatementsFinished(msg, True)
output.reset()