| |
| <!DOCTYPE html> |
| |
| <html> |
| <head> |
| <meta charset="utf-8" /> |
| <title>pyspark.ml.tuning — PySpark 3.5.3 documentation</title> |
| |
| <link href="../../../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet"> |
| <link href="../../../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet"> |
| |
| |
| <link rel="stylesheet" |
| href="../../../_static/vendor/fontawesome/5.13.0/css/all.min.css"> |
| <link rel="preload" as="font" type="font/woff2" crossorigin |
| href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2"> |
| <link rel="preload" as="font" type="font/woff2" crossorigin |
| href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2"> |
| |
| |
| |
| |
| |
| <link rel="stylesheet" href="../../../_static/styles/pydata-sphinx-theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" /> |
| <link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" /> |
| |
| <link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"> |
| |
| <script id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script> |
| <script src="../../../_static/jquery.js"></script> |
| <script src="../../../_static/underscore.js"></script> |
| <script src="../../../_static/doctools.js"></script> |
| <script src="../../../_static/language_data.js"></script> |
| <script src="../../../_static/clipboard.min.js"></script> |
| <script src="../../../_static/copybutton.js"></script> |
| <script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script> |
| <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script> |
| <script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script> |
| <link rel="canonical" href="https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/tuning.html" /> |
| <link rel="search" title="Search" href="../../../search.html" /> |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> |
| <meta name="docsearch:language" content="None"> |
| |
| |
| <!-- Google Analytics --> |
| |
| </head> |
| <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80"> |
| |
| <div class="container-fluid" id="banner"></div> |
| |
| |
| <nav class="navbar navbar-light navbar-expand-lg bg-light fixed-top bd-navbar" id="navbar-main"><div class="container-xl"> |
| |
| <div id="navbar-start"> |
| |
| |
| |
| <a class="navbar-brand" href="../../../index.html"> |
| <img src="../../../_static/spark-logo-reverse.png" class="logo" alt="logo"> |
| </a> |
| |
| |
| |
| </div> |
| |
| <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbar-collapsible" aria-controls="navbar-collapsible" aria-expanded="false" aria-label="Toggle navigation"> |
| <span class="navbar-toggler-icon"></span> |
| </button> |
| |
| |
| <div id="navbar-collapsible" class="col-lg-9 collapse navbar-collapse"> |
| <div id="navbar-center" class="mr-auto"> |
| |
| <div class="navbar-center-item"> |
| <ul id="navbar-main-elements" class="navbar-nav"> |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../index.html"> |
| Overview |
| </a> |
| </li> |
| |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../getting_started/index.html"> |
| Getting Started |
| </a> |
| </li> |
| |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../user_guide/index.html"> |
| User Guides |
| </a> |
| </li> |
| |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../reference/index.html"> |
| API Reference |
| </a> |
| </li> |
| |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../development/index.html"> |
| Development |
| </a> |
| </li> |
| |
| <li class="toctree-l1 nav-item"> |
| <a class="reference internal nav-link" href="../../../migration_guide/index.html"> |
| Migration Guides |
| </a> |
| </li> |
| |
| |
| </ul> |
| </div> |
| |
| </div> |
| |
| <div id="navbar-end"> |
| |
| <div class="navbar-end-item"> |
| <!-- |
| Licensed to the Apache Software Foundation (ASF) under one or more |
| contributor license agreements. See the NOTICE file distributed with |
| this work for additional information regarding copyright ownership. |
| The ASF licenses this file to You under the Apache License, Version 2.0 |
| (the "License"); you may not use this file except in compliance with |
| the License. You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| --> |
| |
| <div id="version-button" class="dropdown"> |
| <button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown"> |
| 3.5.3 |
| <span class="caret"></span> |
| </button> |
| <div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button"> |
| <!-- dropdown will be populated by javascript on page load --> |
| </div> |
| </div> |
| |
| <script type="text/javascript"> |
| // Function to construct the target URL from the JSON components |
| function buildURL(entry) { |
| var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja |
| template = template.replace("{version}", entry.version); |
| return template; |
| } |
| |
| // Function to check if corresponding page path exists in other version of docs |
| // and, if so, go there instead of the homepage of the other docs version |
| function checkPageExistsAndRedirect(event) { |
| const currentFilePath = "_modules/pyspark/ml/tuning.html", |
| otherDocsHomepage = event.target.getAttribute("href"); |
| let tryUrl = `${otherDocsHomepage}${currentFilePath}`; |
| $.ajax({ |
| type: 'HEAD', |
| url: tryUrl, |
| // if the page exists, go there |
| success: function() { |
| location.href = tryUrl; |
| } |
| }).fail(function() { |
| location.href = otherDocsHomepage; |
| }); |
| return false; |
| } |
| |
| // Function to populate the version switcher |
| (function () { |
| // get JSON config |
| $.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) { |
| // create the nodes first (before AJAX calls) to ensure the order is |
| // correct (for now, links will go to doc version homepage) |
| $.each(data, function(index, entry) { |
| // if no custom name specified (e.g., "latest"), use version string |
| if (!("name" in entry)) { |
| entry.name = entry.version; |
| } |
| // construct the appropriate URL, and add it to the dropdown |
| entry.url = buildURL(entry); |
| const node = document.createElement("a"); |
| node.setAttribute("class", "list-group-item list-group-item-action py-1"); |
| node.setAttribute("href", `${entry.url}`); |
| node.textContent = `${entry.name}`; |
| node.onclick = checkPageExistsAndRedirect; |
| $("#version_switcher").append(node); |
| }); |
| }); |
| })(); |
| </script> |
| </div> |
| |
| </div> |
| </div> |
| </div> |
| </nav> |
| |
| |
| <div class="container-xl"> |
| <div class="row"> |
| |
| |
| <!-- Only show if we have sidebars configured, else just a small margin --> |
| <div class="col-12 col-md-3 bd-sidebar"> |
| <div class="sidebar-start-items"><form class="bd-search d-flex align-items-center" action="../../../search.html" method="get"> |
| <i class="icon fas fa-search"></i> |
| <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" > |
| </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main navigation"> |
| <div class="bd-toc-item active"> |
| |
| </div> |
| </nav> |
| </div> |
| <div class="sidebar-end-items"> |
| </div> |
| </div> |
| |
| |
| |
| |
| <div class="d-none d-xl-block col-xl-2 bd-toc"> |
| |
| </div> |
| |
| |
| |
| |
| |
| |
| <main class="col-12 col-md-9 col-xl-7 py-md-5 pl-md-5 pr-md-4 bd-content" role="main"> |
| |
| <div> |
| |
| <h1>Source code for pyspark.ml.tuning</h1><div class="highlight"><pre> |
| <span></span><span class="c1">#</span> |
| <span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span> |
| <span class="c1"># contributor license agreements. See the NOTICE file distributed with</span> |
| <span class="c1"># this work for additional information regarding copyright ownership.</span> |
| <span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span> |
| <span class="c1"># (the "License"); you may not use this file except in compliance with</span> |
| <span class="c1"># the License. You may obtain a copy of the License at</span> |
| <span class="c1">#</span> |
| <span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span> |
| <span class="c1">#</span> |
| <span class="c1"># Unless required by applicable law or agreed to in writing, software</span> |
| <span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span> |
| <span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span> |
| <span class="c1"># See the License for the specific language governing permissions and</span> |
| <span class="c1"># limitations under the License.</span> |
| <span class="c1">#</span> |
| |
| <span class="kn">import</span> <span class="nn">os</span> |
| <span class="kn">import</span> <span class="nn">sys</span> |
| <span class="kn">import</span> <span class="nn">itertools</span> |
| <span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span> |
| |
| <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">Any</span><span class="p">,</span> |
| <span class="n">Callable</span><span class="p">,</span> |
| <span class="n">Dict</span><span class="p">,</span> |
| <span class="n">Iterable</span><span class="p">,</span> |
| <span class="n">List</span><span class="p">,</span> |
| <span class="n">Optional</span><span class="p">,</span> |
| <span class="n">Sequence</span><span class="p">,</span> |
| <span class="n">Tuple</span><span class="p">,</span> |
| <span class="n">Type</span><span class="p">,</span> |
| <span class="n">Union</span><span class="p">,</span> |
| <span class="n">cast</span><span class="p">,</span> |
| <span class="n">overload</span><span class="p">,</span> |
| <span class="n">TYPE_CHECKING</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> |
| |
| <span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">keyword_only</span><span class="p">,</span> <span class="n">since</span><span class="p">,</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">inheritable_thread_target</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">,</span> <span class="n">Model</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.common</span> <span class="kn">import</span> <span class="n">inherit_doc</span><span class="p">,</span> <span class="n">_py2java</span><span class="p">,</span> <span class="n">_java2py</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">Evaluator</span><span class="p">,</span> <span class="n">JavaEvaluator</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.param</span> <span class="kn">import</span> <span class="n">Params</span><span class="p">,</span> <span class="n">Param</span><span class="p">,</span> <span class="n">TypeConverters</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.param.shared</span> <span class="kn">import</span> <span class="n">HasCollectSubModels</span><span class="p">,</span> <span class="n">HasParallelism</span><span class="p">,</span> <span class="n">HasSeed</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.util</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">DefaultParamsReader</span><span class="p">,</span> |
| <span class="n">DefaultParamsWriter</span><span class="p">,</span> |
| <span class="n">MetaAlgorithmReadWrite</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">,</span> |
| <span class="n">MLReader</span><span class="p">,</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="n">MLWriter</span><span class="p">,</span> |
| <span class="n">JavaMLReader</span><span class="p">,</span> |
| <span class="n">JavaMLWriter</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.wrapper</span> <span class="kn">import</span> <span class="n">JavaParams</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">,</span> <span class="n">JavaWrapper</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql.functions</span> <span class="kn">import</span> <span class="n">col</span><span class="p">,</span> <span class="n">lit</span><span class="p">,</span> <span class="n">rand</span><span class="p">,</span> <span class="n">UserDefinedFunction</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql.types</span> <span class="kn">import</span> <span class="n">BooleanType</span> |
| |
| <span class="kn">from</span> <span class="nn">pyspark.sql.dataframe</span> <span class="kn">import</span> <span class="n">DataFrame</span> |
| |
| <span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml._typing</span> <span class="kn">import</span> <span class="n">ParamMap</span> |
| <span class="kn">from</span> <span class="nn">py4j.java_gateway</span> <span class="kn">import</span> <span class="n">JavaObject</span> |
| <span class="kn">from</span> <span class="nn">py4j.java_collections</span> <span class="kn">import</span> <span class="n">JavaArray</span> |
| |
| <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="s2">"ParamGridBuilder"</span><span class="p">,</span> |
| <span class="s2">"CrossValidator"</span><span class="p">,</span> |
| <span class="s2">"CrossValidatorModel"</span><span class="p">,</span> |
| <span class="s2">"TrainValidationSplit"</span><span class="p">,</span> |
| <span class="s2">"TrainValidationSplitModel"</span><span class="p">,</span> |
| <span class="p">]</span> |
| |
| |
| <span class="k">def</span> <span class="nf">_parallelFitTasks</span><span class="p">(</span> |
| <span class="n">est</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span> |
| <span class="n">train</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">,</span> |
| <span class="n">eva</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">,</span> |
| <span class="n">validation</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">,</span> |
| <span class="n">epm</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">],</span> |
| <span class="n">collectSubModel</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Callable</span><span class="p">[[],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">]]]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a list of callables which can be called from different threads to fit and evaluate</span> |
| <span class="sd"> an estimator in parallel. Each callable returns an `(index, metric)` pair.</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> est : :py:class:`pyspark.ml.baseEstimator`</span> |
| <span class="sd"> he estimator to be fit.</span> |
| <span class="sd"> train : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> DataFrame, training data set, used for fitting.</span> |
| <span class="sd"> eva : :py:class:`pyspark.ml.evaluation.Evaluator`</span> |
| <span class="sd"> used to compute `metric`</span> |
| <span class="sd"> validation : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> DataFrame, validation data set, used for evaluation.</span> |
| <span class="sd"> epm : :py:class:`collections.abc.Sequence`</span> |
| <span class="sd"> Sequence of ParamMap, params maps to be used during fitting & evaluation.</span> |
| <span class="sd"> collectSubModel : bool</span> |
| <span class="sd"> Whether to collect sub model.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> tuple</span> |
| <span class="sd"> (int, float, subModel), an index into `epm` and the associated metric value.</span> |
| <span class="sd"> """</span> |
| <span class="n">modelIter</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fitMultiple</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">epm</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">singleTask</span><span class="p">()</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">]:</span> |
| <span class="n">index</span><span class="p">,</span> <span class="n">model</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">modelIter</span><span class="p">)</span> |
| <span class="c1"># TODO: duplicate evaluator to take extra params from input</span> |
| <span class="c1"># Note: Supporting tuning params in evaluator need update method</span> |
| <span class="c1"># `MetaAlgorithmReadWrite.getAllNestedStages`, make it return</span> |
| <span class="c1"># all nested stages and evaluators</span> |
| <span class="n">metric</span> <span class="o">=</span> <span class="n">eva</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">index</span><span class="p">]))</span> |
| <span class="k">return</span> <span class="n">index</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">model</span> <span class="k">if</span> <span class="n">collectSubModel</span> <span class="k">else</span> <span class="kc">None</span> |
| |
| <span class="k">return</span> <span class="p">[</span><span class="n">singleTask</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="ParamGridBuilder"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder">[docs]</a><span class="k">class</span> <span class="nc">ParamGridBuilder</span><span class="p">:</span> |
| <span class="w"> </span><span class="sa">r</span><span class="sd">"""</span> |
| <span class="sd"> Builder for a param grid used in grid search-based model selection.</span> |
| |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.classification import LogisticRegression</span> |
| <span class="sd"> >>> lr = LogisticRegression()</span> |
| <span class="sd"> >>> output = ParamGridBuilder() \</span> |
| <span class="sd"> ... .baseOn({lr.labelCol: 'l'}) \</span> |
| <span class="sd"> ... .baseOn([lr.predictionCol, 'p']) \</span> |
| <span class="sd"> ... .addGrid(lr.regParam, [1.0, 2.0]) \</span> |
| <span class="sd"> ... .addGrid(lr.maxIter, [1, 5]) \</span> |
| <span class="sd"> ... .build()</span> |
| <span class="sd"> >>> expected = [</span> |
| <span class="sd"> ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},</span> |
| <span class="sd"> ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},</span> |
| <span class="sd"> ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},</span> |
| <span class="sd"> ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]</span> |
| <span class="sd"> >>> len(output) == len(expected)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> all([m in expected for m in output])</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="p">:</span> <span class="s2">"ParamMap"</span> <span class="o">=</span> <span class="p">{}</span> |
| |
| <div class="viewcode-block" id="ParamGridBuilder.addGrid"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.addGrid">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">addGrid</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Any</span><span class="p">],</span> <span class="n">values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"ParamGridBuilder"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the given parameters in this grid to fixed values.</span> |
| |
| <span class="sd"> param must be an instance of Param associated with an instance of Params</span> |
| <span class="sd"> (such as Estimator or Transformer).</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">Param</span><span class="p">):</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"param must be an instance of Param"</span><span class="p">)</span> |
| |
| <span class="k">return</span> <span class="bp">self</span></div> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__args</span><span class="p">:</span> <span class="s2">"ParamMap"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"ParamGridBuilder"</span><span class="p">:</span> |
| <span class="o">...</span> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"ParamGridBuilder"</span><span class="p">:</span> |
| <span class="o">...</span> |
| |
| <div class="viewcode-block" id="ParamGridBuilder.baseOn"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.baseOn">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">]])</span> <span class="o">-></span> <span class="s2">"ParamGridBuilder"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the given parameters in this grid to fixed values.</span> |
| <span class="sd"> Accepts either a parameter dictionary or a list of (parameter, value) pairs.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">baseOn</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">())</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">for</span> <span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="ow">in</span> <span class="n">args</span><span class="p">:</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">addGrid</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="p">[</span><span class="n">value</span><span class="p">])</span> |
| |
| <span class="k">return</span> <span class="bp">self</span></div> |
| |
| <div class="viewcode-block" id="ParamGridBuilder.build"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.build">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Builds and returns all combinations of parameters specified</span> |
| <span class="sd"> by the param grid.</span> |
| <span class="sd"> """</span> |
| <span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> |
| <span class="n">grid_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="o">.</span><span class="n">values</span><span class="p">()</span> |
| |
| <span class="k">def</span> <span class="nf">to_key_value_pairs</span><span class="p">(</span> |
| <span class="n">keys</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Param</span><span class="p">],</span> <span class="n">values</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span> |
| <span class="k">return</span> <span class="p">[(</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">typeConverter</span><span class="p">(</span><span class="n">value</span><span class="p">))</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">)]</span> |
| |
| <span class="k">return</span> <span class="p">[</span><span class="nb">dict</span><span class="p">(</span><span class="n">to_key_value_pairs</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">prod</span><span class="p">))</span> <span class="k">for</span> <span class="n">prod</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="o">*</span><span class="n">grid_values</span><span class="p">)]</span></div></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_ValidatorParams</span><span class="p">(</span><span class="n">HasSeed</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Common params for TrainValidationSplit and CrossValidator.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">"estimator"</span><span class="p">,</span> <span class="s2">"estimator to be cross-validated"</span> |
| <span class="p">)</span> |
| <span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">"estimatorParamMaps"</span><span class="p">,</span> <span class="s2">"estimator param maps"</span> |
| <span class="p">)</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"evaluator"</span><span class="p">,</span> |
| <span class="s2">"evaluator used to select hyper-parameters that maximize the validator metric"</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Estimator</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of estimator or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of estimatorParamMaps or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Evaluator</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of evaluator or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java_impl</span><span class="p">(</span> |
| <span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Estimator</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">],</span> <span class="n">Evaluator</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.</span> |
| <span class="sd"> """</span> |
| |
| <span class="c1"># Load information from java_stage to the instance.</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Estimator</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Evaluator</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">estimator</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">):</span> |
| <span class="n">epms</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="n">estimator</span><span class="o">.</span><span class="n">_transfer_param_map_from_java</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">epm</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span> |
| <span class="p">]</span> |
| <span class="k">elif</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">):</span> |
| <span class="c1"># Meta estimator such as Pipeline, OneVsRest</span> |
| <span class="n">epms</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">meta_estimator_transfer_param_maps_from_java</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Unsupported estimator used in tuning: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">estimator</span><span class="p">))</span> |
| |
| <span class="k">return</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java_impl</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="s2">"JavaObject"</span><span class="p">,</span> <span class="s2">"JavaObject"</span><span class="p">,</span> <span class="s2">"JavaObject"</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">gateway</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span> |
| <span class="k">assert</span> <span class="n">gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="bp">cls</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">param</span><span class="o">.</span><span class="n">ParamMap</span> |
| |
| <span class="n">estimator</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">estimator</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">):</span> |
| <span class="n">java_epms</span> <span class="o">=</span> <span class="n">gateway</span><span class="o">.</span><span class="n">new_array</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()))</span> |
| <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">epm</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()):</span> |
| <span class="n">java_epms</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">estimator</span><span class="o">.</span><span class="n">_transfer_param_map_to_java</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span> |
| <span class="k">elif</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">):</span> |
| <span class="c1"># Meta estimator such as Pipeline, OneVsRest</span> |
| <span class="n">java_epms</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">meta_estimator_transfer_param_maps_to_java</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Unsupported estimator used in tuning: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">estimator</span><span class="p">))</span> |
| |
| <span class="n">java_estimator</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaEstimator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> |
| <span class="n">java_evaluator</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaEvaluator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> |
| <span class="k">return</span> <span class="n">java_estimator</span><span class="p">,</span> <span class="n">java_epms</span><span class="p">,</span> <span class="n">java_evaluator</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_ValidatorSharedReadWrite</span><span class="p">:</span> |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">meta_estimator_transfer_param_maps_to_java</span><span class="p">(</span> |
| <span class="n">pyEstimator</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">pyParamMaps</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaArray"</span><span class="p">:</span> |
| <span class="n">pyStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">pyEstimator</span><span class="p">)</span> |
| <span class="n">stagePairs</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">stage</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()),</span> <span class="n">pyStages</span><span class="p">))</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| |
| <span class="k">assert</span> <span class="p">(</span> |
| <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| <span class="p">)</span> |
| |
| <span class="n">paramMapCls</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">param</span><span class="o">.</span><span class="n">ParamMap</span> |
| <span class="n">javaParamMaps</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">new_array</span><span class="p">(</span><span class="n">paramMapCls</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">pyParamMaps</span><span class="p">))</span> |
| |
| <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">pyParamMap</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">pyParamMaps</span><span class="p">):</span> |
| <span class="n">javaParamMap</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">"org.apache.spark.ml.param.ParamMap"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">pyParam</span><span class="p">,</span> <span class="n">pyValue</span> <span class="ow">in</span> <span class="n">pyParamMap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">javaParam</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="k">for</span> <span class="n">pyStage</span><span class="p">,</span> <span class="n">javaStage</span> <span class="ow">in</span> <span class="n">stagePairs</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">_testOwnParam</span><span class="p">(</span><span class="n">pyParam</span><span class="o">.</span><span class="n">parent</span><span class="p">,</span> <span class="n">pyParam</span><span class="o">.</span><span class="n">name</span><span class="p">):</span> |
| <span class="n">javaParam</span> <span class="o">=</span> <span class="n">javaStage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">pyParam</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> |
| <span class="k">break</span> |
| <span class="k">if</span> <span class="n">javaParam</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Resolve param in estimatorParamMaps failed: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pyParam</span><span class="p">))</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pyValue</span><span class="p">,</span> <span class="n">Params</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">pyValue</span><span class="p">,</span> <span class="s2">"_to_java"</span><span class="p">):</span> |
| <span class="n">javaValue</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">pyValue</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">javaValue</span> <span class="o">=</span> <span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">pyValue</span><span class="p">)</span> |
| <span class="n">pair</span> <span class="o">=</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">javaValue</span><span class="p">)</span> |
| <span class="n">javaParamMap</span><span class="o">.</span><span class="n">put</span><span class="p">([</span><span class="n">pair</span><span class="p">])</span> |
| <span class="n">javaParamMaps</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">javaParamMap</span> |
| <span class="k">return</span> <span class="n">javaParamMaps</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">meta_estimator_transfer_param_maps_from_java</span><span class="p">(</span> |
| <span class="n">pyEstimator</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">javaParamMaps</span><span class="p">:</span> <span class="s2">"JavaArray"</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]:</span> |
| <span class="n">pyStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">pyEstimator</span><span class="p">)</span> |
| <span class="n">stagePairs</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">stage</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()),</span> <span class="n">pyStages</span><span class="p">))</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">pyParamMaps</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">for</span> <span class="n">javaParamMap</span> <span class="ow">in</span> <span class="n">javaParamMaps</span><span class="p">:</span> |
| <span class="n">pyParamMap</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="k">for</span> <span class="n">javaPair</span> <span class="ow">in</span> <span class="n">javaParamMap</span><span class="o">.</span><span class="n">toList</span><span class="p">():</span> |
| <span class="n">javaParam</span> <span class="o">=</span> <span class="n">javaPair</span><span class="o">.</span><span class="n">param</span><span class="p">()</span> |
| <span class="n">pyParam</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="k">for</span> <span class="n">pyStage</span><span class="p">,</span> <span class="n">javaStage</span> <span class="ow">in</span> <span class="n">stagePairs</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">_testOwnParam</span><span class="p">(</span><span class="n">javaParam</span><span class="o">.</span><span class="n">parent</span><span class="p">(),</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">()):</span> |
| <span class="n">pyParam</span> <span class="o">=</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">())</span> |
| <span class="k">if</span> <span class="n">pyParam</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="s2">"Resolve param in estimatorParamMaps failed: "</span> |
| <span class="o">+</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">parent</span><span class="p">()</span> |
| <span class="o">+</span> <span class="s2">"."</span> |
| <span class="o">+</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">()</span> |
| <span class="p">)</span> |
| <span class="n">javaValue</span> <span class="o">=</span> <span class="n">javaPair</span><span class="o">.</span><span class="n">value</span><span class="p">()</span> |
| <span class="n">pyValue</span><span class="p">:</span> <span class="n">Any</span> |
| <span class="k">if</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">Class</span><span class="o">.</span><span class="n">forName</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.util.DefaultParamsWritable"</span> |
| <span class="p">)</span><span class="o">.</span><span class="n">isInstance</span><span class="p">(</span><span class="n">javaValue</span><span class="p">):</span> |
| <span class="n">pyValue</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">javaValue</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">pyValue</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">javaValue</span><span class="p">)</span> |
| <span class="n">pyParamMap</span><span class="p">[</span><span class="n">pyParam</span><span class="p">]</span> <span class="o">=</span> <span class="n">pyValue</span> |
| <span class="n">pyParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pyParamMap</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">pyParamMaps</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">is_java_convertible</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> |
| <span class="n">allNestedStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span> |
| <span class="n">evaluator_convertible</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">(),</span> <span class="n">JavaParams</span><span class="p">)</span> |
| <span class="n">estimator_convertible</span> <span class="o">=</span> <span class="nb">all</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="s2">"_to_java"</span><span class="p">),</span> <span class="n">allNestedStages</span><span class="p">))</span> |
| <span class="k">return</span> <span class="n">estimator_convertible</span> <span class="ow">and</span> <span class="n">evaluator_convertible</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> |
| <span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">,</span> |
| <span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">,</span> |
| <span class="n">extraMetadata</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">numParamsNotJson</span> <span class="o">=</span> <span class="mi">0</span> |
| <span class="n">jsonEstimatorParamMaps</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">for</span> <span class="n">paramMap</span> <span class="ow">in</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">():</span> |
| <span class="n">jsonParamMap</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">paramMap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">jsonParam</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"parent"</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">parent</span><span class="p">,</span> <span class="s2">"name"</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">name</span><span class="p">}</span> |
| <span class="k">if</span> <span class="p">(</span> |
| <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Estimator</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">v</span><span class="p">))</span> |
| <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">)</span> |
| <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Evaluator</span><span class="p">)</span> |
| <span class="p">):</span> |
| <span class="n">relative_path</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"epm_</span><span class="si">{</span><span class="n">p</span><span class="o">.</span><span class="n">name</span><span class="si">}{</span><span class="n">numParamsNotJson</span><span class="si">}</span><span class="s2">"</span> |
| <span class="n">param_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">relative_path</span><span class="p">)</span> |
| <span class="n">numParamsNotJson</span> <span class="o">+=</span> <span class="mi">1</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">param_path</span><span class="p">)</span> |
| <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="n">relative_path</span> |
| <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"isJson"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span> |
| <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"ValidatorSharedReadWrite.saveImpl does not handle parameters of type: "</span> |
| <span class="s2">"MLWritable that are not Estimator/Evaluator/Transformer, and if parameter "</span> |
| <span class="s2">"is estimator, it cannot be meta estimator such as Validator or OneVsRest"</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span> |
| <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"isJson"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span> |
| <span class="n">jsonParamMap</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jsonParam</span><span class="p">)</span> |
| <span class="n">jsonEstimatorParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jsonParamMap</span><span class="p">)</span> |
| |
| <span class="n">skipParams</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"estimator"</span><span class="p">,</span> <span class="s2">"evaluator"</span><span class="p">,</span> <span class="s2">"estimatorParamMaps"</span><span class="p">]</span> |
| <span class="n">jsonParams</span> <span class="o">=</span> <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">extractJsonParams</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">skipParams</span><span class="p">)</span> |
| <span class="n">jsonParams</span><span class="p">[</span><span class="s2">"estimatorParamMaps"</span><span class="p">]</span> <span class="o">=</span> <span class="n">jsonEstimatorParamMaps</span> |
| |
| <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">saveMetadata</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="p">,</span> <span class="n">jsonParams</span><span class="p">)</span> |
| <span class="n">evaluatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"evaluator"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">evaluatorPath</span><span class="p">)</span> |
| <span class="n">estimatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"estimator"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">estimatorPath</span><span class="p">)</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">metadata</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Evaluator</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]:</span> |
| <span class="n">evaluatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"evaluator"</span><span class="p">)</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Evaluator</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">evaluatorPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span> |
| <span class="n">estimatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"estimator"</span><span class="p">)</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Estimator</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">estimatorPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span> |
| |
| <span class="n">uidToParams</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getUidMap</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span> |
| <span class="n">uidToParams</span><span class="p">[</span><span class="n">evaluator</span><span class="o">.</span><span class="n">uid</span><span class="p">]</span> <span class="o">=</span> <span class="n">evaluator</span> |
| |
| <span class="n">jsonEstimatorParamMaps</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"paramMap"</span><span class="p">][</span><span class="s2">"estimatorParamMaps"</span><span class="p">]</span> |
| |
| <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">for</span> <span class="n">jsonParamMap</span> <span class="ow">in</span> <span class="n">jsonEstimatorParamMaps</span><span class="p">:</span> |
| <span class="n">paramMap</span> <span class="o">=</span> <span class="p">{}</span> |
| <span class="k">for</span> <span class="n">jsonParam</span> <span class="ow">in</span> <span class="n">jsonParamMap</span><span class="p">:</span> |
| <span class="n">est</span> <span class="o">=</span> <span class="n">uidToParams</span><span class="p">[</span><span class="n">jsonParam</span><span class="p">[</span><span class="s2">"parent"</span><span class="p">]]</span> |
| <span class="n">param</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"name"</span><span class="p">])</span> |
| <span class="k">if</span> <span class="s2">"isJson"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">jsonParam</span> <span class="ow">or</span> <span class="p">(</span><span class="s2">"isJson"</span> <span class="ow">in</span> <span class="n">jsonParam</span> <span class="ow">and</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"isJson"</span><span class="p">]):</span> |
| <span class="n">value</span> <span class="o">=</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"value"</span><span class="p">]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">relativePath</span> <span class="o">=</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">"value"</span><span class="p">]</span> |
| <span class="n">valueSavedPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">relativePath</span><span class="p">)</span> |
| <span class="n">value</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">valueSavedPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span> |
| <span class="n">paramMap</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span> |
| <span class="n">estimatorParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">paramMap</span><span class="p">)</span> |
| |
| <span class="k">return</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">validateParams</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">estiamtor</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span> |
| <span class="n">uidMap</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getUidMap</span><span class="p">(</span><span class="n">estiamtor</span><span class="p">)</span> |
| |
| <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="p">[</span><span class="n">evaluator</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">uidMap</span><span class="o">.</span><span class="n">values</span><span class="p">()):</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="sa">f</span><span class="s2">"Validator write will fail because it contains </span><span class="si">{</span><span class="n">elem</span><span class="o">.</span><span class="n">uid</span><span class="si">}</span><span class="s2"> "</span> |
| <span class="sa">f</span><span class="s2">"which is not writable."</span> |
| <span class="p">)</span> |
| |
| <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span> |
| <span class="n">paramErr</span> <span class="o">=</span> <span class="p">(</span> |
| <span class="s2">"Validator save requires all Params in estimatorParamMaps to apply to "</span> |
| <span class="s2">"its Estimator, An extraneous Param was found: "</span> |
| <span class="p">)</span> |
| <span class="k">for</span> <span class="n">paramMap</span> <span class="ow">in</span> <span class="n">estimatorParamMaps</span><span class="p">:</span> |
| <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">paramMap</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">uidMap</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">paramErr</span> <span class="o">+</span> <span class="nb">repr</span><span class="p">(</span><span class="n">param</span><span class="p">))</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span><span class="n">writer</span><span class="p">:</span> <span class="n">MLWriter</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> |
| <span class="k">if</span> <span class="s2">"persistsubmodels"</span> <span class="ow">in</span> <span class="n">writer</span><span class="o">.</span><span class="n">optionMap</span><span class="p">:</span> |
| <span class="n">persistSubModelsParam</span> <span class="o">=</span> <span class="n">writer</span><span class="o">.</span><span class="n">optionMap</span><span class="p">[</span><span class="s2">"persistsubmodels"</span><span class="p">]</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> |
| <span class="k">if</span> <span class="n">persistSubModelsParam</span> <span class="o">==</span> <span class="s2">"true"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="kc">True</span> |
| <span class="k">elif</span> <span class="n">persistSubModelsParam</span> <span class="o">==</span> <span class="s2">"false"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="kc">False</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="sa">f</span><span class="s2">"persistSubModels option value </span><span class="si">{</span><span class="n">persistSubModelsParam</span><span class="si">}</span><span class="s2"> is invalid, "</span> |
| <span class="sa">f</span><span class="s2">"the possible values are True, 'True' or False, 'False'"</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">writer</span><span class="o">.</span><span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="c1"># type: ignore[attr-defined]</span> |
| |
| |
| <span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="p">(</span> |
| <span class="s2">"When persisting tuning models, you can only set persistSubModels to true if the tuning "</span> |
| <span class="s2">"was done with collectSubModels set to true. To save the sub-models, try rerunning fitting "</span> |
| <span class="s2">"with collectSubModels set to true."</span> |
| <span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">CrossValidatorReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">"CrossValidator"</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">"CrossValidator"</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span> |
| <span class="p">)</span> |
| <span class="n">cv</span> <span class="o">=</span> <span class="n">CrossValidator</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span> |
| <span class="p">)</span> |
| <span class="n">cv</span> <span class="o">=</span> <span class="n">cv</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">])</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">cv</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"estimatorParamMaps"</span><span class="p">])</span> |
| <span class="k">return</span> <span class="n">cv</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">CrossValidatorWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">"CrossValidator"</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">CrossValidatorModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">"CrossValidatorModel"</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">"CrossValidatorModel"</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidatorModel"</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span> |
| <span class="p">)</span> |
| <span class="n">numFolds</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"paramMap"</span><span class="p">][</span><span class="s2">"numFolds"</span><span class="p">]</span> |
| <span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"bestModel"</span><span class="p">)</span> |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"avgMetrics"</span><span class="p">]</span> |
| <span class="k">if</span> <span class="s2">"stdMetrics"</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">:</span> |
| <span class="n">stdMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"stdMetrics"</span><span class="p">]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">stdMetrics</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="n">persistSubModels</span> <span class="o">=</span> <span class="p">(</span><span class="s2">"persistSubModels"</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">)</span> <span class="ow">and</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"persistSubModels"</span><span class="p">]</span> |
| |
| <span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)]</span> <span class="o">*</span> <span class="n">numFolds</span> |
| <span class="k">for</span> <span class="n">splitIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numFolds</span><span class="p">):</span> |
| <span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)):</span> |
| <span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">,</span> <span class="s2">"subModels"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"fold</span><span class="si">{</span><span class="n">splitIndex</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">"</span> |
| <span class="p">)</span> |
| <span class="n">subModels</span><span class="p">[</span><span class="n">splitIndex</span><span class="p">][</span><span class="n">paramIndex</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span> |
| <span class="n">modelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span> |
| |
| <span class="n">cvModel</span> <span class="o">=</span> <span class="n">CrossValidatorModel</span><span class="p">(</span> |
| <span class="n">bestModel</span><span class="p">,</span> |
| <span class="n">avgMetrics</span><span class="o">=</span><span class="n">avgMetrics</span><span class="p">,</span> |
| <span class="n">subModels</span><span class="o">=</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span> |
| <span class="n">stdMetrics</span><span class="o">=</span><span class="n">stdMetrics</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">cvModel</span> <span class="o">=</span> <span class="n">cvModel</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">])</span> |
| <span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimator</span><span class="p">)</span> |
| <span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="p">)</span> |
| <span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">evaluator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span> |
| <span class="n">cvModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"estimatorParamMaps"</span><span class="p">]</span> |
| <span class="p">)</span> |
| <span class="k">return</span> <span class="n">cvModel</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">CrossValidatorModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">"CrossValidatorModel"</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> |
| <span class="n">persistSubModels</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span> |
| <span class="bp">self</span> |
| <span class="p">)</span> |
| <span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"avgMetrics"</span><span class="p">:</span> <span class="n">instance</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">,</span> <span class="s2">"persistSubModels"</span><span class="p">:</span> <span class="n">persistSubModels</span><span class="p">}</span> |
| <span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">stdMetrics</span><span class="p">:</span> |
| <span class="n">extraMetadata</span><span class="p">[</span><span class="s2">"stdMetrics"</span><span class="p">]</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">stdMetrics</span> |
| |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span> |
| <span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"bestModel"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">)</span> |
| <span class="n">subModelsPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"subModels"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">splitIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">()):</span> |
| <span class="n">splitPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">subModelsPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"fold</span><span class="si">{</span><span class="n">splitIndex</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">())):</span> |
| <span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">splitPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span><span class="p">[</span><span class="n">splitIndex</span><span class="p">][</span><span class="n">paramIndex</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_CrossValidatorParams</span><span class="p">(</span><span class="n">_ValidatorParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">numFolds</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"numFolds"</span><span class="p">,</span> |
| <span class="s2">"number of folds for cross validation"</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toInt</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">foldCol</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"foldCol"</span><span class="p">,</span> |
| <span class="s2">"Param for the column name of user "</span> |
| <span class="o">+</span> <span class="s2">"specified fold number. Once this is specified, :py:class:`CrossValidator` "</span> |
| <span class="o">+</span> <span class="s2">"won't do random k-fold split. Note that this column should be integer type "</span> |
| <span class="o">+</span> <span class="s2">"with range [0, numFolds) and Spark will throw exception on out-of-range "</span> |
| <span class="o">+</span> <span class="s2">"fold numbers."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_CrossValidatorParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">numFolds</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">foldCol</span><span class="o">=</span><span class="s2">""</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of numFolds or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of foldCol or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">foldCol</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="CrossValidator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator">[docs]</a><span class="k">class</span> <span class="nc">CrossValidator</span><span class="p">(</span> |
| <span class="n">Estimator</span><span class="p">[</span><span class="s2">"CrossValidatorModel"</span><span class="p">],</span> |
| <span class="n">_CrossValidatorParams</span><span class="p">,</span> |
| <span class="n">HasParallelism</span><span class="p">,</span> |
| <span class="n">HasCollectSubModels</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"CrossValidator"</span><span class="p">],</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| |
| <span class="sd"> K-fold cross validation performs model selection by splitting the dataset into a set of</span> |
| <span class="sd"> non-overlapping randomly partitioned folds which are used as separate training and test datasets</span> |
| <span class="sd"> e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,</span> |
| <span class="sd"> each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the</span> |
| <span class="sd"> test set exactly once.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.classification import LogisticRegression</span> |
| <span class="sd"> >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel</span> |
| <span class="sd"> >>> import tempfile</span> |
| <span class="sd"> >>> dataset = spark.createDataFrame(</span> |
| <span class="sd"> ... [(Vectors.dense([0.0]), 0.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.4]), 1.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.5]), 0.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.6]), 1.0),</span> |
| <span class="sd"> ... (Vectors.dense([1.0]), 1.0)] * 10,</span> |
| <span class="sd"> ... ["features", "label"])</span> |
| <span class="sd"> >>> lr = LogisticRegression()</span> |
| <span class="sd"> >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()</span> |
| <span class="sd"> >>> evaluator = BinaryClassificationEvaluator()</span> |
| <span class="sd"> >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,</span> |
| <span class="sd"> ... parallelism=2)</span> |
| <span class="sd"> >>> cvModel = cv.fit(dataset)</span> |
| <span class="sd"> >>> cvModel.getNumFolds()</span> |
| <span class="sd"> 3</span> |
| <span class="sd"> >>> cvModel.avgMetrics[0]</span> |
| <span class="sd"> 0.5</span> |
| <span class="sd"> >>> path = tempfile.mkdtemp()</span> |
| <span class="sd"> >>> model_path = path + "/model"</span> |
| <span class="sd"> >>> cvModel.write().save(model_path)</span> |
| <span class="sd"> >>> cvModelRead = CrossValidatorModel.read().load(model_path)</span> |
| <span class="sd"> >>> cvModelRead.avgMetrics</span> |
| <span class="sd"> [0.5, ...</span> |
| <span class="sd"> >>> evaluator.evaluate(cvModel.transform(dataset))</span> |
| <span class="sd"> 0.8333...</span> |
| <span class="sd"> >>> evaluator.evaluate(cvModelRead.transform(dataset))</span> |
| <span class="sd"> 0.8333...</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">numFolds</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">foldCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\</span> |
| <span class="sd"> seed=None, parallelism=1, collectSubModels=False, foldCol="")</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="CrossValidator.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">numFolds</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">foldCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\</span> |
| <span class="sd"> seed=None, parallelism=1, collectSubModels=False, foldCol=""):</span> |
| <span class="sd"> Sets params for cross validator.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setEstimator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEstimator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`estimator`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setEstimatorParamMaps"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEstimatorParamMaps">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`estimatorParamMaps`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setEvaluator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEvaluator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`evaluator`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">evaluator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setNumFolds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setNumFolds">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`numFolds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">numFolds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setFoldCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setFoldCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`foldCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">foldCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`parallelism`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.setCollectSubModels"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setCollectSubModels">[docs]</a> <span class="k">def</span> <span class="nf">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`collectSubModels`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">collectSubModels</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">_gen_avg_and_std_metrics</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]:</span> |
| <span class="n">avg_metrics</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> |
| <span class="n">std_metrics</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> |
| <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">avg_metrics</span><span class="p">),</span> <span class="nb">list</span><span class="p">(</span><span class="n">std_metrics</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidatorModel"</span><span class="p">:</span> |
| <span class="n">est</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span> |
| <span class="n">epm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span> |
| <span class="n">numModels</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span> |
| <span class="n">eva</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">nFolds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span> |
| <span class="n">metrics_all</span> <span class="o">=</span> <span class="p">[[</span><span class="mf">0.0</span><span class="p">]</span> <span class="o">*</span> <span class="n">numModels</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">)]</span> |
| |
| <span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numModels</span><span class="p">))</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="n">collectSubModelsParam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span> |
| <span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numModels</span><span class="p">)]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">)]</span> |
| |
| <span class="n">datasets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kFold</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span> |
| <span class="n">validation</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| |
| <span class="n">tasks</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span> |
| <span class="n">inheritable_thread_target</span><span class="p">,</span> |
| <span class="n">_parallelFitTasks</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">eva</span><span class="p">,</span> <span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">,</span> <span class="n">collectSubModelsParam</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">subModel</span> <span class="ow">in</span> <span class="n">pool</span><span class="o">.</span><span class="n">imap_unordered</span><span class="p">(</span><span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="p">(),</span> <span class="n">tasks</span><span class="p">):</span> |
| <span class="n">metrics_all</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric</span> |
| <span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span> |
| <span class="k">assert</span> <span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| <span class="n">subModels</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">subModel</span> |
| |
| <span class="n">validation</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| <span class="n">train</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| |
| <span class="n">metrics</span><span class="p">,</span> <span class="n">std_metrics</span> <span class="o">=</span> <span class="n">CrossValidator</span><span class="o">.</span><span class="n">_gen_avg_and_std_metrics</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">)</span> |
| |
| <span class="k">if</span> <span class="n">eva</span><span class="o">.</span><span class="n">isLargerBetter</span><span class="p">():</span> |
| <span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span> |
| <span class="n">bestModel</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">bestIndex</span><span class="p">])</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span> |
| <span class="n">CrossValidatorModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">metrics</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span> <span class="n">std_metrics</span><span class="p">)</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">_kFold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">]]:</span> |
| <span class="n">nFolds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span> |
| <span class="n">foldCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">foldCol</span><span class="p">)</span> |
| |
| <span class="n">datasets</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">foldCol</span><span class="p">:</span> |
| <span class="c1"># Do random k-fold split.</span> |
| <span class="n">seed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> |
| <span class="n">h</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">nFolds</span> |
| <span class="n">randCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> <span class="o">+</span> <span class="s2">"_rand"</span> |
| <span class="n">df</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"*"</span><span class="p">,</span> <span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">alias</span><span class="p">(</span><span class="n">randCol</span><span class="p">))</span> |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span> |
| <span class="n">validateLB</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="n">h</span> |
| <span class="n">validateUB</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">h</span> |
| <span class="n">condition</span> <span class="o">=</span> <span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o">>=</span> <span class="n">validateLB</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o"><</span> <span class="n">validateUB</span><span class="p">)</span> |
| <span class="n">validation</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="o">~</span><span class="n">condition</span><span class="p">)</span> |
| <span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">train</span><span class="p">,</span> <span class="n">validation</span><span class="p">))</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="c1"># Use user-specified fold numbers.</span> |
| <span class="k">def</span> <span class="nf">checker</span><span class="p">(</span><span class="n">foldNum</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">foldNum</span> <span class="o"><</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">foldNum</span> <span class="o">>=</span> <span class="n">nFolds</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="s2">"Fold number must be in range [0, </span><span class="si">%s</span><span class="s2">), but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="p">(</span><span class="n">nFolds</span><span class="p">,</span> <span class="n">foldNum</span><span class="p">)</span> |
| <span class="p">)</span> |
| <span class="k">return</span> <span class="kc">True</span> |
| |
| <span class="n">checker_udf</span> <span class="o">=</span> <span class="n">UserDefinedFunction</span><span class="p">(</span><span class="n">checker</span><span class="p">,</span> <span class="n">BooleanType</span><span class="p">())</span> |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">checker_udf</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">foldCol</span><span class="p">])</span> <span class="o">&</span> <span class="p">(</span><span class="n">col</span><span class="p">(</span><span class="n">foldCol</span><span class="p">)</span> <span class="o">!=</span> <span class="n">lit</span><span class="p">(</span><span class="n">i</span><span class="p">)))</span> |
| <span class="n">validation</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span> |
| <span class="n">checker_udf</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">foldCol</span><span class="p">])</span> <span class="o">&</span> <span class="p">(</span><span class="n">col</span><span class="p">(</span><span class="n">foldCol</span><span class="p">)</span> <span class="o">==</span> <span class="n">lit</span><span class="p">(</span><span class="n">i</span><span class="p">))</span> |
| <span class="p">)</span> |
| <span class="k">if</span> <span class="n">training</span><span class="o">.</span><span class="n">rdd</span><span class="o">.</span><span class="n">getNumPartitions</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">training</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The training data at fold </span><span class="si">%s</span><span class="s2"> is empty."</span> <span class="o">%</span> <span class="n">i</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">validation</span><span class="o">.</span><span class="n">rdd</span><span class="o">.</span><span class="n">getNumPartitions</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">validation</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The validation data at fold </span><span class="si">%s</span><span class="s2"> is empty."</span> <span class="o">%</span> <span class="n">i</span><span class="p">)</span> |
| <span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">training</span><span class="p">,</span> <span class="n">validation</span><span class="p">))</span> |
| |
| <span class="k">return</span> <span class="n">datasets</span> |
| |
| <div class="viewcode-block" id="CrossValidator.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This copies creates a deep copy of</span> |
| <span class="sd"> the embedded paramMap, and copies the embedded and extra parameters over.</span> |
| |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`CrossValidator`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">newCV</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">):</span> |
| <span class="n">newCV</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span> |
| <span class="c1"># estimatorParamMaps remain the same</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">):</span> |
| <span class="n">newCV</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span> |
| <span class="k">return</span> <span class="n">newCV</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLWriter instance for this ML instance."""</span> |
| <span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">return</span> <span class="n">CrossValidatorWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidator.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="n">CrossValidatorReader</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLReader instance for this class."""</span> |
| <span class="k">return</span> <span class="n">CrossValidatorReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidator"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java CrossValidator, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span> |
| <span class="n">numFolds</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">()</span> |
| <span class="n">seed</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">()</span> |
| <span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span> |
| <span class="n">collectSubModels</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span> |
| <span class="n">foldCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">()</span> |
| <span class="c1"># Create a new instance of this stage.</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">epms</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span><span class="p">,</span> |
| <span class="n">numFolds</span><span class="o">=</span><span class="n">numFolds</span><span class="p">,</span> |
| <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="o">=</span><span class="n">collectSubModels</span><span class="p">,</span> |
| <span class="n">foldCol</span><span class="o">=</span><span class="n">foldCol</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java CrossValidator. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span> |
| |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">"org.apache.spark.ml.tuning.CrossValidator"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="p">(</span><span class="n">epms</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">())</span> |
| |
| <span class="k">return</span> <span class="n">_java_obj</span></div> |
| |
| |
| <div class="viewcode-block" id="CrossValidatorModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel">[docs]</a><span class="k">class</span> <span class="nc">CrossValidatorModel</span><span class="p">(</span> |
| <span class="n">Model</span><span class="p">,</span> <span class="n">_CrossValidatorParams</span><span class="p">,</span> <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"CrossValidatorModel"</span><span class="p">],</span> <span class="n">MLWritable</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> CrossValidatorModel contains the model with the highest average cross-validation</span> |
| <span class="sd"> metric across folds and uses this model to transform input data. CrossValidatorModel</span> |
| <span class="sd"> also tracks the metrics for each param map evaluated.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> Since version 3.3.0, CrossValidatorModel contains a new attribute "stdMetrics",</span> |
| <span class="sd"> which represent standard deviation of metrics for each paramMap in</span> |
| <span class="sd"> CrossValidator.estimatorParamMaps.</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span> |
| <span class="n">avgMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subModels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">stdMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="c1">#: best model from cross validation</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span> <span class="o">=</span> <span class="n">bestModel</span> |
| <span class="c1">#: Average cross-validation metrics for each paramMap in</span> |
| <span class="c1">#: CrossValidator.estimatorParamMaps, in the corresponding order.</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">avgMetrics</span> <span class="ow">or</span> <span class="p">[]</span> |
| <span class="c1">#: sub model list from cross validation</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="n">subModels</span> |
| <span class="c1">#: standard deviation of metrics for each paramMap in</span> |
| <span class="c1">#: CrossValidator.estimatorParamMaps, in the corresponding order.</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">stdMetrics</span> <span class="o">=</span> <span class="n">stdMetrics</span> <span class="ow">or</span> <span class="p">[]</span> |
| |
| <span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="CrossValidatorModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidatorModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This copies the underlying bestModel,</span> |
| <span class="sd"> creates a deep copy of the embedded paramMap, and</span> |
| <span class="sd"> copies the embedded and extra parameters over.</span> |
| <span class="sd"> It does not copy the extra Params into the subModels.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`CrossValidatorModel`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">bestModel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span> |
| <span class="n">avgMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">)</span> |
| <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="p">[</span><span class="n">sub_model</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span> |
| <span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> |
| <span class="p">]</span> |
| <span class="n">stdMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stdMetrics</span><span class="p">)</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span> |
| <span class="n">CrossValidatorModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">avgMetrics</span><span class="p">,</span> <span class="n">subModels</span><span class="p">,</span> <span class="n">stdMetrics</span><span class="p">),</span> <span class="n">extra</span><span class="o">=</span><span class="n">extra</span> |
| <span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidatorModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLWriter instance for this ML instance."""</span> |
| <span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">return</span> <span class="n">CrossValidatorModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="CrossValidatorModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="n">CrossValidatorModelReader</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLReader instance for this class."""</span> |
| <span class="k">return</span> <span class="n">CrossValidatorModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"CrossValidatorModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java CrossValidatorModel, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">bestModel</span><span class="p">())</span> |
| <span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">())</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span> |
| |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">bestModel</span><span class="o">=</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">avgMetrics</span><span class="o">=</span><span class="n">avgMetrics</span><span class="p">)</span> |
| <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> |
| <span class="s2">"evaluator"</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span> |
| <span class="s2">"estimator"</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span> |
| <span class="s2">"estimatorParamMaps"</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span> |
| <span class="s2">"numFolds"</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">(),</span> |
| <span class="s2">"foldCol"</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">(),</span> |
| <span class="s2">"seed"</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span> |
| <span class="p">}</span> |
| <span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="p">{</span><span class="n">param_name</span><span class="p">:</span> <span class="n">param_val</span><span class="p">})</span> |
| |
| <span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">hasSubModels</span><span class="p">():</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="p">[</span><span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">sub_model</span><span class="p">)</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span> |
| <span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">subModels</span><span class="p">()</span> |
| <span class="p">]</span> |
| |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.tuning.CrossValidatorModel"</span><span class="p">,</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">(),</span> |
| <span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span> |
| |
| <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> |
| <span class="s2">"evaluator"</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span> |
| <span class="s2">"estimator"</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span> |
| <span class="s2">"estimatorParamMaps"</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span> |
| <span class="s2">"numFolds"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">(),</span> |
| <span class="s2">"foldCol"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">(),</span> |
| <span class="s2">"seed"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span> |
| <span class="p">}</span> |
| <span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">java_param</span> <span class="o">=</span> <span class="n">_java_obj</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">param_name</span><span class="p">)</span> |
| <span class="n">pair</span> <span class="o">=</span> <span class="n">java_param</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">param_val</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">pair</span><span class="p">)</span> |
| |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">java_sub_models</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">sub_model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span> |
| <span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> |
| <span class="p">]</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setSubModels</span><span class="p">(</span><span class="n">java_sub_models</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">_java_obj</span></div> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">TrainValidationSplitReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">"TrainValidationSplit"</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">"TrainValidationSplit"</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span> |
| <span class="p">)</span> |
| <span class="n">tvs</span> <span class="o">=</span> <span class="n">TrainValidationSplit</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span> |
| <span class="p">)</span> |
| <span class="n">tvs</span> <span class="o">=</span> <span class="n">tvs</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">])</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">tvs</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"estimatorParamMaps"</span><span class="p">])</span> |
| <span class="k">return</span> <span class="n">tvs</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">TrainValidationSplitWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">"TrainValidationSplit"</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">TrainValidationSplitModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">"TrainValidationSplitModel"</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">"TrainValidationSplitModel"</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplitModel"</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> |
| <span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span> |
| <span class="p">)</span> |
| <span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"bestModel"</span><span class="p">)</span> |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"validationMetrics"</span><span class="p">]</span> |
| <span class="n">persistSubModels</span> <span class="o">=</span> <span class="p">(</span><span class="s2">"persistSubModels"</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">)</span> <span class="ow">and</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"persistSubModels"</span><span class="p">]</span> |
| |
| <span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)):</span> |
| <span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"subModels"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="n">subModels</span><span class="p">[</span><span class="n">paramIndex</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span> |
| <span class="n">modelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span> |
| |
| <span class="n">tvsModel</span> <span class="o">=</span> <span class="n">TrainValidationSplitModel</span><span class="p">(</span> |
| <span class="n">bestModel</span><span class="p">,</span> |
| <span class="n">validationMetrics</span><span class="o">=</span><span class="n">validationMetrics</span><span class="p">,</span> |
| <span class="n">subModels</span><span class="o">=</span><span class="n">cast</span><span class="p">(</span><span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="n">tvsModel</span> <span class="o">=</span> <span class="n">tvsModel</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">])</span> |
| <span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimator</span><span class="p">)</span> |
| <span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="p">)</span> |
| <span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">evaluator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span> |
| <span class="n">tvsModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"estimatorParamMaps"</span><span class="p">]</span> |
| <span class="p">)</span> |
| <span class="k">return</span> <span class="n">tvsModel</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">TrainValidationSplitModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">"TrainValidationSplitModel"</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> |
| <span class="n">persistSubModels</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span> |
| <span class="bp">self</span> |
| <span class="p">)</span> |
| |
| <span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span> |
| <span class="s2">"validationMetrics"</span><span class="p">:</span> <span class="n">instance</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">,</span> |
| <span class="s2">"persistSubModels"</span><span class="p">:</span> <span class="n">persistSubModels</span><span class="p">,</span> |
| <span class="p">}</span> |
| <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span> |
| <span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"bestModel"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span> |
| <span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">)</span> |
| <span class="n">subModelsPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"subModels"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">())):</span> |
| <span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">subModelsPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span><span class="p">[</span><span class="n">paramIndex</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_TrainValidationSplitParams</span><span class="p">(</span><span class="n">_ValidatorParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">trainRatio</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"trainRatio"</span><span class="p">,</span> |
| <span class="s2">"Param for ratio between train and</span><span class="se">\</span> |
| <span class="s2"> validation data. Must be between 0 and 1."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_TrainValidationSplitParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">trainRatio</span><span class="o">=</span><span class="mf">0.75</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of trainRatio or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainRatio</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="TrainValidationSplit"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit">[docs]</a><span class="k">class</span> <span class="nc">TrainValidationSplit</span><span class="p">(</span> |
| <span class="n">Estimator</span><span class="p">[</span><span class="s2">"TrainValidationSplitModel"</span><span class="p">],</span> |
| <span class="n">_TrainValidationSplitParams</span><span class="p">,</span> |
| <span class="n">HasParallelism</span><span class="p">,</span> |
| <span class="n">HasCollectSubModels</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"TrainValidationSplit"</span><span class="p">],</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Validation for hyper-parameter tuning. Randomly splits the input dataset into train and</span> |
| <span class="sd"> validation sets, and uses evaluation metric on the validation set to select the best model.</span> |
| <span class="sd"> Similar to :class:`CrossValidator`, but only splits the set once.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.classification import LogisticRegression</span> |
| <span class="sd"> >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder</span> |
| <span class="sd"> >>> from pyspark.ml.tuning import TrainValidationSplitModel</span> |
| <span class="sd"> >>> import tempfile</span> |
| <span class="sd"> >>> dataset = spark.createDataFrame(</span> |
| <span class="sd"> ... [(Vectors.dense([0.0]), 0.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.4]), 1.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.5]), 0.0),</span> |
| <span class="sd"> ... (Vectors.dense([0.6]), 1.0),</span> |
| <span class="sd"> ... (Vectors.dense([1.0]), 1.0)] * 10,</span> |
| <span class="sd"> ... ["features", "label"]).repartition(1)</span> |
| <span class="sd"> >>> lr = LogisticRegression()</span> |
| <span class="sd"> >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()</span> |
| <span class="sd"> >>> evaluator = BinaryClassificationEvaluator()</span> |
| <span class="sd"> >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,</span> |
| <span class="sd"> ... parallelism=1, seed=42)</span> |
| <span class="sd"> >>> tvsModel = tvs.fit(dataset)</span> |
| <span class="sd"> >>> tvsModel.getTrainRatio()</span> |
| <span class="sd"> 0.75</span> |
| <span class="sd"> >>> tvsModel.validationMetrics</span> |
| <span class="sd"> [0.5, ...</span> |
| <span class="sd"> >>> path = tempfile.mkdtemp()</span> |
| <span class="sd"> >>> model_path = path + "/model"</span> |
| <span class="sd"> >>> tvsModel.write().save(model_path)</span> |
| <span class="sd"> >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)</span> |
| <span class="sd"> >>> tvsModelRead.validationMetrics</span> |
| <span class="sd"> [0.5, ...</span> |
| <span class="sd"> >>> evaluator.evaluate(tvsModel.transform(dataset))</span> |
| <span class="sd"> 0.833...</span> |
| <span class="sd"> >>> evaluator.evaluate(tvsModelRead.transform(dataset))</span> |
| <span class="sd"> 0.833...</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">trainRatio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.75</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \</span> |
| <span class="sd"> trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setParams">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">trainRatio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.75</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \</span> |
| <span class="sd"> trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None):</span> |
| <span class="sd"> Sets params for the train validation split.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setEstimator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEstimator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`estimator`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setEstimatorParamMaps"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEstimatorParamMaps">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`estimatorParamMaps`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setEvaluator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEvaluator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`evaluator`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">evaluator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setTrainRatio"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setTrainRatio">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`trainRatio`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">trainRatio</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`parallelism`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.setCollectSubModels"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setCollectSubModels">[docs]</a> <span class="k">def</span> <span class="nf">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`collectSubModels`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">collectSubModels</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplitModel"</span><span class="p">:</span> |
| <span class="n">est</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span> |
| <span class="n">epm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span> |
| <span class="n">numModels</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span> |
| <span class="n">eva</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">tRatio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainRatio</span><span class="p">)</span> |
| <span class="n">seed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> |
| <span class="n">randCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> <span class="o">+</span> <span class="s2">"_rand"</span> |
| <span class="n">df</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"*"</span><span class="p">,</span> <span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">alias</span><span class="p">(</span><span class="n">randCol</span><span class="p">))</span> |
| <span class="n">condition</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o">>=</span> <span class="n">tRatio</span> |
| <span class="n">validation</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="o">~</span><span class="n">condition</span><span class="p">)</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| |
| <span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="n">collectSubModelsParam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span> |
| <span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numModels</span><span class="p">)]</span> |
| |
| <span class="n">tasks</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span> |
| <span class="n">inheritable_thread_target</span><span class="p">,</span> |
| <span class="n">_parallelFitTasks</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">eva</span><span class="p">,</span> <span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">,</span> <span class="n">collectSubModelsParam</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numModels</span><span class="p">))</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">numModels</span> |
| <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">subModel</span> <span class="ow">in</span> <span class="n">pool</span><span class="o">.</span><span class="n">imap_unordered</span><span class="p">(</span><span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="p">(),</span> <span class="n">tasks</span><span class="p">):</span> |
| <span class="n">metrics</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric</span> |
| <span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span> |
| <span class="k">assert</span> <span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| <span class="n">subModels</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">subModel</span> |
| |
| <span class="n">train</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| <span class="n">validation</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| |
| <span class="k">if</span> <span class="n">eva</span><span class="o">.</span><span class="n">isLargerBetter</span><span class="p">():</span> |
| <span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">))</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">))</span> |
| <span class="n">bestModel</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">bestIndex</span><span class="p">])</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span> |
| <span class="n">TrainValidationSplitModel</span><span class="p">(</span> |
| <span class="n">bestModel</span><span class="p">,</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">),</span> |
| <span class="n">subModels</span><span class="p">,</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="p">)</span> |
| <span class="p">)</span> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This copies creates a deep copy of</span> |
| <span class="sd"> the embedded paramMap, and copies the embedded and extra parameters over.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`TrainValidationSplit`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">newTVS</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">):</span> |
| <span class="n">newTVS</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span> |
| <span class="c1"># estimatorParamMaps remain the same</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">):</span> |
| <span class="n">newTVS</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span> |
| <span class="k">return</span> <span class="n">newTVS</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLWriter instance for this ML instance."""</span> |
| <span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">return</span> <span class="n">TrainValidationSplitWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplit.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="n">TrainValidationSplitReader</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLReader instance for this class."""</span> |
| <span class="k">return</span> <span class="n">TrainValidationSplitReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplit"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java TrainValidationSplit, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span> |
| <span class="n">trainRatio</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">()</span> |
| <span class="n">seed</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">()</span> |
| <span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span> |
| <span class="n">collectSubModels</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span> |
| <span class="c1"># Create a new instance of this stage.</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span> |
| <span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> |
| <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">epms</span><span class="p">,</span> |
| <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span><span class="p">,</span> |
| <span class="n">trainRatio</span><span class="o">=</span><span class="n">trainRatio</span><span class="p">,</span> |
| <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span> |
| <span class="n">collectSubModels</span><span class="o">=</span><span class="n">collectSubModels</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span> |
| |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.tuning.TrainValidationSplit"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="p">(</span><span class="n">epms</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="n">evaluator</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">_java_obj</span></div> |
| |
| |
| <div class="viewcode-block" id="TrainValidationSplitModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel">[docs]</a><span class="k">class</span> <span class="nc">TrainValidationSplitModel</span><span class="p">(</span> |
| <span class="n">Model</span><span class="p">,</span> <span class="n">_TrainValidationSplitParams</span><span class="p">,</span> <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"TrainValidationSplitModel"</span><span class="p">],</span> <span class="n">MLWritable</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model from train validation split.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span> |
| <span class="n">validationMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subModels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="c1">#: best model from train validation split</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span> <span class="o">=</span> <span class="n">bestModel</span> |
| <span class="c1">#: evaluated validation metrics</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">validationMetrics</span> <span class="ow">or</span> <span class="p">[]</span> |
| <span class="c1">#: sub models from train validation split</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="n">subModels</span> |
| |
| <span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="TrainValidationSplitModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplitModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This copies the underlying bestModel,</span> |
| <span class="sd"> creates a deep copy of the embedded paramMap, and</span> |
| <span class="sd"> copies the embedded and extra parameters over.</span> |
| <span class="sd"> And, this creates a shallow copy of the validationMetrics.</span> |
| <span class="sd"> It does not copy the extra Params into the subModels.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`TrainValidationSplitModel`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">bestModel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span> |
| <span class="n">validationMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">)</span> |
| <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span><span class="p">]</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span> |
| <span class="n">TrainValidationSplitModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">validationMetrics</span><span class="p">,</span> <span class="n">subModels</span><span class="p">),</span> <span class="n">extra</span><span class="o">=</span><span class="n">extra</span> |
| <span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplitModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLWriter instance for this ML instance."""</span> |
| <span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">return</span> <span class="n">TrainValidationSplitModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="TrainValidationSplitModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="n">TrainValidationSplitModelReader</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""Returns an MLReader instance for this class."""</span> |
| <span class="k">return</span> <span class="n">TrainValidationSplitModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"TrainValidationSplitModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| |
| <span class="c1"># Load information from java_stage to the instance.</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">bestModel</span><span class="p">())</span> |
| <span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">())</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span> |
| <span class="n">java_stage</span> |
| <span class="p">)</span> |
| <span class="c1"># Create a new instance of this stage.</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">bestModel</span><span class="o">=</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">validationMetrics</span><span class="o">=</span><span class="n">validationMetrics</span><span class="p">)</span> |
| <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> |
| <span class="s2">"evaluator"</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span> |
| <span class="s2">"estimator"</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span> |
| <span class="s2">"estimatorParamMaps"</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span> |
| <span class="s2">"trainRatio"</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">(),</span> |
| <span class="s2">"seed"</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span> |
| <span class="p">}</span> |
| <span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="p">{</span><span class="n">param_name</span><span class="p">:</span> <span class="n">param_val</span><span class="p">})</span> |
| |
| <span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">hasSubModels</span><span class="p">():</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">sub_model</span><span class="p">)</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">subModels</span><span class="p">()</span> |
| <span class="p">]</span> |
| |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.tuning.TrainValidationSplitModel"</span><span class="p">,</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">(),</span> |
| <span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span> |
| |
| <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> |
| <span class="s2">"evaluator"</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span> |
| <span class="s2">"estimator"</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span> |
| <span class="s2">"estimatorParamMaps"</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span> |
| <span class="s2">"trainRatio"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">(),</span> |
| <span class="s2">"seed"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span> |
| <span class="p">}</span> |
| <span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> |
| <span class="n">java_param</span> <span class="o">=</span> <span class="n">_java_obj</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">param_name</span><span class="p">)</span> |
| <span class="n">pair</span> <span class="o">=</span> <span class="n">java_param</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">param_val</span><span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">pair</span><span class="p">)</span> |
| |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">java_sub_models</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">sub_model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> |
| <span class="p">]</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setSubModels</span><span class="p">(</span><span class="n">java_sub_models</span><span class="p">)</span> |
| |
| <span class="k">return</span> <span class="n">_java_obj</span></div> |
| |
| |
| <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span> |
| <span class="kn">import</span> <span class="nn">doctest</span> |
| |
| <span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">SparkSession</span> |
| |
| <span class="n">globs</span> <span class="o">=</span> <span class="nb">globals</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> |
| |
| <span class="c1"># The small batch size here ensures that we see multiple batches,</span> |
| <span class="c1"># even in these small test examples:</span> |
| <span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">"local[2]"</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">"ml.tuning tests"</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span> |
| <span class="n">globs</span><span class="p">[</span><span class="s2">"sc"</span><span class="p">]</span> <span class="o">=</span> <span class="n">sc</span> |
| <span class="n">globs</span><span class="p">[</span><span class="s2">"spark"</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span><span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span> <span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span><span class="p">)</span> |
| <span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span> |
| <span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span> |
| <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> |
| </pre></div> |
| |
| </div> |
| |
| |
| <!-- Previous / next buttons --> |
| <div class='prev-next-area'> |
| </div> |
| |
| </main> |
| |
| |
| </div> |
| </div> |
| |
| <script src="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script> |
| <footer class="footer mt-5 mt-md-0"> |
| <div class="container"> |
| |
| <div class="footer-item"> |
| <p class="copyright"> |
| © Copyright .<br> |
| </p> |
| </div> |
| |
| <div class="footer-item"> |
| <p class="sphinx-version"> |
| Created using <a href="http://sphinx-doc.org/">Sphinx</a> 3.0.4.<br> |
| </p> |
| </div> |
| |
| </div> |
| </footer> |
| </body> |
| </html> |