| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Example of invisibly propagating a request ID with middleware.""" |
| |
| import argparse |
| import sys |
| import threading |
| import uuid |
| |
| import pyarrow as pa |
| import pyarrow.flight as flight |
| |
| |
| class TraceContext: |
| _locals = threading.local() |
| _locals.trace_id = None |
| |
| @classmethod |
| def current_trace_id(cls): |
| if not getattr(cls._locals, "trace_id", None): |
| cls.set_trace_id(uuid.uuid4().hex) |
| return cls._locals.trace_id |
| |
| @classmethod |
| def set_trace_id(cls, trace_id): |
| cls._locals.trace_id = trace_id |
| |
| |
| TRACE_HEADER = "x-tracing-id" |
| |
| |
| class TracingServerMiddleware(flight.ServerMiddleware): |
| def __init__(self, trace_id): |
| self.trace_id = trace_id |
| |
| def sending_headers(self): |
| return { |
| TRACE_HEADER: self.trace_id, |
| } |
| |
| |
| class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): |
| def start_call(self, info, headers): |
| print("Starting new call:", info) |
| if TRACE_HEADER in headers: |
| trace_id = headers[TRACE_HEADER][0] |
| print("Found trace header with value:", trace_id) |
| TraceContext.set_trace_id(trace_id) |
| return TracingServerMiddleware(TraceContext.current_trace_id()) |
| |
| |
| class TracingClientMiddleware(flight.ClientMiddleware): |
| def sending_headers(self): |
| print("Sending trace ID:", TraceContext.current_trace_id()) |
| return { |
| "x-tracing-id": TraceContext.current_trace_id(), |
| } |
| |
| def received_headers(self, headers): |
| if TRACE_HEADER in headers: |
| trace_id = headers[TRACE_HEADER][0] |
| print("Found trace header with value:", trace_id) |
| # Don't overwrite our trace ID |
| |
| |
| class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): |
| def start_call(self, info): |
| print("Starting new call:", info) |
| return TracingClientMiddleware() |
| |
| |
| class FlightServer(flight.FlightServerBase): |
| def __init__(self, delegate, **kwargs): |
| super().__init__(**kwargs) |
| if delegate: |
| self.delegate = flight.connect( |
| delegate, |
| middleware=(TracingClientMiddlewareFactory(),)) |
| else: |
| self.delegate = None |
| |
| def list_actions(self, context): |
| return [ |
| ("get-trace-id", "Get the trace context ID."), |
| ] |
| |
| def do_action(self, context, action): |
| trace_middleware = context.get_middleware("trace") |
| if trace_middleware: |
| TraceContext.set_trace_id(trace_middleware.trace_id) |
| if action.type == "get-trace-id": |
| if self.delegate: |
| for result in self.delegate.do_action(action): |
| yield result |
| else: |
| trace_id = TraceContext.current_trace_id().encode("utf-8") |
| print("Returning trace ID:", trace_id) |
| buf = pa.py_buffer(trace_id) |
| yield pa.flight.Result(buf) |
| else: |
| raise KeyError(f"Unknown action {action.type!r}") |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser() |
| |
| subparsers = parser.add_subparsers(dest="command") |
| client = subparsers.add_parser("client", help="Run the client.") |
| client.add_argument("server") |
| client.add_argument("--request-id", default=None) |
| |
| server = subparsers.add_parser("server", help="Run the server.") |
| server.add_argument( |
| "--listen", |
| required=True, |
| help="The location to listen on (example: grpc://localhost:5050)", |
| ) |
| server.add_argument( |
| "--delegate", |
| required=False, |
| default=None, |
| help=("A location to delegate to. That is, this server will " |
| "simply call the given server for the response. Demonstrates " |
| "propagation of the trace ID between servers."), |
| ) |
| |
| args = parser.parse_args() |
| if not getattr(args, "command"): |
| parser.print_help() |
| return 1 |
| |
| if args.command == "server": |
| server = FlightServer( |
| args.delegate, |
| location=args.listen, |
| middleware={"trace": TracingServerMiddlewareFactory()}) |
| server.serve() |
| elif args.command == "client": |
| client = flight.connect( |
| args.server, |
| middleware=(TracingClientMiddlewareFactory(),)) |
| if args.request_id: |
| TraceContext.set_trace_id(args.request_id) |
| else: |
| TraceContext.set_trace_id("client-chosen-id") |
| |
| for result in client.do_action(flight.Action("get-trace-id", b"")): |
| print(result.body.to_pybytes()) |
| |
| |
| if __name__ == "__main__": |
| sys.exit(main() or 0) |