blob: f62df8aab92822c817b534a9884a1e5c2d035405 [file] [log] [blame]
from pathlib import Path
from typing import List, Dict, Optional, Any
import json
from ..config.database import DatabaseConfig
from ..base import BaseRepository as AbstractBaseRepository
class SimulationRepository(AbstractBaseRepository):
def __init__(self, db_config: DatabaseConfig = None):
config = db_config or DatabaseConfig()
super().__init__(db_config=config)
# For backward compatibility
self.db_config = config
def get_simulation_path(self, model_id: str) -> str:
"""
Return the absolute path to `simulate.py` for the given model_id.
Args:
model_id: The ID of the model to find
Returns:
str: Path to the simulation script
Raises:
KeyError: If the model_id is unknown
"""
with self.db_config.get_sqlite_connection() as conn:
row = conn.execute(
"SELECT script_path FROM simulations WHERE id = ?",
(model_id,)
).fetchone()
if row is None:
raise KeyError(f"model_id '{model_id}' not found in DB {self.db_config.database_path}")
return row["script_path"]
def store_simulation_results(self, model_id: str, rows: List[dict], param_keys: List[str]) -> None:
"""
Store simulation results in the results table.
Args:
model_id: The ID of the model
rows: List of result dictionaries from simulation runs
param_keys: List of parameter names used in the simulation
"""
with self.db_config.get_sqlite_connection() as conn:
for row in rows:
# Split data into params and outputs
params = self._extract_parameters(row, param_keys)
outputs = self._extract_results(row, param_keys)
conn.execute("""
INSERT INTO results (model_id, params, outputs, ts)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
""", (
model_id,
params,
outputs
))
@staticmethod
def _extract_parameters(row: dict, param_keys: List[str]) -> str:
"""Extract and serialize parameters from result row"""
params = {k: row[k] for k in param_keys if k in row}
# Include any special fields that start with underscore
params.update({
k: v for k, v in row.items()
if k.startswith('_') and k not in ('_ok', '_error_msg', '_error_type')
})
return json.dumps(params)
@staticmethod
def _extract_results(row: dict, param_keys: List[str]) -> str:
"""Extract and serialize results, excluding parameters and special fields"""
results = {
k: v for k, v in row.items()
if not k.startswith('_') and k not in param_keys
}
# Include error information if present
if not row.get('_ok', True):
results['error'] = {
'type': row.get('_error_type', ''),
'message': row.get('_error_msg', '')
}
return json.dumps(results)
# Implement abstract methods from BaseRepository
def get(self, id: Any) -> Optional[Any]:
"""Get simulation by ID."""
try:
return self.get_simulation_path(str(id))
except KeyError:
return None
def list(self, filters: Dict[str, Any] = None) -> List[Any]:
"""List simulations with optional filters."""
with self.db_config.get_sqlite_connection() as conn:
query = "SELECT id, name, metadata, script_path FROM simulations"
params = []
if filters:
where_clauses = []
for key, value in filters.items():
if key in ['id', 'name']:
where_clauses.append(f"{key} = ?")
params.append(value)
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
rows = conn.execute(query, params).fetchall()
return [dict(row) for row in rows]