| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "7fb27b941602401d91542211134fc71a", |
| "metadata": {}, |
| "source": [ |
| "Licensed to the Apache Software Foundation (ASF) under one\nor more contributor license agreements. See the NOTICE file\ndistributed with this work for additional information\nregarding copyright ownership. The ASF licenses this file\nto you under the Apache License, Version 2.0 (the\n\"License\"); you may not use this file except in compliance\nwith the License. You may obtain a copy of the License at\n\n http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing,\nsoftware distributed under the License is distributed on an\n\"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\nKIND, either express or implied. See the License for the\nspecific language governing permissions and limitations\nunder the License." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "66178523", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Execute this cell to install dependencies\n", |
| "%pip install sf-hamilton[visualization]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "97f4cc58", |
| "metadata": {}, |
| "source": [ |
| "# Materialization [](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/materialization/notebook.ipynb) [](https://github.com/apache/hamilton/blob/main/examples/materialization/notebook.ipynb)\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "6e2fc99a", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2023-09-04T17:30:02.674284Z", |
| "start_time": "2023-09-04T17:29:59.371085Z" |
| } |
| }, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", |
| " warnings.warn(\n" |
| ] |
| } |
| ], |
| "source": [ |
| "import data_loaders\n", |
| "import model_training\n", |
| "\n", |
| "from hamilton import base, driver\n", |
| "from hamilton.io.materialization import to" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "222756ff", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2023-09-04T17:30:02.682841Z", |
| "start_time": "2023-09-04T17:30:02.679637Z" |
| } |
| }, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n" |
| ] |
| } |
| ], |
| "source": [ |
| "dag_config = {\n", |
| " \"test_size_fraction\": 0.5,\n", |
| " \"shuffle_train_test_split\": True,\n", |
| " \"data_loader\": \"iris\",\n", |
| " \"clf\": \"logistic\",\n", |
| " \"penalty\": \"l2\",\n", |
| "}\n", |
| "dr = (\n", |
| " driver.Builder()\n", |
| " .with_adapter(base.DefaultAdapter())\n", |
| " .with_config(dag_config)\n", |
| " .with_modules(data_loaders, model_training)\n", |
| " .build()\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "2f0e75e3", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2023-09-04T17:30:03.208108Z", |
| "start_time": "2023-09-04T17:30:02.696138Z" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"787pt\" height=\"620pt\"\n", |
| " viewBox=\"0.00 0.00 787.01 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-616 783.01,-616 783.01,4 -4,4\"/>\n", |
| "<!-- classification_report_to_txt -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>classification_report_to_txt</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"746.7,-36 582.2,-36 582.2,0 746.7,0 746.7,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"664.45\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report_to_txt</text>\n", |
| "</g>\n", |
| "<!-- target_names -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>target_names</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"718.45\" cy=\"-378\" rx=\"60.56\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"718.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>y_test_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"627.45\" cy=\"-306\" rx=\"80.01\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"627.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n", |
| "</g>\n", |
| "<!-- target_names->y_test_with_labels -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>target_names->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M697.34,-360.76C685.57,-351.71 670.7,-340.27 657.7,-330.28\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"660.13,-326.96 650.07,-323.63 655.86,-332.5 660.13,-326.96\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels -->\n", |
| "<g id=\"node20\" class=\"node\">\n", |
| "<title>predicted_output_with_labels</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"516.45\" cy=\"-162\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"516.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n", |
| "</g>\n", |
| "<!-- target_names->predicted_output_with_labels -->\n", |
| "<g id=\"edge24\" class=\"edge\">\n", |
| "<title>target_names->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M722.65,-359.66C726.24,-340.73 729.06,-310.16 716.45,-288 686.35,-235.12 623.81,-201.51 576.63,-182.62\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"578.1,-179.05 567.51,-178.7 575.57,-185.58 578.1,-179.05\"/>\n", |
| "</g>\n", |
| "<!-- X_train -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>X_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"427.45\" cy=\"-378\" rx=\"39.07\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"427.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n", |
| "</g>\n", |
| "<!-- fit_clf -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>fit_clf</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"331.45\" cy=\"-306\" rx=\"33.44\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"331.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n", |
| "</g>\n", |
| "<!-- X_train->fit_clf -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>X_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M407.09,-362.15C393.19,-352.02 374.61,-338.47 359.34,-327.34\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"361.94,-324.17 351.79,-321.11 357.81,-329.83 361.94,-324.17\"/>\n", |
| "</g>\n", |
| "<!-- train_test_split_func -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>train_test_split_func</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"460.45\" cy=\"-450\" rx=\"86.67\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"460.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_train -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M452.29,-431.7C448.61,-423.9 444.19,-414.51 440.09,-405.83\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"442.94,-404.66 435.51,-397.1 436.61,-407.64 442.94,-404.66\"/>\n", |
| "</g>\n", |
| "<!-- y_train -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>y_train</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"333.45\" cy=\"-378\" rx=\"37.02\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"333.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_train -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_train</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M430.99,-432.76C411.81,-422.19 386.74,-408.38 366.82,-397.4\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"368.67,-393.87 358.22,-392.1 365.29,-400 368.67,-393.87\"/>\n", |
| "</g>\n", |
| "<!-- y_test -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>y_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"573.45\" cy=\"-378\" rx=\"32.93\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"573.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->y_test -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>train_test_split_func->y_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M487.22,-432.41C503.97,-422.04 525.57,-408.66 542.97,-397.88\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"544.48,-400.44 551.14,-392.2 540.8,-394.49 544.48,-400.44\"/>\n", |
| "</g>\n", |
| "<!-- X_test -->\n", |
| "<g id=\"node21\" class=\"node\">\n", |
| "<title>X_test</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"494.45\" cy=\"-306\" rx=\"34.97\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"494.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n", |
| "</g>\n", |
| "<!-- train_test_split_func->X_test -->\n", |
| "<g id=\"edge25\" class=\"edge\">\n", |
| "<title>train_test_split_func->X_test</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M465.6,-431.89C468.63,-421.54 472.45,-408.07 475.45,-396 480.5,-375.62 485.43,-352.42 489,-334.81\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"492.6,-335.63 491.13,-325.14 485.74,-334.26 492.6,-335.63\"/>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>shuffle_train_test_split</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"120.45\" cy=\"-522\" rx=\"120.45\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"120.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n", |
| "</g>\n", |
| "<!-- shuffle_train_test_split->train_test_split_func -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>shuffle_train_test_split->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M189.25,-506.83C247.61,-494.82 331,-477.65 389.86,-465.53\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"390.47,-468.77 399.56,-463.33 389.06,-461.92 390.47,-468.77\"/>\n", |
| "</g>\n", |
| "<!-- clf_to_pickle -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>clf_to_pickle</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"228.07,-252 140.82,-252 140.82,-216 228.07,-216 228.07,-252\"/>\n", |
| "<text text-anchor=\"middle\" x=\"184.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">clf_to_pickle</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->clf_to_pickle -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>fit_clf->clf_to_pickle</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M306.84,-293.28C286.13,-283.42 255.95,-269.05 230.73,-257.04\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"232.52,-253.54 221.99,-252.4 229.52,-259.86 232.52,-253.54\"/>\n", |
| "</g>\n", |
| "<!-- model_parameters -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>model_parameters</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"326.45\" cy=\"-234\" rx=\"80.01\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"326.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_parameters</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->model_parameters -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>fit_clf->model_parameters</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M330.21,-287.7C329.68,-280.24 329.04,-271.32 328.44,-262.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"331.87,-262.83 327.67,-253.1 324.89,-263.33 331.87,-262.83\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output -->\n", |
| "<g id=\"node22\" class=\"node\">\n", |
| "<title>predicted_output</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"497.45\" cy=\"-234\" rx=\"73.36\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"497.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n", |
| "</g>\n", |
| "<!-- fit_clf->predicted_output -->\n", |
| "<g id=\"edge26\" class=\"edge\">\n", |
| "<title>fit_clf->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M357.02,-294.22C382.19,-283.6 421.17,-267.17 451.68,-254.3\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"452.82,-257.2 460.67,-250.09 450.1,-250.75 452.82,-257.2\"/>\n", |
| "</g>\n", |
| "<!-- data -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>data</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-594\" rx=\"27\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"658.45\" y=\"-588.95\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n", |
| "</g>\n", |
| "<!-- data->target_names -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>data->target_names</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M673.08,-578.64C682.26,-568.69 693.46,-554.65 699.45,-540 717.45,-495.95 719.96,-439.86 719.59,-406.85\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"723.07,-407.11 719.36,-397.19 716.07,-407.27 723.07,-407.11\"/>\n", |
| "</g>\n", |
| "<!-- feature_matrix -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>feature_matrix</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"324.45\" cy=\"-522\" rx=\"65.68\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"324.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n", |
| "</g>\n", |
| "<!-- data->feature_matrix -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>data->feature_matrix</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M632.67,-587.91C586.07,-578.67 484.69,-558.38 399.45,-540 394.49,-538.93 389.35,-537.8 384.19,-536.66\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"385.15,-533.07 374.63,-534.3 383.63,-539.9 385.15,-533.07\"/>\n", |
| "</g>\n", |
| "<!-- target -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>target</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-522\" rx=\"31.9\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"658.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n", |
| "</g>\n", |
| "<!-- data->target -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>data->target</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M658.45,-575.7C658.45,-568.24 658.45,-559.32 658.45,-550.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"661.95,-551.1 658.45,-541.1 654.95,-551.1 661.95,-551.1\"/>\n", |
| "</g>\n", |
| "<!-- prefit_clf -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>prefit_clf</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"233.45\" cy=\"-378\" rx=\"45.21\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"233.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n", |
| "</g>\n", |
| "<!-- prefit_clf->fit_clf -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>prefit_clf->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M254.71,-361.81C268.92,-351.66 287.77,-338.19 303.24,-327.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"304.78,-329.63 310.89,-320.97 300.72,-323.93 304.78,-329.63\"/>\n", |
| "</g>\n", |
| "<!-- y_train->fit_clf -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>y_train->fit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M332.95,-359.7C332.74,-352.24 332.48,-343.32 332.24,-334.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"335.72,-335 331.93,-325.1 328.72,-335.2 335.72,-335\"/>\n", |
| "</g>\n", |
| "<!-- classification_report -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>classification_report</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"727.57,-108 601.32,-108 601.32,-72 727.57,-72 727.57,-108\"/>\n", |
| "<text text-anchor=\"middle\" x=\"664.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n", |
| "</g>\n", |
| "<!-- classification_report->classification_report_to_txt -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>classification_report->classification_report_to_txt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M664.45,-71.7C664.45,-64.24 664.45,-55.32 664.45,-46.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"667.95,-47.1 664.45,-37.1 660.95,-47.1 667.95,-47.1\"/>\n", |
| "</g>\n", |
| "<!-- y_test->y_test_with_labels -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>y_test->y_test_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M585.7,-361.12C592.27,-352.59 600.53,-341.89 607.97,-332.25\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"611.15,-334.85 614.49,-324.8 605.61,-330.58 611.15,-334.85\"/>\n", |
| "</g>\n", |
| "<!-- model_params_to_json -->\n", |
| "<g id=\"node23\" class=\"node\">\n", |
| "<title>model_params_to_json</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"378.2,-180 234.7,-180 234.7,-144 378.2,-144 378.2,-180\"/>\n", |
| "<text text-anchor=\"middle\" x=\"306.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_params_to_json</text>\n", |
| "</g>\n", |
| "<!-- model_parameters->model_params_to_json -->\n", |
| "<g id=\"edge28\" class=\"edge\">\n", |
| "<title>model_parameters->model_params_to_json</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M321.5,-215.7C319.35,-208.15 316.77,-199.12 314.35,-190.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"317.45,-189.76 311.33,-181.1 310.71,-191.68 317.45,-189.76\"/>\n", |
| "</g>\n", |
| "<!-- y_test_with_labels->classification_report -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>y_test_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M630.41,-287.85C636.78,-250.99 651.84,-163.92 659.61,-118.96\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"663.21,-119.68 661.47,-109.23 656.31,-118.49 663.21,-119.68\"/>\n", |
| "</g>\n", |
| "<!-- feature_matrix->train_test_split_func -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>feature_matrix->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M354.29,-505.64C373.46,-495.77 398.61,-482.83 419.62,-472.01\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"420.96,-474.74 428.25,-467.06 417.76,-468.52 420.96,-474.74\"/>\n", |
| "</g>\n", |
| "<!-- penalty -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>penalty</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"233.45\" cy=\"-450\" rx=\"62.61\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"233.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n", |
| "</g>\n", |
| "<!-- penalty->prefit_clf -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>penalty->prefit_clf</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M233.45,-431.7C233.45,-424.24 233.45,-415.32 233.45,-406.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"236.95,-407.1 233.45,-397.1 229.95,-407.1 236.95,-407.1\"/>\n", |
| "</g>\n", |
| "<!-- target->train_test_split_func -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>target->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M633.4,-510.3C628.15,-508.15 622.64,-505.95 617.45,-504 584.32,-491.56 546.78,-478.84 516.72,-468.98\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"518.01,-465.39 507.42,-465.61 515.84,-472.04 518.01,-465.39\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels_to_csv -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>predicted_output_with_labels_to_csv</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"583.7,-108 365.2,-108 365.2,-72 583.7,-72 583.7,-108\"/>\n", |
| "<text text-anchor=\"middle\" x=\"474.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels_to_csv</text>\n", |
| "</g>\n", |
| "<!-- test_size_fraction -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>test_size_fraction</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"508.45\" cy=\"-522\" rx=\"100.48\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"508.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n", |
| "</g>\n", |
| "<!-- test_size_fraction->train_test_split_func -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>test_size_fraction->train_test_split_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M496.58,-503.7C491.06,-495.64 484.37,-485.89 478.26,-476.98\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"480.72,-475.37 472.17,-469.1 474.94,-479.33 480.72,-475.37\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->classification_report -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->classification_report</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M551.52,-144.41C571.3,-135.06 596.26,-123.25 617.68,-113.12\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"619.01,-115.89 626.56,-108.45 616.02,-109.56 619.01,-115.89\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output_with_labels->predicted_output_with_labels_to_csv -->\n", |
| "<g id=\"edge22\" class=\"edge\">\n", |
| "<title>predicted_output_with_labels->predicted_output_with_labels_to_csv</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M506.06,-143.7C501.28,-135.73 495.51,-126.1 490.2,-117.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"492.85,-115.88 484.71,-109.1 486.85,-119.48 492.85,-115.88\"/>\n", |
| "</g>\n", |
| "<!-- X_test->predicted_output -->\n", |
| "<g id=\"edge27\" class=\"edge\">\n", |
| "<title>X_test->predicted_output</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M495.19,-287.7C495.51,-280.24 495.89,-271.32 496.25,-262.97\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"499.78,-263.25 496.71,-253.1 492.79,-262.95 499.78,-263.25\"/>\n", |
| "</g>\n", |
| "<!-- predicted_output->predicted_output_with_labels -->\n", |
| "<g id=\"edge23\" class=\"edge\">\n", |
| "<title>predicted_output->predicted_output_with_labels</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M502.14,-215.7C504.19,-208.15 506.64,-199.12 508.93,-190.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"512.56,-191.67 511.8,-181.1 505.81,-189.84 512.56,-191.67\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x140ea8f10>" |
| ] |
| }, |
| "execution_count": 3, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "materializers = [\n", |
| " to.json(\n", |
| " dependencies=[\"model_parameters\"], id=\"model_params_to_json\", path=\"./data/params.json\"\n", |
| " ),\n", |
| " # classification report to .txt file\n", |
| " to.file(\n", |
| " dependencies=[\"classification_report\"],\n", |
| " id=\"classification_report_to_txt\",\n", |
| " path=\"./data/classification_report.txt\",\n", |
| " ),\n", |
| " # materialize the model to a pickle file\n", |
| " to.pickle(dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"),\n", |
| " # materialize the predictions we made to a csv file\n", |
| " to.csv(\n", |
| " dependencies=[\"predicted_output_with_labels\"],\n", |
| " id=\"predicted_output_with_labels_to_csv\",\n", |
| " path=\"./data/predicted_output_with_labels.csv\",\n", |
| " ),\n", |
| "]\n", |
| "\n", |
| "dr.visualize_materialization(\n", |
| " *materializers,\n", |
| " additional_vars=[\"classification_report\"],\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "bce54351", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "materialization_results, additional_vars = dr.materialize(\n", |
| " # materialize model parameters to json\n", |
| " *materializers,\n", |
| " additional_vars=[\"classification_report\"],\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "282f9688", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " precision recall f1-score support\n", |
| "\n", |
| " 0 1.00 1.00 1.00 94\n", |
| " 1 0.91 0.93 0.92 85\n", |
| " 2 0.97 0.99 0.98 96\n", |
| " 3 0.99 0.97 0.98 93\n", |
| " 4 0.99 0.92 0.95 88\n", |
| " 5 0.95 0.95 0.95 85\n", |
| " 6 0.99 0.97 0.98 97\n", |
| " 7 0.97 0.97 0.97 89\n", |
| " 8 0.88 0.88 0.88 82\n", |
| " 9 0.91 0.97 0.94 90\n", |
| "\n", |
| " accuracy 0.96 899\n", |
| " macro avg 0.95 0.95 0.95 899\n", |
| "weighted avg 0.96 0.96 0.96 899\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(additional_vars[\"classification_report\"])" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "3df08628", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " precision recall f1-score support\n", |
| "\n", |
| " 0 1.00 1.00 1.00 94\n", |
| " 1 0.91 0.93 0.92 85\n", |
| " 2 0.97 0.99 0.98 96\n", |
| " 3 0.99 0.97 0.98 93\n", |
| " 4 0.99 0.92 0.95 88\n", |
| " 5 0.95 0.95 0.95 85\n", |
| " 6 0.99 0.97 0.98 97\n", |
| " 7 0.97 0.97 0.97 89\n", |
| " 8 0.88 0.88 0.88 82\n", |
| " 9 0.91 0.97 0.94 90\n", |
| "\n", |
| " accuracy 0.96 899\n", |
| " macro avg 0.95 0.95 0.95 899\n", |
| "weighted avg 0.96 0.96 0.96 899\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(open(materialization_results[\"classification_report_to_txt\"][\"path\"]).read())" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.9.10" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |