| { |
| "cells": [ |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "7bf6a40d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import json\n", |
| "import os\n", |
| "\n", |
| "import data_loaders\n", |
| "import model_training\n", |
| "\n", |
| "from hamilton import base, driver\n", |
| "from hamilton.io.materialization import to\n", |
| "import pandas as pd\n", |
| "\n", |
| "import custom_materializers" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "7a449245", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "dag_config = {\n", |
| " \"test_size_fraction\": 0.5,\n", |
| " \"shuffle_train_test_split\": True,\n", |
| " \"data_loader\" : \"iris\",\n", |
| " \"clf\" : \"logistic\",\n", |
| " \"penalty\" : \"l2\"\n", |
| "}\n", |
| "dr = (\n", |
| " driver.Builder()\n", |
| " .with_adapter(base.DefaultAdapter())\n", |
| " .with_config(dag_config)\n", |
| " .with_modules(data_loaders, model_training)\n", |
| " .build()\n", |
| " )" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "397b09bc", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"786pt\" height=\"620pt\"\n", |
| " viewBox=\"0.00 0.00 786.01 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-616 782.01,-616 782.01,4 -4,4\"/>\n", |
| "<!-- predicted_output_with_labels -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>predicted_output_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"281.56\" cy=\"-162\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"281.56\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels_to_csv -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>predicted_output_with_labels_to_csv</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"432.81,-108 214.31,-108 214.31,-72 432.81,-72 432.81,-108\"/>\n", |
| "<text text-anchor=\"middle\" x=\"323.56\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels_to_csv</text>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->predicted_output_with_labels_to_csv -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->predicted_output_with_labels_to_csv</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M291.95,-143.7C296.73,-135.73 302.5,-126.1 307.81,-117.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"311.16,-119.48 313.3,-109.1 305.16,-115.88 311.16,-119.48\"/>\n", |
| "</g>\n", |
| "<!-- classification_report -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>classification_report</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"196.69,-108 70.44,-108 70.44,-72 196.69,-72 196.69,-108\"/>\n", |
| "<text text-anchor=\"middle\" x=\"133.56\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->classification_report -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M246.49,-144.41C226.71,-135.06 201.75,-123.25 180.33,-113.12\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"181.99,-109.56 171.45,-108.45 178.99,-115.89 181.99,-109.56\"/>\n", |
| "</g>\n", |
| "<!-- train_test_split_func -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>train_test_split_func</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-450\" rx=\"86.67\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"371.56\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n", |
| "</g>\n", |
| "<!-- y_test -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>y_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"207.56\" cy=\"-378\" rx=\"32.93\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"207.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_test -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M334.76,-433.29C307.29,-421.57 269.99,-405.64 242.88,-394.08\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"244.61,-390.58 234.04,-389.87 241.86,-397.02 244.61,-390.58\"/>\n", |
| "</g>\n", |
| "<!-- X_train -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>X_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-378\" rx=\"39.07\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"371.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_train -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M371.56,-431.7C371.56,-424.24 371.56,-415.32 371.56,-406.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-407.1 371.56,-397.1 368.06,-407.1 375.06,-407.1\"/>\n", |
| "</g>\n", |
| "<!-- X_test -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>X_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"303.56\" cy=\"-306\" rx=\"34.97\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"303.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_test -->\n", |
| "<g id=\"edge22\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352.09,-432.16C342.13,-422.53 330.65,-409.66 323.56,-396 313.71,-377 308.68,-353.25 306.14,-335.06\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"309.51,-334.79 304.82,-325.29 302.56,-335.64 309.51,-334.79\"/>\n", |
| "</g>\n", |
| "<!-- y_train -->\n", |
| "<g id=\"node22\" class=\"node\">\n", |
| "<title>y_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"465.56\" cy=\"-378\" rx=\"37.02\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"465.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_train -->\n", |
| "<g id=\"edge27\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M393.84,-432.41C406.85,-422.73 423.38,-410.41 437.3,-400.05\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"439.07,-402.35 445,-393.57 434.89,-396.74 439.07,-402.35\"/>\n", |
| "</g>\n", |
| "<!-- fit_clf -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>fit_clf</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"467.56\" cy=\"-306\" rx=\"33.44\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"467.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n", |
| "</g>\n", |
| "<!-- predicted_output -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>predicted_output</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"300.56\" cy=\"-234\" rx=\"73.36\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"300.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->predicted_output -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>fit_clf->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M441.84,-294.22C416.52,-283.6 377.3,-267.17 346.6,-254.3\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"348.13,-250.72 337.55,-250.09 345.42,-257.18 348.13,-250.72\"/>\n", |
| "</g>\n", |
| "<!-- clf_to_pickle -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>clf_to_pickle</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"657.19,-252 569.94,-252 569.94,-216 657.19,-216 657.19,-252\"/>\n", |
| "<text text-anchor=\"middle\" x=\"613.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">clf_to_pickle</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->clf_to_pickle -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>fit_clf->clf_to_pickle</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M492,-293.28C512.57,-283.42 542.55,-269.05 567.59,-257.04\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"568.76,-259.88 576.27,-252.4 565.74,-253.57 568.76,-259.88\"/>\n", |
| "</g>\n", |
| "<!-- model_parameters -->\n", |
| "<g id=\"node21\" class=\"node\">\n", |
| "<title>model_parameters</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"471.56\" cy=\"-234\" rx=\"80.01\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"471.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_parameters</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->model_parameters -->\n", |
| "<g id=\"edge26\" class=\"edge\">\n", |
| "<title>fit_clf->model_parameters</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M468.55,-287.7C468.98,-280.24 469.49,-271.32 469.97,-262.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"473.51,-263.29 470.59,-253.1 466.52,-262.89 473.51,-263.29\"/>\n", |
| "</g>\n", |
| "<!-- data -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>data</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-594\" rx=\"27\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"371.56\" y=\"-588.95\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n", |
| "</g>\n", |
| "<!-- target_names -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>target_names</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"60.56\" cy=\"-378\" rx=\"60.56\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"60.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n", |
| "</g>\n", |
| "<!-- data->target_names -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>data->target_names</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M344.66,-592.1C279.25,-589.31 114.57,-578.52 78.56,-540 45.12,-504.21 48.85,-442.61 54.65,-406.87\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"58.24,-407.71 56.56,-397.25 51.35,-406.47 58.24,-407.71\"/>\n", |
| "</g>\n", |
| "<!-- feature_matrix -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>feature_matrix</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-522\" rx=\"65.68\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"371.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n", |
| "</g>\n", |
| "<!-- data->feature_matrix -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>data->feature_matrix</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M371.56,-575.7C371.56,-568.24 371.56,-559.32 371.56,-550.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-551.1 371.56,-541.1 368.06,-551.1 375.06,-551.1\"/>\n", |
| "</g>\n", |
| "<!-- target -->\n", |
| "<g id=\"node23\" class=\"node\">\n", |
| "<title>target</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"487.56\" cy=\"-522\" rx=\"31.9\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"487.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n", |
| "</g>\n", |
| "<!-- data->target -->\n", |
| "<g id=\"edge28\" class=\"edge\">\n", |
| "<title>data->target</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M391.24,-581.13C409.21,-570.28 436.06,-554.08 456.85,-541.53\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"458.46,-544.05 465.21,-535.88 454.84,-538.05 458.46,-544.05\"/>\n", |
| "</g>\n", |
| "<!-- penalty -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>penalty</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"565.56\" cy=\"-450\" rx=\"62.61\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"565.56\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n", |
| "</g>\n", |
| "<!-- prefit_clf -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>prefit_clf</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"565.56\" cy=\"-378\" rx=\"45.21\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"565.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n", |
| "</g>\n", |
| "<!-- penalty->prefit_clf -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>penalty->prefit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M565.56,-431.7C565.56,-424.24 565.56,-415.32 565.56,-406.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"569.06,-407.1 565.56,-397.1 562.06,-407.1 569.06,-407.1\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output->predicted_output_with_labels -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>predicted_output->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M295.87,-215.7C293.82,-208.15 291.37,-199.12 289.08,-190.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"292.2,-189.84 286.21,-181.1 285.45,-191.67 292.2,-189.84\"/>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>shuffle_train_test_split</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"657.56\" cy=\"-522\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"657.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split->train_test_split_func -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>shuffle_train_test_split->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M596.56,-506.07C549.59,-494.57 484.78,-478.71 436.78,-466.96\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"437.74,-463.35 427.19,-464.37 436.07,-470.15 437.74,-463.35\"/>\n", |
| "</g>\n", |
| "<!-- target_names->predicted_output_with_labels -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>target_names->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M60.36,-359.69C60.96,-340.5 64.46,-309.44 79.56,-288 115.4,-237.11 179.19,-202.78 225.52,-183.17\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"226.51,-186.13 234.42,-179.09 223.84,-179.66 226.51,-186.13\"/>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>y_test_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"168.56\" cy=\"-306\" rx=\"80.01\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"168.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n", |
| "</g>\n", |
| "<!-- target_names->y_test_with_labels -->\n", |
| "<g id=\"edge25\" class=\"edge\">\n", |
| "<title>target_names->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M85.07,-361.12C99.64,-351.67 118.34,-339.56 134.34,-329.18\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"135.92,-331.68 142.4,-323.31 132.11,-325.81 135.92,-331.68\"/>\n", |
| "</g>\n", |
| "<!-- y_test->y_test_with_labels -->\n", |
| "<g id=\"edge24\" class=\"edge\">\n", |
| "<title>y_test->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M198.32,-360.41C193.82,-352.34 188.3,-342.43 183.25,-333.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"185.94,-331.99 178.01,-324.96 179.82,-335.4 185.94,-331.99\"/>\n", |
| "</g>\n", |
| "<!-- X_train->fit_clf -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>X_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M391.92,-362.15C405.82,-352.02 424.4,-338.47 439.67,-327.34\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"441.2,-329.83 447.22,-321.11 437.07,-324.17 441.2,-329.83\"/>\n", |
| "</g>\n", |
| "<!-- model_params_to_json -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>model_params_to_json</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"563.31,-180 419.81,-180 419.81,-144 563.31,-144 563.31,-180\"/>\n", |
| "<text text-anchor=\"middle\" x=\"491.56\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_params_to_json</text>\n", |
| "</g>\n", |
| "<!-- feature_matrix->train_test_split_func -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>feature_matrix->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M371.56,-503.7C371.56,-496.24 371.56,-487.32 371.56,-478.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-479.1 371.56,-469.1 368.06,-479.1 375.06,-479.1\"/>\n", |
| "</g>\n", |
| "<!-- classification_report_to_txt -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>classification_report_to_txt</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"215.81,-36 51.31,-36 51.31,0 215.81,0 215.81,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"133.56\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report_to_txt</text>\n", |
| "</g>\n", |
| "<!-- classification_report->classification_report_to_txt -->\n", |
| "<g id=\"edge23\" class=\"edge\">\n", |
| "<title>classification_report->classification_report_to_txt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M133.56,-71.7C133.56,-64.24 133.56,-55.32 133.56,-46.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"137.06,-47.1 133.56,-37.1 130.06,-47.1 137.06,-47.1\"/>\n", |
| "</g>\n", |
| "<!-- prefit_clf->fit_clf -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>prefit_clf->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M544.3,-361.81C530.09,-351.66 511.23,-338.19 495.77,-327.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"498.29,-323.93 488.12,-320.97 494.22,-329.63 498.29,-323.93\"/>\n", |
| "</g>\n", |
| "<!-- X_test->predicted_output -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>X_test->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M302.82,-287.7C302.5,-280.24 302.12,-271.32 301.76,-262.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"305.22,-262.95 301.3,-253.1 298.23,-263.25 305.22,-262.95\"/>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels->classification_report -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>y_test_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M165.76,-287.85C159.73,-250.99 145.49,-163.92 138.14,-118.96\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"141.45,-118.54 136.38,-109.23 134.54,-119.67 141.45,-118.54\"/>\n", |
| "</g>\n", |
| "<!-- test_size_fraction -->\n", |
| "<g id=\"node20\" class=\"node\">\n", |
| "<title>test_size_fraction</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"187.56\" cy=\"-522\" rx=\"100.48\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"187.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n", |
| "</g>\n", |
| "<!-- test_size_fraction->train_test_split_func -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>test_size_fraction->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M228.86,-505.29C256.25,-494.87 292.36,-481.13 321.41,-470.08\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"322.26,-473.12 330.36,-466.3 319.77,-466.58 322.26,-473.12\"/>\n", |
| "</g>\n", |
| "<!-- model_parameters->model_params_to_json -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>model_parameters->model_params_to_json</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M476.51,-215.7C478.66,-208.15 481.24,-199.12 483.66,-190.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"487.3,-191.68 486.68,-181.1 480.56,-189.76 487.3,-191.68\"/>\n", |
| "</g>\n", |
| "<!-- y_train->fit_clf -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>y_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M466.06,-359.7C466.27,-352.24 466.53,-343.32 466.76,-334.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"470.29,-335.2 467.08,-325.1 463.29,-335 470.29,-335.2\"/>\n", |
| "</g>\n", |
| "<!-- target->train_test_split_func -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>target->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M466.29,-508.16C450.06,-498.37 427.32,-484.65 408.19,-473.1\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"410.22,-469.63 399.85,-467.46 406.6,-475.63 410.22,-469.63\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x16a2c7730>" |
| ] |
| }, |
| "execution_count": 10, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "materializers = [\n", |
| " to.json(\n", |
| " dependencies=[\"model_parameters\"],\n", |
| " id=\"model_params_to_json\",\n", |
| " path=\"./data/params.json\"\n", |
| " ),\n", |
| " # classification report to .txt file\n", |
| " to.file(\n", |
| " dependencies=[\"classification_report\"],\n", |
| " id=\"classification_report_to_txt\",\n", |
| " path=\"./data/classification_report.txt\",\n", |
| " ),\n", |
| " # materialize the model to a pickle file\n", |
| " to.pickle(\n", |
| " dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n", |
| " ),\n", |
| " # materialize the predictions we made to a csv file\n", |
| " to.csv(\n", |
| " dependencies=[\"predicted_output_with_labels\"],\n", |
| " id=\"predicted_output_with_labels_to_csv\",\n", |
| " path=\"./data/predicted_output_with_labels.csv\",\n", |
| " ),\n", |
| " ]\n", |
| "\n", |
| "dr.visualize_materialization(\n", |
| " *materializers,\n", |
| " additional_vars=[\"classification_report\"],\n", |
| " output_file_path=None,\n", |
| " render_kwargs={},\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "f5727b54", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/Users/elijahbenizzy/.pyenv/versions/3.9.10/envs/hamilton/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", |
| "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", |
| "\n", |
| "Increase the number of iterations (max_iter) or scale the data as shown in:\n", |
| " https://scikit-learn.org/stable/modules/preprocessing.html\n", |
| "Please also refer to the documentation for alternative solver options:\n", |
| " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", |
| " n_iter_i = _check_optimize_result(\n" |
| ] |
| } |
| ], |
| "source": [ |
| "materialization_results, additional_vars = dr.materialize(\n", |
| " # materialize model parameters to json\n", |
| " *materializers,\n", |
| " additional_vars=[\"classification_report\"],\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "8bdfde70", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " precision recall f1-score support\n", |
| "\n", |
| " 0 1.00 1.00 1.00 94\n", |
| " 1 0.91 0.93 0.92 85\n", |
| " 2 0.97 0.99 0.98 96\n", |
| " 3 0.99 0.97 0.98 93\n", |
| " 4 0.99 0.92 0.95 88\n", |
| " 5 0.95 0.95 0.95 85\n", |
| " 6 0.99 0.97 0.98 97\n", |
| " 7 0.97 0.97 0.97 89\n", |
| " 8 0.88 0.88 0.88 82\n", |
| " 9 0.91 0.97 0.94 90\n", |
| "\n", |
| " accuracy 0.96 899\n", |
| " macro avg 0.95 0.95 0.95 899\n", |
| "weighted avg 0.96 0.96 0.96 899\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(additional_vars['classification_report'])" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "a6f5fe83", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " precision recall f1-score support\n", |
| "\n", |
| " 0 1.00 1.00 1.00 94\n", |
| " 1 0.91 0.93 0.92 85\n", |
| " 2 0.97 0.99 0.98 96\n", |
| " 3 0.99 0.97 0.98 93\n", |
| " 4 0.99 0.92 0.95 88\n", |
| " 5 0.95 0.95 0.95 85\n", |
| " 6 0.99 0.97 0.98 97\n", |
| " 7 0.97 0.97 0.97 89\n", |
| " 8 0.88 0.88 0.88 82\n", |
| " 9 0.91 0.97 0.94 90\n", |
| "\n", |
| " accuracy 0.96 899\n", |
| " macro avg 0.95 0.95 0.95 899\n", |
| "weighted avg 0.96 0.96 0.96 899\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(open((materialization_results['classification_report_to_txt']['path'])).read())" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.9.10" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |