blob: 0a414930aedf42741aa9355dad3f41e57376c02a [file] [log] [blame]
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "7bf6a40d",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"import data_loaders\n",
"import model_training\n",
"\n",
"from hamilton import base, driver\n",
"from hamilton.io.materialization import to\n",
"import pandas as pd\n",
"\n",
"import custom_materializers"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7a449245",
"metadata": {},
"outputs": [],
"source": [
"dag_config = {\n",
" \"test_size_fraction\": 0.5,\n",
" \"shuffle_train_test_split\": True,\n",
" \"data_loader\" : \"iris\",\n",
" \"clf\" : \"logistic\",\n",
" \"penalty\" : \"l2\"\n",
"}\n",
"dr = (\n",
" driver.Builder()\n",
" .with_adapter(base.DefaultAdapter())\n",
" .with_config(dag_config)\n",
" .with_modules(data_loaders, model_training)\n",
" .build()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "397b09bc",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"786pt\" height=\"620pt\"\n",
" viewBox=\"0.00 0.00 786.01 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-616 782.01,-616 782.01,4 -4,4\"/>\n",
"<!-- predicted_output_with_labels -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>predicted_output_with_labels</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"281.56\" cy=\"-162\" rx=\"120.45\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"281.56\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n",
"</g>\n",
"<!-- predicted_output_with_labels_to_csv -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>predicted_output_with_labels_to_csv</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"432.81,-108 214.31,-108 214.31,-72 432.81,-72 432.81,-108\"/>\n",
"<text text-anchor=\"middle\" x=\"323.56\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels_to_csv</text>\n",
"</g>\n",
"<!-- predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M291.95,-143.7C296.73,-135.73 302.5,-126.1 307.81,-117.26\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"311.16,-119.48 313.3,-109.1 305.16,-115.88 311.16,-119.48\"/>\n",
"</g>\n",
"<!-- classification_report -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>classification_report</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"196.69,-108 70.44,-108 70.44,-72 196.69,-72 196.69,-108\"/>\n",
"<text text-anchor=\"middle\" x=\"133.56\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n",
"</g>\n",
"<!-- predicted_output_with_labels&#45;&gt;classification_report -->\n",
"<g id=\"edge19\" class=\"edge\">\n",
"<title>predicted_output_with_labels&#45;&gt;classification_report</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M246.49,-144.41C226.71,-135.06 201.75,-123.25 180.33,-113.12\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"181.99,-109.56 171.45,-108.45 178.99,-115.89 181.99,-109.56\"/>\n",
"</g>\n",
"<!-- train_test_split_func -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>train_test_split_func</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-450\" rx=\"86.67\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"371.56\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n",
"</g>\n",
"<!-- y_test -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>y_test</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"207.56\" cy=\"-378\" rx=\"32.93\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"207.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;y_test -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;y_test</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M334.76,-433.29C307.29,-421.57 269.99,-405.64 242.88,-394.08\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"244.61,-390.58 234.04,-389.87 241.86,-397.02 244.61,-390.58\"/>\n",
"</g>\n",
"<!-- X_train -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>X_train</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-378\" rx=\"39.07\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"371.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;X_train -->\n",
"<g id=\"edge15\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;X_train</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M371.56,-431.7C371.56,-424.24 371.56,-415.32 371.56,-406.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-407.1 371.56,-397.1 368.06,-407.1 375.06,-407.1\"/>\n",
"</g>\n",
"<!-- X_test -->\n",
"<g id=\"node17\" class=\"node\">\n",
"<title>X_test</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"303.56\" cy=\"-306\" rx=\"34.97\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"303.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;X_test -->\n",
"<g id=\"edge22\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;X_test</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M352.09,-432.16C342.13,-422.53 330.65,-409.66 323.56,-396 313.71,-377 308.68,-353.25 306.14,-335.06\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"309.51,-334.79 304.82,-325.29 302.56,-335.64 309.51,-334.79\"/>\n",
"</g>\n",
"<!-- y_train -->\n",
"<g id=\"node22\" class=\"node\">\n",
"<title>y_train</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"465.56\" cy=\"-378\" rx=\"37.02\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"465.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n",
"</g>\n",
"<!-- train_test_split_func&#45;&gt;y_train -->\n",
"<g id=\"edge27\" class=\"edge\">\n",
"<title>train_test_split_func&#45;&gt;y_train</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M393.84,-432.41C406.85,-422.73 423.38,-410.41 437.3,-400.05\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"439.07,-402.35 445,-393.57 434.89,-396.74 439.07,-402.35\"/>\n",
"</g>\n",
"<!-- fit_clf -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>fit_clf</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"467.56\" cy=\"-306\" rx=\"33.44\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"467.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n",
"</g>\n",
"<!-- predicted_output -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>predicted_output</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"300.56\" cy=\"-234\" rx=\"73.36\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"300.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n",
"</g>\n",
"<!-- fit_clf&#45;&gt;predicted_output -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>fit_clf&#45;&gt;predicted_output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M441.84,-294.22C416.52,-283.6 377.3,-267.17 346.6,-254.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"348.13,-250.72 337.55,-250.09 345.42,-257.18 348.13,-250.72\"/>\n",
"</g>\n",
"<!-- clf_to_pickle -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>clf_to_pickle</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"657.19,-252 569.94,-252 569.94,-216 657.19,-216 657.19,-252\"/>\n",
"<text text-anchor=\"middle\" x=\"613.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">clf_to_pickle</text>\n",
"</g>\n",
"<!-- fit_clf&#45;&gt;clf_to_pickle -->\n",
"<g id=\"edge17\" class=\"edge\">\n",
"<title>fit_clf&#45;&gt;clf_to_pickle</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M492,-293.28C512.57,-283.42 542.55,-269.05 567.59,-257.04\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"568.76,-259.88 576.27,-252.4 565.74,-253.57 568.76,-259.88\"/>\n",
"</g>\n",
"<!-- model_parameters -->\n",
"<g id=\"node21\" class=\"node\">\n",
"<title>model_parameters</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"471.56\" cy=\"-234\" rx=\"80.01\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"471.56\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_parameters</text>\n",
"</g>\n",
"<!-- fit_clf&#45;&gt;model_parameters -->\n",
"<g id=\"edge26\" class=\"edge\">\n",
"<title>fit_clf&#45;&gt;model_parameters</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M468.55,-287.7C468.98,-280.24 469.49,-271.32 469.97,-262.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"473.51,-263.29 470.59,-253.1 466.52,-262.89 473.51,-263.29\"/>\n",
"</g>\n",
"<!-- data -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>data</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-594\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"371.56\" y=\"-588.95\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
"</g>\n",
"<!-- target_names -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>target_names</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"60.56\" cy=\"-378\" rx=\"60.56\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"60.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n",
"</g>\n",
"<!-- data&#45;&gt;target_names -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>data&#45;&gt;target_names</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M344.66,-592.1C279.25,-589.31 114.57,-578.52 78.56,-540 45.12,-504.21 48.85,-442.61 54.65,-406.87\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"58.24,-407.71 56.56,-397.25 51.35,-406.47 58.24,-407.71\"/>\n",
"</g>\n",
"<!-- feature_matrix -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>feature_matrix</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"371.56\" cy=\"-522\" rx=\"65.68\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"371.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n",
"</g>\n",
"<!-- data&#45;&gt;feature_matrix -->\n",
"<g id=\"edge18\" class=\"edge\">\n",
"<title>data&#45;&gt;feature_matrix</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M371.56,-575.7C371.56,-568.24 371.56,-559.32 371.56,-550.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-551.1 371.56,-541.1 368.06,-551.1 375.06,-551.1\"/>\n",
"</g>\n",
"<!-- target -->\n",
"<g id=\"node23\" class=\"node\">\n",
"<title>target</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"487.56\" cy=\"-522\" rx=\"31.9\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"487.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n",
"</g>\n",
"<!-- data&#45;&gt;target -->\n",
"<g id=\"edge28\" class=\"edge\">\n",
"<title>data&#45;&gt;target</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M391.24,-581.13C409.21,-570.28 436.06,-554.08 456.85,-541.53\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"458.46,-544.05 465.21,-535.88 454.84,-538.05 458.46,-544.05\"/>\n",
"</g>\n",
"<!-- penalty -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>penalty</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"565.56\" cy=\"-450\" rx=\"62.61\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"565.56\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n",
"</g>\n",
"<!-- prefit_clf -->\n",
"<g id=\"node16\" class=\"node\">\n",
"<title>prefit_clf</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"565.56\" cy=\"-378\" rx=\"45.21\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"565.56\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n",
"</g>\n",
"<!-- penalty&#45;&gt;prefit_clf -->\n",
"<g id=\"edge21\" class=\"edge\">\n",
"<title>penalty&#45;&gt;prefit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M565.56,-431.7C565.56,-424.24 565.56,-415.32 565.56,-406.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"569.06,-407.1 565.56,-397.1 562.06,-407.1 569.06,-407.1\"/>\n",
"</g>\n",
"<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M295.87,-215.7C293.82,-208.15 291.37,-199.12 289.08,-190.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"292.2,-189.84 286.21,-181.1 285.45,-191.67 292.2,-189.84\"/>\n",
"</g>\n",
"<!-- shuffle_train_test_split -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>shuffle_train_test_split</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"657.56\" cy=\"-522\" rx=\"120.45\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"657.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n",
"</g>\n",
"<!-- shuffle_train_test_split&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>shuffle_train_test_split&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M596.56,-506.07C549.59,-494.57 484.78,-478.71 436.78,-466.96\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"437.74,-463.35 427.19,-464.37 436.07,-470.15 437.74,-463.35\"/>\n",
"</g>\n",
"<!-- target_names&#45;&gt;predicted_output_with_labels -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>target_names&#45;&gt;predicted_output_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M60.36,-359.69C60.96,-340.5 64.46,-309.44 79.56,-288 115.4,-237.11 179.19,-202.78 225.52,-183.17\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"226.51,-186.13 234.42,-179.09 223.84,-179.66 226.51,-186.13\"/>\n",
"</g>\n",
"<!-- y_test_with_labels -->\n",
"<g id=\"node19\" class=\"node\">\n",
"<title>y_test_with_labels</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"168.56\" cy=\"-306\" rx=\"80.01\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"168.56\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n",
"</g>\n",
"<!-- target_names&#45;&gt;y_test_with_labels -->\n",
"<g id=\"edge25\" class=\"edge\">\n",
"<title>target_names&#45;&gt;y_test_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M85.07,-361.12C99.64,-351.67 118.34,-339.56 134.34,-329.18\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"135.92,-331.68 142.4,-323.31 132.11,-325.81 135.92,-331.68\"/>\n",
"</g>\n",
"<!-- y_test&#45;&gt;y_test_with_labels -->\n",
"<g id=\"edge24\" class=\"edge\">\n",
"<title>y_test&#45;&gt;y_test_with_labels</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M198.32,-360.41C193.82,-352.34 188.3,-342.43 183.25,-333.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"185.94,-331.99 178.01,-324.96 179.82,-335.4 185.94,-331.99\"/>\n",
"</g>\n",
"<!-- X_train&#45;&gt;fit_clf -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>X_train&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M391.92,-362.15C405.82,-352.02 424.4,-338.47 439.67,-327.34\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"441.2,-329.83 447.22,-321.11 437.07,-324.17 441.2,-329.83\"/>\n",
"</g>\n",
"<!-- model_params_to_json -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>model_params_to_json</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"563.31,-180 419.81,-180 419.81,-144 563.31,-144 563.31,-180\"/>\n",
"<text text-anchor=\"middle\" x=\"491.56\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_params_to_json</text>\n",
"</g>\n",
"<!-- feature_matrix&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>feature_matrix&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M371.56,-503.7C371.56,-496.24 371.56,-487.32 371.56,-478.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"375.06,-479.1 371.56,-469.1 368.06,-479.1 375.06,-479.1\"/>\n",
"</g>\n",
"<!-- classification_report_to_txt -->\n",
"<g id=\"node18\" class=\"node\">\n",
"<title>classification_report_to_txt</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"215.81,-36 51.31,-36 51.31,0 215.81,0 215.81,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"133.56\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report_to_txt</text>\n",
"</g>\n",
"<!-- classification_report&#45;&gt;classification_report_to_txt -->\n",
"<g id=\"edge23\" class=\"edge\">\n",
"<title>classification_report&#45;&gt;classification_report_to_txt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M133.56,-71.7C133.56,-64.24 133.56,-55.32 133.56,-46.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"137.06,-47.1 133.56,-37.1 130.06,-47.1 137.06,-47.1\"/>\n",
"</g>\n",
"<!-- prefit_clf&#45;&gt;fit_clf -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>prefit_clf&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M544.3,-361.81C530.09,-351.66 511.23,-338.19 495.77,-327.15\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"498.29,-323.93 488.12,-320.97 494.22,-329.63 498.29,-323.93\"/>\n",
"</g>\n",
"<!-- X_test&#45;&gt;predicted_output -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>X_test&#45;&gt;predicted_output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M302.82,-287.7C302.5,-280.24 302.12,-271.32 301.76,-262.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"305.22,-262.95 301.3,-253.1 298.23,-263.25 305.22,-262.95\"/>\n",
"</g>\n",
"<!-- y_test_with_labels&#45;&gt;classification_report -->\n",
"<g id=\"edge20\" class=\"edge\">\n",
"<title>y_test_with_labels&#45;&gt;classification_report</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M165.76,-287.85C159.73,-250.99 145.49,-163.92 138.14,-118.96\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"141.45,-118.54 136.38,-109.23 134.54,-119.67 141.45,-118.54\"/>\n",
"</g>\n",
"<!-- test_size_fraction -->\n",
"<g id=\"node20\" class=\"node\">\n",
"<title>test_size_fraction</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"187.56\" cy=\"-522\" rx=\"100.48\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"187.56\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n",
"</g>\n",
"<!-- test_size_fraction&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>test_size_fraction&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M228.86,-505.29C256.25,-494.87 292.36,-481.13 321.41,-470.08\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"322.26,-473.12 330.36,-466.3 319.77,-466.58 322.26,-473.12\"/>\n",
"</g>\n",
"<!-- model_parameters&#45;&gt;model_params_to_json -->\n",
"<g id=\"edge16\" class=\"edge\">\n",
"<title>model_parameters&#45;&gt;model_params_to_json</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M476.51,-215.7C478.66,-208.15 481.24,-199.12 483.66,-190.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"487.3,-191.68 486.68,-181.1 480.56,-189.76 487.3,-191.68\"/>\n",
"</g>\n",
"<!-- y_train&#45;&gt;fit_clf -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>y_train&#45;&gt;fit_clf</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M466.06,-359.7C466.27,-352.24 466.53,-343.32 466.76,-334.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"470.29,-335.2 467.08,-325.1 463.29,-335 470.29,-335.2\"/>\n",
"</g>\n",
"<!-- target&#45;&gt;train_test_split_func -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>target&#45;&gt;train_test_split_func</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M466.29,-508.16C450.06,-498.37 427.32,-484.65 408.19,-473.1\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"410.22,-469.63 399.85,-467.46 406.6,-475.63 410.22,-469.63\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x16a2c7730>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"materializers = [\n",
" to.json(\n",
" dependencies=[\"model_parameters\"],\n",
" id=\"model_params_to_json\",\n",
" path=\"./data/params.json\"\n",
" ),\n",
" # classification report to .txt file\n",
" to.file(\n",
" dependencies=[\"classification_report\"],\n",
" id=\"classification_report_to_txt\",\n",
" path=\"./data/classification_report.txt\",\n",
" ),\n",
" # materialize the model to a pickle file\n",
" to.pickle(\n",
" dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n",
" ),\n",
" # materialize the predictions we made to a csv file\n",
" to.csv(\n",
" dependencies=[\"predicted_output_with_labels\"],\n",
" id=\"predicted_output_with_labels_to_csv\",\n",
" path=\"./data/predicted_output_with_labels.csv\",\n",
" ),\n",
" ]\n",
"\n",
"dr.visualize_materialization(\n",
" *materializers,\n",
" additional_vars=[\"classification_report\"],\n",
" output_file_path=None,\n",
" render_kwargs={},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f5727b54",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/elijahbenizzy/.pyenv/versions/3.9.10/envs/hamilton/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
}
],
"source": [
"materialization_results, additional_vars = dr.materialize(\n",
" # materialize model parameters to json\n",
" *materializers,\n",
" additional_vars=[\"classification_report\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8bdfde70",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 94\n",
" 1 0.91 0.93 0.92 85\n",
" 2 0.97 0.99 0.98 96\n",
" 3 0.99 0.97 0.98 93\n",
" 4 0.99 0.92 0.95 88\n",
" 5 0.95 0.95 0.95 85\n",
" 6 0.99 0.97 0.98 97\n",
" 7 0.97 0.97 0.97 89\n",
" 8 0.88 0.88 0.88 82\n",
" 9 0.91 0.97 0.94 90\n",
"\n",
" accuracy 0.96 899\n",
" macro avg 0.95 0.95 0.95 899\n",
"weighted avg 0.96 0.96 0.96 899\n",
"\n"
]
}
],
"source": [
"print(additional_vars['classification_report'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a6f5fe83",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 94\n",
" 1 0.91 0.93 0.92 85\n",
" 2 0.97 0.99 0.98 96\n",
" 3 0.99 0.97 0.98 93\n",
" 4 0.99 0.92 0.95 88\n",
" 5 0.95 0.95 0.95 85\n",
" 6 0.99 0.97 0.98 97\n",
" 7 0.97 0.97 0.97 89\n",
" 8 0.88 0.88 0.88 82\n",
" 9 0.91 0.97 0.94 90\n",
"\n",
" accuracy 0.96 899\n",
" macro avg 0.95 0.95 0.95 899\n",
"weighted avg 0.96 0.96 0.96 899\n",
"\n"
]
}
],
"source": [
"print(open((materialization_results['classification_report_to_txt']['path'])).read())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}