blob: 14c7387c3b2ba720bb14a8a9e4da678bb90cd1eb [file] [log] [blame]
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "7bf6a40d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T17:30:02.674284Z",
"start_time": "2023-09-04T17:29:59.371085Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import json\n",
"import os\n",
"\n",
"import data_loaders\n",
"import model_training\n",
"\n",
"from hamilton import base, driver\n",
"from hamilton.io.materialization import to\n",
"import pandas as pd\n",
"\n",
"import custom_materializers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7a449245",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T17:30:02.682841Z",
"start_time": "2023-09-04T17:30:02.679637Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n"
]
}
],
"source": [
"dag_config = {\n",
" \"test_size_fraction\": 0.5,\n",
" \"shuffle_train_test_split\": True,\n",
" \"data_loader\" : \"iris\",\n",
" \"clf\" : \"logistic\",\n",
" \"penalty\" : \"l2\"\n",
"}\n",
"dr = (\n",
" driver.Builder()\n",
" .with_adapter(base.DefaultAdapter())\n",
" .with_config(dag_config)\n",
" .with_modules(data_loaders, model_training)\n",
" .build()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "397b09bc",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-04T17:30:03.208108Z",
"start_time": "2023-09-04T17:30:02.696138Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n -->\n<!-- Pages: 1 -->\n<svg width=\"787pt\" height=\"620pt\"\n viewBox=\"0.00 0.00 787.01 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-616 783.01,-616 783.01,4 -4,4\"/>\n<!-- classification_report_to_txt -->\n<g id=\"node1\" class=\"node\">\n<title>classification_report_to_txt</title>\n<polygon fill=\"none\" stroke=\"black\" points=\"746.7,-36 582.2,-36 582.2,0 746.7,0 746.7,-36\"/>\n<text text-anchor=\"middle\" x=\"664.45\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report_to_txt</text>\n</g>\n<!-- target_names -->\n<g id=\"node2\" class=\"node\">\n<title>target_names</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"718.45\" cy=\"-378\" rx=\"60.56\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"718.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n</g>\n<!-- y_test_with_labels -->\n<g id=\"node14\" class=\"node\">\n<title>y_test_with_labels</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"627.45\" cy=\"-306\" rx=\"80.01\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"627.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n</g>\n<!-- target_names&#45;&gt;y_test_with_labels -->\n<g id=\"edge19\" class=\"edge\">\n<title>target_names&#45;&gt;y_test_with_labels</title>\n<path fill=\"none\" stroke=\"black\" d=\"M697.34,-360.76C685.57,-351.71 670.7,-340.27 657.7,-330.28\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"660.13,-326.96 650.07,-323.63 655.86,-332.5 660.13,-326.96\"/>\n</g>\n<!-- predicted_output_with_labels -->\n<g id=\"node20\" class=\"node\">\n<title>predicted_output_with_labels</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"516.45\" cy=\"-162\" rx=\"120.45\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"516.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n</g>\n<!-- target_names&#45;&gt;predicted_output_with_labels -->\n<g id=\"edge24\" class=\"edge\">\n<title>target_names&#45;&gt;predicted_output_with_labels</title>\n<path fill=\"none\" stroke=\"black\" d=\"M722.65,-359.66C726.24,-340.73 729.06,-310.16 716.45,-288 686.35,-235.12 623.81,-201.51 576.63,-182.62\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"578.1,-179.05 567.51,-178.7 575.57,-185.58 578.1,-179.05\"/>\n</g>\n<!-- X_train -->\n<g id=\"node3\" class=\"node\">\n<title>X_train</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"427.45\" cy=\"-378\" rx=\"39.07\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"427.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n</g>\n<!-- fit_clf -->\n<g id=\"node6\" class=\"node\">\n<title>fit_clf</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"331.45\" cy=\"-306\" rx=\"33.44\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"331.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n</g>\n<!-- X_train&#45;&gt;fit_clf -->\n<g id=\"edge9\" class=\"edge\">\n<title>X_train&#45;&gt;fit_clf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M407.09,-362.15C393.19,-352.02 374.61,-338.47 359.34,-327.34\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"361.94,-324.17 351.79,-321.11 357.81,-329.83 361.94,-324.17\"/>\n</g>\n<!-- train_test_split_func -->\n<g id=\"node4\" class=\"node\">\n<title>train_test_split_func</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"460.45\" cy=\"-450\" rx=\"86.67\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"460.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n</g>\n<!-- train_test_split_func&#45;&gt;X_train -->\n<g id=\"edge3\" class=\"edge\">\n<title>train_test_split_func&#45;&gt;X_train</title>\n<path fill=\"none\" stroke=\"black\" d=\"M452.29,-431.7C448.61,-423.9 444.19,-414.51 440.09,-405.83\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"442.94,-404.66 435.51,-397.1 436.61,-407.64 442.94,-404.66\"/>\n</g>\n<!-- y_train -->\n<g id=\"node9\" class=\"node\">\n<title>y_train</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"333.45\" cy=\"-378\" rx=\"37.02\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"333.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n</g>\n<!-- train_test_split_func&#45;&gt;y_train -->\n<g id=\"edge12\" class=\"edge\">\n<title>train_test_split_func&#45;&gt;y_train</title>\n<path fill=\"none\" stroke=\"black\" d=\"M430.99,-432.76C411.81,-422.19 386.74,-408.38 366.82,-397.4\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"368.67,-393.87 358.22,-392.1 365.29,-400 368.67,-393.87\"/>\n</g>\n<!-- y_test -->\n<g id=\"node12\" class=\"node\">\n<title>y_test</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"573.45\" cy=\"-378\" rx=\"32.93\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"573.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n</g>\n<!-- train_test_split_func&#45;&gt;y_test -->\n<g id=\"edge16\" class=\"edge\">\n<title>train_test_split_func&#45;&gt;y_test</title>\n<path fill=\"none\" stroke=\"black\" d=\"M487.22,-432.41C503.97,-422.04 525.57,-408.66 542.97,-397.88\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"544.48,-400.44 551.14,-392.2 540.8,-394.49 544.48,-400.44\"/>\n</g>\n<!-- X_test -->\n<g id=\"node21\" class=\"node\">\n<title>X_test</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"494.45\" cy=\"-306\" rx=\"34.97\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"494.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n</g>\n<!-- train_test_split_func&#45;&gt;X_test -->\n<g id=\"edge25\" class=\"edge\">\n<title>train_test_split_func&#45;&gt;X_test</title>\n<path fill=\"none\" stroke=\"black\" d=\"M465.6,-431.89C468.63,-421.54 472.45,-408.07 475.45,-396 480.5,-375.62 485.43,-352.42 489,-334.81\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"492.6,-335.63 491.13,-325.14 485.74,-334.26 492.6,-335.63\"/>\n</g>\n<!-- shuffle_train_test_split -->\n<g id=\"node5\" class=\"node\">\n<title>shuffle_train_test_split</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"120.45\" cy=\"-522\" rx=\"120.45\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"120.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n</g>\n<!-- shuffle_train_test_split&#45;&gt;train_test_split_func -->\n<g id=\"edge7\" class=\"edge\">\n<title>shuffle_train_test_split&#45;&gt;train_test_split_func</title>\n<path fill=\"none\" stroke=\"black\" d=\"M189.25,-506.83C247.61,-494.82 331,-477.65 389.86,-465.53\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"390.47,-468.77 399.56,-463.33 389.06,-461.92 390.47,-468.77\"/>\n</g>\n<!-- clf_to_pickle -->\n<g id=\"node10\" class=\"node\">\n<title>clf_to_pickle</title>\n<polygon fill=\"none\" stroke=\"black\" points=\"228.07,-252 140.82,-252 140.82,-216 228.07,-216 228.07,-252\"/>\n<text text-anchor=\"middle\" x=\"184.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">clf_to_pickle</text>\n</g>\n<!-- fit_clf&#45;&gt;clf_to_pickle -->\n<g id=\"edge13\" class=\"edge\">\n<title>fit_clf&#45;&gt;clf_to_pickle</title>\n<path fill=\"none\" stroke=\"black\" d=\"M306.84,-293.28C286.13,-283.42 255.95,-269.05 230.73,-257.04\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"232.52,-253.54 221.99,-252.4 229.52,-259.86 232.52,-253.54\"/>\n</g>\n<!-- model_parameters -->\n<g id=\"node13\" class=\"node\">\n<title>model_parameters</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"326.45\" cy=\"-234\" rx=\"80.01\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"326.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_parameters</text>\n</g>\n<!-- fit_clf&#45;&gt;model_parameters -->\n<g id=\"edge17\" class=\"edge\">\n<title>fit_clf&#45;&gt;model_parameters</title>\n<path fill=\"none\" stroke=\"black\" d=\"M330.21,-287.7C329.68,-280.24 329.04,-271.32 328.44,-262.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"331.87,-262.83 327.67,-253.1 324.89,-263.33 331.87,-262.83\"/>\n</g>\n<!-- predicted_output -->\n<g id=\"node22\" class=\"node\">\n<title>predicted_output</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"497.45\" cy=\"-234\" rx=\"73.36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"497.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n</g>\n<!-- fit_clf&#45;&gt;predicted_output -->\n<g id=\"edge26\" class=\"edge\">\n<title>fit_clf&#45;&gt;predicted_output</title>\n<path fill=\"none\" stroke=\"black\" d=\"M357.02,-294.22C382.19,-283.6 421.17,-267.17 451.68,-254.3\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"452.82,-257.2 460.67,-250.09 450.1,-250.75 452.82,-257.2\"/>\n</g>\n<!-- data -->\n<g id=\"node7\" class=\"node\">\n<title>data</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-594\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"658.45\" y=\"-588.95\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n</g>\n<!-- data&#45;&gt;target_names -->\n<g id=\"edge2\" class=\"edge\">\n<title>data&#45;&gt;target_names</title>\n<path fill=\"none\" stroke=\"black\" d=\"M673.08,-578.64C682.26,-568.69 693.46,-554.65 699.45,-540 717.45,-495.95 719.96,-439.86 719.59,-406.85\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"723.07,-407.11 719.36,-397.19 716.07,-407.27 723.07,-407.11\"/>\n</g>\n<!-- feature_matrix -->\n<g id=\"node15\" class=\"node\">\n<title>feature_matrix</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"324.45\" cy=\"-522\" rx=\"65.68\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"324.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n</g>\n<!-- data&#45;&gt;feature_matrix -->\n<g id=\"edge20\" class=\"edge\">\n<title>data&#45;&gt;feature_matrix</title>\n<path fill=\"none\" stroke=\"black\" d=\"M632.67,-587.91C586.07,-578.67 484.69,-558.38 399.45,-540 394.49,-538.93 389.35,-537.8 384.19,-536.66\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"385.15,-533.07 374.63,-534.3 383.63,-539.9 385.15,-533.07\"/>\n</g>\n<!-- target -->\n<g id=\"node17\" class=\"node\">\n<title>target</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-522\" rx=\"31.9\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"658.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n</g>\n<!-- data&#45;&gt;target -->\n<g id=\"edge21\" class=\"edge\">\n<title>data&#45;&gt;target</title>\n<path fill=\"none\" stroke=\"black\" d=\"M658.45,-575.7C658.45,-568.24 658.45,-559.32 658.45,-550.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"661.95,-551.1 658.45,-541.1 654.95,-551.1 661.95,-551.1\"/>\n</g>\n<!-- prefit_clf -->\n<g id=\"node8\" class=\"node\">\n<title>prefit_clf</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"233.45\" cy=\"-378\" rx=\"45.21\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"233.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n</g>\n<!-- prefit_clf&#45;&gt;fit_clf -->\n<g id=\"edge8\" class=\"edge\">\n<title>prefit_clf&#45;&gt;fit_clf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M254.71,-361.81C268.92,-351.66 287.77,-338.19 303.24,-327.15\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"304.78,-329.63 310.89,-320.97 300.72,-323.93 304.78,-329.63\"/>\n</g>\n<!-- y_train&#45;&gt;fit_clf -->\n<g id=\"edge10\" class=\"edge\">\n<title>y_train&#45;&gt;fit_clf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M332.95,-359.7C332.74,-352.24 332.48,-343.32 332.24,-334.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"335.72,-335 331.93,-325.1 328.72,-335.2 335.72,-335\"/>\n</g>\n<!-- classification_report -->\n<g id=\"node11\" class=\"node\">\n<title>classification_report</title>\n<polygon fill=\"none\" stroke=\"black\" points=\"727.57,-108 601.32,-108 601.32,-72 727.57,-72 727.57,-108\"/>\n<text text-anchor=\"middle\" x=\"664.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n</g>\n<!-- classification_report&#45;&gt;classification_report_to_txt -->\n<g id=\"edge1\" class=\"edge\">\n<title>classification_report&#45;&gt;classification_report_to_txt</title>\n<path fill=\"none\" stroke=\"black\" d=\"M664.45,-71.7C664.45,-64.24 664.45,-55.32 664.45,-46.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"667.95,-47.1 664.45,-37.1 660.95,-47.1 667.95,-47.1\"/>\n</g>\n<!-- y_test&#45;&gt;y_test_with_labels -->\n<g id=\"edge18\" class=\"edge\">\n<title>y_test&#45;&gt;y_test_with_labels</title>\n<path fill=\"none\" stroke=\"black\" d=\"M585.7,-361.12C592.27,-352.59 600.53,-341.89 607.97,-332.25\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"611.15,-334.85 614.49,-324.8 605.61,-330.58 611.15,-334.85\"/>\n</g>\n<!-- model_params_to_json -->\n<g id=\"node23\" class=\"node\">\n<title>model_params_to_json</title>\n<polygon fill=\"none\" stroke=\"black\" points=\"378.2,-180 234.7,-180 234.7,-144 378.2,-144 378.2,-180\"/>\n<text text-anchor=\"middle\" x=\"306.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_params_to_json</text>\n</g>\n<!-- model_parameters&#45;&gt;model_params_to_json -->\n<g id=\"edge28\" class=\"edge\">\n<title>model_parameters&#45;&gt;model_params_to_json</title>\n<path fill=\"none\" stroke=\"black\" d=\"M321.5,-215.7C319.35,-208.15 316.77,-199.12 314.35,-190.68\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"317.45,-189.76 311.33,-181.1 310.71,-191.68 317.45,-189.76\"/>\n</g>\n<!-- y_test_with_labels&#45;&gt;classification_report -->\n<g id=\"edge15\" class=\"edge\">\n<title>y_test_with_labels&#45;&gt;classification_report</title>\n<path fill=\"none\" stroke=\"black\" d=\"M630.41,-287.85C636.78,-250.99 651.84,-163.92 659.61,-118.96\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"663.21,-119.68 661.47,-109.23 656.31,-118.49 663.21,-119.68\"/>\n</g>\n<!-- feature_matrix&#45;&gt;train_test_split_func -->\n<g id=\"edge4\" class=\"edge\">\n<title>feature_matrix&#45;&gt;train_test_split_func</title>\n<path fill=\"none\" stroke=\"black\" d=\"M354.29,-505.64C373.46,-495.77 398.61,-482.83 419.62,-472.01\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"420.96,-474.74 428.25,-467.06 417.76,-468.52 420.96,-474.74\"/>\n</g>\n<!-- penalty -->\n<g id=\"node16\" class=\"node\">\n<title>penalty</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"233.45\" cy=\"-450\" rx=\"62.61\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"233.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n</g>\n<!-- penalty&#45;&gt;prefit_clf -->\n<g id=\"edge11\" class=\"edge\">\n<title>penalty&#45;&gt;prefit_clf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M233.45,-431.7C233.45,-424.24 233.45,-415.32 233.45,-406.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"236.95,-407.1 233.45,-397.1 229.95,-407.1 236.95,-407.1\"/>\n</g>\n<!-- target&#45;&gt;train_test_split_func -->\n<g id=\"edge5\" class=\"edge\">\n<title>target&#45;&gt;train_test_split_func</title>\n<path fill=\"none\" stroke=\"black\" d=\"M633.4,-510.3C628.15,-508.15 622.64,-505.95 617.45,-504 584.32,-491.56 546.78,-478.84 516.72,-468.98\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"518.01,-465.39 507.42,-465.61 515.84,-472.04 518.01,-465.39\"/>\n</g>\n<!-- predicted_output_with_labels_to_csv -->\n<g id=\"node18\" class=\"node\">\n<title>predicted_output_with_labels_to_csv</title>\n<polygon fill=\"none\" stroke=\"black\" points=\"583.7,-108 365.2,-108 365.2,-72 583.7,-72 583.7,-108\"/>\n<text text-anchor=\"middle\" x=\"474.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels_to_csv</text>\n</g>\n<!-- test_size_fraction -->\n<g id=\"node19\" class=\"node\">\n<title>test_size_fraction</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"508.45\" cy=\"-522\" rx=\"100.48\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"508.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n</g>\n<!-- test_size_fraction&#45;&gt;train_test_split_func -->\n<g id=\"edge6\" class=\"edge\">\n<title>test_size_fraction&#45;&gt;train_test_split_func</title>\n<path fill=\"none\" stroke=\"black\" d=\"M496.58,-503.7C491.06,-495.64 484.37,-485.89 478.26,-476.98\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"480.72,-475.37 472.17,-469.1 474.94,-479.33 480.72,-475.37\"/>\n</g>\n<!-- predicted_output_with_labels&#45;&gt;classification_report -->\n<g id=\"edge14\" class=\"edge\">\n<title>predicted_output_with_labels&#45;&gt;classification_report</title>\n<path fill=\"none\" stroke=\"black\" d=\"M551.52,-144.41C571.3,-135.06 596.26,-123.25 617.68,-113.12\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"619.01,-115.89 626.56,-108.45 616.02,-109.56 619.01,-115.89\"/>\n</g>\n<!-- predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv -->\n<g id=\"edge22\" class=\"edge\">\n<title>predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv</title>\n<path fill=\"none\" stroke=\"black\" d=\"M506.06,-143.7C501.28,-135.73 495.51,-126.1 490.2,-117.26\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"492.85,-115.88 484.71,-109.1 486.85,-119.48 492.85,-115.88\"/>\n</g>\n<!-- X_test&#45;&gt;predicted_output -->\n<g id=\"edge27\" class=\"edge\">\n<title>X_test&#45;&gt;predicted_output</title>\n<path fill=\"none\" stroke=\"black\" d=\"M495.19,-287.7C495.51,-280.24 495.89,-271.32 496.25,-262.97\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"499.78,-263.25 496.71,-253.1 492.79,-262.95 499.78,-263.25\"/>\n</g>\n<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n<g id=\"edge23\" class=\"edge\">\n<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n<path fill=\"none\" stroke=\"black\" d=\"M502.14,-215.7C504.19,-208.15 506.64,-199.12 508.93,-190.68\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"512.56,-191.67 511.8,-181.1 505.81,-189.84 512.56,-191.67\"/>\n</g>\n</g>\n</svg>\n",
"text/plain": "<graphviz.graphs.Digraph at 0x140ea8f10>"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"materializers = [\n",
" to.json(\n",
" dependencies=[\"model_parameters\"],\n",
" id=\"model_params_to_json\",\n",
" path=\"./data/params.json\"\n",
" ),\n",
" # classification report to .txt file\n",
" to.file(\n",
" dependencies=[\"classification_report\"],\n",
" id=\"classification_report_to_txt\",\n",
" path=\"./data/classification_report.txt\",\n",
" ),\n",
" # materialize the model to a pickle file\n",
" to.pickle(\n",
" dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n",
" ),\n",
" # materialize the predictions we made to a csv file\n",
" to.csv(\n",
" dependencies=[\"predicted_output_with_labels\"],\n",
" id=\"predicted_output_with_labels_to_csv\",\n",
" path=\"./data/predicted_output_with_labels.csv\",\n",
" ),\n",
" ]\n",
"\n",
"dr.visualize_materialization(\n",
" *materializers,\n",
" additional_vars=[\"classification_report\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5727b54",
"metadata": {},
"outputs": [],
"source": [
"materialization_results, additional_vars = dr.materialize(\n",
" # materialize model parameters to json\n",
" *materializers,\n",
" additional_vars=[\"classification_report\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8bdfde70",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 94\n",
" 1 0.91 0.93 0.92 85\n",
" 2 0.97 0.99 0.98 96\n",
" 3 0.99 0.97 0.98 93\n",
" 4 0.99 0.92 0.95 88\n",
" 5 0.95 0.95 0.95 85\n",
" 6 0.99 0.97 0.98 97\n",
" 7 0.97 0.97 0.97 89\n",
" 8 0.88 0.88 0.88 82\n",
" 9 0.91 0.97 0.94 90\n",
"\n",
" accuracy 0.96 899\n",
" macro avg 0.95 0.95 0.95 899\n",
"weighted avg 0.96 0.96 0.96 899\n",
"\n"
]
}
],
"source": [
"print(additional_vars['classification_report'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a6f5fe83",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 94\n",
" 1 0.91 0.93 0.92 85\n",
" 2 0.97 0.99 0.98 96\n",
" 3 0.99 0.97 0.98 93\n",
" 4 0.99 0.92 0.95 88\n",
" 5 0.95 0.95 0.95 85\n",
" 6 0.99 0.97 0.98 97\n",
" 7 0.97 0.97 0.97 89\n",
" 8 0.88 0.88 0.88 82\n",
" 9 0.91 0.97 0.94 90\n",
"\n",
" accuracy 0.96 899\n",
" macro avg 0.95 0.95 0.95 899\n",
"weighted avg 0.96 0.96 0.96 899\n",
"\n"
]
}
],
"source": [
"print(open((materialization_results['classification_report_to_txt']['path'])).read())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}