blob: a5b073a26ba0799f2b5a637c8e5f97776fdcbfaf [file] [log] [blame]
from dataclasses import dataclass
from typing import Optional
from contextlib import contextmanager
import sqlite3
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import Engine
@dataclass
class DatabaseConfig:
dialect: str = "sqlite"
database_path: str = "mcp.db"
echo: bool = False
host: Optional[str] = None
port: Optional[int] = None
username: Optional[str] = None
password: Optional[str] = None
_engine: Optional[Engine] = None
_session_factory: Optional[sessionmaker] = None
@property
def connection_string(self) -> str:
if self.dialect == "sqlite":
return f"sqlite:///{self.database_path}"
elif self.dialect == "postgresql":
return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_path}"
raise ValueError(f"Unsupported dialect: {self.dialect}")
def get_engine(self) -> Engine:
if self._engine is None:
self._engine = create_engine(self.connection_string, echo=self.echo)
return self._engine
def get_session_factory(self) -> sessionmaker:
if self._session_factory is None:
self._session_factory = sessionmaker(bind=self.get_engine())
return self._session_factory
@contextmanager
def get_session(self) -> Session:
session = self.get_session_factory()()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
@contextmanager
def get_sqlite_connection(self) -> sqlite3.Connection:
"""Get a SQLite connection with row factory set to dict-like rows"""
conn = sqlite3.connect(str(self.database_path))
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()