| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "47ed8323-e689-464c-83ec-1ee98d2c2585", |
| "metadata": {}, |
| "source": [ |
| "# Hamilton for ML dataflows\n", |
| "\n", |
| "#### Requirements:\n", |
| "\n", |
| "- Install dependencies (listed in `requirements.txt`)\n", |
| "\n", |
| "More details [here](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/model_examples/scikit-learn/README.md#using-hamilton-for-ml-dataflows).\n", |
| "\n", |
| "***\n", |
| "\n", |
| "Uncomment and run the cell below if you are in a Google Colab environment. It will:\n", |
| "1. Mount google drive. You will be asked to authenticate and give permissions.\n", |
| "2. Change directory to google drive.\n", |
| "3. Make a directory \"hamilton-tutorials\"\n", |
| "4. Change directory to it.\n", |
| "5. Clone this repository to your google drive\n", |
| "6. Move your current directory to the hello_world example\n", |
| "7. Install requirements.\n", |
| "\n", |
| "This means that any modifications will be saved, and you won't lose them if you close your browser." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "d5e12e1c-a8b2-477a-a9ff-6257ab587734", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "## 1. Mount google drive\n", |
| "# from google.colab import drive\n", |
| "# drive.mount('/content/drive')\n", |
| "## 2. Change directory to google drive.\n", |
| "# %cd /content/drive/MyDrive\n", |
| "## 3. Make a directory \"hamilton-tutorials\"\n", |
| "# !mkdir hamilton-tutorials\n", |
| "## 4. Change directory to it.\n", |
| "# %cd hamilton-tutorials\n", |
| "## 5. Clone this repository to your google drive\n", |
| "# !git clone https://github.com/DAGWorks-Inc/hamilton/\n", |
| "## 6. Move your current directory to the hello_world example\n", |
| "# %cd hamilton/examples/hello_world\n", |
| "## 7. Install requirements.\n", |
| "# %pip install -r requirements.txt\n", |
| "# clear_output() # optionally clear outputs\n", |
| "# To check your current working directory you can type `!pwd` in a cell and run it." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "9115ca99-cb3b-4dc3-8218-fa26b00d2199", |
| "metadata": {}, |
| "source": [ |
| "***\n", |
| "Here we have a simple example showing how you can write a ML training and evaluation workflow with Hamilton. \n", |
| "***" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "04fa1ff7-74f7-4193-9e1f-c17d9e68efc5", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "\"\"\"\n", |
| "Example script showing how one might setup a generic model training pipeline that is quickly configurable.\n", |
| "\"\"\"\n", |
| "\n", |
| "import digit_loader\n", |
| "import iris_loader\n", |
| "import my_train_evaluate_logic\n", |
| "\n", |
| "from hamilton import base, driver\n", |
| "\n", |
| "\n", |
| "def get_data_loader(data_set: str):\n", |
| " \"\"\"Returns the module to load that will procur data -- the data loaders all have to define the same functions.\"\"\"\n", |
| " if data_set == \"iris\":\n", |
| " return iris_loader\n", |
| " elif data_set == \"digits\":\n", |
| " return digit_loader\n", |
| " else:\n", |
| " raise ValueError(f\"Unknown data_name {data_set}.\")\n", |
| "\n", |
| "\n", |
| "def get_model_config(model_type: str) -> dict:\n", |
| " \"\"\"Returns model type specific configuration\"\"\"\n", |
| " if model_type == \"svm\":\n", |
| " return {\"clf\": \"svm\", \"gamma\": 0.001}\n", |
| " elif model_type == \"logistic\":\n", |
| " return {\"clf\": \"logistic\", \"penalty\": \"l2\"}\n", |
| " else:\n", |
| " raise ValueError(f\"Unsupported model {model_type}.\")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "88ccbc7c-f265-47fa-a921-f26ac3ed7094", |
| "metadata": {}, |
| "source": [ |
| "***\n", |
| "For the purpose of this experiment, lets apply the following configuration:\n", |
| "\n", |
| "- `_data_set` = 'digits'\n", |
| "- `_model_type` = 'logistic'\n", |
| "\n", |
| "More details [here](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/model_examples/scikit-learn/README.md).\n", |
| "***" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "9e5e7282-8286-4055-847f-adb168420da0", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n" |
| ] |
| }, |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "classification_report :\n", |
| " precision recall f1-score support\n", |
| "\n", |
| " 0 1.00 0.99 0.99 91\n", |
| " 1 0.92 0.95 0.94 84\n", |
| " 2 0.98 1.00 0.99 83\n", |
| " 3 0.99 0.98 0.98 81\n", |
| " 4 0.95 0.99 0.97 95\n", |
| " 5 0.98 0.94 0.96 97\n", |
| " 6 0.97 0.98 0.97 85\n", |
| " 7 0.98 0.98 0.98 96\n", |
| " 8 0.91 0.90 0.91 96\n", |
| " 9 0.96 0.93 0.94 91\n", |
| "\n", |
| " accuracy 0.96 899\n", |
| " macro avg 0.96 0.96 0.96 899\n", |
| "weighted avg 0.96 0.96 0.96 899\n", |
| "\n", |
| "confusion_matrix :\n", |
| " [[90 0 0 0 1 0 0 0 0 0]\n", |
| " [ 0 80 0 0 1 0 1 0 2 0]\n", |
| " [ 0 0 83 0 0 0 0 0 0 0]\n", |
| " [ 0 0 0 79 0 0 0 1 0 1]\n", |
| " [ 0 1 0 0 94 0 0 0 0 0]\n", |
| " [ 0 1 0 1 1 91 0 1 0 2]\n", |
| " [ 0 0 0 0 0 0 83 0 2 0]\n", |
| " [ 0 0 0 0 1 0 0 94 0 1]\n", |
| " [ 0 5 2 0 0 1 2 0 86 0]\n", |
| " [ 0 0 0 0 1 1 0 0 4 85]]\n", |
| "fit_clf :\n", |
| " LogisticRegression()\n" |
| ] |
| }, |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/Users/flaviassantos/github/hamilton/venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", |
| "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", |
| "\n", |
| "Increase the number of iterations (max_iter) or scale the data as shown in:\n", |
| " https://scikit-learn.org/stable/modules/preprocessing.html\n", |
| "Please also refer to the documentation for alternative solver options:\n", |
| " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", |
| " n_iter_i = _check_optimize_result(\n" |
| ] |
| } |
| ], |
| "source": [ |
| "_data_set = 'digits' # the data set to load\n", |
| "_model_type = 'logistic' # the model type to fit and evaluate with\n", |
| "\n", |
| "dag_config = {\n", |
| " \"test_size_fraction\": 0.5,\n", |
| " \"shuffle_train_test_split\": True,\n", |
| "}\n", |
| "# augment config\n", |
| "dag_config.update(get_model_config(_model_type))\n", |
| "# get module with functions to load data\n", |
| "data_module = get_data_loader(_data_set)\n", |
| "# set the desired result container we want\n", |
| "adapter = base.DefaultAdapter()\n", |
| "\"\"\"\n", |
| "What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can\n", |
| "reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if\n", |
| "want to setup a generic model fitting and prediction dataflow!\n", |
| "E.g. if we want to support a new data set, then we just need to add a new data loading module.\n", |
| "E.g. if we want to support a new model type, then we just need to add a single conditional function\n", |
| " to my_train_evaluate_logic.\n", |
| "\"\"\"\n", |
| "dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter)\n", |
| "# ensure you have done \"pip install \"sf-hamilton[visualization]\"\" for the following to work:\n", |
| "# dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n", |
| "# f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})\n", |
| "results = dr.execute([\"classification_report\", \"confusion_matrix\", \"fit_clf\"])\n", |
| "for k, v in results.items():\n", |
| " print(k, \":\\n\", v)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "2035065c-c409-4c21-bd11-733e74623226", |
| "metadata": {}, |
| "source": [ |
| "***\n", |
| "Here is the graph of execution for the digits data set:\n", |
| "***" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "17522217-8a09-46da-8b8d-0ba97d278bdc", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 9.0.0 (20230911.1827)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"729pt\" height=\"548pt\"\n", |
| " viewBox=\"0.00 0.00 729.01 548.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 544)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-544 725.01,-544 725.01,4 -4,4\"/>\n", |
| "<!-- confusion_matrix -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>confusion_matrix</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"350.52,-36 238.52,-36 238.52,0 350.52,0 350.52,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"294.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">confusion_matrix</text>\n", |
| "</g>\n", |
| "<!-- classification_report -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>classification_report</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"200.65,-36 74.4,-36 74.4,0 200.65,0 200.65,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"137.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n", |
| "</g>\n", |
| "<!-- y_train -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>y_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"414.52\" cy=\"-306\" rx=\"37.02\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"414.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n", |
| "</g>\n", |
| "<!-- fit_clf -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>fit_clf</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"403.52,-252 349.52,-252 349.52,-216 403.52,-216 403.52,-252\"/>\n", |
| "<text text-anchor=\"middle\" x=\"376.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n", |
| "</g>\n", |
| "<!-- y_train->fit_clf -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>y_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M405.52,-288.41C401.29,-280.62 396.14,-271.14 391.36,-262.33\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"394.56,-260.9 386.72,-253.78 388.41,-264.24 394.56,-260.9\"/>\n", |
| "</g>\n", |
| "<!-- y_test -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>y_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"154.52\" cy=\"-234\" rx=\"32.93\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"154.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>y_test_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"97.52\" cy=\"-90\" rx=\"80.01\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"97.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n", |
| "</g>\n", |
| "<!-- y_test->y_test_with_labels -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>y_test->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M136.39,-218.82C125.44,-209.15 112.34,-195.35 105.52,-180 97.09,-161.02 95.29,-137.54 95.49,-119.47\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"98.98,-119.84 95.83,-109.72 91.99,-119.6 98.98,-119.84\"/>\n", |
| "</g>\n", |
| "<!-- feature_matrix -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>feature_matrix</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"436.52\" cy=\"-450\" rx=\"65.68\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"436.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>train_test_split_func</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-378\" rx=\"86.67\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"320.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n", |
| "</g>\n", |
| "<!-- feature_matrix->train_test_split_func -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>feature_matrix->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M410.2,-433.12C394.67,-423.74 374.76,-411.73 357.65,-401.4\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"359.62,-398.5 349.25,-396.33 356,-404.5 359.62,-398.5\"/>\n", |
| "</g>\n", |
| "<!-- target -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>target</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-450\" rx=\"31.9\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"320.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n", |
| "</g>\n", |
| "<!-- target->train_test_split_func -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>target->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-431.7C320.52,-424.41 320.52,-415.73 320.52,-407.54\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-407.62 320.52,-397.62 317.02,-407.62 324.02,-407.62\"/>\n", |
| "</g>\n", |
| "<!-- test_size_fraction -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>test_size_fraction</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"620.52\" cy=\"-450\" rx=\"100.48\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"620.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n", |
| "</g>\n", |
| "<!-- test_size_fraction->train_test_split_func -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>test_size_fraction->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M560.89,-435.09C510.99,-423.44 439.8,-406.83 387.85,-394.71\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"388.75,-391.33 378.22,-392.46 387.16,-398.14 388.75,-391.33\"/>\n", |
| "</g>\n", |
| "<!-- X_test -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>X_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"277.52\" cy=\"-234\" rx=\"34.97\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"277.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n", |
| "</g>\n", |
| "<!-- predicted_output -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>predicted_output</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"327.52\" cy=\"-162\" rx=\"73.36\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"327.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n", |
| "</g>\n", |
| "<!-- X_test->predicted_output -->\n", |
| "<g id=\"edge25\" class=\"edge\">\n", |
| "<title>X_test->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M289.12,-216.76C294.94,-208.61 302.15,-198.53 308.73,-189.31\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"311.45,-191.53 314.41,-181.36 305.75,-187.46 311.45,-191.53\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>predicted_output_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"315.52\" cy=\"-90\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"315.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->confusion_matrix -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->confusion_matrix</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M310.33,-71.7C308.09,-64.24 305.42,-55.32 302.91,-46.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"306.32,-46.13 300.09,-37.55 299.61,-48.14 306.32,-46.13\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->classification_report -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M274.24,-72.76C249.96,-63.22 218.95,-51.02 192.6,-40.66\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"194.01,-37.45 183.42,-37.05 191.44,-43.97 194.01,-37.45\"/>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_train -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M342.8,-360.41C355.68,-350.82 372.02,-338.65 385.86,-328.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"387.81,-331.25 393.74,-322.47 383.63,-325.64 387.81,-331.25\"/>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_test -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M300.83,-360.15C270.84,-334.5 213.71,-285.63 180.4,-257.14\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"182.71,-254.5 172.84,-250.66 178.16,-259.82 182.71,-254.5\"/>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_test -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M299.32,-360.25C289.13,-350.83 277.97,-338.13 272.52,-324 265.07,-304.67 266.78,-281.21 270.11,-263.23\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"273.52,-264.04 272.2,-253.53 266.67,-262.57 273.52,-264.04\"/>\n", |
| "</g>\n", |
| "<!-- X_train -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>X_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-306\" rx=\"39.07\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"320.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_train -->\n", |
| "<g id=\"edge23\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-359.7C320.52,-352.41 320.52,-343.73 320.52,-335.54\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-335.62 320.52,-325.62 317.02,-335.62 324.02,-335.62\"/>\n", |
| "</g>\n", |
| "<!-- fit_clf->predicted_output -->\n", |
| "<g id=\"edge24\" class=\"edge\">\n", |
| "<title>fit_clf->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M364.41,-215.7C358.83,-207.73 352.09,-198.1 345.91,-189.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"348.97,-187.53 340.36,-181.34 343.23,-191.54 348.97,-187.53\"/>\n", |
| "</g>\n", |
| "<!-- target_names -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>target_names</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"175.52\" cy=\"-162\" rx=\"60.56\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"175.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n", |
| "</g>\n", |
| "<!-- target_names->predicted_output_with_labels -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>target_names->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M205.56,-145.98C224.89,-136.32 250.34,-123.59 271.89,-112.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"273.37,-115.99 280.75,-108.39 270.24,-109.73 273.37,-115.99\"/>\n", |
| "</g>\n", |
| "<!-- target_names->y_test_with_labels -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>target_names->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M157.04,-144.41C147.47,-135.82 135.59,-125.16 124.97,-115.63\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"127.32,-113.04 117.54,-108.97 122.65,-118.25 127.32,-113.04\"/>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels->confusion_matrix -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>y_test_with_labels->confusion_matrix</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M138.82,-74.33C166.68,-64.43 203.87,-51.21 234.92,-40.18\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"235.95,-43.53 244.2,-36.88 233.6,-36.93 235.95,-43.53\"/>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels->classification_report -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>y_test_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M107.41,-71.7C111.87,-63.9 117.23,-54.51 122.19,-45.83\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"125.08,-47.84 127,-37.42 119,-44.36 125.08,-47.84\"/>\n", |
| "</g>\n", |
| "<!-- prefit_clf -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>prefit_clf</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"514.52\" cy=\"-306\" rx=\"45.21\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"514.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n", |
| "</g>\n", |
| "<!-- prefit_clf->fit_clf -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>prefit_clf->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M487.26,-291.17C466.28,-280.53 437.05,-265.7 413.89,-253.96\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"415.58,-250.89 405.08,-249.49 412.42,-257.13 415.58,-250.89\"/>\n", |
| "</g>\n", |
| "<!-- penalty -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>penalty</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"514.52\" cy=\"-378\" rx=\"62.61\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"514.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n", |
| "</g>\n", |
| "<!-- penalty->prefit_clf -->\n", |
| "<g id=\"edge22\" class=\"edge\">\n", |
| "<title>penalty->prefit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M514.52,-359.7C514.52,-352.41 514.52,-343.73 514.52,-335.54\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"518.02,-335.62 514.52,-325.62 511.02,-335.62 518.02,-335.62\"/>\n", |
| "</g>\n", |
| "<!-- X_train->fit_clf -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>X_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M333.51,-288.76C340.1,-280.53 348.27,-270.32 355.71,-261.02\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"358.37,-263.3 361.88,-253.31 352.9,-258.93 358.37,-263.3\"/>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>shuffle_train_test_split</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"150.52\" cy=\"-450\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"150.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split->train_test_split_func -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>shuffle_train_test_split->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M190.38,-432.59C214.88,-422.5 246.42,-409.51 272.32,-398.85\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"273.35,-402.21 281.27,-395.16 270.69,-395.74 273.35,-402.21\"/>\n", |
| "</g>\n", |
| "<!-- digit_data -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>digit_data</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-522\" rx=\"47.77\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"320.52\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">digit_data</text>\n", |
| "</g>\n", |
| "<!-- digit_data->feature_matrix -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>digit_data->feature_matrix</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M345.12,-506.15C361.29,-496.4 382.69,-483.48 400.73,-472.6\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"402.3,-475.74 409.05,-467.58 398.68,-469.75 402.3,-475.74\"/>\n", |
| "</g>\n", |
| "<!-- digit_data->target -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>digit_data->target</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-503.7C320.52,-496.41 320.52,-487.73 320.52,-479.54\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-479.62 320.52,-469.62 317.02,-479.62 324.02,-479.62\"/>\n", |
| "</g>\n", |
| "<!-- digit_data->target_names -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>digit_data->target_names</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M272.99,-518.88C197.89,-514.54 56.7,-501.97 21.52,-468 -7.57,-439.9 2.52,-419.45 2.52,-379 2.52,-379 2.52,-379 2.52,-305 2.52,-241.15 73.79,-200.77 124.85,-180.03\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"125.95,-183.35 133.99,-176.45 123.4,-176.83 125.95,-183.35\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output->predicted_output_with_labels -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>predicted_output->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M324.56,-143.7C323.29,-136.32 321.79,-127.52 320.37,-119.25\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"323.85,-118.86 318.71,-109.6 316.95,-120.04 323.85,-118.86\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x129594590>" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n", |
| " f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "hamilton", |
| "language": "python", |
| "name": "hamilton" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.11.3" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |