blob: 1b690c54f44572de14a7e47281503bfbc0965c7e [file] [log] [blame]
import contextlib
import io
import traceback
import textwrap
import uuid
import os
import time
from pathlib import Path
from typing import Dict, Optional, Union, Any, List, Literal
import matplotlib
from matplotlib import pyplot as plt
# Temporary imports - these functions need to be implemented or imported correctly
try:
from core.script_utils import _capture_show, _media_dir_for, sanitize_metadata
except ImportError:
# Mock implementations for now
def _capture_show(images_list):
return plt.show
def _media_dir_for(model_id):
return Path(f"media/{model_id}" if model_id else "media")
def sanitize_metadata(data, media_dir, media_paths, prefix=""):
return data
matplotlib.use("Agg") # headless matplotlib
import pandas as pd, json
from pandas import DataFrame
# tools.py
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, PrivateAttr
ExecMode = Literal["analysis", "simulate"]
class PythonExecArgs(BaseModel):
code: Optional[str] = Field(default=None, description="Python source to execute.")
mode: ExecMode = Field(default="analysis", description="'analysis' or 'simulate'")
params: Optional[Dict[str, Any]] = Field(default=None, description="Kwargs for simulate(**params) when mode='simulate'")
model_id: Optional[str] = Field(default=None, description="Folder key for media grouping")
timeout_s: float = Field(default=30.0, description="Per-run timeout for simulate mode (soft check)")
class PythonExecTool(BaseTool):
"""
Unified tool:
- analysis mode: executes arbitrary Python against optional `df`
returns {ok, stdout, stderr, images}
- simulate mode: executes code that defines `def simulate(**params)->dict`
returns {ok, stdout, stderr, images, outputs, media_paths}
Backwards compatible with your previous usage if you pass only code.
"""
name: str = Field("python_exec")
description: str = (
"Execute Python. In 'analysis' mode, runs code against DataFrame `df` if provided. "
"In 'simulate' mode, runs code that defines simulate(**params)->dict and returns "
"JSON-safe outputs with media paths. Keys: ok, stdout, stderr, images, [outputs, media_paths]."
)
_df: Optional[DataFrame] = PrivateAttr(default=None)
def __init__(self, df: Optional[DataFrame] = None):
super().__init__() # ensure BaseModel init
self._df = df
# LangChain calls _run with a dict of args (function calling)
def _run(self, args: Dict[str, Any]) -> dict:
payload = PythonExecArgs(**args)
if not payload.code:
return {"ok": False, "stdout": "", "stderr": "No code provided.", "images": []}
if payload.mode == "simulate":
return self.run_simulation(
code=payload.code,
params=payload.params or {},
model_id=payload.model_id,
timeout_s=payload.timeout_s,
)
# default: analysis
return self.run_python(code=payload.code, df=self._df)
# -------- analysis mode (unchanged behavior) --------
def run_python(self, code: str, df: Optional[pd.DataFrame]) -> Dict[str, Any]:
before = {f for f in os.listdir() if f.lower().endswith(".png")}
images: List[str] = []
old_show = _capture_show(images)
stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
ok = True
g = {"plt": plt, "pd": pd, "np": __import__("numpy")}
if df is not None:
g["df"] = df
try:
with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(stderr_buf):
exec(textwrap.dedent(code), g)
except Exception:
stderr_buf.write(traceback.format_exc())
ok = False
finally:
plt.show = old_show
after = {f for f in os.listdir() if f.lower().endswith(".png")}
new_images = sorted(after - before)
# also include images saved via our plt.show hook
for p in images:
if p not in new_images:
new_images.append(p)
return {"ok": ok, "stdout": stdout_buf.getvalue(), "stderr": stderr_buf.getvalue(), "images": new_images}
# # -------- simulate mode --------
# def run_simulation(
# self,
# code: str,
# params: Dict[str, Any],
# model_id: Optional[str] = None,
# timeout_s: float = 30.0,
# ) -> Dict[str, Any]:
# media_dir = _media_dir_for(model_id)
#
# before = {f for f in os.listdir() if f.lower().endswith(".png")}
# images: List[str] = []
# old_show = _capture_show(images)
#
# stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
# ok, ret, err = True, None, ""
# g = {"plt": plt, "np": __import__("numpy")} # simulation shouldn't need df/pd by default
#
# start = time.time()
# try:
# with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(stderr_buf):
# exec(textwrap.dedent(code), g)
# sim = g.get("simulate")
# if not callable(sim):
# raise RuntimeError("No callable `simulate(**params)` found.")
# ret = sim(**params)
# except Exception:
# ok = False
# err = traceback.format_exc()
# finally:
# plt.show = old_show
#
# elapsed = time.time() - start
# if elapsed > timeout_s:
# ok = False
# err = (err + "\n" if err else "") + f"Timeout exceeded: {elapsed:.1f}s > {timeout_s:.1f}s"
#
# after = {f for f in os.listdir() if f.lower().endswith(".png")}
# disk_new = sorted(after - before)
# for f in disk_new:
# if f not in images:
# images.append(f)
#
# media_paths: List[str] = []
# outputs = sanitize_metadata(ret, media_dir, media_paths, prefix="ret")
#
# return {
# "ok": ok,
# "stdout": stdout_buf.getvalue(),
# "stderr": err or stderr_buf.getvalue(),
# "images": images,
# "outputs": outputs,
# "media_paths": media_paths,
# }