| { |
| "cells": [ |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Execute this cell to install dependencies\n", |
| "%pip install sf-hamilton[visualization] pandas scikit-learn numpy" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# MPG Simple Advanced Target [](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleAdvancedTarget.ipynb) [](https://github.com/dagworks-inc/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleAdvancedTarget.ipynb)\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": { |
| "ExecuteTime": { |
| "start_time": "2024-07-20T17:55:33.133151Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "f7ca0a2e-99c4-49de-af45-c8c4bddf5685", |
| "showTitle": false, |
| "title": "" |
| }, |
| "jupyter": { |
| "is_executing": true |
| } |
| }, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", |
| " warnings.warn(\n" |
| ] |
| } |
| ], |
| "source": [ |
| "from hamilton import driver\n", |
| "from IPython.display import HTML, display" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:32:57.962238Z", |
| "start_time": "2024-07-20T17:32:57.947142Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "45fcd1cf-5dee-4d3c-b598-823c82654805", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "%load_ext hamilton.plugins.jupyter_magic" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:37:00.825770Z", |
| "start_time": "2024-07-20T17:37:00.183488Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "155ea802-aef6-4d5c-b264-d9ec5b57c733", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"865pt\" height=\"304pt\"\n", |
| " viewBox=\"0.00 0.00 864.60 303.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 299.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-299.8 860.6,-299.8 860.6,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"31.38,-157.8 31.38,-287.8 116.22,-287.8 116.22,-157.8 31.38,-157.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-270.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- target_column -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>target_column</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M392.8,-256.6C392.8,-256.6 301.45,-256.6 301.45,-256.6 295.45,-256.6 289.45,-250.6 289.45,-244.6 289.45,-244.6 289.45,-205 289.45,-205 289.45,-199 295.45,-193 301.45,-193 301.45,-193 392.8,-193 392.8,-193 398.8,-193 404.8,-199 404.8,-205 404.8,-205 404.8,-244.6 404.8,-244.6 404.8,-250.6 398.8,-256.6 392.8,-256.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"300.25\" y=\"-233.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n", |
| "<text text-anchor=\"start\" x=\"339.62\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n", |
| "</g>\n", |
| "<!-- evaluated_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>evaluated_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M844.6,-160.6C844.6,-160.6 736.75,-160.6 736.75,-160.6 730.75,-160.6 724.75,-154.6 724.75,-148.6 724.75,-148.6 724.75,-109 724.75,-109 724.75,-103 730.75,-97 736.75,-97 736.75,-97 844.6,-97 844.6,-97 850.6,-97 856.6,-103 856.6,-109 856.6,-109 856.6,-148.6 856.6,-148.6 856.6,-154.6 850.6,-160.6 844.6,-160.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"735.55\" y=\"-137.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n", |
| "<text text-anchor=\"start\" x=\"780.18\" y=\"-109.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->evaluated_model -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>target_column->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M405.29,-233.76C476.21,-242.38 599.51,-249.12 695.75,-210.8 718.39,-201.78 739.28,-184.84 755.58,-168.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"757.68,-171.55 762.15,-161.94 752.67,-166.66 757.68,-171.55\"/>\n", |
| "</g>\n", |
| "<!-- linear_model -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>linear_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M527.4,-183.6C527.4,-183.6 445.8,-183.6 445.8,-183.6 439.8,-183.6 433.8,-177.6 433.8,-171.6 433.8,-171.6 433.8,-132 433.8,-132 433.8,-126 439.8,-120 445.8,-120 445.8,-120 527.4,-120 527.4,-120 533.4,-120 539.4,-126 539.4,-132 539.4,-132 539.4,-171.6 539.4,-171.6 539.4,-177.6 533.4,-183.6 527.4,-183.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"444.6\" y=\"-160.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n", |
| "<text text-anchor=\"start\" x=\"476.1\" y=\"-132.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->linear_model -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>target_column->linear_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M405.23,-194.48C411.28,-191.26 417.45,-187.99 423.56,-184.75\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"425.2,-187.84 432.39,-180.06 421.91,-181.65 425.2,-187.84\"/>\n", |
| "</g>\n", |
| "<!-- lr_model -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>lr_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M683.75,-201.6C683.75,-201.6 580.4,-201.6 580.4,-201.6 574.4,-201.6 568.4,-195.6 568.4,-189.6 568.4,-189.6 568.4,-150 568.4,-150 568.4,-144 574.4,-138 580.4,-138 580.4,-138 683.75,-138 683.75,-138 689.75,-138 695.75,-144 695.75,-150 695.75,-150 695.75,-189.6 695.75,-189.6 695.75,-195.6 689.75,-201.6 683.75,-201.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"603.58\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n", |
| "<text text-anchor=\"start\" x=\"579.2\" y=\"-150.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n", |
| "</g>\n", |
| "<!-- lr_model->evaluated_model -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>lr_model->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M695.87,-153.36C701.57,-151.87 707.39,-150.34 713.19,-148.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"714.06,-152.22 722.84,-146.3 712.28,-145.44 714.06,-152.22\"/>\n", |
| "</g>\n", |
| "<!-- mpg_df -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>mpg_df</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n", |
| "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>data_sets</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M248.45,-105.6C248.45,-105.6 188.6,-105.6 188.6,-105.6 182.6,-105.6 176.6,-99.6 176.6,-93.6 176.6,-93.6 176.6,-54 176.6,-54 176.6,-48 182.6,-42 188.6,-42 188.6,-42 248.45,-42 248.45,-42 254.45,-42 260.45,-48 260.45,-54 260.45,-54 260.45,-93.6 260.45,-93.6 260.45,-99.6 254.45,-105.6 248.45,-105.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"187.4\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n", |
| "<text text-anchor=\"start\" x=\"208.02\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- mpg_df->data_sets -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>mpg_df->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M119.18,-102.74C133.93,-98.4 150.48,-93.53 165.74,-89.04\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"166.52,-92.46 175.13,-86.28 164.54,-85.74 166.52,-92.46\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>test_dataset</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M385.3,-64.6C385.3,-64.6 308.95,-64.6 308.95,-64.6 302.95,-64.6 296.95,-58.6 296.95,-52.6 296.95,-52.6 296.95,-13 296.95,-13 296.95,-7 302.95,-1 308.95,-1 308.95,-1 385.3,-1 385.3,-1 391.3,-1 397.3,-7 397.3,-13 397.3,-13 397.3,-52.6 397.3,-52.6 397.3,-58.6 391.3,-64.6 385.3,-64.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"307.75\" y=\"-41.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n", |
| "<text text-anchor=\"start\" x=\"313\" y=\"-13.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets->test_dataset -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>data_sets->test_dataset</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-60.49C268.75,-57.86 277.43,-55.05 286.03,-52.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"286.96,-55.64 295.4,-49.23 284.81,-48.98 286.96,-55.64\"/>\n", |
| "</g>\n", |
| "<!-- train_dataset -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>train_dataset</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M387.93,-165.6C387.93,-165.6 306.32,-165.6 306.32,-165.6 300.32,-165.6 294.32,-159.6 294.32,-153.6 294.32,-153.6 294.32,-114 294.32,-114 294.32,-108 300.32,-102 306.32,-102 306.32,-102 387.93,-102 387.93,-102 393.93,-102 399.93,-108 399.93,-114 399.93,-114 399.93,-153.6 399.93,-153.6 399.93,-159.6 393.93,-165.6 387.93,-165.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"305.12\" y=\"-142.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">train_dataset</text>\n", |
| "<text text-anchor=\"start\" x=\"313\" y=\"-114.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets->train_dataset -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>data_sets->train_dataset</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-93.28C268.06,-96.8 275.96,-100.54 283.83,-104.28\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"282.23,-107.39 292.76,-108.51 285.23,-101.07 282.23,-107.39\"/>\n", |
| "</g>\n", |
| "<!-- scaler -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>scaler</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M677.38,-119.6C677.38,-119.6 586.78,-119.6 586.78,-119.6 580.78,-119.6 574.78,-113.6 574.78,-107.6 574.78,-107.6 574.78,-68 574.78,-68 574.78,-62 580.78,-56 586.78,-56 586.78,-56 677.38,-56 677.38,-56 683.38,-56 689.38,-62 689.38,-68 689.38,-68 689.38,-107.6 689.38,-107.6 689.38,-113.6 683.38,-119.6 677.38,-119.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"612.58\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n", |
| "<text text-anchor=\"start\" x=\"585.58\" y=\"-68.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n", |
| "</g>\n", |
| "<!-- scaler->evaluated_model -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>scaler->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M689.61,-102.6C697.38,-104.64 705.46,-106.75 713.51,-108.86\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"712.41,-112.19 722.97,-111.34 714.19,-105.42 712.41,-112.19\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset->evaluated_model -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>test_dataset->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M397.75,-24.79C467.18,-15.77 596.01,-7.08 695.75,-46.8 718.39,-55.82 739.28,-72.76 755.58,-88.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"752.67,-90.94 762.15,-95.66 757.68,-86.05 752.67,-90.94\"/>\n", |
| "</g>\n", |
| "<!-- train_dataset->linear_model -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>train_dataset->linear_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M400.1,-140.61C407.37,-141.56 414.91,-142.54 422.36,-143.52\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"421.68,-146.96 432.05,-144.79 422.59,-140.02 421.68,-146.96\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->lr_model -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>linear_model->lr_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-158.35C545.37,-159.05 551.11,-159.77 556.86,-160.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"556.06,-163.92 566.42,-161.69 556.93,-156.97 556.06,-163.92\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->scaler -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>linear_model->scaler</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-128.51C547.69,-125 555.92,-121.32 564.07,-117.69\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"565.22,-121.01 572.93,-113.74 562.37,-114.62 565.22,-121.01\"/>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>_data_sets_inputs</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n", |
| "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs->data_sets -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>_data_sets_inputs->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M147.83,-53.78C153.77,-55.49 159.69,-57.19 165.43,-58.83\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"164.17,-62.11 174.75,-61.51 166.1,-55.39 164.17,-62.11\"/>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-202.1 46.8,-202.1 46.8,-165.5 100.8,-165.5 100.8,-202.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-178\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-257.1C96.22,-257.1 51.37,-257.1 51.37,-257.1 45.37,-257.1 39.37,-251.1 39.37,-245.1 39.37,-245.1 39.37,-232.5 39.37,-232.5 39.37,-226.5 45.37,-220.5 51.37,-220.5 51.37,-220.5 96.22,-220.5 96.22,-220.5 102.22,-220.5 108.22,-226.5 108.22,-232.5 108.22,-232.5 108.22,-245.1 108.22,-245.1 108.22,-251.1 102.22,-257.1 96.22,-257.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-233\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x1516f2d60>" |
| ] |
| }, |
| "metadata": {}, |
| "output_type": "display_data" |
| } |
| ], |
| "source": [ |
| "%%cell_to_module pipeline --display\n", |
| "# when done you can write to file and then load it as a module normally\n", |
| "# add -w to do so\n", |
| "import numpy as np\n", |
| "import pandas as pd\n", |
| "from sklearn.linear_model import LinearRegression\n", |
| "from sklearn.preprocessing import StandardScaler\n", |
| "from sklearn.metrics import mean_absolute_error\n", |
| "\n", |
| "from hamilton.function_modifiers import extract_fields\n", |
| "\n", |
| "\n", |
| "def mpg_df() -> pd.DataFrame:\n", |
| " url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'\n", |
| " column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',\n", |
| " 'Acceleration', 'Model Year', 'Origin']\n", |
| "\n", |
| " raw_dataset = pd.read_csv(url, names=column_names,\n", |
| " na_values='?', comment='\\t',\n", |
| " sep=' ', skipinitialspace=True)\n", |
| "\n", |
| " ## some schema manipulation\n", |
| " _mpg_df = raw_dataset.rename(columns={\"Model Year\": \"ModelYear\"})\n", |
| " return _mpg_df\n", |
| "\n", |
| "\n", |
| "@extract_fields({\"train_dataset\": pd.DataFrame, \"test_dataset\": pd.DataFrame})\n", |
| "def data_sets(mpg_df: pd.DataFrame, train_test_split: float = 0.8, seed: int = 123) -> dict:\n", |
| " # Do some feature engineering / data cleaning to create the data sets\n", |
| " # one hot encode -- we know the encoding here.\n", |
| " for value, country in {1: \"USA\", 2: \"Europe\", 3: \"Japan\"}.items():\n", |
| " mpg_df[country] = np.where(mpg_df[\"Origin\"] == value, 1, 0)\n", |
| " raw_dataset = mpg_df.dropna()\n", |
| " # split the pandas dataframe into train and test\n", |
| " train_dataset = raw_dataset.sample(frac=train_test_split, random_state=seed)\n", |
| " test_dataset = raw_dataset.drop(train_dataset.index)\n", |
| "\n", |
| " return {\"train_dataset\": train_dataset, \"test_dataset\": test_dataset}\n", |
| "\n", |
| "\n", |
| "def target_column() -> str:\n", |
| " return \"MPG\"\n", |
| "\n", |
| "\n", |
| "@extract_fields({\"lr_model\": LinearRegression, \"scaler\": StandardScaler})\n", |
| "def linear_model(train_dataset: pd.DataFrame, target_column: str) -> dict:\n", |
| " # Fit the model\n", |
| " # pull out target\n", |
| " train_labels = train_dataset.pop(target_column)\n", |
| " # Convert boolean columns to integers for the model\n", |
| " bool_columns = train_dataset.select_dtypes(include=[bool]).columns\n", |
| " train_dataset[bool_columns] = train_dataset[bool_columns].astype(int)\n", |
| " # Normalize the features for the model\n", |
| " scaler = StandardScaler()\n", |
| " train_dataset_scaled = scaler.fit_transform(train_dataset)\n", |
| "\n", |
| " # Initialize and fit the Linear Regression model\n", |
| " linear_model = LinearRegression()\n", |
| " linear_model.fit(train_dataset_scaled, train_labels)\n", |
| " return {\"lr_model\": linear_model, \"scaler\": scaler}\n", |
| "\n", |
| "\n", |
| "def evaluated_model(lr_model: LinearRegression,\n", |
| " scaler: StandardScaler,\n", |
| " test_dataset: pd.DataFrame, target_column: str) -> dict:\n", |
| " # evaluate the model - pull out target\n", |
| " test_labels = test_dataset.pop(target_column)\n", |
| " # Evaluate the model\n", |
| " # convert boolean columns to integers for the model\n", |
| " bool_columns = test_dataset.select_dtypes(include=[bool]).columns\n", |
| " test_dataset[bool_columns] = test_dataset[bool_columns].astype(int)\n", |
| " test_dataset_scaled = scaler.transform(test_dataset)\n", |
| "\n", |
| " # Predict and evaluate the model\n", |
| " test_pred = lr_model.predict(test_dataset_scaled)\n", |
| " mae = mean_absolute_error(test_labels, test_pred)\n", |
| " test_results = {\n", |
| " \"linear_model\": mae\n", |
| " }\n", |
| " return test_results\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:37:04.415885Z", |
| "start_time": "2024-07-20T17:37:04.097149Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "17b10355-4e25-4e75-84c5-fd95c0bd3dfb", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"865pt\" height=\"304pt\"\n", |
| " viewBox=\"0.00 0.00 864.60 303.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 299.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-299.8 860.6,-299.8 860.6,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"31.38,-157.8 31.38,-287.8 116.22,-287.8 116.22,-157.8 31.38,-157.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-270.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- target_column -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>target_column</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M392.8,-256.6C392.8,-256.6 301.45,-256.6 301.45,-256.6 295.45,-256.6 289.45,-250.6 289.45,-244.6 289.45,-244.6 289.45,-205 289.45,-205 289.45,-199 295.45,-193 301.45,-193 301.45,-193 392.8,-193 392.8,-193 398.8,-193 404.8,-199 404.8,-205 404.8,-205 404.8,-244.6 404.8,-244.6 404.8,-250.6 398.8,-256.6 392.8,-256.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"300.25\" y=\"-233.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n", |
| "<text text-anchor=\"start\" x=\"339.62\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n", |
| "</g>\n", |
| "<!-- evaluated_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>evaluated_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M844.6,-160.6C844.6,-160.6 736.75,-160.6 736.75,-160.6 730.75,-160.6 724.75,-154.6 724.75,-148.6 724.75,-148.6 724.75,-109 724.75,-109 724.75,-103 730.75,-97 736.75,-97 736.75,-97 844.6,-97 844.6,-97 850.6,-97 856.6,-103 856.6,-109 856.6,-109 856.6,-148.6 856.6,-148.6 856.6,-154.6 850.6,-160.6 844.6,-160.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"735.55\" y=\"-137.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n", |
| "<text text-anchor=\"start\" x=\"780.18\" y=\"-109.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->evaluated_model -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>target_column->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M405.29,-233.76C476.21,-242.38 599.51,-249.12 695.75,-210.8 718.39,-201.78 739.28,-184.84 755.58,-168.68\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"757.68,-171.55 762.15,-161.94 752.67,-166.66 757.68,-171.55\"/>\n", |
| "</g>\n", |
| "<!-- linear_model -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>linear_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M527.4,-183.6C527.4,-183.6 445.8,-183.6 445.8,-183.6 439.8,-183.6 433.8,-177.6 433.8,-171.6 433.8,-171.6 433.8,-132 433.8,-132 433.8,-126 439.8,-120 445.8,-120 445.8,-120 527.4,-120 527.4,-120 533.4,-120 539.4,-126 539.4,-132 539.4,-132 539.4,-171.6 539.4,-171.6 539.4,-177.6 533.4,-183.6 527.4,-183.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"444.6\" y=\"-160.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n", |
| "<text text-anchor=\"start\" x=\"476.1\" y=\"-132.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->linear_model -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>target_column->linear_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M405.23,-194.48C411.28,-191.26 417.45,-187.99 423.56,-184.75\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"425.2,-187.84 432.39,-180.06 421.91,-181.65 425.2,-187.84\"/>\n", |
| "</g>\n", |
| "<!-- lr_model -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>lr_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M683.75,-201.6C683.75,-201.6 580.4,-201.6 580.4,-201.6 574.4,-201.6 568.4,-195.6 568.4,-189.6 568.4,-189.6 568.4,-150 568.4,-150 568.4,-144 574.4,-138 580.4,-138 580.4,-138 683.75,-138 683.75,-138 689.75,-138 695.75,-144 695.75,-150 695.75,-150 695.75,-189.6 695.75,-189.6 695.75,-195.6 689.75,-201.6 683.75,-201.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"603.58\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n", |
| "<text text-anchor=\"start\" x=\"579.2\" y=\"-150.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n", |
| "</g>\n", |
| "<!-- lr_model->evaluated_model -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>lr_model->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M695.87,-153.36C701.57,-151.87 707.39,-150.34 713.19,-148.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"714.06,-152.22 722.84,-146.3 712.28,-145.44 714.06,-152.22\"/>\n", |
| "</g>\n", |
| "<!-- mpg_df -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>mpg_df</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n", |
| "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>data_sets</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M248.45,-105.6C248.45,-105.6 188.6,-105.6 188.6,-105.6 182.6,-105.6 176.6,-99.6 176.6,-93.6 176.6,-93.6 176.6,-54 176.6,-54 176.6,-48 182.6,-42 188.6,-42 188.6,-42 248.45,-42 248.45,-42 254.45,-42 260.45,-48 260.45,-54 260.45,-54 260.45,-93.6 260.45,-93.6 260.45,-99.6 254.45,-105.6 248.45,-105.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"187.4\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n", |
| "<text text-anchor=\"start\" x=\"208.02\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- mpg_df->data_sets -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>mpg_df->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M119.18,-102.74C133.93,-98.4 150.48,-93.53 165.74,-89.04\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"166.52,-92.46 175.13,-86.28 164.54,-85.74 166.52,-92.46\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>test_dataset</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M385.3,-64.6C385.3,-64.6 308.95,-64.6 308.95,-64.6 302.95,-64.6 296.95,-58.6 296.95,-52.6 296.95,-52.6 296.95,-13 296.95,-13 296.95,-7 302.95,-1 308.95,-1 308.95,-1 385.3,-1 385.3,-1 391.3,-1 397.3,-7 397.3,-13 397.3,-13 397.3,-52.6 397.3,-52.6 397.3,-58.6 391.3,-64.6 385.3,-64.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"307.75\" y=\"-41.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n", |
| "<text text-anchor=\"start\" x=\"313\" y=\"-13.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets->test_dataset -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>data_sets->test_dataset</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-60.49C268.75,-57.86 277.43,-55.05 286.03,-52.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"286.96,-55.64 295.4,-49.23 284.81,-48.98 286.96,-55.64\"/>\n", |
| "</g>\n", |
| "<!-- train_dataset -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>train_dataset</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M387.93,-165.6C387.93,-165.6 306.32,-165.6 306.32,-165.6 300.32,-165.6 294.32,-159.6 294.32,-153.6 294.32,-153.6 294.32,-114 294.32,-114 294.32,-108 300.32,-102 306.32,-102 306.32,-102 387.93,-102 387.93,-102 393.93,-102 399.93,-108 399.93,-114 399.93,-114 399.93,-153.6 399.93,-153.6 399.93,-159.6 393.93,-165.6 387.93,-165.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"305.12\" y=\"-142.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">train_dataset</text>\n", |
| "<text text-anchor=\"start\" x=\"313\" y=\"-114.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets->train_dataset -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>data_sets->train_dataset</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-93.28C268.06,-96.8 275.96,-100.54 283.83,-104.28\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"282.23,-107.39 292.76,-108.51 285.23,-101.07 282.23,-107.39\"/>\n", |
| "</g>\n", |
| "<!-- scaler -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>scaler</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M677.38,-119.6C677.38,-119.6 586.78,-119.6 586.78,-119.6 580.78,-119.6 574.78,-113.6 574.78,-107.6 574.78,-107.6 574.78,-68 574.78,-68 574.78,-62 580.78,-56 586.78,-56 586.78,-56 677.38,-56 677.38,-56 683.38,-56 689.38,-62 689.38,-68 689.38,-68 689.38,-107.6 689.38,-107.6 689.38,-113.6 683.38,-119.6 677.38,-119.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"612.58\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n", |
| "<text text-anchor=\"start\" x=\"585.58\" y=\"-68.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n", |
| "</g>\n", |
| "<!-- scaler->evaluated_model -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>scaler->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M689.61,-102.6C697.38,-104.64 705.46,-106.75 713.51,-108.86\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"712.41,-112.19 722.97,-111.34 714.19,-105.42 712.41,-112.19\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset->evaluated_model -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>test_dataset->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M397.75,-24.79C467.18,-15.77 596.01,-7.08 695.75,-46.8 718.39,-55.82 739.28,-72.76 755.58,-88.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"752.67,-90.94 762.15,-95.66 757.68,-86.05 752.67,-90.94\"/>\n", |
| "</g>\n", |
| "<!-- train_dataset->linear_model -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>train_dataset->linear_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M400.1,-140.61C407.37,-141.56 414.91,-142.54 422.36,-143.52\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"421.68,-146.96 432.05,-144.79 422.59,-140.02 421.68,-146.96\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->lr_model -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>linear_model->lr_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-158.35C545.37,-159.05 551.11,-159.77 556.86,-160.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"556.06,-163.92 566.42,-161.69 556.93,-156.97 556.06,-163.92\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->scaler -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>linear_model->scaler</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-128.51C547.69,-125 555.92,-121.32 564.07,-117.69\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"565.22,-121.01 572.93,-113.74 562.37,-114.62 565.22,-121.01\"/>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>_data_sets_inputs</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n", |
| "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs->data_sets -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>_data_sets_inputs->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M147.83,-53.78C153.77,-55.49 159.69,-57.19 165.43,-58.83\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"164.17,-62.11 174.75,-61.51 166.1,-55.39 164.17,-62.11\"/>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-202.1 46.8,-202.1 46.8,-165.5 100.8,-165.5 100.8,-202.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-178\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-257.1C96.22,-257.1 51.37,-257.1 51.37,-257.1 45.37,-257.1 39.37,-251.1 39.37,-245.1 39.37,-245.1 39.37,-232.5 39.37,-232.5 39.37,-226.5 45.37,-220.5 51.37,-220.5 51.37,-220.5 96.22,-220.5 96.22,-220.5 102.22,-220.5 108.22,-226.5 108.22,-232.5 108.22,-232.5 108.22,-245.1 108.22,-245.1 108.22,-251.1 102.22,-257.1 96.22,-257.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-233\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<hamilton.driver.Driver at 0x1516d7100>" |
| ] |
| }, |
| "execution_count": 4, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "dr = driver.Builder().with_modules(pipeline).build()\n", |
| "dr" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:37:06.029923Z", |
| "start_time": "2024-07-20T17:37:05.934165Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "923eff9a-ce20-484e-a4c7-acfbebb58e16", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "{'evaluated_model': {'linear_model': 2.4926580150007007},\n", |
| " 'linear_model': {'lr_model': LinearRegression(), 'scaler': StandardScaler()}}" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "result = dr.execute([\"evaluated_model\", \"linear_model\"])\n", |
| "result" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:37:36.535977Z", |
| "start_time": "2024-07-20T17:37:36.241822Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": { |
| "byteLimit": 2048000, |
| "rowLimit": 10000 |
| }, |
| "inputWidgets": {}, |
| "nuid": "399b0164-819e-4c8a-a46d-021742ace28e", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"607pt\" height=\"566pt\"\n", |
| " viewBox=\"0.00 0.00 607.40 565.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 561.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-561.8 603.4,-561.8 603.4,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"30.62,-309.8 30.62,-549.8 116.97,-549.8 116.97,-309.8 30.62,-309.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-532.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- target_column -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>target_column</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M119.47,-299.6C119.47,-299.6 28.12,-299.6 28.12,-299.6 22.12,-299.6 16.12,-293.6 16.12,-287.6 16.12,-287.6 16.12,-248 16.12,-248 16.12,-242 22.12,-236 28.12,-236 28.12,-236 119.47,-236 119.47,-236 125.47,-236 131.47,-242 131.47,-248 131.47,-248 131.47,-287.6 131.47,-287.6 131.47,-293.6 125.47,-299.6 119.47,-299.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"26.92\" y=\"-276.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n", |
| "<text text-anchor=\"start\" x=\"66.3\" y=\"-248.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n", |
| "</g>\n", |
| "<!-- evaluated_model -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>evaluated_model</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M587.4,-228.6C587.4,-228.6 479.55,-228.6 479.55,-228.6 473.55,-228.6 467.55,-222.6 467.55,-216.6 467.55,-216.6 467.55,-177 467.55,-177 467.55,-171 473.55,-165 479.55,-165 479.55,-165 587.4,-165 587.4,-165 593.4,-165 599.4,-171 599.4,-177 599.4,-177 599.4,-216.6 599.4,-216.6 599.4,-222.6 593.4,-228.6 587.4,-228.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"478.35\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n", |
| "<text text-anchor=\"start\" x=\"522.98\" y=\"-177.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->evaluated_model -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>target_column->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M131.81,-282.49C205.32,-298.56 335.92,-317.3 438.55,-278.8 461.54,-270.18 482.6,-253.09 498.92,-236.74\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"501.09,-239.54 505.49,-229.91 496.04,-234.69 501.09,-239.54\"/>\n", |
| "</g>\n", |
| "<!-- linear_model -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>linear_model</title>\n", |
| "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"282.2,-269.6 176.6,-269.6 176.6,-206 282.2,-206 282.2,-269.6\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"188.6,-269.6 176.6,-257.6\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"176.6,-218 188.6,-206\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"270.2,-206 282.2,-218\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"282.2,-257.6 270.2,-269.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"187.4\" y=\"-246.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n", |
| "<text text-anchor=\"start\" x=\"218.9\" y=\"-218.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- target_column->linear_model -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>target_column->linear_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M131.56,-256.71C142.52,-254.57 154.06,-252.32 165.21,-250.14\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"165.57,-253.64 174.71,-248.29 164.23,-246.77 165.57,-253.64\"/>\n", |
| "</g>\n", |
| "<!-- lr_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>lr_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M426.55,-269.6C426.55,-269.6 323.2,-269.6 323.2,-269.6 317.2,-269.6 311.2,-263.6 311.2,-257.6 311.2,-257.6 311.2,-218 311.2,-218 311.2,-212 317.2,-206 323.2,-206 323.2,-206 426.55,-206 426.55,-206 432.55,-206 438.55,-212 438.55,-218 438.55,-218 438.55,-257.6 438.55,-257.6 438.55,-263.6 432.55,-269.6 426.55,-269.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"346.38\" y=\"-246.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n", |
| "<text text-anchor=\"start\" x=\"322\" y=\"-218.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n", |
| "</g>\n", |
| "<!-- lr_model->evaluated_model -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>lr_model->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M438.67,-221.36C444.37,-219.87 450.19,-218.34 455.99,-216.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"456.86,-220.22 465.64,-214.3 455.08,-213.44 456.86,-220.22\"/>\n", |
| "</g>\n", |
| "<!-- mpg_df -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>mpg_df</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n", |
| "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>data_sets</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M259.32,-105.6C259.32,-105.6 199.47,-105.6 199.47,-105.6 193.47,-105.6 187.47,-99.6 187.47,-93.6 187.47,-93.6 187.47,-54 187.47,-54 187.47,-48 193.47,-42 199.47,-42 199.47,-42 259.32,-42 259.32,-42 265.32,-42 271.32,-48 271.32,-54 271.32,-54 271.32,-93.6 271.32,-93.6 271.32,-99.6 265.32,-105.6 259.32,-105.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"198.27\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n", |
| "<text text-anchor=\"start\" x=\"218.9\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- mpg_df->data_sets -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>mpg_df->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M118.82,-103.76C136.67,-98.88 157.42,-93.21 176.04,-88.12\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"176.83,-91.53 185.56,-85.52 174.99,-84.78 176.83,-91.53\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>test_dataset</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M413.05,-105.6C413.05,-105.6 336.7,-105.6 336.7,-105.6 330.7,-105.6 324.7,-99.6 324.7,-93.6 324.7,-93.6 324.7,-54 324.7,-54 324.7,-48 330.7,-42 336.7,-42 336.7,-42 413.05,-42 413.05,-42 419.05,-42 425.05,-48 425.05,-54 425.05,-54 425.05,-93.6 425.05,-93.6 425.05,-99.6 419.05,-105.6 413.05,-105.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"335.5\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n", |
| "<text text-anchor=\"start\" x=\"340.75\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- data_sets->test_dataset -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>data_sets->test_dataset</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M271.51,-73.8C284.52,-73.8 299.17,-73.8 313.24,-73.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"312.85,-77.3 322.85,-73.8 312.85,-70.3 312.85,-77.3\"/>\n", |
| "</g>\n", |
| "<!-- scaler -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>scaler</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M420.18,-187.6C420.18,-187.6 329.57,-187.6 329.57,-187.6 323.57,-187.6 317.57,-181.6 317.57,-175.6 317.57,-175.6 317.57,-136 317.57,-136 317.57,-130 323.57,-124 329.57,-124 329.57,-124 420.18,-124 420.18,-124 426.18,-124 432.18,-130 432.18,-136 432.18,-136 432.18,-175.6 432.18,-175.6 432.18,-181.6 426.18,-187.6 420.18,-187.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"355.38\" y=\"-164.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n", |
| "<text text-anchor=\"start\" x=\"328.38\" y=\"-136.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n", |
| "</g>\n", |
| "<!-- scaler->evaluated_model -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>scaler->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M432.41,-170.6C440.18,-172.64 448.26,-174.75 456.31,-176.86\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"455.21,-180.19 465.77,-179.34 456.99,-173.42 455.21,-180.19\"/>\n", |
| "</g>\n", |
| "<!-- test_dataset->evaluated_model -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>test_dataset->evaluated_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M425.25,-105.17C429.81,-108.34 434.3,-111.58 438.55,-114.8 455.87,-127.93 474.05,-143.37 489.76,-157.3\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"487.29,-159.78 497.08,-163.84 491.96,-154.57 487.29,-159.78\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->lr_model -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>linear_model->lr_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M282.6,-237.8C288.03,-237.8 293.61,-237.8 299.22,-237.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"299.21,-241.3 309.21,-237.8 299.21,-234.3 299.21,-241.3\"/>\n", |
| "</g>\n", |
| "<!-- linear_model->scaler -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>linear_model->scaler</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M282.6,-207.96C290.78,-203.29 299.33,-198.41 307.76,-193.58\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"309.34,-196.72 316.28,-188.72 305.86,-190.64 309.34,-196.72\"/>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>_data_sets_inputs</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n", |
| "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n", |
| "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n", |
| "</g>\n", |
| "<!-- _data_sets_inputs->data_sets -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>_data_sets_inputs->data_sets</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M148,-52.34C157.45,-54.86 166.98,-57.41 176.01,-59.81\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"175.1,-63.19 185.66,-62.39 176.9,-56.43 175.1,-63.19\"/>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-354.1 46.8,-354.1 46.8,-317.5 100.8,-317.5 100.8,-354.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-330\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-409.1C96.22,-409.1 51.37,-409.1 51.37,-409.1 45.37,-409.1 39.37,-403.1 39.37,-397.1 39.37,-397.1 39.37,-384.5 39.37,-384.5 39.37,-378.5 45.37,-372.5 51.37,-372.5 51.37,-372.5 96.22,-372.5 96.22,-372.5 102.22,-372.5 108.22,-378.5 108.22,-384.5 108.22,-384.5 108.22,-397.1 108.22,-397.1 108.22,-403.1 102.22,-409.1 96.22,-409.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-385\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "<!-- output -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>output</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M91.35,-464.1C91.35,-464.1 56.25,-464.1 56.25,-464.1 50.25,-464.1 44.25,-458.1 44.25,-452.1 44.25,-452.1 44.25,-439.5 44.25,-439.5 44.25,-433.5 50.25,-427.5 56.25,-427.5 56.25,-427.5 91.35,-427.5 91.35,-427.5 97.35,-427.5 103.35,-433.5 103.35,-439.5 103.35,-439.5 103.35,-452.1 103.35,-452.1 103.35,-458.1 97.35,-464.1 91.35,-464.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-440\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">output</text>\n", |
| "</g>\n", |
| "<!-- override -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>override</title>\n", |
| "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"108.97,-519.1 38.62,-519.1 38.62,-482.5 108.97,-482.5 108.97,-519.1\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"50.62,-519.1 38.62,-507.1\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"38.62,-494.5 50.62,-482.5\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"96.97,-482.5 108.97,-494.5\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"108.97,-507.1 96.97,-519.1\"/>\n", |
| "<text text-anchor=\"middle\" x=\"73.8\" y=\"-495\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">override</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x151753040>" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Visualize Overrides\n", |
| "dr.visualize_execution([\"evaluated_model\"], \n", |
| " overrides={\"linear_model\": result[\"linear_model\"]})" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-07-20T17:37:37.088527Z", |
| "start_time": "2024-07-20T17:37:37.025195Z" |
| }, |
| "application/vnd.databricks.v1+cell": { |
| "cellMetadata": {}, |
| "inputWidgets": {}, |
| "nuid": "b617038c-2ffa-498c-b6c9-bbd6bb79bb1f", |
| "showTitle": false, |
| "title": "" |
| } |
| }, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "{'evaluated_model': {'linear_model': 2.4926580150007007}}" |
| ] |
| }, |
| "execution_count": 7, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Execute with overrides\n", |
| "dr.execute([\"evaluated_model\"], overrides={\"linear_model\": result[\"linear_model\"]})" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [] |
| } |
| ], |
| "metadata": { |
| "application/vnd.databricks.v1+notebook": { |
| "dashboards": [], |
| "language": "python", |
| "notebookMetadata": { |
| "mostRecentlyExecutedCommandWithImplicitDF": { |
| "commandId": 2746022128672016, |
| "dataframes": [ |
| "_sqldf" |
| ] |
| }, |
| "pythonIndentUnit": 4 |
| }, |
| "notebookName": "MPG Simple V1 Target", |
| "widgets": {} |
| }, |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.9.13" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 4 |
| } |