blob: bb8646df2b0bf2fa100279a2ee572ed465ea4d4b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import namedtuple
import os
import traceback
CallSite = namedtuple("CallSite", "function file linenum")
def first_spark_call():
"""
Return a CallSite representing the first Spark call in the current call stack.
"""
tb = traceback.extract_stack()
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return CallSite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return CallSite(function=sfun, file=ufile, linenum=uline)
class SCCallSiteSync(object):
"""
Helper for setting the spark context call site.
Example usage:
from pyspark.context import SCCallSiteSync
with SCCallSiteSync(<relevant SparkContext>) as css:
<a Spark call>
"""
_spark_stack_depth = 0
def __init__(self, sc):
call_site = first_spark_call()
if call_site is not None:
self._call_site = "%s at %s:%s" % (
call_site.function, call_site.file, call_site.linenum)
else:
self._call_site = "Error! Could not extract traceback info"
self._context = sc
def __enter__(self):
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(self._call_site)
SCCallSiteSync._spark_stack_depth += 1
def __exit__(self, type, value, tb):
SCCallSiteSync._spark_stack_depth -= 1
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(None)