blob: 1326ea202dcdae8695b99a23ebf60e716edc0027 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "47ed8323-e689-464c-83ec-1ee98d2c2585",
"metadata": {},
"source": [
"# Hamilton for ML dataflows\n",
"\n",
"#### Requirements:\n",
"\n",
"- Install dependencies (listed in `requirements.txt`)\n",
"\n",
"More details [here](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/model_examples/scikit-learn/README.md#using-hamilton-for-ml-dataflows).\n",
"\n",
"***\n",
"\n",
"Uncomment and run the cell below if you are in a Google Colab environment. It will:\n",
"1. Mount google drive. You will be asked to authenticate and give permissions.\n",
"2. Change directory to google drive.\n",
"3. Make a directory \"hamilton-tutorials\"\n",
"4. Change directory to it.\n",
"5. Clone this repository to your google drive\n",
"6. Move your current directory to the hello_world example\n",
"7. Install requirements.\n",
"\n",
"This means that any modifications will be saved, and you won't lose them if you close your browser."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e12e1c-a8b2-477a-a9ff-6257ab587734",
"metadata": {},
"outputs": [],
"source": [
"## 1. Mount google drive\n",
"# from google.colab import drive\n",
"# drive.mount('/content/drive')\n",
"## 2. Change directory to google drive.\n",
"# %cd /content/drive/MyDrive\n",
"## 3. Make a directory \"hamilton-tutorials\"\n",
"# !mkdir hamilton-tutorials\n",
"## 4. Change directory to it.\n",
"# %cd hamilton-tutorials\n",
"## 5. Clone this repository to your google drive\n",
"# !git clone https://github.com/DAGWorks-Inc/hamilton/\n",
"## 6. Move your current directory to the hello_world example\n",
"# %cd hamilton/examples/hello_world\n",
"## 7. Install requirements.\n",
"# %pip install -r requirements.txt\n",
"# clear_output() # optionally clear outputs\n",
"# To check your current working directory you can type `!pwd` in a cell and run it."
]
},
{
"cell_type": "markdown",
"id": "9115ca99-cb3b-4dc3-8218-fa26b00d2199",
"metadata": {},
"source": [
"***\n",
"Here we have a simple example showing how you can write a ML training and evaluation workflow with Hamilton. \n",
"***"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "04fa1ff7-74f7-4193-9e1f-c17d9e68efc5",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Example script showing how one might setup a generic model training pipeline that is quickly configurable.\n",
"\"\"\"\n",
"\n",
"import digit_loader\n",
"import iris_loader\n",
"import my_train_evaluate_logic\n",
"\n",
"from hamilton import base, driver\n",
"\n",
"\n",
"def get_data_loader(data_set: str):\n",
" \"\"\"Returns the module to load that will procur data -- the data loaders all have to define the same functions.\"\"\"\n",
" if data_set == \"iris\":\n",
" return iris_loader\n",
" elif data_set == \"digits\":\n",
" return digit_loader\n",
" else:\n",
" raise ValueError(f\"Unknown data_name {data_set}.\")\n",
"\n",
"\n",
"def get_model_config(model_type: str) -> dict:\n",
" \"\"\"Returns model type specific configuration\"\"\"\n",
" if model_type == \"svm\":\n",
" return {\"clf\": \"svm\", \"gamma\": 0.001}\n",
" elif model_type == \"logistic\":\n",
" return {\"clf\": \"logistic\", \"penalty\": \"l2\"}\n",
" else:\n",
" raise ValueError(f\"Unsupported model {model_type}.\")"
]
},
{
"cell_type": "markdown",
"id": "88ccbc7c-f265-47fa-a921-f26ac3ed7094",
"metadata": {},
"source": [
"***\n",
"For the purpose of this experiment, lets apply the following configuration:\n",
"\n",
"- `_data_set` = 'digits'\n",
"- `_model_type` = 'logistic'\n",
"\n",
"More details [here](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/model_examples/scikit-learn/README.md).\n",
"***"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9e5e7282-8286-4055-847f-adb168420da0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"classification_report :\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.99 0.99 91\n",
" 1 0.92 0.95 0.94 84\n",
" 2 0.98 1.00 0.99 83\n",
" 3 0.99 0.98 0.98 81\n",
" 4 0.95 0.99 0.97 95\n",
" 5 0.98 0.94 0.96 97\n",
" 6 0.97 0.98 0.97 85\n",
" 7 0.98 0.98 0.98 96\n",
" 8 0.91 0.90 0.91 96\n",
" 9 0.96 0.93 0.94 91\n",
"\n",
" accuracy 0.96 899\n",
" macro avg 0.96 0.96 0.96 899\n",
"weighted avg 0.96 0.96 0.96 899\n",
"\n",
"confusion_matrix :\n",
" [[90 0 0 0 1 0 0 0 0 0]\n",
" [ 0 80 0 0 1 0 1 0 2 0]\n",
" [ 0 0 83 0 0 0 0 0 0 0]\n",
" [ 0 0 0 79 0 0 0 1 0 1]\n",
" [ 0 1 0 0 94 0 0 0 0 0]\n",
" [ 0 1 0 1 1 91 0 1 0 2]\n",
" [ 0 0 0 0 0 0 83 0 2 0]\n",
" [ 0 0 0 0 1 0 0 94 0 1]\n",
" [ 0 5 2 0 0 1 2 0 86 0]\n",
" [ 0 0 0 0 1 1 0 0 4 85]]\n",
"fit_clf :\n",
" LogisticRegression()\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/flaviassantos/github/hamilton/venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
}
],
"source": [
"_data_set = 'digits' # the data set to load\n",
"_model_type = 'logistic' # the model type to fit and evaluate with\n",
"\n",
"dag_config = {\n",
" \"test_size_fraction\": 0.5,\n",
" \"shuffle_train_test_split\": True,\n",
"}\n",
"# augment config\n",
"dag_config.update(get_model_config(_model_type))\n",
"# get module with functions to load data\n",
"data_module = get_data_loader(_data_set)\n",
"# set the desired result container we want\n",
"adapter = base.DefaultAdapter()\n",
"\"\"\"\n",
"What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can\n",
"reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if\n",
"want to setup a generic model fitting and prediction dataflow!\n",
"E.g. if we want to support a new data set, then we just need to add a new data loading module.\n",
"E.g. if we want to support a new model type, then we just need to add a single conditional function\n",
" to my_train_evaluate_logic.\n",
"\"\"\"\n",
"dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter)\n",
"# ensure you have done \"pip install \"sf-hamilton[visualization]\"\" for the following to work:\n",
"# dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n",
"# f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})\n",
"results = dr.execute([\"classification_report\", \"confusion_matrix\", \"fit_clf\"])\n",
"for k, v in results.items():\n",
" print(k, \":\\n\", v)"
]
},
{
"cell_type": "markdown",
"id": "2035065c-c409-4c21-bd11-733e74623226",
"metadata": {},
"source": [
"***\n",
"Here is the graph of execution for the digits data set:\n",
"***"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "17522217-8a09-46da-8b8d-0ba97d278bdc",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 9.0.0 (20230911.1827)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"729pt\" height=\"548pt\"\n",
" viewBox=\"0.00 0.00 729.01 548.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 544)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-544 725.01,-544 725.01,4 -4,4\"/>\n",
"<!-- confusion_matrix -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>confusion_matrix</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"350.52,-36 238.52,-36 238.52,0 350.52,0 350.52,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"294.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">confusion_matrix</text>\n",
"</g>\n",
"<!-- classification_report -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>classification_report</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"200.65,-36 74.4,-36 74.4,0 200.65,0 200.65,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"137.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n",
"</g>\n",
"<!-- y_train -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>y_train</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"414.52\" cy=\"-306\" rx=\"37.02\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"414.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n",
"</g>\n",
"<!-- fit_clf -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>fit_clf</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"403.52,-252 349.52,-252 349.52,-216 403.52,-216 403.52,-252\"/>\n",
"<text text-anchor=\"middle\" x=\"376.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n",
"</g>\n",
"<!-- y_train&#45;&gt;fit_clf -->\n",
"<g id=\"edge18\" class=\"edge\">\n",
"<title>y_train&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M405.52,-288.41C401.29,-280.62 396.14,-271.14 391.36,-262.33\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"394.56,-260.9 386.72,-253.78 388.41,-264.24 394.56,-260.9\"/>\n",
"</g>\n",
"<!-- y_test -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>y_test</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"154.52\" cy=\"-234\" rx=\"32.93\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"154.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n",
"</g>\n",
"<!-- y_test_with_labels -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>y_test_with_labels</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"97.52\" cy=\"-90\" rx=\"80.01\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"97.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n",
"</g>\n",
"<!-- y_test&#45;&gt;y_test_with_labels -->\n",
"<g id=\"edge20\" class=\"edge\">\n",
"<title>y_test&#45;&gt;y_test_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M136.39,-218.82C125.44,-209.15 112.34,-195.35 105.52,-180 97.09,-161.02 95.29,-137.54 95.49,-119.47\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"98.98,-119.84 95.83,-109.72 91.99,-119.6 98.98,-119.84\"/>\n",
"</g>\n",
"<!-- feature_matrix -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>feature_matrix</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"436.52\" cy=\"-450\" rx=\"65.68\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"436.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n",
"</g>\n",
"<!-- train_test_split_func -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>train_test_split_func</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-378\" rx=\"86.67\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"320.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n",
"</g>\n",
"<!-- feature_matrix&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>feature_matrix&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M410.2,-433.12C394.67,-423.74 374.76,-411.73 357.65,-401.4\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"359.62,-398.5 349.25,-396.33 356,-404.5 359.62,-398.5\"/>\n",
"</g>\n",
"<!-- target -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>target</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-450\" rx=\"31.9\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"320.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n",
"</g>\n",
"<!-- target&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>target&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M320.52,-431.7C320.52,-424.41 320.52,-415.73 320.52,-407.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-407.62 320.52,-397.62 317.02,-407.62 324.02,-407.62\"/>\n",
"</g>\n",
"<!-- test_size_fraction -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>test_size_fraction</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"620.52\" cy=\"-450\" rx=\"100.48\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"620.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n",
"</g>\n",
"<!-- test_size_fraction&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>test_size_fraction&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M560.89,-435.09C510.99,-423.44 439.8,-406.83 387.85,-394.71\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"388.75,-391.33 378.22,-392.46 387.16,-398.14 388.75,-391.33\"/>\n",
"</g>\n",
"<!-- X_test -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>X_test</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"277.52\" cy=\"-234\" rx=\"34.97\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"277.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n",
"</g>\n",
"<!-- predicted_output -->\n",
"<g id=\"node19\" class=\"node\">\n",
"<title>predicted_output</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"327.52\" cy=\"-162\" rx=\"73.36\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"327.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n",
"</g>\n",
"<!-- X_test&#45;&gt;predicted_output -->\n",
"<g id=\"edge25\" class=\"edge\">\n",
"<title>X_test&#45;&gt;predicted_output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M289.12,-216.76C294.94,-208.61 302.15,-198.53 308.73,-189.31\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"311.45,-191.53 314.41,-181.36 305.75,-187.46 311.45,-191.53\"/>\n",
"</g>\n",
"<!-- predicted_output_with_labels -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>predicted_output_with_labels</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"315.52\" cy=\"-90\" rx=\"120.45\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"315.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n",
"</g>\n",
"<!-- predicted_output_with_labels&#45;&gt;confusion_matrix -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>predicted_output_with_labels&#45;&gt;confusion_matrix</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M310.33,-71.7C308.09,-64.24 305.42,-55.32 302.91,-46.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"306.32,-46.13 300.09,-37.55 299.61,-48.14 306.32,-46.13\"/>\n",
"</g>\n",
"<!-- predicted_output_with_labels&#45;&gt;classification_report -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>predicted_output_with_labels&#45;&gt;classification_report</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M274.24,-72.76C249.96,-63.22 218.95,-51.02 192.6,-40.66\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"194.01,-37.45 183.42,-37.05 191.44,-43.97 194.01,-37.45\"/>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;y_train -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;y_train</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M342.8,-360.41C355.68,-350.82 372.02,-338.65 385.86,-328.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"387.81,-331.25 393.74,-322.47 383.63,-325.64 387.81,-331.25\"/>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;y_test -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;y_test</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.83,-360.15C270.84,-334.5 213.71,-285.63 180.4,-257.14\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"182.71,-254.5 172.84,-250.66 178.16,-259.82 182.71,-254.5\"/>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;X_test -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;X_test</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M299.32,-360.25C289.13,-350.83 277.97,-338.13 272.52,-324 265.07,-304.67 266.78,-281.21 270.11,-263.23\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"273.52,-264.04 272.2,-253.53 266.67,-262.57 273.52,-264.04\"/>\n",
"</g>\n",
"<!-- X_train -->\n",
"<g id=\"node16\" class=\"node\">\n",
"<title>X_train</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-306\" rx=\"39.07\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"320.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;X_train -->\n",
"<g id=\"edge23\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;X_train</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M320.52,-359.7C320.52,-352.41 320.52,-343.73 320.52,-335.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-335.62 320.52,-325.62 317.02,-335.62 324.02,-335.62\"/>\n",
"</g>\n",
"<!-- fit_clf&#45;&gt;predicted_output -->\n",
"<g id=\"edge24\" class=\"edge\">\n",
"<title>fit_clf&#45;&gt;predicted_output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M364.41,-215.7C358.83,-207.73 352.09,-198.1 345.91,-189.26\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"348.97,-187.53 340.36,-181.34 343.23,-191.54 348.97,-187.53\"/>\n",
"</g>\n",
"<!-- target_names -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>target_names</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"175.52\" cy=\"-162\" rx=\"60.56\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"175.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n",
"</g>\n",
"<!-- target_names&#45;&gt;predicted_output_with_labels -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>target_names&#45;&gt;predicted_output_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M205.56,-145.98C224.89,-136.32 250.34,-123.59 271.89,-112.82\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"273.37,-115.99 280.75,-108.39 270.24,-109.73 273.37,-115.99\"/>\n",
"</g>\n",
"<!-- target_names&#45;&gt;y_test_with_labels -->\n",
"<g id=\"edge21\" class=\"edge\">\n",
"<title>target_names&#45;&gt;y_test_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M157.04,-144.41C147.47,-135.82 135.59,-125.16 124.97,-115.63\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"127.32,-113.04 117.54,-108.97 122.65,-118.25 127.32,-113.04\"/>\n",
"</g>\n",
"<!-- y_test_with_labels&#45;&gt;confusion_matrix -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>y_test_with_labels&#45;&gt;confusion_matrix</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M138.82,-74.33C166.68,-64.43 203.87,-51.21 234.92,-40.18\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"235.95,-43.53 244.2,-36.88 233.6,-36.93 235.95,-43.53\"/>\n",
"</g>\n",
"<!-- y_test_with_labels&#45;&gt;classification_report -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>y_test_with_labels&#45;&gt;classification_report</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M107.41,-71.7C111.87,-63.9 117.23,-54.51 122.19,-45.83\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"125.08,-47.84 127,-37.42 119,-44.36 125.08,-47.84\"/>\n",
"</g>\n",
"<!-- prefit_clf -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>prefit_clf</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"514.52\" cy=\"-306\" rx=\"45.21\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"514.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n",
"</g>\n",
"<!-- prefit_clf&#45;&gt;fit_clf -->\n",
"<g id=\"edge16\" class=\"edge\">\n",
"<title>prefit_clf&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M487.26,-291.17C466.28,-280.53 437.05,-265.7 413.89,-253.96\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"415.58,-250.89 405.08,-249.49 412.42,-257.13 415.58,-250.89\"/>\n",
"</g>\n",
"<!-- penalty -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>penalty</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"514.52\" cy=\"-378\" rx=\"62.61\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"514.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n",
"</g>\n",
"<!-- penalty&#45;&gt;prefit_clf -->\n",
"<g id=\"edge22\" class=\"edge\">\n",
"<title>penalty&#45;&gt;prefit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M514.52,-359.7C514.52,-352.41 514.52,-343.73 514.52,-335.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"518.02,-335.62 514.52,-325.62 511.02,-335.62 518.02,-335.62\"/>\n",
"</g>\n",
"<!-- X_train&#45;&gt;fit_clf -->\n",
"<g id=\"edge17\" class=\"edge\">\n",
"<title>X_train&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M333.51,-288.76C340.1,-280.53 348.27,-270.32 355.71,-261.02\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"358.37,-263.3 361.88,-253.31 352.9,-258.93 358.37,-263.3\"/>\n",
"</g>\n",
"<!-- shuffle_train_test_split -->\n",
"<g id=\"node17\" class=\"node\">\n",
"<title>shuffle_train_test_split</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"150.52\" cy=\"-450\" rx=\"120.45\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"150.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n",
"</g>\n",
"<!-- shuffle_train_test_split&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge15\" class=\"edge\">\n",
"<title>shuffle_train_test_split&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M190.38,-432.59C214.88,-422.5 246.42,-409.51 272.32,-398.85\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"273.35,-402.21 281.27,-395.16 270.69,-395.74 273.35,-402.21\"/>\n",
"</g>\n",
"<!-- digit_data -->\n",
"<g id=\"node18\" class=\"node\">\n",
"<title>digit_data</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-522\" rx=\"47.77\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"320.52\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">digit_data</text>\n",
"</g>\n",
"<!-- digit_data&#45;&gt;feature_matrix -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>digit_data&#45;&gt;feature_matrix</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M345.12,-506.15C361.29,-496.4 382.69,-483.48 400.73,-472.6\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"402.3,-475.74 409.05,-467.58 398.68,-469.75 402.3,-475.74\"/>\n",
"</g>\n",
"<!-- digit_data&#45;&gt;target -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>digit_data&#45;&gt;target</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M320.52,-503.7C320.52,-496.41 320.52,-487.73 320.52,-479.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-479.62 320.52,-469.62 317.02,-479.62 324.02,-479.62\"/>\n",
"</g>\n",
"<!-- digit_data&#45;&gt;target_names -->\n",
"<g id=\"edge19\" class=\"edge\">\n",
"<title>digit_data&#45;&gt;target_names</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M272.99,-518.88C197.89,-514.54 56.7,-501.97 21.52,-468 -7.57,-439.9 2.52,-419.45 2.52,-379 2.52,-379 2.52,-379 2.52,-305 2.52,-241.15 73.79,-200.77 124.85,-180.03\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"125.95,-183.35 133.99,-176.45 123.4,-176.83 125.95,-183.35\"/>\n",
"</g>\n",
"<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M324.56,-143.7C323.29,-136.32 321.79,-127.52 320.37,-119.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"323.85,-118.86 318.71,-109.6 316.95,-120.04 323.85,-118.86\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x129594590>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n",
" f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hamilton",
"language": "python",
"name": "hamilton"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}