blob: a4d00b59588e3c5b932acc5af0f1c2b63922d0d1 [file] [log] [blame]
import pathlib
import pytest
import xgboost
from sklearn.utils.validation import check_is_fitted
from hamilton.io.utils import FILE_METADATA
from hamilton.plugins.xgboost_extensions import XGBoostJsonReader, XGBoostJsonWriter
@pytest.fixture
def fitted_xgboost_model() -> xgboost.XGBModel:
model = xgboost.XGBRegressor()
model.fit([[0]], [[0]])
return model
@pytest.fixture
def fitted_xgboost_booster() -> xgboost.Booster:
dtrain = xgboost.DMatrix([[0]], label=[[0]])
booster = xgboost.train({"objective": "binary:logistic"}, dtrain, 1)
return booster
def test_xgboost_model_json_writer(
fitted_xgboost_model: xgboost.XGBModel, tmp_path: pathlib.Path
) -> None:
model_path = tmp_path / "model.json"
writer = XGBoostJsonWriter(path=model_path)
metadata = writer.save_data(fitted_xgboost_model)
assert model_path.exists()
assert metadata[FILE_METADATA]["path"] == str(model_path)
def test_xgboost_model_json_reader(
fitted_xgboost_model: xgboost.XGBModel, tmp_path: pathlib.Path
) -> None:
model_path = tmp_path / "model.json"
fitted_xgboost_model.save_model(model_path)
reader = XGBoostJsonReader(model_path)
model, metadata = reader.load_data(xgboost.XGBRegressor)
check_is_fitted(model)
assert XGBoostJsonReader.applicable_types() == [xgboost.XGBModel, xgboost.Booster]
def test_xgboost_booster_json_writer(
fitted_xgboost_booster: xgboost.Booster, tmp_path: pathlib.Path
) -> None:
booster_path = tmp_path / "booster.json"
writer = XGBoostJsonWriter(path=booster_path)
metadata = writer.save_data(fitted_xgboost_booster)
assert booster_path.exists()
assert metadata[FILE_METADATA]["path"] == str(booster_path)
def test_xgboost_booster_json_reader(
fitted_xgboost_booster: xgboost.Booster, tmp_path: pathlib.Path
) -> None:
booster_path = tmp_path / "booster.json"
fitted_xgboost_booster.save_model(booster_path)
reader = XGBoostJsonReader(booster_path)
booster, metadata = reader.load_data(xgboost.Booster)
assert len(booster.get_dump()) > 0
assert XGBoostJsonReader.applicable_types() == [xgboost.XGBModel, xgboost.Booster]