| from dataclasses import dataclass, field |
| from pathlib import Path |
| |
| from execute.test.fix_agent import FixAgent |
| from execute.test.smoke_tester import SmokeTester |
| |
| # Import database function - adjust path as needed |
| from db import store_simulation_script |
| |
| |
| @dataclass |
| class SimulationRefiner: |
| """ |
| Iteratively smoke-tests a simulate.py and uses an agent to repair it. |
| Writes intermediate .iter{i}.py files; returns model_id when passing. |
| """ |
| script_path: Path |
| model_name: str |
| max_iterations: int = 3 |
| smoke_tester: SmokeTester = field(default_factory=SmokeTester) |
| agent: FixAgent = field(default_factory=FixAgent) |
| |
| def refine(self) -> str: |
| for i in range(1, self.max_iterations + 1): |
| res = self.smoke_tester.test(self.script_path) |
| if res.ok: |
| print(f"[✓] simulate.py passed smoke test on iteration {i}") |
| final_model_id = store_simulation_script( |
| model_name=self.model_name, |
| metadata={}, # keep parity with your original |
| script_path=str(self.script_path), |
| ) |
| return final_model_id |
| |
| print(f"[!] simulate.py failed on iteration {i}:\n{res.log.strip()}") |
| current_src = self.script_path.read_text() |
| |
| corrected_code = self.agent.propose_fix(res.log, current_src) |
| |
| # Save intermediate & replace current |
| iter_path = self.script_path.with_name(f"{self.script_path.stem}.iter{i}.py") |
| iter_path.write_text(corrected_code) |
| self.script_path.write_text(corrected_code) |
| |
| raise RuntimeError("simulate.py still failing after all correction attempts.") |
| |