| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Uncomment and run the cell below if you are in a Google Colab environment. It will:\n", |
| "1. Mount google drive. You will be asked to authenticate and give permissions.\n", |
| "2. Change directory to google drive.\n", |
| "3. Make a directory \"hamilton-tutorials\"\n", |
| "4. Change directory to it.\n", |
| "5. Clone this repository to your google drive\n", |
| "6. Move your current directory to the hello_world example\n", |
| "7. Install requirements.\n", |
| "\n", |
| "This means that any modifications will be saved, and you won't lose them if you close your browser." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "## 1. Mount google drive\n", |
| "# from google.colab import drive\n", |
| "# drive.mount('/content/drive')\n", |
| "## 2. Change directory to google drive.\n", |
| "# %cd /content/drive/MyDrive\n", |
| "## 3. Make a directory \"hamilton-tutorials\"\n", |
| "# !mkdir hamilton-tutorials\n", |
| "## 4. Change directory to it.\n", |
| "# %cd hamilton-tutorials\n", |
| "## 5. Clone this repository to your google drive\n", |
| "# !git clone https://github.com/DAGWorks-Inc/hamilton/\n", |
| "## 6. Move your current directory to the hello_world example\n", |
| "# %cd hamilton/examples/hello_world\n", |
| "## 7. Install requirements.\n", |
| "# %pip install -r requirements.txt\n", |
| "# clear_output() # optionally clear outputs\n", |
| "# To check your current working directory you can type `!pwd` in a cell and run it." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Cell 2 - import modules to create part of the DAG from\n", |
| "# We use the autoreload extension that comes with ipython to automatically reload modules when\n", |
| "# the code in them changes.\n", |
| "\n", |
| "# import the jupyter extension\n", |
| "%load_ext autoreload\n", |
| "# set it to only reload the modules imported\n", |
| "%autoreload 1" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from __future__ import annotations\n", |
| "\n", |
| "import importlib\n", |
| "import logging\n", |
| "import sys\n", |
| "from types import ModuleType\n", |
| "from typing import Any, Dict, List\n", |
| "\n", |
| "import numpy as np\n", |
| "import pandas as pd\n", |
| "from sklearn.base import BaseEstimator, TransformerMixin\n", |
| "from sklearn.pipeline import Pipeline\n", |
| "from sklearn.preprocessing import StandardScaler\n", |
| "from sklearn.utils.validation import check_array, check_is_fitted\n", |
| "\n", |
| "from hamilton import base, driver, log_setup, ad_hoc_utils\n", |
| "\n", |
| "logger = logging.getLogger(__name__)\n", |
| "log_setup.setup_logging()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# We'll place the spend calculations into a new module\n", |
| "\n", |
| "def avg_3wk_spend(spend: pd.Series) -> pd.Series:\n", |
| " \"\"\"Rolling 3 week average spend.\"\"\"\n", |
| " return spend.rolling(3).mean()\n", |
| "\n", |
| "\n", |
| "def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:\n", |
| " \"\"\"The cost per signup in relation to spend.\"\"\"\n", |
| " return spend / signups\n", |
| "\n", |
| "\n", |
| "spend_calculations = ad_hoc_utils.create_temporary_module(\n", |
| " avg_3wk_spend, spend_per_signup, module_name=\"spend_calculations\"\n", |
| ")\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# We'll place the spend statistics calculations into a new module\n", |
| "\n", |
| "def spend_mean(spend: pd.Series) -> float:\n", |
| " \"\"\"Shows function creating a scalar. In this case it computes the mean of the entire column.\"\"\"\n", |
| " return spend.mean()\n", |
| "\n", |
| "\n", |
| "def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:\n", |
| " \"\"\"Shows function that takes a scalar. In this case to zero mean spend.\"\"\"\n", |
| " return spend - spend_mean\n", |
| "\n", |
| "\n", |
| "def spend_std_dev(spend: pd.Series) -> float:\n", |
| " \"\"\"Function that computes the standard deviation of the spend column.\"\"\"\n", |
| " return spend.std()\n", |
| "\n", |
| "\n", |
| "def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:\n", |
| " \"\"\"Function showing one way to make spend have zero mean and unit variance.\"\"\"\n", |
| " return spend_zero_mean / spend_std_dev\n", |
| "\n", |
| "\n", |
| "spend_statistics = ad_hoc_utils.create_temporary_module(\n", |
| " spend_mean, spend_zero_mean, spend_std_dev, spend_zero_mean_unit_variance, module_name=\"spend_statistics\"\n", |
| ")\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "In this example we show you a custom scikit-learn `Transformer` class. This class should be compliant with [scikit-learn transformers specifications](https://scikit-learn.org/stable/developers/develop.html). This class is meant to be used as part of broader scikit-learn pipelines. Scikit-learn estimators and pipelines allow for stateful objects, which are helpful when applying transformations on train-test splits notably. Also, all pipeline, estimator, and transformer objects should be picklable, enabling reproducible pipelines." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "class HamiltonTransformer(BaseEstimator, TransformerMixin):\n", |
| " \"\"\"Scikit-learn compatible Transformer implementing Hamilton behavior\"\"\"\n", |
| "\n", |
| " def __init__(\n", |
| " self,\n", |
| " config: dict = None,\n", |
| " modules: List[ModuleType] = None,\n", |
| " adapter: base.HamiltonGraphAdapter = None,\n", |
| " final_vars: List[str] = None,\n", |
| " ):\n", |
| " self.config = {} if config is None else config\n", |
| " self.modules = [] if modules is None else modules\n", |
| " self.adapter = adapter\n", |
| " self.final_vars = [] if final_vars is None else final_vars\n", |
| "\n", |
| " def get_params(self) -> dict:\n", |
| " \"\"\"Get parameters for this estimator.\n", |
| "\n", |
| " :return: Current parameters of the estimator\n", |
| " \"\"\"\n", |
| " return {\n", |
| " \"config\": self.config,\n", |
| " \"modules\": self.modules,\n", |
| " \"adapter\": self.adapter,\n", |
| " \"final_vars\": self.final_vars,\n", |
| " }\n", |
| "\n", |
| " def set_params(self, **parameters) -> HamiltonTransformer:\n", |
| " \"\"\"Get parameters for this estimator.\n", |
| "\n", |
| " :param parameters: Estimator parameters.\n", |
| " :return: self\n", |
| " \"\"\"\n", |
| " for parameter, value in parameters.items():\n", |
| " setattr(self, parameter, value)\n", |
| " return self\n", |
| "\n", |
| " def get_features_names_out(self):\n", |
| " \"\"\"\"\"\"\n", |
| " if self.feature_names_out_:\n", |
| " return self.feature_names_out_\n", |
| "\n", |
| " def _get_tags(self) -> dict:\n", |
| " \"\"\"Get scikit-learn compatible estimator tags for introspection\n", |
| "\n", |
| " ref: https://scikit-learn.org/stable/developers/develop.html#estimator-tags\n", |
| " \"\"\"\n", |
| " return {\"requires_fit\": True, \"requires_y\": False}\n", |
| "\n", |
| " def fit(self, X, y=None, overrides: Dict[str, Any] = None) -> HamiltonTransformer:\n", |
| " \"\"\"Instantiate Hamilton driver.Driver object\n", |
| "\n", |
| " :param X: Input 2D array\n", |
| " :param overrides: dictionary of override values passed to driver.execute() during .transform()\n", |
| " :return: self\n", |
| " \"\"\"\n", |
| "\n", |
| " check_array(X, accept_sparse=True)\n", |
| " self.overrides_ = {} if overrides is None else overrides\n", |
| "\n", |
| " self.driver_ = driver.Driver(self.config, *self.modules, adapter=self.adapter)\n", |
| " self.n_features_in_: int = X.shape[1]\n", |
| "\n", |
| " return self\n", |
| "\n", |
| " def transform(self, X, y=None, **kwargs) -> pd.DataFrame:\n", |
| " \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n", |
| " transformed version of X. Requires prior call to .fit() to instantiate Hamilton Driver\n", |
| "\n", |
| " :param X: Input 2D array\n", |
| " :return: Hamilton Driver output 2D array\n", |
| " \"\"\"\n", |
| "\n", |
| " check_is_fitted(self, \"n_features_in_\")\n", |
| "\n", |
| " if isinstance(X, pd.DataFrame):\n", |
| " check_array(X, accept_sparse=True)\n", |
| " if X.shape[1] != self.n_features_in_:\n", |
| " raise ValueError(\"Shape of input is different from what was seen in `fit`\")\n", |
| "\n", |
| " X = X.to_dict(orient=\"series\")\n", |
| "\n", |
| " X_t = self.driver_.execute(final_vars=self.final_vars, overrides=self.overrides_, inputs=X)\n", |
| " # self.driver_.visualize_execution(final_vars=self.final_vars,\n", |
| " # output_file_path=\"./scikit_transformer\",\n", |
| " # render_kwargs={\"format\": \"png\"},\n", |
| " # inputs=X)\n", |
| " self.n_features_out_ = len(self.final_vars)\n", |
| " self.feature_names_out_ = X_t.columns.to_list()\n", |
| " return X_t\n", |
| "\n", |
| " def fit_transform(self, X, y=None, **fit_params) -> pd.DataFrame:\n", |
| " \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n", |
| " transformed version of X.\n", |
| "\n", |
| " :param X: Input 2D array\n", |
| " :return: Hamilton Driver output 2D array\n", |
| " \"\"\"\n", |
| " return self.fit(X, **fit_params).transform(X)\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "WARNING:hamilton.telemetry:Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n" |
| ] |
| } |
| ], |
| "source": [ |
| "# Set up the driver, input and output columns\n", |
| "initial_df = pd.DataFrame(\n", |
| " {\"signups\": [1, 10, 50, 100, 200, 400], \"spend\": [10, 10, 20, 40, 40, 50]}\n", |
| ")\n", |
| "\n", |
| "output_columns = [\n", |
| " \"spend\",\n", |
| " \"signups\",\n", |
| " \"avg_3wk_spend\",\n", |
| " \"spend_per_signup\",\n", |
| " \"spend_zero_mean_unit_variance\",\n", |
| "]\n", |
| "\n", |
| "\n", |
| "dr = driver.Driver({}, spend_calculations,spend_statistics)\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 2.43.0 (0)\n", |
| " -->\n", |
| "<!-- Title: %3 Pages: 1 -->\n", |
| "<svg width=\"610pt\" height=\"260pt\"\n", |
| " viewBox=\"0.00 0.00 610.43 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n", |
| "<title>%3</title>\n", |
| "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 606.43,-256 606.43,4 -4,4\"/>\n", |
| "<!-- spend_per_signup -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>spend_per_signup</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"94.24\" cy=\"-162\" rx=\"94.48\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"94.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_per_signup</text>\n", |
| "</g>\n", |
| "<!-- spend_mean -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>spend_mean</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-162\" rx=\"68.49\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"274.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_mean</text>\n", |
| "</g>\n", |
| "<!-- spend_zero_mean -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>spend_zero_mean</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-90\" rx=\"92.88\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"274.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean</text>\n", |
| "</g>\n", |
| "<!-- spend_mean->spend_zero_mean -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>spend_mean->spend_zero_mean</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M274.24,-143.7C274.24,-135.98 274.24,-126.71 274.24,-118.11\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"277.74,-118.1 274.24,-108.1 270.74,-118.1 277.74,-118.1\"/>\n", |
| "</g>\n", |
| "<!-- spend -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>spend</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"370.24\" cy=\"-234\" rx=\"69.59\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"370.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: spend</text>\n", |
| "</g>\n", |
| "<!-- spend->spend_per_signup -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>spend->spend_per_signup</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M322.48,-220.89C277.06,-209.37 208.42,-191.96 158.28,-179.24\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"159.06,-175.83 148.51,-176.76 157.34,-182.61 159.06,-175.83\"/>\n", |
| "</g>\n", |
| "<!-- spend->spend_mean -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>spend->spend_mean</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M347.98,-216.76C335.13,-207.4 318.8,-195.49 304.77,-185.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"306.53,-182.21 296.39,-179.15 302.4,-187.87 306.53,-182.21\"/>\n", |
| "</g>\n", |
| "<!-- spend_std_dev -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>spend_std_dev</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"464.24\" cy=\"-90\" rx=\"78.79\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"464.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_std_dev</text>\n", |
| "</g>\n", |
| "<!-- spend->spend_std_dev -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>spend->spend_std_dev</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M380.76,-216.14C392.05,-198.14 410.61,-168.86 427.24,-144 433.32,-134.91 440.14,-125.05 446.25,-116.33\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"449.18,-118.26 452.07,-108.07 443.45,-114.23 449.18,-118.26\"/>\n", |
| "</g>\n", |
| "<!-- spend->spend_zero_mean -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>spend->spend_zero_mean</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M370.25,-215.82C369.48,-196.75 365.81,-165.81 351.24,-144 342.67,-131.17 329.85,-120.56 317.11,-112.3\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"318.72,-109.18 308.36,-106.95 315.07,-115.15 318.72,-109.18\"/>\n", |
| "</g>\n", |
| "<!-- avg_3wk_spend -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>avg_3wk_spend</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"519.24\" cy=\"-162\" rx=\"83.39\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"519.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">avg_3wk_spend</text>\n", |
| "</g>\n", |
| "<!-- spend->avg_3wk_spend -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>spend->avg_3wk_spend</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M402.21,-217.98C424.02,-207.73 453.16,-194.05 476.9,-182.89\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"478.48,-186.02 486.04,-178.6 475.5,-179.68 478.48,-186.02\"/>\n", |
| "</g>\n", |
| "<!-- spend_zero_mean_unit_variance -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>spend_zero_mean_unit_variance</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" cx=\"369.24\" cy=\"-18\" rx=\"159.47\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"369.24\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean_unit_variance</text>\n", |
| "</g>\n", |
| "<!-- spend_std_dev->spend_zero_mean_unit_variance -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>spend_std_dev->spend_zero_mean_unit_variance</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M441.73,-72.41C429.37,-63.3 413.85,-51.87 400.35,-41.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"402.38,-39.07 392.25,-35.96 398.23,-44.71 402.38,-39.07\"/>\n", |
| "</g>\n", |
| "<!-- spend_zero_mean->spend_zero_mean_unit_variance -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>spend_zero_mean->spend_zero_mean_unit_variance</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M296.76,-72.41C309.11,-63.3 324.63,-51.87 338.13,-41.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"340.26,-44.71 346.23,-35.96 336.1,-39.07 340.26,-44.71\"/>\n", |
| "</g>\n", |
| "<!-- signups -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>signups</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"94.24\" cy=\"-234\" rx=\"77.19\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"94.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: signups</text>\n", |
| "</g>\n", |
| "<!-- signups->spend_per_signup -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>signups->spend_per_signup</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M94.24,-215.7C94.24,-207.98 94.24,-198.71 94.24,-190.11\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"97.74,-190.1 94.24,-180.1 90.74,-190.1 97.74,-190.1\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x7fc5dc81e550>" |
| ] |
| }, |
| "execution_count": 14, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Visualize execution\n", |
| "# To visualize do `pip install \"sf-hamilton[visualization]\"` if you want these to work\n", |
| "\n", |
| "# visualize all possible functions\n", |
| "dr.display_all_functions(output_file_path=None)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Check 1: output of `vanilla driver` == `custom transformer`" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "\n", |
| "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n", |
| "\n", |
| "custom_transformer = HamiltonTransformer(\n", |
| " config={}, modules=[spend_calculations, spend_statistics], final_vars=output_columns)\n", |
| "sklearn_df = custom_transformer.fit_transform(initial_df)\n", |
| "\n", |
| "try:\n", |
| " pd.testing.assert_frame_equal(sklearn_df, hamilton_df)\n", |
| "\n", |
| "except ValueError as e:\n", |
| " logger.warning(\"Check 1 failed; `sklearn_df` and `hamilton_df` are unequal\")\n", |
| " raise e\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Check 2: output of `vanilla driver >> transformation` == `scikit-learn pipeline`" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "scaler = StandardScaler()\n", |
| "\n", |
| "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n", |
| "hamilton_then_sklearn = scaler.fit_transform(hamilton_df)\n", |
| "\n", |
| "pipeline1 = Pipeline(steps=[(\"hamilton\", custom_transformer), (\"scaler\", scaler)])\n", |
| "pipe_custom_then_sklearn = pipeline1.fit_transform(initial_df)\n", |
| "try:\n", |
| " assert isinstance(hamilton_then_sklearn, np.ndarray)\n", |
| " assert isinstance(pipe_custom_then_sklearn, np.ndarray)\n", |
| "\n", |
| " np.testing.assert_equal(pipe_custom_then_sklearn, hamilton_then_sklearn)\n", |
| "\n", |
| "except ValueError as e:\n", |
| " logger.warning(\n", |
| " \"Check 2 failed; `pipe_custom_then_sklearn` and `hamilton_then_sklearn` are unequal\"\n", |
| " )\n", |
| " raise e\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Check 3: output of `transformation >> vanilla driver` == `scikit-learn pipeline`\n", |
| "The custom transformer requires a DataFrame, we leverage the `.set_output` from scikit-learn v1.2\n", |
| "ref: https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 18, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "\n", |
| "scaler = StandardScaler().set_output(transform=\"pandas\")\n", |
| "\n", |
| "scaled_df = scaler.fit_transform(initial_df)\n", |
| "sklearn_then_hamilton = dr.execute(\n", |
| " final_vars=output_columns, inputs=scaled_df.to_dict(orient=\"series\")\n", |
| ")\n", |
| "\n", |
| "pipeline2 = Pipeline(steps=[(\"scaler\", scaler), (\"hamilton\", custom_transformer)])\n", |
| "pipe_sklearn_then_custom = pipeline2.fit_transform(initial_df)\n", |
| "\n", |
| "try:\n", |
| " assert isinstance(sklearn_then_hamilton, pd.DataFrame)\n", |
| " assert isinstance(pipe_sklearn_then_custom, pd.DataFrame)\n", |
| "\n", |
| " pd.testing.assert_frame_equal(pipe_sklearn_then_custom, sklearn_then_hamilton)\n", |
| "except ValueError as e:\n", |
| " logger.warning(\n", |
| " \"Check 3 failed; `pipe_sklearn_then_custom` and `sklearn_then_hamilton` are unequal\"\n", |
| " )\n", |
| " raise e\n", |
| "\n", |
| "logger.info(\"All checks passed. `HamiltonTransformer` behaves properly\")\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Before continuing with using hamilton with scikit-learn please be aware of its possible limitations [here](https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/scikit-learn#limitations-and-todos)." |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.11.1" |
| }, |
| "orig_nbformat": 4 |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 2 |
| } |