blob: 6dbf28aec90ef1dcba0b94f9ec75b25b3b3bcfcd [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
try:
import mock
except ImportError:
import unittest.mock as mock
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from marvin_titanic_engine.training import Trainer
@mock.patch('marvin_titanic_engine.training.trainer.round')
@mock.patch('marvin_titanic_engine.training.trainer.GridSearchCV')
def test_execute(grid_mocked, round_mocked, mocked_params):
test_dataset = {
"X_train": pd.DataFrame({'Sex': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]}),
"y_train": pd.DataFrame({'Sex': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]}),
"sss": mock.MagicMock()
}
mocked_params = {
"pred_cols": ["Sex", "B"],
"dep_var": "C",
"svm": [
{"C": [1, 10, 100], "gamma": [0.01, 0.001], "kernel": ["linear"]},
{"C": [1, 10, 100], "gamma": [0.01, 0.001], "kernel": ["rbf"]}
],
"rf": {
"max_depth": [3],
"random_state": [0],
"min_samples_split": [2],
"min_samples_leaf": [1],
"n_estimators": [20],
"bootstrap": [True, False],
"criterion": ["gini", "entropy"]
}
}
ac = Trainer(dataset=test_dataset)
ac.execute(params=mocked_params)
grid_mocked.assert_called()
round_mocked.assert_called()