|  | { | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "Uncomment and run the cell below if you are in a Google Colab environment. It will:\n", | 
|  | "1. Mount google drive. You will be asked to authenticate and give permissions.\n", | 
|  | "2. Change directory to google drive.\n", | 
|  | "3. Make a directory \"hamilton-tutorials\"\n", | 
|  | "4. Change directory to it.\n", | 
|  | "5. Clone this repository to your google drive\n", | 
|  | "6. Move your current directory to the hello_world example\n", | 
|  | "7. Install requirements.\n", | 
|  | "\n", | 
|  | "This means that any modifications will be saved, and you won't lose them if you close your browser." | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 1, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "## 1. Mount google drive\n", | 
|  | "# from google.colab import drive\n", | 
|  | "# drive.mount('/content/drive')\n", | 
|  | "## 2. Change directory to google drive.\n", | 
|  | "# %cd /content/drive/MyDrive\n", | 
|  | "## 3. Make a directory \"hamilton-tutorials\"\n", | 
|  | "# !mkdir hamilton-tutorials\n", | 
|  | "## 4. Change directory to it.\n", | 
|  | "# %cd hamilton-tutorials\n", | 
|  | "## 5. Clone this repository to your google drive\n", | 
|  | "# !git clone https://github.com/DAGWorks-Inc/hamilton/\n", | 
|  | "## 6. Move your current directory to the hello_world example\n", | 
|  | "# %cd hamilton/examples/hello_world\n", | 
|  | "## 7. Install requirements.\n", | 
|  | "# %pip install -r requirements.txt\n", | 
|  | "# clear_output()  # optionally clear outputs\n", | 
|  | "# To check your current working directory you can type `!pwd` in a cell and run it." | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 4, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "# Cell 2 - import modules to create part of the DAG from\n", | 
|  | "# We use the autoreload extension that comes with ipython to automatically reload modules when\n", | 
|  | "# the code in them changes.\n", | 
|  | "\n", | 
|  | "# import the jupyter extension\n", | 
|  | "%load_ext autoreload\n", | 
|  | "# set it to only reload the modules imported\n", | 
|  | "%autoreload 1" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 9, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "from __future__ import annotations\n", | 
|  | "\n", | 
|  | "import importlib\n", | 
|  | "import logging\n", | 
|  | "import sys\n", | 
|  | "from types import ModuleType\n", | 
|  | "from typing import Any, Dict, List\n", | 
|  | "\n", | 
|  | "import numpy as np\n", | 
|  | "import pandas as pd\n", | 
|  | "from sklearn.base import BaseEstimator, TransformerMixin\n", | 
|  | "from sklearn.pipeline import Pipeline\n", | 
|  | "from sklearn.preprocessing import StandardScaler\n", | 
|  | "from sklearn.utils.validation import check_array, check_is_fitted\n", | 
|  | "\n", | 
|  | "from hamilton import base, driver, log_setup, ad_hoc_utils\n", | 
|  | "\n", | 
|  | "logger = logging.getLogger(__name__)\n", | 
|  | "log_setup.setup_logging()" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 5, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "# We'll place the spend calculations into a new module\n", | 
|  | "\n", | 
|  | "def avg_3wk_spend(spend: pd.Series) -> pd.Series:\n", | 
|  | "    \"\"\"Rolling 3 week average spend.\"\"\"\n", | 
|  | "    return spend.rolling(3).mean()\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:\n", | 
|  | "    \"\"\"The cost per signup in relation to spend.\"\"\"\n", | 
|  | "    return spend / signups\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "spend_calculations = ad_hoc_utils.create_temporary_module(\n", | 
|  | "    avg_3wk_spend, spend_per_signup, module_name=\"spend_calculations\"\n", | 
|  | ")\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 6, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "# We'll place the spend statistics calculations into a new module\n", | 
|  | "\n", | 
|  | "def spend_mean(spend: pd.Series) -> float:\n", | 
|  | "    \"\"\"Shows function creating a scalar. In this case it computes the mean of the entire column.\"\"\"\n", | 
|  | "    return spend.mean()\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:\n", | 
|  | "    \"\"\"Shows function that takes a scalar. In this case to zero mean spend.\"\"\"\n", | 
|  | "    return spend - spend_mean\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def spend_std_dev(spend: pd.Series) -> float:\n", | 
|  | "    \"\"\"Function that computes the standard deviation of the spend column.\"\"\"\n", | 
|  | "    return spend.std()\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:\n", | 
|  | "    \"\"\"Function showing one way to make spend have zero mean and unit variance.\"\"\"\n", | 
|  | "    return spend_zero_mean / spend_std_dev\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "spend_statistics = ad_hoc_utils.create_temporary_module(\n", | 
|  | "    spend_mean, spend_zero_mean, spend_std_dev, spend_zero_mean_unit_variance, module_name=\"spend_statistics\"\n", | 
|  | ")\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "In this example we show you a custom scikit-learn `Transformer` class. This class should be compliant with [scikit-learn transformers specifications](https://scikit-learn.org/stable/developers/develop.html). This class is meant to be used as part of broader scikit-learn pipelines. Scikit-learn estimators and pipelines allow for stateful objects, which are helpful when applying transformations on train-test splits notably. Also, all pipeline, estimator, and transformer objects should be picklable, enabling reproducible pipelines." | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 8, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "class HamiltonTransformer(BaseEstimator, TransformerMixin):\n", | 
|  | "    \"\"\"Scikit-learn compatible Transformer implementing Hamilton behavior\"\"\"\n", | 
|  | "\n", | 
|  | "    def __init__(\n", | 
|  | "        self,\n", | 
|  | "        config: dict = None,\n", | 
|  | "        modules: List[ModuleType] = None,\n", | 
|  | "        adapter: base.HamiltonGraphAdapter = None,\n", | 
|  | "        final_vars: List[str] = None,\n", | 
|  | "    ):\n", | 
|  | "        self.config = {} if config is None else config\n", | 
|  | "        self.modules = [] if modules is None else modules\n", | 
|  | "        self.adapter = adapter\n", | 
|  | "        self.final_vars = [] if final_vars is None else final_vars\n", | 
|  | "\n", | 
|  | "    def get_params(self) -> dict:\n", | 
|  | "        \"\"\"Get parameters for this estimator.\n", | 
|  | "\n", | 
|  | "        :return: Current parameters of the estimator\n", | 
|  | "        \"\"\"\n", | 
|  | "        return {\n", | 
|  | "            \"config\": self.config,\n", | 
|  | "            \"modules\": self.modules,\n", | 
|  | "            \"adapter\": self.adapter,\n", | 
|  | "            \"final_vars\": self.final_vars,\n", | 
|  | "        }\n", | 
|  | "\n", | 
|  | "    def set_params(self, **parameters) -> HamiltonTransformer:\n", | 
|  | "        \"\"\"Get parameters for this estimator.\n", | 
|  | "\n", | 
|  | "        :param parameters: Estimator parameters.\n", | 
|  | "        :return: self\n", | 
|  | "        \"\"\"\n", | 
|  | "        for parameter, value in parameters.items():\n", | 
|  | "            setattr(self, parameter, value)\n", | 
|  | "        return self\n", | 
|  | "\n", | 
|  | "    def get_features_names_out(self):\n", | 
|  | "        \"\"\"\"\"\"\n", | 
|  | "        if self.feature_names_out_:\n", | 
|  | "            return self.feature_names_out_\n", | 
|  | "\n", | 
|  | "    def _get_tags(self) -> dict:\n", | 
|  | "        \"\"\"Get scikit-learn compatible estimator tags for introspection\n", | 
|  | "\n", | 
|  | "        ref: https://scikit-learn.org/stable/developers/develop.html#estimator-tags\n", | 
|  | "        \"\"\"\n", | 
|  | "        return {\"requires_fit\": True, \"requires_y\": False}\n", | 
|  | "\n", | 
|  | "    def fit(self, X, y=None, overrides: Dict[str, Any] = None) -> HamiltonTransformer:\n", | 
|  | "        \"\"\"Instantiate Hamilton driver.Driver object\n", | 
|  | "\n", | 
|  | "        :param X: Input 2D array\n", | 
|  | "        :param overrides: dictionary of override values passed to driver.execute() during .transform()\n", | 
|  | "        :return: self\n", | 
|  | "        \"\"\"\n", | 
|  | "\n", | 
|  | "        check_array(X, accept_sparse=True)\n", | 
|  | "        self.overrides_ = {} if overrides is None else overrides\n", | 
|  | "\n", | 
|  | "        self.driver_ = driver.Driver(self.config, *self.modules, adapter=self.adapter)\n", | 
|  | "        self.n_features_in_: int = X.shape[1]\n", | 
|  | "\n", | 
|  | "        return self\n", | 
|  | "\n", | 
|  | "    def transform(self, X, y=None, **kwargs) -> pd.DataFrame:\n", | 
|  | "        \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n", | 
|  | "        transformed version of X. Requires prior call to .fit() to instantiate Hamilton Driver\n", | 
|  | "\n", | 
|  | "        :param X: Input 2D array\n", | 
|  | "        :return: Hamilton Driver output 2D array\n", | 
|  | "        \"\"\"\n", | 
|  | "\n", | 
|  | "        check_is_fitted(self, \"n_features_in_\")\n", | 
|  | "\n", | 
|  | "        if isinstance(X, pd.DataFrame):\n", | 
|  | "            check_array(X, accept_sparse=True)\n", | 
|  | "            if X.shape[1] != self.n_features_in_:\n", | 
|  | "                raise ValueError(\"Shape of input is different from what was seen in `fit`\")\n", | 
|  | "\n", | 
|  | "            X = X.to_dict(orient=\"series\")\n", | 
|  | "\n", | 
|  | "        X_t = self.driver_.execute(final_vars=self.final_vars, overrides=self.overrides_, inputs=X)\n", | 
|  | "        # self.driver_.visualize_execution(final_vars=self.final_vars,\n", | 
|  | "        #                                  output_file_path=\"./scikit_transformer\",\n", | 
|  | "        #                                  render_kwargs={\"format\": \"png\"},\n", | 
|  | "        #                                  inputs=X)\n", | 
|  | "        self.n_features_out_ = len(self.final_vars)\n", | 
|  | "        self.feature_names_out_ = X_t.columns.to_list()\n", | 
|  | "        return X_t\n", | 
|  | "\n", | 
|  | "    def fit_transform(self, X, y=None, **fit_params) -> pd.DataFrame:\n", | 
|  | "        \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n", | 
|  | "        transformed version of X.\n", | 
|  | "\n", | 
|  | "        :param X: Input 2D array\n", | 
|  | "        :return: Hamilton Driver output 2D array\n", | 
|  | "        \"\"\"\n", | 
|  | "        return self.fit(X, **fit_params).transform(X)\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 10, | 
|  | "metadata": {}, | 
|  | "outputs": [ | 
|  | { | 
|  | "name": "stdout", | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "WARNING:hamilton.telemetry:Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "# Set up the driver, input and output columns\n", | 
|  | "initial_df = pd.DataFrame(\n", | 
|  | "    {\"signups\": [1, 10, 50, 100, 200, 400], \"spend\": [10, 10, 20, 40, 40, 50]}\n", | 
|  | ")\n", | 
|  | "\n", | 
|  | "output_columns = [\n", | 
|  | "    \"spend\",\n", | 
|  | "    \"signups\",\n", | 
|  | "    \"avg_3wk_spend\",\n", | 
|  | "    \"spend_per_signup\",\n", | 
|  | "    \"spend_zero_mean_unit_variance\",\n", | 
|  | "]\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "dr = driver.Driver({}, spend_calculations,spend_statistics)\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 14, | 
|  | "metadata": {}, | 
|  | "outputs": [ | 
|  | { | 
|  | "data": { | 
|  | "image/svg+xml": [ | 
|  | "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | 
|  | "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | 
|  | " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | 
|  | "<!-- Generated by graphviz version 2.43.0 (0)\n", | 
|  | " -->\n", | 
|  | "<!-- Title: %3 Pages: 1 -->\n", | 
|  | "<svg width=\"610pt\" height=\"260pt\"\n", | 
|  | " viewBox=\"0.00 0.00 610.43 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | 
|  | "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n", | 
|  | "<title>%3</title>\n", | 
|  | "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 606.43,-256 606.43,4 -4,4\"/>\n", | 
|  | "<!-- spend_per_signup -->\n", | 
|  | "<g id=\"node1\" class=\"node\">\n", | 
|  | "<title>spend_per_signup</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"94.24\" cy=\"-162\" rx=\"94.48\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"94.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_per_signup</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_mean -->\n", | 
|  | "<g id=\"node2\" class=\"node\">\n", | 
|  | "<title>spend_mean</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-162\" rx=\"68.49\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"274.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_mean</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_zero_mean -->\n", | 
|  | "<g id=\"node5\" class=\"node\">\n", | 
|  | "<title>spend_zero_mean</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-90\" rx=\"92.88\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"274.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_mean->spend_zero_mean -->\n", | 
|  | "<g id=\"edge6\" class=\"edge\">\n", | 
|  | "<title>spend_mean->spend_zero_mean</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M274.24,-143.7C274.24,-135.98 274.24,-126.71 274.24,-118.11\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"277.74,-118.1 274.24,-108.1 270.74,-118.1 277.74,-118.1\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend -->\n", | 
|  | "<g id=\"node3\" class=\"node\">\n", | 
|  | "<title>spend</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"370.24\" cy=\"-234\" rx=\"69.59\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"370.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: spend</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend->spend_per_signup -->\n", | 
|  | "<g id=\"edge1\" class=\"edge\">\n", | 
|  | "<title>spend->spend_per_signup</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M322.48,-220.89C277.06,-209.37 208.42,-191.96 158.28,-179.24\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"159.06,-175.83 148.51,-176.76 157.34,-182.61 159.06,-175.83\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend->spend_mean -->\n", | 
|  | "<g id=\"edge3\" class=\"edge\">\n", | 
|  | "<title>spend->spend_mean</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M347.98,-216.76C335.13,-207.4 318.8,-195.49 304.77,-185.26\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"306.53,-182.21 296.39,-179.15 302.4,-187.87 306.53,-182.21\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_std_dev -->\n", | 
|  | "<g id=\"node4\" class=\"node\">\n", | 
|  | "<title>spend_std_dev</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"464.24\" cy=\"-90\" rx=\"78.79\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"464.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_std_dev</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend->spend_std_dev -->\n", | 
|  | "<g id=\"edge4\" class=\"edge\">\n", | 
|  | "<title>spend->spend_std_dev</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M380.76,-216.14C392.05,-198.14 410.61,-168.86 427.24,-144 433.32,-134.91 440.14,-125.05 446.25,-116.33\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"449.18,-118.26 452.07,-108.07 443.45,-114.23 449.18,-118.26\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend->spend_zero_mean -->\n", | 
|  | "<g id=\"edge5\" class=\"edge\">\n", | 
|  | "<title>spend->spend_zero_mean</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M370.25,-215.82C369.48,-196.75 365.81,-165.81 351.24,-144 342.67,-131.17 329.85,-120.56 317.11,-112.3\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"318.72,-109.18 308.36,-106.95 315.07,-115.15 318.72,-109.18\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- avg_3wk_spend -->\n", | 
|  | "<g id=\"node6\" class=\"node\">\n", | 
|  | "<title>avg_3wk_spend</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"519.24\" cy=\"-162\" rx=\"83.39\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"519.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">avg_3wk_spend</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend->avg_3wk_spend -->\n", | 
|  | "<g id=\"edge7\" class=\"edge\">\n", | 
|  | "<title>spend->avg_3wk_spend</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M402.21,-217.98C424.02,-207.73 453.16,-194.05 476.9,-182.89\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"478.48,-186.02 486.04,-178.6 475.5,-179.68 478.48,-186.02\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_zero_mean_unit_variance -->\n", | 
|  | "<g id=\"node8\" class=\"node\">\n", | 
|  | "<title>spend_zero_mean_unit_variance</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" cx=\"369.24\" cy=\"-18\" rx=\"159.47\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"369.24\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean_unit_variance</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_std_dev->spend_zero_mean_unit_variance -->\n", | 
|  | "<g id=\"edge9\" class=\"edge\">\n", | 
|  | "<title>spend_std_dev->spend_zero_mean_unit_variance</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M441.73,-72.41C429.37,-63.3 413.85,-51.87 400.35,-41.92\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"402.38,-39.07 392.25,-35.96 398.23,-44.71 402.38,-39.07\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- spend_zero_mean->spend_zero_mean_unit_variance -->\n", | 
|  | "<g id=\"edge8\" class=\"edge\">\n", | 
|  | "<title>spend_zero_mean->spend_zero_mean_unit_variance</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M296.76,-72.41C309.11,-63.3 324.63,-51.87 338.13,-41.92\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"340.26,-44.71 346.23,-35.96 336.1,-39.07 340.26,-44.71\"/>\n", | 
|  | "</g>\n", | 
|  | "<!-- signups -->\n", | 
|  | "<g id=\"node7\" class=\"node\">\n", | 
|  | "<title>signups</title>\n", | 
|  | "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"94.24\" cy=\"-234\" rx=\"77.19\" ry=\"18\"/>\n", | 
|  | "<text text-anchor=\"middle\" x=\"94.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: signups</text>\n", | 
|  | "</g>\n", | 
|  | "<!-- signups->spend_per_signup -->\n", | 
|  | "<g id=\"edge2\" class=\"edge\">\n", | 
|  | "<title>signups->spend_per_signup</title>\n", | 
|  | "<path fill=\"none\" stroke=\"black\" d=\"M94.24,-215.7C94.24,-207.98 94.24,-198.71 94.24,-190.11\"/>\n", | 
|  | "<polygon fill=\"black\" stroke=\"black\" points=\"97.74,-190.1 94.24,-180.1 90.74,-190.1 97.74,-190.1\"/>\n", | 
|  | "</g>\n", | 
|  | "</g>\n", | 
|  | "</svg>\n" | 
|  | ], | 
|  | "text/plain": [ | 
|  | "<graphviz.graphs.Digraph at 0x7fc5dc81e550>" | 
|  | ] | 
|  | }, | 
|  | "execution_count": 14, | 
|  | "metadata": {}, | 
|  | "output_type": "execute_result" | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "# Visualize execution\n", | 
|  | "# To visualize do `pip install \"sf-hamilton[visualization]\"` if you want these to work\n", | 
|  | "\n", | 
|  | "# visualize all possible functions\n", | 
|  | "dr.display_all_functions(output_file_path=None)" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "Check 1: output of `vanilla driver` == `custom transformer`" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 16, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "\n", | 
|  | "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n", | 
|  | "\n", | 
|  | "custom_transformer = HamiltonTransformer(\n", | 
|  | "    config={}, modules=[spend_calculations, spend_statistics], final_vars=output_columns)\n", | 
|  | "sklearn_df = custom_transformer.fit_transform(initial_df)\n", | 
|  | "\n", | 
|  | "try:\n", | 
|  | "    pd.testing.assert_frame_equal(sklearn_df, hamilton_df)\n", | 
|  | "\n", | 
|  | "except ValueError as e:\n", | 
|  | "    logger.warning(\"Check 1 failed; `sklearn_df` and `hamilton_df` are unequal\")\n", | 
|  | "    raise e\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "Check 2: output of `vanilla driver >> transformation` == `scikit-learn pipeline`" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 17, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "scaler = StandardScaler()\n", | 
|  | "\n", | 
|  | "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n", | 
|  | "hamilton_then_sklearn = scaler.fit_transform(hamilton_df)\n", | 
|  | "\n", | 
|  | "pipeline1 = Pipeline(steps=[(\"hamilton\", custom_transformer), (\"scaler\", scaler)])\n", | 
|  | "pipe_custom_then_sklearn = pipeline1.fit_transform(initial_df)\n", | 
|  | "try:\n", | 
|  | "    assert isinstance(hamilton_then_sklearn, np.ndarray)\n", | 
|  | "    assert isinstance(pipe_custom_then_sklearn, np.ndarray)\n", | 
|  | "\n", | 
|  | "    np.testing.assert_equal(pipe_custom_then_sklearn, hamilton_then_sklearn)\n", | 
|  | "\n", | 
|  | "except ValueError as e:\n", | 
|  | "    logger.warning(\n", | 
|  | "        \"Check 2 failed; `pipe_custom_then_sklearn` and `hamilton_then_sklearn` are unequal\"\n", | 
|  | "    )\n", | 
|  | "    raise e\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "Check 3: output of `transformation >> vanilla driver` == `scikit-learn pipeline`\n", | 
|  | "The custom transformer requires a DataFrame, we leverage the `.set_output` from scikit-learn v1.2\n", | 
|  | "ref: https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 18, | 
|  | "metadata": {}, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "\n", | 
|  | "scaler = StandardScaler().set_output(transform=\"pandas\")\n", | 
|  | "\n", | 
|  | "scaled_df = scaler.fit_transform(initial_df)\n", | 
|  | "sklearn_then_hamilton = dr.execute(\n", | 
|  | "    final_vars=output_columns, inputs=scaled_df.to_dict(orient=\"series\")\n", | 
|  | ")\n", | 
|  | "\n", | 
|  | "pipeline2 = Pipeline(steps=[(\"scaler\", scaler), (\"hamilton\", custom_transformer)])\n", | 
|  | "pipe_sklearn_then_custom = pipeline2.fit_transform(initial_df)\n", | 
|  | "\n", | 
|  | "try:\n", | 
|  | "    assert isinstance(sklearn_then_hamilton, pd.DataFrame)\n", | 
|  | "    assert isinstance(pipe_sklearn_then_custom, pd.DataFrame)\n", | 
|  | "\n", | 
|  | "    pd.testing.assert_frame_equal(pipe_sklearn_then_custom, sklearn_then_hamilton)\n", | 
|  | "except ValueError as e:\n", | 
|  | "    logger.warning(\n", | 
|  | "        \"Check 3 failed; `pipe_sklearn_then_custom` and `sklearn_then_hamilton` are unequal\"\n", | 
|  | "    )\n", | 
|  | "    raise e\n", | 
|  | "\n", | 
|  | "logger.info(\"All checks passed. `HamiltonTransformer` behaves properly\")\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": {}, | 
|  | "source": [ | 
|  | "Before continuing with using hamilton with scikit-learn please be aware of its possible limitations [here](https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/scikit-learn#limitations-and-todos)." | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "metadata": { | 
|  | "kernelspec": { | 
|  | "display_name": "Python 3", | 
|  | "language": "python", | 
|  | "name": "python3" | 
|  | }, | 
|  | "language_info": { | 
|  | "codemirror_mode": { | 
|  | "name": "ipython", | 
|  | "version": 3 | 
|  | }, | 
|  | "file_extension": ".py", | 
|  | "mimetype": "text/x-python", | 
|  | "name": "python", | 
|  | "nbconvert_exporter": "python", | 
|  | "pygments_lexer": "ipython3", | 
|  | "version": "3.11.1" | 
|  | }, | 
|  | "orig_nbformat": 4 | 
|  | }, | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 2 | 
|  | } |