blob: 7cd3d36c4d9484c36f85b47ad258afca3d20ffe2 [file] [log] [blame]
import json
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, List, Dict, Any, Optional
from langchain_core.tools import BaseTool
from reasoning.base import BaseAgent
from reasoning.messages.llm_client import LLMClient
from reasoning.messages.openai_client import OpenAIChatClient
from reasoning.model.reasoning_result import ReasoningResult
from reasoning.tools.final_answer import FinalAnswerTool
from reasoning.tools.python_exec import PythonExecTool
from reasoning.tools.simulate_exec import SimulateTools
from reasoning.utils.load_results import load_results
from reasoning.helpers.prompts import _default_system_prompt, _append_tool_message
from reasoning.helpers.chat_utils import prune_history
from reasoning.config.tools import _openai_tools_spec
from db.config.database import DatabaseConfig
LOG_FMT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, stream=sys.stdout)
log = logging.getLogger("agent_loop")
@dataclass
class ReasoningAgent(BaseAgent):
model_id: str
db_config: DatabaseConfig = field(default_factory=lambda: DatabaseConfig())
db_path: str = ""
llm: LLMClient = field(default_factory=lambda: OpenAIChatClient(model="gpt-5-mini", temperature=1.0))
max_steps: int = 20
temperature: float = 1.0
# Hooks / callbacks (override as needed)
system_prompt_builder: Callable[[List[str]], str] = field(default=None)
history_pruner: Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]] = field(default=None)
report_store: Callable[[str, str, str, List[str]], None] = field(default=None) # (model_id, question, answer, images)
# Internal tool instances (bound to model/db and df when loaded)
_tools: Dict[str, BaseTool] = field(init=False, default_factory=dict)
def __post_init__(self) -> None:
# Set up database path
self.db_path = self.db_config.database_path
# Set defaults for optional callbacks
if self.system_prompt_builder is None:
self.system_prompt_builder = _default_system_prompt
if self.history_pruner is None:
self.history_pruner = prune_history
if self.report_store is None:
# Lazy import to avoid circulars; replace with your store_report
from db import store_report as _store_report # type: ignore
self.report_store = _store_report
# ── Public entrypoint ────────────────────────────────────────────────────
def ask(self, question: str, stop_flag: Optional[Callable[[], bool]] = None) -> ReasoningResult:
"""Main entry point for asking questions to the reasoning agent."""
log.info("=== Starting analysis for model_id=%s ===", self.model_id)
# Initialize session
df, schema = self._load_context()
self._build_tools(df)
history = self._initialize_conversation(question, schema)
# Initialize tracking
tracking_state = self._initialize_tracking()
# Main reasoning loop
return self._reasoning_loop(history, tracking_state, stop_flag, question)
def _load_context(self) -> tuple:
"""Load data context and schema."""
df = load_results(db_path=self.db_path, model_id=self.model_id)
schema = list(df.columns)
return df, schema
def _build_tools(self, df: Any) -> None:
"""Build tools bound to this session."""
self._tools = {
"python_exec": PythonExecTool(df=df),
"run_simulation_for_model": SimulateTools(db_config=self.db_config, default_model_id=self.model_id),
"run_batch_for_model": SimulateTools(db_config=self.db_config, default_model_id=self.model_id),
"final_answer": FinalAnswerTool(),
}
def _initialize_conversation(self, question: str, schema: List[str]) -> List[Dict[str, Any]]:
"""Initialize conversation history."""
return [
{"role": "system", "content": self.system_prompt_builder(schema)},
{"role": "user", "content": question},
]
def _initialize_tracking(self) -> Dict[str, Any]:
"""Initialize tracking state for images and code."""
return {
"seen_imgs": {p.name for p in Path.cwd().glob("*.png")},
"all_images": [],
"code_map": {},
"step_idx": 0
}
def _reasoning_loop(self, history: List[Dict[str, Any]], tracking_state: Dict[str, Any],
stop_flag: Optional[Callable[[], bool]], question: str) -> ReasoningResult:
"""Main reasoning loop."""
tools_spec = _openai_tools_spec()
for _ in range(self.max_steps):
if stop_flag and stop_flag():
return self._create_result(history, tracking_state, "(stopped)")
# Get LLM response
msg = self.llm.chat(messages=self.history_pruner(history), tools=tools_spec)
assistant_entry = self._create_assistant_entry(msg)
history.append(assistant_entry)
# Process tool calls if any
if assistant_entry.get("tool_calls"):
final_result = self._process_tool_calls(
history, assistant_entry["tool_calls"], tracking_state, question
)
if final_result: # final_answer was called
return final_result
continue
# Nudge if no tool call
self._nudge_for_tool_call(history)
# Loop exhausted
log.error("Agent loop exhausted without an answer")
return self._create_result(history, tracking_state, "(no answer)")
def _create_assistant_entry(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Create assistant entry from LLM message."""
assistant_entry = {"role": "assistant", "content": msg.get("content", "")}
if "tool_calls" in msg:
assistant_entry["tool_calls"] = msg["tool_calls"]
return assistant_entry
def _process_tool_calls(self, history: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]],
tracking_state: Dict[str, Any], question: str) -> Optional[ReasoningResult]:
"""Process all tool calls in the assistant message."""
for tc in tool_calls:
call_id = tc["id"]
fname = tc["function"]["name"]
raw_args = tc["function"]["arguments"] or "{}"
# Parse arguments
args = self._parse_tool_args(history, call_id, raw_args)
if args is None:
continue
# Handle code tracking for python_exec
if fname == "python_exec":
if not self._track_python_code(history, call_id, args, tracking_state):
continue
# Execute tool
result = self._execute_tool(history, call_id, fname, args)
if result is None:
continue
# Track images
self._track_images(result, tracking_state)
# Handle final answer specially
if fname == "final_answer":
return self._handle_final_answer(history, call_id, result, tracking_state, question)
# Append tool result and continue
self._append_tool_message(history, call_id, result)
return None # Continue loop
def _parse_tool_args(self, history: List[Dict[str, Any]], call_id: str, raw_args: str) -> Optional[Dict[str, Any]]:
"""Parse tool arguments from JSON string."""
try:
return json.loads(raw_args)
except json.JSONDecodeError:
self._append_tool_message(history, call_id, {"ok": False, "stderr": "Malformed tool arguments"})
return None
def _track_python_code(self, history: List[Dict[str, Any]], call_id: str,
args: Dict[str, Any], tracking_state: Dict[str, Any]) -> bool:
"""Track Python code for python_exec calls."""
code = args.get("code", "")
if not isinstance(code, str) or not code.strip():
self._append_tool_message(history, call_id, {"ok": False, "stderr": "No code provided"})
return False
tracking_state["code_map"][tracking_state["step_idx"]] = code
tracking_state["step_idx"] += 1
return True
def _execute_tool(self, history: List[Dict[str, Any]], call_id: str,
fname: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Execute a tool and return its result."""
tool = self._tools.get(fname)
if not tool:
self._append_tool_message(history, call_id, {"ok": False, "stderr": f"Unknown tool '{fname}'"})
return None
try:
return tool._run(**args) # type: ignore[arg-type]
except TypeError as e:
result = {"ok": False, "stderr": f"Bad arguments: {e}"}
except Exception as e:
result = {"ok": False, "stderr": f"{type(e).__name__}: {e}"}
return result
def _track_images(self, result: Any, tracking_state: Dict[str, Any]) -> None:
"""Track new images from tool results."""
if not isinstance(result, dict):
return
for img in result.get("images", []):
if img not in tracking_state["seen_imgs"]:
tracking_state["seen_imgs"].add(img)
tracking_state["all_images"].append(img)
def _handle_final_answer(self, history: List[Dict[str, Any]], call_id: str,
result: Dict[str, Any], tracking_state: Dict[str, Any],
question: str) -> ReasoningResult:
"""Handle final_answer tool call."""
fa = result if isinstance(result, dict) else {}
answer_text = fa.get("answer", "")
merged_images = list({*tracking_state["all_images"], *fa.get("images", [])})
# Persist report
try:
self.report_store(self.model_id, question, answer_text, merged_images)
except Exception as e:
log.warning("store_report failed: %s", e)
# Echo tool message and return
self._append_tool_message(history, call_id, fa)
return ReasoningResult(
history=history,
code_map=tracking_state["code_map"],
answer=answer_text,
images=merged_images
)
def _nudge_for_tool_call(self, history: List[Dict[str, Any]]) -> None:
"""Nudge the model to use a tool call."""
history.append({
"role": "user",
"content": "Please respond with a tool call (python_exec / run_simulation_for_model / run_batch_for_model / final_answer)."
})
def _create_result(self, history: List[Dict[str, Any]], tracking_state: Dict[str, Any],
answer: str) -> ReasoningResult:
"""Create a ReasoningResult from current state."""
return ReasoningResult(
history=history,
code_map=tracking_state["code_map"],
answer=answer,
images=tracking_state["all_images"]
)
def _append_tool_message(self, history: List[Dict[str, Any]], call_id: str, payload: Any) -> None:
"""Helper method to append tool messages to history."""
_append_tool_message(history, call_id, payload)