[Fix][Plugin] sw_flask general exceptions handled (#93)
* sw_flask fix will handle errors like returning the wrong type from a handler or other internal errors.
* Updated StackedSpan to track depth, and made depth variable instance instead of class level (this was a bug).
* Changed how SpanContext decides when all spans finished to write Segment data, now counts span start / stops which should work better across different async scenarios.
* Changed new_exit_span() with span.inject() to work simpler like the NodeJS agent, now plugins inject directly themselves if they need to.
* Removed carrier from plugins which didn't actually use it.
diff --git a/Makefile b/Makefile
index 1a8e9fe..8dd8fee 100644
--- a/Makefile
+++ b/Makefile
@@ -32,7 +32,7 @@
lint: clean
flake8 --version || python3 -m pip install flake8
flake8 . --count --select=E9,F63,F7,F82 --show-source
- flake8 . --count --max-complexity=12 --max-line-length=120
+ flake8 . --count --max-complexity=13 --max-line-length=120
license: clean
python3 tools/check-license-header.py skywalking tests tools
diff --git a/skywalking/plugins/sw_flask.py b/skywalking/plugins/sw_flask.py
index 3b7e16c..037f6cf 100644
--- a/skywalking/plugins/sw_flask.py
+++ b/skywalking/plugins/sw_flask.py
@@ -26,8 +26,8 @@
def install():
from flask import Flask
_full_dispatch_request = Flask.full_dispatch_request
-
_handle_user_exception = Flask.handle_user_exception
+ _handle_exception = Flask.handle_exception
def params_tostring(params):
return "\n".join([k + '=[' + ",".join(params.getlist(k)) + ']' for k, _ in params.items()])
@@ -66,5 +66,14 @@
return _handle_user_exception(this, e)
+ def _sw_handle_exception(this: Flask, e):
+ if e is not None:
+ entry_span = get_context().active_span()
+ if entry_span is not None and type(entry_span) is not NoopSpan:
+ entry_span.raised()
+
+ return _handle_exception(this, e)
+
Flask.full_dispatch_request = _sw_full_dispatch_request
Flask.handle_user_exception = _sw_handle_user_exception
+ Flask.handle_exception = _sw_handle_exception
diff --git a/skywalking/plugins/sw_kafka.py b/skywalking/plugins/sw_kafka.py
index 7ca1f5f..201f6a4 100644
--- a/skywalking/plugins/sw_kafka.py
+++ b/skywalking/plugins/sw_kafka.py
@@ -72,18 +72,15 @@
peer = ";".join(this.config["bootstrap_servers"])
context = get_context()
- carrier = Carrier()
- with context.new_exit_span(op="Kafka/" + topic + "/Producer" or "/", peer=peer, carrier=carrier) as span:
+ with context.new_exit_span(op="Kafka/" + topic + "/Producer" or "/", peer=peer) as span:
+ carrier = span.inject()
span.layer = Layer.MQ
span.component = Component.KafkaProducer
if headers is None:
headers = []
- for item in carrier:
- headers.append((item.key, item.val.encode("utf-8")))
- else:
- for item in carrier:
- headers.append((item.key, item.val.encode("utf-8")))
+ for item in carrier:
+ headers.append((item.key, item.val.encode("utf-8")))
res = _send(this, topic, value=value, key=key, headers=headers, partition=partition,
timestamp_ms=timestamp_ms)
diff --git a/skywalking/plugins/sw_pymongo.py b/skywalking/plugins/sw_pymongo.py
index 01e3763..8aae500 100644
--- a/skywalking/plugins/sw_pymongo.py
+++ b/skywalking/plugins/sw_pymongo.py
@@ -17,7 +17,6 @@
from skywalking import Layer, Component, config
from skywalking.trace import tags
-from skywalking.trace.carrier import Carrier
from skywalking.trace.context import get_context
from skywalking.trace.tags import Tag
@@ -56,11 +55,10 @@
address = this.sock.getpeername()
peer = "%s:%s" % address
context = get_context()
- carrier = Carrier()
operation = list(spec.keys())[0]
sw_op = operation.capitalize() + "Operation"
- with context.new_exit_span(op="MongoDB/" + sw_op, peer=peer, carrier=carrier) as span:
+ with context.new_exit_span(op="MongoDB/" + sw_op, peer=peer) as span:
result = _command(this, dbname, spec, *args, **kwargs)
span.layer = Layer.Database
@@ -108,10 +106,9 @@
address = this.collection.database.client.address
peer = "%s:%s" % address
context = get_context()
- carrier = Carrier()
sw_op = "MixedBulkWriteOperation"
- with context.new_exit_span(op="MongoDB/"+sw_op, peer=peer, carrier=carrier) as span:
+ with context.new_exit_span(op="MongoDB/"+sw_op, peer=peer) as span:
span.layer = Layer.Database
span.component = Component.MongoDB
@@ -144,10 +141,9 @@
peer = "%s:%s" % address
context = get_context()
- carrier = Carrier()
op = "FindOperation"
- with context.new_exit_span(op="MongoDB/"+op, peer=peer, carrier=carrier) as span:
+ with context.new_exit_span(op="MongoDB/"+op, peer=peer) as span:
span.layer = Layer.Database
span.component = Component.MongoDB
diff --git a/skywalking/plugins/sw_pymysql.py b/skywalking/plugins/sw_pymysql.py
index 4adcabf..cf32160 100644
--- a/skywalking/plugins/sw_pymysql.py
+++ b/skywalking/plugins/sw_pymysql.py
@@ -17,7 +17,6 @@
from skywalking import Layer, Component, config
from skywalking.trace import tags
-from skywalking.trace.carrier import Carrier
from skywalking.trace.context import get_context
from skywalking.trace.tags import Tag
@@ -31,8 +30,7 @@
peer = "%s:%s" % (this.connection.host, this.connection.port)
context = get_context()
- carrier = Carrier()
- with context.new_exit_span(op="Mysql/PyMsql/execute", peer=peer, carrier=carrier) as span:
+ with context.new_exit_span(op="Mysql/PyMsql/execute", peer=peer) as span:
span.layer = Layer.Database
span.component = Component.PyMysql
res = _execute(this, query, args)
diff --git a/skywalking/plugins/sw_rabbitmq.py b/skywalking/plugins/sw_rabbitmq.py
index 9e87642..e595d74 100644
--- a/skywalking/plugins/sw_rabbitmq.py
+++ b/skywalking/plugins/sw_rabbitmq.py
@@ -39,22 +39,18 @@
mandatory=False):
peer = '%s:%s' % (this.connection.params.host, this.connection.params.port)
context = get_context()
- carrier = Carrier()
import pika
with context.new_exit_span(op="RabbitMQ/Topic/" + exchange + "/Queue/" + routing_key + "/Producer" or "/",
- peer=peer, carrier=carrier) as span:
+ peer=peer) as span:
+ carrier = span.inject()
span.layer = Layer.MQ
span.component = Component.RabbitmqProducer
properties = pika.BasicProperties() if properties is None else properties
if properties.headers is None:
- headers = {}
- for item in carrier:
- headers[item.key] = item.val
- properties.headers = headers
- else:
- for item in carrier:
- properties.headers[item.key] = item.val
+ properties.headers = {}
+ for item in carrier:
+ properties.headers[item.key] = item.val
res = _basic_publish(this, exchange,
routing_key,
diff --git a/skywalking/plugins/sw_requests.py b/skywalking/plugins/sw_requests.py
index 52937a5..69397b5 100644
--- a/skywalking/plugins/sw_requests.py
+++ b/skywalking/plugins/sw_requests.py
@@ -17,7 +17,6 @@
from skywalking import Layer, Component
from skywalking.trace import tags
-from skywalking.trace.carrier import Carrier
from skywalking.trace.context import get_context
from skywalking.trace.tags import Tag
from skywalking import config
@@ -44,18 +43,15 @@
hooks, stream, verify, cert, json)
context = get_context()
- carrier = Carrier()
- with context.new_exit_span(op=url_param.path or "/", peer=url_param.netloc, carrier=carrier) as span:
+ with context.new_exit_span(op=url_param.path or "/", peer=url_param.netloc) as span:
+ carrier = span.inject()
span.layer = Layer.Http
span.component = Component.Requests
if headers is None:
headers = {}
- for item in carrier:
- headers[item.key] = item.val
- else:
- for item in carrier:
- headers[item.key] = item.val
+ for item in carrier:
+ headers[item.key] = item.val
span.tag(Tag(key=tags.HttpMethod, val=method.upper()))
span.tag(Tag(key=tags.HttpUrl, val=url))
diff --git a/skywalking/plugins/sw_urllib3.py b/skywalking/plugins/sw_urllib3.py
index 446d453..9f3f1cc 100644
--- a/skywalking/plugins/sw_urllib3.py
+++ b/skywalking/plugins/sw_urllib3.py
@@ -17,7 +17,6 @@
from skywalking import Layer, Component
from skywalking.trace import tags
-from skywalking.trace.carrier import Carrier
from skywalking.trace.context import get_context
from skywalking.trace.tags import Tag
@@ -31,19 +30,16 @@
from urllib.parse import urlparse
url_param = urlparse(url)
- carrier = Carrier()
context = get_context()
- with context.new_exit_span(op=url_param.path or "/", peer=url_param.netloc, carrier=carrier) as span:
+ with context.new_exit_span(op=url_param.path or "/", peer=url_param.netloc) as span:
+ carrier = span.inject()
span.layer = Layer.Http
span.component = Component.Urllib3
if headers is None:
headers = {}
- for item in carrier:
- headers[item.key] = item.val
- else:
- for item in carrier:
- headers[item.key] = item.val
+ for item in carrier:
+ headers[item.key] = item.val
span.tag(Tag(key=tags.HttpMethod, val=method.upper()))
span.tag(Tag(key=tags.HttpUrl, val=url))
diff --git a/skywalking/plugins/sw_urllib_request.py b/skywalking/plugins/sw_urllib_request.py
index 8aec5cd..66d138e 100644
--- a/skywalking/plugins/sw_urllib_request.py
+++ b/skywalking/plugins/sw_urllib_request.py
@@ -19,7 +19,6 @@
from skywalking import Layer, Component
from skywalking.trace import tags
-from skywalking.trace.carrier import Carrier
from skywalking.trace.context import get_context
from skywalking.trace.tags import Tag
@@ -36,9 +35,9 @@
fullurl = Request(fullurl, data)
context = get_context()
- carrier = Carrier()
url = fullurl.selector.split("?")[0] if fullurl.selector else '/'
- with context.new_exit_span(op=url, peer=fullurl.host, carrier=carrier) as span:
+ with context.new_exit_span(op=url, peer=fullurl.host) as span:
+ carrier = span.inject()
span.layer = Layer.Http
span.component = Component.General
code = None
diff --git a/skywalking/trace/carrier.py b/skywalking/trace/carrier.py
index d485ca4..d4866d2 100644
--- a/skywalking/trace/carrier.py
+++ b/skywalking/trace/carrier.py
@@ -44,18 +44,22 @@
class Carrier(CarrierItem):
- def __init__(self):
+ def __init__(self, trace_id: str = '', segment_id: str = '', span_id: str = '', service: str = '',
+ service_instance: str = '', endpoint: str = '', client_address: str = '',
+ correlation: dict = None): # pyre-ignore
super(Carrier, self).__init__(key='sw8')
- self.trace_id = '' # type: str
- self.segment_id = '' # type: str
- self.span_id = '' # type: str
- self.service = '' # type: str
- self.service_instance = '' # type: str
- self.endpoint = '' # type: str
- self.client_address = '' # type: str
+ self.trace_id = trace_id # type: str
+ self.segment_id = segment_id # type: str
+ self.span_id = span_id # type: str
+ self.service = service # type: str
+ self.service_instance = service_instance # type: str
+ self.endpoint = endpoint # type: str
+ self.client_address = client_address # type: str
self.correlation_carrier = SW8CorrelationCarrier()
self.items = [self.correlation_carrier, self] # type: List[CarrierItem]
self.__iter_index = 0 # type: int
+ if correlation is not None:
+ self.correlation_carrier.correlation = correlation
@property
def val(self) -> str:
diff --git a/skywalking/trace/context.py b/skywalking/trace/context.py
index 814fca5..af7d9fa 100644
--- a/skywalking/trace/context.py
+++ b/skywalking/trace/context.py
@@ -74,6 +74,7 @@
self.segment = Segment() # type: Segment
self._sid = Counter()
self._correlation = {} # type: dict
+ self._nspans = 0
def new_local_span(self, op: str) -> Span:
span = self.ignore_check(op, Kind.Local)
@@ -111,7 +112,7 @@
return span
- def new_exit_span(self, op: str, peer: str, carrier: 'Carrier' = None) -> Span:
+ def new_exit_span(self, op: str, peer: str) -> Span:
span = self.ignore_check(op, Kind.Exit)
if span is not None:
return span
@@ -127,9 +128,6 @@
peer=peer,
)
- if carrier is not None:
- span.inject(carrier=carrier)
-
return span
def ignore_check(self, op: str, kind: Kind):
@@ -150,22 +148,23 @@
return None
def start(self, span: Span):
+ self._nspans += 1
spans = _spans()
if span not in spans:
spans.append(span)
def stop(self, span: Span) -> bool:
spans = _spans()
- idx = spans.index(span) # span SHOULD now always be at end even in async-world, but just in case
+ span.finish(self.segment)
+ del spans[spans.index(span)]
- if span.finish(self.segment):
- del spans[idx]
-
- if len(spans) == 0:
+ self._nspans -= 1
+ if self._nspans == 0:
_local().context = None
agent.archive(self.segment)
+ return True
- return len(spans) == 0
+ return False
def active_span(self):
spans = _spans()
@@ -231,10 +230,7 @@
self._noop_span.extract(carrier)
return self._noop_span
- def new_exit_span(self, op: str, peer: str, carrier: 'Carrier' = None) -> Span:
- if carrier is not None:
- self._noop_span.inject(carrier)
-
+ def new_exit_span(self, op: str, peer: str) -> Span:
return self._noop_span
def start(self, span: Span):
diff --git a/skywalking/trace/span.py b/skywalking/trace/span.py
index 2d2fcf5..c193168 100644
--- a/skywalking/trace/span.py
+++ b/skywalking/trace/span.py
@@ -97,7 +97,7 @@
return self
- def inject(self, carrier: 'Carrier') -> 'Span':
+ def inject(self) -> 'Carrier':
raise RuntimeWarning(
'can only inject context carrier into ExitSpan, this may be a potential bug in the agent, '
'please report this in https://github.com/apache/skywalking/issues if you encounter this. '
@@ -126,11 +126,19 @@
@tostring
class StackedSpan(Span):
- _depth = 0
+ def __init__(self, *args, **kwargs):
+ Span.__init__(self, *args, **kwargs)
+ self._depth = 0
- def finish(self, segment: 'Segment') -> bool:
+ def start(self):
+ self._depth += 1
+ if self._depth == 1:
+ Span.start(self)
+
+ def stop(self):
self._depth -= 1
- return self._depth == 0 and Span.finish(self, segment)
+ if self._depth == 0:
+ Span.stop(self)
@tostring
@@ -159,10 +167,8 @@
self._max_depth = 0
def start(self):
- self._depth += 1
+ StackedSpan.start(self)
self._max_depth = self._depth
- if self._max_depth == 1:
- StackedSpan.start(self)
self.component = 0
self.layer = Layer.Unknown
self.logs = []
@@ -206,20 +212,17 @@
layer,
)
- def inject(self, carrier: 'Carrier') -> 'Span':
- carrier.trace_id = str(self.context.segment.related_traces[0])
- carrier.segment_id = str(self.context.segment.segment_id)
- carrier.span_id = str(self.sid)
- carrier.service = config.service_name
- carrier.service_instance = config.service_instance
- carrier.endpoint = self.op
- carrier.client_address = self.peer
- carrier.correlation_carrier.correlation = self.context._correlation
- return self
-
- def start(self):
- self._depth += 1
- StackedSpan.start(self)
+ def inject(self) -> 'Carrier':
+ return Carrier(
+ trace_id=str(self.context.segment.related_traces[0]),
+ segment_id=str(self.context.segment.segment_id),
+ span_id=str(self.sid),
+ service=config.service_name,
+ service_instance=config.service_instance,
+ endpoint=self.op,
+ client_address=self.peer,
+ correlation=self.context._correlation,
+ )
@tostring
@@ -231,6 +234,5 @@
if carrier is not None:
self.context._correlation = carrier.correlation_carrier.correlation
- def inject(self, carrier: 'Carrier') -> 'Span':
- carrier.correlation_carrier.correlation = self.context._correlation
- return self
+ def inject(self) -> 'Carrier':
+ return Carrier(correlation=self.context._correlation)