| { |
| "cells": [ |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from hamilton import driver, dataflows" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# 1. Install from PyPI\n", |
| "Install the package from PyPI with `!pip install sf-hamilton-contrib`.\n", |
| "Then see run.py for details. " |
| ], |
| "metadata": { |
| "collapsed": false |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 2. Dynamic installation\n", |
| "Install the package dynamically with without installing/upgrading sf-hamilton-contrib." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "xgboost_optuna = dataflows.import_module(\"xgboost_optuna\", \"zilto\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 2.43.0 (0)\n", |
| " -->\n", |
| "<!-- Title: %3 Pages: 1 -->\n", |
| "<svg width=\"2759pt\" height=\"639pt\"\n", |
| " viewBox=\"0.00 0.00 2759.00 639.41\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 635.41)\">\n", |
| "<title>%3</title>\n", |
| "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-635.41 2755,-635.41 2755,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"8,-398.41 8,-584.41 104,-584.41 104,-398.41 8,-398.41\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-569.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>y_test_pred</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2459,-388.41C2459,-388.41 2370,-388.41 2370,-388.41 2364,-388.41 2358,-382.41 2358,-376.41 2358,-376.41 2358,-336.41 2358,-336.41 2358,-330.41 2364,-324.41 2370,-324.41 2370,-324.41 2459,-324.41 2459,-324.41 2465,-324.41 2471,-330.41 2471,-336.41 2471,-336.41 2471,-376.41 2471,-376.41 2471,-382.41 2465,-388.41 2459,-388.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"2369\" y=\"-367.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test_pred</text>\n", |
| "<text text-anchor=\"start\" x=\"2387\" y=\"-339.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n", |
| "</g>\n", |
| "<!-- test_score -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>test_score</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2739,-140.41C2739,-140.41 2659,-140.41 2659,-140.41 2653,-140.41 2647,-134.41 2647,-128.41 2647,-128.41 2647,-88.41 2647,-88.41 2647,-82.41 2653,-76.41 2659,-76.41 2659,-76.41 2739,-76.41 2739,-76.41 2745,-76.41 2751,-82.41 2751,-88.41 2751,-88.41 2751,-128.41 2751,-128.41 2751,-134.41 2745,-140.41 2739,-140.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"2658\" y=\"-119.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_score</text>\n", |
| "<text text-anchor=\"start\" x=\"2685\" y=\"-91.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred->test_score -->\n", |
| "<g id=\"edge24\" class=\"edge\">\n", |
| "<title>y_test_pred->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2452.14,-324.24C2503.9,-278.8 2598.29,-195.94 2653.85,-147.17\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2656.22,-149.74 2661.42,-140.52 2651.6,-144.48 2656.22,-149.74\"/>\n", |
| "</g>\n", |
| "<!-- base_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>base_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M873,-603.41C873,-603.41 782,-603.41 782,-603.41 776,-603.41 770,-597.41 770,-591.41 770,-591.41 770,-551.41 770,-551.41 770,-545.41 776,-539.41 782,-539.41 782,-539.41 873,-539.41 873,-539.41 879,-539.41 885,-545.41 885,-551.41 885,-551.41 885,-591.41 885,-591.41 885,-597.41 879,-603.41 873,-603.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"781\" y=\"-582.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">base_model</text>\n", |
| "<text text-anchor=\"start\" x=\"800\" y=\"-554.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Callable</text>\n", |
| "</g>\n", |
| "<!-- best_model -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>best_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2022,-466.41C2022,-466.41 1934,-466.41 1934,-466.41 1928,-466.41 1922,-460.41 1922,-454.41 1922,-454.41 1922,-414.41 1922,-414.41 1922,-408.41 1928,-402.41 1934,-402.41 1934,-402.41 2022,-402.41 2022,-402.41 2028,-402.41 2034,-408.41 2034,-414.41 2034,-414.41 2034,-454.41 2034,-454.41 2034,-460.41 2028,-466.41 2022,-466.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1933\" y=\"-445.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_model</text>\n", |
| "<text text-anchor=\"start\" x=\"1942\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">XGBModel</text>\n", |
| "</g>\n", |
| "<!-- base_model->best_model -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>base_model->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M885.16,-571.21C957.34,-571.01 1085.94,-570.79 1196,-571.41\"/>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>hyperparameter_search</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1290,-357.41C1290,-357.41 1104,-357.41 1104,-357.41 1098,-357.41 1092,-351.41 1092,-345.41 1092,-345.41 1092,-305.41 1092,-305.41 1092,-299.41 1098,-293.41 1104,-293.41 1104,-293.41 1290,-293.41 1290,-293.41 1296,-293.41 1302,-299.41 1302,-305.41 1302,-305.41 1302,-345.41 1302,-345.41 1302,-351.41 1296,-357.41 1290,-357.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1103\" y=\"-336.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">hyperparameter_search</text>\n", |
| "<text text-anchor=\"start\" x=\"1184\" y=\"-308.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- base_model->hyperparameter_search -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>base_model->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M885.23,-573.79C935.83,-573.22 1009.91,-565.51 1063,-530.41 1122.59,-491.01 1161.34,-414.74 1180.87,-367.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1184.17,-368.33 1184.64,-357.75 1177.67,-365.73 1184.17,-368.33\"/>\n", |
| "</g>\n", |
| "<!-- study -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>study</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M848.5,-439.41C848.5,-439.41 806.5,-439.41 806.5,-439.41 800.5,-439.41 794.5,-433.41 794.5,-427.41 794.5,-427.41 794.5,-387.41 794.5,-387.41 794.5,-381.41 800.5,-375.41 806.5,-375.41 806.5,-375.41 848.5,-375.41 848.5,-375.41 854.5,-375.41 860.5,-381.41 860.5,-387.41 860.5,-387.41 860.5,-427.41 860.5,-427.41 860.5,-433.41 854.5,-439.41 848.5,-439.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"805.5\" y=\"-418.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study</text>\n", |
| "<text text-anchor=\"start\" x=\"807\" y=\"-390.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n", |
| "</g>\n", |
| "<!-- study->hyperparameter_search -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>study->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M860.62,-403.1C906.04,-396.71 991.47,-383.61 1063,-366.41 1070.85,-364.52 1078.92,-362.42 1087.01,-360.21\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1088.19,-363.52 1096.88,-357.46 1086.31,-356.77 1088.19,-363.52\"/>\n", |
| "</g>\n", |
| "<!-- task -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>task</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"77,-388.41 29,-388.41 29,-338.41 83,-338.41 83,-382.41 77,-388.41\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"77,-388.41 77,-382.41 \"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"83,-382.41 77,-382.41 \"/>\n", |
| "<text text-anchor=\"start\" x=\"39\" y=\"-374.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">task</text>\n", |
| "<text text-anchor=\"start\" x=\"42\" y=\"-346.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n", |
| "</g>\n", |
| "<!-- scorer -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>scorer</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M80,-320.41C80,-320.41 32,-320.41 32,-320.41 26,-320.41 20,-314.41 20,-308.41 20,-308.41 20,-268.41 20,-268.41 20,-262.41 26,-256.41 32,-256.41 32,-256.41 80,-256.41 80,-256.41 86,-256.41 92,-262.41 92,-268.41 92,-268.41 92,-308.41 92,-308.41 92,-314.41 86,-320.41 80,-320.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"31\" y=\"-299.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scorer</text>\n", |
| "<text text-anchor=\"start\" x=\"43\" y=\"-271.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- higher_is_better -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>higher_is_better</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M406.5,-547.41C406.5,-547.41 281.5,-547.41 281.5,-547.41 275.5,-547.41 269.5,-541.41 269.5,-535.41 269.5,-535.41 269.5,-495.41 269.5,-495.41 269.5,-489.41 275.5,-483.41 281.5,-483.41 281.5,-483.41 406.5,-483.41 406.5,-483.41 412.5,-483.41 418.5,-489.41 418.5,-495.41 418.5,-495.41 418.5,-535.41 418.5,-535.41 418.5,-541.41 412.5,-547.41 406.5,-547.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"280.5\" y=\"-526.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">higher_is_better</text>\n", |
| "<text text-anchor=\"start\" x=\"329\" y=\"-498.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">bool</text>\n", |
| "</g>\n", |
| "<!-- scorer->higher_is_better -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>scorer->higher_is_better</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M92,-313.67C96.66,-318.42 100.89,-323.7 104,-329.41 135.14,-386.59 79.93,-427.4 125,-474.41 159.01,-509.88 213.27,-520.17 259.15,-521.58\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"259.35,-525.08 269.41,-521.75 259.47,-518.08 259.35,-525.08\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>scoring_func</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M875.5,-87.41C875.5,-87.41 779.5,-87.41 779.5,-87.41 773.5,-87.41 767.5,-81.41 767.5,-75.41 767.5,-75.41 767.5,-35.41 767.5,-35.41 767.5,-29.41 773.5,-23.41 779.5,-23.41 779.5,-23.41 875.5,-23.41 875.5,-23.41 881.5,-23.41 887.5,-29.41 887.5,-35.41 887.5,-35.41 887.5,-75.41 887.5,-75.41 887.5,-81.41 881.5,-87.41 875.5,-87.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"778.5\" y=\"-66.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scoring_func</text>\n", |
| "<text text-anchor=\"start\" x=\"798.5\" y=\"-38.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "<!-- scorer->scoring_func -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>scorer->scoring_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M58.01,-256.15C61.32,-211.87 74.45,-133.4 125,-95.41 314.56,47.06 620.71,-3.77 757.09,-36.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"756.46,-39.93 767.01,-38.9 758.12,-33.13 756.46,-39.93\"/>\n", |
| "</g>\n", |
| "<!-- model_config -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>model_config</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M878,-521.41C878,-521.41 777,-521.41 777,-521.41 771,-521.41 765,-515.41 765,-509.41 765,-509.41 765,-469.41 765,-469.41 765,-463.41 771,-457.41 777,-457.41 777,-457.41 878,-457.41 878,-457.41 884,-457.41 890,-463.41 890,-469.41 890,-469.41 890,-509.41 890,-509.41 890,-515.41 884,-521.41 878,-521.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"776\" y=\"-500.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model_config</text>\n", |
| "<text text-anchor=\"start\" x=\"814.5\" y=\"-472.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- model_config->best_model -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>model_config->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M890.16,-498.02C937.71,-505.18 1004.94,-516.44 1063,-530.41 1123.14,-544.87 1134.24,-567.99 1196,-571.41\"/>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1198,-571.41C1466.19,-586.24 1779.34,-498.42 1912.17,-456.16\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1913.34,-459.46 1921.8,-453.08 1911.21,-452.79 1913.34,-459.46\"/>\n", |
| "</g>\n", |
| "<!-- model_config->hyperparameter_search -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>model_config->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M890.22,-488.1C939.31,-484.88 1008.55,-475.36 1063,-448.41 1102.93,-428.64 1139.4,-393.07 1164.01,-365.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1166.71,-367.48 1170.63,-357.64 1161.43,-362.89 1166.71,-367.48\"/>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>best_hyperparameters</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1626,-471.41C1626,-471.41 1450,-471.41 1450,-471.41 1444,-471.41 1438,-465.41 1438,-459.41 1438,-459.41 1438,-419.41 1438,-419.41 1438,-413.41 1444,-407.41 1450,-407.41 1450,-407.41 1626,-407.41 1626,-407.41 1632,-407.41 1638,-413.41 1638,-419.41 1638,-419.41 1638,-459.41 1638,-459.41 1638,-465.41 1632,-471.41 1626,-471.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1449\" y=\"-450.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_hyperparameters</text>\n", |
| "<text text-anchor=\"start\" x=\"1525\" y=\"-422.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters->best_model -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>best_hyperparameters->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1638.12,-438.28C1721.82,-437.32 1839.29,-435.98 1911.66,-435.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1912.04,-438.65 1922,-435.03 1911.96,-431.65 1912.04,-438.65\"/>\n", |
| "</g>\n", |
| "<!-- higher_is_better->study -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>higher_is_better->study</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M418.57,-511.51C462.38,-506.83 518.02,-496.6 563,-474.41 578.52,-466.75 576.43,-455.97 592,-448.41 654.35,-418.14 735.64,-409.93 784.3,-407.87\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"784.59,-411.36 794.46,-407.51 784.35,-404.37 784.59,-411.36\"/>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>cross_validation_folds</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M912.5,-357.41C912.5,-357.41 742.5,-357.41 742.5,-357.41 736.5,-357.41 730.5,-351.41 730.5,-345.41 730.5,-345.41 730.5,-305.41 730.5,-305.41 730.5,-299.41 736.5,-293.41 742.5,-293.41 742.5,-293.41 912.5,-293.41 912.5,-293.41 918.5,-293.41 924.5,-299.41 924.5,-305.41 924.5,-305.41 924.5,-345.41 924.5,-345.41 924.5,-351.41 918.5,-357.41 912.5,-357.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"741.5\" y=\"-336.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">cross_validation_folds</text>\n", |
| "<text text-anchor=\"start\" x=\"793\" y=\"-308.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Sequence</text>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds->hyperparameter_search -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>cross_validation_folds->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M924.58,-325.41C972.63,-325.41 1031.12,-325.41 1081.49,-325.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1081.78,-328.91 1091.78,-325.41 1081.78,-321.91 1081.78,-328.91\"/>\n", |
| "</g>\n", |
| "<!-- optuna_distributions -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>optuna_distributions</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M907.5,-275.41C907.5,-275.41 747.5,-275.41 747.5,-275.41 741.5,-275.41 735.5,-269.41 735.5,-263.41 735.5,-263.41 735.5,-223.41 735.5,-223.41 735.5,-217.41 741.5,-211.41 747.5,-211.41 747.5,-211.41 907.5,-211.41 907.5,-211.41 913.5,-211.41 919.5,-217.41 919.5,-223.41 919.5,-223.41 919.5,-263.41 919.5,-263.41 919.5,-269.41 913.5,-275.41 907.5,-275.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"746.5\" y=\"-254.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">optuna_distributions</text>\n", |
| "<text text-anchor=\"start\" x=\"814.5\" y=\"-226.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- optuna_distributions->hyperparameter_search -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>optuna_distributions->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M919.53,-256.59C963.19,-263.6 1016.09,-273.12 1063,-284.41 1070.85,-286.3 1078.92,-288.39 1087.01,-290.6\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1086.31,-294.04 1096.88,-293.36 1088.19,-287.3 1086.31,-294.04\"/>\n", |
| "</g>\n", |
| "<!-- study_results_df -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>study_results_df</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2040.5,-305.41C2040.5,-305.41 1915.5,-305.41 1915.5,-305.41 1909.5,-305.41 1903.5,-299.41 1903.5,-293.41 1903.5,-293.41 1903.5,-253.41 1903.5,-253.41 1903.5,-247.41 1909.5,-241.41 1915.5,-241.41 1915.5,-241.41 2040.5,-241.41 2040.5,-241.41 2046.5,-241.41 2052.5,-247.41 2052.5,-253.41 2052.5,-253.41 2052.5,-293.41 2052.5,-293.41 2052.5,-299.41 2046.5,-305.41 2040.5,-305.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1914.5\" y=\"-284.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results_df</text>\n", |
| "<text text-anchor=\"start\" x=\"1939.5\" y=\"-256.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- best_model->y_test_pred -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>best_model->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2034.03,-424.53C2114.1,-410.16 2262.46,-383.52 2348.03,-368.16\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2348.66,-371.6 2357.89,-366.39 2347.43,-364.71 2348.66,-371.6\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->hyperparameter_search -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>scoring_func->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M887.65,-52.42C938.46,-52.76 1011.47,-60.48 1063,-96.41 1097.1,-120.18 1149.38,-224.63 1177.15,-283.91\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1174,-285.43 1181.39,-293.02 1180.34,-282.48 1174,-285.43\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->test_score -->\n", |
| "<g id=\"edge25\" class=\"edge\">\n", |
| "<title>scoring_func->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M887.64,-62.83C1009.1,-77.42 1295.55,-108.41 1537,-108.41 1537,-108.41 1537,-108.41 1979,-108.41 2222.43,-108.41 2512.23,-108.41 2636.48,-108.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2636.56,-111.91 2646.56,-108.41 2636.56,-104.91 2636.56,-111.91\"/>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search->best_hyperparameters -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>hyperparameter_search->best_hyperparameters</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1247.75,-357.46C1272.03,-371.85 1302.17,-387.91 1331,-398.41 1361.61,-409.55 1395.94,-417.91 1427.67,-424.07\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1427.39,-427.58 1437.87,-426 1428.69,-420.7 1427.39,-427.58\"/>\n", |
| "</g>\n", |
| "<!-- study_results -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>study_results</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1589.5,-305.41C1589.5,-305.41 1486.5,-305.41 1486.5,-305.41 1480.5,-305.41 1474.5,-299.41 1474.5,-293.41 1474.5,-293.41 1474.5,-253.41 1474.5,-253.41 1474.5,-247.41 1480.5,-241.41 1486.5,-241.41 1486.5,-241.41 1589.5,-241.41 1589.5,-241.41 1595.5,-241.41 1601.5,-247.41 1601.5,-253.41 1601.5,-253.41 1601.5,-293.41 1601.5,-293.41 1601.5,-299.41 1595.5,-305.41 1589.5,-305.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1485.5\" y=\"-284.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results</text>\n", |
| "<text text-anchor=\"start\" x=\"1517.5\" y=\"-256.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search->study_results -->\n", |
| "<g id=\"edge23\" class=\"edge\">\n", |
| "<title>hyperparameter_search->study_results</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1302.01,-309.45C1354.42,-301.41 1416.69,-291.86 1463.96,-284.61\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1464.76,-288.03 1474.11,-283.05 1463.7,-281.11 1464.76,-288.03\"/>\n", |
| "</g>\n", |
| "<!-- study_results->study_results_df -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>study_results->study_results_df</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1601.75,-273.41C1678.36,-273.41 1808.17,-273.41 1893.25,-273.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1893.39,-276.91 1903.39,-273.41 1893.39,-269.91 1893.39,-276.91\"/>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>_y_test_pred_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"2182,-378.91 1774,-378.91 1774,-333.91 2182,-333.91 2182,-378.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"1789\" y=\"-352.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_test</text>\n", |
| "<text text-anchor=\"start\" x=\"1838\" y=\"-352.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs->y_test_pred -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>_y_test_pred_inputs->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2182.26,-356.41C2242.09,-356.41 2303.34,-356.41 2347.76,-356.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2347.83,-359.91 2357.83,-356.41 2347.83,-352.91 2347.83,-359.91\"/>\n", |
| "</g>\n", |
| "<!-- _study_inputs -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>_study_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"441,-464.91 247,-464.91 247,-335.91 441,-335.91 441,-464.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"283\" y=\"-438.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sampler</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-438.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"287.5\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">pruner</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"262\" y=\"-396.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_storage</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-396.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"265\" y=\"-375.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">load_if_exists</text>\n", |
| "<text text-anchor=\"start\" x=\"381.5\" y=\"-375.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"269\" y=\"-354.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_name</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-354.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _study_inputs->study -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>_study_inputs->study</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M441.28,-401.81C545.86,-403.33 708.17,-405.69 784.33,-406.79\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"784.44,-410.3 794.49,-406.94 784.54,-403.3 784.44,-410.3\"/>\n", |
| "</g>\n", |
| "<!-- _model_config_inputs -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>_model_config_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"470,-631.41 218,-631.41 218,-565.41 470,-565.41 470,-631.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"294.5\" y=\"-605.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"416\" y=\"-605.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"233\" y=\"-584.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">model_config_override</text>\n", |
| "<text text-anchor=\"start\" x=\"396\" y=\"-584.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _model_config_inputs->model_config -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>_model_config_inputs->model_config</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M470.14,-586.21C501.55,-579.98 534.43,-570.58 563,-556.41 578.51,-548.71 576.43,-537.97 592,-530.41 643.02,-505.63 706.73,-495.64 754.86,-491.7\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"755.18,-495.18 764.89,-490.94 754.65,-488.2 755.18,-495.18\"/>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs -->\n", |
| "<g id=\"node20\" class=\"node\">\n", |
| "<title>_cross_validation_folds_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"563,-317.41 125,-317.41 125,-167.41 563,-167.41 563,-317.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"152\" y=\"-291.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"219\" y=\"-291.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"152\" y=\"-270.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">stratify</text>\n", |
| "<text text-anchor=\"start\" x=\"368.5\" y=\"-270.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"140\" y=\"-249.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_cv_folds</text>\n", |
| "<text text-anchor=\"start\" x=\"374\" y=\"-249.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"152.5\" y=\"-228.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"219\" y=\"-228.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"152.5\" y=\"-207.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">shuffle</text>\n", |
| "<text text-anchor=\"start\" x=\"368.5\" y=\"-207.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"159.5\" y=\"-186.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"374\" y=\"-186.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs->cross_validation_folds -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>_cross_validation_folds_inputs->cross_validation_folds</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M563.21,-280.02C617.8,-289.43 673.95,-299.11 720.3,-307.1\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"719.9,-310.58 730.34,-308.83 721.08,-303.68 719.9,-310.58\"/>\n", |
| "</g>\n", |
| "<!-- _optuna_distributions_inputs -->\n", |
| "<g id=\"node21\" class=\"node\">\n", |
| "<title>_optuna_distributions_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"495,-149.91 193,-149.91 193,-104.91 495,-104.91 495,-149.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"208\" y=\"-123.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">optuna_distributions_override</text>\n", |
| "<text text-anchor=\"start\" x=\"421\" y=\"-123.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _optuna_distributions_inputs->optuna_distributions -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>_optuna_distributions_inputs->optuna_distributions</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M495.08,-131.91C519.02,-137.26 542.59,-145.63 563,-158.41 582.54,-170.64 572.68,-188.83 592,-201.41 631.14,-226.9 681.4,-238.08 725.34,-242.6\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"725.2,-246.11 735.48,-243.54 725.85,-239.14 725.2,-246.11\"/>\n", |
| "</g>\n", |
| "<!-- _best_model_inputs -->\n", |
| "<g id=\"node22\" class=\"node\">\n", |
| "<title>_best_model_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1745,-389.41 1331,-389.41 1331,-323.41 1745,-323.41 1745,-389.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1346.5\" y=\"-363.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"1401\" y=\"-363.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"1346\" y=\"-342.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"1401\" y=\"-342.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _best_model_inputs->best_model -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>_best_model_inputs->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1724.43,-389.43C1791.23,-401.32 1862.29,-413.98 1911.96,-422.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1911.46,-426.29 1921.92,-424.6 1912.68,-419.4 1911.46,-426.29\"/>\n", |
| "</g>\n", |
| "<!-- _hyperparameter_search_inputs -->\n", |
| "<g id=\"node23\" class=\"node\">\n", |
| "<title>_hyperparameter_search_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1063,-192.91 592,-192.91 592,-105.91 1063,-105.91 1063,-192.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"636.5\" y=\"-166.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"719.5\" y=\"-166.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"636\" y=\"-145.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"719.5\" y=\"-145.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"607.5\" y=\"-124.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_optuna_trials</text>\n", |
| "<text text-anchor=\"start\" x=\"874.5\" y=\"-124.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "</g>\n", |
| "<!-- _hyperparameter_search_inputs->hyperparameter_search -->\n", |
| "<g id=\"edge22\" class=\"edge\">\n", |
| "<title>_hyperparameter_search_inputs->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1043.32,-192.97C1050.05,-195.91 1056.63,-199.05 1063,-202.41 1102.31,-223.09 1138.72,-258.38 1163.47,-285.87\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1160.89,-288.23 1170.14,-293.4 1166.13,-283.59 1160.89,-288.23\"/>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs -->\n", |
| "<g id=\"node24\" class=\"node\">\n", |
| "<title>_test_score_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"2618,-79.91 2211,-79.91 2211,-34.91 2618,-34.91 2618,-79.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"2226.5\" y=\"-53.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_test</text>\n", |
| "<text text-anchor=\"start\" x=\"2274.5\" y=\"-53.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs->test_score -->\n", |
| "<g id=\"edge26\" class=\"edge\">\n", |
| "<title>_test_score_inputs->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2540.23,-79.92C2573.67,-85.96 2608.36,-92.22 2636.66,-97.33\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2636.33,-100.83 2646.79,-99.16 2637.57,-93.94 2636.33,-100.83\"/>\n", |
| "</g>\n", |
| "<!-- config -->\n", |
| "<g id=\"node25\" class=\"node\">\n", |
| "<title>config</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"79.5,-553.41 26.5,-553.41 26.5,-517.41 85.5,-517.41 85.5,-547.41 79.5,-553.41\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"79.5,-553.41 79.5,-547.41 \"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"85.5,-547.41 79.5,-547.41 \"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-531.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">config</text>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node26\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"85.5,-498.91 26.5,-498.91 26.5,-461.91 85.5,-461.91 85.5,-498.91\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-476.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node27\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M84,-443.91C84,-443.91 28,-443.91 28,-443.91 22,-443.91 16,-437.91 16,-431.91 16,-431.91 16,-418.91 16,-418.91 16,-412.91 22,-406.91 28,-406.91 28,-406.91 84,-406.91 84,-406.91 90,-406.91 96,-412.91 96,-418.91 96,-418.91 96,-431.91 96,-431.91 96,-437.91 90,-443.91 84,-443.91\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-421.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x7f6a2a3aefb0>" |
| ] |
| }, |
| "execution_count": 4, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "dr = (\n", |
| " driver.Builder()\n", |
| " .with_modules(xgboost_optuna)\n", |
| " .with_config(dict(task=\"classification\"))\n", |
| " .build()\n", |
| ")\n", |
| "\n", |
| "dr.display_all_functions()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 2.43.0 (0)\n", |
| " -->\n", |
| "<!-- Title: %3 Pages: 1 -->\n", |
| "<svg width=\"985pt\" height=\"751pt\"\n", |
| " viewBox=\"0.00 0.00 984.95 750.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 746.5)\">\n", |
| "<title>%3</title>\n", |
| "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-746.5 980.95,-746.5 980.95,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"160,-658.5 160,-734.5 420,-734.5 420,-658.5 160,-658.5\"/>\n", |
| "<text text-anchor=\"middle\" x=\"290\" y=\"-719.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>y_test_pred</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M396.5,-157C396.5,-157 307.5,-157 307.5,-157 301.5,-157 295.5,-151 295.5,-145 295.5,-145 295.5,-105 295.5,-105 295.5,-99 301.5,-93 307.5,-93 307.5,-93 396.5,-93 396.5,-93 402.5,-93 408.5,-99 408.5,-105 408.5,-105 408.5,-145 408.5,-145 408.5,-151 402.5,-157 396.5,-157\"/>\n", |
| "<text text-anchor=\"start\" x=\"306.5\" y=\"-135.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test_pred</text>\n", |
| "<text text-anchor=\"start\" x=\"324.5\" y=\"-107.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n", |
| "</g>\n", |
| "<!-- test_score -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>test_score</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M392,-64C392,-64 312,-64 312,-64 306,-64 300,-58 300,-52 300,-52 300,-12 300,-12 300,-6 306,0 312,0 312,0 392,0 392,0 398,0 404,-6 404,-12 404,-12 404,-52 404,-52 404,-58 398,-64 392,-64\"/>\n", |
| "<text text-anchor=\"start\" x=\"311\" y=\"-42.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_score</text>\n", |
| "<text text-anchor=\"start\" x=\"338\" y=\"-14.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred->test_score -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>y_test_pred->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352,-92.94C352,-87 352,-80.7 352,-74.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"355.5,-74.23 352,-64.23 348.5,-74.23 355.5,-74.23\"/>\n", |
| "</g>\n", |
| "<!-- base_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>base_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M776.5,-529C776.5,-529 685.5,-529 685.5,-529 679.5,-529 673.5,-523 673.5,-517 673.5,-517 673.5,-477 673.5,-477 673.5,-471 679.5,-465 685.5,-465 685.5,-465 776.5,-465 776.5,-465 782.5,-465 788.5,-471 788.5,-477 788.5,-477 788.5,-517 788.5,-517 788.5,-523 782.5,-529 776.5,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"684.5\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">base_model</text>\n", |
| "<text text-anchor=\"start\" x=\"703.5\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Callable</text>\n", |
| "</g>\n", |
| "<!-- best_model -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>best_model</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M396,-250C396,-250 308,-250 308,-250 302,-250 296,-244 296,-238 296,-238 296,-198 296,-198 296,-192 302,-186 308,-186 308,-186 396,-186 396,-186 402,-186 408,-192 408,-198 408,-198 408,-238 408,-238 408,-244 402,-250 396,-250\"/>\n", |
| "<text text-anchor=\"start\" x=\"307\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_model</text>\n", |
| "<text text-anchor=\"start\" x=\"316\" y=\"-200.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">XGBModel</text>\n", |
| "</g>\n", |
| "<!-- base_model->best_model -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>base_model->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M788.62,-468.99C791.78,-467.62 794.92,-466.28 798,-465 868.9,-435.52 928.21,-474.89 960,-405\"/>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M960,-403C1062.44,-177.78 662.45,-334.04 417.83,-249.98\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"418.71,-246.58 408.12,-246.48 416.34,-253.16 418.71,-246.58\"/>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>hyperparameter_search</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M445,-436C445,-436 259,-436 259,-436 253,-436 247,-430 247,-424 247,-424 247,-384 247,-384 247,-378 253,-372 259,-372 259,-372 445,-372 445,-372 451,-372 457,-378 457,-384 457,-384 457,-424 457,-424 457,-430 451,-436 445,-436\"/>\n", |
| "<text text-anchor=\"start\" x=\"258\" y=\"-414.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">hyperparameter_search</text>\n", |
| "<text text-anchor=\"start\" x=\"339\" y=\"-386.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- base_model->hyperparameter_search -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>base_model->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M673.38,-467.94C670.57,-466.89 667.77,-465.9 665,-465 601.08,-444.22 527.35,-429.66 467.51,-420.08\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"467.69,-416.57 457.27,-418.47 466.6,-423.48 467.69,-416.57\"/>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>cross_validation_folds</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M644,-529C644,-529 474,-529 474,-529 468,-529 462,-523 462,-517 462,-517 462,-477 462,-477 462,-471 468,-465 474,-465 474,-465 644,-465 644,-465 650,-465 656,-471 656,-477 656,-477 656,-517 656,-517 656,-523 650,-529 644,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"473\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">cross_validation_folds</text>\n", |
| "<text text-anchor=\"start\" x=\"524.5\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Sequence</text>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds->hyperparameter_search -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>cross_validation_folds->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M488.34,-464.94C470.21,-456.97 450.57,-448.33 432,-440.17\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"433.25,-436.9 422.69,-436.08 430.44,-443.31 433.25,-436.9\"/>\n", |
| "</g>\n", |
| "<!-- optuna_distributions -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>optuna_distributions</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M432,-529C432,-529 272,-529 272,-529 266,-529 260,-523 260,-517 260,-517 260,-477 260,-477 260,-471 266,-465 272,-465 272,-465 432,-465 432,-465 438,-465 444,-471 444,-477 444,-477 444,-517 444,-517 444,-523 438,-529 432,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"271\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">optuna_distributions</text>\n", |
| "<text text-anchor=\"start\" x=\"339\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- optuna_distributions->hyperparameter_search -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>optuna_distributions->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352,-464.94C352,-459 352,-452.7 352,-446.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"355.5,-446.23 352,-436.23 348.5,-446.23 355.5,-446.23\"/>\n", |
| "</g>\n", |
| "<!-- higher_is_better -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>higher_is_better</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M229.5,-623C229.5,-623 104.5,-623 104.5,-623 98.5,-623 92.5,-617 92.5,-611 92.5,-611 92.5,-571 92.5,-571 92.5,-565 98.5,-559 104.5,-559 104.5,-559 229.5,-559 229.5,-559 235.5,-559 241.5,-565 241.5,-571 241.5,-571 241.5,-611 241.5,-611 241.5,-617 235.5,-623 229.5,-623\"/>\n", |
| "<text text-anchor=\"start\" x=\"103.5\" y=\"-601.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">higher_is_better</text>\n", |
| "<text text-anchor=\"start\" x=\"152\" y=\"-573.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">bool</text>\n", |
| "</g>\n", |
| "<!-- study -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>study</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M192,-529C192,-529 150,-529 150,-529 144,-529 138,-523 138,-517 138,-517 138,-477 138,-477 138,-471 144,-465 150,-465 150,-465 192,-465 192,-465 198,-465 204,-471 204,-477 204,-477 204,-517 204,-517 204,-523 198,-529 192,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"149\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study</text>\n", |
| "<text text-anchor=\"start\" x=\"150.5\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n", |
| "</g>\n", |
| "<!-- higher_is_better->study -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>higher_is_better->study</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M168.35,-558.85C168.63,-552.56 168.92,-545.85 169.21,-539.27\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"172.71,-539.26 169.65,-529.11 165.72,-538.95 172.71,-539.26\"/>\n", |
| "</g>\n", |
| "<!-- study->hyperparameter_search -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>study->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M204.22,-473.14C208.78,-470.28 213.46,-467.49 218,-465 233.74,-456.36 250.97,-447.93 267.67,-440.25\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"269.23,-443.39 276.88,-436.06 266.33,-437.01 269.23,-443.39\"/>\n", |
| "</g>\n", |
| "<!-- scorer -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>scorer</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M138,-717C138,-717 90,-717 90,-717 84,-717 78,-711 78,-705 78,-705 78,-665 78,-665 78,-659 84,-653 90,-653 90,-653 138,-653 138,-653 144,-653 150,-659 150,-665 150,-665 150,-705 150,-705 150,-711 144,-717 138,-717\"/>\n", |
| "<text text-anchor=\"start\" x=\"89\" y=\"-695.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scorer</text>\n", |
| "<text text-anchor=\"start\" x=\"101\" y=\"-667.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- scorer->higher_is_better -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>scorer->higher_is_better</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M131.94,-652.85C135.83,-646.11 139.99,-638.89 144.04,-631.86\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"147.12,-633.53 149.08,-623.11 141.05,-630.03 147.12,-633.53\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>scoring_func</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M108,-529C108,-529 12,-529 12,-529 6,-529 0,-523 0,-517 0,-517 0,-477 0,-477 0,-471 6,-465 12,-465 12,-465 108,-465 108,-465 114,-465 120,-471 120,-477 120,-477 120,-517 120,-517 120,-523 114,-529 108,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"11\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scoring_func</text>\n", |
| "<text text-anchor=\"start\" x=\"31\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "<!-- scorer->scoring_func -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>scorer->scoring_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M96.44,-652.94C91.85,-643.83 87.28,-633.71 84,-624 74.7,-596.5 68.65,-564.33 64.95,-539.38\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"68.37,-538.57 63.51,-529.16 61.44,-539.55 68.37,-538.57\"/>\n", |
| "</g>\n", |
| "<!-- model_config -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>model_config</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M919.5,-529C919.5,-529 818.5,-529 818.5,-529 812.5,-529 806.5,-523 806.5,-517 806.5,-517 806.5,-477 806.5,-477 806.5,-471 812.5,-465 818.5,-465 818.5,-465 919.5,-465 919.5,-465 925.5,-465 931.5,-471 931.5,-477 931.5,-477 931.5,-517 931.5,-517 931.5,-523 925.5,-529 919.5,-529\"/>\n", |
| "<text text-anchor=\"start\" x=\"817.5\" y=\"-507.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model_config</text>\n", |
| "<text text-anchor=\"start\" x=\"856\" y=\"-479.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- model_config->best_model -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>model_config->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M926.32,-464.84C945.87,-449.71 962.34,-429.47 960,-405\"/>\n", |
| "</g>\n", |
| "<!-- model_config->hyperparameter_search -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>model_config->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M806.23,-467.51C803.47,-466.61 800.72,-465.76 798,-465 688.13,-434.27 558.49,-419.08 467.38,-411.69\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"467.54,-408.2 457.3,-410.9 466.99,-415.17 467.54,-408.2\"/>\n", |
| "</g>\n", |
| "<!-- best_model->y_test_pred -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>best_model->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352,-185.94C352,-180 352,-173.7 352,-167.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"355.5,-167.23 352,-157.23 348.5,-167.23 355.5,-167.23\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->hyperparameter_search -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>scoring_func->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M120.36,-468.23C123.26,-467.1 126.15,-466.01 129,-465 163.6,-452.66 202.01,-441.47 237.06,-432.18\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"238.06,-435.53 246.85,-429.61 236.29,-428.76 238.06,-435.53\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->test_score -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>scoring_func->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M82.92,-464.89C106.56,-429.66 140,-369.65 140,-312 140,-312 140,-312 140,-217 140,-136.69 227.85,-83.53 290.44,-55.84\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"292.07,-58.95 299.87,-51.78 289.3,-52.53 292.07,-58.95\"/>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>best_hyperparameters</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M440,-343C440,-343 264,-343 264,-343 258,-343 252,-337 252,-331 252,-331 252,-291 252,-291 252,-285 258,-279 264,-279 264,-279 440,-279 440,-279 446,-279 452,-285 452,-291 452,-291 452,-331 452,-331 452,-337 446,-343 440,-343\"/>\n", |
| "<text text-anchor=\"start\" x=\"263\" y=\"-321.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_hyperparameters</text>\n", |
| "<text text-anchor=\"start\" x=\"339\" y=\"-293.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search->best_hyperparameters -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>hyperparameter_search->best_hyperparameters</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352,-371.94C352,-366 352,-359.7 352,-353.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"355.5,-353.23 352,-343.23 348.5,-353.23 355.5,-353.23\"/>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters->best_model -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>best_hyperparameters->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M352,-278.94C352,-273 352,-266.7 352,-260.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"355.5,-260.23 352,-250.23 348.5,-260.23 355.5,-260.23\"/>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>_y_test_pred_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"834,-240.5 426,-240.5 426,-195.5 834,-195.5 834,-240.5\"/>\n", |
| "<text text-anchor=\"start\" x=\"441\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_test</text>\n", |
| "<text text-anchor=\"start\" x=\"490\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs->y_test_pred -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>_y_test_pred_inputs->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M541.81,-195.5C503.32,-185.23 457.95,-171.92 418,-157 417.9,-156.96 417.81,-156.93 417.71,-156.89\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"419.22,-153.72 408.63,-153.32 416.65,-160.24 419.22,-153.72\"/>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>_cross_validation_folds_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"766,-624 352,-624 352,-558 766,-558 766,-624\"/>\n", |
| "<text text-anchor=\"start\" x=\"367.5\" y=\"-597.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"422\" y=\"-597.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-576.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"422\" y=\"-576.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs->cross_validation_folds -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>_cross_validation_folds_inputs->cross_validation_folds</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M559,-557.82C559,-551.85 559,-545.53 559,-539.33\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"562.5,-539.07 559,-529.07 555.5,-539.07 562.5,-539.07\"/>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs->best_model -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>_cross_validation_folds_inputs->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M766.18,-578.5C841.78,-569.94 915.14,-555.02 941,-529 980.3,-489.45 965.3,-460.5 960,-405\"/>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs->hyperparameter_search -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>_cross_validation_folds_inputs->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M351.61,-568.99C305.52,-559.94 265.88,-547.11 251,-529 232.95,-507.02 236.93,-489.72 251,-465 255.92,-456.36 262.63,-448.83 270.26,-442.31\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"272.46,-445.04 278.18,-436.12 268.14,-439.52 272.46,-445.04\"/>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>_test_score_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"833.5,-147.5 426.5,-147.5 426.5,-102.5 833.5,-102.5 833.5,-147.5\"/>\n", |
| "<text text-anchor=\"start\" x=\"442\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_test</text>\n", |
| "<text text-anchor=\"start\" x=\"490\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs->test_score -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>_test_score_inputs->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M564.12,-102.43C518.64,-87.55 458.57,-67.89 413.97,-53.28\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"415,-49.94 404.41,-50.16 412.83,-56.59 415,-49.94\"/>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"411.5,-703.5 352.5,-703.5 352.5,-666.5 411.5,-666.5 411.5,-703.5\"/>\n", |
| "<text text-anchor=\"middle\" x=\"382\" y=\"-681.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M322,-703.5C322,-703.5 266,-703.5 266,-703.5 260,-703.5 254,-697.5 254,-691.5 254,-691.5 254,-678.5 254,-678.5 254,-672.5 260,-666.5 266,-666.5 266,-666.5 322,-666.5 322,-666.5 328,-666.5 334,-672.5 334,-678.5 334,-678.5 334,-691.5 334,-691.5 334,-697.5 328,-703.5 322,-703.5\"/>\n", |
| "<text text-anchor=\"middle\" x=\"294\" y=\"-681.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "<!-- output -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>output</title>\n", |
| "<path fill=\"#ffc857\" stroke=\"black\" d=\"M224,-703.5C224,-703.5 180,-703.5 180,-703.5 174,-703.5 168,-697.5 168,-691.5 168,-691.5 168,-678.5 168,-678.5 168,-672.5 174,-666.5 180,-666.5 180,-666.5 224,-666.5 224,-666.5 230,-666.5 236,-672.5 236,-678.5 236,-678.5 236,-691.5 236,-691.5 236,-697.5 230,-703.5 224,-703.5\"/>\n", |
| "<text text-anchor=\"middle\" x=\"202\" y=\"-681.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">output</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x7f6a2839ee30>" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from sklearn.datasets import load_breast_cancer\n", |
| "from sklearn.model_selection import train_test_split\n", |
| "\n", |
| "# Load the Boston Housing dataset (regression example)\n", |
| "data = load_breast_cancer()\n", |
| "X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)\n", |
| "\n", |
| "inputs = dict(\n", |
| " X_train=X_train,\n", |
| " y_train=y_train,\n", |
| " X_test=X_test,\n", |
| " y_test=y_test,\n", |
| ")\n", |
| "\n", |
| "final_vars = [\"best_model\", \"best_hyperparameters\", \"test_score\"]\n", |
| "\n", |
| "dr.visualize_execution(\n", |
| " final_vars=final_vars,\n", |
| " inputs=inputs,\n", |
| " output_file_path=None,\n", |
| " orient=\"TB\",\n", |
| " deduplicate_inputs=True,\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "[I 2023-11-15 18:40:48,705] A new study created in memory with name: no-name-7bf4dece-e402-4b38-9818-9c564d2d68b2\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/plain": [ |
| "{'best_model': XGBClassifier(base_score=None, booster='gbtree', callbacks=None,\n", |
| " colsample_bylevel=None, colsample_bynode=None,\n", |
| " colsample_bytree=0.5256259560017136, device='cpu',\n", |
| " early_stopping_rounds=None, enable_categorical=True,\n", |
| " eval_metric=None, feature_types=None, gamma=0.11302140972445468,\n", |
| " grow_policy=None, importance_type=None,\n", |
| " interaction_constraints=None, learning_rate=0.032344180499316345,\n", |
| " max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None,\n", |
| " max_delta_step=9, max_depth=4, max_leaves=None,\n", |
| " min_child_weight=3, missing=nan, monotone_constraints=None,\n", |
| " multi_strategy=None, n_estimators=550, n_jobs=None,\n", |
| " num_parallel_tree=None, random_state=None, ...),\n", |
| " 'best_hyperparameters': {'n_estimators': 550,\n", |
| " 'learning_rate': 0.032344180499316345,\n", |
| " 'max_depth': 4,\n", |
| " 'gamma': 0.11302140972445468,\n", |
| " 'colsample_bytree': 0.5256259560017136,\n", |
| " 'min_child_weight': 3,\n", |
| " 'max_delta_step': 9},\n", |
| " 'test_score': 0.9736842105263158}" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "dr.execute(\n", |
| " final_vars=final_vars,\n", |
| " inputs=inputs,\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 3. Local copy" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# do 2. Dynamic installation\n", |
| "xgboost_optuna = dataflows.import_module(\"xgboost_optuna\", \"zilto\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# then create local copy\n", |
| "dataflows.copy(xgboost_optuna, destination_path=\"./my_copy\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# import your local copy\n", |
| "from my_copy import xgboost_optuna" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 2.43.0 (0)\n", |
| " -->\n", |
| "<!-- Title: %3 Pages: 1 -->\n", |
| "<svg width=\"2759pt\" height=\"639pt\"\n", |
| " viewBox=\"0.00 0.00 2759.00 639.41\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 635.41)\">\n", |
| "<title>%3</title>\n", |
| "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-635.41 2755,-635.41 2755,4 -4,4\"/>\n", |
| "<g id=\"clust1\" class=\"cluster\">\n", |
| "<title>cluster__legend</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"8,-398.41 8,-584.41 104,-584.41 104,-398.41 8,-398.41\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-569.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>y_test_pred</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2459,-388.41C2459,-388.41 2370,-388.41 2370,-388.41 2364,-388.41 2358,-382.41 2358,-376.41 2358,-376.41 2358,-336.41 2358,-336.41 2358,-330.41 2364,-324.41 2370,-324.41 2370,-324.41 2459,-324.41 2459,-324.41 2465,-324.41 2471,-330.41 2471,-336.41 2471,-336.41 2471,-376.41 2471,-376.41 2471,-382.41 2465,-388.41 2459,-388.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"2369\" y=\"-367.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test_pred</text>\n", |
| "<text text-anchor=\"start\" x=\"2387\" y=\"-339.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n", |
| "</g>\n", |
| "<!-- test_score -->\n", |
| "<g id=\"node16\" class=\"node\">\n", |
| "<title>test_score</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2739,-140.41C2739,-140.41 2659,-140.41 2659,-140.41 2653,-140.41 2647,-134.41 2647,-128.41 2647,-128.41 2647,-88.41 2647,-88.41 2647,-82.41 2653,-76.41 2659,-76.41 2659,-76.41 2739,-76.41 2739,-76.41 2745,-76.41 2751,-82.41 2751,-88.41 2751,-88.41 2751,-128.41 2751,-128.41 2751,-134.41 2745,-140.41 2739,-140.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"2658\" y=\"-119.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_score</text>\n", |
| "<text text-anchor=\"start\" x=\"2685\" y=\"-91.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n", |
| "</g>\n", |
| "<!-- y_test_pred->test_score -->\n", |
| "<g id=\"edge24\" class=\"edge\">\n", |
| "<title>y_test_pred->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2452.14,-324.24C2503.9,-278.8 2598.29,-195.94 2653.85,-147.17\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2656.22,-149.74 2661.42,-140.52 2651.6,-144.48 2656.22,-149.74\"/>\n", |
| "</g>\n", |
| "<!-- base_model -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>base_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M873,-603.41C873,-603.41 782,-603.41 782,-603.41 776,-603.41 770,-597.41 770,-591.41 770,-591.41 770,-551.41 770,-551.41 770,-545.41 776,-539.41 782,-539.41 782,-539.41 873,-539.41 873,-539.41 879,-539.41 885,-545.41 885,-551.41 885,-551.41 885,-591.41 885,-591.41 885,-597.41 879,-603.41 873,-603.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"781\" y=\"-582.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">base_model</text>\n", |
| "<text text-anchor=\"start\" x=\"800\" y=\"-554.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Callable</text>\n", |
| "</g>\n", |
| "<!-- best_model -->\n", |
| "<g id=\"node12\" class=\"node\">\n", |
| "<title>best_model</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2022,-466.41C2022,-466.41 1934,-466.41 1934,-466.41 1928,-466.41 1922,-460.41 1922,-454.41 1922,-454.41 1922,-414.41 1922,-414.41 1922,-408.41 1928,-402.41 1934,-402.41 1934,-402.41 2022,-402.41 2022,-402.41 2028,-402.41 2034,-408.41 2034,-414.41 2034,-414.41 2034,-454.41 2034,-454.41 2034,-460.41 2028,-466.41 2022,-466.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1933\" y=\"-445.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_model</text>\n", |
| "<text text-anchor=\"start\" x=\"1942\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">XGBModel</text>\n", |
| "</g>\n", |
| "<!-- base_model->best_model -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>base_model->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M885.16,-571.21C957.34,-571.01 1085.94,-570.79 1196,-571.41\"/>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search -->\n", |
| "<g id=\"node14\" class=\"node\">\n", |
| "<title>hyperparameter_search</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1290,-357.41C1290,-357.41 1104,-357.41 1104,-357.41 1098,-357.41 1092,-351.41 1092,-345.41 1092,-345.41 1092,-305.41 1092,-305.41 1092,-299.41 1098,-293.41 1104,-293.41 1104,-293.41 1290,-293.41 1290,-293.41 1296,-293.41 1302,-299.41 1302,-305.41 1302,-305.41 1302,-345.41 1302,-345.41 1302,-351.41 1296,-357.41 1290,-357.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1103\" y=\"-336.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">hyperparameter_search</text>\n", |
| "<text text-anchor=\"start\" x=\"1184\" y=\"-308.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- base_model->hyperparameter_search -->\n", |
| "<g id=\"edge17\" class=\"edge\">\n", |
| "<title>base_model->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M885.23,-573.79C935.83,-573.22 1009.91,-565.51 1063,-530.41 1122.59,-491.01 1161.34,-414.74 1180.87,-367.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1184.17,-368.33 1184.64,-357.75 1177.67,-365.73 1184.17,-368.33\"/>\n", |
| "</g>\n", |
| "<!-- study -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>study</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M848.5,-439.41C848.5,-439.41 806.5,-439.41 806.5,-439.41 800.5,-439.41 794.5,-433.41 794.5,-427.41 794.5,-427.41 794.5,-387.41 794.5,-387.41 794.5,-381.41 800.5,-375.41 806.5,-375.41 806.5,-375.41 848.5,-375.41 848.5,-375.41 854.5,-375.41 860.5,-381.41 860.5,-387.41 860.5,-387.41 860.5,-427.41 860.5,-427.41 860.5,-433.41 854.5,-439.41 848.5,-439.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"805.5\" y=\"-418.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study</text>\n", |
| "<text text-anchor=\"start\" x=\"807\" y=\"-390.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n", |
| "</g>\n", |
| "<!-- study->hyperparameter_search -->\n", |
| "<g id=\"edge21\" class=\"edge\">\n", |
| "<title>study->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M860.62,-403.1C906.04,-396.71 991.47,-383.61 1063,-366.41 1070.85,-364.52 1078.92,-362.42 1087.01,-360.21\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1088.19,-363.52 1096.88,-357.46 1086.31,-356.77 1088.19,-363.52\"/>\n", |
| "</g>\n", |
| "<!-- task -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>task</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"77,-388.41 29,-388.41 29,-338.41 83,-338.41 83,-382.41 77,-388.41\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"77,-388.41 77,-382.41 \"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"83,-382.41 77,-382.41 \"/>\n", |
| "<text text-anchor=\"start\" x=\"39\" y=\"-374.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">task</text>\n", |
| "<text text-anchor=\"start\" x=\"42\" y=\"-346.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n", |
| "</g>\n", |
| "<!-- scorer -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>scorer</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M80,-320.41C80,-320.41 32,-320.41 32,-320.41 26,-320.41 20,-314.41 20,-308.41 20,-308.41 20,-268.41 20,-268.41 20,-262.41 26,-256.41 32,-256.41 32,-256.41 80,-256.41 80,-256.41 86,-256.41 92,-262.41 92,-268.41 92,-268.41 92,-308.41 92,-308.41 92,-314.41 86,-320.41 80,-320.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"31\" y=\"-299.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scorer</text>\n", |
| "<text text-anchor=\"start\" x=\"43\" y=\"-271.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- higher_is_better -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>higher_is_better</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M406.5,-547.41C406.5,-547.41 281.5,-547.41 281.5,-547.41 275.5,-547.41 269.5,-541.41 269.5,-535.41 269.5,-535.41 269.5,-495.41 269.5,-495.41 269.5,-489.41 275.5,-483.41 281.5,-483.41 281.5,-483.41 406.5,-483.41 406.5,-483.41 412.5,-483.41 418.5,-489.41 418.5,-495.41 418.5,-495.41 418.5,-535.41 418.5,-535.41 418.5,-541.41 412.5,-547.41 406.5,-547.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"280.5\" y=\"-526.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">higher_is_better</text>\n", |
| "<text text-anchor=\"start\" x=\"329\" y=\"-498.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">bool</text>\n", |
| "</g>\n", |
| "<!-- scorer->higher_is_better -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>scorer->higher_is_better</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M92,-313.67C96.66,-318.42 100.89,-323.7 104,-329.41 135.14,-386.59 79.93,-427.4 125,-474.41 159.01,-509.88 213.27,-520.17 259.15,-521.58\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"259.35,-525.08 269.41,-521.75 259.47,-518.08 259.35,-525.08\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func -->\n", |
| "<g id=\"node13\" class=\"node\">\n", |
| "<title>scoring_func</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M875.5,-87.41C875.5,-87.41 779.5,-87.41 779.5,-87.41 773.5,-87.41 767.5,-81.41 767.5,-75.41 767.5,-75.41 767.5,-35.41 767.5,-35.41 767.5,-29.41 773.5,-23.41 779.5,-23.41 779.5,-23.41 875.5,-23.41 875.5,-23.41 881.5,-23.41 887.5,-29.41 887.5,-35.41 887.5,-35.41 887.5,-75.41 887.5,-75.41 887.5,-81.41 881.5,-87.41 875.5,-87.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"778.5\" y=\"-66.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scoring_func</text>\n", |
| "<text text-anchor=\"start\" x=\"798.5\" y=\"-38.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "<!-- scorer->scoring_func -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>scorer->scoring_func</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M58.01,-256.15C61.32,-211.87 74.45,-133.4 125,-95.41 314.56,47.06 620.71,-3.77 757.09,-36.49\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"756.46,-39.93 767.01,-38.9 758.12,-33.13 756.46,-39.93\"/>\n", |
| "</g>\n", |
| "<!-- model_config -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>model_config</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M878,-521.41C878,-521.41 777,-521.41 777,-521.41 771,-521.41 765,-515.41 765,-509.41 765,-509.41 765,-469.41 765,-469.41 765,-463.41 771,-457.41 777,-457.41 777,-457.41 878,-457.41 878,-457.41 884,-457.41 890,-463.41 890,-469.41 890,-469.41 890,-509.41 890,-509.41 890,-515.41 884,-521.41 878,-521.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"776\" y=\"-500.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model_config</text>\n", |
| "<text text-anchor=\"start\" x=\"814.5\" y=\"-472.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- model_config->best_model -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>model_config->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M890.16,-498.02C937.71,-505.18 1004.94,-516.44 1063,-530.41 1123.14,-544.87 1134.24,-567.99 1196,-571.41\"/>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1198,-571.41C1466.19,-586.24 1779.34,-498.42 1912.17,-456.16\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1913.34,-459.46 1921.8,-453.08 1911.21,-452.79 1913.34,-459.46\"/>\n", |
| "</g>\n", |
| "<!-- model_config->hyperparameter_search -->\n", |
| "<g id=\"edge18\" class=\"edge\">\n", |
| "<title>model_config->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M890.22,-488.1C939.31,-484.88 1008.55,-475.36 1063,-448.41 1102.93,-428.64 1139.4,-393.07 1164.01,-365.26\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1166.71,-367.48 1170.63,-357.64 1161.43,-362.89 1166.71,-367.48\"/>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>best_hyperparameters</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1626,-471.41C1626,-471.41 1450,-471.41 1450,-471.41 1444,-471.41 1438,-465.41 1438,-459.41 1438,-459.41 1438,-419.41 1438,-419.41 1438,-413.41 1444,-407.41 1450,-407.41 1450,-407.41 1626,-407.41 1626,-407.41 1632,-407.41 1638,-413.41 1638,-419.41 1638,-419.41 1638,-459.41 1638,-459.41 1638,-465.41 1632,-471.41 1626,-471.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1449\" y=\"-450.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_hyperparameters</text>\n", |
| "<text text-anchor=\"start\" x=\"1525\" y=\"-422.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- best_hyperparameters->best_model -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>best_hyperparameters->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1638.12,-438.28C1721.82,-437.32 1839.29,-435.98 1911.66,-435.15\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1912.04,-438.65 1922,-435.03 1911.96,-431.65 1912.04,-438.65\"/>\n", |
| "</g>\n", |
| "<!-- higher_is_better->study -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>higher_is_better->study</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M418.57,-511.51C462.38,-506.83 518.02,-496.6 563,-474.41 578.52,-466.75 576.43,-455.97 592,-448.41 654.35,-418.14 735.64,-409.93 784.3,-407.87\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"784.59,-411.36 794.46,-407.51 784.35,-404.37 784.59,-411.36\"/>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>cross_validation_folds</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M912.5,-357.41C912.5,-357.41 742.5,-357.41 742.5,-357.41 736.5,-357.41 730.5,-351.41 730.5,-345.41 730.5,-345.41 730.5,-305.41 730.5,-305.41 730.5,-299.41 736.5,-293.41 742.5,-293.41 742.5,-293.41 912.5,-293.41 912.5,-293.41 918.5,-293.41 924.5,-299.41 924.5,-305.41 924.5,-305.41 924.5,-345.41 924.5,-345.41 924.5,-351.41 918.5,-357.41 912.5,-357.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"741.5\" y=\"-336.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">cross_validation_folds</text>\n", |
| "<text text-anchor=\"start\" x=\"793\" y=\"-308.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Sequence</text>\n", |
| "</g>\n", |
| "<!-- cross_validation_folds->hyperparameter_search -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>cross_validation_folds->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M924.58,-325.41C972.63,-325.41 1031.12,-325.41 1081.49,-325.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1081.78,-328.91 1091.78,-325.41 1081.78,-321.91 1081.78,-328.91\"/>\n", |
| "</g>\n", |
| "<!-- optuna_distributions -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>optuna_distributions</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M907.5,-275.41C907.5,-275.41 747.5,-275.41 747.5,-275.41 741.5,-275.41 735.5,-269.41 735.5,-263.41 735.5,-263.41 735.5,-223.41 735.5,-223.41 735.5,-217.41 741.5,-211.41 747.5,-211.41 747.5,-211.41 907.5,-211.41 907.5,-211.41 913.5,-211.41 919.5,-217.41 919.5,-223.41 919.5,-223.41 919.5,-263.41 919.5,-263.41 919.5,-269.41 913.5,-275.41 907.5,-275.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"746.5\" y=\"-254.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">optuna_distributions</text>\n", |
| "<text text-anchor=\"start\" x=\"814.5\" y=\"-226.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n", |
| "</g>\n", |
| "<!-- optuna_distributions->hyperparameter_search -->\n", |
| "<g id=\"edge20\" class=\"edge\">\n", |
| "<title>optuna_distributions->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M919.53,-256.59C963.19,-263.6 1016.09,-273.12 1063,-284.41 1070.85,-286.3 1078.92,-288.39 1087.01,-290.6\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1086.31,-294.04 1096.88,-293.36 1088.19,-287.3 1086.31,-294.04\"/>\n", |
| "</g>\n", |
| "<!-- study_results_df -->\n", |
| "<g id=\"node11\" class=\"node\">\n", |
| "<title>study_results_df</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M2040.5,-305.41C2040.5,-305.41 1915.5,-305.41 1915.5,-305.41 1909.5,-305.41 1903.5,-299.41 1903.5,-293.41 1903.5,-293.41 1903.5,-253.41 1903.5,-253.41 1903.5,-247.41 1909.5,-241.41 1915.5,-241.41 1915.5,-241.41 2040.5,-241.41 2040.5,-241.41 2046.5,-241.41 2052.5,-247.41 2052.5,-253.41 2052.5,-253.41 2052.5,-293.41 2052.5,-293.41 2052.5,-299.41 2046.5,-305.41 2040.5,-305.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1914.5\" y=\"-284.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results_df</text>\n", |
| "<text text-anchor=\"start\" x=\"1939.5\" y=\"-256.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n", |
| "</g>\n", |
| "<!-- best_model->y_test_pred -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>best_model->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2034.03,-424.53C2114.1,-410.16 2262.46,-383.52 2348.03,-368.16\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2348.66,-371.6 2357.89,-366.39 2347.43,-364.71 2348.66,-371.6\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->hyperparameter_search -->\n", |
| "<g id=\"edge19\" class=\"edge\">\n", |
| "<title>scoring_func->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M887.65,-52.42C938.46,-52.76 1011.47,-60.48 1063,-96.41 1097.1,-120.18 1149.38,-224.63 1177.15,-283.91\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1174,-285.43 1181.39,-293.02 1180.34,-282.48 1174,-285.43\"/>\n", |
| "</g>\n", |
| "<!-- scoring_func->test_score -->\n", |
| "<g id=\"edge25\" class=\"edge\">\n", |
| "<title>scoring_func->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M887.64,-62.83C1009.1,-77.42 1295.55,-108.41 1537,-108.41 1537,-108.41 1537,-108.41 1979,-108.41 2222.43,-108.41 2512.23,-108.41 2636.48,-108.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2636.56,-111.91 2646.56,-108.41 2636.56,-104.91 2636.56,-111.91\"/>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search->best_hyperparameters -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>hyperparameter_search->best_hyperparameters</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1247.75,-357.46C1272.03,-371.85 1302.17,-387.91 1331,-398.41 1361.61,-409.55 1395.94,-417.91 1427.67,-424.07\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1427.39,-427.58 1437.87,-426 1428.69,-420.7 1427.39,-427.58\"/>\n", |
| "</g>\n", |
| "<!-- study_results -->\n", |
| "<g id=\"node15\" class=\"node\">\n", |
| "<title>study_results</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1589.5,-305.41C1589.5,-305.41 1486.5,-305.41 1486.5,-305.41 1480.5,-305.41 1474.5,-299.41 1474.5,-293.41 1474.5,-293.41 1474.5,-253.41 1474.5,-253.41 1474.5,-247.41 1480.5,-241.41 1486.5,-241.41 1486.5,-241.41 1589.5,-241.41 1589.5,-241.41 1595.5,-241.41 1601.5,-247.41 1601.5,-253.41 1601.5,-253.41 1601.5,-293.41 1601.5,-293.41 1601.5,-299.41 1595.5,-305.41 1589.5,-305.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1485.5\" y=\"-284.21\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results</text>\n", |
| "<text text-anchor=\"start\" x=\"1517.5\" y=\"-256.21\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n", |
| "</g>\n", |
| "<!-- hyperparameter_search->study_results -->\n", |
| "<g id=\"edge23\" class=\"edge\">\n", |
| "<title>hyperparameter_search->study_results</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1302.01,-309.45C1354.42,-301.41 1416.69,-291.86 1463.96,-284.61\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1464.76,-288.03 1474.11,-283.05 1463.7,-281.11 1464.76,-288.03\"/>\n", |
| "</g>\n", |
| "<!-- study_results->study_results_df -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>study_results->study_results_df</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1601.75,-273.41C1678.36,-273.41 1808.17,-273.41 1893.25,-273.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1893.39,-276.91 1903.39,-273.41 1893.39,-269.91 1893.39,-276.91\"/>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs -->\n", |
| "<g id=\"node17\" class=\"node\">\n", |
| "<title>_y_test_pred_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"2182,-378.91 1774,-378.91 1774,-333.91 2182,-333.91 2182,-378.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"1789\" y=\"-352.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_test</text>\n", |
| "<text text-anchor=\"start\" x=\"1838\" y=\"-352.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _y_test_pred_inputs->y_test_pred -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>_y_test_pred_inputs->y_test_pred</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2182.26,-356.41C2242.09,-356.41 2303.34,-356.41 2347.76,-356.41\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2347.83,-359.91 2357.83,-356.41 2347.83,-352.91 2347.83,-359.91\"/>\n", |
| "</g>\n", |
| "<!-- _study_inputs -->\n", |
| "<g id=\"node18\" class=\"node\">\n", |
| "<title>_study_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"441,-464.91 247,-464.91 247,-335.91 441,-335.91 441,-464.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"283\" y=\"-438.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sampler</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-438.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"287.5\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">pruner</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-417.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"262\" y=\"-396.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_storage</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-396.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "<text text-anchor=\"start\" x=\"265\" y=\"-375.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">load_if_exists</text>\n", |
| "<text text-anchor=\"start\" x=\"381.5\" y=\"-375.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"269\" y=\"-354.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_name</text>\n", |
| "<text text-anchor=\"start\" x=\"367\" y=\"-354.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _study_inputs->study -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>_study_inputs->study</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M441.28,-401.81C545.86,-403.33 708.17,-405.69 784.33,-406.79\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"784.44,-410.3 794.49,-406.94 784.54,-403.3 784.44,-410.3\"/>\n", |
| "</g>\n", |
| "<!-- _model_config_inputs -->\n", |
| "<g id=\"node19\" class=\"node\">\n", |
| "<title>_model_config_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"470,-631.41 218,-631.41 218,-565.41 470,-565.41 470,-631.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"294.5\" y=\"-605.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"416\" y=\"-605.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"233\" y=\"-584.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">model_config_override</text>\n", |
| "<text text-anchor=\"start\" x=\"396\" y=\"-584.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _model_config_inputs->model_config -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>_model_config_inputs->model_config</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M470.14,-586.21C501.55,-579.98 534.43,-570.58 563,-556.41 578.51,-548.71 576.43,-537.97 592,-530.41 643.02,-505.63 706.73,-495.64 754.86,-491.7\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"755.18,-495.18 764.89,-490.94 754.65,-488.2 755.18,-495.18\"/>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs -->\n", |
| "<g id=\"node20\" class=\"node\">\n", |
| "<title>_cross_validation_folds_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"563,-317.41 125,-317.41 125,-167.41 563,-167.41 563,-317.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"152\" y=\"-291.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"219\" y=\"-291.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"152\" y=\"-270.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">stratify</text>\n", |
| "<text text-anchor=\"start\" x=\"368.5\" y=\"-270.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"140\" y=\"-249.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_cv_folds</text>\n", |
| "<text text-anchor=\"start\" x=\"374\" y=\"-249.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "<text text-anchor=\"start\" x=\"152.5\" y=\"-228.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"219\" y=\"-228.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"152.5\" y=\"-207.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">shuffle</text>\n", |
| "<text text-anchor=\"start\" x=\"368.5\" y=\"-207.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n", |
| "<text text-anchor=\"start\" x=\"159.5\" y=\"-186.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n", |
| "<text text-anchor=\"start\" x=\"374\" y=\"-186.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "</g>\n", |
| "<!-- _cross_validation_folds_inputs->cross_validation_folds -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>_cross_validation_folds_inputs->cross_validation_folds</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M563.21,-280.02C617.8,-289.43 673.95,-299.11 720.3,-307.1\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"719.9,-310.58 730.34,-308.83 721.08,-303.68 719.9,-310.58\"/>\n", |
| "</g>\n", |
| "<!-- _optuna_distributions_inputs -->\n", |
| "<g id=\"node21\" class=\"node\">\n", |
| "<title>_optuna_distributions_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"495,-149.91 193,-149.91 193,-104.91 495,-104.91 495,-149.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"208\" y=\"-123.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">optuna_distributions_override</text>\n", |
| "<text text-anchor=\"start\" x=\"421\" y=\"-123.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n", |
| "</g>\n", |
| "<!-- _optuna_distributions_inputs->optuna_distributions -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>_optuna_distributions_inputs->optuna_distributions</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M495.08,-131.91C519.02,-137.26 542.59,-145.63 563,-158.41 582.54,-170.64 572.68,-188.83 592,-201.41 631.14,-226.9 681.4,-238.08 725.34,-242.6\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"725.2,-246.11 735.48,-243.54 725.85,-239.14 725.2,-246.11\"/>\n", |
| "</g>\n", |
| "<!-- _best_model_inputs -->\n", |
| "<g id=\"node22\" class=\"node\">\n", |
| "<title>_best_model_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1745,-389.41 1331,-389.41 1331,-323.41 1745,-323.41 1745,-389.41\"/>\n", |
| "<text text-anchor=\"start\" x=\"1346.5\" y=\"-363.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"1401\" y=\"-363.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"1346\" y=\"-342.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"1401\" y=\"-342.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _best_model_inputs->best_model -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>_best_model_inputs->best_model</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1724.43,-389.43C1791.23,-401.32 1862.29,-413.98 1911.96,-422.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1911.46,-426.29 1921.92,-424.6 1912.68,-419.4 1911.46,-426.29\"/>\n", |
| "</g>\n", |
| "<!-- _hyperparameter_search_inputs -->\n", |
| "<g id=\"node23\" class=\"node\">\n", |
| "<title>_hyperparameter_search_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1063,-192.91 592,-192.91 592,-105.91 1063,-105.91 1063,-192.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"636.5\" y=\"-166.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n", |
| "<text text-anchor=\"start\" x=\"719.5\" y=\"-166.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"636\" y=\"-145.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n", |
| "<text text-anchor=\"start\" x=\"719.5\" y=\"-145.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "<text text-anchor=\"start\" x=\"607.5\" y=\"-124.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_optuna_trials</text>\n", |
| "<text text-anchor=\"start\" x=\"874.5\" y=\"-124.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n", |
| "</g>\n", |
| "<!-- _hyperparameter_search_inputs->hyperparameter_search -->\n", |
| "<g id=\"edge22\" class=\"edge\">\n", |
| "<title>_hyperparameter_search_inputs->hyperparameter_search</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M1043.32,-192.97C1050.05,-195.91 1056.63,-199.05 1063,-202.41 1102.31,-223.09 1138.72,-258.38 1163.47,-285.87\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"1160.89,-288.23 1170.14,-293.4 1166.13,-283.59 1160.89,-288.23\"/>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs -->\n", |
| "<g id=\"node24\" class=\"node\">\n", |
| "<title>_test_score_inputs</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"2618,-79.91 2211,-79.91 2211,-34.91 2618,-34.91 2618,-79.91\"/>\n", |
| "<text text-anchor=\"start\" x=\"2226.5\" y=\"-53.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_test</text>\n", |
| "<text text-anchor=\"start\" x=\"2274.5\" y=\"-53.21\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n", |
| "</g>\n", |
| "<!-- _test_score_inputs->test_score -->\n", |
| "<g id=\"edge26\" class=\"edge\">\n", |
| "<title>_test_score_inputs->test_score</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M2540.23,-79.92C2573.67,-85.96 2608.36,-92.22 2636.66,-97.33\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"2636.33,-100.83 2646.79,-99.16 2637.57,-93.94 2636.33,-100.83\"/>\n", |
| "</g>\n", |
| "<!-- config -->\n", |
| "<g id=\"node25\" class=\"node\">\n", |
| "<title>config</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" points=\"79.5,-553.41 26.5,-553.41 26.5,-517.41 85.5,-517.41 85.5,-547.41 79.5,-553.41\"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"79.5,-553.41 79.5,-547.41 \"/>\n", |
| "<polyline fill=\"none\" stroke=\"black\" points=\"85.5,-547.41 79.5,-547.41 \"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-531.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">config</text>\n", |
| "</g>\n", |
| "<!-- input -->\n", |
| "<g id=\"node26\" class=\"node\">\n", |
| "<title>input</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"85.5,-498.91 26.5,-498.91 26.5,-461.91 85.5,-461.91 85.5,-498.91\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-476.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n", |
| "</g>\n", |
| "<!-- function -->\n", |
| "<g id=\"node27\" class=\"node\">\n", |
| "<title>function</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M84,-443.91C84,-443.91 28,-443.91 28,-443.91 22,-443.91 16,-437.91 16,-431.91 16,-431.91 16,-418.91 16,-418.91 16,-412.91 22,-406.91 28,-406.91 28,-406.91 84,-406.91 84,-406.91 90,-406.91 96,-412.91 96,-418.91 96,-418.91 96,-431.91 96,-431.91 96,-437.91 90,-443.91 84,-443.91\"/>\n", |
| "<text text-anchor=\"middle\" x=\"56\" y=\"-421.71\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x7f6a143b3010>" |
| ] |
| }, |
| "execution_count": 13, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "dr = (\n", |
| " driver.Builder()\n", |
| " .with_modules(xgboost_optuna)\n", |
| " .with_config(dict(task=\"classification\"))\n", |
| " .build()\n", |
| ")\n", |
| "\n", |
| "dr.display_all_functions()" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "venv", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.10.9" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 2 |
| } |