blob: cb6e2115dc8514665af5f957d47301651c1926e7 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
"""Trainer engine action.
Use this module to add the project main code.
"""
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from .._compatibility import six
from .._logging import get_logger
from marvin_python_toolbox.engine_base import EngineBaseTraining
__all__ = ['Trainer']
logger = get_logger('trainer')
class Trainer(EngineBaseTraining):
def __init__(self, **kwargs):
super(Trainer, self).__init__(**kwargs)
def execute(self, params, **kwargs):
algorithms = {
'svm': svm.SVC,
'lr': LogisticRegression,
'dt': DecisionTreeClassifier,
'knn': KNeighborsClassifier
}
_model = {}
for name in algorithms.keys():
algorithm = algorithms[name]
_model[name + '_petals'] = algorithm().fit(self.marvin_dataset['petals']['train_X'], self.marvin_dataset['petals']['train_y'])
_model[name + '_sepals'] = algorithm().fit(self.marvin_dataset['sepals']['train_X'], self.marvin_dataset['sepals']['train_y'])
_model[name + '_joined'] = algorithm().fit(self.marvin_dataset['joined']['train_X'], self.marvin_dataset['joined']['train_y'])
self.marvin_model = _model