blob: 2056bae1f9155b2a769e0b52eb9c06d39c3c1363 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example of invisibly propagating a request ID with middleware."""
import argparse
import sys
import threading
import uuid
import pyarrow as pa
import pyarrow.flight as flight
class TraceContext:
_locals = threading.local()
_locals.trace_id = None
@classmethod
def current_trace_id(cls):
if not getattr(cls._locals, "trace_id", None):
cls.set_trace_id(uuid.uuid4().hex)
return cls._locals.trace_id
@classmethod
def set_trace_id(cls, trace_id):
cls._locals.trace_id = trace_id
TRACE_HEADER = "x-tracing-id"
class TracingServerMiddleware(flight.ServerMiddleware):
def __init__(self, trace_id):
self.trace_id = trace_id
def sending_headers(self):
return {
TRACE_HEADER: self.trace_id,
}
class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory):
def start_call(self, info, headers):
print("Starting new call:", info)
if TRACE_HEADER in headers:
trace_id = headers[TRACE_HEADER][0]
print("Found trace header with value:", trace_id)
TraceContext.set_trace_id(trace_id)
return TracingServerMiddleware(TraceContext.current_trace_id())
class TracingClientMiddleware(flight.ClientMiddleware):
def sending_headers(self):
print("Sending trace ID:", TraceContext.current_trace_id())
return {
"x-tracing-id": TraceContext.current_trace_id(),
}
def received_headers(self, headers):
if TRACE_HEADER in headers:
trace_id = headers[TRACE_HEADER][0]
print("Found trace header with value:", trace_id)
# Don't overwrite our trace ID
class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory):
def start_call(self, info):
print("Starting new call:", info)
return TracingClientMiddleware()
class FlightServer(flight.FlightServerBase):
def __init__(self, delegate, **kwargs):
super().__init__(**kwargs)
if delegate:
self.delegate = flight.connect(
delegate,
middleware=(TracingClientMiddlewareFactory(),))
else:
self.delegate = None
def list_actions(self, context):
return [
("get-trace-id", "Get the trace context ID."),
]
def do_action(self, context, action):
trace_middleware = context.get_middleware("trace")
if trace_middleware:
TraceContext.set_trace_id(trace_middleware.trace_id)
if action.type == "get-trace-id":
if self.delegate:
for result in self.delegate.do_action(action):
yield result
else:
trace_id = TraceContext.current_trace_id().encode("utf-8")
print("Returning trace ID:", trace_id)
buf = pa.py_buffer(trace_id)
yield pa.flight.Result(buf)
else:
raise KeyError(f"Unknown action {action.type!r}")
def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
client = subparsers.add_parser("client", help="Run the client.")
client.add_argument("server")
client.add_argument("--request-id", default=None)
server = subparsers.add_parser("server", help="Run the server.")
server.add_argument(
"--listen",
required=True,
help="The location to listen on (example: grpc://localhost:5050)",
)
server.add_argument(
"--delegate",
required=False,
default=None,
help=("A location to delegate to. That is, this server will "
"simply call the given server for the response. Demonstrates "
"propagation of the trace ID between servers."),
)
args = parser.parse_args()
if not getattr(args, "command"):
parser.print_help()
return 1
if args.command == "server":
server = FlightServer(
args.delegate,
location=args.listen,
middleware={"trace": TracingServerMiddlewareFactory()})
server.serve()
elif args.command == "client":
client = flight.connect(
args.server,
middleware=(TracingClientMiddlewareFactory(),))
if args.request_id:
TraceContext.set_trace_id(args.request_id)
else:
TraceContext.set_trace_id("client-chosen-id")
for result in client.do_action(flight.Action("get-trace-id", b"")):
print(result.body.to_pybytes())
if __name__ == "__main__":
sys.exit(main() or 0)