blob: d10d0a7695a708018d2d30591b5522f769c2248e [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
try:
import mock
except ImportError:
import unittest.mock as mock
from marvin_product_classifier_engine.training import Trainer
@mock.patch('marvin_product_classifier_engine.training.trainer.MultinomialNB.fit')
def test_execute(fit_mocked, mocked_params):
data_source = {
"X_train": ["train datas"],
"X_test": ["test datas"],
"y_train": ["train labels"],
"y_test": ["test labels"],
"vect": "test"
}
ac = Trainer(dataset=data_source)
ac.execute(params=mocked_params)
fit_mocked.assert_called_once_with(['train datas'], ['train labels'])
assert str(ac.marvin_model["vect"]) == "test"
assert not ac._params