blob: 6c4e056dead441da279d71bb45aef645b6e0df7e [file]
"""This module contains specific incarnations of models."""
from sklearn import base
from hamilton.function_modifiers import source, subdag
try:
import model_fitting
except ImportError:
from . import model_fitting
@subdag(
model_fitting,
inputs={
"data_set": source("data_set_v1"),
},
config={"clf": "svm", "shuffle_train_test_split": True, "test_size_fraction": 0.2},
)
def svm_model(
fit_clf: base.ClassifierMixin, training_accuracy: float, testing_accuracy: float
) -> dict:
return {
"svm": fit_clf,
"training_accuracy": training_accuracy,
"testing_accuracy": testing_accuracy,
}
@subdag(
model_fitting,
inputs={
"data_set": source("data_set_v1"),
},
config={
"clf": "logistic",
"shuffle_train_test_split": True,
"test_size_fraction": 0.2,
"penalty": "l2",
},
)
def lr_model(
fit_clf: base.ClassifierMixin, training_accuracy: float, testing_accuracy: float
) -> dict:
return {
"logistic": fit_clf,
"training_accuracy": training_accuracy,
"testing_accuracy": testing_accuracy,
}
def best_model(svm_model: dict, lr_model: dict) -> dict:
"""Returns the best model based on the testing accuracy."""
if svm_model["testing_accuracy"] > lr_model["testing_accuracy"]:
return svm_model
else:
return lr_model