blob: ca7474168a01c9da501ec023277e5993093d79ac [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
try:
import mock
except ImportError:
import unittest.mock as mock
import pandas as pd
from marvin_iris_species_engine.training import Trainer
@mock.patch('marvin_iris_species_engine.training.trainer.svm.SVC')
@mock.patch('marvin_iris_species_engine.training.trainer.LogisticRegression')
@mock.patch('marvin_iris_species_engine.training.trainer.DecisionTreeClassifier')
@mock.patch('marvin_iris_species_engine.training.trainer.KNeighborsClassifier')
def test_execute(knn_mocked, dt_mocked, lr_mocked, svc_mocked, mocked_params):
data_petals_x = {
'PetalLengthCm': [5, 6],
'PetalWidthCm': [7, 8],
}
data_sepals_x = {
'SepalLengthCm': [1, 2],
'SepalWidthCm': [3, 4],
}
data_joined_x = {
'SepalLengthCm': [1, 2],
'SepalWidthCm': [3, 4],
'PetalLengthCm': [5, 6],
'PetalWidthCm': [7, 8],
}
data_petals_y = {'Species': ['species1']}
data_sepals_y = {'Species': ['species1']}
data_joined_y = {'Species': ['species1']}
train_x_p = pd.DataFrame(data=data_petals_x)
train_y_p = pd.DataFrame(data=data_petals_y)
test_x_p = pd.DataFrame(data=data_petals_x)
test_y_p = pd.DataFrame(data=data_petals_y)
train_x_s = pd.DataFrame(data=data_sepals_x)
train_y_s = pd.DataFrame(data=data_sepals_y)
test_x_s = pd.DataFrame(data=data_sepals_x)
test_y_s = pd.DataFrame(data=data_sepals_y)
train_X = pd.DataFrame(data=data_joined_x)
train_y = pd.DataFrame(data=data_joined_y)
test_X = pd.DataFrame(data=data_joined_x)
test_y = pd.DataFrame(data=data_joined_y)
data_source = {
'petals': {'train_X': train_x_p, 'train_y': train_y_p, 'test_X': test_x_p, 'test_y': test_y_p},
'sepals': {'train_X': train_x_s, 'train_y': train_y_s, 'test_X': test_x_s, 'test_y': test_y_s},
'joined': {'train_X': train_X, 'train_y': train_y, 'test_X': test_X, 'test_y': test_y}
}
ac = Trainer(dataset=data_source)
ac.execute(params=mocked_params)
knn_mocked.assert_called()
dt_mocked.assert_called()
lr_mocked.assert_called()
svc_mocked.assert_called()