blob: d4b52b6d403319907a3a2766e4ecfeb4d893a3a3 [file] [log] [blame]
import dataclasses
from os import PathLike
from typing import Any, Collection, Dict, Tuple, Type, Union
import xgboost
from hamilton.io import utils
from hamilton.io.data_adapters import DataLoader
@dataclasses.dataclass
class XGBoostJsonReader(DataLoader):
path: Union[str, bytearray, PathLike]
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [xgboost.XGBModel]
def load_data(self, type_: Type) -> Tuple[xgboost.XGBModel, Dict[str, Any]]:
# uses the XGBoost library
model = type_()
model.load_model(self.path)
metadata = utils.get_file_metadata(self.path)
return model, metadata
@classmethod
def name(cls) -> str:
return "json" # the name for `from_.{name}`