blob: 827124e069b7ca8ef7691a0a9b2af3a4bc0514b3 [file] [log] [blame]
import dataclasses
from os import PathLike
from typing import Any, Collection, Dict, Type, Union
import xgboost
from hamilton.io import utils
from hamilton.io.data_adapters import DataSaver
@dataclasses.dataclass
class XGBoostJsonWriter(DataSaver):
path: Union[str, PathLike]
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [xgboost.XGBModel]
def save_data(self, data: xgboost.XGBModel) -> Dict[str, Any]:
# uses the XGBoost library
data.save_model(self.path)
return utils.get_file_metadata(self.path)
@classmethod
def name(cls) -> str:
return "json" # the name for `to.{name}`