blob: 8e861356ac9aee17982b13286bf7005cf296c3d9 [file] [log] [blame]
import dataclasses
from os import PathLike
from typing import Any, Collection, Dict, Optional, Type, Union
# This is not necessary once this PR gets merged: https://github.com/DAGWorks-Inc/hamilton/pull/467.
try:
import sklearn.inspection
import sklearn.metrics
import sklearn.model_selection
except ImportError:
raise NotImplementedError("scikit-learn is not installed.")
from hamilton import registry
from hamilton.io import utils
from hamilton.io.data_adapters import DataSaver
display_classes = [
"ConfusionMatrixDisplay",
"DetCurveDisplay",
"PrecisionRecallDisplay",
"PredictionErrorDisplay",
"RocCurveDisplay",
"CalibrationDisplay",
"DecisionBoundaryDisplay",
"LearningCurveDisplay",
"PartialDependenceDisplay",
"ValidationCurveDisplay",
]
SKLEARN_PLOT_TYPES = []
for class_name in display_classes:
# get the attribute via string from sklearn.metrics; if not found return None
class_metrics = getattr(sklearn.metrics, class_name, None)
class_inspection = getattr(sklearn.inspection, class_name, None)
class_model_selection = getattr(sklearn.model_selection, class_name, None)
if class_metrics:
SKLEARN_PLOT_TYPES.append(class_metrics)
if class_inspection:
SKLEARN_PLOT_TYPES.append(class_inspection)
if class_model_selection:
SKLEARN_PLOT_TYPES.append(class_model_selection)
SKLEARN_PLOT_TYPES_ANNOTATION = Union[tuple(SKLEARN_PLOT_TYPES)]
@dataclasses.dataclass
class SklearnPlotSaver(DataSaver):
path: Union[str, PathLike]
# kwargs
dpi: float = 200
format: str = "png"
metadata: Optional[dict] = None
bbox_inches: str = None
pad_inches: float = 0.1
backend: Optional[str] = None
papertype: str = None
transparent: bool = None
bbox_extra_artists: Optional[list] = None
pil_kwargs: Optional[dict] = None
@classmethod
def applicable_types(cls) -> Collection[Type]:
return SKLEARN_PLOT_TYPES
def _get_saving_kwargs(self) -> Dict[str, Any]:
kwargs = {}
if self.dpi is not None:
kwargs["dpi"] = self.dpi
if self.format is not None:
kwargs["format"] = self.format
if self.metadata is not None:
kwargs["metadata"] = self.metadata
if self.bbox_inches is not None:
kwargs["bbox_inches"] = self.bbox_inches
if self.pad_inches is not None:
kwargs["pad_inches"] = self.pad_inches
if self.backend is not None:
kwargs["backend"] = self.backend
if self.papertype is not None:
kwargs["papertype"] = self.papertype
if self.transparent is not None:
kwargs["transparent"] = self.transparent
if self.bbox_extra_artists is not None:
kwargs["bbox_extra_artists"] = self.bbox_extra_artists
if self.pil_kwargs is not None:
kwargs["pil_kwargs"] = self.pil_kwargs
return kwargs
def save_data(self, data: SKLEARN_PLOT_TYPES_ANNOTATION) -> Dict[str, Any]:
data.plot()
data.figure_.savefig(self.path, **self._get_saving_kwargs())
return utils.get_file_metadata(self.path)
@classmethod
def name(cls) -> str:
return "png"
def register_data_loaders():
"""Function to register the data loaders for this extension."""
registry.register_adapter(SklearnPlotSaver)
register_data_loaders()