blob: 238cb7e1cf3d99a7e09ff33c2707738dfda915f0 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
"""Trainer engine action.
Use this module to add the project main code.
"""
from .._compatibility import six
from .._logging import get_logger
from sklearn import svm
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, scale
from sklearn.linear_model import LogisticRegression
from marvin_python_toolbox.engine_base import EngineBaseTraining
__all__ = ['Trainer']
logger = get_logger('trainer')
class Trainer(EngineBaseTraining):
def __init__(self, **kwargs):
super(Trainer, self).__init__(**kwargs)
def execute(self, params, **kwargs):
print("\n\nStarting grid search using SVM!")
# Create a classifier with the parameter candidates
svm_grid = GridSearchCV(estimator=svm.SVC(), param_grid=params["svm"], cv=self.marvin_dataset["sss"], n_jobs=-1)
# Train the classifier on training data
svm_grid.fit(
self.marvin_dataset['X_train'],
self.marvin_dataset['y_train']
)
print("Model Type: SVM\n{}".format(svm_grid.best_estimator_.get_params()))
print("Accuracy Score: {}%".format(round(svm_grid.best_score_, 4)))
print("\n\nStarting grid search using RandomForestClassifier!")
# run grid search
rf_grid = GridSearchCV(estimator=RandomForestClassifier(), param_grid=params["rf"], cv=self.marvin_dataset["sss"])
rf_grid.fit(
self.marvin_dataset['X_train'],
self.marvin_dataset['y_train']
)
print("Model Type: RF\n{}".format(rf_grid.best_estimator_.get_params()))
print("Accuracy Score: {}%".format(round(rf_grid.best_score_, 4)))
self.marvin_model = {
'svm': svm_grid,
'rf': rf_grid
}