blob: d49ee5ba07d3fe1572e3657792a00647a4ef0794 [file] [log] [blame]
"""
Example script showing how one might setup a generic model training pipeline that is quickly configurable.
"""
import digit_loader
import iris_loader
import my_train_evaluate_logic
from hamilton import base, driver
def get_data_loader(data_set: str):
"""Returns the module to load that will procur data -- the data loaders all have to define the same functions."""
if data_set == "iris":
return iris_loader
elif data_set == "digits":
return digit_loader
else:
raise ValueError(f"Unknown data_name {data_set}.")
def get_model_config(model_type: str) -> dict:
"""Returns model type specific configuration"""
if model_type == "svm":
return {"clf": "svm", "gamma": 0.001}
elif model_type == "logistic":
return {"clf": "logistic", "penalty": "l2"}
else:
raise ValueError(f"Unsupported model {model_type}.")
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Error: required arguments are [iris|digits] [svm|logistic]")
sys.exit(1)
_data_set = sys.argv[1] # the data set to load
_model_type = sys.argv[2] # the model type to fit and evaluate with
dag_config = {
"test_size_fraction": 0.5,
"shuffle_train_test_split": True,
}
# augment config
dag_config.update(get_model_config(_model_type))
# get module with functions to load data
data_module = get_data_loader(_data_set)
# set the desired result container we want
adapter = base.DefaultAdapter()
"""
What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can
reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if
want to setup a generic model fitting and prediction dataflow!
E.g. if we want to support a new data set, then we just need to add a new data loading module.
E.g. if we want to support a new model type, then we just need to add a single conditional function
to my_train_evaluate_logic.
"""
dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter)
# ensure you have done "pip install "sf-hamilton[visualization]"" for the following to work:
# dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],
# f'./model_dag_{_data_set}_{_model_type}.dot', {"format": "png"})
results = dr.execute(["classification_report", "confusion_matrix", "fit_clf"])
for k, v in results.items():
print(k, ":\n", v)