blob: 9f3e236826ac2458ab1a8859fc79e0f898f98c1d [file] [log] [blame]
import logging
from contextlib import asynccontextmanager
import aiohttp
import async_module
import fastapi
from aiohttp import client_exceptions
from hamilton import async_driver
from hamilton_sdk import adapters
logger = logging.getLogger(__name__)
# can instantiate a driver once for the life of the app:
dr_with_tracking: async_driver.AsyncDriver = None
dr_without_tracking: async_driver.AsyncDriver = None
async def _tracking_server_running():
"""Quickly tells if the tracking server is up and running"""
async with aiohttp.ClientSession() as session:
try:
async with session.get("http://localhost:8241/api/v1/ping") as response:
if response.status == 200:
return True
else:
return False
except client_exceptions.ClientConnectionError:
return False
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
"""Fast API lifespan context manager for setting up the driver and tracking adapters
This has to be done async as there are initializers
"""
global dr_with_tracking
global dr_without_tracking
builder = async_driver.Builder().with_modules(async_module)
is_server_running = await _tracking_server_running()
dr_without_tracking = await builder.build()
dr_with_tracking = (
await builder.with_adapters(
adapters.AsyncHamiltonTracker(
project_id=1,
username="elijah",
dag_name="async_tracker",
)
).build()
if is_server_running
else None
)
if not is_server_running:
logger.warning(
"Tracking server is not running, skipping telemetry. To run the telemetry server, run hamilton ui. "
"Note you must have a project with ID 1 if it is running -- if not, you can change the project "
"ID in this file or create a new one from the UI. Then make sure to restart this server."
)
yield
app = fastapi.FastAPI(lifespan=lifespan)
@app.post("/execute")
async def call_without_tracker(request: fastapi.Request) -> dict:
"""Handler for pipeline call -- this does not track in the Hamilton UI"""
input_data = {"request": request}
result = await dr_without_tracking.execute(["pipeline"], inputs=input_data)
# dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data)
return result
@app.post("/execute_tracker")
async def call_with_tracker(request: fastapi.Request) -> dict:
"""Handler for pipeline call -- this does track in the Hamilton UI."""
input_data = {"request": request}
if dr_with_tracking is None:
raise ValueError(
"Tracking driver not initialized -- you must have the tracking server running at app startup to use this endpoint."
)
result = await dr_with_tracking.execute(["pipeline"], inputs=input_data)
# dr.visualize_execution(["pipeline"], "./pipeline.dot", {"format": "png"}, inputs=input_data)
return result
if __name__ == "__main__":
# If you run this as a script, then the app will be started on localhost:8000
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)