blob: 3cb6a44068306d90c5705e936aedff516342506c [file]
"""FastAPI routes for Supervisor Agent"""
import json
import uuid
import time
import asyncio
from pathlib import Path
from typing import AsyncGenerator, Dict, Any
from fastapi import APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
from sse_starlette.sse import EventSourceResponse
from common.rocketmq.rocketmq_utils import logger
from supervisor_agent.utils.constants.constants import (
SESSION_KEY_TRACE_ID,
SESSION_KEY_USER_INPUT,
SESSION_KEY_CREATED_AT,
SESSION_KEY_STATUS,
SESSION_KEY_DISCONNECTED_AT,
SESSION_KEY_WEATHER_TRACE_ID,
SESSION_KEY_TRAVEL_TRACE_ID,
SESSION_KEY_INTENT,
SESSION_STATUS_ACTIVE,
SESSION_STATUS_DISCONNECTED,
SESSION_STATUS_COMPLETED,
STATE_KEY_WEATHER_TRACE_ID,
STATE_KEY_TRAVEL_TRACE_ID,
MSG_METADATA_IS_FINAL,
MSG_METADATA_ERROR,
MSG_METADATA_CHUNK_INDEX,
SSE_EVENT_TYPE_START,
SSE_EVENT_TYPE_CHUNK,
SSE_EVENT_TYPE_ERROR,
SSE_EVENT_TYPE_RECONNECTED,
SSE_EVENT_DONE,
TRACE_PREFIX_MAIN,
NODE_CHAT, STATE_KEY_TRACE_ID, STATE_KEY_SESSION_ID, STATE_KEY_USER_INPUT, STATE_KEY_INTENT, STATE_KEY_CITY,
STATE_KEY_DATE_INFO, STATE_KEY_WEATHER_DATA, STATE_KEY_FINAL_RESPONSE, STATE_KEY_WEATHER_COMPLETE
)
from supervisor_agent.utils.rocketmq.mq_service import subscribe_lite_topic, unsubscribe_lite_topic
from supervisor_agent.utils.stream.stream_manager import stream_queue_manager
from supervisor_agent.utils.workflow.workflow_graph import build_workflow
from supervisor_agent.utils.session.session_manager import session_manager
router = APIRouter()
app_graph = build_workflow()
# Get the project root directory
PROJECT_ROOT = Path(__file__).parent.parent
STATIC_FILE = PROJECT_ROOT / "static" / "index.html"
@router.get("/", response_class=HTMLResponse)
async def read_root():
"""Serve frontend HTML page"""
with open(STATIC_FILE, "r", encoding="utf-8") as f:
return f.read()
# ==================== Core Business Logic ====================
async def execute_langgraph_workflow(session_id: str, user_input: str, main_trace_id: str) -> Dict[str, Any]:
"""
Execute LangGraph workflow and register sub-traces.
Args:
session_id: Session identifier
user_input: User's input message
main_trace_id: Main trace ID for this conversation
Returns:
Workflow execution result metadata
"""
initial_state = {
STATE_KEY_TRACE_ID: main_trace_id,
STATE_KEY_SESSION_ID: session_id,
STATE_KEY_USER_INPUT: user_input,
STATE_KEY_INTENT: "",
STATE_KEY_CITY: "",
STATE_KEY_DATE_INFO: "",
STATE_KEY_WEATHER_DATA: "",
STATE_KEY_FINAL_RESPONSE: "",
STATE_KEY_WEATHER_TRACE_ID: None,
STATE_KEY_TRAVEL_TRACE_ID: None,
STATE_KEY_WEATHER_COMPLETE: False
}
active_traces = set()
is_chat_mode = False
try:
async for event in app_graph.astream(initial_state):
for node_name, output in event.items():
logger.info(f"Graph node executed: {node_name}, output: {output}")
# Detect chat mode
if node_name == NODE_CHAT:
is_chat_mode = True
logger.info("Detected chat mode, will wait for streaming completion")
# Register weather sub-trace for streaming
if STATE_KEY_WEATHER_TRACE_ID in output and output[STATE_KEY_WEATHER_TRACE_ID]:
weather_tid = output[STATE_KEY_WEATHER_TRACE_ID]
active_traces.add(weather_tid)
stream_queue_manager.register_sub_trace(weather_tid, main_trace_id)
logger.info(f"Registered weather trace: {weather_tid} -> {main_trace_id}")
# Save to session metadata for reconnection
metadata = session_manager.get_session_metadata(session_id)
if metadata:
metadata[SESSION_KEY_WEATHER_TRACE_ID] = weather_tid
session_manager.add_session(session_id, metadata)
logger.info(f"[Session] Saved weather_trace_id to metadata: {weather_tid}")
# Register travel sub-trace for streaming
if STATE_KEY_TRAVEL_TRACE_ID in output and output[STATE_KEY_TRAVEL_TRACE_ID]:
travel_tid = output[STATE_KEY_TRAVEL_TRACE_ID]
active_traces.add(travel_tid)
stream_queue_manager.register_sub_trace(travel_tid, main_trace_id)
logger.info(f"Registered travel trace: {travel_tid} -> {main_trace_id}")
# Save to session metadata for reconnection
metadata = session_manager.get_session_metadata(session_id)
if metadata:
metadata[SESSION_KEY_TRAVEL_TRACE_ID] = travel_tid
session_manager.add_session(session_id, metadata)
logger.info(f"[Session] Saved travel_trace_id to metadata: {travel_tid}")
except Exception as e:
logger.error(f"Graph execution error: {e}", exc_info=True)
raise
return {
"active_traces": active_traces,
"is_chat_mode": is_chat_mode
}
async def stream_messages_to_client(
main_trace_id: str,
response_queue: asyncio.Queue,
graph_task: asyncio.Task = None,
max_timeout: float = 120.0
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Stream messages from queue to client via SSE.
Args:
main_trace_id: Main trace ID
response_queue: Queue containing message payloads
graph_task: Background workflow task for completion detection
max_timeout: Maximum timeout in seconds
Yields:
SSE event dictionaries
"""
active_traces = set()
completed_traces = set()
chat_stream_completed = False
stream_start_time = time.time()
def _sync_active_traces_from_graph() -> None:
"""Pull active_traces from completed graph_task to enable exit detection."""
nonlocal active_traces
if active_traces or not graph_task or not graph_task.done():
return
try:
result = graph_task.result()
if result and result.get("active_traces"):
active_traces = set(result["active_traces"])
logger.info(f"Synced active_traces from graph_task: {active_traces}")
except Exception as e:
logger.debug(f"Could not retrieve graph_task result yet: {e}")
while True:
# Check for direct chat response
chat_response = stream_queue_manager.get_chat_response(main_trace_id)
if chat_response:
logger.info(f"Sending chat response: {chat_response}")
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_CHUNK,
"role": "assistant",
"content": chat_response,
"chunk_index": 0,
"is_final": True,
"sub_trace_id": main_trace_id
})}
break
try:
# Calculate remaining timeout
remaining_timeout = max_timeout - (time.time() - stream_start_time)
if remaining_timeout <= 0:
logger.warning(f"Overall timeout reached for trace_id: {main_trace_id}")
yield {"data": json.dumps({"type": SSE_EVENT_TYPE_ERROR, "content": "响应超时"})}
break
# Wait for message payload from queue (short poll for fast graph-done detection)
payload = await asyncio.wait_for(response_queue.get(), timeout=min(remaining_timeout, 1.0))
# Extract metadata
msg_metadata = payload.metadata or {}
is_final = msg_metadata.get(MSG_METADATA_IS_FINAL, False)
is_error = msg_metadata.get(MSG_METADATA_ERROR, False)
chunk_index = msg_metadata.get(MSG_METADATA_CHUNK_INDEX, 0)
sub_trace_id = payload.trace_id
role = payload.role.value if hasattr(payload.role, 'value') else str(payload.role)
# Handle error messages
if is_error:
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_ERROR,
"role": role,
"content": payload.content,
"chunk_index": chunk_index,
"sub_trace_id": sub_trace_id
})}
completed_traces.add(sub_trace_id)
continue
# Stream content chunks
if payload.content:
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_CHUNK,
"role": role,
"content": payload.content,
"chunk_index": chunk_index,
"is_final": is_final,
"sub_trace_id": sub_trace_id
})}
# Check if sub-trace or main trace is complete
if is_final:
if sub_trace_id != main_trace_id:
# Sub-trace completed
completed_traces.add(sub_trace_id)
# Refresh active_traces from graph_task (if it has finished)
_sync_active_traces_from_graph()
logger.info(
f"Sub-trace completed: {sub_trace_id}, total completed: {len(completed_traces)}/{len(active_traces)}")
# Check if all sub-traces are done
if active_traces and completed_traces >= active_traces:
logger.info(f"All sub-traces completed for main trace: {main_trace_id}")
break
else:
# Main chat stream completed
logger.info(f"Chat streaming completed for main trace: {main_trace_id}")
chat_stream_completed = True
break
except asyncio.TimeoutError:
# Retry checking for chat response on timeout
chat_response = stream_queue_manager.get_chat_response(main_trace_id)
if chat_response:
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_CHUNK,
"role": "assistant",
"content": chat_response,
"chunk_index": 0,
"is_final": True,
"sub_trace_id": main_trace_id
})}
break
# Exit conditions
if chat_stream_completed:
logger.info("Chat stream already completed, exiting loop")
break
elif graph_task and graph_task.done():
# Graph finished — sync active_traces and exit if all sub-traces have completed
_sync_active_traces_from_graph()
if not active_traces or completed_traces >= active_traces:
logger.info(f"Graph task completed, ending stream for trace_id: {main_trace_id}")
break
# Still waiting for remaining sub-trace messages from RocketMQ
continue
elif not active_traces:
continue
elif active_traces and completed_traces >= active_traces:
break
elif time.time() - stream_start_time > max_timeout:
logger.warning(f"Timeout waiting for messages for trace_id: {main_trace_id}")
yield {"data": json.dumps({"type": SSE_EVENT_TYPE_ERROR, "content": "响应超时"})}
break
async def stream_reconnected_messages(
main_trace_id: str,
response_queue: asyncio.Queue,
intent: str,
max_timeout: float = 120.0
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Stream messages for reconnected client.
Args:
main_trace_id: Main trace ID
response_queue: Queue containing message payloads
intent: Intent type for determining final trace
max_timeout: Maximum timeout in seconds
Yields:
SSE event dictionaries
"""
start_time = time.time()
while time.time() - start_time < max_timeout:
try:
# Wait for messages from RocketMQ consumer
payload = await asyncio.wait_for(response_queue.get(), timeout=5.0)
msg_metadata = payload.metadata or {}
is_final = msg_metadata.get(MSG_METADATA_IS_FINAL, False)
is_error = msg_metadata.get(MSG_METADATA_ERROR, False)
chunk_index = msg_metadata.get(MSG_METADATA_CHUNK_INDEX, 0)
trace_id = payload.trace_id
# Check if this is the final trace based on intent
is_last_trace = intent and trace_id.startswith(intent)
role = payload.role.value if hasattr(payload.role, 'value') else str(payload.role)
# Send payload to frontend
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_ERROR if is_error else SSE_EVENT_TYPE_CHUNK,
"role": role,
"content": payload.content,
"chunk_index": chunk_index,
"is_final": is_final,
"sub_trace_id": trace_id
})}
# Stop if final message received for the last trace
if is_final and is_last_trace:
logger.info(f"[Reconnect] Final message received for trace: {main_trace_id}")
break
except asyncio.TimeoutError:
# Continue waiting for messages
continue
# ==================== SSE Event Generators ====================
async def create_chat_event_generator(session_id: str, user_input: str, main_trace_id: str) -> AsyncGenerator[Dict[str, Any], None]:
"""Create event generator for new chat sessions"""
yield {"data": json.dumps({"type": SSE_EVENT_TYPE_START, "trace_id": main_trace_id})}
# Register response queue
response_queue = stream_queue_manager.register_trace(main_trace_id)
try:
# Start LangGraph workflow in background
graph_task = asyncio.create_task(
execute_langgraph_workflow(session_id, user_input, main_trace_id)
)
# Stream messages to client
async for event in stream_messages_to_client(main_trace_id, response_queue, graph_task):
yield event
# Streaming completed normally - mark session as completed
metadata = session_manager.get_session_metadata(session_id)
if metadata:
metadata[SESSION_KEY_STATUS] = SESSION_STATUS_COMPLETED
session_manager.add_session(session_id, metadata)
logger.info(f"[Chat] Session marked as completed: {session_id}")
# Wait for graph task to finish
try:
await asyncio.wait_for(graph_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("Graph task timeout, continuing...")
except Exception as e:
logger.error(f"Event generator error: {e}", exc_info=True)
yield {"data": json.dumps({"type": SSE_EVENT_TYPE_ERROR, "content": str(e)})}
finally:
# NOTE: Do NOT cancel graph_task on client disconnect.
# Keep the LangGraph workflow running in the background, blocked on
# result_store polling. When the user calls /reconnect, the consumer
# re-subscribes the per-session lite_topic, the RocketMQ broker
# redelivers the messages buffered during the disconnected window
# (resuming from the consumer-group offset), result_store gets
# populated, and the workflow naturally advances (e.g. weather ->
# travel). The node-level 300s timeout still bounds the worst case.
if not graph_task.done():
logger.info(
f"[Disconnect] Client disconnected, keep graph_task running for trace_id: {main_trace_id}"
)
# Clean up response queue
stream_queue_manager.unregister_trace(main_trace_id, response_queue)
yield {"data": SSE_EVENT_DONE}
async def create_reconnect_event_generator(session_id: str, main_trace_id: str, intent: str) -> AsyncGenerator[Dict[str, Any], None]:
"""Create event generator for reconnected sessions"""
# Send reconnection confirmation
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_RECONNECTED,
"session_id": session_id,
"trace_id": main_trace_id
})}
# Register response queue
response_queue = stream_queue_manager.register_trace(main_trace_id)
try:
# Subscribe to RocketMQ
subscribe_lite_topic(session_id)
logger.info(f"[Reconnect] Re-subscribed to session: {session_id}")
# Stream messages to client
async for event in stream_reconnected_messages(main_trace_id, response_queue, intent):
yield event
except Exception as e:
logger.error(f"[Reconnect] Stream error: {e}", exc_info=True)
yield {"data": json.dumps({"type": SSE_EVENT_TYPE_ERROR, "content": str(e)})}
finally:
logger.info(f"[Reconnect] Cleanup response queue for trace_id: {main_trace_id}")
stream_queue_manager.unregister_trace(main_trace_id, response_queue)
yield {"data": SSE_EVENT_DONE}
# ==================== API Endpoints ====================
@router.post("/chat")
async def chat(request: dict):
"""Chat endpoint with SSE streaming support"""
user_input = request.get("message")
session_id = request.get("session_id", "")
main_trace_id = TRACE_PREFIX_MAIN + str(uuid.uuid4())
# Register session with metadata
session_manager.add_session(session_id, {
SESSION_KEY_TRACE_ID: main_trace_id,
SESSION_KEY_USER_INPUT: user_input,
SESSION_KEY_CREATED_AT: time.time(),
SESSION_KEY_STATUS: SESSION_STATUS_ACTIVE
})
logger.info(f"[Chat] Session registered: {session_id}, trace_id: {main_trace_id}")
# Subscribe to RocketMQ topic
subscribe_lite_topic(session_id)
return EventSourceResponse(
create_chat_event_generator(session_id, user_input, main_trace_id)
)
@router.post("/disconnect")
async def disconnect(request: dict):
"""Pause session: unsubscribe from RocketMQ but keep session for potential reconnection"""
session_id = request.get("session_id", "")
logger.info(f"[Disconnect] Session ID: {session_id}")
# Unsubscribe from RocketMQ topic (pause message delivery)
try:
unsubscribe_lite_topic(session_id)
logger.info(f"[Disconnect] Unsubscribed from session: {session_id}")
except Exception as e:
logger.error(f"[Disconnect] Failed to unsubscribe: {e}", exc_info=True)
# Update session status to inactive (but don't overwrite "completed" status)
metadata = session_manager.get_session_metadata(session_id)
if metadata:
current_status = metadata.get(SESSION_KEY_STATUS)
if current_status != SESSION_STATUS_COMPLETED:
metadata[SESSION_KEY_STATUS] = SESSION_STATUS_DISCONNECTED
metadata[SESSION_KEY_DISCONNECTED_AT] = time.time()
session_manager.add_session(session_id, metadata)
return JSONResponse(content={
"status": "success",
"message": "Disconnected. Session kept for reconnection.",
"session_id": session_id
})
@router.post("/end-session")
async def end_session(request: dict):
"""Completely end session: remove from session_manager and unsubscribe"""
session_id = request.get("session_id", "")
logger.info(f"[End Session] Session ID: {session_id}")
# Remove session completely
removed = session_manager.remove_session(session_id)
# Unsubscribe from RocketMQ
if removed:
try:
unsubscribe_lite_topic(session_id)
logger.info(f"[End Session] Unsubscribed and removed: {session_id}")
except Exception as e:
logger.error(f"[End Session] Failed to unsubscribe: {e}", exc_info=True)
return JSONResponse(content={
"status": "success",
"message": "Session ended and cleaned up",
"session_id": session_id,
"removed": removed
})
@router.post("/reconnect")
async def reconnect(request: dict):
"""Reconnect SSE stream and resume message delivery from RocketMQ"""
session_id = request.get("session_id", "")
logger.info(f"[Reconnect] Session ID: {session_id}")
# Re-register session (update last_active timestamp)
session_manager.add_session(session_id)
# Get main_trace_id from session metadata
metadata = session_manager.get_session_metadata(session_id)
main_trace_id = metadata.get(SESSION_KEY_TRACE_ID) if metadata else None
if not main_trace_id:
logger.warning(f"[Reconnect] No active trace found for session: {session_id}")
return JSONResponse(content={
"status": "error",
"message": "No active session found. Please start a new chat.",
"session_id": session_id
}, status_code=404)
# Get intent and sub-trace IDs from metadata
intent = metadata.get(SESSION_KEY_INTENT, "")
weather_trace_id = metadata.get(SESSION_KEY_WEATHER_TRACE_ID)
travel_trace_id = metadata.get(SESSION_KEY_TRAVEL_TRACE_ID)
# Check if session has already completed
status = metadata.get(SESSION_KEY_STATUS, SESSION_STATUS_ACTIVE)
if status == SESSION_STATUS_COMPLETED:
logger.info(f"[Reconnect] Session already completed: {session_id}")
async def completed_event_generator():
"""Send immediate completion for already-finished session"""
yield {"data": json.dumps({
"type": SSE_EVENT_TYPE_RECONNECTED,
"session_id": session_id,
"trace_id": main_trace_id
})}
yield {"data": SSE_EVENT_DONE}
return EventSourceResponse(completed_event_generator())
# Register response queue FIRST (before subscribing)
response_queue = stream_queue_manager.register_trace(main_trace_id)
logger.info(f"[Reconnect] Registered response queue for trace_id: {main_trace_id}")
# Re-register any known sub-traces from session metadata
if weather_trace_id:
stream_queue_manager.register_sub_trace(weather_trace_id, main_trace_id)
logger.info(f"[Reconnect] Re-registered weather sub-trace: {weather_trace_id} -> {main_trace_id}")
if travel_trace_id:
stream_queue_manager.register_sub_trace(travel_trace_id, main_trace_id)
logger.info(f"[Reconnect] Re-registered travel sub-trace: {travel_trace_id} -> {main_trace_id}")
return EventSourceResponse(
create_reconnect_event_generator(session_id, main_trace_id, intent)
)